443 lines
15 KiB
Python
443 lines
15 KiB
Python
import json
|
|
from typing import Any, List, Literal, Optional, Tuple
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_logger
|
|
from litellm.types.llms.openai import Batch
|
|
from litellm.types.utils import CallTypes, ModelInfo, Usage
|
|
from litellm.utils import token_counter
|
|
|
|
|
|
async def calculate_batch_cost_and_usage(
|
|
file_content_dictionary: List[dict],
|
|
custom_llm_provider: Literal[
|
|
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
|
],
|
|
model_name: Optional[str] = None,
|
|
model_info: Optional[ModelInfo] = None,
|
|
) -> Tuple[float, Usage, List[str]]:
|
|
"""
|
|
Calculate the cost and usage of a batch.
|
|
|
|
Args:
|
|
model_info: Optional deployment-level model info with custom batch
|
|
pricing. Threaded through to batch_cost_calculator so that
|
|
deployment-specific pricing (e.g. input_cost_per_token_batches)
|
|
is used instead of the global cost map.
|
|
"""
|
|
batch_cost = _batch_cost_calculator(
|
|
custom_llm_provider=custom_llm_provider,
|
|
file_content_dictionary=file_content_dictionary,
|
|
model_name=model_name,
|
|
model_info=model_info,
|
|
)
|
|
batch_usage = _get_batch_job_total_usage_from_file_content(
|
|
file_content_dictionary=file_content_dictionary,
|
|
custom_llm_provider=custom_llm_provider,
|
|
model_name=model_name,
|
|
)
|
|
batch_models = _get_batch_models_from_file_content(
|
|
file_content_dictionary, model_name
|
|
)
|
|
|
|
return batch_cost, batch_usage, batch_models
|
|
|
|
|
|
async def _handle_completed_batch(
|
|
batch: Batch,
|
|
custom_llm_provider: Literal[
|
|
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
|
],
|
|
model_name: Optional[str] = None,
|
|
litellm_params: Optional[dict] = None,
|
|
) -> Tuple[float, Usage, List[str]]:
|
|
"""Helper function to process a completed batch and handle logging
|
|
|
|
Args:
|
|
batch: The batch object
|
|
custom_llm_provider: The LLM provider
|
|
model_name: Optional model name
|
|
litellm_params: Optional litellm parameters containing credentials (api_key, api_base, etc.)
|
|
"""
|
|
# Get batch results
|
|
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
|
|
batch, custom_llm_provider, litellm_params=litellm_params
|
|
)
|
|
|
|
# Calculate costs and usage
|
|
batch_cost = _batch_cost_calculator(
|
|
custom_llm_provider=custom_llm_provider,
|
|
file_content_dictionary=file_content_dictionary,
|
|
model_name=model_name,
|
|
)
|
|
batch_usage = _get_batch_job_total_usage_from_file_content(
|
|
file_content_dictionary=file_content_dictionary,
|
|
custom_llm_provider=custom_llm_provider,
|
|
model_name=model_name,
|
|
)
|
|
|
|
batch_models = _get_batch_models_from_file_content(
|
|
file_content_dictionary, model_name
|
|
)
|
|
|
|
return batch_cost, batch_usage, batch_models
|
|
|
|
|
|
def _get_batch_models_from_file_content(
|
|
file_content_dictionary: List[dict],
|
|
model_name: Optional[str] = None,
|
|
) -> List[str]:
|
|
"""
|
|
Get the models from the file content
|
|
"""
|
|
if model_name:
|
|
return [model_name]
|
|
batch_models = []
|
|
for _item in file_content_dictionary:
|
|
if _batch_response_was_successful(_item):
|
|
_response_body = _get_response_from_batch_job_output_file(_item)
|
|
_model = _response_body.get("model")
|
|
if _model:
|
|
batch_models.append(_model)
|
|
return batch_models
|
|
|
|
|
|
def _batch_cost_calculator(
|
|
file_content_dictionary: List[dict],
|
|
custom_llm_provider: Literal[
|
|
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
|
] = "openai",
|
|
model_name: Optional[str] = None,
|
|
model_info: Optional[ModelInfo] = None,
|
|
) -> float:
|
|
"""
|
|
Calculate the cost of a batch based on the output file id
|
|
"""
|
|
# Handle Vertex AI with specialized method
|
|
if custom_llm_provider == "vertex_ai" and model_name:
|
|
batch_cost, _ = calculate_vertex_ai_batch_cost_and_usage(
|
|
file_content_dictionary, model_name
|
|
)
|
|
verbose_logger.debug("vertex_ai_total_cost=%s", batch_cost)
|
|
return batch_cost
|
|
|
|
# For other providers, use the existing logic
|
|
total_cost = _get_batch_job_cost_from_file_content(
|
|
file_content_dictionary=file_content_dictionary,
|
|
custom_llm_provider=custom_llm_provider,
|
|
model_info=model_info,
|
|
)
|
|
verbose_logger.debug("total_cost=%s", total_cost)
|
|
return total_cost
|
|
|
|
|
|
def calculate_vertex_ai_batch_cost_and_usage(
|
|
vertex_ai_batch_responses: List[dict],
|
|
model_name: Optional[str] = None,
|
|
) -> Tuple[float, Usage]:
|
|
"""
|
|
Calculate both cost and usage from Vertex AI batch responses.
|
|
|
|
Vertex AI batch output lines have format:
|
|
{"request": ..., "status": "", "response": {"candidates": [...], "usageMetadata": {...}}}
|
|
|
|
usageMetadata contains promptTokenCount, candidatesTokenCount, totalTokenCount.
|
|
"""
|
|
from litellm.cost_calculator import batch_cost_calculator
|
|
|
|
total_cost = 0.0
|
|
total_tokens = 0
|
|
prompt_tokens = 0
|
|
completion_tokens = 0
|
|
actual_model_name = model_name or "gemini-2.0-flash-001"
|
|
|
|
for response in vertex_ai_batch_responses:
|
|
response_body = response.get("response")
|
|
if response_body is None:
|
|
continue
|
|
|
|
usage_metadata = response_body.get("usageMetadata", {})
|
|
_prompt = usage_metadata.get("promptTokenCount", 0) or 0
|
|
_completion = usage_metadata.get("candidatesTokenCount", 0) or 0
|
|
_total = usage_metadata.get("totalTokenCount", 0) or (_prompt + _completion)
|
|
|
|
line_usage = Usage(
|
|
prompt_tokens=_prompt,
|
|
completion_tokens=_completion,
|
|
total_tokens=_total,
|
|
)
|
|
|
|
try:
|
|
p_cost, c_cost = batch_cost_calculator(
|
|
usage=line_usage,
|
|
model=actual_model_name,
|
|
custom_llm_provider="vertex_ai",
|
|
)
|
|
total_cost += p_cost + c_cost
|
|
except Exception as e:
|
|
verbose_logger.debug(
|
|
"vertex_ai batch cost calculation error for line: %s", str(e)
|
|
)
|
|
|
|
prompt_tokens += _prompt
|
|
completion_tokens += _completion
|
|
total_tokens += _total
|
|
|
|
verbose_logger.info(
|
|
"vertex_ai batch cost: cost=%s, prompt=%d, completion=%d, total=%d",
|
|
total_cost,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens,
|
|
)
|
|
|
|
return total_cost, Usage(
|
|
total_tokens=total_tokens,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
)
|
|
|
|
|
|
async def _get_batch_output_file_content_as_dictionary(
|
|
batch: Batch,
|
|
custom_llm_provider: Literal[
|
|
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
|
] = "openai",
|
|
litellm_params: Optional[dict] = None,
|
|
) -> List[dict]:
|
|
"""
|
|
Get the batch output file content as a list of dictionaries
|
|
|
|
Args:
|
|
batch: The batch object
|
|
custom_llm_provider: The LLM provider
|
|
litellm_params: Optional litellm parameters containing credentials (api_key, api_base, etc.)
|
|
Required for Azure and other providers that need authentication
|
|
"""
|
|
from litellm.files.main import afile_content
|
|
from litellm.proxy.openai_files_endpoints.common_utils import (
|
|
_is_base64_encoded_unified_file_id,
|
|
)
|
|
|
|
if custom_llm_provider == "vertex_ai":
|
|
raise ValueError("Vertex AI does not support file content retrieval")
|
|
|
|
if batch.output_file_id is None:
|
|
raise ValueError("Output file id is None cannot retrieve file content")
|
|
|
|
file_id = batch.output_file_id
|
|
is_base64_unified_file_id = _is_base64_encoded_unified_file_id(file_id)
|
|
if is_base64_unified_file_id:
|
|
try:
|
|
file_id = is_base64_unified_file_id.split("llm_output_file_id,")[1].split(
|
|
";"
|
|
)[0]
|
|
verbose_logger.debug(
|
|
f"Extracted LLM output file ID from unified file ID: {file_id}"
|
|
)
|
|
except (IndexError, AttributeError) as e:
|
|
verbose_logger.error(
|
|
f"Failed to extract LLM output file ID from unified file ID: {batch.output_file_id}, error: {e}"
|
|
)
|
|
|
|
# Build kwargs for afile_content with credentials from litellm_params
|
|
file_content_kwargs = {
|
|
"file_id": file_id,
|
|
"custom_llm_provider": custom_llm_provider,
|
|
}
|
|
|
|
# Extract and add credentials for file access
|
|
credentials = _extract_file_access_credentials(litellm_params)
|
|
file_content_kwargs.update(credentials)
|
|
|
|
_file_content = await afile_content(**file_content_kwargs) # type: ignore[reportArgumentType]
|
|
return _get_file_content_as_dictionary(_file_content.content)
|
|
|
|
|
|
def _extract_file_access_credentials(litellm_params: Optional[dict]) -> dict:
|
|
"""
|
|
Extract credentials from litellm_params for file access operations.
|
|
|
|
This method extracts relevant authentication and configuration parameters
|
|
needed for accessing files across different providers (Azure, Vertex AI, etc.).
|
|
|
|
Args:
|
|
litellm_params: Dictionary containing litellm parameters with credentials
|
|
|
|
Returns:
|
|
Dictionary containing only the credentials needed for file access
|
|
"""
|
|
credentials = {}
|
|
|
|
if litellm_params:
|
|
# List of credential keys that should be passed to file operations
|
|
credential_keys = [
|
|
"api_key",
|
|
"api_base",
|
|
"api_version",
|
|
"organization",
|
|
"azure_ad_token",
|
|
"azure_ad_token_provider",
|
|
"vertex_project",
|
|
"vertex_location",
|
|
"vertex_credentials",
|
|
"timeout",
|
|
"max_retries",
|
|
]
|
|
for key in credential_keys:
|
|
if key in litellm_params:
|
|
credentials[key] = litellm_params[key]
|
|
|
|
return credentials
|
|
|
|
|
|
def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
|
|
"""
|
|
Get the file content as a list of dictionaries from JSON Lines format
|
|
"""
|
|
try:
|
|
_file_content_str = file_content.decode("utf-8")
|
|
# Split by newlines and parse each line as a separate JSON object
|
|
json_objects = []
|
|
for line in _file_content_str.strip().split("\n"):
|
|
if line: # Skip empty lines
|
|
json_objects.append(json.loads(line))
|
|
verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
|
|
return json_objects
|
|
except Exception as e:
|
|
raise e
|
|
|
|
|
|
def _get_batch_job_cost_from_file_content(
|
|
file_content_dictionary: List[dict],
|
|
custom_llm_provider: Literal[
|
|
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
|
] = "openai",
|
|
model_info: Optional[ModelInfo] = None,
|
|
) -> float:
|
|
"""
|
|
Get the cost of a batch job from the file content
|
|
"""
|
|
from litellm.cost_calculator import batch_cost_calculator
|
|
|
|
try:
|
|
total_cost: float = 0.0
|
|
# parse the file content as json
|
|
verbose_logger.debug(
|
|
"file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
|
|
)
|
|
for _item in file_content_dictionary:
|
|
if _batch_response_was_successful(_item):
|
|
_response_body = _get_response_from_batch_job_output_file(_item)
|
|
if model_info is not None:
|
|
usage = _get_batch_job_usage_from_response_body(_response_body)
|
|
model = _response_body.get("model", "")
|
|
prompt_cost, completion_cost = batch_cost_calculator(
|
|
usage=usage,
|
|
model=model,
|
|
custom_llm_provider=custom_llm_provider,
|
|
model_info=model_info,
|
|
)
|
|
total_cost += prompt_cost + completion_cost
|
|
else:
|
|
total_cost += litellm.completion_cost(
|
|
completion_response=_response_body,
|
|
custom_llm_provider=custom_llm_provider,
|
|
call_type=CallTypes.aretrieve_batch.value,
|
|
)
|
|
verbose_logger.debug("total_cost=%s", total_cost)
|
|
return total_cost
|
|
except Exception as e:
|
|
verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
|
|
raise e
|
|
|
|
|
|
def _get_batch_job_total_usage_from_file_content(
|
|
file_content_dictionary: List[dict],
|
|
custom_llm_provider: Literal[
|
|
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
|
] = "openai",
|
|
model_name: Optional[str] = None,
|
|
) -> Usage:
|
|
"""
|
|
Get the tokens of a batch job from the file content
|
|
"""
|
|
# Handle Vertex AI with specialized method
|
|
if custom_llm_provider == "vertex_ai" and model_name:
|
|
_, batch_usage = calculate_vertex_ai_batch_cost_and_usage(
|
|
file_content_dictionary, model_name
|
|
)
|
|
return batch_usage
|
|
|
|
# For other providers, use the existing logic
|
|
total_tokens: int = 0
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
for _item in file_content_dictionary:
|
|
if _batch_response_was_successful(_item):
|
|
_response_body = _get_response_from_batch_job_output_file(_item)
|
|
usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
|
|
total_tokens += usage.total_tokens
|
|
prompt_tokens += usage.prompt_tokens
|
|
completion_tokens += usage.completion_tokens
|
|
return Usage(
|
|
total_tokens=total_tokens,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
)
|
|
|
|
|
|
def _get_batch_job_input_file_usage(
|
|
file_content_dictionary: List[dict],
|
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
|
model_name: Optional[str] = None,
|
|
) -> Usage:
|
|
"""
|
|
Count the number of tokens in the input file
|
|
|
|
Used for batch rate limiting to count the number of tokens in the input file
|
|
"""
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
|
|
for _item in file_content_dictionary:
|
|
body = _item.get("body", {})
|
|
model = body.get("model", model_name or "")
|
|
messages = body.get("messages", [])
|
|
|
|
if messages:
|
|
item_tokens = token_counter(model=model, messages=messages)
|
|
prompt_tokens += item_tokens
|
|
|
|
return Usage(
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
)
|
|
|
|
|
|
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
|
|
"""
|
|
Get the tokens of a batch job from the response body
|
|
"""
|
|
_usage_dict = response_body.get("usage", None) or {}
|
|
usage: Usage = Usage(**_usage_dict)
|
|
return usage
|
|
|
|
|
|
def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
|
|
"""
|
|
Get the response from the batch job output file
|
|
"""
|
|
_response: dict = batch_job_output_file.get("response", None) or {}
|
|
_response_body = _response.get("body", None) or {}
|
|
return _response_body
|
|
|
|
|
|
def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
|
|
"""
|
|
Check if the batch job response status == 200
|
|
"""
|
|
_response: dict = batch_job_output_file.get("response", None) or {}
|
|
return _response.get("status_code", None) == 200
|