chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
from .common_utils import (
DEFAULT_API_BASE,
DEFAULT_MAX_POLLING_TIME,
DEFAULT_POLLING_INTERVAL,
IMAGE_EDIT_MODELS,
IMAGE_GENERATION_MODELS,
BlackForestLabsError,
)
from .image_edit import BlackForestLabsImageEditConfig
from .image_generation import BlackForestLabsImageGenerationConfig
__all__ = [
"BlackForestLabsError",
"BlackForestLabsImageEditConfig",
"BlackForestLabsImageGenerationConfig",
"DEFAULT_API_BASE",
"DEFAULT_MAX_POLLING_TIME",
"DEFAULT_POLLING_INTERVAL",
"IMAGE_EDIT_MODELS",
"IMAGE_GENERATION_MODELS",
]

View File

@@ -0,0 +1,42 @@
"""
Black Forest Labs Common Utilities
Common utilities, constants, and error handling for Black Forest Labs API.
"""
from typing import Dict
from litellm.llms.base_llm.chat.transformation import BaseLLMException
class BlackForestLabsError(BaseLLMException):
"""Exception class for Black Forest Labs API errors."""
pass
# API Constants
DEFAULT_API_BASE = "https://api.bfl.ai"
# Polling configuration
DEFAULT_POLLING_INTERVAL = 1.5 # seconds
DEFAULT_MAX_POLLING_TIME = 300 # 5 minutes
# Model to endpoint mapping for image edit
IMAGE_EDIT_MODELS: Dict[str, str] = {
"flux-kontext-pro": "/v1/flux-kontext-pro",
"flux-kontext-max": "/v1/flux-kontext-max",
"flux-pro-1.0-fill": "/v1/flux-pro-1.0-fill",
"flux-pro-1.0-expand": "/v1/flux-pro-1.0-expand",
}
# Model to endpoint mapping for image generation
IMAGE_GENERATION_MODELS: Dict[str, str] = {
"flux-pro-1.1": "/v1/flux-pro-1.1",
"flux-pro-1.1-ultra": "/v1/flux-pro-1.1-ultra",
"flux-dev": "/v1/flux-dev",
"flux-pro": "/v1/flux-pro",
# Kontext models support both text-to-image and image editing
"flux-kontext-pro": "/v1/flux-kontext-pro",
"flux-kontext-max": "/v1/flux-kontext-max",
}

View File

@@ -0,0 +1,8 @@
from .handler import BlackForestLabsImageEdit, bfl_image_edit
from .transformation import BlackForestLabsImageEditConfig
__all__ = [
"BlackForestLabsImageEditConfig",
"BlackForestLabsImageEdit",
"bfl_image_edit",
]

View File

@@ -0,0 +1,464 @@
"""
Black Forest Labs Image Edit Handler
Handles image edit requests for Black Forest Labs models.
BFL uses an async polling pattern - the initial request returns a task ID,
then we poll until the result is ready.
"""
import asyncio
import time
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import FileTypes, ImageResponse
from ..common_utils import (
DEFAULT_MAX_POLLING_TIME,
DEFAULT_POLLING_INTERVAL,
BlackForestLabsError,
)
from .transformation import BlackForestLabsImageEditConfig
class BlackForestLabsImageEdit:
"""
Black Forest Labs Image Edit handler.
Handles the HTTP requests and polling logic, delegating data transformation
to the BlackForestLabsImageEditConfig class.
"""
def __init__(self):
self.config = BlackForestLabsImageEditConfig()
def image_edit(
self,
model: str,
image: Union[FileTypes, List[FileTypes]],
prompt: Optional[str],
image_edit_optional_request_params: Dict,
litellm_params: Union[GenericLiteLLMParams, Dict],
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
extra_headers: Optional[Dict[str, Any]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
aimage_edit: bool = False,
) -> Union[ImageResponse, Any]:
"""
Main entry point for image edit requests.
Args:
model: The model to use (e.g., "black_forest_labs/flux-kontext-pro")
image: The image(s) to edit
prompt: The edit instruction
image_edit_optional_request_params: Optional parameters for the request
litellm_params: LiteLLM parameters including api_key, api_base
logging_obj: Logging object
timeout: Request timeout
extra_headers: Additional headers
client: HTTP client to use
aimage_edit: If True, return async coroutine
Returns:
ImageResponse or coroutine if aimage_edit=True
"""
# Handle litellm_params as dict or object
if isinstance(litellm_params, dict):
api_key = litellm_params.get("api_key")
api_base = litellm_params.get("api_base")
litellm_params_dict = litellm_params
else:
api_key = litellm_params.api_key
api_base = litellm_params.api_base
litellm_params_dict = dict(litellm_params)
if aimage_edit:
return self.async_image_edit(
model=model,
image=image,
prompt=prompt,
image_edit_optional_request_params=image_edit_optional_request_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
timeout=timeout,
extra_headers=extra_headers,
client=client if isinstance(client, AsyncHTTPHandler) else None,
)
# Sync version
if client is None or not isinstance(client, HTTPHandler):
sync_client = _get_httpx_client()
else:
sync_client = client
# Validate environment and get headers
headers = self.config.validate_environment(
api_key=api_key,
headers=image_edit_optional_request_params.get("extra_headers", {}) or {},
model=model,
)
if extra_headers:
headers.update(extra_headers)
# Get complete URL
complete_url = self.config.get_complete_url(
model=model,
api_base=api_base,
litellm_params=litellm_params_dict,
)
# Transform request
# Handle image list vs single image
if isinstance(image, list):
if not image:
raise BlackForestLabsError(status_code=400, message="No image provided")
image_input = image[0]
else:
image_input = image
data, _ = self.config.transform_image_edit_request(
model=model,
prompt=prompt or "",
image=image_input,
image_edit_optional_request_params=image_edit_optional_request_params,
litellm_params=litellm_params_dict,
headers=headers,
)
# Logging
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": complete_url,
"headers": headers,
},
)
# Make initial request
try:
response = sync_client.post(
url=complete_url,
headers=headers,
json=data,
timeout=timeout,
)
except Exception as e:
raise BlackForestLabsError(
status_code=500,
message=f"Request failed: {str(e)}",
)
# Poll for result
final_response = self._poll_for_result_sync(
initial_response=response,
headers=headers,
sync_client=sync_client,
)
# Transform response
return self.config.transform_image_edit_response(
model=model,
raw_response=final_response,
logging_obj=logging_obj,
)
async def async_image_edit(
self,
model: str,
image: Union[FileTypes, List[FileTypes]],
prompt: Optional[str],
image_edit_optional_request_params: Dict,
litellm_params: Union[GenericLiteLLMParams, Dict],
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
extra_headers: Optional[Dict[str, Any]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> ImageResponse:
"""
Async version of image edit.
"""
# Handle litellm_params as dict or object
if isinstance(litellm_params, dict):
api_key = litellm_params.get("api_key")
api_base = litellm_params.get("api_base")
litellm_params_dict = litellm_params
else:
api_key = litellm_params.api_key
api_base = litellm_params.api_base
litellm_params_dict = dict(litellm_params)
if client is None:
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.BLACK_FOREST_LABS,
)
else:
async_client = client
# Validate environment and get headers
headers = self.config.validate_environment(
api_key=api_key,
headers=image_edit_optional_request_params.get("extra_headers", {}) or {},
model=model,
)
if extra_headers:
headers.update(extra_headers)
# Get complete URL
complete_url = self.config.get_complete_url(
model=model,
api_base=api_base,
litellm_params=litellm_params_dict,
)
# Transform request
if isinstance(image, list):
if not image:
raise BlackForestLabsError(status_code=400, message="No image provided")
image_input = image[0]
else:
image_input = image
data, _ = self.config.transform_image_edit_request(
model=model,
prompt=prompt or "",
image=image_input,
image_edit_optional_request_params=image_edit_optional_request_params,
litellm_params=litellm_params_dict,
headers=headers,
)
# Logging
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": complete_url,
"headers": headers,
},
)
# Make initial request
try:
response = await async_client.post(
url=complete_url,
headers=headers,
json=data,
timeout=timeout,
)
except Exception as e:
raise BlackForestLabsError(
status_code=500,
message=f"Request failed: {str(e)}",
)
# Poll for result
final_response = await self._poll_for_result_async(
initial_response=response,
headers=headers,
async_client=async_client,
)
# Transform response
return self.config.transform_image_edit_response(
model=model,
raw_response=final_response,
logging_obj=logging_obj,
)
def _poll_for_result_sync(
self,
initial_response: httpx.Response,
headers: dict,
sync_client: HTTPHandler,
max_wait: float = DEFAULT_MAX_POLLING_TIME,
interval: float = DEFAULT_POLLING_INTERVAL,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> httpx.Response:
"""
Poll BFL API until result is ready (sync version).
Args:
initial_response: The initial response containing polling_url
headers: Headers to use for polling (must include x-key)
sync_client: HTTP client
max_wait: Maximum time to wait in seconds
interval: Polling interval in seconds
timeout: Timeout for each individual polling request
Returns:
Final response with completed result
"""
# Validate initial response status code
if initial_response.status_code >= 400:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL initial request failed: {initial_response.text}",
)
# Parse initial response to get polling URL
try:
response_data = initial_response.json()
except Exception as e:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"Error parsing initial response: {e}",
)
# Check for immediate errors
if "errors" in response_data:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL error: {response_data['errors']}",
)
polling_url = response_data.get("polling_url")
if not polling_url:
raise BlackForestLabsError(
status_code=500,
message="No polling_url in BFL response",
)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}
start_time = time.time()
verbose_logger.debug(f"BFL starting sync polling at {polling_url}")
while time.time() - start_time < max_wait:
response = sync_client.get(
url=polling_url,
headers=polling_headers,
)
if response.status_code != 200:
raise BlackForestLabsError(
status_code=response.status_code,
message=f"Polling failed: {response.text}",
)
data = response.json()
status = data.get("status")
verbose_logger.debug(f"BFL poll status: {status}")
if status == "Ready":
return response
elif status in [
"Error",
"Failed",
"Content Moderated",
"Request Moderated",
]:
raise BlackForestLabsError(
status_code=400,
message=f"Image generation failed: {status}",
)
time.sleep(interval)
raise BlackForestLabsError(
status_code=408,
message=f"Polling timed out after {max_wait} seconds",
)
async def _poll_for_result_async(
self,
initial_response: httpx.Response,
headers: dict,
async_client: AsyncHTTPHandler,
max_wait: float = DEFAULT_MAX_POLLING_TIME,
interval: float = DEFAULT_POLLING_INTERVAL,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> httpx.Response:
"""
Poll BFL API until result is ready (async version).
"""
# Validate initial response status code
if initial_response.status_code >= 400:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL initial request failed: {initial_response.text}",
)
# Parse initial response to get polling URL
try:
response_data = initial_response.json()
except Exception as e:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"Error parsing initial response: {e}",
)
# Check for immediate errors
if "errors" in response_data:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL error: {response_data['errors']}",
)
polling_url = response_data.get("polling_url")
if not polling_url:
raise BlackForestLabsError(
status_code=500,
message="No polling_url in BFL response",
)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}
start_time = time.time()
verbose_logger.debug(f"BFL starting async polling at {polling_url}")
while time.time() - start_time < max_wait:
response = await async_client.get(
url=polling_url,
headers=polling_headers,
)
if response.status_code != 200:
raise BlackForestLabsError(
status_code=response.status_code,
message=f"Polling failed: {response.text}",
)
data = response.json()
status = data.get("status")
verbose_logger.debug(f"BFL poll status: {status}")
if status == "Ready":
return response
elif status in [
"Error",
"Failed",
"Content Moderated",
"Request Moderated",
]:
raise BlackForestLabsError(
status_code=400,
message=f"Image generation failed: {status}",
)
await asyncio.sleep(interval)
raise BlackForestLabsError(
status_code=408,
message=f"Polling timed out after {max_wait} seconds",
)
# Singleton instance for use in images/main.py
bfl_image_edit = BlackForestLabsImageEdit()

View File

@@ -0,0 +1,323 @@
"""
Black Forest Labs Image Edit Configuration
Handles transformation between OpenAI-compatible format and Black Forest Labs API format
for image editing endpoints (flux-kontext-pro, flux-kontext-max, etc.).
API Reference: https://docs.bfl.ai/
"""
import base64
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from httpx._types import RequestFiles
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.images.main import ImageEditOptionalRequestParams
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import FileTypes, ImageObject, ImageResponse
from ..common_utils import (
DEFAULT_API_BASE,
IMAGE_EDIT_MODELS,
BlackForestLabsError,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BlackForestLabsImageEditConfig(BaseImageEditConfig):
"""
Configuration for Black Forest Labs image editing.
Supports:
- flux-kontext-pro: General image editing with prompts
- flux-kontext-max: Premium quality editing
- flux-pro-1.0-fill: Inpainting with mask
- flux-pro-1.0-expand: Outpainting (expand image borders)
Note: HTTP requests and polling are handled by the handler (handler.py).
This class only handles data transformation.
"""
def get_supported_openai_params(self, model: str) -> List[str]:
"""
Return list of OpenAI params supported by Black Forest Labs.
Note: BFL uses different parameter names, these are mapped in map_openai_params.
"""
return [
"mask",
"seed",
"output_format",
"safety_tolerance",
"prompt_upsampling",
"aspect_ratio",
"steps",
"guidance",
"grow_mask",
"top",
"bottom",
"left",
"right",
]
def map_openai_params(
self,
image_edit_optional_params: ImageEditOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
"""
Map OpenAI parameters to Black Forest Labs parameters.
BFL-specific params are passed through directly.
"""
optional_params: Dict[str, Any] = {}
# Pass through BFL-specific params
bfl_params = [
"seed",
"output_format",
"safety_tolerance",
"prompt_upsampling",
# Kontext-specific
"aspect_ratio",
# Fill/Inpaint-specific
"steps",
"guidance",
"grow_mask",
# Expand-specific
"top",
"bottom",
"left",
"right",
]
# Convert TypedDict to regular dict for access
params_dict = dict(image_edit_optional_params)
for param in bfl_params:
if param in params_dict:
value = params_dict[param]
if value is not None:
optional_params[param] = value
# Set default output format
if "output_format" not in optional_params:
optional_params["output_format"] = "png"
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
"""
Validate environment and set up headers for Black Forest Labs.
BFL uses x-key header for authentication.
"""
final_api_key: Optional[str] = (
api_key
or get_secret_str("BFL_API_KEY")
or get_secret_str("BLACK_FOREST_LABS_API_KEY")
)
if not final_api_key:
raise BlackForestLabsError(
status_code=401,
message="BFL_API_KEY is not set. Please set it via environment variable or pass api_key parameter.",
)
headers["x-key"] = final_api_key
headers["Content-Type"] = "application/json"
headers["Accept"] = "application/json"
return headers
def use_multipart_form_data(self) -> bool:
"""
BFL uses JSON requests, not multipart/form-data.
"""
return False
def _get_model_endpoint(self, model: str) -> str:
"""
Get the API endpoint for a given model.
"""
# Remove provider prefix if present (e.g., "black_forest_labs/flux-kontext-pro")
model_name = model.lower()
if "/" in model_name:
model_name = model_name.split("/")[-1]
# Check if model is in our mapping
if model_name in IMAGE_EDIT_MODELS:
return IMAGE_EDIT_MODELS[model_name]
raise ValueError(
f"Unknown BFL image edit model: {model_name}. "
f"Supported models: {list(IMAGE_EDIT_MODELS.keys())}"
)
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the complete URL for the Black Forest Labs API request.
"""
base_url: str = api_base or get_secret_str("BFL_API_BASE") or DEFAULT_API_BASE
base_url = base_url.rstrip("/")
endpoint = self._get_model_endpoint(model)
return f"{base_url}{endpoint}"
def _read_image_bytes(
self,
image: Any,
depth: int = 0,
max_depth: int = DEFAULT_MAX_RECURSE_DEPTH,
) -> bytes:
"""Read image bytes from various input types."""
if depth > max_depth:
raise ValueError(
f"Max recursion depth {max_depth} reached while reading image bytes for Black Forest Labs image edit."
)
if isinstance(image, bytes):
return image
elif isinstance(image, list):
# If it's a list, take the first image
return self._read_image_bytes(image[0], depth=depth + 1, max_depth=max_depth)
elif isinstance(image, str):
if image.startswith(("http://", "https://")):
# Download image from URL
response = httpx.get(image, timeout=60.0)
response.raise_for_status()
return response.content
else:
# Assume it's a file path
with open(image, "rb") as f:
return f.read()
elif hasattr(image, "read"):
# File-like object
pos = getattr(image, "tell", lambda: 0)()
if hasattr(image, "seek"):
image.seek(0)
data = image.read()
if hasattr(image, "seek"):
image.seek(pos)
return data
else:
raise ValueError(
f"Unsupported image type: {type(image)}. "
"Expected bytes, str (URL or file path), or file-like object."
)
def transform_image_edit_request(
self,
model: str,
prompt: Optional[str],
image: Optional[FileTypes],
image_edit_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[Dict, RequestFiles]:
"""
Transform OpenAI-style request to Black Forest Labs request format.
BFL uses JSON body with base64-encoded images, not multipart/form-data.
"""
# Read and encode image
image_bytes = self._read_image_bytes(image)
b64_image = base64.b64encode(image_bytes).decode("utf-8")
# Build request body
request_body: Dict[str, Any] = {
"prompt": prompt,
"input_image": b64_image,
}
# Add optional params (only BFL-recognized parameters)
bfl_request_params = [
"seed",
"output_format",
"safety_tolerance",
"prompt_upsampling",
"aspect_ratio",
"steps",
"guidance",
"grow_mask",
"top",
"bottom",
"left",
"right",
]
for key, value in image_edit_optional_request_params.items():
if key in bfl_request_params and value is not None:
request_body[key] = value
# Handle mask if provided (for inpainting)
if "mask" in image_edit_optional_request_params:
mask = image_edit_optional_request_params["mask"]
mask_bytes = self._read_image_bytes(mask)
request_body["mask"] = base64.b64encode(mask_bytes).decode("utf-8")
# BFL uses JSON, not multipart - return empty files
return request_body, []
def transform_image_edit_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ImageResponse:
"""
Transform Black Forest Labs response to OpenAI-compatible ImageResponse.
This is called with the FINAL polled response (after handler does polling).
The response contains: {"status": "Ready", "result": {"sample": "https://..."}}
"""
try:
response_data = raw_response.json()
except Exception as e:
raise BlackForestLabsError(
status_code=raw_response.status_code,
message=f"Error parsing BFL response: {e}",
)
# Get image URL from result
image_url = response_data.get("result", {}).get("sample")
if not image_url:
raise BlackForestLabsError(
status_code=500,
message="No image URL in BFL result",
)
# Build ImageResponse
return ImageResponse(
created=int(time.time()),
data=[ImageObject(url=image_url)],
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BlackForestLabsError:
"""Return the appropriate error class for Black Forest Labs."""
return BlackForestLabsError(
status_code=status_code,
message=error_message,
)

View File

@@ -0,0 +1,12 @@
from .handler import BlackForestLabsImageGeneration, bfl_image_generation
from .transformation import (
BlackForestLabsImageGenerationConfig,
get_black_forest_labs_image_generation_config,
)
__all__ = [
"BlackForestLabsImageGenerationConfig",
"get_black_forest_labs_image_generation_config",
"BlackForestLabsImageGeneration",
"bfl_image_generation",
]

View File

@@ -0,0 +1,450 @@
"""
Black Forest Labs Image Generation Handler
Handles image generation requests for Black Forest Labs models.
BFL uses an async polling pattern - the initial request returns a task ID,
then we poll until the result is ready.
"""
import asyncio
import time
from typing import Any, Dict, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import ImageResponse
from ..common_utils import (
DEFAULT_MAX_POLLING_TIME,
DEFAULT_POLLING_INTERVAL,
BlackForestLabsError,
)
from .transformation import BlackForestLabsImageGenerationConfig
class BlackForestLabsImageGeneration:
"""
Black Forest Labs Image Generation handler.
Handles the HTTP requests and polling logic, delegating data transformation
to the BlackForestLabsImageGenerationConfig class.
"""
def __init__(self):
self.config = BlackForestLabsImageGenerationConfig()
def image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: Dict,
litellm_params: Union[GenericLiteLLMParams, Dict],
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
extra_headers: Optional[Dict[str, Any]] = None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
aimg_generation: bool = False,
) -> Union[ImageResponse, Any]:
"""
Main entry point for image generation requests.
Args:
model: The model to use (e.g., "black_forest_labs/flux-pro-1.1")
prompt: The text prompt for image generation
model_response: ImageResponse object to populate
optional_params: Optional parameters for the request
litellm_params: LiteLLM parameters including api_key, api_base
logging_obj: Logging object
timeout: Request timeout
extra_headers: Additional headers
client: HTTP client to use
aimg_generation: If True, return async coroutine
Returns:
ImageResponse or coroutine if aimg_generation=True
"""
# Handle litellm_params as dict or object
if isinstance(litellm_params, dict):
api_key = litellm_params.get("api_key")
api_base = litellm_params.get("api_base")
litellm_params_dict = litellm_params
else:
api_key = litellm_params.api_key
api_base = litellm_params.api_base
litellm_params_dict = dict(litellm_params)
if aimg_generation:
return self.async_image_generation(
model=model,
prompt=prompt,
model_response=model_response,
optional_params=optional_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
timeout=timeout,
extra_headers=extra_headers,
client=client if isinstance(client, AsyncHTTPHandler) else None,
)
# Sync version
if client is None or not isinstance(client, HTTPHandler):
sync_client = _get_httpx_client()
else:
sync_client = client
# Validate environment and get headers
headers = self.config.validate_environment(
api_key=api_key,
headers={},
model=model,
messages=[],
optional_params=optional_params,
litellm_params=litellm_params_dict,
)
if extra_headers:
headers.update(extra_headers)
# Get complete URL
complete_url = self.config.get_complete_url(
api_base=api_base,
api_key=api_key,
model=model,
optional_params=optional_params,
litellm_params=litellm_params_dict,
)
# Transform request
data = self.config.transform_image_generation_request(
model=model,
prompt=prompt,
optional_params=optional_params,
litellm_params=litellm_params_dict,
headers=headers,
)
# Logging
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": complete_url,
"headers": headers,
},
)
# Make initial request
try:
response = sync_client.post(
url=complete_url,
headers=headers,
json=data,
timeout=timeout,
)
except Exception as e:
raise BlackForestLabsError(
status_code=500,
message=f"Request failed: {str(e)}",
)
# Poll for result
final_response = self._poll_for_result_sync(
initial_response=response,
headers=headers,
sync_client=sync_client,
)
# Transform response
return self.config.transform_image_generation_response(
model=model,
raw_response=final_response,
model_response=model_response,
logging_obj=logging_obj,
)
async def async_image_generation(
self,
model: str,
prompt: str,
model_response: ImageResponse,
optional_params: Dict,
litellm_params: Union[GenericLiteLLMParams, Dict],
logging_obj: LiteLLMLoggingObj,
timeout: Optional[Union[float, httpx.Timeout]],
extra_headers: Optional[Dict[str, Any]] = None,
client: Optional[AsyncHTTPHandler] = None,
) -> ImageResponse:
"""
Async version of image generation.
"""
# Handle litellm_params as dict or object
if isinstance(litellm_params, dict):
api_key = litellm_params.get("api_key")
api_base = litellm_params.get("api_base")
litellm_params_dict = litellm_params
else:
api_key = litellm_params.api_key
api_base = litellm_params.api_base
litellm_params_dict = dict(litellm_params)
if client is None:
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.BLACK_FOREST_LABS,
)
else:
async_client = client
# Validate environment and get headers
headers = self.config.validate_environment(
api_key=api_key,
headers={},
model=model,
messages=[],
optional_params=optional_params,
litellm_params=litellm_params_dict,
)
if extra_headers:
headers.update(extra_headers)
# Get complete URL
complete_url = self.config.get_complete_url(
api_base=api_base,
api_key=api_key,
model=model,
optional_params=optional_params,
litellm_params=litellm_params_dict,
)
# Transform request
data = self.config.transform_image_generation_request(
model=model,
prompt=prompt,
optional_params=optional_params,
litellm_params=litellm_params_dict,
headers=headers,
)
# Logging
logging_obj.pre_call(
input=prompt,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": complete_url,
"headers": headers,
},
)
# Make initial request
try:
response = await async_client.post(
url=complete_url,
headers=headers,
json=data,
timeout=timeout,
)
except Exception as e:
raise BlackForestLabsError(
status_code=500,
message=f"Request failed: {str(e)}",
)
# Poll for result
final_response = await self._poll_for_result_async(
initial_response=response,
headers=headers,
async_client=async_client,
)
# Transform response
return self.config.transform_image_generation_response(
model=model,
raw_response=final_response,
model_response=model_response,
logging_obj=logging_obj,
)
def _poll_for_result_sync(
self,
initial_response: httpx.Response,
headers: dict,
sync_client: HTTPHandler,
max_wait: float = DEFAULT_MAX_POLLING_TIME,
interval: float = DEFAULT_POLLING_INTERVAL,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> httpx.Response:
"""
Poll BFL API until result is ready (sync version).
"""
# Validate initial response status code
if initial_response.status_code >= 400:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL initial request failed: {initial_response.text}",
)
# Parse initial response to get polling URL
try:
response_data = initial_response.json()
except Exception as e:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"Error parsing initial response: {e}",
)
# Check for immediate errors
if "errors" in response_data:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL error: {response_data['errors']}",
)
polling_url = response_data.get("polling_url")
if not polling_url:
raise BlackForestLabsError(
status_code=500,
message="No polling_url in BFL response",
)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}
start_time = time.time()
verbose_logger.debug(f"BFL starting sync polling at {polling_url}")
while time.time() - start_time < max_wait:
response = sync_client.get(
url=polling_url,
headers=polling_headers,
)
if response.status_code != 200:
raise BlackForestLabsError(
status_code=response.status_code,
message=f"Polling failed: {response.text}",
)
data = response.json()
status = data.get("status")
verbose_logger.debug(f"BFL poll status: {status}")
if status == "Ready":
return response
elif status in [
"Error",
"Failed",
"Content Moderated",
"Request Moderated",
]:
raise BlackForestLabsError(
status_code=400,
message=f"Image generation failed: {status}",
)
time.sleep(interval)
raise BlackForestLabsError(
status_code=408,
message=f"Polling timed out after {max_wait} seconds",
)
async def _poll_for_result_async(
self,
initial_response: httpx.Response,
headers: dict,
async_client: AsyncHTTPHandler,
max_wait: float = DEFAULT_MAX_POLLING_TIME,
interval: float = DEFAULT_POLLING_INTERVAL,
timeout: Optional[Union[float, httpx.Timeout]] = None,
) -> httpx.Response:
"""
Poll BFL API until result is ready (async version).
"""
# Validate initial response status code
if initial_response.status_code >= 400:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL initial request failed: {initial_response.text}",
)
# Parse initial response to get polling URL
try:
response_data = initial_response.json()
except Exception as e:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"Error parsing initial response: {e}",
)
# Check for immediate errors
if "errors" in response_data:
raise BlackForestLabsError(
status_code=initial_response.status_code,
message=f"BFL error: {response_data['errors']}",
)
polling_url = response_data.get("polling_url")
if not polling_url:
raise BlackForestLabsError(
status_code=500,
message="No polling_url in BFL response",
)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}
start_time = time.time()
verbose_logger.debug(f"BFL starting async polling at {polling_url}")
while time.time() - start_time < max_wait:
response = await async_client.get(
url=polling_url,
headers=polling_headers,
)
if response.status_code != 200:
raise BlackForestLabsError(
status_code=response.status_code,
message=f"Polling failed: {response.text}",
)
data = response.json()
status = data.get("status")
verbose_logger.debug(f"BFL poll status: {status}")
if status == "Ready":
return response
elif status in [
"Error",
"Failed",
"Content Moderated",
"Request Moderated",
]:
raise BlackForestLabsError(
status_code=400,
message=f"Image generation failed: {status}",
)
await asyncio.sleep(interval)
raise BlackForestLabsError(
status_code=408,
message=f"Polling timed out after {max_wait} seconds",
)
# Singleton instance for use in images/main.py
bfl_image_generation = BlackForestLabsImageGeneration()

View File

@@ -0,0 +1,327 @@
"""
Black Forest Labs Image Generation Configuration
Handles transformation between OpenAI-compatible format and Black Forest Labs API format
for image generation endpoints (flux-pro-1.1, flux-pro-1.1-ultra, flux-dev, flux-pro).
API Reference: https://docs.bfl.ai/
"""
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIImageGenerationOptionalParams,
)
from litellm.types.utils import ImageObject, ImageResponse
from ..common_utils import (
DEFAULT_API_BASE,
IMAGE_GENERATION_MODELS,
BlackForestLabsError,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BlackForestLabsImageGenerationConfig(BaseImageGenerationConfig):
"""
Configuration for Black Forest Labs image generation (text-to-image).
Supports:
- flux-pro-1.1: Fast & reliable standard generation
- flux-pro-1.1-ultra: Ultra high-resolution (up to 4MP)
- flux-dev: Development/open-source variant
- flux-pro: Original pro model
Note: HTTP requests and polling are handled by the handler (handler.py).
This class only handles data transformation.
"""
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageGenerationOptionalParams]:
"""
Return list of OpenAI params supported by Black Forest Labs.
Note: BFL uses different parameter names, these are mapped in map_openai_params.
"""
return [
"n", # Number of images (BFL returns 1 per request, but ultra supports up to 4)
"size", # Maps to width/height or aspect_ratio
"quality", # Maps to raw mode for ultra
"seed",
"output_format",
"safety_tolerance",
"prompt_upsampling",
"raw",
"num_images",
"image_url",
"image_prompt_strength",
"aspect_ratio",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI parameters to Black Forest Labs parameters.
BFL-specific params are passed through directly.
"""
supported_params = self.get_supported_openai_params(model)
for k, v in non_default_params.items():
if k in optional_params:
continue
if k in supported_params:
# Map OpenAI 'size' to BFL width/height
if k == "size" and v:
self._map_size_param(v, optional_params)
elif k == "n":
if "ultra" in model.lower():
optional_params["num_images"] = v
# non-ultra: silently skip (n=1 is BFL default)
elif k == "quality":
if v == "hd" and "ultra" in model.lower():
optional_params["raw"] = True
# other quality values have no BFL mapping
else:
optional_params[k] = v
elif not drop_params:
raise ValueError(
f"Parameter {k} is not supported for model {model}. "
f"Supported parameters are {supported_params}. "
f"Set drop_params=True to drop unsupported parameters."
)
return optional_params
def _map_size_param(self, size: str, optional_params: dict) -> None:
"""Map OpenAI size parameter to BFL width/height."""
# Common size mappings
size_mapping = {
"1024x1024": (1024, 1024),
"1792x1024": (1792, 1024),
"1024x1792": (1024, 1792),
"512x512": (512, 512),
"256x256": (256, 256),
}
if size in size_mapping:
width, height = size_mapping[size]
optional_params["width"] = width
optional_params["height"] = height
elif "x" in size:
# Parse custom size
try:
width, height = map(int, size.lower().split("x"))
optional_params["width"] = width
optional_params["height"] = height
except ValueError:
raise ValueError(
f"Invalid size format: '{size}'. Expected format 'WIDTHxHEIGHT' (e.g., '1024x1024')."
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate environment and set up headers for Black Forest Labs.
BFL uses x-key header for authentication.
"""
final_api_key: Optional[str] = (
api_key
or get_secret_str("BFL_API_KEY")
or get_secret_str("BLACK_FOREST_LABS_API_KEY")
)
if not final_api_key:
raise BlackForestLabsError(
status_code=401,
message="BFL_API_KEY is not set. Please set it via environment variable or pass api_key parameter.",
)
headers["x-key"] = final_api_key
headers["Content-Type"] = "application/json"
headers["Accept"] = "application/json"
return headers
def _get_model_endpoint(self, model: str) -> str:
"""
Get the API endpoint for a given model.
"""
# Remove provider prefix if present (e.g., "black_forest_labs/flux-pro-1.1")
model_name = model.lower()
if "/" in model_name:
model_name = model_name.split("/")[-1]
# Check if model is in our mapping
if model_name in IMAGE_GENERATION_MODELS:
return IMAGE_GENERATION_MODELS[model_name]
raise ValueError(
f"Unknown BFL image generation model: {model_name}. "
f"Supported models: {list(IMAGE_GENERATION_MODELS.keys())}"
)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the Black Forest Labs API request.
"""
base_url: str = api_base or get_secret_str("BFL_API_BASE") or DEFAULT_API_BASE
base_url = base_url.rstrip("/")
endpoint = self._get_model_endpoint(model)
return f"{base_url}{endpoint}"
def transform_image_generation_request(
self,
model: str,
prompt: str,
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform OpenAI-style request to Black Forest Labs request format.
https://docs.bfl.ai/flux_models/flux_1_1_pro
"""
# Build request body with prompt
request_body: Dict[str, Any] = {
"prompt": prompt,
}
# BFL-specific params that can be passed through
bfl_params = [
"width",
"height",
"aspect_ratio",
"seed",
"output_format",
"safety_tolerance",
"prompt_upsampling",
# Ultra-specific
"raw",
"num_images",
"image_url",
"image_prompt_strength",
]
for param in bfl_params:
if param in optional_params and optional_params[param] is not None:
request_body[param] = optional_params[param]
# Set default output format if not specified
if "output_format" not in request_body:
request_body["output_format"] = "png"
return request_body
def transform_image_generation_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ImageResponse:
"""
Transform Black Forest Labs response to OpenAI-compatible ImageResponse.
This is called with the FINAL polled response (after handler does polling).
The response contains: {"status": "Ready", "result": {"sample": "https://..."}}
"""
try:
response_data = raw_response.json()
except Exception as e:
raise BlackForestLabsError(
status_code=raw_response.status_code,
message=f"Error parsing BFL response: {e}",
)
result = response_data.get("result", {})
if not model_response.data:
model_response.data = []
# Handle single image (sample) or multiple images
if isinstance(result, dict) and "sample" in result:
model_response.data.append(ImageObject(url=result["sample"]))
elif isinstance(result, list):
# Multiple images returned
for img in result:
if isinstance(img, str):
model_response.data.append(ImageObject(url=img))
elif isinstance(img, dict) and "url" in img:
model_response.data.append(ImageObject(url=img["url"]))
if not model_response.data:
raise BlackForestLabsError(
status_code=500,
message="No image URL in BFL result",
)
model_response.created = int(time.time())
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BlackForestLabsError:
"""Return the appropriate error class for Black Forest Labs."""
return BlackForestLabsError(
status_code=status_code,
message=error_message,
)
def get_black_forest_labs_image_generation_config(
model: str,
) -> BlackForestLabsImageGenerationConfig:
"""
Get the appropriate image generation config for a Black Forest Labs model.
Currently returns a single config class, but can be extended
for model-specific configurations if needed.
"""
return BlackForestLabsImageGenerationConfig()