chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,74 @@
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
|
||||
from .bria_transformation import FalAIBriaConfig
|
||||
from .flux_pro_v11_transformation import FalAIFluxProV11Config
|
||||
from .flux_pro_v11_ultra_transformation import FalAIFluxProV11UltraConfig
|
||||
from .flux_schnell_transformation import FalAIFluxSchnellConfig
|
||||
from .imagen4_transformation import FalAIImagen4Config
|
||||
from .recraft_v3_transformation import FalAIRecraftV3Config
|
||||
from .ideogram_v3_transformation import FalAIIdeogramV3Config
|
||||
from .stable_diffusion_transformation import FalAIStableDiffusionConfig
|
||||
from .transformation import FalAIBaseConfig, FalAIImageGenerationConfig
|
||||
from .bytedance_transformation import (
|
||||
FalAIBytedanceSeedreamV3Config,
|
||||
FalAIBytedanceDreaminaV31Config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FalAIBaseConfig",
|
||||
"FalAIImageGenerationConfig",
|
||||
"FalAIImagen4Config",
|
||||
"FalAIRecraftV3Config",
|
||||
"FalAIBriaConfig",
|
||||
"FalAIFluxProV11Config",
|
||||
"FalAIFluxProV11UltraConfig",
|
||||
"FalAIFluxSchnellConfig",
|
||||
"FalAIStableDiffusionConfig",
|
||||
"FalAIBytedanceSeedreamV3Config",
|
||||
"FalAIBytedanceDreaminaV31Config",
|
||||
"FalAIIdeogramV3Config",
|
||||
]
|
||||
|
||||
|
||||
def get_fal_ai_image_generation_config(model: str) -> BaseImageGenerationConfig:
|
||||
"""
|
||||
Get the appropriate Fal AI image generation configuration based on the model.
|
||||
|
||||
Args:
|
||||
model: The Fal AI model name (e.g., "fal-ai/imagen4/preview", "fal-ai/recraft/v3/text-to-image")
|
||||
|
||||
Returns:
|
||||
The appropriate configuration class for the specified model
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
# Map model names to their corresponding configuration classes
|
||||
if "imagen4" in model_lower or "imagen-4" in model_lower:
|
||||
return FalAIImagen4Config()
|
||||
elif "recraft" in model_lower:
|
||||
return FalAIRecraftV3Config()
|
||||
elif "bria" in model_lower:
|
||||
return FalAIBriaConfig()
|
||||
elif "flux-pro" in model_lower:
|
||||
if "ultra" in model_lower:
|
||||
return FalAIFluxProV11UltraConfig()
|
||||
return FalAIFluxProV11Config()
|
||||
elif (
|
||||
"flux/schnell" in model_lower
|
||||
or "flux-schnell" in model_lower
|
||||
or "schnell" in model_lower
|
||||
):
|
||||
return FalAIFluxSchnellConfig()
|
||||
elif "bytedance/seedream" in model_lower:
|
||||
return FalAIBytedanceSeedreamV3Config()
|
||||
elif "bytedance/dreamina" in model_lower:
|
||||
return FalAIBytedanceDreaminaV31Config()
|
||||
elif "ideogram" in model_lower:
|
||||
return FalAIIdeogramV3Config()
|
||||
elif "stable-diffusion" in model_lower:
|
||||
return FalAIStableDiffusionConfig()
|
||||
|
||||
# Default to generic Fal AI configuration
|
||||
return FalAIImageGenerationConfig()
|
||||
@@ -0,0 +1,231 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
|
||||
from litellm.types.utils import ImageObject, ImageResponse
|
||||
|
||||
from .transformation import FalAIBaseConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class FalAIBriaConfig(FalAIBaseConfig):
|
||||
"""
|
||||
Configuration for Bria Text-to-Image 3.2 model.
|
||||
|
||||
Bria 3.2 is a commercial-grade text-to-image model with prompt enhancement
|
||||
and multiple aspect ratio options.
|
||||
|
||||
Model endpoint: bria/text-to-image/3.2
|
||||
Documentation: https://fal.ai/models/bria/text-to-image/3.2
|
||||
"""
|
||||
|
||||
IMAGE_GENERATION_ENDPOINT: str = "bria/text-to-image/3.2"
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
Get supported OpenAI parameters for Bria 3.2.
|
||||
"""
|
||||
return [
|
||||
"n",
|
||||
"response_format",
|
||||
"size",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI parameters to Bria 3.2 parameters.
|
||||
|
||||
Mappings:
|
||||
- size -> aspect_ratio (1:1, 2:3, 3:2, 3:4, 4:3, 4:5, 5:4, 9:16, 16:9)
|
||||
- response_format -> ignored (Bria returns URLs)
|
||||
- n -> ignored (Bria doesn't support multiple images in one call)
|
||||
"""
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
# Map OpenAI params to Bria params
|
||||
param_mapping = {
|
||||
"size": "aspect_ratio",
|
||||
}
|
||||
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
# Use mapped parameter name if exists
|
||||
mapped_key = param_mapping.get(k, k)
|
||||
mapped_value = non_default_params[k]
|
||||
|
||||
# Transform specific parameters
|
||||
if k == "response_format":
|
||||
# Bria always returns URLs, so we can ignore this
|
||||
continue
|
||||
elif k == "n":
|
||||
# Bria doesn't support multiple images, ignore
|
||||
continue
|
||||
elif k == "size":
|
||||
# Map OpenAI size format to Bria aspect ratio
|
||||
mapped_value = self._map_aspect_ratio(mapped_value)
|
||||
|
||||
optional_params[mapped_key] = mapped_value
|
||||
elif drop_params:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def _map_aspect_ratio(self, size: str) -> str:
|
||||
"""
|
||||
Map OpenAI size format to Bria aspect ratio format.
|
||||
|
||||
OpenAI format: "1024x1024", "1792x1024", etc.
|
||||
Bria format: "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9"
|
||||
"""
|
||||
# Map common OpenAI sizes to Bria aspect ratios
|
||||
size_to_aspect_ratio = {
|
||||
"1024x1024": "1:1",
|
||||
"512x512": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1024x768": "4:3",
|
||||
"768x1024": "3:4",
|
||||
"1280x960": "4:3",
|
||||
"960x1280": "3:4",
|
||||
}
|
||||
|
||||
if size in size_to_aspect_ratio:
|
||||
return size_to_aspect_ratio[size]
|
||||
|
||||
# Parse custom size format "WIDTHxHEIGHT" and calculate aspect ratio
|
||||
if "x" in size:
|
||||
try:
|
||||
width_str, height_str = size.split("x")
|
||||
width = int(width_str)
|
||||
height = int(height_str)
|
||||
|
||||
# Calculate aspect ratio and find closest match
|
||||
ratio = width / height
|
||||
|
||||
# Map to closest supported aspect ratio
|
||||
if 0.95 <= ratio <= 1.05: # Close to 1:1
|
||||
return "1:1"
|
||||
elif ratio >= 1.7: # Close to 16:9
|
||||
return "16:9"
|
||||
elif ratio <= 0.6: # Close to 9:16
|
||||
return "9:16"
|
||||
elif 1.3 <= ratio <= 1.4: # Close to 4:3
|
||||
return "4:3"
|
||||
elif 0.7 <= ratio <= 0.8: # Close to 3:4
|
||||
return "3:4"
|
||||
elif 1.45 <= ratio <= 1.55: # Close to 3:2
|
||||
return "3:2"
|
||||
elif 0.65 <= ratio <= 0.7: # Close to 2:3
|
||||
return "2:3"
|
||||
elif 1.2 <= ratio <= 1.3: # Close to 5:4
|
||||
return "5:4"
|
||||
elif 0.75 <= ratio <= 0.85: # Close to 4:5
|
||||
return "4:5"
|
||||
except (ValueError, AttributeError, ZeroDivisionError):
|
||||
pass
|
||||
|
||||
# Default to 1:1
|
||||
return "1:1"
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the image generation request to Bria 3.2 request body.
|
||||
|
||||
Required parameters:
|
||||
- prompt: Prompt for image generation
|
||||
|
||||
Optional parameters:
|
||||
- aspect_ratio: "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9" (default: "1:1")
|
||||
- prompt_enhancer: Improve the prompt (default: true)
|
||||
- sync_mode: Return image directly in response (default: false)
|
||||
- truncate_prompt: Truncate the prompt (default: true)
|
||||
- guidance_scale: Guidance scale 1-10 (default: 5)
|
||||
- num_inference_steps: Inference steps 20-50 (default: 30)
|
||||
- seed: Random seed for reproducibility (default: 5555)
|
||||
- negative_prompt: Negative prompt string
|
||||
"""
|
||||
bria_request_body = {
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
return bria_request_body
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the Bria 3.2 response to litellm ImageResponse format.
|
||||
|
||||
Expected response format:
|
||||
{
|
||||
"image": {
|
||||
"url": "https://...",
|
||||
"content_type": "image/png",
|
||||
"file_name": "...",
|
||||
"file_size": 123456,
|
||||
"width": 1024,
|
||||
"height": 1024
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Handle Bria response format - uses "image" (singular) not "images"
|
||||
image_data = response_data.get("image")
|
||||
if image_data and isinstance(image_data, dict):
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
url=image_data.get("url", None),
|
||||
b64_json=None, # Bria returns URLs only
|
||||
)
|
||||
)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,104 @@
|
||||
from typing import Any
|
||||
|
||||
from .flux_pro_v11_ultra_transformation import FalAIFluxProV11UltraConfig
|
||||
|
||||
|
||||
class FalAIBytedanceBaseConfig(FalAIFluxProV11UltraConfig):
|
||||
"""
|
||||
Shared configuration for Fal AI ByteDance text-to-image models that follow
|
||||
the Flux Schnell style parameter mapping.
|
||||
|
||||
These models accept the OpenAI-compatible `size` parameter in LiteLLM
|
||||
requests but expect `image_size` enums or custom size objects on Fal AI.
|
||||
"""
|
||||
|
||||
_OPENAI_SIZE_TO_IMAGE_SIZE = {
|
||||
"1024x1024": "square_hd",
|
||||
"512x512": "square",
|
||||
"1792x1024": "landscape_16_9",
|
||||
"1024x1792": "portrait_16_9",
|
||||
"1024x768": "landscape_4_3",
|
||||
"768x1024": "portrait_4_3",
|
||||
}
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
param_mapping = {
|
||||
"n": "num_images",
|
||||
"response_format": "output_format",
|
||||
"size": "image_size",
|
||||
}
|
||||
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
mapped_key = param_mapping.get(k, k)
|
||||
mapped_value = non_default_params[k]
|
||||
|
||||
if k == "response_format":
|
||||
if mapped_value in ["b64_json", "url"]:
|
||||
mapped_value = "jpeg"
|
||||
elif k == "size":
|
||||
mapped_value = self._map_image_size(mapped_value)
|
||||
|
||||
optional_params[mapped_key] = mapped_value
|
||||
elif drop_params:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. "
|
||||
f"Supported parameters are {supported_params}. "
|
||||
"Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def _map_image_size(self, size: Any) -> Any:
|
||||
if isinstance(size, dict):
|
||||
return size
|
||||
|
||||
if not isinstance(size, str):
|
||||
return size
|
||||
|
||||
if size in self._OPENAI_SIZE_TO_IMAGE_SIZE:
|
||||
return self._OPENAI_SIZE_TO_IMAGE_SIZE[size]
|
||||
|
||||
if "x" in size:
|
||||
try:
|
||||
width_str, height_str = size.split("x")
|
||||
width = int(width_str)
|
||||
height = int(height_str)
|
||||
return {"width": width, "height": height}
|
||||
except (ValueError, AttributeError, ZeroDivisionError):
|
||||
pass
|
||||
|
||||
return "landscape_4_3"
|
||||
|
||||
|
||||
class FalAIBytedanceSeedreamV3Config(FalAIBytedanceBaseConfig):
|
||||
"""
|
||||
Configuration for Fal AI ByteDance Seedream v3 text-to-image model.
|
||||
|
||||
Model endpoint: fal-ai/bytedance/seedream/v3/text-to-image
|
||||
Documentation: https://fal.ai/models/fal-ai/bytedance/seedream/v3/text-to-image
|
||||
"""
|
||||
|
||||
IMAGE_GENERATION_ENDPOINT: str = "fal-ai/bytedance/seedream/v3/text-to-image"
|
||||
|
||||
|
||||
class FalAIBytedanceDreaminaV31Config(FalAIBytedanceBaseConfig):
|
||||
"""
|
||||
Configuration for Fal AI ByteDance Dreamina v3.1 text-to-image model.
|
||||
|
||||
Model endpoint: fal-ai/bytedance/dreamina/v3.1/text-to-image
|
||||
Documentation: https://fal.ai/models/fal-ai/bytedance/dreamina/v3.1/text-to-image
|
||||
"""
|
||||
|
||||
IMAGE_GENERATION_ENDPOINT: str = "fal-ai/bytedance/dreamina/v3.1/text-to-image"
|
||||
@@ -0,0 +1,89 @@
|
||||
from typing import Any
|
||||
|
||||
from .flux_pro_v11_ultra_transformation import FalAIFluxProV11UltraConfig
|
||||
|
||||
|
||||
class FalAIFluxProV11Config(FalAIFluxProV11UltraConfig):
|
||||
"""
|
||||
Configuration for Fal AI Flux Pro v1.1 model.
|
||||
|
||||
FLUX Pro v1.1 leverages the same overall request/response structure as the
|
||||
Ultra variant but expects the `image_size` parameter instead of
|
||||
`aspect_ratio`.
|
||||
|
||||
Model endpoint: fal-ai/flux-pro/v1.1
|
||||
Documentation: https://fal.ai/models/fal-ai/flux-pro/v1.1
|
||||
"""
|
||||
|
||||
IMAGE_GENERATION_ENDPOINT: str = "fal-ai/flux-pro/v1.1"
|
||||
|
||||
_OPENAI_SIZE_TO_IMAGE_SIZE = {
|
||||
"1024x1024": "square_hd",
|
||||
"512x512": "square",
|
||||
"1792x1024": "landscape_16_9",
|
||||
"1024x1792": "portrait_16_9",
|
||||
"1024x768": "landscape_4_3",
|
||||
"768x1024": "portrait_4_3",
|
||||
}
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Override size handling to map to Flux Pro v1.1 image_size enums/object.
|
||||
"""
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
param_mapping = {
|
||||
"n": "num_images",
|
||||
"response_format": "output_format",
|
||||
"size": "image_size",
|
||||
}
|
||||
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
mapped_key = param_mapping.get(k, k)
|
||||
mapped_value = non_default_params[k]
|
||||
|
||||
if k == "response_format":
|
||||
if mapped_value in ["b64_json", "url"]:
|
||||
mapped_value = "jpeg"
|
||||
elif k == "size":
|
||||
mapped_value = self._map_image_size(mapped_value)
|
||||
|
||||
optional_params[mapped_key] = mapped_value
|
||||
elif drop_params:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. "
|
||||
f"Supported parameters are {supported_params}. "
|
||||
"Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def _map_image_size(self, size: Any) -> Any:
|
||||
if isinstance(size, dict):
|
||||
return size
|
||||
if not isinstance(size, str):
|
||||
return size
|
||||
|
||||
if size in self._OPENAI_SIZE_TO_IMAGE_SIZE:
|
||||
return self._OPENAI_SIZE_TO_IMAGE_SIZE[size]
|
||||
|
||||
if "x" in size:
|
||||
try:
|
||||
width_str, height_str = size.split("x")
|
||||
width = int(width_str)
|
||||
height = int(height_str)
|
||||
return {"width": width, "height": height}
|
||||
except (ValueError, AttributeError, ZeroDivisionError):
|
||||
pass
|
||||
|
||||
return "landscape_4_3"
|
||||
@@ -0,0 +1,263 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
|
||||
from litellm.types.utils import ImageObject, ImageResponse
|
||||
|
||||
from .transformation import FalAIBaseConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class FalAIFluxProV11UltraConfig(FalAIBaseConfig):
|
||||
"""
|
||||
Configuration for Fal AI Flux Pro v1.1-ultra model.
|
||||
|
||||
FLUX Pro v1.1-ultra is a high-quality text-to-image model with enhanced detail
|
||||
and support for image prompts.
|
||||
|
||||
Model endpoint: fal-ai/flux-pro/v1.1-ultra
|
||||
Documentation: https://fal.ai/models/fal-ai/flux-pro/v1.1-ultra
|
||||
"""
|
||||
|
||||
IMAGE_GENERATION_ENDPOINT: str = "fal-ai/flux-pro/v1.1-ultra"
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
Get supported OpenAI parameters for Flux Pro v1.1-ultra.
|
||||
"""
|
||||
return [
|
||||
"n",
|
||||
"response_format",
|
||||
"size",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI parameters to Flux Pro v1.1-ultra parameters.
|
||||
|
||||
Mappings:
|
||||
- n -> num_images (1-4, default 1)
|
||||
- response_format -> output_format (jpeg or png)
|
||||
- size -> aspect_ratio (21:9, 16:9, 4:3, 3:2, 1:1, 2:3, 3:4, 9:16, 9:21)
|
||||
"""
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
# Map OpenAI params to Flux Pro v1.1-ultra params
|
||||
param_mapping = {
|
||||
"n": "num_images",
|
||||
"response_format": "output_format",
|
||||
"size": "aspect_ratio",
|
||||
}
|
||||
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
# Use mapped parameter name if exists
|
||||
mapped_key = param_mapping.get(k, k)
|
||||
mapped_value = non_default_params[k]
|
||||
|
||||
# Transform specific parameters
|
||||
if k == "response_format":
|
||||
# Map OpenAI response formats to image formats
|
||||
if mapped_value in ["b64_json", "url"]:
|
||||
mapped_value = "jpeg"
|
||||
elif k == "size":
|
||||
# Map OpenAI size format to Flux aspect ratio
|
||||
mapped_value = self._map_aspect_ratio(mapped_value)
|
||||
|
||||
optional_params[mapped_key] = mapped_value
|
||||
elif drop_params:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def _map_aspect_ratio(self, size: str) -> str:
|
||||
"""
|
||||
Map OpenAI size format to Flux Pro aspect ratio format.
|
||||
|
||||
OpenAI format: "1024x1024", "1792x1024", etc.
|
||||
Flux format: "21:9", "16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16", "9:21"
|
||||
|
||||
Default: "16:9"
|
||||
"""
|
||||
# Map common OpenAI sizes to Flux aspect ratios
|
||||
size_to_aspect_ratio = {
|
||||
"1024x1024": "1:1",
|
||||
"512x512": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1024x768": "4:3",
|
||||
"768x1024": "3:4",
|
||||
"1536x1024": "3:2",
|
||||
"1024x1536": "2:3",
|
||||
"2048x876": "21:9",
|
||||
"876x2048": "9:21",
|
||||
}
|
||||
|
||||
if size in size_to_aspect_ratio:
|
||||
return size_to_aspect_ratio[size]
|
||||
|
||||
# Parse custom size format "WIDTHxHEIGHT" and calculate aspect ratio
|
||||
if "x" in size:
|
||||
try:
|
||||
width_str, height_str = size.split("x")
|
||||
width = int(width_str)
|
||||
height = int(height_str)
|
||||
|
||||
# Calculate aspect ratio and find closest match
|
||||
ratio = width / height
|
||||
|
||||
# Map to closest supported aspect ratio
|
||||
if 0.95 <= ratio <= 1.05: # Close to 1:1
|
||||
return "1:1"
|
||||
elif ratio >= 2.3: # Close to 21:9
|
||||
return "21:9"
|
||||
elif 1.7 <= ratio < 2.3: # Close to 16:9
|
||||
return "16:9"
|
||||
elif 1.3 <= ratio < 1.7: # Close to 4:3
|
||||
return "4:3"
|
||||
elif 1.4 <= ratio < 1.6: # Close to 3:2
|
||||
return "3:2"
|
||||
elif 0.6 <= ratio < 0.7: # Close to 3:4
|
||||
return "3:4"
|
||||
elif 0.65 <= ratio < 0.75: # Close to 2:3
|
||||
return "2:3"
|
||||
elif 0.5 <= ratio < 0.6: # Close to 9:16
|
||||
return "9:16"
|
||||
elif ratio < 0.5: # Close to 9:21
|
||||
return "9:21"
|
||||
except (ValueError, AttributeError, ZeroDivisionError):
|
||||
pass
|
||||
|
||||
# Default to 16:9
|
||||
return "16:9"
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the image generation request to Flux Pro v1.1-ultra request body.
|
||||
|
||||
Required parameters:
|
||||
- prompt: The prompt to generate an image from
|
||||
|
||||
Optional parameters:
|
||||
- num_images: Number of images (1-4, default: 1)
|
||||
- aspect_ratio: Aspect ratio (default: "16:9")
|
||||
- raw: Generate less processed images (default: false)
|
||||
- output_format: "jpeg" or "png" (default: "jpeg")
|
||||
- image_url: Image URL for image-to-image generation
|
||||
- sync_mode: Return data URI (default: false)
|
||||
- safety_tolerance: Safety level "1"-"6" (default: "2")
|
||||
- enable_safety_checker: Enable safety checker (default: true)
|
||||
- seed: Random seed for reproducibility
|
||||
- image_prompt_strength: Strength of image prompt 0-1 (default: 0.1)
|
||||
- enhance_prompt: Enhance prompt for better results (default: false)
|
||||
"""
|
||||
flux_pro_request_body = {
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
return flux_pro_request_body
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the Flux Pro v1.1-ultra response to litellm ImageResponse format.
|
||||
|
||||
Expected response format:
|
||||
{
|
||||
"images": [
|
||||
{
|
||||
"url": "https://...",
|
||||
"width": 1024,
|
||||
"height": 768,
|
||||
"content_type": "image/jpeg"
|
||||
}
|
||||
],
|
||||
"timings": {"inference": 2.5, ...},
|
||||
"seed": 42,
|
||||
"has_nsfw_concepts": [false],
|
||||
"prompt": "original prompt"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Handle Flux Pro v1.1-ultra response format
|
||||
images = response_data.get("images", [])
|
||||
if isinstance(images, list):
|
||||
for image_data in images:
|
||||
if isinstance(image_data, dict):
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
url=image_data.get("url", None),
|
||||
b64_json=None, # Flux Pro returns URLs only
|
||||
)
|
||||
)
|
||||
elif isinstance(image_data, str):
|
||||
# If images is just a list of URLs
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
url=image_data,
|
||||
b64_json=None,
|
||||
)
|
||||
)
|
||||
|
||||
# Add additional metadata from Flux Pro response
|
||||
if hasattr(model_response, "_hidden_params"):
|
||||
if "seed" in response_data:
|
||||
model_response._hidden_params["seed"] = response_data["seed"]
|
||||
if "timings" in response_data:
|
||||
model_response._hidden_params["timings"] = response_data["timings"]
|
||||
if "has_nsfw_concepts" in response_data:
|
||||
model_response._hidden_params["has_nsfw_concepts"] = response_data[
|
||||
"has_nsfw_concepts"
|
||||
]
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,87 @@
|
||||
from typing import Any
|
||||
|
||||
from .flux_pro_v11_ultra_transformation import FalAIFluxProV11UltraConfig
|
||||
|
||||
|
||||
class FalAIFluxSchnellConfig(FalAIFluxProV11UltraConfig):
|
||||
"""
|
||||
Configuration for Fal AI Flux Schnell model.
|
||||
|
||||
Flux Schnell shares the same response format as Flux Pro models but expects
|
||||
the OpenAI `size` parameter to be translated into Fal AI's `image_size`
|
||||
enum/object.
|
||||
|
||||
Model endpoint: fal-ai/flux/schnell
|
||||
Documentation: https://fal.ai/models/fal-ai/flux/schnell
|
||||
"""
|
||||
|
||||
IMAGE_GENERATION_ENDPOINT: str = "fal-ai/flux/schnell"
|
||||
|
||||
_OPENAI_SIZE_TO_IMAGE_SIZE = {
|
||||
"1024x1024": "square_hd",
|
||||
"512x512": "square",
|
||||
"1792x1024": "landscape_16_9",
|
||||
"1024x1792": "portrait_16_9",
|
||||
"1024x768": "landscape_4_3",
|
||||
"768x1024": "portrait_4_3",
|
||||
}
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
param_mapping = {
|
||||
"n": "num_images",
|
||||
"response_format": "output_format",
|
||||
"size": "image_size",
|
||||
}
|
||||
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
mapped_key = param_mapping.get(k, k)
|
||||
mapped_value = non_default_params[k]
|
||||
|
||||
if k == "response_format":
|
||||
if mapped_value in ["b64_json", "url"]:
|
||||
mapped_value = "jpeg"
|
||||
elif k == "size":
|
||||
mapped_value = self._map_image_size(mapped_value)
|
||||
|
||||
optional_params[mapped_key] = mapped_value
|
||||
elif drop_params:
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. "
|
||||
f"Supported parameters are {supported_params}. "
|
||||
"Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def _map_image_size(self, size: Any) -> Any:
|
||||
if isinstance(size, dict):
|
||||
return size
|
||||
|
||||
if not isinstance(size, str):
|
||||
return size
|
||||
|
||||
if size in self._OPENAI_SIZE_TO_IMAGE_SIZE:
|
||||
return self._OPENAI_SIZE_TO_IMAGE_SIZE[size]
|
||||
|
||||
if "x" in size:
|
||||
try:
|
||||
width_str, height_str = size.split("x")
|
||||
width = int(width_str)
|
||||
height = int(height_str)
|
||||
return {"width": width, "height": height}
|
||||
except (ValueError, AttributeError, ZeroDivisionError):
|
||||
pass
|
||||
|
||||
return "landscape_4_3"
|
||||
@@ -0,0 +1,191 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
|
||||
from litellm.types.utils import ImageObject, ImageResponse
|
||||
|
||||
from .transformation import FalAIBaseConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class FalAIIdeogramV3Config(FalAIBaseConfig):
|
||||
"""
|
||||
Configuration for fal-ai/ideogram/v3 image generation.
|
||||
|
||||
The Ideogram v3 endpoint exposes multiple generation modes (text-to-image,
|
||||
remixing, reframing, background replacement, character workflows, etc.).
|
||||
LiteLLM focuses on the text-to-image interface to maintain OpenAI parity.
|
||||
|
||||
Model endpoint: fal-ai/ideogram/v3
|
||||
Documentation: https://fal.ai/models/fal-ai/ideogram/v3
|
||||
"""
|
||||
|
||||
IMAGE_GENERATION_ENDPOINT: str = "fal-ai/ideogram/v3"
|
||||
|
||||
_OPENAI_SIZE_TO_IMAGE_SIZE = {
|
||||
"1024x1024": "square_hd",
|
||||
"512x512": "square",
|
||||
"1024x768": "landscape_4_3",
|
||||
"768x1024": "portrait_4_3",
|
||||
"1536x1024": "landscape_16_9",
|
||||
"1024x1536": "portrait_16_9",
|
||||
}
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
Ideogram v3 accepts the core OpenAI image parameters.
|
||||
"""
|
||||
|
||||
return [
|
||||
"n",
|
||||
"response_format",
|
||||
"size",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI-style parameters onto Ideogram's request schema.
|
||||
"""
|
||||
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
for k in non_default_params.keys():
|
||||
if k in optional_params:
|
||||
continue
|
||||
|
||||
if k not in supported_params:
|
||||
if drop_params:
|
||||
continue
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. "
|
||||
f"Supported parameters are {supported_params}. "
|
||||
"Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
value = non_default_params[k]
|
||||
|
||||
if k == "n":
|
||||
optional_params["num_images"] = value
|
||||
elif k == "size":
|
||||
optional_params["image_size"] = self._map_image_size(value)
|
||||
elif k == "response_format":
|
||||
# Ideogram always returns URLs; nothing to map but don't error.
|
||||
continue
|
||||
|
||||
return optional_params
|
||||
|
||||
def _map_image_size(self, size: Any) -> Any:
|
||||
if isinstance(size, dict):
|
||||
width = size.get("width")
|
||||
height = size.get("height")
|
||||
if isinstance(width, int) and isinstance(height, int):
|
||||
return {"width": width, "height": height}
|
||||
return size
|
||||
|
||||
if not isinstance(size, str):
|
||||
return size
|
||||
|
||||
normalized = size.strip()
|
||||
if normalized in self._OPENAI_SIZE_TO_IMAGE_SIZE:
|
||||
return self._OPENAI_SIZE_TO_IMAGE_SIZE[normalized]
|
||||
|
||||
if "x" in normalized:
|
||||
try:
|
||||
width_str, height_str = normalized.split("x")
|
||||
width = int(width_str)
|
||||
height = int(height_str)
|
||||
return {"width": width, "height": height}
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
# Fallback to a safe default that Ideogram accepts.
|
||||
return "square_hd"
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Construct the request payload for Ideogram v3.
|
||||
|
||||
Required:
|
||||
- prompt: text prompt describing the scene.
|
||||
|
||||
Optional (subset):
|
||||
- rendering_speed, style_preset, style, style_codes, color_palette,
|
||||
image_urls, style_reference_images, expand_prompt, seed,
|
||||
negative_prompt, image_size, etc.
|
||||
"""
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Parse Ideogram v3 responses which contain a list of File objects.
|
||||
"""
|
||||
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
images = response_data.get("images", [])
|
||||
if isinstance(images, list):
|
||||
for image_entry in images:
|
||||
if isinstance(image_entry, dict):
|
||||
url = image_entry.get("url")
|
||||
else:
|
||||
url = image_entry
|
||||
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
url=url,
|
||||
b64_json=None,
|
||||
)
|
||||
)
|
||||
|
||||
if hasattr(model_response, "_hidden_params") and "seed" in response_data:
|
||||
model_response._hidden_params["seed"] = response_data["seed"]
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,242 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
|
||||
from litellm.types.utils import ImageObject, ImageResponse
|
||||
|
||||
from .transformation import FalAIBaseConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class FalAIImagen4Config(FalAIBaseConfig):
|
||||
"""
|
||||
Configuration for Fal AI Imagen4 model.
|
||||
|
||||
Google's highest quality image generation model available through Fal AI.
|
||||
|
||||
Model variants:
|
||||
- fal-ai/imagen4/preview (Standard): $0.05 per image
|
||||
- fal-ai/imagen4/preview/fast (Fast): $0.02 per image
|
||||
- fal-ai/imagen4/preview/ultra (Ultra): $0.06 per image
|
||||
|
||||
Documentation: https://fal.ai/models/fal-ai/imagen4/preview
|
||||
"""
|
||||
|
||||
IMAGE_GENERATION_ENDPOINT: str = "fal-ai/imagen4/preview"
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
Get supported OpenAI parameters for Imagen4.
|
||||
"""
|
||||
return [
|
||||
"n",
|
||||
"response_format",
|
||||
"size",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI parameters to Imagen4 parameters.
|
||||
|
||||
Mappings:
|
||||
- n -> num_images (1-4, default 1)
|
||||
- size -> aspect_ratio (1:1, 16:9, 9:16, 3:4, 4:3)
|
||||
- response_format -> ignored (Imagen4 returns URLs)
|
||||
"""
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
# Map OpenAI params to Imagen4 params
|
||||
param_mapping = {
|
||||
"n": "num_images",
|
||||
"size": "aspect_ratio",
|
||||
}
|
||||
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
# Use mapped parameter name if exists
|
||||
mapped_key = param_mapping.get(k, k)
|
||||
mapped_value = non_default_params[k]
|
||||
|
||||
# Transform specific parameters
|
||||
if k == "response_format":
|
||||
# Imagen4 always returns URLs, so we can ignore this
|
||||
continue
|
||||
elif k == "size":
|
||||
# Map OpenAI size format to Imagen4 aspect ratio
|
||||
mapped_value = self._map_aspect_ratio(mapped_value)
|
||||
|
||||
optional_params[mapped_key] = mapped_value
|
||||
elif drop_params:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def _map_aspect_ratio(self, size: str) -> str:
|
||||
"""
|
||||
Map OpenAI size format to Imagen4 aspect ratio format.
|
||||
|
||||
OpenAI format: "1024x1024", "1792x1024", etc.
|
||||
Imagen4 format: "1:1", "16:9", "9:16", "3:4", "4:3"
|
||||
|
||||
Available aspect ratios:
|
||||
- 1:1 (default)
|
||||
- 16:9
|
||||
- 9:16
|
||||
- 3:4
|
||||
- 4:3
|
||||
"""
|
||||
# Map common OpenAI sizes to Imagen4 aspect ratios
|
||||
size_to_aspect_ratio = {
|
||||
"1024x1024": "1:1",
|
||||
"512x512": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1024x768": "4:3",
|
||||
"768x1024": "3:4",
|
||||
}
|
||||
|
||||
if size in size_to_aspect_ratio:
|
||||
return size_to_aspect_ratio[size]
|
||||
|
||||
# Parse custom size format "WIDTHxHEIGHT" and calculate aspect ratio
|
||||
if "x" in size:
|
||||
try:
|
||||
width_str, height_str = size.split("x")
|
||||
width = int(width_str)
|
||||
height = int(height_str)
|
||||
|
||||
# Calculate aspect ratio and find closest match
|
||||
ratio = width / height
|
||||
|
||||
# Map to closest supported aspect ratio
|
||||
if 0.95 <= ratio <= 1.05: # Close to 1:1
|
||||
return "1:1"
|
||||
elif ratio >= 1.7: # Close to 16:9
|
||||
return "16:9"
|
||||
elif ratio <= 0.6: # Close to 9:16
|
||||
return "9:16"
|
||||
elif ratio >= 1.2: # Close to 4:3
|
||||
return "4:3"
|
||||
elif ratio <= 0.8: # Close to 3:4
|
||||
return "3:4"
|
||||
except (ValueError, AttributeError, ZeroDivisionError):
|
||||
pass
|
||||
|
||||
# Default to 1:1
|
||||
return "1:1"
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the image generation request to Imagen4 request body.
|
||||
|
||||
Required parameters:
|
||||
- prompt: The text prompt describing what you want to see
|
||||
|
||||
Optional parameters:
|
||||
- aspect_ratio: "1:1", "16:9", "9:16", "3:4", "4:3" (default: "1:1")
|
||||
- num_images: Number of images (1-4, default: 1)
|
||||
- resolution: "1K" or "2K" (default: "1K")
|
||||
- seed: Random seed for reproducibility
|
||||
- negative_prompt: Description of what to discourage (default: "")
|
||||
"""
|
||||
imagen4_request_body = {
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
return imagen4_request_body
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the Imagen4 response to litellm ImageResponse format.
|
||||
|
||||
Expected response format:
|
||||
{
|
||||
"images": [
|
||||
{
|
||||
"url": "https://...",
|
||||
"content_type": "image/png",
|
||||
"file_name": "z9RV14K95DvU.png",
|
||||
"file_size": 4404019
|
||||
}
|
||||
],
|
||||
"seed": 42
|
||||
}
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Handle Imagen4 response format
|
||||
images = response_data.get("images", [])
|
||||
if isinstance(images, list):
|
||||
for image_data in images:
|
||||
if isinstance(image_data, dict):
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
url=image_data.get("url", None),
|
||||
b64_json=None, # Imagen4 returns URLs only
|
||||
)
|
||||
)
|
||||
elif isinstance(image_data, str):
|
||||
# If images is just a list of URLs
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
url=image_data,
|
||||
b64_json=None,
|
||||
)
|
||||
)
|
||||
|
||||
# Add seed metadata from Imagen4 response
|
||||
if hasattr(model_response, "_hidden_params"):
|
||||
if "seed" in response_data:
|
||||
model_response._hidden_params["seed"] = response_data["seed"]
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,226 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
|
||||
from litellm.types.utils import ImageObject, ImageResponse
|
||||
|
||||
from .transformation import FalAIBaseConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class FalAIRecraftV3Config(FalAIBaseConfig):
|
||||
"""
|
||||
Configuration for Fal AI Recraft v3 Text-to-Image model.
|
||||
|
||||
Recraft v3 is a text-to-image model with multiple style options including
|
||||
realistic images, digital illustrations, and vector illustrations.
|
||||
|
||||
Model endpoint: fal-ai/recraft/v3/text-to-image
|
||||
Documentation: https://fal.ai/models/fal-ai/recraft/v3/text-to-image
|
||||
"""
|
||||
|
||||
IMAGE_GENERATION_ENDPOINT: str = "fal-ai/recraft/v3/text-to-image"
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
Get supported OpenAI parameters for Recraft v3.
|
||||
"""
|
||||
return [
|
||||
"n",
|
||||
"response_format",
|
||||
"size",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI parameters to Recraft v3 parameters.
|
||||
|
||||
Mappings:
|
||||
- size -> image_size (can be preset or custom width/height)
|
||||
- response_format -> ignored (Recraft returns URLs)
|
||||
- n -> ignored (Recraft doesn't support multiple images)
|
||||
"""
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
# Map OpenAI params to Recraft v3 params
|
||||
param_mapping = {
|
||||
"size": "image_size",
|
||||
}
|
||||
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
# Use mapped parameter name if exists
|
||||
mapped_key = param_mapping.get(k, k)
|
||||
mapped_value = non_default_params[k]
|
||||
|
||||
# Transform specific parameters
|
||||
if k == "response_format":
|
||||
# Recraft always returns URLs, so we can ignore this
|
||||
continue
|
||||
elif k == "n":
|
||||
# Recraft doesn't support multiple images, ignore
|
||||
continue
|
||||
elif k == "size":
|
||||
# Map OpenAI size format to Recraft image_size
|
||||
mapped_value = self._map_image_size(mapped_value)
|
||||
|
||||
optional_params[mapped_key] = mapped_value
|
||||
elif drop_params:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def _map_image_size(self, size: str) -> Any:
|
||||
"""
|
||||
Map OpenAI size format to Recraft v3 image_size format.
|
||||
|
||||
OpenAI format: "1024x1024", "1792x1024", etc.
|
||||
Recraft format: Can be preset strings or {"width": int, "height": int}
|
||||
|
||||
Available presets:
|
||||
- square_hd (default)
|
||||
- square
|
||||
- portrait_4_3
|
||||
- portrait_16_9
|
||||
- landscape_4_3
|
||||
- landscape_16_9
|
||||
"""
|
||||
# Map common OpenAI sizes to Recraft presets
|
||||
size_mapping = {
|
||||
"1024x1024": "square_hd",
|
||||
"512x512": "square",
|
||||
"768x1024": "portrait_4_3",
|
||||
"576x1024": "portrait_16_9",
|
||||
"1024x768": "landscape_4_3",
|
||||
"1024x576": "landscape_16_9",
|
||||
}
|
||||
|
||||
if size in size_mapping:
|
||||
return size_mapping[size]
|
||||
|
||||
# Parse custom size format "WIDTHxHEIGHT"
|
||||
if "x" in size:
|
||||
try:
|
||||
width, height = size.split("x")
|
||||
return {
|
||||
"width": int(width),
|
||||
"height": int(height),
|
||||
}
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
# Default to square_hd
|
||||
return "square_hd"
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the image generation request to Recraft v3 request body.
|
||||
|
||||
Required parameters:
|
||||
- prompt: Text prompt (max 1000 characters)
|
||||
|
||||
Optional parameters:
|
||||
- image_size: Preset or {"width": int, "height": int} (default: "square_hd")
|
||||
- style: Style preset (default: "realistic_image")
|
||||
Options: "any", "realistic_image", "digital_illustration", "vector_illustration", etc.
|
||||
- colors: Array of RGB color objects [{"r": 0-255, "g": 0-255, "b": 0-255}]
|
||||
- enable_safety_checker: Enable safety checker (default: false)
|
||||
- style_id: UUID for custom style reference
|
||||
|
||||
Note: Vector illustrations cost 2X as much.
|
||||
"""
|
||||
recraft_request_body = {
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
return recraft_request_body
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the Recraft v3 response to litellm ImageResponse format.
|
||||
|
||||
Expected response format:
|
||||
{
|
||||
"images": [
|
||||
{
|
||||
"url": "https://...",
|
||||
"content_type": "image/webp",
|
||||
"file_name": "...",
|
||||
"file_size": 123456
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Handle Recraft v3 response format
|
||||
images = response_data.get("images", [])
|
||||
if isinstance(images, list):
|
||||
for image_data in images:
|
||||
if isinstance(image_data, dict):
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
url=image_data.get("url", None),
|
||||
b64_json=None, # Recraft returns URLs only
|
||||
)
|
||||
)
|
||||
elif isinstance(image_data, str):
|
||||
# If images is just a list of URLs
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
url=image_data,
|
||||
b64_json=None,
|
||||
)
|
||||
)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,279 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
|
||||
from litellm.types.utils import ImageObject, ImageResponse
|
||||
|
||||
from .transformation import FalAIBaseConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class FalAIStableDiffusionConfig(FalAIBaseConfig):
|
||||
"""
|
||||
Configuration for Fal AI Stable Diffusion models.
|
||||
|
||||
Supports Stable Diffusion v3.5 variants and other Stable Diffusion models on Fal AI.
|
||||
|
||||
Example models:
|
||||
- fal-ai/stable-diffusion-v35-medium
|
||||
- fal-ai/stable-diffusion-v35-large
|
||||
|
||||
Documentation: https://fal.ai/models/fal-ai/stable-diffusion-v35-medium
|
||||
"""
|
||||
|
||||
IMAGE_GENERATION_ENDPOINT: str = "" # Will be set from model name
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request.
|
||||
|
||||
For Stable Diffusion models, extract the endpoint from the model name.
|
||||
"""
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
complete_url: str = (
|
||||
api_base or get_secret_str("FAL_AI_API_BASE") or self.DEFAULT_BASE_URL
|
||||
)
|
||||
|
||||
complete_url = complete_url.rstrip("/")
|
||||
|
||||
# Extract endpoint from model name
|
||||
# e.g., "fal-ai/stable-diffusion-v35-medium" or "stable-diffusion-v35-medium"
|
||||
endpoint = model
|
||||
if "/" in model and not model.startswith("fal-ai/"):
|
||||
# If model is like "custom/stable-diffusion-v35-medium", use full path
|
||||
endpoint = model
|
||||
elif not model.startswith("fal-ai/"):
|
||||
# If model is just "stable-diffusion-v35-medium", prepend fal-ai
|
||||
endpoint = f"fal-ai/{model}"
|
||||
|
||||
complete_url = f"{complete_url}/{endpoint}"
|
||||
return complete_url
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
Get supported OpenAI parameters for Stable Diffusion models.
|
||||
"""
|
||||
return [
|
||||
"n",
|
||||
"response_format",
|
||||
"size",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI parameters to Stable Diffusion parameters.
|
||||
|
||||
Mappings:
|
||||
- n -> num_images (1-4, default 1)
|
||||
- response_format -> output_format (jpeg or png)
|
||||
- size -> image_size (can be preset or custom width/height)
|
||||
"""
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
# Map OpenAI params to Stable Diffusion params
|
||||
param_mapping = {
|
||||
"n": "num_images",
|
||||
"response_format": "output_format",
|
||||
"size": "image_size",
|
||||
}
|
||||
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
# Use mapped parameter name if exists
|
||||
mapped_key = param_mapping.get(k, k)
|
||||
mapped_value = non_default_params[k]
|
||||
|
||||
# Transform specific parameters
|
||||
if k == "response_format":
|
||||
# Map OpenAI response formats to image formats
|
||||
if mapped_value in ["b64_json", "url"]:
|
||||
mapped_value = "jpeg"
|
||||
elif k == "size":
|
||||
# Map OpenAI size format to Stable Diffusion image_size
|
||||
mapped_value = self._map_image_size(mapped_value)
|
||||
|
||||
optional_params[mapped_key] = mapped_value
|
||||
elif drop_params:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def _map_image_size(self, size: str) -> Any:
|
||||
"""
|
||||
Map OpenAI size format to Stable Diffusion image_size format.
|
||||
|
||||
OpenAI format: "1024x1024", "1792x1024", etc.
|
||||
Stable Diffusion format: Can be preset strings or {"width": int, "height": int}
|
||||
|
||||
Available presets:
|
||||
- square_hd
|
||||
- square
|
||||
- portrait_4_3
|
||||
- portrait_16_9
|
||||
- landscape_4_3 (default)
|
||||
- landscape_16_9
|
||||
"""
|
||||
# Map common OpenAI sizes to Stable Diffusion presets
|
||||
size_mapping = {
|
||||
"1024x1024": "square_hd",
|
||||
"512x512": "square",
|
||||
"768x1024": "portrait_4_3",
|
||||
"576x1024": "portrait_16_9",
|
||||
"1024x768": "landscape_4_3",
|
||||
"1024x576": "landscape_16_9",
|
||||
}
|
||||
|
||||
if size in size_mapping:
|
||||
return size_mapping[size]
|
||||
|
||||
# Parse custom size format "WIDTHxHEIGHT"
|
||||
if "x" in size:
|
||||
try:
|
||||
width, height = size.split("x")
|
||||
return {
|
||||
"width": int(width),
|
||||
"height": int(height),
|
||||
}
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
# Default to landscape_4_3
|
||||
return "landscape_4_3"
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the image generation request to Stable Diffusion request body.
|
||||
|
||||
Required parameters:
|
||||
- prompt: The prompt to generate an image from
|
||||
|
||||
Optional parameters:
|
||||
- num_images: Number of images (1-4, default: 1)
|
||||
- image_size: Size preset or {"width": int, "height": int} (default: landscape_4_3)
|
||||
- output_format: "jpeg" or "png" (default: jpeg)
|
||||
- sync_mode: Wait for image upload before returning (default: false)
|
||||
- guidance_scale: CFG scale 0-20 (default: 4.5)
|
||||
- num_inference_steps: Inference steps 1-50 (default: 40)
|
||||
- seed: Random seed for reproducibility
|
||||
- negative_prompt: Negative prompt string (default: "")
|
||||
- enable_safety_checker: Enable safety checker (default: true)
|
||||
"""
|
||||
stable_diffusion_request_body = {
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
return stable_diffusion_request_body
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the Stable Diffusion response to litellm ImageResponse format.
|
||||
|
||||
Expected response format:
|
||||
{
|
||||
"images": [
|
||||
{
|
||||
"url": "https://...",
|
||||
"width": 1024,
|
||||
"height": 768,
|
||||
"content_type": "image/jpeg"
|
||||
}
|
||||
],
|
||||
"timings": {"inference": 2.5, ...},
|
||||
"seed": 42,
|
||||
"has_nsfw_concepts": [false],
|
||||
"prompt": "original prompt"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Handle Stable Diffusion response format
|
||||
images = response_data.get("images", [])
|
||||
if isinstance(images, list):
|
||||
for image_data in images:
|
||||
if isinstance(image_data, dict):
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
url=image_data.get("url", None),
|
||||
b64_json=None, # Stable Diffusion returns URLs only
|
||||
)
|
||||
)
|
||||
elif isinstance(image_data, str):
|
||||
# If images is just a list of URLs
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
url=image_data,
|
||||
b64_json=None,
|
||||
)
|
||||
)
|
||||
|
||||
# Add additional metadata from Stable Diffusion response
|
||||
if hasattr(model_response, "_hidden_params"):
|
||||
if "seed" in response_data:
|
||||
model_response._hidden_params["seed"] = response_data["seed"]
|
||||
if "timings" in response_data:
|
||||
model_response._hidden_params["timings"] = response_data["timings"]
|
||||
if "has_nsfw_concepts" in response_data:
|
||||
model_response._hidden_params["has_nsfw_concepts"] = response_data[
|
||||
"has_nsfw_concepts"
|
||||
]
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,175 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIImageGenerationOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import ImageObject, ImageResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class FalAIBaseConfig(BaseImageGenerationConfig):
|
||||
"""
|
||||
Base configuration for Fal AI image generation models.
|
||||
Handles common functionality like URL construction and authentication.
|
||||
"""
|
||||
|
||||
DEFAULT_BASE_URL: str = "https://fal.run"
|
||||
IMAGE_GENERATION_ENDPOINT: str = ""
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
complete_url: str = (
|
||||
api_base or get_secret_str("FAL_AI_API_BASE") or self.DEFAULT_BASE_URL
|
||||
)
|
||||
|
||||
complete_url = complete_url.rstrip("/")
|
||||
if self.IMAGE_GENERATION_ENDPOINT:
|
||||
complete_url = f"{complete_url}/{self.IMAGE_GENERATION_ENDPOINT}"
|
||||
return complete_url
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
final_api_key: Optional[str] = api_key or get_secret_str("FAL_AI_API_KEY")
|
||||
if not final_api_key:
|
||||
raise ValueError("FAL_AI_API_KEY is not set")
|
||||
|
||||
headers["Authorization"] = f"Key {final_api_key}"
|
||||
return headers
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the image generation response to the litellm image response
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Handle fal.ai response format
|
||||
images = response_data.get("images", [])
|
||||
if isinstance(images, list):
|
||||
for image_data in images:
|
||||
if isinstance(image_data, dict):
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
url=image_data.get("url", None),
|
||||
b64_json=image_data.get("b64_json", None),
|
||||
)
|
||||
)
|
||||
elif isinstance(image_data, str):
|
||||
# If images is just a list of URLs
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
url=image_data,
|
||||
b64_json=None,
|
||||
)
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
class FalAIImageGenerationConfig(FalAIBaseConfig):
|
||||
"""
|
||||
Default Fal AI image generation configuration for generic models.
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
Get supported OpenAI parameters for fal.ai image generation
|
||||
"""
|
||||
return [
|
||||
"n",
|
||||
"response_format",
|
||||
"size",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
optional_params[k] = non_default_params[k]
|
||||
elif drop_params:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the image generation request to the fal.ai image generation request body
|
||||
"""
|
||||
fal_ai_image_generation_request_body = {
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
return fal_ai_image_generation_request_body
|
||||
Reference in New Issue
Block a user