chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
# litellm/proxy/search_endpoints/__init__.py
|
||||
|
||||
from .search_tool_registry import SearchToolRegistry
|
||||
|
||||
__all__ = [
|
||||
"SearchToolRegistry",
|
||||
]
|
||||
@@ -0,0 +1,266 @@
|
||||
#### Search Endpoints #####
|
||||
|
||||
import orjson
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import ORJSONResponse
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/search/{search_tool_name}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["search"],
|
||||
)
|
||||
@router.post(
|
||||
"/search/{search_tool_name}",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["search"],
|
||||
)
|
||||
@router.post(
|
||||
"/v1/search",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["search"],
|
||||
)
|
||||
@router.post(
|
||||
"/search",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["search"],
|
||||
)
|
||||
async def search(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
search_tool_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Search endpoint for performing web searches.
|
||||
|
||||
Follows the Perplexity Search API spec:
|
||||
https://docs.perplexity.ai/api-reference/search-post
|
||||
|
||||
The search_tool_name can be passed either:
|
||||
1. In the URL path: /v1/search/{search_tool_name}
|
||||
2. In the request body: {"search_tool_name": "..."}
|
||||
|
||||
Example with search_tool_name in URL (recommended - keeps body Perplexity-compatible):
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/search/litellm-search" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"query": "latest AI developments 2024",
|
||||
"max_results": 5,
|
||||
"search_domain_filter": ["arxiv.org", "nature.com"],
|
||||
"country": "US"
|
||||
}'
|
||||
```
|
||||
|
||||
Example with search_tool_name in body:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/v1/search" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"search_tool_name": "litellm-search",
|
||||
"query": "latest AI developments 2024",
|
||||
"max_results": 5,
|
||||
"search_domain_filter": ["arxiv.org", "nature.com"],
|
||||
"country": "US"
|
||||
}'
|
||||
```
|
||||
|
||||
Request Body Parameters (when search_tool_name not in URL):
|
||||
- search_tool_name (str, required if not in URL): Name of the search tool configured in router
|
||||
- query (str or list[str], required): Search query
|
||||
- max_results (int, optional): Maximum number of results (1-20), default 10
|
||||
- search_domain_filter (list[str], optional): List of domains to filter (max 20)
|
||||
- max_tokens_per_page (int, optional): Max tokens per page, default 1024
|
||||
- country (str, optional): Country code filter (e.g., 'US', 'GB', 'DE')
|
||||
|
||||
When using URL path parameter, only Perplexity-compatible parameters are needed in body:
|
||||
- query (str or list[str], required): Search query
|
||||
- max_results (int, optional): Maximum number of results (1-20), default 10
|
||||
- search_domain_filter (list[str], optional): List of domains to filter (max 20)
|
||||
- max_tokens_per_page (int, optional): Max tokens per page, default 1024
|
||||
- country (str, optional): Country code filter (e.g., 'US', 'GB', 'DE')
|
||||
|
||||
Response follows Perplexity Search API format:
|
||||
```json
|
||||
{
|
||||
"object": "search",
|
||||
"results": [
|
||||
{
|
||||
"title": "Result title",
|
||||
"url": "https://example.com",
|
||||
"snippet": "Result snippet...",
|
||||
"date": "2024-01-01",
|
||||
"last_updated": "2024-01-01"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
llm_router,
|
||||
proxy_config,
|
||||
proxy_logging_obj,
|
||||
select_data_generator,
|
||||
user_api_base,
|
||||
user_max_tokens,
|
||||
user_model,
|
||||
user_request_timeout,
|
||||
user_temperature,
|
||||
version,
|
||||
)
|
||||
|
||||
# Read request body
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
# If search_tool_name is provided in URL path, use it (takes precedence over body)
|
||||
if search_tool_name is not None:
|
||||
data["search_tool_name"] = search_tool_name
|
||||
|
||||
if "search_tool_name" in data and data["search_tool_name"]:
|
||||
data["model"] = data["search_tool_name"]
|
||||
|
||||
if llm_router is not None and hasattr(llm_router, "search_tools"):
|
||||
search_tool_name_value = data["search_tool_name"]
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Search endpoint - Looking for search_tool_name: {search_tool_name_value}. "
|
||||
f"Available search tools in router: {[tool.get('search_tool_name') for tool in llm_router.search_tools]}. "
|
||||
f"Total search tools: {len(llm_router.search_tools)}"
|
||||
)
|
||||
|
||||
matching_tools = [
|
||||
tool
|
||||
for tool in llm_router.search_tools
|
||||
if tool.get("search_tool_name") == search_tool_name_value
|
||||
]
|
||||
|
||||
if matching_tools:
|
||||
search_tool = matching_tools[0]
|
||||
search_provider = search_tool.get("litellm_params", {}).get(
|
||||
"search_provider"
|
||||
)
|
||||
|
||||
if search_provider:
|
||||
data["custom_llm_provider"] = search_provider
|
||||
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
data["metadata"]["model_group"] = search_tool_name_value
|
||||
|
||||
# Process request using ProxyBaseLLMRequestProcessing
|
||||
processor = ProxyBaseLLMRequestProcessing(data=data)
|
||||
try:
|
||||
return await processor.base_process_llm_request(
|
||||
request=request,
|
||||
fastapi_response=fastapi_response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
route_type="asearch",
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
llm_router=llm_router,
|
||||
general_settings=general_settings,
|
||||
proxy_config=proxy_config,
|
||||
select_data_generator=select_data_generator,
|
||||
model=None,
|
||||
user_model=user_model,
|
||||
user_temperature=user_temperature,
|
||||
user_request_timeout=user_request_timeout,
|
||||
user_max_tokens=user_max_tokens,
|
||||
user_api_base=user_api_base,
|
||||
version=version,
|
||||
)
|
||||
except Exception as e:
|
||||
raise await processor._handle_llm_api_exception(
|
||||
e=e,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
version=version,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/search/tools",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["search"],
|
||||
)
|
||||
@router.get(
|
||||
"/search/tools",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_class=ORJSONResponse,
|
||||
tags=["search"],
|
||||
)
|
||||
async def list_search_tools(
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
List all available search tools configured in the router.
|
||||
|
||||
This endpoint returns the search tools that are currently loaded and available
|
||||
for use with the /v1/search endpoint.
|
||||
|
||||
Example:
|
||||
```bash
|
||||
curl -X GET "http://localhost:4000/v1/search/tools" \
|
||||
-H "Authorization: Bearer sk-1234"
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"search_tool_name": "litellm-search",
|
||||
"search_provider": "perplexity",
|
||||
"description": "Perplexity search tool"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
try:
|
||||
search_tools_list = []
|
||||
|
||||
if llm_router is not None and hasattr(llm_router, "search_tools"):
|
||||
for tool in llm_router.search_tools:
|
||||
tool_info = {
|
||||
"search_tool_name": tool.get("search_tool_name"),
|
||||
"search_provider": tool.get("litellm_params", {}).get(
|
||||
"search_provider"
|
||||
),
|
||||
}
|
||||
|
||||
# Add description if available
|
||||
if "search_tool_info" in tool and tool["search_tool_info"]:
|
||||
description = tool["search_tool_info"].get("description")
|
||||
if description:
|
||||
tool_info["description"] = description
|
||||
|
||||
search_tools_list.append(tool_info)
|
||||
|
||||
return {"object": "list", "data": search_tools_list}
|
||||
except Exception as e:
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
verbose_proxy_logger.exception(f"Error listing search tools: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -0,0 +1,625 @@
|
||||
"""
|
||||
CRUD ENDPOINTS FOR SEARCH TOOLS
|
||||
"""
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.search_endpoints.search_tool_registry import SearchToolRegistry
|
||||
from litellm.types.search import (
|
||||
ListSearchToolsResponse,
|
||||
SearchTool,
|
||||
SearchToolInfoResponse,
|
||||
)
|
||||
from litellm.types.utils import SearchProviders
|
||||
|
||||
#### SEARCH TOOLS ENDPOINTS ####
|
||||
|
||||
router = APIRouter()
|
||||
SEARCH_TOOL_REGISTRY = SearchToolRegistry()
|
||||
|
||||
|
||||
def _convert_datetime_to_str(value: Union[datetime, str, None]) -> Union[str, None]:
|
||||
"""
|
||||
Convert datetime object to ISO format string.
|
||||
|
||||
Args:
|
||||
value: datetime object, string, or None
|
||||
|
||||
Returns:
|
||||
ISO format string or original value if already string or None
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
return value
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search_tools/list",
|
||||
tags=["Search Tools"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ListSearchToolsResponse,
|
||||
)
|
||||
async def list_search_tools():
|
||||
"""
|
||||
List all search tools that are available in the database and config file.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X GET "http://localhost:4000/search_tools/list" -H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"search_tools": [
|
||||
{
|
||||
"search_tool_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"search_tool_name": "litellm-search",
|
||||
"litellm_params": {
|
||||
"search_provider": "perplexity",
|
||||
"api_key": "sk-***",
|
||||
"api_base": "https://api.perplexity.ai"
|
||||
},
|
||||
"search_tool_info": {
|
||||
"description": "Perplexity search tool"
|
||||
},
|
||||
"created_at": "2023-11-09T12:34:56.789Z",
|
||||
"updated_at": "2023-11-09T12:34:56.789Z",
|
||||
"is_from_config": false
|
||||
},
|
||||
{
|
||||
"search_tool_name": "config-search-tool",
|
||||
"litellm_params": {
|
||||
"search_provider": "tavily",
|
||||
"api_key": "tvly-***"
|
||||
},
|
||||
"is_from_config": true
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import _get_masked_values
|
||||
from litellm.proxy.proxy_server import prisma_client, proxy_config
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
search_tools_from_db = await SEARCH_TOOL_REGISTRY.get_all_search_tools_from_db(
|
||||
prisma_client=prisma_client
|
||||
)
|
||||
|
||||
db_tool_names = {tool.get("search_tool_name") for tool in search_tools_from_db}
|
||||
|
||||
search_tool_configs: List[SearchToolInfoResponse] = []
|
||||
|
||||
config_search_tools = []
|
||||
|
||||
try:
|
||||
config = await proxy_config.get_config()
|
||||
parsed_tools = proxy_config.parse_search_tools(config)
|
||||
if parsed_tools:
|
||||
config_search_tools = parsed_tools
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Could not get config-defined search tools: {e}"
|
||||
)
|
||||
|
||||
for search_tool in config_search_tools:
|
||||
tool_name = search_tool.get("search_tool_name")
|
||||
if tool_name:
|
||||
litellm_params_dict = dict(search_tool.get("litellm_params", {}))
|
||||
masked_litellm_params_dict = _get_masked_values(
|
||||
litellm_params_dict,
|
||||
unmasked_length=4,
|
||||
number_of_asterisks=4,
|
||||
)
|
||||
|
||||
search_tool_configs.append(
|
||||
SearchToolInfoResponse(
|
||||
search_tool_id=None,
|
||||
search_tool_name=tool_name,
|
||||
litellm_params=masked_litellm_params_dict,
|
||||
search_tool_info=search_tool.get("search_tool_info"),
|
||||
created_at=None,
|
||||
updated_at=None,
|
||||
is_from_config=True,
|
||||
)
|
||||
)
|
||||
|
||||
search_tool_configs = [
|
||||
tool
|
||||
for tool in search_tool_configs
|
||||
if tool.get("search_tool_name") not in db_tool_names
|
||||
]
|
||||
|
||||
for search_tool in search_tools_from_db:
|
||||
litellm_params_dict = dict(search_tool.get("litellm_params", {}))
|
||||
masked_litellm_params_dict = _get_masked_values(
|
||||
litellm_params_dict,
|
||||
unmasked_length=4,
|
||||
number_of_asterisks=4,
|
||||
)
|
||||
|
||||
search_tool_configs.append(
|
||||
SearchToolInfoResponse(
|
||||
search_tool_id=search_tool.get("search_tool_id"),
|
||||
search_tool_name=search_tool.get("search_tool_name", ""),
|
||||
litellm_params=masked_litellm_params_dict,
|
||||
search_tool_info=search_tool.get("search_tool_info"),
|
||||
created_at=_convert_datetime_to_str(search_tool.get("created_at")),
|
||||
updated_at=_convert_datetime_to_str(search_tool.get("updated_at")),
|
||||
is_from_config=False,
|
||||
)
|
||||
)
|
||||
|
||||
return ListSearchToolsResponse(search_tools=search_tool_configs)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting search tools: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
class CreateSearchToolRequest(BaseModel):
|
||||
search_tool: SearchTool
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search_tools",
|
||||
tags=["Search Tools"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def create_search_tool(request: CreateSearchToolRequest):
|
||||
"""
|
||||
Create a new search tool.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/search_tools" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"search_tool": {
|
||||
"search_tool_name": "litellm-search",
|
||||
"litellm_params": {
|
||||
"search_provider": "perplexity",
|
||||
"api_key": "sk-..."
|
||||
},
|
||||
"search_tool_info": {
|
||||
"description": "Perplexity search tool"
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"search_tool_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"search_tool_name": "litellm-search",
|
||||
"litellm_params": {
|
||||
"search_provider": "perplexity",
|
||||
"api_key": "sk-..."
|
||||
},
|
||||
"search_tool_info": {
|
||||
"description": "Perplexity search tool"
|
||||
},
|
||||
"created_at": "2023-11-09T12:34:56.789Z",
|
||||
"updated_at": "2023-11-09T12:34:56.789Z"
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
result = await SEARCH_TOOL_REGISTRY.add_search_tool_to_db(
|
||||
search_tool=request.search_tool, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Successfully added search tool '{result.get('search_tool_name')}' to database. "
|
||||
f"Router will be updated by the cron job."
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error adding search tool to db: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
class UpdateSearchToolRequest(BaseModel):
|
||||
search_tool: SearchTool
|
||||
|
||||
|
||||
@router.put(
|
||||
"/search_tools/{search_tool_id}",
|
||||
tags=["Search Tools"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def update_search_tool(search_tool_id: str, request: UpdateSearchToolRequest):
|
||||
"""
|
||||
Update an existing search tool.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X PUT "http://localhost:4000/search_tools/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"search_tool": {
|
||||
"search_tool_name": "updated-search",
|
||||
"litellm_params": {
|
||||
"search_provider": "perplexity",
|
||||
"api_key": "sk-new-key"
|
||||
},
|
||||
"search_tool_info": {
|
||||
"description": "Updated search tool"
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"search_tool_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"search_tool_name": "updated-search",
|
||||
"litellm_params": {
|
||||
"search_provider": "perplexity",
|
||||
"api_key": "sk-new-key"
|
||||
},
|
||||
"search_tool_info": {
|
||||
"description": "Updated search tool"
|
||||
},
|
||||
"created_at": "2023-11-09T12:34:56.789Z",
|
||||
"updated_at": "2023-11-09T13:45:12.345Z"
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
# Check if search tool exists
|
||||
existing_tool = await SEARCH_TOOL_REGISTRY.get_search_tool_by_id_from_db(
|
||||
search_tool_id=search_tool_id, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
if existing_tool is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Search tool with ID {search_tool_id} not found",
|
||||
)
|
||||
|
||||
result = await SEARCH_TOOL_REGISTRY.update_search_tool_in_db(
|
||||
search_tool_id=search_tool_id,
|
||||
search_tool=request.search_tool,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Successfully updated search tool '{result.get('search_tool_name')}' in database. "
|
||||
f"Router will be updated by the cron job."
|
||||
)
|
||||
|
||||
return result
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating search tool: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/search_tools/{search_tool_id}",
|
||||
tags=["Search Tools"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_search_tool(search_tool_id: str):
|
||||
"""
|
||||
Delete a search tool.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X DELETE "http://localhost:4000/search_tools/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"message": "Search tool 123e4567-e89b-12d3-a456-426614174000 deleted successfully",
|
||||
"search_tool_name": "litellm-search"
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
# Check if search tool exists
|
||||
existing_tool = await SEARCH_TOOL_REGISTRY.get_search_tool_by_id_from_db(
|
||||
search_tool_id=search_tool_id, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
if existing_tool is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Search tool with ID {search_tool_id} not found",
|
||||
)
|
||||
|
||||
result = await SEARCH_TOOL_REGISTRY.delete_search_tool_from_db(
|
||||
search_tool_id=search_tool_id, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Successfully deleted search tool from database. "
|
||||
"Router will be updated by the cron job."
|
||||
)
|
||||
|
||||
return result
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting search tool: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search_tools/{search_tool_id}",
|
||||
tags=["Search Tools"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_search_tool_info(search_tool_id: str):
|
||||
"""
|
||||
Get detailed information about a specific search tool by ID.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X GET "http://localhost:4000/search_tools/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"search_tool_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"search_tool_name": "litellm-search",
|
||||
"litellm_params": {
|
||||
"search_provider": "perplexity",
|
||||
"api_key": "sk-***"
|
||||
},
|
||||
"search_tool_info": {
|
||||
"description": "Perplexity search tool"
|
||||
},
|
||||
"created_at": "2023-11-09T12:34:56.789Z",
|
||||
"updated_at": "2023-11-09T12:34:56.789Z"
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import _get_masked_values
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Prisma client not initialized")
|
||||
|
||||
try:
|
||||
result = await SEARCH_TOOL_REGISTRY.get_search_tool_by_id_from_db(
|
||||
search_tool_id=search_tool_id, prisma_client=prisma_client
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Search tool with ID {search_tool_id} not found",
|
||||
)
|
||||
|
||||
# Mask sensitive data
|
||||
litellm_params_dict = dict(result.get("litellm_params", {}))
|
||||
masked_litellm_params_dict = _get_masked_values(
|
||||
litellm_params_dict,
|
||||
unmasked_length=4,
|
||||
number_of_asterisks=4,
|
||||
)
|
||||
|
||||
return SearchToolInfoResponse(
|
||||
search_tool_id=result.get("search_tool_id"),
|
||||
search_tool_name=result.get("search_tool_name", ""),
|
||||
litellm_params=masked_litellm_params_dict,
|
||||
search_tool_info=result.get("search_tool_info"),
|
||||
created_at=_convert_datetime_to_str(result.get("created_at")),
|
||||
updated_at=_convert_datetime_to_str(result.get("updated_at")),
|
||||
is_from_config=False, # This endpoint only returns DB tools
|
||||
)
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting search tool info: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
class TestSearchToolConnectionRequest(BaseModel):
|
||||
litellm_params: Dict[str, Any]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/search_tools/test_connection",
|
||||
tags=["Search Tools"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def test_search_tool_connection(request: TestSearchToolConnectionRequest):
|
||||
"""
|
||||
Test connection to a search provider with the given configuration.
|
||||
|
||||
Makes a simple test search query to verify the API key and configuration are valid.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/search_tools/test_connection" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"litellm_params": {
|
||||
"search_provider": "perplexity",
|
||||
"api_key": "sk-..."
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
Example Response (Success):
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"message": "Successfully connected to perplexity search provider",
|
||||
"test_query": "test",
|
||||
"results_count": 5
|
||||
}
|
||||
```
|
||||
|
||||
Example Response (Failure):
|
||||
```json
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Authentication failed: Invalid API key",
|
||||
"error_type": "AuthenticationError"
|
||||
}
|
||||
```
|
||||
"""
|
||||
try:
|
||||
from litellm.search import asearch
|
||||
|
||||
# Extract params from request
|
||||
litellm_params = request.litellm_params
|
||||
search_provider = litellm_params.get("search_provider")
|
||||
api_key = litellm_params.get("api_key")
|
||||
api_base = litellm_params.get("api_base")
|
||||
|
||||
if not search_provider:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="search_provider is required in litellm_params"
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Testing connection to search provider: {search_provider}"
|
||||
)
|
||||
|
||||
# Make a simple test search query with max_results=1 to minimize cost
|
||||
test_query = "test"
|
||||
response = await asearch(
|
||||
query=test_query,
|
||||
search_provider=search_provider,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
max_results=1, # Minimize results to reduce cost
|
||||
timeout=10.0, # 10 second timeout for test
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Successfully tested connection to {search_provider} search provider"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Successfully connected to {search_provider} search provider",
|
||||
"test_query": test_query,
|
||||
"results_count": len(response.results)
|
||||
if response and response.results
|
||||
else 0,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
error_type = type(e).__name__
|
||||
|
||||
verbose_proxy_logger.exception(
|
||||
f"Failed to connect to search provider: {error_message}"
|
||||
)
|
||||
|
||||
# Return error details in a structured format
|
||||
return {
|
||||
"status": "error",
|
||||
"message": error_message,
|
||||
"error_type": error_type,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/search_tools/ui/available_providers",
|
||||
tags=["Search Tools"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_available_search_providers():
|
||||
"""
|
||||
Get the list of available search providers with their configuration fields.
|
||||
|
||||
Auto-discovers search providers and their UI-friendly names from transformation configs.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X GET "http://localhost:4000/search_tools/ui/available_providers" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"providers": [
|
||||
{
|
||||
"provider_name": "perplexity",
|
||||
"ui_friendly_name": "Perplexity"
|
||||
},
|
||||
{
|
||||
"provider_name": "tavily",
|
||||
"ui_friendly_name": "Tavily"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
"""
|
||||
try:
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
available_providers = []
|
||||
|
||||
# Auto-discover providers from SearchProviders enum
|
||||
for provider in SearchProviders:
|
||||
try:
|
||||
# Get the config class for this provider
|
||||
config = ProviderConfigManager.get_provider_search_config(
|
||||
provider=provider
|
||||
)
|
||||
|
||||
if config is not None:
|
||||
# Get the UI-friendly name from the config class
|
||||
ui_name = config.ui_friendly_name()
|
||||
|
||||
available_providers.append(
|
||||
{
|
||||
"provider_name": provider.value,
|
||||
"ui_friendly_name": ui_name,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Could not get config for search provider {provider.value}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return {"providers": available_providers}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting available search providers: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
Search Tool Registry for managing search tool configurations.
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.search import SearchTool
|
||||
|
||||
|
||||
class SearchToolRegistry:
|
||||
"""
|
||||
Handles adding, removing, and getting search tools in DB + in memory.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _convert_prisma_to_dict(prisma_obj) -> dict:
|
||||
"""
|
||||
Convert Prisma result to dict with datetime objects as ISO format strings.
|
||||
|
||||
Args:
|
||||
prisma_obj: Prisma model instance
|
||||
|
||||
Returns:
|
||||
Dict with datetime fields converted to ISO strings
|
||||
"""
|
||||
result = dict(prisma_obj)
|
||||
# Convert datetime objects to ISO format strings
|
||||
if "created_at" in result and result["created_at"]:
|
||||
result["created_at"] = result["created_at"].isoformat()
|
||||
if "updated_at" in result and result["updated_at"]:
|
||||
result["updated_at"] = result["updated_at"].isoformat()
|
||||
return result
|
||||
|
||||
###########################################################
|
||||
########### DB management helpers for search tools ########
|
||||
###########################################################
|
||||
|
||||
async def add_search_tool_to_db(
|
||||
self, search_tool: SearchTool, prisma_client: PrismaClient
|
||||
):
|
||||
"""
|
||||
Add a search tool to the database.
|
||||
|
||||
Args:
|
||||
search_tool: Search tool configuration
|
||||
prisma_client: Prisma client instance
|
||||
|
||||
Returns:
|
||||
Dict with created search tool data
|
||||
"""
|
||||
try:
|
||||
search_tool_name = search_tool.get("search_tool_name")
|
||||
litellm_params: str = safe_dumps(
|
||||
dict(search_tool.get("litellm_params", {}))
|
||||
)
|
||||
search_tool_info: str = safe_dumps(search_tool.get("search_tool_info", {}))
|
||||
|
||||
# Create search tool in DB
|
||||
created_search_tool = (
|
||||
await prisma_client.db.litellm_searchtoolstable.create(
|
||||
data={
|
||||
"search_tool_name": search_tool_name,
|
||||
"litellm_params": litellm_params,
|
||||
"search_tool_info": search_tool_info,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Add search_tool_id to the returned search tool object
|
||||
search_tool_dict = dict(search_tool)
|
||||
search_tool_dict["search_tool_id"] = created_search_tool.search_tool_id
|
||||
search_tool_dict["created_at"] = created_search_tool.created_at.isoformat()
|
||||
search_tool_dict["updated_at"] = created_search_tool.updated_at.isoformat()
|
||||
|
||||
return search_tool_dict
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error adding search tool to DB: {str(e)}")
|
||||
raise Exception(f"Error adding search tool to DB: {str(e)}")
|
||||
|
||||
async def delete_search_tool_from_db(
|
||||
self, search_tool_id: str, prisma_client: PrismaClient
|
||||
):
|
||||
"""
|
||||
Delete a search tool from the database.
|
||||
|
||||
Args:
|
||||
search_tool_id: ID of search tool to delete
|
||||
prisma_client: Prisma client instance
|
||||
|
||||
Returns:
|
||||
Dict with success message
|
||||
"""
|
||||
try:
|
||||
# Get search tool before deletion for response
|
||||
existing_tool = await prisma_client.db.litellm_searchtoolstable.find_unique(
|
||||
where={"search_tool_id": search_tool_id}
|
||||
)
|
||||
|
||||
if not existing_tool:
|
||||
raise Exception(f"Search tool with ID {search_tool_id} not found")
|
||||
|
||||
# Delete from DB
|
||||
await prisma_client.db.litellm_searchtoolstable.delete(
|
||||
where={"search_tool_id": search_tool_id}
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Search tool {search_tool_id} deleted successfully",
|
||||
"search_tool_name": existing_tool.search_tool_name,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error deleting search tool from DB: {str(e)}"
|
||||
)
|
||||
raise Exception(f"Error deleting search tool from DB: {str(e)}")
|
||||
|
||||
async def update_search_tool_in_db(
|
||||
self, search_tool_id: str, search_tool: SearchTool, prisma_client: PrismaClient
|
||||
):
|
||||
"""
|
||||
Update a search tool in the database.
|
||||
|
||||
Args:
|
||||
search_tool_id: ID of search tool to update
|
||||
search_tool: Updated search tool configuration
|
||||
prisma_client: Prisma client instance
|
||||
|
||||
Returns:
|
||||
Dict with updated search tool data
|
||||
"""
|
||||
try:
|
||||
search_tool_name = search_tool.get("search_tool_name")
|
||||
litellm_params: str = safe_dumps(
|
||||
dict(search_tool.get("litellm_params", {}))
|
||||
)
|
||||
search_tool_info: str = safe_dumps(search_tool.get("search_tool_info", {}))
|
||||
|
||||
# Update in DB
|
||||
updated_search_tool = (
|
||||
await prisma_client.db.litellm_searchtoolstable.update(
|
||||
where={"search_tool_id": search_tool_id},
|
||||
data={
|
||||
"search_tool_name": search_tool_name,
|
||||
"litellm_params": litellm_params,
|
||||
"search_tool_info": search_tool_info,
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Convert to dict with ISO formatted datetimes
|
||||
return self._convert_prisma_to_dict(updated_search_tool)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error updating search tool in DB: {str(e)}"
|
||||
)
|
||||
raise Exception(f"Error updating search tool in DB: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def get_all_search_tools_from_db(
|
||||
prisma_client: PrismaClient,
|
||||
) -> List[SearchTool]:
|
||||
"""
|
||||
Get all search tools from the database.
|
||||
|
||||
Args:
|
||||
prisma_client: Prisma client instance
|
||||
|
||||
Returns:
|
||||
List of search tool configurations
|
||||
"""
|
||||
try:
|
||||
search_tools_from_db = (
|
||||
await prisma_client.db.litellm_searchtoolstable.find_many(
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
)
|
||||
|
||||
search_tools: List[SearchTool] = []
|
||||
for search_tool in search_tools_from_db:
|
||||
# Convert Prisma result to dict with ISO formatted datetimes
|
||||
search_tool_dict = SearchToolRegistry._convert_prisma_to_dict(
|
||||
search_tool
|
||||
)
|
||||
search_tools.append(SearchTool(**search_tool_dict)) # type: ignore
|
||||
|
||||
return search_tools
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error getting search tools from DB: {str(e)}"
|
||||
)
|
||||
raise Exception(f"Error getting search tools from DB: {str(e)}")
|
||||
|
||||
async def get_search_tool_by_id_from_db(
|
||||
self, search_tool_id: str, prisma_client: PrismaClient
|
||||
) -> Optional[SearchTool]:
|
||||
"""
|
||||
Get a search tool by its ID from the database.
|
||||
|
||||
Args:
|
||||
search_tool_id: ID of search tool to retrieve
|
||||
prisma_client: Prisma client instance
|
||||
|
||||
Returns:
|
||||
Search tool configuration or None if not found
|
||||
"""
|
||||
try:
|
||||
search_tool = await prisma_client.db.litellm_searchtoolstable.find_unique(
|
||||
where={"search_tool_id": search_tool_id}
|
||||
)
|
||||
|
||||
if not search_tool:
|
||||
return None
|
||||
|
||||
# Convert Prisma result to dict with ISO formatted datetimes
|
||||
search_tool_dict = self._convert_prisma_to_dict(search_tool)
|
||||
return SearchTool(**search_tool_dict) # type: ignore
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error getting search tool from DB: {str(e)}"
|
||||
)
|
||||
raise Exception(f"Error getting search tool from DB: {str(e)}")
|
||||
|
||||
async def get_search_tool_by_name_from_db(
|
||||
self, search_tool_name: str, prisma_client: PrismaClient
|
||||
) -> Optional[SearchTool]:
|
||||
"""
|
||||
Get a search tool by its name from the database.
|
||||
|
||||
Args:
|
||||
search_tool_name: Name of search tool to retrieve
|
||||
prisma_client: Prisma client instance
|
||||
|
||||
Returns:
|
||||
Search tool configuration or None if not found
|
||||
"""
|
||||
try:
|
||||
search_tool = await prisma_client.db.litellm_searchtoolstable.find_unique(
|
||||
where={"search_tool_name": search_tool_name}
|
||||
)
|
||||
|
||||
if not search_tool:
|
||||
return None
|
||||
|
||||
# Convert Prisma result to dict with ISO formatted datetimes
|
||||
search_tool_dict = self._convert_prisma_to_dict(search_tool)
|
||||
return SearchTool(**search_tool_dict) # type: ignore
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error getting search tool from DB: {str(e)}"
|
||||
)
|
||||
raise Exception(f"Error getting search tool from DB: {str(e)}")
|
||||
Reference in New Issue
Block a user