chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
from .common_utils import (
|
||||
DEFAULT_API_BASE,
|
||||
DEFAULT_MAX_POLLING_TIME,
|
||||
DEFAULT_POLLING_INTERVAL,
|
||||
IMAGE_EDIT_MODELS,
|
||||
IMAGE_GENERATION_MODELS,
|
||||
BlackForestLabsError,
|
||||
)
|
||||
from .image_edit import BlackForestLabsImageEditConfig
|
||||
from .image_generation import BlackForestLabsImageGenerationConfig
|
||||
|
||||
__all__ = [
|
||||
"BlackForestLabsError",
|
||||
"BlackForestLabsImageEditConfig",
|
||||
"BlackForestLabsImageGenerationConfig",
|
||||
"DEFAULT_API_BASE",
|
||||
"DEFAULT_MAX_POLLING_TIME",
|
||||
"DEFAULT_POLLING_INTERVAL",
|
||||
"IMAGE_EDIT_MODELS",
|
||||
"IMAGE_GENERATION_MODELS",
|
||||
]
|
||||
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Black Forest Labs Common Utilities
|
||||
|
||||
Common utilities, constants, and error handling for Black Forest Labs API.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
|
||||
class BlackForestLabsError(BaseLLMException):
|
||||
"""Exception class for Black Forest Labs API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# API Constants
|
||||
DEFAULT_API_BASE = "https://api.bfl.ai"
|
||||
|
||||
# Polling configuration
|
||||
DEFAULT_POLLING_INTERVAL = 1.5 # seconds
|
||||
DEFAULT_MAX_POLLING_TIME = 300 # 5 minutes
|
||||
|
||||
# Model to endpoint mapping for image edit
|
||||
IMAGE_EDIT_MODELS: Dict[str, str] = {
|
||||
"flux-kontext-pro": "/v1/flux-kontext-pro",
|
||||
"flux-kontext-max": "/v1/flux-kontext-max",
|
||||
"flux-pro-1.0-fill": "/v1/flux-pro-1.0-fill",
|
||||
"flux-pro-1.0-expand": "/v1/flux-pro-1.0-expand",
|
||||
}
|
||||
|
||||
# Model to endpoint mapping for image generation
|
||||
IMAGE_GENERATION_MODELS: Dict[str, str] = {
|
||||
"flux-pro-1.1": "/v1/flux-pro-1.1",
|
||||
"flux-pro-1.1-ultra": "/v1/flux-pro-1.1-ultra",
|
||||
"flux-dev": "/v1/flux-dev",
|
||||
"flux-pro": "/v1/flux-pro",
|
||||
# Kontext models support both text-to-image and image editing
|
||||
"flux-kontext-pro": "/v1/flux-kontext-pro",
|
||||
"flux-kontext-max": "/v1/flux-kontext-max",
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
from .handler import BlackForestLabsImageEdit, bfl_image_edit
|
||||
from .transformation import BlackForestLabsImageEditConfig
|
||||
|
||||
__all__ = [
|
||||
"BlackForestLabsImageEditConfig",
|
||||
"BlackForestLabsImageEdit",
|
||||
"bfl_image_edit",
|
||||
]
|
||||
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
Black Forest Labs Image Edit Handler
|
||||
|
||||
Handles image edit requests for Black Forest Labs models.
|
||||
BFL uses an async polling pattern - the initial request returns a task ID,
|
||||
then we poll until the result is ready.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes, ImageResponse
|
||||
|
||||
from ..common_utils import (
|
||||
DEFAULT_MAX_POLLING_TIME,
|
||||
DEFAULT_POLLING_INTERVAL,
|
||||
BlackForestLabsError,
|
||||
)
|
||||
from .transformation import BlackForestLabsImageEditConfig
|
||||
|
||||
|
||||
class BlackForestLabsImageEdit:
|
||||
"""
|
||||
Black Forest Labs Image Edit handler.
|
||||
|
||||
Handles the HTTP requests and polling logic, delegating data transformation
|
||||
to the BlackForestLabsImageEditConfig class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = BlackForestLabsImageEditConfig()
|
||||
|
||||
def image_edit(
|
||||
self,
|
||||
model: str,
|
||||
image: Union[FileTypes, List[FileTypes]],
|
||||
prompt: Optional[str],
|
||||
image_edit_optional_request_params: Dict,
|
||||
litellm_params: Union[GenericLiteLLMParams, Dict],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
aimage_edit: bool = False,
|
||||
) -> Union[ImageResponse, Any]:
|
||||
"""
|
||||
Main entry point for image edit requests.
|
||||
|
||||
Args:
|
||||
model: The model to use (e.g., "black_forest_labs/flux-kontext-pro")
|
||||
image: The image(s) to edit
|
||||
prompt: The edit instruction
|
||||
image_edit_optional_request_params: Optional parameters for the request
|
||||
litellm_params: LiteLLM parameters including api_key, api_base
|
||||
logging_obj: Logging object
|
||||
timeout: Request timeout
|
||||
extra_headers: Additional headers
|
||||
client: HTTP client to use
|
||||
aimage_edit: If True, return async coroutine
|
||||
|
||||
Returns:
|
||||
ImageResponse or coroutine if aimage_edit=True
|
||||
"""
|
||||
# Handle litellm_params as dict or object
|
||||
if isinstance(litellm_params, dict):
|
||||
api_key = litellm_params.get("api_key")
|
||||
api_base = litellm_params.get("api_base")
|
||||
litellm_params_dict = litellm_params
|
||||
else:
|
||||
api_key = litellm_params.api_key
|
||||
api_base = litellm_params.api_base
|
||||
litellm_params_dict = dict(litellm_params)
|
||||
|
||||
if aimage_edit:
|
||||
return self.async_image_edit(
|
||||
model=model,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
image_edit_optional_request_params=image_edit_optional_request_params,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
extra_headers=extra_headers,
|
||||
client=client if isinstance(client, AsyncHTTPHandler) else None,
|
||||
)
|
||||
|
||||
# Sync version
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_client = _get_httpx_client()
|
||||
else:
|
||||
sync_client = client
|
||||
|
||||
# Validate environment and get headers
|
||||
headers = self.config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=image_edit_optional_request_params.get("extra_headers", {}) or {},
|
||||
model=model,
|
||||
)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# Get complete URL
|
||||
complete_url = self.config.get_complete_url(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
|
||||
# Transform request
|
||||
# Handle image list vs single image
|
||||
if isinstance(image, list):
|
||||
if not image:
|
||||
raise BlackForestLabsError(status_code=400, message="No image provided")
|
||||
image_input = image[0]
|
||||
else:
|
||||
image_input = image
|
||||
data, _ = self.config.transform_image_edit_request(
|
||||
model=model,
|
||||
prompt=prompt or "",
|
||||
image=image_input,
|
||||
image_edit_optional_request_params=image_edit_optional_request_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Logging
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": complete_url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
# Make initial request
|
||||
try:
|
||||
response = sync_client.post(
|
||||
url=complete_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message=f"Request failed: {str(e)}",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
final_response = self._poll_for_result_sync(
|
||||
initial_response=response,
|
||||
headers=headers,
|
||||
sync_client=sync_client,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
return self.config.transform_image_edit_response(
|
||||
model=model,
|
||||
raw_response=final_response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
async def async_image_edit(
|
||||
self,
|
||||
model: str,
|
||||
image: Union[FileTypes, List[FileTypes]],
|
||||
prompt: Optional[str],
|
||||
image_edit_optional_request_params: Dict,
|
||||
litellm_params: Union[GenericLiteLLMParams, Dict],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Async version of image edit.
|
||||
"""
|
||||
# Handle litellm_params as dict or object
|
||||
if isinstance(litellm_params, dict):
|
||||
api_key = litellm_params.get("api_key")
|
||||
api_base = litellm_params.get("api_base")
|
||||
litellm_params_dict = litellm_params
|
||||
else:
|
||||
api_key = litellm_params.api_key
|
||||
api_base = litellm_params.api_base
|
||||
litellm_params_dict = dict(litellm_params)
|
||||
|
||||
if client is None:
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BLACK_FOREST_LABS,
|
||||
)
|
||||
else:
|
||||
async_client = client
|
||||
|
||||
# Validate environment and get headers
|
||||
headers = self.config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=image_edit_optional_request_params.get("extra_headers", {}) or {},
|
||||
model=model,
|
||||
)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# Get complete URL
|
||||
complete_url = self.config.get_complete_url(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
|
||||
# Transform request
|
||||
if isinstance(image, list):
|
||||
if not image:
|
||||
raise BlackForestLabsError(status_code=400, message="No image provided")
|
||||
image_input = image[0]
|
||||
else:
|
||||
image_input = image
|
||||
data, _ = self.config.transform_image_edit_request(
|
||||
model=model,
|
||||
prompt=prompt or "",
|
||||
image=image_input,
|
||||
image_edit_optional_request_params=image_edit_optional_request_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Logging
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": complete_url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
# Make initial request
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=complete_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message=f"Request failed: {str(e)}",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
final_response = await self._poll_for_result_async(
|
||||
initial_response=response,
|
||||
headers=headers,
|
||||
async_client=async_client,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
return self.config.transform_image_edit_response(
|
||||
model=model,
|
||||
raw_response=final_response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
def _poll_for_result_sync(
|
||||
self,
|
||||
initial_response: httpx.Response,
|
||||
headers: dict,
|
||||
sync_client: HTTPHandler,
|
||||
max_wait: float = DEFAULT_MAX_POLLING_TIME,
|
||||
interval: float = DEFAULT_POLLING_INTERVAL,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Poll BFL API until result is ready (sync version).
|
||||
|
||||
Args:
|
||||
initial_response: The initial response containing polling_url
|
||||
headers: Headers to use for polling (must include x-key)
|
||||
sync_client: HTTP client
|
||||
max_wait: Maximum time to wait in seconds
|
||||
interval: Polling interval in seconds
|
||||
timeout: Timeout for each individual polling request
|
||||
|
||||
Returns:
|
||||
Final response with completed result
|
||||
"""
|
||||
# Validate initial response status code
|
||||
if initial_response.status_code >= 400:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL initial request failed: {initial_response.text}",
|
||||
)
|
||||
|
||||
# Parse initial response to get polling URL
|
||||
try:
|
||||
response_data = initial_response.json()
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"Error parsing initial response: {e}",
|
||||
)
|
||||
|
||||
# Check for immediate errors
|
||||
if "errors" in response_data:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL error: {response_data['errors']}",
|
||||
)
|
||||
|
||||
polling_url = response_data.get("polling_url")
|
||||
if not polling_url:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
start_time = time.time()
|
||||
verbose_logger.debug(f"BFL starting sync polling at {polling_url}")
|
||||
|
||||
while time.time() - start_time < max_wait:
|
||||
response = sync_client.get(
|
||||
url=polling_url,
|
||||
headers=polling_headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BlackForestLabsError(
|
||||
status_code=response.status_code,
|
||||
message=f"Polling failed: {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
status = data.get("status")
|
||||
|
||||
verbose_logger.debug(f"BFL poll status: {status}")
|
||||
|
||||
if status == "Ready":
|
||||
return response
|
||||
elif status in [
|
||||
"Error",
|
||||
"Failed",
|
||||
"Content Moderated",
|
||||
"Request Moderated",
|
||||
]:
|
||||
raise BlackForestLabsError(
|
||||
status_code=400,
|
||||
message=f"Image generation failed: {status}",
|
||||
)
|
||||
|
||||
time.sleep(interval)
|
||||
|
||||
raise BlackForestLabsError(
|
||||
status_code=408,
|
||||
message=f"Polling timed out after {max_wait} seconds",
|
||||
)
|
||||
|
||||
async def _poll_for_result_async(
|
||||
self,
|
||||
initial_response: httpx.Response,
|
||||
headers: dict,
|
||||
async_client: AsyncHTTPHandler,
|
||||
max_wait: float = DEFAULT_MAX_POLLING_TIME,
|
||||
interval: float = DEFAULT_POLLING_INTERVAL,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Poll BFL API until result is ready (async version).
|
||||
"""
|
||||
# Validate initial response status code
|
||||
if initial_response.status_code >= 400:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL initial request failed: {initial_response.text}",
|
||||
)
|
||||
|
||||
# Parse initial response to get polling URL
|
||||
try:
|
||||
response_data = initial_response.json()
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"Error parsing initial response: {e}",
|
||||
)
|
||||
|
||||
# Check for immediate errors
|
||||
if "errors" in response_data:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL error: {response_data['errors']}",
|
||||
)
|
||||
|
||||
polling_url = response_data.get("polling_url")
|
||||
if not polling_url:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
start_time = time.time()
|
||||
verbose_logger.debug(f"BFL starting async polling at {polling_url}")
|
||||
|
||||
while time.time() - start_time < max_wait:
|
||||
response = await async_client.get(
|
||||
url=polling_url,
|
||||
headers=polling_headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BlackForestLabsError(
|
||||
status_code=response.status_code,
|
||||
message=f"Polling failed: {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
status = data.get("status")
|
||||
|
||||
verbose_logger.debug(f"BFL poll status: {status}")
|
||||
|
||||
if status == "Ready":
|
||||
return response
|
||||
elif status in [
|
||||
"Error",
|
||||
"Failed",
|
||||
"Content Moderated",
|
||||
"Request Moderated",
|
||||
]:
|
||||
raise BlackForestLabsError(
|
||||
status_code=400,
|
||||
message=f"Image generation failed: {status}",
|
||||
)
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
raise BlackForestLabsError(
|
||||
status_code=408,
|
||||
message=f"Polling timed out after {max_wait} seconds",
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance for use in images/main.py
|
||||
bfl_image_edit = BlackForestLabsImageEdit()
|
||||
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
Black Forest Labs Image Edit Configuration
|
||||
|
||||
Handles transformation between OpenAI-compatible format and Black Forest Labs API format
|
||||
for image editing endpoints (flux-kontext-pro, flux-kontext-max, etc.).
|
||||
|
||||
API Reference: https://docs.bfl.ai/
|
||||
"""
|
||||
|
||||
import base64
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes, ImageObject, ImageResponse
|
||||
|
||||
from ..common_utils import (
|
||||
DEFAULT_API_BASE,
|
||||
IMAGE_EDIT_MODELS,
|
||||
BlackForestLabsError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BlackForestLabsImageEditConfig(BaseImageEditConfig):
|
||||
"""
|
||||
Configuration for Black Forest Labs image editing.
|
||||
|
||||
Supports:
|
||||
- flux-kontext-pro: General image editing with prompts
|
||||
- flux-kontext-max: Premium quality editing
|
||||
- flux-pro-1.0-fill: Inpainting with mask
|
||||
- flux-pro-1.0-expand: Outpainting (expand image borders)
|
||||
|
||||
Note: HTTP requests and polling are handled by the handler (handler.py).
|
||||
This class only handles data transformation.
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
Return list of OpenAI params supported by Black Forest Labs.
|
||||
|
||||
Note: BFL uses different parameter names, these are mapped in map_openai_params.
|
||||
"""
|
||||
return [
|
||||
"mask",
|
||||
"seed",
|
||||
"output_format",
|
||||
"safety_tolerance",
|
||||
"prompt_upsampling",
|
||||
"aspect_ratio",
|
||||
"steps",
|
||||
"guidance",
|
||||
"grow_mask",
|
||||
"top",
|
||||
"bottom",
|
||||
"left",
|
||||
"right",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""
|
||||
Map OpenAI parameters to Black Forest Labs parameters.
|
||||
|
||||
BFL-specific params are passed through directly.
|
||||
"""
|
||||
optional_params: Dict[str, Any] = {}
|
||||
|
||||
# Pass through BFL-specific params
|
||||
bfl_params = [
|
||||
"seed",
|
||||
"output_format",
|
||||
"safety_tolerance",
|
||||
"prompt_upsampling",
|
||||
# Kontext-specific
|
||||
"aspect_ratio",
|
||||
# Fill/Inpaint-specific
|
||||
"steps",
|
||||
"guidance",
|
||||
"grow_mask",
|
||||
# Expand-specific
|
||||
"top",
|
||||
"bottom",
|
||||
"left",
|
||||
"right",
|
||||
]
|
||||
|
||||
# Convert TypedDict to regular dict for access
|
||||
params_dict = dict(image_edit_optional_params)
|
||||
|
||||
for param in bfl_params:
|
||||
if param in params_dict:
|
||||
value = params_dict[param]
|
||||
if value is not None:
|
||||
optional_params[param] = value
|
||||
|
||||
# Set default output format
|
||||
if "output_format" not in optional_params:
|
||||
optional_params["output_format"] = "png"
|
||||
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment and set up headers for Black Forest Labs.
|
||||
|
||||
BFL uses x-key header for authentication.
|
||||
"""
|
||||
final_api_key: Optional[str] = (
|
||||
api_key
|
||||
or get_secret_str("BFL_API_KEY")
|
||||
or get_secret_str("BLACK_FOREST_LABS_API_KEY")
|
||||
)
|
||||
|
||||
if not final_api_key:
|
||||
raise BlackForestLabsError(
|
||||
status_code=401,
|
||||
message="BFL_API_KEY is not set. Please set it via environment variable or pass api_key parameter.",
|
||||
)
|
||||
|
||||
headers["x-key"] = final_api_key
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["Accept"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
def use_multipart_form_data(self) -> bool:
|
||||
"""
|
||||
BFL uses JSON requests, not multipart/form-data.
|
||||
"""
|
||||
return False
|
||||
|
||||
def _get_model_endpoint(self, model: str) -> str:
|
||||
"""
|
||||
Get the API endpoint for a given model.
|
||||
"""
|
||||
# Remove provider prefix if present (e.g., "black_forest_labs/flux-kontext-pro")
|
||||
model_name = model.lower()
|
||||
if "/" in model_name:
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
# Check if model is in our mapping
|
||||
if model_name in IMAGE_EDIT_MODELS:
|
||||
return IMAGE_EDIT_MODELS[model_name]
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown BFL image edit model: {model_name}. "
|
||||
f"Supported models: {list(IMAGE_EDIT_MODELS.keys())}"
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the Black Forest Labs API request.
|
||||
"""
|
||||
base_url: str = api_base or get_secret_str("BFL_API_BASE") or DEFAULT_API_BASE
|
||||
base_url = base_url.rstrip("/")
|
||||
|
||||
endpoint = self._get_model_endpoint(model)
|
||||
return f"{base_url}{endpoint}"
|
||||
|
||||
def _read_image_bytes(
|
||||
self,
|
||||
image: Any,
|
||||
depth: int = 0,
|
||||
max_depth: int = DEFAULT_MAX_RECURSE_DEPTH,
|
||||
) -> bytes:
|
||||
"""Read image bytes from various input types."""
|
||||
if depth > max_depth:
|
||||
raise ValueError(
|
||||
f"Max recursion depth {max_depth} reached while reading image bytes for Black Forest Labs image edit."
|
||||
)
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
elif isinstance(image, list):
|
||||
# If it's a list, take the first image
|
||||
return self._read_image_bytes(image[0], depth=depth + 1, max_depth=max_depth)
|
||||
elif isinstance(image, str):
|
||||
if image.startswith(("http://", "https://")):
|
||||
# Download image from URL
|
||||
response = httpx.get(image, timeout=60.0)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
else:
|
||||
# Assume it's a file path
|
||||
with open(image, "rb") as f:
|
||||
return f.read()
|
||||
elif hasattr(image, "read"):
|
||||
# File-like object
|
||||
pos = getattr(image, "tell", lambda: 0)()
|
||||
if hasattr(image, "seek"):
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
if hasattr(image, "seek"):
|
||||
image.seek(pos)
|
||||
return data
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported image type: {type(image)}. "
|
||||
"Expected bytes, str (URL or file path), or file-like object."
|
||||
)
|
||||
|
||||
def transform_image_edit_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Optional[str],
|
||||
image: Optional[FileTypes],
|
||||
image_edit_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict, RequestFiles]:
|
||||
"""
|
||||
Transform OpenAI-style request to Black Forest Labs request format.
|
||||
|
||||
BFL uses JSON body with base64-encoded images, not multipart/form-data.
|
||||
"""
|
||||
# Read and encode image
|
||||
image_bytes = self._read_image_bytes(image)
|
||||
b64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
# Build request body
|
||||
request_body: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"input_image": b64_image,
|
||||
}
|
||||
|
||||
# Add optional params (only BFL-recognized parameters)
|
||||
bfl_request_params = [
|
||||
"seed",
|
||||
"output_format",
|
||||
"safety_tolerance",
|
||||
"prompt_upsampling",
|
||||
"aspect_ratio",
|
||||
"steps",
|
||||
"guidance",
|
||||
"grow_mask",
|
||||
"top",
|
||||
"bottom",
|
||||
"left",
|
||||
"right",
|
||||
]
|
||||
for key, value in image_edit_optional_request_params.items():
|
||||
if key in bfl_request_params and value is not None:
|
||||
request_body[key] = value
|
||||
|
||||
# Handle mask if provided (for inpainting)
|
||||
if "mask" in image_edit_optional_request_params:
|
||||
mask = image_edit_optional_request_params["mask"]
|
||||
mask_bytes = self._read_image_bytes(mask)
|
||||
request_body["mask"] = base64.b64encode(mask_bytes).decode("utf-8")
|
||||
|
||||
# BFL uses JSON, not multipart - return empty files
|
||||
return request_body, []
|
||||
|
||||
def transform_image_edit_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform Black Forest Labs response to OpenAI-compatible ImageResponse.
|
||||
|
||||
This is called with the FINAL polled response (after handler does polling).
|
||||
The response contains: {"status": "Ready", "result": {"sample": "https://..."}}
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"Error parsing BFL response: {e}",
|
||||
)
|
||||
|
||||
# Get image URL from result
|
||||
image_url = response_data.get("result", {}).get("sample")
|
||||
if not image_url:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message="No image URL in BFL result",
|
||||
)
|
||||
|
||||
# Build ImageResponse
|
||||
return ImageResponse(
|
||||
created=int(time.time()),
|
||||
data=[ImageObject(url=image_url)],
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BlackForestLabsError:
|
||||
"""Return the appropriate error class for Black Forest Labs."""
|
||||
return BlackForestLabsError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
)
|
||||
@@ -0,0 +1,12 @@
|
||||
from .handler import BlackForestLabsImageGeneration, bfl_image_generation
|
||||
from .transformation import (
|
||||
BlackForestLabsImageGenerationConfig,
|
||||
get_black_forest_labs_image_generation_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BlackForestLabsImageGenerationConfig",
|
||||
"get_black_forest_labs_image_generation_config",
|
||||
"BlackForestLabsImageGeneration",
|
||||
"bfl_image_generation",
|
||||
]
|
||||
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
Black Forest Labs Image Generation Handler
|
||||
|
||||
Handles image generation requests for Black Forest Labs models.
|
||||
BFL uses an async polling pattern - the initial request returns a task ID,
|
||||
then we poll until the result is ready.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
from ..common_utils import (
|
||||
DEFAULT_MAX_POLLING_TIME,
|
||||
DEFAULT_POLLING_INTERVAL,
|
||||
BlackForestLabsError,
|
||||
)
|
||||
from .transformation import BlackForestLabsImageGenerationConfig
|
||||
|
||||
|
||||
class BlackForestLabsImageGeneration:
|
||||
"""
|
||||
Black Forest Labs Image Generation handler.
|
||||
|
||||
Handles the HTTP requests and polling logic, delegating data transformation
|
||||
to the BlackForestLabsImageGenerationConfig class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = BlackForestLabsImageGenerationConfig()
|
||||
|
||||
def image_generation(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
optional_params: Dict,
|
||||
litellm_params: Union[GenericLiteLLMParams, Dict],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
aimg_generation: bool = False,
|
||||
) -> Union[ImageResponse, Any]:
|
||||
"""
|
||||
Main entry point for image generation requests.
|
||||
|
||||
Args:
|
||||
model: The model to use (e.g., "black_forest_labs/flux-pro-1.1")
|
||||
prompt: The text prompt for image generation
|
||||
model_response: ImageResponse object to populate
|
||||
optional_params: Optional parameters for the request
|
||||
litellm_params: LiteLLM parameters including api_key, api_base
|
||||
logging_obj: Logging object
|
||||
timeout: Request timeout
|
||||
extra_headers: Additional headers
|
||||
client: HTTP client to use
|
||||
aimg_generation: If True, return async coroutine
|
||||
|
||||
Returns:
|
||||
ImageResponse or coroutine if aimg_generation=True
|
||||
"""
|
||||
# Handle litellm_params as dict or object
|
||||
if isinstance(litellm_params, dict):
|
||||
api_key = litellm_params.get("api_key")
|
||||
api_base = litellm_params.get("api_base")
|
||||
litellm_params_dict = litellm_params
|
||||
else:
|
||||
api_key = litellm_params.api_key
|
||||
api_base = litellm_params.api_base
|
||||
litellm_params_dict = dict(litellm_params)
|
||||
|
||||
if aimg_generation:
|
||||
return self.async_image_generation(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
extra_headers=extra_headers,
|
||||
client=client if isinstance(client, AsyncHTTPHandler) else None,
|
||||
)
|
||||
|
||||
# Sync version
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
sync_client = _get_httpx_client()
|
||||
else:
|
||||
sync_client = client
|
||||
|
||||
# Validate environment and get headers
|
||||
headers = self.config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers={},
|
||||
model=model,
|
||||
messages=[],
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# Get complete URL
|
||||
complete_url = self.config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
|
||||
# Transform request
|
||||
data = self.config.transform_image_generation_request(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Logging
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": complete_url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
# Make initial request
|
||||
try:
|
||||
response = sync_client.post(
|
||||
url=complete_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message=f"Request failed: {str(e)}",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
final_response = self._poll_for_result_sync(
|
||||
initial_response=response,
|
||||
headers=headers,
|
||||
sync_client=sync_client,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
return self.config.transform_image_generation_response(
|
||||
model=model,
|
||||
raw_response=final_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
async def async_image_generation(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
optional_params: Dict,
|
||||
litellm_params: Union[GenericLiteLLMParams, Dict],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Async version of image generation.
|
||||
"""
|
||||
# Handle litellm_params as dict or object
|
||||
if isinstance(litellm_params, dict):
|
||||
api_key = litellm_params.get("api_key")
|
||||
api_base = litellm_params.get("api_base")
|
||||
litellm_params_dict = litellm_params
|
||||
else:
|
||||
api_key = litellm_params.api_key
|
||||
api_base = litellm_params.api_base
|
||||
litellm_params_dict = dict(litellm_params)
|
||||
|
||||
if client is None:
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BLACK_FOREST_LABS,
|
||||
)
|
||||
else:
|
||||
async_client = client
|
||||
|
||||
# Validate environment and get headers
|
||||
headers = self.config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers={},
|
||||
model=model,
|
||||
messages=[],
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
# Get complete URL
|
||||
complete_url = self.config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
|
||||
# Transform request
|
||||
data = self.config.transform_image_generation_request(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Logging
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": complete_url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
# Make initial request
|
||||
try:
|
||||
response = await async_client.post(
|
||||
url=complete_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message=f"Request failed: {str(e)}",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
final_response = await self._poll_for_result_async(
|
||||
initial_response=response,
|
||||
headers=headers,
|
||||
async_client=async_client,
|
||||
)
|
||||
|
||||
# Transform response
|
||||
return self.config.transform_image_generation_response(
|
||||
model=model,
|
||||
raw_response=final_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
def _poll_for_result_sync(
|
||||
self,
|
||||
initial_response: httpx.Response,
|
||||
headers: dict,
|
||||
sync_client: HTTPHandler,
|
||||
max_wait: float = DEFAULT_MAX_POLLING_TIME,
|
||||
interval: float = DEFAULT_POLLING_INTERVAL,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Poll BFL API until result is ready (sync version).
|
||||
"""
|
||||
# Validate initial response status code
|
||||
if initial_response.status_code >= 400:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL initial request failed: {initial_response.text}",
|
||||
)
|
||||
|
||||
# Parse initial response to get polling URL
|
||||
try:
|
||||
response_data = initial_response.json()
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"Error parsing initial response: {e}",
|
||||
)
|
||||
|
||||
# Check for immediate errors
|
||||
if "errors" in response_data:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL error: {response_data['errors']}",
|
||||
)
|
||||
|
||||
polling_url = response_data.get("polling_url")
|
||||
if not polling_url:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
start_time = time.time()
|
||||
verbose_logger.debug(f"BFL starting sync polling at {polling_url}")
|
||||
|
||||
while time.time() - start_time < max_wait:
|
||||
response = sync_client.get(
|
||||
url=polling_url,
|
||||
headers=polling_headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BlackForestLabsError(
|
||||
status_code=response.status_code,
|
||||
message=f"Polling failed: {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
status = data.get("status")
|
||||
|
||||
verbose_logger.debug(f"BFL poll status: {status}")
|
||||
|
||||
if status == "Ready":
|
||||
return response
|
||||
elif status in [
|
||||
"Error",
|
||||
"Failed",
|
||||
"Content Moderated",
|
||||
"Request Moderated",
|
||||
]:
|
||||
raise BlackForestLabsError(
|
||||
status_code=400,
|
||||
message=f"Image generation failed: {status}",
|
||||
)
|
||||
|
||||
time.sleep(interval)
|
||||
|
||||
raise BlackForestLabsError(
|
||||
status_code=408,
|
||||
message=f"Polling timed out after {max_wait} seconds",
|
||||
)
|
||||
|
||||
async def _poll_for_result_async(
|
||||
self,
|
||||
initial_response: httpx.Response,
|
||||
headers: dict,
|
||||
async_client: AsyncHTTPHandler,
|
||||
max_wait: float = DEFAULT_MAX_POLLING_TIME,
|
||||
interval: float = DEFAULT_POLLING_INTERVAL,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Poll BFL API until result is ready (async version).
|
||||
"""
|
||||
# Validate initial response status code
|
||||
if initial_response.status_code >= 400:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL initial request failed: {initial_response.text}",
|
||||
)
|
||||
|
||||
# Parse initial response to get polling URL
|
||||
try:
|
||||
response_data = initial_response.json()
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"Error parsing initial response: {e}",
|
||||
)
|
||||
|
||||
# Check for immediate errors
|
||||
if "errors" in response_data:
|
||||
raise BlackForestLabsError(
|
||||
status_code=initial_response.status_code,
|
||||
message=f"BFL error: {response_data['errors']}",
|
||||
)
|
||||
|
||||
polling_url = response_data.get("polling_url")
|
||||
if not polling_url:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
start_time = time.time()
|
||||
verbose_logger.debug(f"BFL starting async polling at {polling_url}")
|
||||
|
||||
while time.time() - start_time < max_wait:
|
||||
response = await async_client.get(
|
||||
url=polling_url,
|
||||
headers=polling_headers,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BlackForestLabsError(
|
||||
status_code=response.status_code,
|
||||
message=f"Polling failed: {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
status = data.get("status")
|
||||
|
||||
verbose_logger.debug(f"BFL poll status: {status}")
|
||||
|
||||
if status == "Ready":
|
||||
return response
|
||||
elif status in [
|
||||
"Error",
|
||||
"Failed",
|
||||
"Content Moderated",
|
||||
"Request Moderated",
|
||||
]:
|
||||
raise BlackForestLabsError(
|
||||
status_code=400,
|
||||
message=f"Image generation failed: {status}",
|
||||
)
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
raise BlackForestLabsError(
|
||||
status_code=408,
|
||||
message=f"Polling timed out after {max_wait} seconds",
|
||||
)
|
||||
|
||||
|
||||
# Singleton instance for use in images/main.py
|
||||
bfl_image_generation = BlackForestLabsImageGeneration()
|
||||
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
Black Forest Labs Image Generation Configuration
|
||||
|
||||
Handles transformation between OpenAI-compatible format and Black Forest Labs API format
|
||||
for image generation endpoints (flux-pro-1.1, flux-pro-1.1-ultra, flux-dev, flux-pro).
|
||||
|
||||
API Reference: https://docs.bfl.ai/
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIImageGenerationOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import ImageObject, ImageResponse
|
||||
|
||||
from ..common_utils import (
|
||||
DEFAULT_API_BASE,
|
||||
IMAGE_GENERATION_MODELS,
|
||||
BlackForestLabsError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BlackForestLabsImageGenerationConfig(BaseImageGenerationConfig):
|
||||
"""
|
||||
Configuration for Black Forest Labs image generation (text-to-image).
|
||||
|
||||
Supports:
|
||||
- flux-pro-1.1: Fast & reliable standard generation
|
||||
- flux-pro-1.1-ultra: Ultra high-resolution (up to 4MP)
|
||||
- flux-dev: Development/open-source variant
|
||||
- flux-pro: Original pro model
|
||||
|
||||
Note: HTTP requests and polling are handled by the handler (handler.py).
|
||||
This class only handles data transformation.
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
Return list of OpenAI params supported by Black Forest Labs.
|
||||
|
||||
Note: BFL uses different parameter names, these are mapped in map_openai_params.
|
||||
"""
|
||||
return [
|
||||
"n", # Number of images (BFL returns 1 per request, but ultra supports up to 4)
|
||||
"size", # Maps to width/height or aspect_ratio
|
||||
"quality", # Maps to raw mode for ultra
|
||||
"seed",
|
||||
"output_format",
|
||||
"safety_tolerance",
|
||||
"prompt_upsampling",
|
||||
"raw",
|
||||
"num_images",
|
||||
"image_url",
|
||||
"image_prompt_strength",
|
||||
"aspect_ratio",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI parameters to Black Forest Labs parameters.
|
||||
|
||||
BFL-specific params are passed through directly.
|
||||
"""
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
for k, v in non_default_params.items():
|
||||
if k in optional_params:
|
||||
continue
|
||||
|
||||
if k in supported_params:
|
||||
# Map OpenAI 'size' to BFL width/height
|
||||
if k == "size" and v:
|
||||
self._map_size_param(v, optional_params)
|
||||
elif k == "n":
|
||||
if "ultra" in model.lower():
|
||||
optional_params["num_images"] = v
|
||||
# non-ultra: silently skip (n=1 is BFL default)
|
||||
elif k == "quality":
|
||||
if v == "hd" and "ultra" in model.lower():
|
||||
optional_params["raw"] = True
|
||||
# other quality values have no BFL mapping
|
||||
else:
|
||||
optional_params[k] = v
|
||||
elif not drop_params:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. "
|
||||
f"Supported parameters are {supported_params}. "
|
||||
f"Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def _map_size_param(self, size: str, optional_params: dict) -> None:
|
||||
"""Map OpenAI size parameter to BFL width/height."""
|
||||
# Common size mappings
|
||||
size_mapping = {
|
||||
"1024x1024": (1024, 1024),
|
||||
"1792x1024": (1792, 1024),
|
||||
"1024x1792": (1024, 1792),
|
||||
"512x512": (512, 512),
|
||||
"256x256": (256, 256),
|
||||
}
|
||||
|
||||
if size in size_mapping:
|
||||
width, height = size_mapping[size]
|
||||
optional_params["width"] = width
|
||||
optional_params["height"] = height
|
||||
elif "x" in size:
|
||||
# Parse custom size
|
||||
try:
|
||||
width, height = map(int, size.lower().split("x"))
|
||||
optional_params["width"] = width
|
||||
optional_params["height"] = height
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid size format: '{size}'. Expected format 'WIDTHxHEIGHT' (e.g., '1024x1024')."
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment and set up headers for Black Forest Labs.
|
||||
|
||||
BFL uses x-key header for authentication.
|
||||
"""
|
||||
final_api_key: Optional[str] = (
|
||||
api_key
|
||||
or get_secret_str("BFL_API_KEY")
|
||||
or get_secret_str("BLACK_FOREST_LABS_API_KEY")
|
||||
)
|
||||
|
||||
if not final_api_key:
|
||||
raise BlackForestLabsError(
|
||||
status_code=401,
|
||||
message="BFL_API_KEY is not set. Please set it via environment variable or pass api_key parameter.",
|
||||
)
|
||||
|
||||
headers["x-key"] = final_api_key
|
||||
headers["Content-Type"] = "application/json"
|
||||
headers["Accept"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
def _get_model_endpoint(self, model: str) -> str:
|
||||
"""
|
||||
Get the API endpoint for a given model.
|
||||
"""
|
||||
# Remove provider prefix if present (e.g., "black_forest_labs/flux-pro-1.1")
|
||||
model_name = model.lower()
|
||||
if "/" in model_name:
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
# Check if model is in our mapping
|
||||
if model_name in IMAGE_GENERATION_MODELS:
|
||||
return IMAGE_GENERATION_MODELS[model_name]
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown BFL image generation model: {model_name}. "
|
||||
f"Supported models: {list(IMAGE_GENERATION_MODELS.keys())}"
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the Black Forest Labs API request.
|
||||
"""
|
||||
base_url: str = api_base or get_secret_str("BFL_API_BASE") or DEFAULT_API_BASE
|
||||
base_url = base_url.rstrip("/")
|
||||
|
||||
endpoint = self._get_model_endpoint(model)
|
||||
return f"{base_url}{endpoint}"
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform OpenAI-style request to Black Forest Labs request format.
|
||||
|
||||
https://docs.bfl.ai/flux_models/flux_1_1_pro
|
||||
"""
|
||||
# Build request body with prompt
|
||||
request_body: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
# BFL-specific params that can be passed through
|
||||
bfl_params = [
|
||||
"width",
|
||||
"height",
|
||||
"aspect_ratio",
|
||||
"seed",
|
||||
"output_format",
|
||||
"safety_tolerance",
|
||||
"prompt_upsampling",
|
||||
# Ultra-specific
|
||||
"raw",
|
||||
"num_images",
|
||||
"image_url",
|
||||
"image_prompt_strength",
|
||||
]
|
||||
|
||||
for param in bfl_params:
|
||||
if param in optional_params and optional_params[param] is not None:
|
||||
request_body[param] = optional_params[param]
|
||||
|
||||
# Set default output format if not specified
|
||||
if "output_format" not in request_body:
|
||||
request_body["output_format"] = "png"
|
||||
|
||||
return request_body
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform Black Forest Labs response to OpenAI-compatible ImageResponse.
|
||||
|
||||
This is called with the FINAL polled response (after handler does polling).
|
||||
The response contains: {"status": "Ready", "result": {"sample": "https://..."}}
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise BlackForestLabsError(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"Error parsing BFL response: {e}",
|
||||
)
|
||||
|
||||
result = response_data.get("result", {})
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Handle single image (sample) or multiple images
|
||||
if isinstance(result, dict) and "sample" in result:
|
||||
model_response.data.append(ImageObject(url=result["sample"]))
|
||||
elif isinstance(result, list):
|
||||
# Multiple images returned
|
||||
for img in result:
|
||||
if isinstance(img, str):
|
||||
model_response.data.append(ImageObject(url=img))
|
||||
elif isinstance(img, dict) and "url" in img:
|
||||
model_response.data.append(ImageObject(url=img["url"]))
|
||||
|
||||
if not model_response.data:
|
||||
raise BlackForestLabsError(
|
||||
status_code=500,
|
||||
message="No image URL in BFL result",
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BlackForestLabsError:
|
||||
"""Return the appropriate error class for Black Forest Labs."""
|
||||
return BlackForestLabsError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
)
|
||||
|
||||
|
||||
def get_black_forest_labs_image_generation_config(
|
||||
model: str,
|
||||
) -> BlackForestLabsImageGenerationConfig:
|
||||
"""
|
||||
Get the appropriate image generation config for a Black Forest Labs model.
|
||||
|
||||
Currently returns a single config class, but can be extended
|
||||
for model-specific configurations if needed.
|
||||
"""
|
||||
return BlackForestLabsImageGenerationConfig()
|
||||
Reference in New Issue
Block a user