vilarin commited on
Commit
021dc80
·
verified ·
1 Parent(s): 05fa434

Upload 16 files

Browse files
flux/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from ._version import (
3
+ version as __version__, # type: ignore
4
+ version_tuple,
5
+ )
6
+ except ImportError:
7
+ __version__ = "unknown (no version information available)"
8
+ version_tuple = (0, 0, "unknown", "noinfo")
9
+
10
+ from pathlib import Path
11
+
12
+ PACKAGE = __package__.replace("_", "-")
13
+ PACKAGE_ROOT = Path(__file__).parent
flux/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .cli import app
2
+
3
+ if __name__ == "__main__":
4
+ app()
flux/api.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+
6
+ import requests
7
+ from PIL import Image
8
+
9
+ API_URL = "https://api.bfl.ml"
10
+ API_ENDPOINTS = {
11
+ "flux.1-pro": "flux-pro",
12
+ "flux.1-dev": "flux-dev",
13
+ "flux.1.1-pro": "flux-pro-1.1",
14
+ }
15
+
16
+
17
+ class ApiException(Exception):
18
+ def __init__(self, status_code: int, detail: str | list[dict] | None = None):
19
+ super().__init__()
20
+ self.detail = detail
21
+ self.status_code = status_code
22
+
23
+ def __str__(self) -> str:
24
+ return self.__repr__()
25
+
26
+ def __repr__(self) -> str:
27
+ if self.detail is None:
28
+ message = None
29
+ elif isinstance(self.detail, str):
30
+ message = self.detail
31
+ else:
32
+ message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
33
+ return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
34
+
35
+
36
+ class ImageRequest:
37
+ def __init__(
38
+ self,
39
+ # api inputs
40
+ prompt: str,
41
+ name: str = "flux.1.1-pro",
42
+ width: int | None = None,
43
+ height: int | None = None,
44
+ num_steps: int | None = None,
45
+ prompt_upsampling: bool | None = None,
46
+ seed: int | None = None,
47
+ guidance: float | None = None,
48
+ interval: float | None = None,
49
+ safety_tolerance: int | None = None,
50
+ # behavior of this class
51
+ validate: bool = True,
52
+ launch: bool = True,
53
+ api_key: str | None = None,
54
+ ):
55
+ """
56
+ Manages an image generation request to the API.
57
+
58
+ All parameters not specified will use the API defaults.
59
+
60
+ Args:
61
+ prompt: Text prompt for image generation.
62
+ width: Width of the generated image in pixels. Must be a multiple of 32.
63
+ height: Height of the generated image in pixels. Must be a multiple of 32.
64
+ name: Which model version to use
65
+ num_steps: Number of steps for the image generation process.
66
+ prompt_upsampling: Whether to perform upsampling on the prompt.
67
+ seed: Optional seed for reproducibility.
68
+ guidance: Guidance scale for image generation.
69
+ safety_tolerance: Tolerance level for input and output moderation.
70
+ Between 0 and 6, 0 being most strict, 6 being least strict.
71
+ validate: Run input validation
72
+ launch: Directly launches request
73
+ api_key: Your API key if not provided by the environment
74
+
75
+ Raises:
76
+ ValueError: For invalid input, when `validate`
77
+ ApiException: For errors raised from the API
78
+ """
79
+ if validate:
80
+ if name not in API_ENDPOINTS.keys():
81
+ raise ValueError(f"Invalid model {name}")
82
+ elif width is not None and width % 32 != 0:
83
+ raise ValueError(f"width must be divisible by 32, got {width}")
84
+ elif width is not None and not (256 <= width <= 1440):
85
+ raise ValueError(f"width must be between 256 and 1440, got {width}")
86
+ elif height is not None and height % 32 != 0:
87
+ raise ValueError(f"height must be divisible by 32, got {height}")
88
+ elif height is not None and not (256 <= height <= 1440):
89
+ raise ValueError(f"height must be between 256 and 1440, got {height}")
90
+ elif num_steps is not None and not (1 <= num_steps <= 50):
91
+ raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
92
+ elif guidance is not None and not (1.5 <= guidance <= 5.0):
93
+ raise ValueError(f"guidance must be between 1.5 and 4, got {guidance}")
94
+ elif interval is not None and not (1.0 <= interval <= 4.0):
95
+ raise ValueError(f"interval must be between 1 and 4, got {interval}")
96
+ elif safety_tolerance is not None and not (0 <= safety_tolerance <= 6.0):
97
+ raise ValueError(f"safety_tolerance must be between 0 and 6, got {interval}")
98
+
99
+ if name == "flux.1-dev":
100
+ if interval is not None:
101
+ raise ValueError("Interval is not supported for flux.1-dev")
102
+ if name == "flux.1.1-pro":
103
+ if interval is not None or num_steps is not None or guidance is not None:
104
+ raise ValueError("Interval, num_steps and guidance are not supported for " "flux.1.1-pro")
105
+
106
+ self.name = name
107
+ self.request_json = {
108
+ "prompt": prompt,
109
+ "width": width,
110
+ "height": height,
111
+ "steps": num_steps,
112
+ "prompt_upsampling": prompt_upsampling,
113
+ "seed": seed,
114
+ "guidance": guidance,
115
+ "interval": interval,
116
+ "safety_tolerance": safety_tolerance,
117
+ }
118
+ self.request_json = {key: value for key, value in self.request_json.items() if value is not None}
119
+
120
+ self.request_id: str | None = None
121
+ self.result: dict | None = None
122
+ self._image_bytes: bytes | None = None
123
+ self._url: str | None = None
124
+ if api_key is None:
125
+ self.api_key = os.environ.get("BFL_API_KEY")
126
+ else:
127
+ self.api_key = api_key
128
+
129
+ if launch:
130
+ self.request()
131
+
132
+ def request(self):
133
+ """
134
+ Request to generate the image.
135
+ """
136
+ if self.request_id is not None:
137
+ return
138
+ response = requests.post(
139
+ f"{API_URL}/v1/{API_ENDPOINTS[self.name]}",
140
+ headers={
141
+ "accept": "application/json",
142
+ "x-key": self.api_key,
143
+ "Content-Type": "application/json",
144
+ },
145
+ json=self.request_json,
146
+ )
147
+ result = response.json()
148
+ if response.status_code != 200:
149
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
150
+ self.request_id = response.json()["id"]
151
+
152
+ def retrieve(self) -> dict:
153
+ """
154
+ Wait for the generation to finish and retrieve response.
155
+ """
156
+ if self.request_id is None:
157
+ self.request()
158
+ while self.result is None:
159
+ response = requests.get(
160
+ f"{API_URL}/v1/get_result",
161
+ headers={
162
+ "accept": "application/json",
163
+ "x-key": self.api_key,
164
+ },
165
+ params={
166
+ "id": self.request_id,
167
+ },
168
+ )
169
+ result = response.json()
170
+ if "status" not in result:
171
+ raise ApiException(status_code=response.status_code, detail=result.get("detail"))
172
+ elif result["status"] == "Ready":
173
+ self.result = result["result"]
174
+ elif result["status"] == "Pending":
175
+ time.sleep(0.5)
176
+ else:
177
+ raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
178
+ return self.result
179
+
180
+ @property
181
+ def bytes(self) -> bytes:
182
+ """
183
+ Generated image as bytes.
184
+ """
185
+ if self._image_bytes is None:
186
+ response = requests.get(self.url)
187
+ if response.status_code == 200:
188
+ self._image_bytes = response.content
189
+ else:
190
+ raise ApiException(status_code=response.status_code)
191
+ return self._image_bytes
192
+
193
+ @property
194
+ def url(self) -> str:
195
+ """
196
+ Public url to retrieve the image from
197
+ """
198
+ if self._url is None:
199
+ result = self.retrieve()
200
+ self._url = result["sample"]
201
+ return self._url
202
+
203
+ @property
204
+ def image(self) -> Image.Image:
205
+ """
206
+ Load the image as a PIL Image
207
+ """
208
+ return Image.open(io.BytesIO(self.bytes))
209
+
210
+ def save(self, path: str):
211
+ """
212
+ Save the generated image to a local path
213
+ """
214
+ suffix = Path(self.url).suffix
215
+ if not path.endswith(suffix):
216
+ path = path + suffix
217
+ Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
218
+ with open(path, "wb") as file:
219
+ file.write(self.bytes)
220
+
221
+
222
+ if __name__ == "__main__":
223
+ from fire import Fire
224
+
225
+ Fire(ImageRequest)
flux/cli.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+
7
+ import torch
8
+ from fire import Fire
9
+ from transformers import pipeline
10
+
11
+ from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
12
+ from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
13
+
14
+ NSFW_THRESHOLD = 0.85
15
+
16
+
17
+ @dataclass
18
+ class SamplingOptions:
19
+ prompt: str
20
+ width: int
21
+ height: int
22
+ num_steps: int
23
+ guidance: float
24
+ seed: int | None
25
+
26
+
27
+ def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
28
+ user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
29
+ usage = (
30
+ "Usage: Either write your prompt directly, leave this field empty "
31
+ "to repeat the prompt or write a command starting with a slash:\n"
32
+ "- '/w <width>' will set the width of the generated image\n"
33
+ "- '/h <height>' will set the height of the generated image\n"
34
+ "- '/s <seed>' sets the next seed\n"
35
+ "- '/g <guidance>' sets the guidance (flux-dev only)\n"
36
+ "- '/n <steps>' sets the number of steps\n"
37
+ "- '/q' to quit"
38
+ )
39
+
40
+ while (prompt := input(user_question)).startswith("/"):
41
+ if prompt.startswith("/w"):
42
+ if prompt.count(" ") != 1:
43
+ print(f"Got invalid command '{prompt}'\n{usage}")
44
+ continue
45
+ _, width = prompt.split()
46
+ options.width = 16 * (int(width) // 16)
47
+ print(
48
+ f"Setting resolution to {options.width} x {options.height} "
49
+ f"({options.height *options.width/1e6:.2f}MP)"
50
+ )
51
+ elif prompt.startswith("/h"):
52
+ if prompt.count(" ") != 1:
53
+ print(f"Got invalid command '{prompt}'\n{usage}")
54
+ continue
55
+ _, height = prompt.split()
56
+ options.height = 16 * (int(height) // 16)
57
+ print(
58
+ f"Setting resolution to {options.width} x {options.height} "
59
+ f"({options.height *options.width/1e6:.2f}MP)"
60
+ )
61
+ elif prompt.startswith("/g"):
62
+ if prompt.count(" ") != 1:
63
+ print(f"Got invalid command '{prompt}'\n{usage}")
64
+ continue
65
+ _, guidance = prompt.split()
66
+ options.guidance = float(guidance)
67
+ print(f"Setting guidance to {options.guidance}")
68
+ elif prompt.startswith("/s"):
69
+ if prompt.count(" ") != 1:
70
+ print(f"Got invalid command '{prompt}'\n{usage}")
71
+ continue
72
+ _, seed = prompt.split()
73
+ options.seed = int(seed)
74
+ print(f"Setting seed to {options.seed}")
75
+ elif prompt.startswith("/n"):
76
+ if prompt.count(" ") != 1:
77
+ print(f"Got invalid command '{prompt}'\n{usage}")
78
+ continue
79
+ _, steps = prompt.split()
80
+ options.num_steps = int(steps)
81
+ print(f"Setting number of steps to {options.num_steps}")
82
+ elif prompt.startswith("/q"):
83
+ print("Quitting")
84
+ return None
85
+ else:
86
+ if not prompt.startswith("/h"):
87
+ print(f"Got invalid command '{prompt}'\n{usage}")
88
+ print(usage)
89
+ if prompt != "":
90
+ options.prompt = prompt
91
+ return options
92
+
93
+
94
+ @torch.inference_mode()
95
+ def main(
96
+ name: str = "flux-schnell",
97
+ width: int = 1360,
98
+ height: int = 768,
99
+ seed: int | None = None,
100
+ prompt: str = (
101
+ "a photo of a forest with mist swirling around the tree trunks. The word "
102
+ '"FLUX" is painted over it in big, red brush strokes with visible texture'
103
+ ),
104
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
105
+ num_steps: int | None = None,
106
+ loop: bool = False,
107
+ guidance: float = 3.5,
108
+ offload: bool = False,
109
+ output_dir: str = "output",
110
+ add_sampling_metadata: bool = True,
111
+ ):
112
+ """
113
+ Sample the flux model. Either interactively (set `--loop`) or run for a
114
+ single image.
115
+
116
+ Args:
117
+ name: Name of the model to load
118
+ height: height of the sample in pixels (should be a multiple of 16)
119
+ width: width of the sample in pixels (should be a multiple of 16)
120
+ seed: Set a seed for sampling
121
+ output_name: where to save the output image, `{idx}` will be replaced
122
+ by the index of the sample
123
+ prompt: Prompt used for sampling
124
+ device: Pytorch device
125
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
126
+ loop: start an interactive session and sample multiple times
127
+ guidance: guidance value used for guidance distillation
128
+ add_sampling_metadata: Add the prompt to the image Exif metadata
129
+ """
130
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
131
+
132
+ if name not in configs:
133
+ available = ", ".join(configs.keys())
134
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
135
+
136
+ torch_device = torch.device(device)
137
+ if num_steps is None:
138
+ num_steps = 4 if name == "flux-schnell" else 50
139
+
140
+ # allow for packing and conversion to latent space
141
+ height = 16 * (height // 16)
142
+ width = 16 * (width // 16)
143
+
144
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
145
+ if not os.path.exists(output_dir):
146
+ os.makedirs(output_dir)
147
+ idx = 0
148
+ else:
149
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
150
+ if len(fns) > 0:
151
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
152
+ else:
153
+ idx = 0
154
+
155
+ # init all components
156
+ t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
157
+ clip = load_clip(torch_device)
158
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
159
+ ae = load_ae(name, device="cpu" if offload else torch_device)
160
+
161
+ rng = torch.Generator(device="cpu")
162
+ opts = SamplingOptions(
163
+ prompt=prompt,
164
+ width=width,
165
+ height=height,
166
+ num_steps=num_steps,
167
+ guidance=guidance,
168
+ seed=seed,
169
+ )
170
+
171
+ if loop:
172
+ opts = parse_prompt(opts)
173
+
174
+ while opts is not None:
175
+ if opts.seed is None:
176
+ opts.seed = rng.seed()
177
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
178
+ t0 = time.perf_counter()
179
+
180
+ # prepare input
181
+ x = get_noise(
182
+ 1,
183
+ opts.height,
184
+ opts.width,
185
+ device=torch_device,
186
+ dtype=torch.bfloat16,
187
+ seed=opts.seed,
188
+ )
189
+ opts.seed = None
190
+ if offload:
191
+ ae = ae.cpu()
192
+ torch.cuda.empty_cache()
193
+ t5, clip = t5.to(torch_device), clip.to(torch_device)
194
+ inp = prepare(t5, clip, x, prompt=opts.prompt)
195
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
196
+
197
+ # offload TEs to CPU, load model to gpu
198
+ if offload:
199
+ t5, clip = t5.cpu(), clip.cpu()
200
+ torch.cuda.empty_cache()
201
+ model = model.to(torch_device)
202
+
203
+ # denoise initial noise
204
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
205
+
206
+ # offload model, load autoencoder to gpu
207
+ if offload:
208
+ model.cpu()
209
+ torch.cuda.empty_cache()
210
+ ae.decoder.to(x.device)
211
+
212
+ # decode latents to pixel space
213
+ x = unpack(x.float(), opts.height, opts.width)
214
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
215
+ x = ae.decode(x)
216
+
217
+ if torch.cuda.is_available():
218
+ torch.cuda.synchronize()
219
+ t1 = time.perf_counter()
220
+
221
+ fn = output_name.format(idx=idx)
222
+ print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
223
+
224
+ idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt)
225
+
226
+ if loop:
227
+ print("-" * 80)
228
+ opts = parse_prompt(opts)
229
+ else:
230
+ opts = None
231
+
232
+
233
+ def app():
234
+ Fire(main)
235
+
236
+
237
+ if __name__ == "__main__":
238
+ app()
flux/cli_control.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+
7
+ import torch
8
+ from fire import Fire
9
+ from transformers import pipeline
10
+
11
+ from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder
12
+ from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack
13
+ from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
14
+
15
+
16
+ @dataclass
17
+ class SamplingOptions:
18
+ prompt: str
19
+ width: int
20
+ height: int
21
+ num_steps: int
22
+ guidance: float
23
+ seed: int | None
24
+ img_cond_path: str
25
+ lora_scale: float | None
26
+
27
+
28
+ def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
29
+ user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
30
+ usage = (
31
+ "Usage: Either write your prompt directly, leave this field empty "
32
+ "to repeat the prompt or write a command starting with a slash:\n"
33
+ "- '/w <width>' will set the width of the generated image\n"
34
+ "- '/h <height>' will set the height of the generated image\n"
35
+ "- '/s <seed>' sets the next seed\n"
36
+ "- '/g <guidance>' sets the guidance (flux-dev only)\n"
37
+ "- '/n <steps>' sets the number of steps\n"
38
+ "- '/q' to quit"
39
+ )
40
+
41
+ while (prompt := input(user_question)).startswith("/"):
42
+ if prompt.startswith("/w"):
43
+ if prompt.count(" ") != 1:
44
+ print(f"Got invalid command '{prompt}'\n{usage}")
45
+ continue
46
+ _, width = prompt.split()
47
+ options.width = 16 * (int(width) // 16)
48
+ print(
49
+ f"Setting resolution to {options.width} x {options.height} "
50
+ f"({options.height *options.width/1e6:.2f}MP)"
51
+ )
52
+ elif prompt.startswith("/h"):
53
+ if prompt.count(" ") != 1:
54
+ print(f"Got invalid command '{prompt}'\n{usage}")
55
+ continue
56
+ _, height = prompt.split()
57
+ options.height = 16 * (int(height) // 16)
58
+ print(
59
+ f"Setting resolution to {options.width} x {options.height} "
60
+ f"({options.height *options.width/1e6:.2f}MP)"
61
+ )
62
+ elif prompt.startswith("/g"):
63
+ if prompt.count(" ") != 1:
64
+ print(f"Got invalid command '{prompt}'\n{usage}")
65
+ continue
66
+ _, guidance = prompt.split()
67
+ options.guidance = float(guidance)
68
+ print(f"Setting guidance to {options.guidance}")
69
+ elif prompt.startswith("/s"):
70
+ if prompt.count(" ") != 1:
71
+ print(f"Got invalid command '{prompt}'\n{usage}")
72
+ continue
73
+ _, seed = prompt.split()
74
+ options.seed = int(seed)
75
+ print(f"Setting seed to {options.seed}")
76
+ elif prompt.startswith("/n"):
77
+ if prompt.count(" ") != 1:
78
+ print(f"Got invalid command '{prompt}'\n{usage}")
79
+ continue
80
+ _, steps = prompt.split()
81
+ options.num_steps = int(steps)
82
+ print(f"Setting number of steps to {options.num_steps}")
83
+ elif prompt.startswith("/q"):
84
+ print("Quitting")
85
+ return None
86
+ else:
87
+ if not prompt.startswith("/h"):
88
+ print(f"Got invalid command '{prompt}'\n{usage}")
89
+ print(usage)
90
+ if prompt != "":
91
+ options.prompt = prompt
92
+ return options
93
+
94
+
95
+ def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
96
+ if options is None:
97
+ return None
98
+
99
+ user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
100
+ usage = (
101
+ "Usage: Either write your prompt directly, leave this field empty "
102
+ "to repeat the conditioning image or write a command starting with a slash:\n"
103
+ "- '/q' to quit"
104
+ )
105
+
106
+ while True:
107
+ img_cond_path = input(user_question)
108
+
109
+ if img_cond_path.startswith("/"):
110
+ if img_cond_path.startswith("/q"):
111
+ print("Quitting")
112
+ return None
113
+ else:
114
+ if not img_cond_path.startswith("/h"):
115
+ print(f"Got invalid command '{img_cond_path}'\n{usage}")
116
+ print(usage)
117
+ continue
118
+
119
+ if img_cond_path == "":
120
+ break
121
+
122
+ if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
123
+ (".jpg", ".jpeg", ".png", ".webp")
124
+ ):
125
+ print(f"File '{img_cond_path}' does not exist or is not a valid image file")
126
+ continue
127
+
128
+ options.img_cond_path = img_cond_path
129
+ break
130
+
131
+ return options
132
+
133
+
134
+ def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingOptions | None, bool]:
135
+ changed = False
136
+
137
+ if options is None:
138
+ return None, changed
139
+
140
+ user_question = "Next lora scale (write /h for help, /q to quit and leave empty to repeat):\n"
141
+ usage = (
142
+ "Usage: Either write your prompt directly, leave this field empty "
143
+ "to repeat the lora scale or write a command starting with a slash:\n"
144
+ "- '/q' to quit"
145
+ )
146
+
147
+ while (prompt := input(user_question)).startswith("/"):
148
+ if prompt.startswith("/q"):
149
+ print("Quitting")
150
+ return None, changed
151
+ else:
152
+ if not prompt.startswith("/h"):
153
+ print(f"Got invalid command '{prompt}'\n{usage}")
154
+ print(usage)
155
+ if prompt != "":
156
+ options.lora_scale = float(prompt)
157
+ changed = True
158
+ return options, changed
159
+
160
+
161
+ @torch.inference_mode()
162
+ def main(
163
+ name: str,
164
+ width: int = 1024,
165
+ height: int = 1024,
166
+ seed: int | None = None,
167
+ prompt: str = "a robot made out of gold",
168
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
169
+ num_steps: int = 50,
170
+ loop: bool = False,
171
+ guidance: float | None = None,
172
+ offload: bool = False,
173
+ output_dir: str = "output",
174
+ add_sampling_metadata: bool = True,
175
+ img_cond_path: str = "assets/robot.webp",
176
+ lora_scale: float | None = 0.85,
177
+ ):
178
+ """
179
+ Sample the flux model. Either interactively (set `--loop`) or run for a
180
+ single image.
181
+
182
+ Args:
183
+ height: height of the sample in pixels (should be a multiple of 16)
184
+ width: width of the sample in pixels (should be a multiple of 16)
185
+ seed: Set a seed for sampling
186
+ output_name: where to save the output image, `{idx}` will be replaced
187
+ by the index of the sample
188
+ prompt: Prompt used for sampling
189
+ device: Pytorch device
190
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
191
+ loop: start an interactive session and sample multiple times
192
+ guidance: guidance value used for guidance distillation
193
+ add_sampling_metadata: Add the prompt to the image Exif metadata
194
+ img_cond_path: path to conditioning image (jpeg/png/webp)
195
+ """
196
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
197
+
198
+ assert name in [
199
+ "flux-dev-canny",
200
+ "flux-dev-depth",
201
+ "flux-dev-canny-lora",
202
+ "flux-dev-depth-lora",
203
+ ], f"Got unknown model name: {name}"
204
+ if guidance is None:
205
+ if name in ["flux-dev-canny", "flux-dev-canny-lora"]:
206
+ guidance = 30.0
207
+ elif name in ["flux-dev-depth", "flux-dev-depth-lora"]:
208
+ guidance = 10.0
209
+ else:
210
+ raise NotImplementedError()
211
+
212
+ if name not in configs:
213
+ available = ", ".join(configs.keys())
214
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
215
+
216
+ torch_device = torch.device(device)
217
+
218
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
219
+ if not os.path.exists(output_dir):
220
+ os.makedirs(output_dir)
221
+ idx = 0
222
+ else:
223
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
224
+ if len(fns) > 0:
225
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
226
+ else:
227
+ idx = 0
228
+
229
+ # init all components
230
+ t5 = load_t5(torch_device, max_length=512)
231
+ clip = load_clip(torch_device)
232
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
233
+ ae = load_ae(name, device="cpu" if offload else torch_device)
234
+
235
+ # set lora scale
236
+ if "lora" in name and lora_scale is not None:
237
+ for _, module in model.named_modules():
238
+ if hasattr(module, "set_scale"):
239
+ module.set_scale(lora_scale)
240
+
241
+ if name in ["flux-dev-depth", "flux-dev-depth-lora"]:
242
+ img_embedder = DepthImageEncoder(torch_device)
243
+ elif name in ["flux-dev-canny", "flux-dev-canny-lora"]:
244
+ img_embedder = CannyImageEncoder(torch_device)
245
+ else:
246
+ raise NotImplementedError()
247
+
248
+ rng = torch.Generator(device="cpu")
249
+ opts = SamplingOptions(
250
+ prompt=prompt,
251
+ width=width,
252
+ height=height,
253
+ num_steps=num_steps,
254
+ guidance=guidance,
255
+ seed=seed,
256
+ img_cond_path=img_cond_path,
257
+ lora_scale=lora_scale,
258
+ )
259
+
260
+ if loop:
261
+ opts = parse_prompt(opts)
262
+ opts = parse_img_cond_path(opts)
263
+ if "lora" in name:
264
+ opts, changed = parse_lora_scale(opts)
265
+ if changed:
266
+ # update the lora scale:
267
+ for _, module in model.named_modules():
268
+ if hasattr(module, "set_scale"):
269
+ module.set_scale(opts.lora_scale)
270
+
271
+ while opts is not None:
272
+ if opts.seed is None:
273
+ opts.seed = rng.seed()
274
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
275
+ t0 = time.perf_counter()
276
+
277
+ # prepare input
278
+ x = get_noise(
279
+ 1,
280
+ opts.height,
281
+ opts.width,
282
+ device=torch_device,
283
+ dtype=torch.bfloat16,
284
+ seed=opts.seed,
285
+ )
286
+ opts.seed = None
287
+ if offload:
288
+ t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
289
+ inp = prepare_control(
290
+ t5,
291
+ clip,
292
+ x,
293
+ prompt=opts.prompt,
294
+ ae=ae,
295
+ encoder=img_embedder,
296
+ img_cond_path=opts.img_cond_path,
297
+ )
298
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
299
+
300
+ # offload TEs and AE to CPU, load model to gpu
301
+ if offload:
302
+ t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
303
+ torch.cuda.empty_cache()
304
+ model = model.to(torch_device)
305
+
306
+ # denoise initial noise
307
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
308
+
309
+ # offload model, load autoencoder to gpu
310
+ if offload:
311
+ model.cpu()
312
+ torch.cuda.empty_cache()
313
+ ae.decoder.to(x.device)
314
+
315
+ # decode latents to pixel space
316
+ x = unpack(x.float(), opts.height, opts.width)
317
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
318
+ x = ae.decode(x)
319
+
320
+ if torch.cuda.is_available():
321
+ torch.cuda.synchronize()
322
+ t1 = time.perf_counter()
323
+ print(f"Done in {t1 - t0:.1f}s")
324
+
325
+ idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt)
326
+
327
+ if loop:
328
+ print("-" * 80)
329
+ opts = parse_prompt(opts)
330
+ opts = parse_img_cond_path(opts)
331
+ if "lora" in name:
332
+ opts, changed = parse_lora_scale(opts)
333
+ if changed:
334
+ # update the lora scale:
335
+ for _, module in model.named_modules():
336
+ if hasattr(module, "set_scale"):
337
+ module.set_scale(opts.lora_scale)
338
+ else:
339
+ opts = None
340
+
341
+
342
+ def app():
343
+ Fire(main)
344
+
345
+
346
+ if __name__ == "__main__":
347
+ app()
flux/cli_fill.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+
7
+ import torch
8
+ from fire import Fire
9
+ from PIL import Image
10
+ from transformers import pipeline
11
+
12
+ from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack
13
+ from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
14
+
15
+
16
+ @dataclass
17
+ class SamplingOptions:
18
+ prompt: str
19
+ width: int
20
+ height: int
21
+ num_steps: int
22
+ guidance: float
23
+ seed: int | None
24
+ img_cond_path: str
25
+ img_mask_path: str
26
+
27
+
28
+ def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
29
+ user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
30
+ usage = (
31
+ "Usage: Either write your prompt directly, leave this field empty "
32
+ "to repeat the prompt or write a command starting with a slash:\n"
33
+ "- '/s <seed>' sets the next seed\n"
34
+ "- '/g <guidance>' sets the guidance (flux-dev only)\n"
35
+ "- '/n <steps>' sets the number of steps\n"
36
+ "- '/q' to quit"
37
+ )
38
+
39
+ while (prompt := input(user_question)).startswith("/"):
40
+ if prompt.startswith("/g"):
41
+ if prompt.count(" ") != 1:
42
+ print(f"Got invalid command '{prompt}'\n{usage}")
43
+ continue
44
+ _, guidance = prompt.split()
45
+ options.guidance = float(guidance)
46
+ print(f"Setting guidance to {options.guidance}")
47
+ elif prompt.startswith("/s"):
48
+ if prompt.count(" ") != 1:
49
+ print(f"Got invalid command '{prompt}'\n{usage}")
50
+ continue
51
+ _, seed = prompt.split()
52
+ options.seed = int(seed)
53
+ print(f"Setting seed to {options.seed}")
54
+ elif prompt.startswith("/n"):
55
+ if prompt.count(" ") != 1:
56
+ print(f"Got invalid command '{prompt}'\n{usage}")
57
+ continue
58
+ _, steps = prompt.split()
59
+ options.num_steps = int(steps)
60
+ print(f"Setting number of steps to {options.num_steps}")
61
+ elif prompt.startswith("/q"):
62
+ print("Quitting")
63
+ return None
64
+ else:
65
+ if not prompt.startswith("/h"):
66
+ print(f"Got invalid command '{prompt}'\n{usage}")
67
+ print(usage)
68
+ if prompt != "":
69
+ options.prompt = prompt
70
+ return options
71
+
72
+
73
+ def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
74
+ if options is None:
75
+ return None
76
+
77
+ user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
78
+ usage = (
79
+ "Usage: Either write your prompt directly, leave this field empty "
80
+ "to repeat the conditioning image or write a command starting with a slash:\n"
81
+ "- '/q' to quit"
82
+ )
83
+
84
+ while True:
85
+ img_cond_path = input(user_question)
86
+
87
+ if img_cond_path.startswith("/"):
88
+ if img_cond_path.startswith("/q"):
89
+ print("Quitting")
90
+ return None
91
+ else:
92
+ if not img_cond_path.startswith("/h"):
93
+ print(f"Got invalid command '{img_cond_path}'\n{usage}")
94
+ print(usage)
95
+ continue
96
+
97
+ if img_cond_path == "":
98
+ break
99
+
100
+ if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
101
+ (".jpg", ".jpeg", ".png", ".webp")
102
+ ):
103
+ print(f"File '{img_cond_path}' does not exist or is not a valid image file")
104
+ continue
105
+ else:
106
+ with Image.open(img_cond_path) as img:
107
+ width, height = img.size
108
+
109
+ if width % 32 != 0 or height % 32 != 0:
110
+ print(f"Image dimensions must be divisible by 32, got {width}x{height}")
111
+ continue
112
+
113
+ options.img_cond_path = img_cond_path
114
+ break
115
+
116
+ return options
117
+
118
+
119
+ def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | None:
120
+ if options is None:
121
+ return None
122
+
123
+ user_question = "Next conditioning mask (write /h for help, /q to quit and leave empty to repeat):\n"
124
+ usage = (
125
+ "Usage: Either write your prompt directly, leave this field empty "
126
+ "to repeat the conditioning mask or write a command starting with a slash:\n"
127
+ "- '/q' to quit"
128
+ )
129
+
130
+ while True:
131
+ img_mask_path = input(user_question)
132
+
133
+ if img_mask_path.startswith("/"):
134
+ if img_mask_path.startswith("/q"):
135
+ print("Quitting")
136
+ return None
137
+ else:
138
+ if not img_mask_path.startswith("/h"):
139
+ print(f"Got invalid command '{img_mask_path}'\n{usage}")
140
+ print(usage)
141
+ continue
142
+
143
+ if img_mask_path == "":
144
+ break
145
+
146
+ if not os.path.isfile(img_mask_path) or not img_mask_path.lower().endswith(
147
+ (".jpg", ".jpeg", ".png", ".webp")
148
+ ):
149
+ print(f"File '{img_mask_path}' does not exist or is not a valid image file")
150
+ continue
151
+ else:
152
+ with Image.open(img_mask_path) as img:
153
+ width, height = img.size
154
+
155
+ if width % 32 != 0 or height % 32 != 0:
156
+ print(f"Image dimensions must be divisible by 32, got {width}x{height}")
157
+ continue
158
+ else:
159
+ with Image.open(options.img_cond_path) as img_cond:
160
+ img_cond_width, img_cond_height = img_cond.size
161
+
162
+ if width != img_cond_width or height != img_cond_height:
163
+ print(
164
+ f"Mask dimensions must match conditioning image, got {width}x{height} and {img_cond_width}x{img_cond_height}"
165
+ )
166
+ continue
167
+
168
+ options.img_mask_path = img_mask_path
169
+ break
170
+
171
+ return options
172
+
173
+
174
+ @torch.inference_mode()
175
+ def main(
176
+ seed: int | None = None,
177
+ prompt: str = "a white paper cup",
178
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
179
+ num_steps: int = 50,
180
+ loop: bool = False,
181
+ guidance: float = 30.0,
182
+ offload: bool = False,
183
+ output_dir: str = "output",
184
+ add_sampling_metadata: bool = True,
185
+ img_cond_path: str = "assets/cup.png",
186
+ img_mask_path: str = "assets/cup_mask.png",
187
+ ):
188
+ """
189
+ Sample the flux model. Either interactively (set `--loop`) or run for a
190
+ single image. This demo assumes that the conditioning image and mask have
191
+ the same shape and that height and width are divisible by 32.
192
+
193
+ Args:
194
+ seed: Set a seed for sampling
195
+ output_name: where to save the output image, `{idx}` will be replaced
196
+ by the index of the sample
197
+ prompt: Prompt used for sampling
198
+ device: Pytorch device
199
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
200
+ loop: start an interactive session and sample multiple times
201
+ guidance: guidance value used for guidance distillation
202
+ add_sampling_metadata: Add the prompt to the image Exif metadata
203
+ img_cond_path: path to conditioning image (jpeg/png/webp)
204
+ img_mask_path: path to conditioning mask (jpeg/png/webp
205
+ """
206
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
207
+
208
+ name = "flux-dev-fill"
209
+ if name not in configs:
210
+ available = ", ".join(configs.keys())
211
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
212
+
213
+ torch_device = torch.device(device)
214
+
215
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
216
+ if not os.path.exists(output_dir):
217
+ os.makedirs(output_dir)
218
+ idx = 0
219
+ else:
220
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
221
+ if len(fns) > 0:
222
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
223
+ else:
224
+ idx = 0
225
+
226
+ # init all components
227
+ t5 = load_t5(torch_device, max_length=128)
228
+ clip = load_clip(torch_device)
229
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
230
+ ae = load_ae(name, device="cpu" if offload else torch_device)
231
+
232
+ rng = torch.Generator(device="cpu")
233
+ with Image.open(img_cond_path) as img:
234
+ width, height = img.size
235
+ opts = SamplingOptions(
236
+ prompt=prompt,
237
+ width=width,
238
+ height=height,
239
+ num_steps=num_steps,
240
+ guidance=guidance,
241
+ seed=seed,
242
+ img_cond_path=img_cond_path,
243
+ img_mask_path=img_mask_path,
244
+ )
245
+
246
+ if loop:
247
+ opts = parse_prompt(opts)
248
+ opts = parse_img_cond_path(opts)
249
+
250
+ with Image.open(opts.img_cond_path) as img:
251
+ width, height = img.size
252
+ opts.height = height
253
+ opts.width = width
254
+
255
+ opts = parse_img_mask_path(opts)
256
+
257
+ while opts is not None:
258
+ if opts.seed is None:
259
+ opts.seed = rng.seed()
260
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
261
+ t0 = time.perf_counter()
262
+
263
+ # prepare input
264
+ x = get_noise(
265
+ 1,
266
+ opts.height,
267
+ opts.width,
268
+ device=torch_device,
269
+ dtype=torch.bfloat16,
270
+ seed=opts.seed,
271
+ )
272
+ opts.seed = None
273
+ if offload:
274
+ t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch.device)
275
+ inp = prepare_fill(
276
+ t5,
277
+ clip,
278
+ x,
279
+ prompt=opts.prompt,
280
+ ae=ae,
281
+ img_cond_path=opts.img_cond_path,
282
+ mask_path=opts.img_mask_path,
283
+ )
284
+
285
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
286
+
287
+ # offload TEs and AE to CPU, load model to gpu
288
+ if offload:
289
+ t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
290
+ torch.cuda.empty_cache()
291
+ model = model.to(torch_device)
292
+
293
+ # denoise initial noise
294
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
295
+
296
+ # offload model, load autoencoder to gpu
297
+ if offload:
298
+ model.cpu()
299
+ torch.cuda.empty_cache()
300
+ ae.decoder.to(x.device)
301
+
302
+ # decode latents to pixel space
303
+ x = unpack(x.float(), opts.height, opts.width)
304
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
305
+ x = ae.decode(x)
306
+
307
+ if torch.cuda.is_available():
308
+ torch.cuda.synchronize()
309
+ t1 = time.perf_counter()
310
+ print(f"Done in {t1 - t0:.1f}s")
311
+
312
+ idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt)
313
+
314
+ if loop:
315
+ print("-" * 80)
316
+ opts = parse_prompt(opts)
317
+ opts = parse_img_cond_path(opts)
318
+
319
+ with Image.open(opts.img_cond_path) as img:
320
+ width, height = img.size
321
+ opts.height = height
322
+ opts.width = width
323
+
324
+ opts = parse_img_mask_path(opts)
325
+ else:
326
+ opts = None
327
+
328
+
329
+ def app():
330
+ Fire(main)
331
+
332
+
333
+ if __name__ == "__main__":
334
+ app()
flux/cli_redux.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import time
4
+ from dataclasses import dataclass
5
+ from glob import iglob
6
+
7
+ import torch
8
+ from fire import Fire
9
+ from transformers import pipeline
10
+
11
+ from flux.modules.image_embedders import ReduxImageEncoder
12
+ from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack
13
+ from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
14
+
15
+
16
+ @dataclass
17
+ class SamplingOptions:
18
+ prompt: str
19
+ width: int
20
+ height: int
21
+ num_steps: int
22
+ guidance: float
23
+ seed: int | None
24
+ img_cond_path: str
25
+
26
+
27
+ def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
28
+ user_question = "Write /h for help, /q to quit and leave empty to repeat):\n"
29
+ usage = (
30
+ "Usage: Leave this field empty to do nothing "
31
+ "or write a command starting with a slash:\n"
32
+ "- '/w <width>' will set the width of the generated image\n"
33
+ "- '/h <height>' will set the height of the generated image\n"
34
+ "- '/s <seed>' sets the next seed\n"
35
+ "- '/g <guidance>' sets the guidance (flux-dev only)\n"
36
+ "- '/n <steps>' sets the number of steps\n"
37
+ "- '/q' to quit"
38
+ )
39
+
40
+ while (prompt := input(user_question)).startswith("/"):
41
+ if prompt.startswith("/w"):
42
+ if prompt.count(" ") != 1:
43
+ print(f"Got invalid command '{prompt}'\n{usage}")
44
+ continue
45
+ _, width = prompt.split()
46
+ options.width = 16 * (int(width) // 16)
47
+ print(
48
+ f"Setting resolution to {options.width} x {options.height} "
49
+ f"({options.height *options.width/1e6:.2f}MP)"
50
+ )
51
+ elif prompt.startswith("/h"):
52
+ if prompt.count(" ") != 1:
53
+ print(f"Got invalid command '{prompt}'\n{usage}")
54
+ continue
55
+ _, height = prompt.split()
56
+ options.height = 16 * (int(height) // 16)
57
+ print(
58
+ f"Setting resolution to {options.width} x {options.height} "
59
+ f"({options.height *options.width/1e6:.2f}MP)"
60
+ )
61
+ elif prompt.startswith("/g"):
62
+ if prompt.count(" ") != 1:
63
+ print(f"Got invalid command '{prompt}'\n{usage}")
64
+ continue
65
+ _, guidance = prompt.split()
66
+ options.guidance = float(guidance)
67
+ print(f"Setting guidance to {options.guidance}")
68
+ elif prompt.startswith("/s"):
69
+ if prompt.count(" ") != 1:
70
+ print(f"Got invalid command '{prompt}'\n{usage}")
71
+ continue
72
+ _, seed = prompt.split()
73
+ options.seed = int(seed)
74
+ print(f"Setting seed to {options.seed}")
75
+ elif prompt.startswith("/n"):
76
+ if prompt.count(" ") != 1:
77
+ print(f"Got invalid command '{prompt}'\n{usage}")
78
+ continue
79
+ _, steps = prompt.split()
80
+ options.num_steps = int(steps)
81
+ print(f"Setting number of steps to {options.num_steps}")
82
+ elif prompt.startswith("/q"):
83
+ print("Quitting")
84
+ return None
85
+ else:
86
+ if not prompt.startswith("/h"):
87
+ print(f"Got invalid command '{prompt}'\n{usage}")
88
+ print(usage)
89
+ return options
90
+
91
+
92
+ def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
93
+ if options is None:
94
+ return None
95
+
96
+ user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
97
+ usage = (
98
+ "Usage: Either write your prompt directly, leave this field empty "
99
+ "to repeat the conditioning image or write a command starting with a slash:\n"
100
+ "- '/q' to quit"
101
+ )
102
+
103
+ while True:
104
+ img_cond_path = input(user_question)
105
+
106
+ if img_cond_path.startswith("/"):
107
+ if img_cond_path.startswith("/q"):
108
+ print("Quitting")
109
+ return None
110
+ else:
111
+ if not img_cond_path.startswith("/h"):
112
+ print(f"Got invalid command '{img_cond_path}'\n{usage}")
113
+ print(usage)
114
+ continue
115
+
116
+ if img_cond_path == "":
117
+ break
118
+
119
+ if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
120
+ (".jpg", ".jpeg", ".png", ".webp")
121
+ ):
122
+ print(f"File '{img_cond_path}' does not exist or is not a valid image file")
123
+ continue
124
+
125
+ options.img_cond_path = img_cond_path
126
+ break
127
+
128
+ return options
129
+
130
+
131
+ @torch.inference_mode()
132
+ def main(
133
+ name: str = "flux-dev",
134
+ width: int = 1360,
135
+ height: int = 768,
136
+ seed: int | None = None,
137
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
138
+ num_steps: int | None = None,
139
+ loop: bool = False,
140
+ guidance: float = 2.5,
141
+ offload: bool = False,
142
+ output_dir: str = "output",
143
+ add_sampling_metadata: bool = True,
144
+ img_cond_path: str = "assets/robot.webp",
145
+ ):
146
+ """
147
+ Sample the flux model. Either interactively (set `--loop`) or run for a
148
+ single image.
149
+
150
+ Args:
151
+ name: Name of the model to load
152
+ height: height of the sample in pixels (should be a multiple of 16)
153
+ width: width of the sample in pixels (should be a multiple of 16)
154
+ seed: Set a seed for sampling
155
+ output_name: where to save the output image, `{idx}` will be replaced
156
+ by the index of the sample
157
+ prompt: Prompt used for sampling
158
+ device: Pytorch device
159
+ num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
160
+ loop: start an interactive session and sample multiple times
161
+ guidance: guidance value used for guidance distillation
162
+ add_sampling_metadata: Add the prompt to the image Exif metadata
163
+ img_cond_path: path to conditioning image (jpeg/png/webp)
164
+ """
165
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
166
+
167
+ if name not in configs:
168
+ available = ", ".join(configs.keys())
169
+ raise ValueError(f"Got unknown model name: {name}, chose from {available}")
170
+
171
+ torch_device = torch.device(device)
172
+ if num_steps is None:
173
+ num_steps = 4 if name == "flux-schnell" else 50
174
+
175
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
176
+ if not os.path.exists(output_dir):
177
+ os.makedirs(output_dir)
178
+ idx = 0
179
+ else:
180
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
181
+ if len(fns) > 0:
182
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
183
+ else:
184
+ idx = 0
185
+
186
+ # init all components
187
+ t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
188
+ clip = load_clip(torch_device)
189
+ model = load_flow_model(name, device="cpu" if offload else torch_device)
190
+ ae = load_ae(name, device="cpu" if offload else torch_device)
191
+ img_embedder = ReduxImageEncoder(torch_device)
192
+
193
+ rng = torch.Generator(device="cpu")
194
+ prompt = ""
195
+ opts = SamplingOptions(
196
+ prompt=prompt,
197
+ width=width,
198
+ height=height,
199
+ num_steps=num_steps,
200
+ guidance=guidance,
201
+ seed=seed,
202
+ img_cond_path=img_cond_path,
203
+ )
204
+
205
+ if loop:
206
+ opts = parse_prompt(opts)
207
+ opts = parse_img_cond_path(opts)
208
+
209
+ while opts is not None:
210
+ if opts.seed is None:
211
+ opts.seed = rng.seed()
212
+ print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
213
+ t0 = time.perf_counter()
214
+
215
+ # prepare input
216
+ x = get_noise(
217
+ 1,
218
+ opts.height,
219
+ opts.width,
220
+ device=torch_device,
221
+ dtype=torch.bfloat16,
222
+ seed=opts.seed,
223
+ )
224
+ opts.seed = None
225
+ if offload:
226
+ ae = ae.cpu()
227
+ torch.cuda.empty_cache()
228
+ t5, clip = t5.to(torch_device), clip.to(torch_device)
229
+ inp = prepare_redux(
230
+ t5,
231
+ clip,
232
+ x,
233
+ prompt=opts.prompt,
234
+ encoder=img_embedder,
235
+ img_cond_path=opts.img_cond_path,
236
+ )
237
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
238
+
239
+ # offload TEs to CPU, load model to gpu
240
+ if offload:
241
+ t5, clip = t5.cpu(), clip.cpu()
242
+ torch.cuda.empty_cache()
243
+ model = model.to(torch_device)
244
+
245
+ # denoise initial noise
246
+ x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
247
+
248
+ # offload model, load autoencoder to gpu
249
+ if offload:
250
+ model.cpu()
251
+ torch.cuda.empty_cache()
252
+ ae.decoder.to(x.device)
253
+
254
+ # decode latents to pixel space
255
+ x = unpack(x.float(), opts.height, opts.width)
256
+ with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
257
+ x = ae.decode(x)
258
+
259
+ if torch.cuda.is_available():
260
+ torch.cuda.synchronize()
261
+ t1 = time.perf_counter()
262
+ print(f"Done in {t1 - t0:.1f}s")
263
+
264
+ idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt)
265
+
266
+ if loop:
267
+ print("-" * 80)
268
+ opts = parse_prompt(opts)
269
+ opts = parse_img_cond_path(opts)
270
+ else:
271
+ opts = None
272
+
273
+
274
+ def app():
275
+ Fire(main)
276
+
277
+
278
+ if __name__ == "__main__":
279
+ app()
flux/math.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+
5
+
6
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
+ q, k = apply_rope(q, k, pe)
8
+
9
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
10
+ x = rearrange(x, "B H L D -> B L (H D)")
11
+
12
+ return x
13
+
14
+
15
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
16
+ assert dim % 2 == 0
17
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
18
+ omega = 1.0 / (theta**scale)
19
+ out = torch.einsum("...n,d->...nd", pos, omega)
20
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
21
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
22
+ return out.float()
23
+
24
+
25
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
26
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
27
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
28
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
29
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
30
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
flux/model.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from flux.modules.layers import (
7
+ DoubleStreamBlock,
8
+ EmbedND,
9
+ LastLayer,
10
+ MLPEmbedder,
11
+ SingleStreamBlock,
12
+ timestep_embedding,
13
+ )
14
+ from flux.modules.lora import LinearLora, replace_linear_with_lora
15
+
16
+
17
+ @dataclass
18
+ class FluxParams:
19
+ in_channels: int
20
+ out_channels: int
21
+ vec_in_dim: int
22
+ context_in_dim: int
23
+ hidden_size: int
24
+ mlp_ratio: float
25
+ num_heads: int
26
+ depth: int
27
+ depth_single_blocks: int
28
+ axes_dim: list[int]
29
+ theta: int
30
+ qkv_bias: bool
31
+ guidance_embed: bool
32
+
33
+
34
+ class Flux(nn.Module):
35
+ """
36
+ Transformer model for flow matching on sequences.
37
+ """
38
+
39
+ def __init__(self, params: FluxParams):
40
+ super().__init__()
41
+
42
+ self.params = params
43
+ self.in_channels = params.in_channels
44
+ self.out_channels = params.out_channels
45
+ if params.hidden_size % params.num_heads != 0:
46
+ raise ValueError(
47
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
48
+ )
49
+ pe_dim = params.hidden_size // params.num_heads
50
+ if sum(params.axes_dim) != pe_dim:
51
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
52
+ self.hidden_size = params.hidden_size
53
+ self.num_heads = params.num_heads
54
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
55
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
56
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
57
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
58
+ self.guidance_in = (
59
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
60
+ )
61
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
62
+
63
+ self.double_blocks = nn.ModuleList(
64
+ [
65
+ DoubleStreamBlock(
66
+ self.hidden_size,
67
+ self.num_heads,
68
+ mlp_ratio=params.mlp_ratio,
69
+ qkv_bias=params.qkv_bias,
70
+ )
71
+ for _ in range(params.depth)
72
+ ]
73
+ )
74
+
75
+ self.single_blocks = nn.ModuleList(
76
+ [
77
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
78
+ for _ in range(params.depth_single_blocks)
79
+ ]
80
+ )
81
+
82
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
83
+
84
+ def forward(
85
+ self,
86
+ img: Tensor,
87
+ img_ids: Tensor,
88
+ txt: Tensor,
89
+ txt_ids: Tensor,
90
+ timesteps: Tensor,
91
+ y: Tensor,
92
+ guidance: Tensor | None = None,
93
+ ) -> Tensor:
94
+ if img.ndim != 3 or txt.ndim != 3:
95
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
96
+
97
+ # running on sequences img
98
+ img = self.img_in(img)
99
+ vec = self.time_in(timestep_embedding(timesteps, 256))
100
+ if self.params.guidance_embed:
101
+ if guidance is None:
102
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
103
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
104
+ vec = vec + self.vector_in(y)
105
+ txt = self.txt_in(txt)
106
+
107
+ ids = torch.cat((txt_ids, img_ids), dim=1)
108
+ pe = self.pe_embedder(ids)
109
+
110
+ for block in self.double_blocks:
111
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
112
+
113
+ img = torch.cat((txt, img), 1)
114
+ for block in self.single_blocks:
115
+ img = block(img, vec=vec, pe=pe)
116
+ img = img[:, txt.shape[1] :, ...]
117
+
118
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
119
+ return img
120
+
121
+
122
+ class FluxLoraWrapper(Flux):
123
+ def __init__(
124
+ self,
125
+ lora_rank: int = 128,
126
+ lora_scale: float = 1.0,
127
+ *args,
128
+ **kwargs,
129
+ ) -> None:
130
+ super().__init__(*args, **kwargs)
131
+
132
+ self.lora_rank = lora_rank
133
+
134
+ replace_linear_with_lora(
135
+ self,
136
+ max_rank=lora_rank,
137
+ scale=lora_scale,
138
+ )
139
+
140
+ def set_lora_scale(self, scale: float) -> None:
141
+ for module in self.modules():
142
+ if isinstance(module, LinearLora):
143
+ module.set_scale(scale=scale)
flux/modules/autoencoder.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # z to block_in
239
+ h = self.conv_in(z)
240
+
241
+ # middle
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.attn_1(h)
244
+ h = self.mid.block_2(h)
245
+
246
+ # upsampling
247
+ for i_level in reversed(range(self.num_resolutions)):
248
+ for i_block in range(self.num_res_blocks + 1):
249
+ h = self.up[i_level].block[i_block](h)
250
+ if len(self.up[i_level].attn) > 0:
251
+ h = self.up[i_level].attn[i_block](h)
252
+ if i_level != 0:
253
+ h = self.up[i_level].upsample(h)
254
+
255
+ # end
256
+ h = self.norm_out(h)
257
+ h = swish(h)
258
+ h = self.conv_out(h)
259
+ return h
260
+
261
+
262
+ class DiagonalGaussian(nn.Module):
263
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
+ super().__init__()
265
+ self.sample = sample
266
+ self.chunk_dim = chunk_dim
267
+
268
+ def forward(self, z: Tensor) -> Tensor:
269
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
+ if self.sample:
271
+ std = torch.exp(0.5 * logvar)
272
+ return mean + std * torch.randn_like(mean)
273
+ else:
274
+ return mean
275
+
276
+
277
+ class AutoEncoder(nn.Module):
278
+ def __init__(self, params: AutoEncoderParams):
279
+ super().__init__()
280
+ self.encoder = Encoder(
281
+ resolution=params.resolution,
282
+ in_channels=params.in_channels,
283
+ ch=params.ch,
284
+ ch_mult=params.ch_mult,
285
+ num_res_blocks=params.num_res_blocks,
286
+ z_channels=params.z_channels,
287
+ )
288
+ self.decoder = Decoder(
289
+ resolution=params.resolution,
290
+ in_channels=params.in_channels,
291
+ ch=params.ch,
292
+ out_ch=params.out_ch,
293
+ ch_mult=params.ch_mult,
294
+ num_res_blocks=params.num_res_blocks,
295
+ z_channels=params.z_channels,
296
+ )
297
+ self.reg = DiagonalGaussian()
298
+
299
+ self.scale_factor = params.scale_factor
300
+ self.shift_factor = params.shift_factor
301
+
302
+ def encode(self, x: Tensor) -> Tensor:
303
+ z = self.reg(self.encoder(x))
304
+ z = self.scale_factor * (z - self.shift_factor)
305
+ return z
306
+
307
+ def decode(self, z: Tensor) -> Tensor:
308
+ z = z / self.scale_factor + self.shift_factor
309
+ return self.decoder(z)
310
+
311
+ def forward(self, x: Tensor) -> Tensor:
312
+ return self.decode(self.encode(x))
flux/modules/conditioner.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
3
+
4
+
5
+ class HFEmbedder(nn.Module):
6
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
7
+ super().__init__()
8
+ self.is_clip = version.startswith("openai")
9
+ self.max_length = max_length
10
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
11
+
12
+ if self.is_clip:
13
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
14
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
15
+ else:
16
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
17
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
18
+
19
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
20
+
21
+ def forward(self, text: list[str]) -> Tensor:
22
+ batch_encoding = self.tokenizer(
23
+ text,
24
+ truncation=True,
25
+ max_length=self.max_length,
26
+ return_length=False,
27
+ return_overflowing_tokens=False,
28
+ padding="max_length",
29
+ return_tensors="pt",
30
+ )
31
+
32
+ outputs = self.hf_module(
33
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
34
+ attention_mask=None,
35
+ output_hidden_states=False,
36
+ )
37
+ return outputs[self.output_key]
flux/modules/image_embedders.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from PIL import Image
8
+ from safetensors.torch import load_file as load_sft
9
+ from torch import nn
10
+ from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
11
+
12
+ from flux.util import print_load_warning
13
+
14
+
15
+ class DepthImageEncoder:
16
+ depth_model_name = "LiheYoung/depth-anything-large-hf"
17
+
18
+ def __init__(self, device):
19
+ self.device = device
20
+ self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
21
+ self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
22
+
23
+ def __call__(self, img: torch.Tensor) -> torch.Tensor:
24
+ hw = img.shape[-2:]
25
+
26
+ img = torch.clamp(img, -1.0, 1.0)
27
+ img_byte = ((img + 1.0) * 127.5).byte()
28
+
29
+ img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
30
+ depth = self.depth_model(img.to(self.device)).predicted_depth
31
+ depth = repeat(depth, "b h w -> b 3 h w")
32
+ depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
33
+
34
+ depth = depth / 127.5 - 1.0
35
+ return depth
36
+
37
+
38
+ class CannyImageEncoder:
39
+ def __init__(
40
+ self,
41
+ device,
42
+ min_t: int = 50,
43
+ max_t: int = 200,
44
+ ):
45
+ self.device = device
46
+ self.min_t = min_t
47
+ self.max_t = max_t
48
+
49
+ def __call__(self, img: torch.Tensor) -> torch.Tensor:
50
+ assert img.shape[0] == 1, "Only batch size 1 is supported"
51
+
52
+ img = rearrange(img[0], "c h w -> h w c")
53
+ img = torch.clamp(img, -1.0, 1.0)
54
+ img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
55
+
56
+ # Apply Canny edge detection
57
+ canny = cv2.Canny(img_np, self.min_t, self.max_t)
58
+
59
+ # Convert back to torch tensor and reshape
60
+ canny = torch.from_numpy(canny).float() / 127.5 - 1.0
61
+ canny = rearrange(canny, "h w -> 1 1 h w")
62
+ canny = repeat(canny, "b 1 ... -> b 3 ...")
63
+ return canny.to(self.device)
64
+
65
+
66
+ class ReduxImageEncoder(nn.Module):
67
+ siglip_model_name = "google/siglip-so400m-patch14-384"
68
+
69
+ def __init__(
70
+ self,
71
+ device,
72
+ redux_dim: int = 1152,
73
+ txt_in_features: int = 4096,
74
+ redux_path: str | None = os.getenv("FLUX_REDUX"),
75
+ dtype=torch.bfloat16,
76
+ ) -> None:
77
+ assert redux_path is not None, "Redux path must be provided"
78
+
79
+ super().__init__()
80
+
81
+ self.redux_dim = redux_dim
82
+ self.device = device if isinstance(device, torch.device) else torch.device(device)
83
+ self.dtype = dtype
84
+
85
+ with self.device:
86
+ self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
87
+ self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
88
+
89
+ sd = load_sft(redux_path, device=str(device))
90
+ missing, unexpected = self.load_state_dict(sd, strict=False, assign=True)
91
+ print_load_warning(missing, unexpected)
92
+
93
+ self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(dtype=dtype)
94
+ self.normalize = SiglipImageProcessor.from_pretrained(self.siglip_model_name)
95
+
96
+ def __call__(self, x: Image.Image) -> torch.Tensor:
97
+ imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
98
+
99
+ _encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state
100
+
101
+ projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x)))
102
+
103
+ return projected_x
flux/modules/layers.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from flux.math import attention, rope
9
+
10
+
11
+ class EmbedND(nn.Module):
12
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.theta = theta
16
+ self.axes_dim = axes_dim
17
+
18
+ def forward(self, ids: Tensor) -> Tensor:
19
+ n_axes = ids.shape[-1]
20
+ emb = torch.cat(
21
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
22
+ dim=-3,
23
+ )
24
+
25
+ return emb.unsqueeze(1)
26
+
27
+
28
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
29
+ """
30
+ Create sinusoidal timestep embeddings.
31
+ :param t: a 1-D Tensor of N indices, one per batch element.
32
+ These may be fractional.
33
+ :param dim: the dimension of the output.
34
+ :param max_period: controls the minimum frequency of the embeddings.
35
+ :return: an (N, D) Tensor of positional embeddings.
36
+ """
37
+ t = time_factor * t
38
+ half = dim // 2
39
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
40
+ t.device
41
+ )
42
+
43
+ args = t[:, None].float() * freqs[None]
44
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
+ if dim % 2:
46
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
+ if torch.is_floating_point(t):
48
+ embedding = embedding.to(t)
49
+ return embedding
50
+
51
+
52
+ class MLPEmbedder(nn.Module):
53
+ def __init__(self, in_dim: int, hidden_dim: int):
54
+ super().__init__()
55
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
56
+ self.silu = nn.SiLU()
57
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ return self.out_layer(self.silu(self.in_layer(x)))
61
+
62
+
63
+ class RMSNorm(torch.nn.Module):
64
+ def __init__(self, dim: int):
65
+ super().__init__()
66
+ self.scale = nn.Parameter(torch.ones(dim))
67
+
68
+ def forward(self, x: Tensor):
69
+ x_dtype = x.dtype
70
+ x = x.float()
71
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
72
+ return (x * rrms).to(dtype=x_dtype) * self.scale
73
+
74
+
75
+ class QKNorm(torch.nn.Module):
76
+ def __init__(self, dim: int):
77
+ super().__init__()
78
+ self.query_norm = RMSNorm(dim)
79
+ self.key_norm = RMSNorm(dim)
80
+
81
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
82
+ q = self.query_norm(q)
83
+ k = self.key_norm(k)
84
+ return q.to(v), k.to(v)
85
+
86
+
87
+ class SelfAttention(nn.Module):
88
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
89
+ super().__init__()
90
+ self.num_heads = num_heads
91
+ head_dim = dim // num_heads
92
+
93
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
94
+ self.norm = QKNorm(head_dim)
95
+ self.proj = nn.Linear(dim, dim)
96
+
97
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
98
+ qkv = self.qkv(x)
99
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
100
+ q, k = self.norm(q, k, v)
101
+ x = attention(q, k, v, pe=pe)
102
+ x = self.proj(x)
103
+ return x
104
+
105
+
106
+ @dataclass
107
+ class ModulationOut:
108
+ shift: Tensor
109
+ scale: Tensor
110
+ gate: Tensor
111
+
112
+
113
+ class Modulation(nn.Module):
114
+ def __init__(self, dim: int, double: bool):
115
+ super().__init__()
116
+ self.is_double = double
117
+ self.multiplier = 6 if double else 3
118
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
119
+
120
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
121
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
122
+
123
+ return (
124
+ ModulationOut(*out[:3]),
125
+ ModulationOut(*out[3:]) if self.is_double else None,
126
+ )
127
+
128
+
129
+ class DoubleStreamBlock(nn.Module):
130
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
131
+ super().__init__()
132
+
133
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
134
+ self.num_heads = num_heads
135
+ self.hidden_size = hidden_size
136
+ self.img_mod = Modulation(hidden_size, double=True)
137
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
138
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
139
+
140
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
141
+ self.img_mlp = nn.Sequential(
142
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
143
+ nn.GELU(approximate="tanh"),
144
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
145
+ )
146
+
147
+ self.txt_mod = Modulation(hidden_size, double=True)
148
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
149
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
150
+
151
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
152
+ self.txt_mlp = nn.Sequential(
153
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
154
+ nn.GELU(approximate="tanh"),
155
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
156
+ )
157
+
158
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
159
+ img_mod1, img_mod2 = self.img_mod(vec)
160
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
161
+
162
+ # prepare image for attention
163
+ img_modulated = self.img_norm1(img)
164
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
165
+ img_qkv = self.img_attn.qkv(img_modulated)
166
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
167
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
168
+
169
+ # prepare txt for attention
170
+ txt_modulated = self.txt_norm1(txt)
171
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
172
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
173
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
174
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
175
+
176
+ # run actual attention
177
+ q = torch.cat((txt_q, img_q), dim=2)
178
+ k = torch.cat((txt_k, img_k), dim=2)
179
+ v = torch.cat((txt_v, img_v), dim=2)
180
+
181
+ attn = attention(q, k, v, pe=pe)
182
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
183
+
184
+ # calculate the img bloks
185
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
186
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
187
+
188
+ # calculate the txt bloks
189
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
190
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
191
+ return img, txt
192
+
193
+
194
+ class SingleStreamBlock(nn.Module):
195
+ """
196
+ A DiT block with parallel linear layers as described in
197
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
198
+ """
199
+
200
+ def __init__(
201
+ self,
202
+ hidden_size: int,
203
+ num_heads: int,
204
+ mlp_ratio: float = 4.0,
205
+ qk_scale: float | None = None,
206
+ ):
207
+ super().__init__()
208
+ self.hidden_dim = hidden_size
209
+ self.num_heads = num_heads
210
+ head_dim = hidden_size // num_heads
211
+ self.scale = qk_scale or head_dim**-0.5
212
+
213
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
214
+ # qkv and mlp_in
215
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
216
+ # proj and mlp_out
217
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
218
+
219
+ self.norm = QKNorm(head_dim)
220
+
221
+ self.hidden_size = hidden_size
222
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
223
+
224
+ self.mlp_act = nn.GELU(approximate="tanh")
225
+ self.modulation = Modulation(hidden_size, double=False)
226
+
227
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
228
+ mod, _ = self.modulation(vec)
229
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
230
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
231
+
232
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
233
+ q, k = self.norm(q, k, v)
234
+
235
+ # compute attention
236
+ attn = attention(q, k, v, pe=pe)
237
+ # compute activation in mlp stream, cat again and run second linear layer
238
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
239
+ return x + mod.gate * output
240
+
241
+
242
+ class LastLayer(nn.Module):
243
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
244
+ super().__init__()
245
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
247
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
248
+
249
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
250
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
251
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
252
+ x = self.linear(x)
253
+ return x
flux/modules/lora.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ def replace_linear_with_lora(
6
+ module: nn.Module,
7
+ max_rank: int,
8
+ scale: float = 1.0,
9
+ ) -> None:
10
+ for name, child in module.named_children():
11
+ if isinstance(child, nn.Linear):
12
+ new_lora = LinearLora(
13
+ in_features=child.in_features,
14
+ out_features=child.out_features,
15
+ bias=child.bias,
16
+ rank=max_rank,
17
+ scale=scale,
18
+ dtype=child.weight.dtype,
19
+ device=child.weight.device,
20
+ )
21
+
22
+ new_lora.weight = child.weight
23
+ new_lora.bias = child.bias if child.bias is not None else None
24
+
25
+ setattr(module, name, new_lora)
26
+ else:
27
+ replace_linear_with_lora(
28
+ module=child,
29
+ max_rank=max_rank,
30
+ scale=scale,
31
+ )
32
+
33
+
34
+ class LinearLora(nn.Linear):
35
+ def __init__(
36
+ self,
37
+ in_features: int,
38
+ out_features: int,
39
+ bias: bool,
40
+ rank: int,
41
+ dtype: torch.dtype,
42
+ device: torch.device,
43
+ lora_bias: bool = True,
44
+ scale: float = 1.0,
45
+ *args,
46
+ **kwargs,
47
+ ) -> None:
48
+ super().__init__(
49
+ in_features=in_features,
50
+ out_features=out_features,
51
+ bias=bias is not None,
52
+ device=device,
53
+ dtype=dtype,
54
+ *args,
55
+ **kwargs,
56
+ )
57
+
58
+ assert isinstance(scale, float), "scale must be a float"
59
+
60
+ self.scale = scale
61
+ self.rank = rank
62
+ self.lora_bias = lora_bias
63
+ self.dtype = dtype
64
+ self.device = device
65
+
66
+ if rank > (new_rank := min(self.out_features, self.in_features)):
67
+ self.rank = new_rank
68
+
69
+ self.lora_A = nn.Linear(
70
+ in_features=in_features,
71
+ out_features=self.rank,
72
+ bias=False,
73
+ dtype=dtype,
74
+ device=device,
75
+ )
76
+ self.lora_B = nn.Linear(
77
+ in_features=self.rank,
78
+ out_features=out_features,
79
+ bias=self.lora_bias,
80
+ dtype=dtype,
81
+ device=device,
82
+ )
83
+
84
+ def set_scale(self, scale: float) -> None:
85
+ assert isinstance(scale, float), "scalar value must be a float"
86
+ self.scale = scale
87
+
88
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
89
+ base_out = super().forward(input)
90
+
91
+ _lora_out_B = self.lora_B(self.lora_A(input))
92
+ lora_update = _lora_out_B * self.scale
93
+
94
+ return base_out + lora_update
flux/sampling.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Callable
3
+
4
+ import numpy as np
5
+ import torch
6
+ from einops import rearrange, repeat
7
+ from PIL import Image
8
+ from torch import Tensor
9
+
10
+ from .model import Flux
11
+ from .modules.autoencoder import AutoEncoder
12
+ from .modules.conditioner import HFEmbedder
13
+ from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder
14
+
15
+
16
+ def get_noise(
17
+ num_samples: int,
18
+ height: int,
19
+ width: int,
20
+ device: torch.device,
21
+ dtype: torch.dtype,
22
+ seed: int,
23
+ ):
24
+ return torch.randn(
25
+ num_samples,
26
+ 16,
27
+ # allow for packing
28
+ 2 * math.ceil(height / 16),
29
+ 2 * math.ceil(width / 16),
30
+ device=device,
31
+ dtype=dtype,
32
+ generator=torch.Generator(device=device).manual_seed(seed),
33
+ )
34
+
35
+
36
+ def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
37
+ bs, c, h, w = img.shape
38
+ if bs == 1 and not isinstance(prompt, str):
39
+ bs = len(prompt)
40
+
41
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
42
+ if img.shape[0] == 1 and bs > 1:
43
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
44
+
45
+ img_ids = torch.zeros(h // 2, w // 2, 3)
46
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
47
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
48
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
49
+
50
+ if isinstance(prompt, str):
51
+ prompt = [prompt]
52
+ txt = t5(prompt)
53
+ if txt.shape[0] == 1 and bs > 1:
54
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
55
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
56
+
57
+ vec = clip(prompt)
58
+ if vec.shape[0] == 1 and bs > 1:
59
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
60
+
61
+ return {
62
+ "img": img,
63
+ "img_ids": img_ids.to(img.device),
64
+ "txt": txt.to(img.device),
65
+ "txt_ids": txt_ids.to(img.device),
66
+ "vec": vec.to(img.device),
67
+ }
68
+
69
+
70
+ def prepare_control(
71
+ t5: HFEmbedder,
72
+ clip: HFEmbedder,
73
+ img: Tensor,
74
+ prompt: str | list[str],
75
+ ae: AutoEncoder,
76
+ encoder: DepthImageEncoder | CannyImageEncoder,
77
+ img_cond_path: str,
78
+ ) -> dict[str, Tensor]:
79
+ # load and encode the conditioning image
80
+ bs, _, h, w = img.shape
81
+ if bs == 1 and not isinstance(prompt, str):
82
+ bs = len(prompt)
83
+
84
+ img_cond = Image.open(img_cond_path).convert("RGB")
85
+
86
+ width = w * 8
87
+ height = h * 8
88
+ img_cond = img_cond.resize((width, height), Image.LANCZOS)
89
+ img_cond = np.array(img_cond)
90
+ img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
91
+ img_cond = rearrange(img_cond, "h w c -> 1 c h w")
92
+
93
+ with torch.no_grad():
94
+ img_cond = encoder(img_cond)
95
+ img_cond = ae.encode(img_cond)
96
+
97
+ img_cond = img_cond.to(torch.bfloat16)
98
+ img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
99
+ if img_cond.shape[0] == 1 and bs > 1:
100
+ img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
101
+
102
+ return_dict = prepare(t5, clip, img, prompt)
103
+ return_dict["img_cond"] = img_cond
104
+ return return_dict
105
+
106
+
107
+ def prepare_fill(
108
+ t5: HFEmbedder,
109
+ clip: HFEmbedder,
110
+ img: Tensor,
111
+ prompt: str | list[str],
112
+ ae: AutoEncoder,
113
+ img_cond_path: str,
114
+ mask_path: str,
115
+ ) -> dict[str, Tensor]:
116
+ # load and encode the conditioning image and the mask
117
+ bs, _, _, _ = img.shape
118
+ if bs == 1 and not isinstance(prompt, str):
119
+ bs = len(prompt)
120
+
121
+ img_cond = Image.open(img_cond_path).convert("RGB")
122
+ img_cond = np.array(img_cond)
123
+ img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
124
+ img_cond = rearrange(img_cond, "h w c -> 1 c h w")
125
+
126
+ mask = Image.open(mask_path).convert("L")
127
+ mask = np.array(mask)
128
+ mask = torch.from_numpy(mask).float() / 255.0
129
+ mask = rearrange(mask, "h w -> 1 1 h w")
130
+
131
+ with torch.no_grad():
132
+ img_cond = img_cond.to(img.device)
133
+ mask = mask.to(img.device)
134
+ img_cond = img_cond * (1 - mask)
135
+ img_cond = ae.encode(img_cond)
136
+ mask = mask[:, 0, :, :]
137
+ mask = mask.to(torch.bfloat16)
138
+ mask = rearrange(
139
+ mask,
140
+ "b (h ph) (w pw) -> b (ph pw) h w",
141
+ ph=8,
142
+ pw=8,
143
+ )
144
+ mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
145
+ if mask.shape[0] == 1 and bs > 1:
146
+ mask = repeat(mask, "1 ... -> bs ...", bs=bs)
147
+
148
+ img_cond = img_cond.to(torch.bfloat16)
149
+ img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
150
+ if img_cond.shape[0] == 1 and bs > 1:
151
+ img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
152
+
153
+ img_cond = torch.cat((img_cond, mask), dim=-1)
154
+
155
+ return_dict = prepare(t5, clip, img, prompt)
156
+ return_dict["img_cond"] = img_cond.to(img.device)
157
+ return return_dict
158
+
159
+
160
+ def prepare_redux(
161
+ t5: HFEmbedder,
162
+ clip: HFEmbedder,
163
+ img: Tensor,
164
+ prompt: str | list[str],
165
+ encoder: ReduxImageEncoder,
166
+ img_cond_path: str,
167
+ ) -> dict[str, Tensor]:
168
+ bs, _, h, w = img.shape
169
+ if bs == 1 and not isinstance(prompt, str):
170
+ bs = len(prompt)
171
+
172
+ img_cond = Image.open(img_cond_path).convert("RGB")
173
+ with torch.no_grad():
174
+ img_cond = encoder(img_cond)
175
+
176
+ img_cond = img_cond.to(torch.bfloat16)
177
+ if img_cond.shape[0] == 1 and bs > 1:
178
+ img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
179
+
180
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
181
+ if img.shape[0] == 1 and bs > 1:
182
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
183
+
184
+ img_ids = torch.zeros(h // 2, w // 2, 3)
185
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
186
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
187
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
188
+
189
+ if isinstance(prompt, str):
190
+ prompt = [prompt]
191
+ txt = t5(prompt)
192
+ txt = torch.cat((txt, img_cond.to(txt)), dim=-2)
193
+ if txt.shape[0] == 1 and bs > 1:
194
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
195
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
196
+
197
+ vec = clip(prompt)
198
+ if vec.shape[0] == 1 and bs > 1:
199
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
200
+
201
+ return {
202
+ "img": img,
203
+ "img_ids": img_ids.to(img.device),
204
+ "txt": txt.to(img.device),
205
+ "txt_ids": txt_ids.to(img.device),
206
+ "vec": vec.to(img.device),
207
+ }
208
+
209
+
210
+ def time_shift(mu: float, sigma: float, t: Tensor):
211
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
212
+
213
+
214
+ def get_lin_function(
215
+ x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
216
+ ) -> Callable[[float], float]:
217
+ m = (y2 - y1) / (x2 - x1)
218
+ b = y1 - m * x1
219
+ return lambda x: m * x + b
220
+
221
+
222
+ def get_schedule(
223
+ num_steps: int,
224
+ image_seq_len: int,
225
+ base_shift: float = 0.5,
226
+ max_shift: float = 1.15,
227
+ shift: bool = True,
228
+ ) -> list[float]:
229
+ # extra step for zero
230
+ timesteps = torch.linspace(1, 0, num_steps + 1)
231
+
232
+ # shifting the schedule to favor high timesteps for higher signal images
233
+ if shift:
234
+ # estimate mu based on linear estimation between two points
235
+ mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
236
+ timesteps = time_shift(mu, 1.0, timesteps)
237
+
238
+ return timesteps.tolist()
239
+
240
+
241
+ def denoise(
242
+ model: Flux,
243
+ # model input
244
+ img: Tensor,
245
+ img_ids: Tensor,
246
+ txt: Tensor,
247
+ txt_ids: Tensor,
248
+ vec: Tensor,
249
+ # sampling parameters
250
+ timesteps: list[float],
251
+ guidance: float = 4.0,
252
+ # extra img tokens
253
+ img_cond: Tensor | None = None,
254
+ ):
255
+ # this is ignored for schnell
256
+ guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
257
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
258
+ t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
259
+ pred = model(
260
+ img=torch.cat((img, img_cond), dim=-1) if img_cond is not None else img,
261
+ img_ids=img_ids,
262
+ txt=txt,
263
+ txt_ids=txt_ids,
264
+ y=vec,
265
+ timesteps=t_vec,
266
+ guidance=guidance_vec,
267
+ )
268
+
269
+ img = img + (t_prev - t_curr) * pred
270
+
271
+ return img
272
+
273
+
274
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
275
+ return rearrange(
276
+ x,
277
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
278
+ h=math.ceil(height / 16),
279
+ w=math.ceil(width / 16),
280
+ ph=2,
281
+ pw=2,
282
+ )
flux/util.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from huggingface_hub import hf_hub_download
7
+ from imwatermark import WatermarkEncoder
8
+ from PIL import ExifTags, Image
9
+ from safetensors.torch import load_file as load_sft
10
+
11
+ from flux.model import Flux, FluxLoraWrapper, FluxParams
12
+ from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
13
+ from flux.modules.conditioner import HFEmbedder
14
+
15
+
16
+ def save_image(
17
+ nsfw_classifier,
18
+ name: str,
19
+ output_name: str,
20
+ idx: int,
21
+ x: torch.Tensor,
22
+ add_sampling_metadata: bool,
23
+ prompt: str,
24
+ nsfw_threshold: float = 0.85,
25
+ ) -> int:
26
+ fn = output_name.format(idx=idx)
27
+ print(f"Saving {fn}")
28
+ # bring into PIL format and save
29
+ x = x.clamp(-1, 1)
30
+ x = embed_watermark(x.float())
31
+ x = rearrange(x[0], "c h w -> h w c")
32
+
33
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
34
+ nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
35
+
36
+ if nsfw_score < nsfw_threshold:
37
+ exif_data = Image.Exif()
38
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
39
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
40
+ exif_data[ExifTags.Base.Model] = name
41
+ if add_sampling_metadata:
42
+ exif_data[ExifTags.Base.ImageDescription] = prompt
43
+ img.save(fn, exif=exif_data, quality=95, subsampling=0)
44
+ idx += 1
45
+ else:
46
+ print("Your generated image may contain NSFW content.")
47
+
48
+ return idx
49
+
50
+
51
+ @dataclass
52
+ class ModelSpec:
53
+ params: FluxParams
54
+ ae_params: AutoEncoderParams
55
+ ckpt_path: str | None
56
+ lora_path: str | None
57
+ ae_path: str | None
58
+ repo_id: str | None
59
+ repo_flow: str | None
60
+ repo_ae: str | None
61
+
62
+
63
+ configs = {
64
+ "flux-dev": ModelSpec(
65
+ repo_id="black-forest-labs/FLUX.1-dev",
66
+ repo_flow="flux1-dev.safetensors",
67
+ repo_ae="ae.safetensors",
68
+ ckpt_path=os.getenv("FLUX_DEV"),
69
+ lora_path=None,
70
+ params=FluxParams(
71
+ in_channels=64,
72
+ out_channels=64,
73
+ vec_in_dim=768,
74
+ context_in_dim=4096,
75
+ hidden_size=3072,
76
+ mlp_ratio=4.0,
77
+ num_heads=24,
78
+ depth=19,
79
+ depth_single_blocks=38,
80
+ axes_dim=[16, 56, 56],
81
+ theta=10_000,
82
+ qkv_bias=True,
83
+ guidance_embed=True,
84
+ ),
85
+ ae_path=os.getenv("AE"),
86
+ ae_params=AutoEncoderParams(
87
+ resolution=256,
88
+ in_channels=3,
89
+ ch=128,
90
+ out_ch=3,
91
+ ch_mult=[1, 2, 4, 4],
92
+ num_res_blocks=2,
93
+ z_channels=16,
94
+ scale_factor=0.3611,
95
+ shift_factor=0.1159,
96
+ ),
97
+ ),
98
+ "flux-schnell": ModelSpec(
99
+ repo_id="black-forest-labs/FLUX.1-schnell",
100
+ repo_flow="flux1-schnell.safetensors",
101
+ repo_ae="ae.safetensors",
102
+ ckpt_path=os.getenv("FLUX_SCHNELL"),
103
+ lora_path=None,
104
+ params=FluxParams(
105
+ in_channels=64,
106
+ out_channels=64,
107
+ vec_in_dim=768,
108
+ context_in_dim=4096,
109
+ hidden_size=3072,
110
+ mlp_ratio=4.0,
111
+ num_heads=24,
112
+ depth=19,
113
+ depth_single_blocks=38,
114
+ axes_dim=[16, 56, 56],
115
+ theta=10_000,
116
+ qkv_bias=True,
117
+ guidance_embed=False,
118
+ ),
119
+ ae_path=os.getenv("AE"),
120
+ ae_params=AutoEncoderParams(
121
+ resolution=256,
122
+ in_channels=3,
123
+ ch=128,
124
+ out_ch=3,
125
+ ch_mult=[1, 2, 4, 4],
126
+ num_res_blocks=2,
127
+ z_channels=16,
128
+ scale_factor=0.3611,
129
+ shift_factor=0.1159,
130
+ ),
131
+ ),
132
+ "flux-dev-canny": ModelSpec(
133
+ repo_id="black-forest-labs/FLUX.1-Canny-dev",
134
+ repo_flow="flux1-canny-dev.safetensors",
135
+ repo_ae="ae.safetensors",
136
+ ckpt_path=os.getenv("FLUX_DEV_CANNY"),
137
+ lora_path=None,
138
+ params=FluxParams(
139
+ in_channels=128,
140
+ out_channels=64,
141
+ vec_in_dim=768,
142
+ context_in_dim=4096,
143
+ hidden_size=3072,
144
+ mlp_ratio=4.0,
145
+ num_heads=24,
146
+ depth=19,
147
+ depth_single_blocks=38,
148
+ axes_dim=[16, 56, 56],
149
+ theta=10_000,
150
+ qkv_bias=True,
151
+ guidance_embed=True,
152
+ ),
153
+ ae_path=os.getenv("AE"),
154
+ ae_params=AutoEncoderParams(
155
+ resolution=256,
156
+ in_channels=3,
157
+ ch=128,
158
+ out_ch=3,
159
+ ch_mult=[1, 2, 4, 4],
160
+ num_res_blocks=2,
161
+ z_channels=16,
162
+ scale_factor=0.3611,
163
+ shift_factor=0.1159,
164
+ ),
165
+ ),
166
+ "flux-dev-canny-lora": ModelSpec(
167
+ repo_id="black-forest-labs/FLUX.1-dev",
168
+ repo_flow="flux1-dev.safetensors",
169
+ repo_ae="ae.safetensors",
170
+ ckpt_path=os.getenv("FLUX_DEV"),
171
+ lora_path=os.getenv("FLUX_DEV_CANNY_LORA"),
172
+ params=FluxParams(
173
+ in_channels=128,
174
+ out_channels=64,
175
+ vec_in_dim=768,
176
+ context_in_dim=4096,
177
+ hidden_size=3072,
178
+ mlp_ratio=4.0,
179
+ num_heads=24,
180
+ depth=19,
181
+ depth_single_blocks=38,
182
+ axes_dim=[16, 56, 56],
183
+ theta=10_000,
184
+ qkv_bias=True,
185
+ guidance_embed=True,
186
+ ),
187
+ ae_path=os.getenv("AE"),
188
+ ae_params=AutoEncoderParams(
189
+ resolution=256,
190
+ in_channels=3,
191
+ ch=128,
192
+ out_ch=3,
193
+ ch_mult=[1, 2, 4, 4],
194
+ num_res_blocks=2,
195
+ z_channels=16,
196
+ scale_factor=0.3611,
197
+ shift_factor=0.1159,
198
+ ),
199
+ ),
200
+ "flux-dev-depth": ModelSpec(
201
+ repo_id="black-forest-labs/FLUX.1-Depth-dev",
202
+ repo_flow="flux1-depth-dev.safetensors",
203
+ repo_ae="ae.safetensors",
204
+ ckpt_path=os.getenv("FLUX_DEV_DEPTH"),
205
+ lora_path=None,
206
+ params=FluxParams(
207
+ in_channels=128,
208
+ out_channels=64,
209
+ vec_in_dim=768,
210
+ context_in_dim=4096,
211
+ hidden_size=3072,
212
+ mlp_ratio=4.0,
213
+ num_heads=24,
214
+ depth=19,
215
+ depth_single_blocks=38,
216
+ axes_dim=[16, 56, 56],
217
+ theta=10_000,
218
+ qkv_bias=True,
219
+ guidance_embed=True,
220
+ ),
221
+ ae_path=os.getenv("AE"),
222
+ ae_params=AutoEncoderParams(
223
+ resolution=256,
224
+ in_channels=3,
225
+ ch=128,
226
+ out_ch=3,
227
+ ch_mult=[1, 2, 4, 4],
228
+ num_res_blocks=2,
229
+ z_channels=16,
230
+ scale_factor=0.3611,
231
+ shift_factor=0.1159,
232
+ ),
233
+ ),
234
+ "flux-dev-depth-lora": ModelSpec(
235
+ repo_id="black-forest-labs/FLUX.1-dev",
236
+ repo_flow="flux1-dev.safetensors",
237
+ repo_ae="ae.safetensors",
238
+ ckpt_path=os.getenv("FLUX_DEV"),
239
+ lora_path=os.getenv("FLUX_DEV_DEPTH_LORA"),
240
+ params=FluxParams(
241
+ in_channels=128,
242
+ out_channels=64,
243
+ vec_in_dim=768,
244
+ context_in_dim=4096,
245
+ hidden_size=3072,
246
+ mlp_ratio=4.0,
247
+ num_heads=24,
248
+ depth=19,
249
+ depth_single_blocks=38,
250
+ axes_dim=[16, 56, 56],
251
+ theta=10_000,
252
+ qkv_bias=True,
253
+ guidance_embed=True,
254
+ ),
255
+ ae_path=os.getenv("AE"),
256
+ ae_params=AutoEncoderParams(
257
+ resolution=256,
258
+ in_channels=3,
259
+ ch=128,
260
+ out_ch=3,
261
+ ch_mult=[1, 2, 4, 4],
262
+ num_res_blocks=2,
263
+ z_channels=16,
264
+ scale_factor=0.3611,
265
+ shift_factor=0.1159,
266
+ ),
267
+ ),
268
+ "flux-dev-fill": ModelSpec(
269
+ repo_id="black-forest-labs/FLUX.1-Fill-dev",
270
+ repo_flow="flux1-fill-dev.safetensors",
271
+ repo_ae="ae.safetensors",
272
+ ckpt_path=os.getenv("FLUX_DEV_FILL"),
273
+ lora_path=None,
274
+ params=FluxParams(
275
+ in_channels=384,
276
+ out_channels=64,
277
+ vec_in_dim=768,
278
+ context_in_dim=4096,
279
+ hidden_size=3072,
280
+ mlp_ratio=4.0,
281
+ num_heads=24,
282
+ depth=19,
283
+ depth_single_blocks=38,
284
+ axes_dim=[16, 56, 56],
285
+ theta=10_000,
286
+ qkv_bias=True,
287
+ guidance_embed=True,
288
+ ),
289
+ ae_path=os.getenv("AE"),
290
+ ae_params=AutoEncoderParams(
291
+ resolution=256,
292
+ in_channels=3,
293
+ ch=128,
294
+ out_ch=3,
295
+ ch_mult=[1, 2, 4, 4],
296
+ num_res_blocks=2,
297
+ z_channels=16,
298
+ scale_factor=0.3611,
299
+ shift_factor=0.1159,
300
+ ),
301
+ ),
302
+ }
303
+
304
+
305
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
306
+ if len(missing) > 0 and len(unexpected) > 0:
307
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
308
+ print("\n" + "-" * 79 + "\n")
309
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
310
+ elif len(missing) > 0:
311
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
312
+ elif len(unexpected) > 0:
313
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
314
+
315
+
316
+ def load_flow_model(
317
+ name: str, device: str | torch.device = "cuda", hf_download: bool = True, verbose: bool = False
318
+ ) -> Flux:
319
+ # Loading Flux
320
+ print("Init model")
321
+ ckpt_path = configs[name].ckpt_path
322
+ lora_path = configs[name].lora_path
323
+ if (
324
+ ckpt_path is None
325
+ and configs[name].repo_id is not None
326
+ and configs[name].repo_flow is not None
327
+ and hf_download
328
+ ):
329
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
330
+
331
+ with torch.device("meta" if ckpt_path is not None else device):
332
+ if lora_path is not None:
333
+ model = FluxLoraWrapper(params=configs[name].params).to(torch.bfloat16)
334
+ else:
335
+ model = Flux(configs[name].params).to(torch.bfloat16)
336
+
337
+ if ckpt_path is not None:
338
+ print("Loading checkpoint")
339
+ # load_sft doesn't support torch.device
340
+ sd = load_sft(ckpt_path, device=str(device))
341
+ sd = optionally_expand_state_dict(model, sd)
342
+ missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
343
+ if verbose:
344
+ print_load_warning(missing, unexpected)
345
+
346
+ if configs[name].lora_path is not None:
347
+ print("Loading LoRA")
348
+ lora_sd = load_sft(configs[name].lora_path, device=str(device))
349
+ # loading the lora params + overwriting scale values in the norms
350
+ missing, unexpected = model.load_state_dict(lora_sd, strict=False, assign=True)
351
+ if verbose:
352
+ print_load_warning(missing, unexpected)
353
+ return model
354
+
355
+
356
+ def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
357
+ # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
358
+ return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
359
+
360
+
361
+ def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
362
+ return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
363
+
364
+
365
+ def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
366
+ ckpt_path = configs[name].ae_path
367
+ if (
368
+ ckpt_path is None
369
+ and configs[name].repo_id is not None
370
+ and configs[name].repo_ae is not None
371
+ and hf_download
372
+ ):
373
+ ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
374
+
375
+ # Loading the autoencoder
376
+ print("Init AE")
377
+ with torch.device("meta" if ckpt_path is not None else device):
378
+ ae = AutoEncoder(configs[name].ae_params)
379
+
380
+ if ckpt_path is not None:
381
+ sd = load_sft(ckpt_path, device=str(device))
382
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
383
+ print_load_warning(missing, unexpected)
384
+ return ae
385
+
386
+
387
+ def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict:
388
+ """
389
+ Optionally expand the state dict to match the model's parameters shapes.
390
+ """
391
+ for name, param in model.named_parameters():
392
+ if name in state_dict:
393
+ if state_dict[name].shape != param.shape:
394
+ print(
395
+ f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}."
396
+ )
397
+ # expand with zeros:
398
+ expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device)
399
+ slices = tuple(slice(0, dim) for dim in state_dict[name].shape)
400
+ expanded_state_dict_weight[slices] = state_dict[name]
401
+ state_dict[name] = expanded_state_dict_weight
402
+
403
+ return state_dict
404
+
405
+
406
+ class WatermarkEmbedder:
407
+ def __init__(self, watermark):
408
+ self.watermark = watermark
409
+ self.num_bits = len(WATERMARK_BITS)
410
+ self.encoder = WatermarkEncoder()
411
+ self.encoder.set_watermark("bits", self.watermark)
412
+
413
+ def __call__(self, image: torch.Tensor) -> torch.Tensor:
414
+ """
415
+ Adds a predefined watermark to the input image
416
+
417
+ Args:
418
+ image: ([N,] B, RGB, H, W) in range [-1, 1]
419
+
420
+ Returns:
421
+ same as input but watermarked
422
+ """
423
+ image = 0.5 * image + 0.5
424
+ squeeze = len(image.shape) == 4
425
+ if squeeze:
426
+ image = image[None, ...]
427
+ n = image.shape[0]
428
+ image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
429
+ # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
430
+ # watermarking libary expects input as cv2 BGR format
431
+ for k in range(image_np.shape[0]):
432
+ image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
433
+ image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
434
+ image.device
435
+ )
436
+ image = torch.clamp(image / 255, min=0.0, max=1.0)
437
+ if squeeze:
438
+ image = image[0]
439
+ image = 2 * image - 1
440
+ return image
441
+
442
+
443
+ # A fixed 48-bit message that was chosen at random
444
+ WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
445
+ # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
446
+ WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
447
+ embed_watermark = WatermarkEmbedder(WATERMARK_BITS)