vilarin's picture
Upload 16 files
021dc80 verified
raw
history blame
8.14 kB
import io
import os
import time
from pathlib import Path
import requests
from PIL import Image
API_URL = "https://api.bfl.ml"
API_ENDPOINTS = {
"flux.1-pro": "flux-pro",
"flux.1-dev": "flux-dev",
"flux.1.1-pro": "flux-pro-1.1",
}
class ApiException(Exception):
def __init__(self, status_code: int, detail: str | list[dict] | None = None):
super().__init__()
self.detail = detail
self.status_code = status_code
def __str__(self) -> str:
return self.__repr__()
def __repr__(self) -> str:
if self.detail is None:
message = None
elif isinstance(self.detail, str):
message = self.detail
else:
message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
class ImageRequest:
def __init__(
self,
# api inputs
prompt: str,
name: str = "flux.1.1-pro",
width: int | None = None,
height: int | None = None,
num_steps: int | None = None,
prompt_upsampling: bool | None = None,
seed: int | None = None,
guidance: float | None = None,
interval: float | None = None,
safety_tolerance: int | None = None,
# behavior of this class
validate: bool = True,
launch: bool = True,
api_key: str | None = None,
):
"""
Manages an image generation request to the API.
All parameters not specified will use the API defaults.
Args:
prompt: Text prompt for image generation.
width: Width of the generated image in pixels. Must be a multiple of 32.
height: Height of the generated image in pixels. Must be a multiple of 32.
name: Which model version to use
num_steps: Number of steps for the image generation process.
prompt_upsampling: Whether to perform upsampling on the prompt.
seed: Optional seed for reproducibility.
guidance: Guidance scale for image generation.
safety_tolerance: Tolerance level for input and output moderation.
Between 0 and 6, 0 being most strict, 6 being least strict.
validate: Run input validation
launch: Directly launches request
api_key: Your API key if not provided by the environment
Raises:
ValueError: For invalid input, when `validate`
ApiException: For errors raised from the API
"""
if validate:
if name not in API_ENDPOINTS.keys():
raise ValueError(f"Invalid model {name}")
elif width is not None and width % 32 != 0:
raise ValueError(f"width must be divisible by 32, got {width}")
elif width is not None and not (256 <= width <= 1440):
raise ValueError(f"width must be between 256 and 1440, got {width}")
elif height is not None and height % 32 != 0:
raise ValueError(f"height must be divisible by 32, got {height}")
elif height is not None and not (256 <= height <= 1440):
raise ValueError(f"height must be between 256 and 1440, got {height}")
elif num_steps is not None and not (1 <= num_steps <= 50):
raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
elif guidance is not None and not (1.5 <= guidance <= 5.0):
raise ValueError(f"guidance must be between 1.5 and 4, got {guidance}")
elif interval is not None and not (1.0 <= interval <= 4.0):
raise ValueError(f"interval must be between 1 and 4, got {interval}")
elif safety_tolerance is not None and not (0 <= safety_tolerance <= 6.0):
raise ValueError(f"safety_tolerance must be between 0 and 6, got {interval}")
if name == "flux.1-dev":
if interval is not None:
raise ValueError("Interval is not supported for flux.1-dev")
if name == "flux.1.1-pro":
if interval is not None or num_steps is not None or guidance is not None:
raise ValueError("Interval, num_steps and guidance are not supported for " "flux.1.1-pro")
self.name = name
self.request_json = {
"prompt": prompt,
"width": width,
"height": height,
"steps": num_steps,
"prompt_upsampling": prompt_upsampling,
"seed": seed,
"guidance": guidance,
"interval": interval,
"safety_tolerance": safety_tolerance,
}
self.request_json = {key: value for key, value in self.request_json.items() if value is not None}
self.request_id: str | None = None
self.result: dict | None = None
self._image_bytes: bytes | None = None
self._url: str | None = None
if api_key is None:
self.api_key = os.environ.get("BFL_API_KEY")
else:
self.api_key = api_key
if launch:
self.request()
def request(self):
"""
Request to generate the image.
"""
if self.request_id is not None:
return
response = requests.post(
f"{API_URL}/v1/{API_ENDPOINTS[self.name]}",
headers={
"accept": "application/json",
"x-key": self.api_key,
"Content-Type": "application/json",
},
json=self.request_json,
)
result = response.json()
if response.status_code != 200:
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
self.request_id = response.json()["id"]
def retrieve(self) -> dict:
"""
Wait for the generation to finish and retrieve response.
"""
if self.request_id is None:
self.request()
while self.result is None:
response = requests.get(
f"{API_URL}/v1/get_result",
headers={
"accept": "application/json",
"x-key": self.api_key,
},
params={
"id": self.request_id,
},
)
result = response.json()
if "status" not in result:
raise ApiException(status_code=response.status_code, detail=result.get("detail"))
elif result["status"] == "Ready":
self.result = result["result"]
elif result["status"] == "Pending":
time.sleep(0.5)
else:
raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
return self.result
@property
def bytes(self) -> bytes:
"""
Generated image as bytes.
"""
if self._image_bytes is None:
response = requests.get(self.url)
if response.status_code == 200:
self._image_bytes = response.content
else:
raise ApiException(status_code=response.status_code)
return self._image_bytes
@property
def url(self) -> str:
"""
Public url to retrieve the image from
"""
if self._url is None:
result = self.retrieve()
self._url = result["sample"]
return self._url
@property
def image(self) -> Image.Image:
"""
Load the image as a PIL Image
"""
return Image.open(io.BytesIO(self.bytes))
def save(self, path: str):
"""
Save the generated image to a local path
"""
suffix = Path(self.url).suffix
if not path.endswith(suffix):
path = path + suffix
Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as file:
file.write(self.bytes)
if __name__ == "__main__":
from fire import Fire
Fire(ImageRequest)