File size: 8,135 Bytes
021dc80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import io
import os
import time
from pathlib import Path

import requests
from PIL import Image

API_URL = "https://api.bfl.ml"
API_ENDPOINTS = {
    "flux.1-pro": "flux-pro",
    "flux.1-dev": "flux-dev",
    "flux.1.1-pro": "flux-pro-1.1",
}


class ApiException(Exception):
    def __init__(self, status_code: int, detail: str | list[dict] | None = None):
        super().__init__()
        self.detail = detail
        self.status_code = status_code

    def __str__(self) -> str:
        return self.__repr__()

    def __repr__(self) -> str:
        if self.detail is None:
            message = None
        elif isinstance(self.detail, str):
            message = self.detail
        else:
            message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
        return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"


class ImageRequest:
    def __init__(
        self,
        # api inputs
        prompt: str,
        name: str = "flux.1.1-pro",
        width: int | None = None,
        height: int | None = None,
        num_steps: int | None = None,
        prompt_upsampling: bool | None = None,
        seed: int | None = None,
        guidance: float | None = None,
        interval: float | None = None,
        safety_tolerance: int | None = None,
        # behavior of this class
        validate: bool = True,
        launch: bool = True,
        api_key: str | None = None,
    ):
        """
        Manages an image generation request to the API.

        All parameters not specified will use the API defaults.

        Args:
            prompt: Text prompt for image generation.
            width: Width of the generated image in pixels. Must be a multiple of 32.
            height: Height of the generated image in pixels. Must be a multiple of 32.
            name: Which model version to use
            num_steps: Number of steps for the image generation process.
            prompt_upsampling: Whether to perform upsampling on the prompt.
            seed: Optional seed for reproducibility.
            guidance: Guidance scale for image generation.
            safety_tolerance: Tolerance level for input and output moderation.
                 Between 0 and 6, 0 being most strict, 6 being least strict.
            validate: Run input validation
            launch: Directly launches request
            api_key: Your API key if not provided by the environment

        Raises:
            ValueError: For invalid input, when `validate`
            ApiException: For errors raised from the API
        """
        if validate:
            if name not in API_ENDPOINTS.keys():
                raise ValueError(f"Invalid model {name}")
            elif width is not None and width % 32 != 0:
                raise ValueError(f"width must be divisible by 32, got {width}")
            elif width is not None and not (256 <= width <= 1440):
                raise ValueError(f"width must be between 256 and 1440, got {width}")
            elif height is not None and height % 32 != 0:
                raise ValueError(f"height must be divisible by 32, got {height}")
            elif height is not None and not (256 <= height <= 1440):
                raise ValueError(f"height must be between 256 and 1440, got {height}")
            elif num_steps is not None and not (1 <= num_steps <= 50):
                raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
            elif guidance is not None and not (1.5 <= guidance <= 5.0):
                raise ValueError(f"guidance must be between 1.5 and 4, got {guidance}")
            elif interval is not None and not (1.0 <= interval <= 4.0):
                raise ValueError(f"interval must be between 1 and 4, got {interval}")
            elif safety_tolerance is not None and not (0 <= safety_tolerance <= 6.0):
                raise ValueError(f"safety_tolerance must be between 0 and 6, got {interval}")

            if name == "flux.1-dev":
                if interval is not None:
                    raise ValueError("Interval is not supported for flux.1-dev")
            if name == "flux.1.1-pro":
                if interval is not None or num_steps is not None or guidance is not None:
                    raise ValueError("Interval, num_steps and guidance are not supported for " "flux.1.1-pro")

        self.name = name
        self.request_json = {
            "prompt": prompt,
            "width": width,
            "height": height,
            "steps": num_steps,
            "prompt_upsampling": prompt_upsampling,
            "seed": seed,
            "guidance": guidance,
            "interval": interval,
            "safety_tolerance": safety_tolerance,
        }
        self.request_json = {key: value for key, value in self.request_json.items() if value is not None}

        self.request_id: str | None = None
        self.result: dict | None = None
        self._image_bytes: bytes | None = None
        self._url: str | None = None
        if api_key is None:
            self.api_key = os.environ.get("BFL_API_KEY")
        else:
            self.api_key = api_key

        if launch:
            self.request()

    def request(self):
        """
        Request to generate the image.
        """
        if self.request_id is not None:
            return
        response = requests.post(
            f"{API_URL}/v1/{API_ENDPOINTS[self.name]}",
            headers={
                "accept": "application/json",
                "x-key": self.api_key,
                "Content-Type": "application/json",
            },
            json=self.request_json,
        )
        result = response.json()
        if response.status_code != 200:
            raise ApiException(status_code=response.status_code, detail=result.get("detail"))
        self.request_id = response.json()["id"]

    def retrieve(self) -> dict:
        """
        Wait for the generation to finish and retrieve response.
        """
        if self.request_id is None:
            self.request()
        while self.result is None:
            response = requests.get(
                f"{API_URL}/v1/get_result",
                headers={
                    "accept": "application/json",
                    "x-key": self.api_key,
                },
                params={
                    "id": self.request_id,
                },
            )
            result = response.json()
            if "status" not in result:
                raise ApiException(status_code=response.status_code, detail=result.get("detail"))
            elif result["status"] == "Ready":
                self.result = result["result"]
            elif result["status"] == "Pending":
                time.sleep(0.5)
            else:
                raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
        return self.result

    @property
    def bytes(self) -> bytes:
        """
        Generated image as bytes.
        """
        if self._image_bytes is None:
            response = requests.get(self.url)
            if response.status_code == 200:
                self._image_bytes = response.content
            else:
                raise ApiException(status_code=response.status_code)
        return self._image_bytes

    @property
    def url(self) -> str:
        """
        Public url to retrieve the image from
        """
        if self._url is None:
            result = self.retrieve()
            self._url = result["sample"]
        return self._url

    @property
    def image(self) -> Image.Image:
        """
        Load the image as a PIL Image
        """
        return Image.open(io.BytesIO(self.bytes))

    def save(self, path: str):
        """
        Save the generated image to a local path
        """
        suffix = Path(self.url).suffix
        if not path.endswith(suffix):
            path = path + suffix
        Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
        with open(path, "wb") as file:
            file.write(self.bytes)


if __name__ == "__main__":
    from fire import Fire

    Fire(ImageRequest)