vilarin commited on
Commit
6533101
·
verified ·
1 Parent(s): 6fdb3d2

Delete flux

Browse files
flux/__init__.py DELETED
@@ -1,13 +0,0 @@
1
- try:
2
- from ._version import (
3
- version as __version__, # type: ignore
4
- version_tuple,
5
- )
6
- except ImportError:
7
- __version__ = "unknown (no version information available)"
8
- version_tuple = (0, 0, "unknown", "noinfo")
9
-
10
- from pathlib import Path
11
-
12
- PACKAGE = __package__.replace("_", "-")
13
- PACKAGE_ROOT = Path(__file__).parent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/__main__.py DELETED
@@ -1,4 +0,0 @@
1
- from .cli import app
2
-
3
- if __name__ == "__main__":
4
- app()
 
 
 
 
 
flux/api.py DELETED
@@ -1,225 +0,0 @@
1
- import io
2
- import os
3
- import time
4
- from pathlib import Path
5
-
6
- import requests
7
- from PIL import Image
8
-
9
- API_URL = "https://api.bfl.ml"
10
- API_ENDPOINTS = {
11
- "flux.1-pro": "flux-pro",
12
- "flux.1-dev": "flux-dev",
13
- "flux.1.1-pro": "flux-pro-1.1",
14
- }
15
-
16
-
17
- class ApiException(Exception):
18
- def __init__(self, status_code: int, detail: str | list[dict] | None = None):
19
- super().__init__()
20
- self.detail = detail
21
- self.status_code = status_code
22
-
23
- def __str__(self) -> str:
24
- return self.__repr__()
25
-
26
- def __repr__(self) -> str:
27
- if self.detail is None:
28
- message = None
29
- elif isinstance(self.detail, str):
30
- message = self.detail
31
- else:
32
- message = "[" + ",".join(d["msg"] for d in self.detail) + "]"
33
- return f"ApiException({self.status_code=}, {message=}, detail={self.detail})"
34
-
35
-
36
- class ImageRequest:
37
- def __init__(
38
- self,
39
- # api inputs
40
- prompt: str,
41
- name: str = "flux.1.1-pro",
42
- width: int | None = None,
43
- height: int | None = None,
44
- num_steps: int | None = None,
45
- prompt_upsampling: bool | None = None,
46
- seed: int | None = None,
47
- guidance: float | None = None,
48
- interval: float | None = None,
49
- safety_tolerance: int | None = None,
50
- # behavior of this class
51
- validate: bool = True,
52
- launch: bool = True,
53
- api_key: str | None = None,
54
- ):
55
- """
56
- Manages an image generation request to the API.
57
-
58
- All parameters not specified will use the API defaults.
59
-
60
- Args:
61
- prompt: Text prompt for image generation.
62
- width: Width of the generated image in pixels. Must be a multiple of 32.
63
- height: Height of the generated image in pixels. Must be a multiple of 32.
64
- name: Which model version to use
65
- num_steps: Number of steps for the image generation process.
66
- prompt_upsampling: Whether to perform upsampling on the prompt.
67
- seed: Optional seed for reproducibility.
68
- guidance: Guidance scale for image generation.
69
- safety_tolerance: Tolerance level for input and output moderation.
70
- Between 0 and 6, 0 being most strict, 6 being least strict.
71
- validate: Run input validation
72
- launch: Directly launches request
73
- api_key: Your API key if not provided by the environment
74
-
75
- Raises:
76
- ValueError: For invalid input, when `validate`
77
- ApiException: For errors raised from the API
78
- """
79
- if validate:
80
- if name not in API_ENDPOINTS.keys():
81
- raise ValueError(f"Invalid model {name}")
82
- elif width is not None and width % 32 != 0:
83
- raise ValueError(f"width must be divisible by 32, got {width}")
84
- elif width is not None and not (256 <= width <= 1440):
85
- raise ValueError(f"width must be between 256 and 1440, got {width}")
86
- elif height is not None and height % 32 != 0:
87
- raise ValueError(f"height must be divisible by 32, got {height}")
88
- elif height is not None and not (256 <= height <= 1440):
89
- raise ValueError(f"height must be between 256 and 1440, got {height}")
90
- elif num_steps is not None and not (1 <= num_steps <= 50):
91
- raise ValueError(f"steps must be between 1 and 50, got {num_steps}")
92
- elif guidance is not None and not (1.5 <= guidance <= 5.0):
93
- raise ValueError(f"guidance must be between 1.5 and 4, got {guidance}")
94
- elif interval is not None and not (1.0 <= interval <= 4.0):
95
- raise ValueError(f"interval must be between 1 and 4, got {interval}")
96
- elif safety_tolerance is not None and not (0 <= safety_tolerance <= 6.0):
97
- raise ValueError(f"safety_tolerance must be between 0 and 6, got {interval}")
98
-
99
- if name == "flux.1-dev":
100
- if interval is not None:
101
- raise ValueError("Interval is not supported for flux.1-dev")
102
- if name == "flux.1.1-pro":
103
- if interval is not None or num_steps is not None or guidance is not None:
104
- raise ValueError("Interval, num_steps and guidance are not supported for " "flux.1.1-pro")
105
-
106
- self.name = name
107
- self.request_json = {
108
- "prompt": prompt,
109
- "width": width,
110
- "height": height,
111
- "steps": num_steps,
112
- "prompt_upsampling": prompt_upsampling,
113
- "seed": seed,
114
- "guidance": guidance,
115
- "interval": interval,
116
- "safety_tolerance": safety_tolerance,
117
- }
118
- self.request_json = {key: value for key, value in self.request_json.items() if value is not None}
119
-
120
- self.request_id: str | None = None
121
- self.result: dict | None = None
122
- self._image_bytes: bytes | None = None
123
- self._url: str | None = None
124
- if api_key is None:
125
- self.api_key = os.environ.get("BFL_API_KEY")
126
- else:
127
- self.api_key = api_key
128
-
129
- if launch:
130
- self.request()
131
-
132
- def request(self):
133
- """
134
- Request to generate the image.
135
- """
136
- if self.request_id is not None:
137
- return
138
- response = requests.post(
139
- f"{API_URL}/v1/{API_ENDPOINTS[self.name]}",
140
- headers={
141
- "accept": "application/json",
142
- "x-key": self.api_key,
143
- "Content-Type": "application/json",
144
- },
145
- json=self.request_json,
146
- )
147
- result = response.json()
148
- if response.status_code != 200:
149
- raise ApiException(status_code=response.status_code, detail=result.get("detail"))
150
- self.request_id = response.json()["id"]
151
-
152
- def retrieve(self) -> dict:
153
- """
154
- Wait for the generation to finish and retrieve response.
155
- """
156
- if self.request_id is None:
157
- self.request()
158
- while self.result is None:
159
- response = requests.get(
160
- f"{API_URL}/v1/get_result",
161
- headers={
162
- "accept": "application/json",
163
- "x-key": self.api_key,
164
- },
165
- params={
166
- "id": self.request_id,
167
- },
168
- )
169
- result = response.json()
170
- if "status" not in result:
171
- raise ApiException(status_code=response.status_code, detail=result.get("detail"))
172
- elif result["status"] == "Ready":
173
- self.result = result["result"]
174
- elif result["status"] == "Pending":
175
- time.sleep(0.5)
176
- else:
177
- raise ApiException(status_code=200, detail=f"API returned status '{result['status']}'")
178
- return self.result
179
-
180
- @property
181
- def bytes(self) -> bytes:
182
- """
183
- Generated image as bytes.
184
- """
185
- if self._image_bytes is None:
186
- response = requests.get(self.url)
187
- if response.status_code == 200:
188
- self._image_bytes = response.content
189
- else:
190
- raise ApiException(status_code=response.status_code)
191
- return self._image_bytes
192
-
193
- @property
194
- def url(self) -> str:
195
- """
196
- Public url to retrieve the image from
197
- """
198
- if self._url is None:
199
- result = self.retrieve()
200
- self._url = result["sample"]
201
- return self._url
202
-
203
- @property
204
- def image(self) -> Image.Image:
205
- """
206
- Load the image as a PIL Image
207
- """
208
- return Image.open(io.BytesIO(self.bytes))
209
-
210
- def save(self, path: str):
211
- """
212
- Save the generated image to a local path
213
- """
214
- suffix = Path(self.url).suffix
215
- if not path.endswith(suffix):
216
- path = path + suffix
217
- Path(path).resolve().parent.mkdir(parents=True, exist_ok=True)
218
- with open(path, "wb") as file:
219
- file.write(self.bytes)
220
-
221
-
222
- if __name__ == "__main__":
223
- from fire import Fire
224
-
225
- Fire(ImageRequest)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/cli.py DELETED
@@ -1,238 +0,0 @@
1
- import os
2
- import re
3
- import time
4
- from dataclasses import dataclass
5
- from glob import iglob
6
-
7
- import torch
8
- from fire import Fire
9
- from transformers import pipeline
10
-
11
- from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
12
- from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
13
-
14
- NSFW_THRESHOLD = 0.85
15
-
16
-
17
- @dataclass
18
- class SamplingOptions:
19
- prompt: str
20
- width: int
21
- height: int
22
- num_steps: int
23
- guidance: float
24
- seed: int | None
25
-
26
-
27
- def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
28
- user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
29
- usage = (
30
- "Usage: Either write your prompt directly, leave this field empty "
31
- "to repeat the prompt or write a command starting with a slash:\n"
32
- "- '/w <width>' will set the width of the generated image\n"
33
- "- '/h <height>' will set the height of the generated image\n"
34
- "- '/s <seed>' sets the next seed\n"
35
- "- '/g <guidance>' sets the guidance (flux-dev only)\n"
36
- "- '/n <steps>' sets the number of steps\n"
37
- "- '/q' to quit"
38
- )
39
-
40
- while (prompt := input(user_question)).startswith("/"):
41
- if prompt.startswith("/w"):
42
- if prompt.count(" ") != 1:
43
- print(f"Got invalid command '{prompt}'\n{usage}")
44
- continue
45
- _, width = prompt.split()
46
- options.width = 16 * (int(width) // 16)
47
- print(
48
- f"Setting resolution to {options.width} x {options.height} "
49
- f"({options.height *options.width/1e6:.2f}MP)"
50
- )
51
- elif prompt.startswith("/h"):
52
- if prompt.count(" ") != 1:
53
- print(f"Got invalid command '{prompt}'\n{usage}")
54
- continue
55
- _, height = prompt.split()
56
- options.height = 16 * (int(height) // 16)
57
- print(
58
- f"Setting resolution to {options.width} x {options.height} "
59
- f"({options.height *options.width/1e6:.2f}MP)"
60
- )
61
- elif prompt.startswith("/g"):
62
- if prompt.count(" ") != 1:
63
- print(f"Got invalid command '{prompt}'\n{usage}")
64
- continue
65
- _, guidance = prompt.split()
66
- options.guidance = float(guidance)
67
- print(f"Setting guidance to {options.guidance}")
68
- elif prompt.startswith("/s"):
69
- if prompt.count(" ") != 1:
70
- print(f"Got invalid command '{prompt}'\n{usage}")
71
- continue
72
- _, seed = prompt.split()
73
- options.seed = int(seed)
74
- print(f"Setting seed to {options.seed}")
75
- elif prompt.startswith("/n"):
76
- if prompt.count(" ") != 1:
77
- print(f"Got invalid command '{prompt}'\n{usage}")
78
- continue
79
- _, steps = prompt.split()
80
- options.num_steps = int(steps)
81
- print(f"Setting number of steps to {options.num_steps}")
82
- elif prompt.startswith("/q"):
83
- print("Quitting")
84
- return None
85
- else:
86
- if not prompt.startswith("/h"):
87
- print(f"Got invalid command '{prompt}'\n{usage}")
88
- print(usage)
89
- if prompt != "":
90
- options.prompt = prompt
91
- return options
92
-
93
-
94
- @torch.inference_mode()
95
- def main(
96
- name: str = "flux-schnell",
97
- width: int = 1360,
98
- height: int = 768,
99
- seed: int | None = None,
100
- prompt: str = (
101
- "a photo of a forest with mist swirling around the tree trunks. The word "
102
- '"FLUX" is painted over it in big, red brush strokes with visible texture'
103
- ),
104
- device: str = "cuda" if torch.cuda.is_available() else "cpu",
105
- num_steps: int | None = None,
106
- loop: bool = False,
107
- guidance: float = 3.5,
108
- offload: bool = False,
109
- output_dir: str = "output",
110
- add_sampling_metadata: bool = True,
111
- ):
112
- """
113
- Sample the flux model. Either interactively (set `--loop`) or run for a
114
- single image.
115
-
116
- Args:
117
- name: Name of the model to load
118
- height: height of the sample in pixels (should be a multiple of 16)
119
- width: width of the sample in pixels (should be a multiple of 16)
120
- seed: Set a seed for sampling
121
- output_name: where to save the output image, `{idx}` will be replaced
122
- by the index of the sample
123
- prompt: Prompt used for sampling
124
- device: Pytorch device
125
- num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
126
- loop: start an interactive session and sample multiple times
127
- guidance: guidance value used for guidance distillation
128
- add_sampling_metadata: Add the prompt to the image Exif metadata
129
- """
130
- nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
131
-
132
- if name not in configs:
133
- available = ", ".join(configs.keys())
134
- raise ValueError(f"Got unknown model name: {name}, chose from {available}")
135
-
136
- torch_device = torch.device(device)
137
- if num_steps is None:
138
- num_steps = 4 if name == "flux-schnell" else 50
139
-
140
- # allow for packing and conversion to latent space
141
- height = 16 * (height // 16)
142
- width = 16 * (width // 16)
143
-
144
- output_name = os.path.join(output_dir, "img_{idx}.jpg")
145
- if not os.path.exists(output_dir):
146
- os.makedirs(output_dir)
147
- idx = 0
148
- else:
149
- fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
150
- if len(fns) > 0:
151
- idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
152
- else:
153
- idx = 0
154
-
155
- # init all components
156
- t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
157
- clip = load_clip(torch_device)
158
- model = load_flow_model(name, device="cpu" if offload else torch_device)
159
- ae = load_ae(name, device="cpu" if offload else torch_device)
160
-
161
- rng = torch.Generator(device="cpu")
162
- opts = SamplingOptions(
163
- prompt=prompt,
164
- width=width,
165
- height=height,
166
- num_steps=num_steps,
167
- guidance=guidance,
168
- seed=seed,
169
- )
170
-
171
- if loop:
172
- opts = parse_prompt(opts)
173
-
174
- while opts is not None:
175
- if opts.seed is None:
176
- opts.seed = rng.seed()
177
- print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
178
- t0 = time.perf_counter()
179
-
180
- # prepare input
181
- x = get_noise(
182
- 1,
183
- opts.height,
184
- opts.width,
185
- device=torch_device,
186
- dtype=torch.bfloat16,
187
- seed=opts.seed,
188
- )
189
- opts.seed = None
190
- if offload:
191
- ae = ae.cpu()
192
- torch.cuda.empty_cache()
193
- t5, clip = t5.to(torch_device), clip.to(torch_device)
194
- inp = prepare(t5, clip, x, prompt=opts.prompt)
195
- timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
196
-
197
- # offload TEs to CPU, load model to gpu
198
- if offload:
199
- t5, clip = t5.cpu(), clip.cpu()
200
- torch.cuda.empty_cache()
201
- model = model.to(torch_device)
202
-
203
- # denoise initial noise
204
- x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
205
-
206
- # offload model, load autoencoder to gpu
207
- if offload:
208
- model.cpu()
209
- torch.cuda.empty_cache()
210
- ae.decoder.to(x.device)
211
-
212
- # decode latents to pixel space
213
- x = unpack(x.float(), opts.height, opts.width)
214
- with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
215
- x = ae.decode(x)
216
-
217
- if torch.cuda.is_available():
218
- torch.cuda.synchronize()
219
- t1 = time.perf_counter()
220
-
221
- fn = output_name.format(idx=idx)
222
- print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
223
-
224
- idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt)
225
-
226
- if loop:
227
- print("-" * 80)
228
- opts = parse_prompt(opts)
229
- else:
230
- opts = None
231
-
232
-
233
- def app():
234
- Fire(main)
235
-
236
-
237
- if __name__ == "__main__":
238
- app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/cli_control.py DELETED
@@ -1,347 +0,0 @@
1
- import os
2
- import re
3
- import time
4
- from dataclasses import dataclass
5
- from glob import iglob
6
-
7
- import torch
8
- from fire import Fire
9
- from transformers import pipeline
10
-
11
- from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder
12
- from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack
13
- from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
14
-
15
-
16
- @dataclass
17
- class SamplingOptions:
18
- prompt: str
19
- width: int
20
- height: int
21
- num_steps: int
22
- guidance: float
23
- seed: int | None
24
- img_cond_path: str
25
- lora_scale: float | None
26
-
27
-
28
- def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
29
- user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
30
- usage = (
31
- "Usage: Either write your prompt directly, leave this field empty "
32
- "to repeat the prompt or write a command starting with a slash:\n"
33
- "- '/w <width>' will set the width of the generated image\n"
34
- "- '/h <height>' will set the height of the generated image\n"
35
- "- '/s <seed>' sets the next seed\n"
36
- "- '/g <guidance>' sets the guidance (flux-dev only)\n"
37
- "- '/n <steps>' sets the number of steps\n"
38
- "- '/q' to quit"
39
- )
40
-
41
- while (prompt := input(user_question)).startswith("/"):
42
- if prompt.startswith("/w"):
43
- if prompt.count(" ") != 1:
44
- print(f"Got invalid command '{prompt}'\n{usage}")
45
- continue
46
- _, width = prompt.split()
47
- options.width = 16 * (int(width) // 16)
48
- print(
49
- f"Setting resolution to {options.width} x {options.height} "
50
- f"({options.height *options.width/1e6:.2f}MP)"
51
- )
52
- elif prompt.startswith("/h"):
53
- if prompt.count(" ") != 1:
54
- print(f"Got invalid command '{prompt}'\n{usage}")
55
- continue
56
- _, height = prompt.split()
57
- options.height = 16 * (int(height) // 16)
58
- print(
59
- f"Setting resolution to {options.width} x {options.height} "
60
- f"({options.height *options.width/1e6:.2f}MP)"
61
- )
62
- elif prompt.startswith("/g"):
63
- if prompt.count(" ") != 1:
64
- print(f"Got invalid command '{prompt}'\n{usage}")
65
- continue
66
- _, guidance = prompt.split()
67
- options.guidance = float(guidance)
68
- print(f"Setting guidance to {options.guidance}")
69
- elif prompt.startswith("/s"):
70
- if prompt.count(" ") != 1:
71
- print(f"Got invalid command '{prompt}'\n{usage}")
72
- continue
73
- _, seed = prompt.split()
74
- options.seed = int(seed)
75
- print(f"Setting seed to {options.seed}")
76
- elif prompt.startswith("/n"):
77
- if prompt.count(" ") != 1:
78
- print(f"Got invalid command '{prompt}'\n{usage}")
79
- continue
80
- _, steps = prompt.split()
81
- options.num_steps = int(steps)
82
- print(f"Setting number of steps to {options.num_steps}")
83
- elif prompt.startswith("/q"):
84
- print("Quitting")
85
- return None
86
- else:
87
- if not prompt.startswith("/h"):
88
- print(f"Got invalid command '{prompt}'\n{usage}")
89
- print(usage)
90
- if prompt != "":
91
- options.prompt = prompt
92
- return options
93
-
94
-
95
- def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
96
- if options is None:
97
- return None
98
-
99
- user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
100
- usage = (
101
- "Usage: Either write your prompt directly, leave this field empty "
102
- "to repeat the conditioning image or write a command starting with a slash:\n"
103
- "- '/q' to quit"
104
- )
105
-
106
- while True:
107
- img_cond_path = input(user_question)
108
-
109
- if img_cond_path.startswith("/"):
110
- if img_cond_path.startswith("/q"):
111
- print("Quitting")
112
- return None
113
- else:
114
- if not img_cond_path.startswith("/h"):
115
- print(f"Got invalid command '{img_cond_path}'\n{usage}")
116
- print(usage)
117
- continue
118
-
119
- if img_cond_path == "":
120
- break
121
-
122
- if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
123
- (".jpg", ".jpeg", ".png", ".webp")
124
- ):
125
- print(f"File '{img_cond_path}' does not exist or is not a valid image file")
126
- continue
127
-
128
- options.img_cond_path = img_cond_path
129
- break
130
-
131
- return options
132
-
133
-
134
- def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingOptions | None, bool]:
135
- changed = False
136
-
137
- if options is None:
138
- return None, changed
139
-
140
- user_question = "Next lora scale (write /h for help, /q to quit and leave empty to repeat):\n"
141
- usage = (
142
- "Usage: Either write your prompt directly, leave this field empty "
143
- "to repeat the lora scale or write a command starting with a slash:\n"
144
- "- '/q' to quit"
145
- )
146
-
147
- while (prompt := input(user_question)).startswith("/"):
148
- if prompt.startswith("/q"):
149
- print("Quitting")
150
- return None, changed
151
- else:
152
- if not prompt.startswith("/h"):
153
- print(f"Got invalid command '{prompt}'\n{usage}")
154
- print(usage)
155
- if prompt != "":
156
- options.lora_scale = float(prompt)
157
- changed = True
158
- return options, changed
159
-
160
-
161
- @torch.inference_mode()
162
- def main(
163
- name: str,
164
- width: int = 1024,
165
- height: int = 1024,
166
- seed: int | None = None,
167
- prompt: str = "a robot made out of gold",
168
- device: str = "cuda" if torch.cuda.is_available() else "cpu",
169
- num_steps: int = 50,
170
- loop: bool = False,
171
- guidance: float | None = None,
172
- offload: bool = False,
173
- output_dir: str = "output",
174
- add_sampling_metadata: bool = True,
175
- img_cond_path: str = "assets/robot.webp",
176
- lora_scale: float | None = 0.85,
177
- ):
178
- """
179
- Sample the flux model. Either interactively (set `--loop`) or run for a
180
- single image.
181
-
182
- Args:
183
- height: height of the sample in pixels (should be a multiple of 16)
184
- width: width of the sample in pixels (should be a multiple of 16)
185
- seed: Set a seed for sampling
186
- output_name: where to save the output image, `{idx}` will be replaced
187
- by the index of the sample
188
- prompt: Prompt used for sampling
189
- device: Pytorch device
190
- num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
191
- loop: start an interactive session and sample multiple times
192
- guidance: guidance value used for guidance distillation
193
- add_sampling_metadata: Add the prompt to the image Exif metadata
194
- img_cond_path: path to conditioning image (jpeg/png/webp)
195
- """
196
- nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
197
-
198
- assert name in [
199
- "flux-dev-canny",
200
- "flux-dev-depth",
201
- "flux-dev-canny-lora",
202
- "flux-dev-depth-lora",
203
- ], f"Got unknown model name: {name}"
204
- if guidance is None:
205
- if name in ["flux-dev-canny", "flux-dev-canny-lora"]:
206
- guidance = 30.0
207
- elif name in ["flux-dev-depth", "flux-dev-depth-lora"]:
208
- guidance = 10.0
209
- else:
210
- raise NotImplementedError()
211
-
212
- if name not in configs:
213
- available = ", ".join(configs.keys())
214
- raise ValueError(f"Got unknown model name: {name}, chose from {available}")
215
-
216
- torch_device = torch.device(device)
217
-
218
- output_name = os.path.join(output_dir, "img_{idx}.jpg")
219
- if not os.path.exists(output_dir):
220
- os.makedirs(output_dir)
221
- idx = 0
222
- else:
223
- fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
224
- if len(fns) > 0:
225
- idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
226
- else:
227
- idx = 0
228
-
229
- # init all components
230
- t5 = load_t5(torch_device, max_length=512)
231
- clip = load_clip(torch_device)
232
- model = load_flow_model(name, device="cpu" if offload else torch_device)
233
- ae = load_ae(name, device="cpu" if offload else torch_device)
234
-
235
- # set lora scale
236
- if "lora" in name and lora_scale is not None:
237
- for _, module in model.named_modules():
238
- if hasattr(module, "set_scale"):
239
- module.set_scale(lora_scale)
240
-
241
- if name in ["flux-dev-depth", "flux-dev-depth-lora"]:
242
- img_embedder = DepthImageEncoder(torch_device)
243
- elif name in ["flux-dev-canny", "flux-dev-canny-lora"]:
244
- img_embedder = CannyImageEncoder(torch_device)
245
- else:
246
- raise NotImplementedError()
247
-
248
- rng = torch.Generator(device="cpu")
249
- opts = SamplingOptions(
250
- prompt=prompt,
251
- width=width,
252
- height=height,
253
- num_steps=num_steps,
254
- guidance=guidance,
255
- seed=seed,
256
- img_cond_path=img_cond_path,
257
- lora_scale=lora_scale,
258
- )
259
-
260
- if loop:
261
- opts = parse_prompt(opts)
262
- opts = parse_img_cond_path(opts)
263
- if "lora" in name:
264
- opts, changed = parse_lora_scale(opts)
265
- if changed:
266
- # update the lora scale:
267
- for _, module in model.named_modules():
268
- if hasattr(module, "set_scale"):
269
- module.set_scale(opts.lora_scale)
270
-
271
- while opts is not None:
272
- if opts.seed is None:
273
- opts.seed = rng.seed()
274
- print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
275
- t0 = time.perf_counter()
276
-
277
- # prepare input
278
- x = get_noise(
279
- 1,
280
- opts.height,
281
- opts.width,
282
- device=torch_device,
283
- dtype=torch.bfloat16,
284
- seed=opts.seed,
285
- )
286
- opts.seed = None
287
- if offload:
288
- t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
289
- inp = prepare_control(
290
- t5,
291
- clip,
292
- x,
293
- prompt=opts.prompt,
294
- ae=ae,
295
- encoder=img_embedder,
296
- img_cond_path=opts.img_cond_path,
297
- )
298
- timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
299
-
300
- # offload TEs and AE to CPU, load model to gpu
301
- if offload:
302
- t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
303
- torch.cuda.empty_cache()
304
- model = model.to(torch_device)
305
-
306
- # denoise initial noise
307
- x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
308
-
309
- # offload model, load autoencoder to gpu
310
- if offload:
311
- model.cpu()
312
- torch.cuda.empty_cache()
313
- ae.decoder.to(x.device)
314
-
315
- # decode latents to pixel space
316
- x = unpack(x.float(), opts.height, opts.width)
317
- with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
318
- x = ae.decode(x)
319
-
320
- if torch.cuda.is_available():
321
- torch.cuda.synchronize()
322
- t1 = time.perf_counter()
323
- print(f"Done in {t1 - t0:.1f}s")
324
-
325
- idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt)
326
-
327
- if loop:
328
- print("-" * 80)
329
- opts = parse_prompt(opts)
330
- opts = parse_img_cond_path(opts)
331
- if "lora" in name:
332
- opts, changed = parse_lora_scale(opts)
333
- if changed:
334
- # update the lora scale:
335
- for _, module in model.named_modules():
336
- if hasattr(module, "set_scale"):
337
- module.set_scale(opts.lora_scale)
338
- else:
339
- opts = None
340
-
341
-
342
- def app():
343
- Fire(main)
344
-
345
-
346
- if __name__ == "__main__":
347
- app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/cli_fill.py DELETED
@@ -1,334 +0,0 @@
1
- import os
2
- import re
3
- import time
4
- from dataclasses import dataclass
5
- from glob import iglob
6
-
7
- import torch
8
- from fire import Fire
9
- from PIL import Image
10
- from transformers import pipeline
11
-
12
- from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack
13
- from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
14
-
15
-
16
- @dataclass
17
- class SamplingOptions:
18
- prompt: str
19
- width: int
20
- height: int
21
- num_steps: int
22
- guidance: float
23
- seed: int | None
24
- img_cond_path: str
25
- img_mask_path: str
26
-
27
-
28
- def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
29
- user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
30
- usage = (
31
- "Usage: Either write your prompt directly, leave this field empty "
32
- "to repeat the prompt or write a command starting with a slash:\n"
33
- "- '/s <seed>' sets the next seed\n"
34
- "- '/g <guidance>' sets the guidance (flux-dev only)\n"
35
- "- '/n <steps>' sets the number of steps\n"
36
- "- '/q' to quit"
37
- )
38
-
39
- while (prompt := input(user_question)).startswith("/"):
40
- if prompt.startswith("/g"):
41
- if prompt.count(" ") != 1:
42
- print(f"Got invalid command '{prompt}'\n{usage}")
43
- continue
44
- _, guidance = prompt.split()
45
- options.guidance = float(guidance)
46
- print(f"Setting guidance to {options.guidance}")
47
- elif prompt.startswith("/s"):
48
- if prompt.count(" ") != 1:
49
- print(f"Got invalid command '{prompt}'\n{usage}")
50
- continue
51
- _, seed = prompt.split()
52
- options.seed = int(seed)
53
- print(f"Setting seed to {options.seed}")
54
- elif prompt.startswith("/n"):
55
- if prompt.count(" ") != 1:
56
- print(f"Got invalid command '{prompt}'\n{usage}")
57
- continue
58
- _, steps = prompt.split()
59
- options.num_steps = int(steps)
60
- print(f"Setting number of steps to {options.num_steps}")
61
- elif prompt.startswith("/q"):
62
- print("Quitting")
63
- return None
64
- else:
65
- if not prompt.startswith("/h"):
66
- print(f"Got invalid command '{prompt}'\n{usage}")
67
- print(usage)
68
- if prompt != "":
69
- options.prompt = prompt
70
- return options
71
-
72
-
73
- def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
74
- if options is None:
75
- return None
76
-
77
- user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
78
- usage = (
79
- "Usage: Either write your prompt directly, leave this field empty "
80
- "to repeat the conditioning image or write a command starting with a slash:\n"
81
- "- '/q' to quit"
82
- )
83
-
84
- while True:
85
- img_cond_path = input(user_question)
86
-
87
- if img_cond_path.startswith("/"):
88
- if img_cond_path.startswith("/q"):
89
- print("Quitting")
90
- return None
91
- else:
92
- if not img_cond_path.startswith("/h"):
93
- print(f"Got invalid command '{img_cond_path}'\n{usage}")
94
- print(usage)
95
- continue
96
-
97
- if img_cond_path == "":
98
- break
99
-
100
- if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
101
- (".jpg", ".jpeg", ".png", ".webp")
102
- ):
103
- print(f"File '{img_cond_path}' does not exist or is not a valid image file")
104
- continue
105
- else:
106
- with Image.open(img_cond_path) as img:
107
- width, height = img.size
108
-
109
- if width % 32 != 0 or height % 32 != 0:
110
- print(f"Image dimensions must be divisible by 32, got {width}x{height}")
111
- continue
112
-
113
- options.img_cond_path = img_cond_path
114
- break
115
-
116
- return options
117
-
118
-
119
- def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | None:
120
- if options is None:
121
- return None
122
-
123
- user_question = "Next conditioning mask (write /h for help, /q to quit and leave empty to repeat):\n"
124
- usage = (
125
- "Usage: Either write your prompt directly, leave this field empty "
126
- "to repeat the conditioning mask or write a command starting with a slash:\n"
127
- "- '/q' to quit"
128
- )
129
-
130
- while True:
131
- img_mask_path = input(user_question)
132
-
133
- if img_mask_path.startswith("/"):
134
- if img_mask_path.startswith("/q"):
135
- print("Quitting")
136
- return None
137
- else:
138
- if not img_mask_path.startswith("/h"):
139
- print(f"Got invalid command '{img_mask_path}'\n{usage}")
140
- print(usage)
141
- continue
142
-
143
- if img_mask_path == "":
144
- break
145
-
146
- if not os.path.isfile(img_mask_path) or not img_mask_path.lower().endswith(
147
- (".jpg", ".jpeg", ".png", ".webp")
148
- ):
149
- print(f"File '{img_mask_path}' does not exist or is not a valid image file")
150
- continue
151
- else:
152
- with Image.open(img_mask_path) as img:
153
- width, height = img.size
154
-
155
- if width % 32 != 0 or height % 32 != 0:
156
- print(f"Image dimensions must be divisible by 32, got {width}x{height}")
157
- continue
158
- else:
159
- with Image.open(options.img_cond_path) as img_cond:
160
- img_cond_width, img_cond_height = img_cond.size
161
-
162
- if width != img_cond_width or height != img_cond_height:
163
- print(
164
- f"Mask dimensions must match conditioning image, got {width}x{height} and {img_cond_width}x{img_cond_height}"
165
- )
166
- continue
167
-
168
- options.img_mask_path = img_mask_path
169
- break
170
-
171
- return options
172
-
173
-
174
- @torch.inference_mode()
175
- def main(
176
- seed: int | None = None,
177
- prompt: str = "a white paper cup",
178
- device: str = "cuda" if torch.cuda.is_available() else "cpu",
179
- num_steps: int = 50,
180
- loop: bool = False,
181
- guidance: float = 30.0,
182
- offload: bool = False,
183
- output_dir: str = "output",
184
- add_sampling_metadata: bool = True,
185
- img_cond_path: str = "assets/cup.png",
186
- img_mask_path: str = "assets/cup_mask.png",
187
- ):
188
- """
189
- Sample the flux model. Either interactively (set `--loop`) or run for a
190
- single image. This demo assumes that the conditioning image and mask have
191
- the same shape and that height and width are divisible by 32.
192
-
193
- Args:
194
- seed: Set a seed for sampling
195
- output_name: where to save the output image, `{idx}` will be replaced
196
- by the index of the sample
197
- prompt: Prompt used for sampling
198
- device: Pytorch device
199
- num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
200
- loop: start an interactive session and sample multiple times
201
- guidance: guidance value used for guidance distillation
202
- add_sampling_metadata: Add the prompt to the image Exif metadata
203
- img_cond_path: path to conditioning image (jpeg/png/webp)
204
- img_mask_path: path to conditioning mask (jpeg/png/webp
205
- """
206
- nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
207
-
208
- name = "flux-dev-fill"
209
- if name not in configs:
210
- available = ", ".join(configs.keys())
211
- raise ValueError(f"Got unknown model name: {name}, chose from {available}")
212
-
213
- torch_device = torch.device(device)
214
-
215
- output_name = os.path.join(output_dir, "img_{idx}.jpg")
216
- if not os.path.exists(output_dir):
217
- os.makedirs(output_dir)
218
- idx = 0
219
- else:
220
- fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
221
- if len(fns) > 0:
222
- idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
223
- else:
224
- idx = 0
225
-
226
- # init all components
227
- t5 = load_t5(torch_device, max_length=128)
228
- clip = load_clip(torch_device)
229
- model = load_flow_model(name, device="cpu" if offload else torch_device)
230
- ae = load_ae(name, device="cpu" if offload else torch_device)
231
-
232
- rng = torch.Generator(device="cpu")
233
- with Image.open(img_cond_path) as img:
234
- width, height = img.size
235
- opts = SamplingOptions(
236
- prompt=prompt,
237
- width=width,
238
- height=height,
239
- num_steps=num_steps,
240
- guidance=guidance,
241
- seed=seed,
242
- img_cond_path=img_cond_path,
243
- img_mask_path=img_mask_path,
244
- )
245
-
246
- if loop:
247
- opts = parse_prompt(opts)
248
- opts = parse_img_cond_path(opts)
249
-
250
- with Image.open(opts.img_cond_path) as img:
251
- width, height = img.size
252
- opts.height = height
253
- opts.width = width
254
-
255
- opts = parse_img_mask_path(opts)
256
-
257
- while opts is not None:
258
- if opts.seed is None:
259
- opts.seed = rng.seed()
260
- print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
261
- t0 = time.perf_counter()
262
-
263
- # prepare input
264
- x = get_noise(
265
- 1,
266
- opts.height,
267
- opts.width,
268
- device=torch_device,
269
- dtype=torch.bfloat16,
270
- seed=opts.seed,
271
- )
272
- opts.seed = None
273
- if offload:
274
- t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch.device)
275
- inp = prepare_fill(
276
- t5,
277
- clip,
278
- x,
279
- prompt=opts.prompt,
280
- ae=ae,
281
- img_cond_path=opts.img_cond_path,
282
- mask_path=opts.img_mask_path,
283
- )
284
-
285
- timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
286
-
287
- # offload TEs and AE to CPU, load model to gpu
288
- if offload:
289
- t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
290
- torch.cuda.empty_cache()
291
- model = model.to(torch_device)
292
-
293
- # denoise initial noise
294
- x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
295
-
296
- # offload model, load autoencoder to gpu
297
- if offload:
298
- model.cpu()
299
- torch.cuda.empty_cache()
300
- ae.decoder.to(x.device)
301
-
302
- # decode latents to pixel space
303
- x = unpack(x.float(), opts.height, opts.width)
304
- with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
305
- x = ae.decode(x)
306
-
307
- if torch.cuda.is_available():
308
- torch.cuda.synchronize()
309
- t1 = time.perf_counter()
310
- print(f"Done in {t1 - t0:.1f}s")
311
-
312
- idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt)
313
-
314
- if loop:
315
- print("-" * 80)
316
- opts = parse_prompt(opts)
317
- opts = parse_img_cond_path(opts)
318
-
319
- with Image.open(opts.img_cond_path) as img:
320
- width, height = img.size
321
- opts.height = height
322
- opts.width = width
323
-
324
- opts = parse_img_mask_path(opts)
325
- else:
326
- opts = None
327
-
328
-
329
- def app():
330
- Fire(main)
331
-
332
-
333
- if __name__ == "__main__":
334
- app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/cli_redux.py DELETED
@@ -1,279 +0,0 @@
1
- import os
2
- import re
3
- import time
4
- from dataclasses import dataclass
5
- from glob import iglob
6
-
7
- import torch
8
- from fire import Fire
9
- from transformers import pipeline
10
-
11
- from flux.modules.image_embedders import ReduxImageEncoder
12
- from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack
13
- from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
14
-
15
-
16
- @dataclass
17
- class SamplingOptions:
18
- prompt: str
19
- width: int
20
- height: int
21
- num_steps: int
22
- guidance: float
23
- seed: int | None
24
- img_cond_path: str
25
-
26
-
27
- def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
28
- user_question = "Write /h for help, /q to quit and leave empty to repeat):\n"
29
- usage = (
30
- "Usage: Leave this field empty to do nothing "
31
- "or write a command starting with a slash:\n"
32
- "- '/w <width>' will set the width of the generated image\n"
33
- "- '/h <height>' will set the height of the generated image\n"
34
- "- '/s <seed>' sets the next seed\n"
35
- "- '/g <guidance>' sets the guidance (flux-dev only)\n"
36
- "- '/n <steps>' sets the number of steps\n"
37
- "- '/q' to quit"
38
- )
39
-
40
- while (prompt := input(user_question)).startswith("/"):
41
- if prompt.startswith("/w"):
42
- if prompt.count(" ") != 1:
43
- print(f"Got invalid command '{prompt}'\n{usage}")
44
- continue
45
- _, width = prompt.split()
46
- options.width = 16 * (int(width) // 16)
47
- print(
48
- f"Setting resolution to {options.width} x {options.height} "
49
- f"({options.height *options.width/1e6:.2f}MP)"
50
- )
51
- elif prompt.startswith("/h"):
52
- if prompt.count(" ") != 1:
53
- print(f"Got invalid command '{prompt}'\n{usage}")
54
- continue
55
- _, height = prompt.split()
56
- options.height = 16 * (int(height) // 16)
57
- print(
58
- f"Setting resolution to {options.width} x {options.height} "
59
- f"({options.height *options.width/1e6:.2f}MP)"
60
- )
61
- elif prompt.startswith("/g"):
62
- if prompt.count(" ") != 1:
63
- print(f"Got invalid command '{prompt}'\n{usage}")
64
- continue
65
- _, guidance = prompt.split()
66
- options.guidance = float(guidance)
67
- print(f"Setting guidance to {options.guidance}")
68
- elif prompt.startswith("/s"):
69
- if prompt.count(" ") != 1:
70
- print(f"Got invalid command '{prompt}'\n{usage}")
71
- continue
72
- _, seed = prompt.split()
73
- options.seed = int(seed)
74
- print(f"Setting seed to {options.seed}")
75
- elif prompt.startswith("/n"):
76
- if prompt.count(" ") != 1:
77
- print(f"Got invalid command '{prompt}'\n{usage}")
78
- continue
79
- _, steps = prompt.split()
80
- options.num_steps = int(steps)
81
- print(f"Setting number of steps to {options.num_steps}")
82
- elif prompt.startswith("/q"):
83
- print("Quitting")
84
- return None
85
- else:
86
- if not prompt.startswith("/h"):
87
- print(f"Got invalid command '{prompt}'\n{usage}")
88
- print(usage)
89
- return options
90
-
91
-
92
- def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
93
- if options is None:
94
- return None
95
-
96
- user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
97
- usage = (
98
- "Usage: Either write your prompt directly, leave this field empty "
99
- "to repeat the conditioning image or write a command starting with a slash:\n"
100
- "- '/q' to quit"
101
- )
102
-
103
- while True:
104
- img_cond_path = input(user_question)
105
-
106
- if img_cond_path.startswith("/"):
107
- if img_cond_path.startswith("/q"):
108
- print("Quitting")
109
- return None
110
- else:
111
- if not img_cond_path.startswith("/h"):
112
- print(f"Got invalid command '{img_cond_path}'\n{usage}")
113
- print(usage)
114
- continue
115
-
116
- if img_cond_path == "":
117
- break
118
-
119
- if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
120
- (".jpg", ".jpeg", ".png", ".webp")
121
- ):
122
- print(f"File '{img_cond_path}' does not exist or is not a valid image file")
123
- continue
124
-
125
- options.img_cond_path = img_cond_path
126
- break
127
-
128
- return options
129
-
130
-
131
- @torch.inference_mode()
132
- def main(
133
- name: str = "flux-dev",
134
- width: int = 1360,
135
- height: int = 768,
136
- seed: int | None = None,
137
- device: str = "cuda" if torch.cuda.is_available() else "cpu",
138
- num_steps: int | None = None,
139
- loop: bool = False,
140
- guidance: float = 2.5,
141
- offload: bool = False,
142
- output_dir: str = "output",
143
- add_sampling_metadata: bool = True,
144
- img_cond_path: str = "assets/robot.webp",
145
- ):
146
- """
147
- Sample the flux model. Either interactively (set `--loop`) or run for a
148
- single image.
149
-
150
- Args:
151
- name: Name of the model to load
152
- height: height of the sample in pixels (should be a multiple of 16)
153
- width: width of the sample in pixels (should be a multiple of 16)
154
- seed: Set a seed for sampling
155
- output_name: where to save the output image, `{idx}` will be replaced
156
- by the index of the sample
157
- prompt: Prompt used for sampling
158
- device: Pytorch device
159
- num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
160
- loop: start an interactive session and sample multiple times
161
- guidance: guidance value used for guidance distillation
162
- add_sampling_metadata: Add the prompt to the image Exif metadata
163
- img_cond_path: path to conditioning image (jpeg/png/webp)
164
- """
165
- nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
166
-
167
- if name not in configs:
168
- available = ", ".join(configs.keys())
169
- raise ValueError(f"Got unknown model name: {name}, chose from {available}")
170
-
171
- torch_device = torch.device(device)
172
- if num_steps is None:
173
- num_steps = 4 if name == "flux-schnell" else 50
174
-
175
- output_name = os.path.join(output_dir, "img_{idx}.jpg")
176
- if not os.path.exists(output_dir):
177
- os.makedirs(output_dir)
178
- idx = 0
179
- else:
180
- fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
181
- if len(fns) > 0:
182
- idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
183
- else:
184
- idx = 0
185
-
186
- # init all components
187
- t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
188
- clip = load_clip(torch_device)
189
- model = load_flow_model(name, device="cpu" if offload else torch_device)
190
- ae = load_ae(name, device="cpu" if offload else torch_device)
191
- img_embedder = ReduxImageEncoder(torch_device)
192
-
193
- rng = torch.Generator(device="cpu")
194
- prompt = ""
195
- opts = SamplingOptions(
196
- prompt=prompt,
197
- width=width,
198
- height=height,
199
- num_steps=num_steps,
200
- guidance=guidance,
201
- seed=seed,
202
- img_cond_path=img_cond_path,
203
- )
204
-
205
- if loop:
206
- opts = parse_prompt(opts)
207
- opts = parse_img_cond_path(opts)
208
-
209
- while opts is not None:
210
- if opts.seed is None:
211
- opts.seed = rng.seed()
212
- print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
213
- t0 = time.perf_counter()
214
-
215
- # prepare input
216
- x = get_noise(
217
- 1,
218
- opts.height,
219
- opts.width,
220
- device=torch_device,
221
- dtype=torch.bfloat16,
222
- seed=opts.seed,
223
- )
224
- opts.seed = None
225
- if offload:
226
- ae = ae.cpu()
227
- torch.cuda.empty_cache()
228
- t5, clip = t5.to(torch_device), clip.to(torch_device)
229
- inp = prepare_redux(
230
- t5,
231
- clip,
232
- x,
233
- prompt=opts.prompt,
234
- encoder=img_embedder,
235
- img_cond_path=opts.img_cond_path,
236
- )
237
- timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
238
-
239
- # offload TEs to CPU, load model to gpu
240
- if offload:
241
- t5, clip = t5.cpu(), clip.cpu()
242
- torch.cuda.empty_cache()
243
- model = model.to(torch_device)
244
-
245
- # denoise initial noise
246
- x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
247
-
248
- # offload model, load autoencoder to gpu
249
- if offload:
250
- model.cpu()
251
- torch.cuda.empty_cache()
252
- ae.decoder.to(x.device)
253
-
254
- # decode latents to pixel space
255
- x = unpack(x.float(), opts.height, opts.width)
256
- with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
257
- x = ae.decode(x)
258
-
259
- if torch.cuda.is_available():
260
- torch.cuda.synchronize()
261
- t1 = time.perf_counter()
262
- print(f"Done in {t1 - t0:.1f}s")
263
-
264
- idx = save_image(nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt)
265
-
266
- if loop:
267
- print("-" * 80)
268
- opts = parse_prompt(opts)
269
- opts = parse_img_cond_path(opts)
270
- else:
271
- opts = None
272
-
273
-
274
- def app():
275
- Fire(main)
276
-
277
-
278
- if __name__ == "__main__":
279
- app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/math.py DELETED
@@ -1,30 +0,0 @@
1
- import torch
2
- from einops import rearrange
3
- from torch import Tensor
4
-
5
-
6
- def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
- q, k = apply_rope(q, k, pe)
8
-
9
- x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
10
- x = rearrange(x, "B H L D -> B L (H D)")
11
-
12
- return x
13
-
14
-
15
- def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
16
- assert dim % 2 == 0
17
- scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
18
- omega = 1.0 / (theta**scale)
19
- out = torch.einsum("...n,d->...nd", pos, omega)
20
- out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
21
- out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
22
- return out.float()
23
-
24
-
25
- def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
26
- xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
27
- xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
28
- xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
29
- xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
30
- return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/model.py DELETED
@@ -1,143 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import torch
4
- from torch import Tensor, nn
5
-
6
- from flux.modules.layers import (
7
- DoubleStreamBlock,
8
- EmbedND,
9
- LastLayer,
10
- MLPEmbedder,
11
- SingleStreamBlock,
12
- timestep_embedding,
13
- )
14
- from flux.modules.lora import LinearLora, replace_linear_with_lora
15
-
16
-
17
- @dataclass
18
- class FluxParams:
19
- in_channels: int
20
- out_channels: int
21
- vec_in_dim: int
22
- context_in_dim: int
23
- hidden_size: int
24
- mlp_ratio: float
25
- num_heads: int
26
- depth: int
27
- depth_single_blocks: int
28
- axes_dim: list[int]
29
- theta: int
30
- qkv_bias: bool
31
- guidance_embed: bool
32
-
33
-
34
- class Flux(nn.Module):
35
- """
36
- Transformer model for flow matching on sequences.
37
- """
38
-
39
- def __init__(self, params: FluxParams):
40
- super().__init__()
41
-
42
- self.params = params
43
- self.in_channels = params.in_channels
44
- self.out_channels = params.out_channels
45
- if params.hidden_size % params.num_heads != 0:
46
- raise ValueError(
47
- f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
48
- )
49
- pe_dim = params.hidden_size // params.num_heads
50
- if sum(params.axes_dim) != pe_dim:
51
- raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
52
- self.hidden_size = params.hidden_size
53
- self.num_heads = params.num_heads
54
- self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
55
- self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
56
- self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
57
- self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
58
- self.guidance_in = (
59
- MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
60
- )
61
- self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
62
-
63
- self.double_blocks = nn.ModuleList(
64
- [
65
- DoubleStreamBlock(
66
- self.hidden_size,
67
- self.num_heads,
68
- mlp_ratio=params.mlp_ratio,
69
- qkv_bias=params.qkv_bias,
70
- )
71
- for _ in range(params.depth)
72
- ]
73
- )
74
-
75
- self.single_blocks = nn.ModuleList(
76
- [
77
- SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
78
- for _ in range(params.depth_single_blocks)
79
- ]
80
- )
81
-
82
- self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
83
-
84
- def forward(
85
- self,
86
- img: Tensor,
87
- img_ids: Tensor,
88
- txt: Tensor,
89
- txt_ids: Tensor,
90
- timesteps: Tensor,
91
- y: Tensor,
92
- guidance: Tensor | None = None,
93
- ) -> Tensor:
94
- if img.ndim != 3 or txt.ndim != 3:
95
- raise ValueError("Input img and txt tensors must have 3 dimensions.")
96
-
97
- # running on sequences img
98
- img = self.img_in(img)
99
- vec = self.time_in(timestep_embedding(timesteps, 256))
100
- if self.params.guidance_embed:
101
- if guidance is None:
102
- raise ValueError("Didn't get guidance strength for guidance distilled model.")
103
- vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
104
- vec = vec + self.vector_in(y)
105
- txt = self.txt_in(txt)
106
-
107
- ids = torch.cat((txt_ids, img_ids), dim=1)
108
- pe = self.pe_embedder(ids)
109
-
110
- for block in self.double_blocks:
111
- img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
112
-
113
- img = torch.cat((txt, img), 1)
114
- for block in self.single_blocks:
115
- img = block(img, vec=vec, pe=pe)
116
- img = img[:, txt.shape[1] :, ...]
117
-
118
- img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
119
- return img
120
-
121
-
122
- class FluxLoraWrapper(Flux):
123
- def __init__(
124
- self,
125
- lora_rank: int = 128,
126
- lora_scale: float = 1.0,
127
- *args,
128
- **kwargs,
129
- ) -> None:
130
- super().__init__(*args, **kwargs)
131
-
132
- self.lora_rank = lora_rank
133
-
134
- replace_linear_with_lora(
135
- self,
136
- max_rank=lora_rank,
137
- scale=lora_scale,
138
- )
139
-
140
- def set_lora_scale(self, scale: float) -> None:
141
- for module in self.modules():
142
- if isinstance(module, LinearLora):
143
- module.set_scale(scale=scale)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/modules/autoencoder.py DELETED
@@ -1,312 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- import torch
4
- from einops import rearrange
5
- from torch import Tensor, nn
6
-
7
-
8
- @dataclass
9
- class AutoEncoderParams:
10
- resolution: int
11
- in_channels: int
12
- ch: int
13
- out_ch: int
14
- ch_mult: list[int]
15
- num_res_blocks: int
16
- z_channels: int
17
- scale_factor: float
18
- shift_factor: float
19
-
20
-
21
- def swish(x: Tensor) -> Tensor:
22
- return x * torch.sigmoid(x)
23
-
24
-
25
- class AttnBlock(nn.Module):
26
- def __init__(self, in_channels: int):
27
- super().__init__()
28
- self.in_channels = in_channels
29
-
30
- self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
-
32
- self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
- self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
- self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
- self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
-
37
- def attention(self, h_: Tensor) -> Tensor:
38
- h_ = self.norm(h_)
39
- q = self.q(h_)
40
- k = self.k(h_)
41
- v = self.v(h_)
42
-
43
- b, c, h, w = q.shape
44
- q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
- k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
- v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
- h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
-
49
- return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
-
51
- def forward(self, x: Tensor) -> Tensor:
52
- return x + self.proj_out(self.attention(x))
53
-
54
-
55
- class ResnetBlock(nn.Module):
56
- def __init__(self, in_channels: int, out_channels: int):
57
- super().__init__()
58
- self.in_channels = in_channels
59
- out_channels = in_channels if out_channels is None else out_channels
60
- self.out_channels = out_channels
61
-
62
- self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
- self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
- if self.in_channels != self.out_channels:
67
- self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
-
69
- def forward(self, x):
70
- h = x
71
- h = self.norm1(h)
72
- h = swish(h)
73
- h = self.conv1(h)
74
-
75
- h = self.norm2(h)
76
- h = swish(h)
77
- h = self.conv2(h)
78
-
79
- if self.in_channels != self.out_channels:
80
- x = self.nin_shortcut(x)
81
-
82
- return x + h
83
-
84
-
85
- class Downsample(nn.Module):
86
- def __init__(self, in_channels: int):
87
- super().__init__()
88
- # no asymmetric padding in torch conv, must do it ourselves
89
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
-
91
- def forward(self, x: Tensor):
92
- pad = (0, 1, 0, 1)
93
- x = nn.functional.pad(x, pad, mode="constant", value=0)
94
- x = self.conv(x)
95
- return x
96
-
97
-
98
- class Upsample(nn.Module):
99
- def __init__(self, in_channels: int):
100
- super().__init__()
101
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
-
103
- def forward(self, x: Tensor):
104
- x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
- x = self.conv(x)
106
- return x
107
-
108
-
109
- class Encoder(nn.Module):
110
- def __init__(
111
- self,
112
- resolution: int,
113
- in_channels: int,
114
- ch: int,
115
- ch_mult: list[int],
116
- num_res_blocks: int,
117
- z_channels: int,
118
- ):
119
- super().__init__()
120
- self.ch = ch
121
- self.num_resolutions = len(ch_mult)
122
- self.num_res_blocks = num_res_blocks
123
- self.resolution = resolution
124
- self.in_channels = in_channels
125
- # downsampling
126
- self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
-
128
- curr_res = resolution
129
- in_ch_mult = (1,) + tuple(ch_mult)
130
- self.in_ch_mult = in_ch_mult
131
- self.down = nn.ModuleList()
132
- block_in = self.ch
133
- for i_level in range(self.num_resolutions):
134
- block = nn.ModuleList()
135
- attn = nn.ModuleList()
136
- block_in = ch * in_ch_mult[i_level]
137
- block_out = ch * ch_mult[i_level]
138
- for _ in range(self.num_res_blocks):
139
- block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
- block_in = block_out
141
- down = nn.Module()
142
- down.block = block
143
- down.attn = attn
144
- if i_level != self.num_resolutions - 1:
145
- down.downsample = Downsample(block_in)
146
- curr_res = curr_res // 2
147
- self.down.append(down)
148
-
149
- # middle
150
- self.mid = nn.Module()
151
- self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
- self.mid.attn_1 = AttnBlock(block_in)
153
- self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
-
155
- # end
156
- self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
- self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
-
159
- def forward(self, x: Tensor) -> Tensor:
160
- # downsampling
161
- hs = [self.conv_in(x)]
162
- for i_level in range(self.num_resolutions):
163
- for i_block in range(self.num_res_blocks):
164
- h = self.down[i_level].block[i_block](hs[-1])
165
- if len(self.down[i_level].attn) > 0:
166
- h = self.down[i_level].attn[i_block](h)
167
- hs.append(h)
168
- if i_level != self.num_resolutions - 1:
169
- hs.append(self.down[i_level].downsample(hs[-1]))
170
-
171
- # middle
172
- h = hs[-1]
173
- h = self.mid.block_1(h)
174
- h = self.mid.attn_1(h)
175
- h = self.mid.block_2(h)
176
- # end
177
- h = self.norm_out(h)
178
- h = swish(h)
179
- h = self.conv_out(h)
180
- return h
181
-
182
-
183
- class Decoder(nn.Module):
184
- def __init__(
185
- self,
186
- ch: int,
187
- out_ch: int,
188
- ch_mult: list[int],
189
- num_res_blocks: int,
190
- in_channels: int,
191
- resolution: int,
192
- z_channels: int,
193
- ):
194
- super().__init__()
195
- self.ch = ch
196
- self.num_resolutions = len(ch_mult)
197
- self.num_res_blocks = num_res_blocks
198
- self.resolution = resolution
199
- self.in_channels = in_channels
200
- self.ffactor = 2 ** (self.num_resolutions - 1)
201
-
202
- # compute in_ch_mult, block_in and curr_res at lowest res
203
- block_in = ch * ch_mult[self.num_resolutions - 1]
204
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
- self.z_shape = (1, z_channels, curr_res, curr_res)
206
-
207
- # z to block_in
208
- self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
-
210
- # middle
211
- self.mid = nn.Module()
212
- self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
- self.mid.attn_1 = AttnBlock(block_in)
214
- self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
-
216
- # upsampling
217
- self.up = nn.ModuleList()
218
- for i_level in reversed(range(self.num_resolutions)):
219
- block = nn.ModuleList()
220
- attn = nn.ModuleList()
221
- block_out = ch * ch_mult[i_level]
222
- for _ in range(self.num_res_blocks + 1):
223
- block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
- block_in = block_out
225
- up = nn.Module()
226
- up.block = block
227
- up.attn = attn
228
- if i_level != 0:
229
- up.upsample = Upsample(block_in)
230
- curr_res = curr_res * 2
231
- self.up.insert(0, up) # prepend to get consistent order
232
-
233
- # end
234
- self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
- self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
-
237
- def forward(self, z: Tensor) -> Tensor:
238
- # z to block_in
239
- h = self.conv_in(z)
240
-
241
- # middle
242
- h = self.mid.block_1(h)
243
- h = self.mid.attn_1(h)
244
- h = self.mid.block_2(h)
245
-
246
- # upsampling
247
- for i_level in reversed(range(self.num_resolutions)):
248
- for i_block in range(self.num_res_blocks + 1):
249
- h = self.up[i_level].block[i_block](h)
250
- if len(self.up[i_level].attn) > 0:
251
- h = self.up[i_level].attn[i_block](h)
252
- if i_level != 0:
253
- h = self.up[i_level].upsample(h)
254
-
255
- # end
256
- h = self.norm_out(h)
257
- h = swish(h)
258
- h = self.conv_out(h)
259
- return h
260
-
261
-
262
- class DiagonalGaussian(nn.Module):
263
- def __init__(self, sample: bool = True, chunk_dim: int = 1):
264
- super().__init__()
265
- self.sample = sample
266
- self.chunk_dim = chunk_dim
267
-
268
- def forward(self, z: Tensor) -> Tensor:
269
- mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
270
- if self.sample:
271
- std = torch.exp(0.5 * logvar)
272
- return mean + std * torch.randn_like(mean)
273
- else:
274
- return mean
275
-
276
-
277
- class AutoEncoder(nn.Module):
278
- def __init__(self, params: AutoEncoderParams):
279
- super().__init__()
280
- self.encoder = Encoder(
281
- resolution=params.resolution,
282
- in_channels=params.in_channels,
283
- ch=params.ch,
284
- ch_mult=params.ch_mult,
285
- num_res_blocks=params.num_res_blocks,
286
- z_channels=params.z_channels,
287
- )
288
- self.decoder = Decoder(
289
- resolution=params.resolution,
290
- in_channels=params.in_channels,
291
- ch=params.ch,
292
- out_ch=params.out_ch,
293
- ch_mult=params.ch_mult,
294
- num_res_blocks=params.num_res_blocks,
295
- z_channels=params.z_channels,
296
- )
297
- self.reg = DiagonalGaussian()
298
-
299
- self.scale_factor = params.scale_factor
300
- self.shift_factor = params.shift_factor
301
-
302
- def encode(self, x: Tensor) -> Tensor:
303
- z = self.reg(self.encoder(x))
304
- z = self.scale_factor * (z - self.shift_factor)
305
- return z
306
-
307
- def decode(self, z: Tensor) -> Tensor:
308
- z = z / self.scale_factor + self.shift_factor
309
- return self.decoder(z)
310
-
311
- def forward(self, x: Tensor) -> Tensor:
312
- return self.decode(self.encode(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/modules/conditioner.py DELETED
@@ -1,37 +0,0 @@
1
- from torch import Tensor, nn
2
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
3
-
4
-
5
- class HFEmbedder(nn.Module):
6
- def __init__(self, version: str, max_length: int, **hf_kwargs):
7
- super().__init__()
8
- self.is_clip = version.startswith("openai")
9
- self.max_length = max_length
10
- self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
11
-
12
- if self.is_clip:
13
- self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
14
- self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
15
- else:
16
- self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
17
- self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
18
-
19
- self.hf_module = self.hf_module.eval().requires_grad_(False)
20
-
21
- def forward(self, text: list[str]) -> Tensor:
22
- batch_encoding = self.tokenizer(
23
- text,
24
- truncation=True,
25
- max_length=self.max_length,
26
- return_length=False,
27
- return_overflowing_tokens=False,
28
- padding="max_length",
29
- return_tensors="pt",
30
- )
31
-
32
- outputs = self.hf_module(
33
- input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
34
- attention_mask=None,
35
- output_hidden_states=False,
36
- )
37
- return outputs[self.output_key]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/modules/image_embedders.py DELETED
@@ -1,103 +0,0 @@
1
- import os
2
-
3
- import cv2
4
- import numpy as np
5
- import torch
6
- from einops import rearrange, repeat
7
- from PIL import Image
8
- from safetensors.torch import load_file as load_sft
9
- from torch import nn
10
- from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
11
-
12
- from flux.util import print_load_warning
13
-
14
-
15
- class DepthImageEncoder:
16
- depth_model_name = "LiheYoung/depth-anything-large-hf"
17
-
18
- def __init__(self, device):
19
- self.device = device
20
- self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device)
21
- self.processor = AutoProcessor.from_pretrained(self.depth_model_name)
22
-
23
- def __call__(self, img: torch.Tensor) -> torch.Tensor:
24
- hw = img.shape[-2:]
25
-
26
- img = torch.clamp(img, -1.0, 1.0)
27
- img_byte = ((img + 1.0) * 127.5).byte()
28
-
29
- img = self.processor(img_byte, return_tensors="pt")["pixel_values"]
30
- depth = self.depth_model(img.to(self.device)).predicted_depth
31
- depth = repeat(depth, "b h w -> b 3 h w")
32
- depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True)
33
-
34
- depth = depth / 127.5 - 1.0
35
- return depth
36
-
37
-
38
- class CannyImageEncoder:
39
- def __init__(
40
- self,
41
- device,
42
- min_t: int = 50,
43
- max_t: int = 200,
44
- ):
45
- self.device = device
46
- self.min_t = min_t
47
- self.max_t = max_t
48
-
49
- def __call__(self, img: torch.Tensor) -> torch.Tensor:
50
- assert img.shape[0] == 1, "Only batch size 1 is supported"
51
-
52
- img = rearrange(img[0], "c h w -> h w c")
53
- img = torch.clamp(img, -1.0, 1.0)
54
- img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8)
55
-
56
- # Apply Canny edge detection
57
- canny = cv2.Canny(img_np, self.min_t, self.max_t)
58
-
59
- # Convert back to torch tensor and reshape
60
- canny = torch.from_numpy(canny).float() / 127.5 - 1.0
61
- canny = rearrange(canny, "h w -> 1 1 h w")
62
- canny = repeat(canny, "b 1 ... -> b 3 ...")
63
- return canny.to(self.device)
64
-
65
-
66
- class ReduxImageEncoder(nn.Module):
67
- siglip_model_name = "google/siglip-so400m-patch14-384"
68
-
69
- def __init__(
70
- self,
71
- device,
72
- redux_dim: int = 1152,
73
- txt_in_features: int = 4096,
74
- redux_path: str | None = os.getenv("FLUX_REDUX"),
75
- dtype=torch.bfloat16,
76
- ) -> None:
77
- assert redux_path is not None, "Redux path must be provided"
78
-
79
- super().__init__()
80
-
81
- self.redux_dim = redux_dim
82
- self.device = device if isinstance(device, torch.device) else torch.device(device)
83
- self.dtype = dtype
84
-
85
- with self.device:
86
- self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
87
- self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
88
-
89
- sd = load_sft(redux_path, device=str(device))
90
- missing, unexpected = self.load_state_dict(sd, strict=False, assign=True)
91
- print_load_warning(missing, unexpected)
92
-
93
- self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(dtype=dtype)
94
- self.normalize = SiglipImageProcessor.from_pretrained(self.siglip_model_name)
95
-
96
- def __call__(self, x: Image.Image) -> torch.Tensor:
97
- imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
98
-
99
- _encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state
100
-
101
- projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x)))
102
-
103
- return projected_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/modules/layers.py DELETED
@@ -1,253 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
-
4
- import torch
5
- from einops import rearrange
6
- from torch import Tensor, nn
7
-
8
- from flux.math import attention, rope
9
-
10
-
11
- class EmbedND(nn.Module):
12
- def __init__(self, dim: int, theta: int, axes_dim: list[int]):
13
- super().__init__()
14
- self.dim = dim
15
- self.theta = theta
16
- self.axes_dim = axes_dim
17
-
18
- def forward(self, ids: Tensor) -> Tensor:
19
- n_axes = ids.shape[-1]
20
- emb = torch.cat(
21
- [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
22
- dim=-3,
23
- )
24
-
25
- return emb.unsqueeze(1)
26
-
27
-
28
- def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
29
- """
30
- Create sinusoidal timestep embeddings.
31
- :param t: a 1-D Tensor of N indices, one per batch element.
32
- These may be fractional.
33
- :param dim: the dimension of the output.
34
- :param max_period: controls the minimum frequency of the embeddings.
35
- :return: an (N, D) Tensor of positional embeddings.
36
- """
37
- t = time_factor * t
38
- half = dim // 2
39
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
40
- t.device
41
- )
42
-
43
- args = t[:, None].float() * freqs[None]
44
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
45
- if dim % 2:
46
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
47
- if torch.is_floating_point(t):
48
- embedding = embedding.to(t)
49
- return embedding
50
-
51
-
52
- class MLPEmbedder(nn.Module):
53
- def __init__(self, in_dim: int, hidden_dim: int):
54
- super().__init__()
55
- self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
56
- self.silu = nn.SiLU()
57
- self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
58
-
59
- def forward(self, x: Tensor) -> Tensor:
60
- return self.out_layer(self.silu(self.in_layer(x)))
61
-
62
-
63
- class RMSNorm(torch.nn.Module):
64
- def __init__(self, dim: int):
65
- super().__init__()
66
- self.scale = nn.Parameter(torch.ones(dim))
67
-
68
- def forward(self, x: Tensor):
69
- x_dtype = x.dtype
70
- x = x.float()
71
- rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
72
- return (x * rrms).to(dtype=x_dtype) * self.scale
73
-
74
-
75
- class QKNorm(torch.nn.Module):
76
- def __init__(self, dim: int):
77
- super().__init__()
78
- self.query_norm = RMSNorm(dim)
79
- self.key_norm = RMSNorm(dim)
80
-
81
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
82
- q = self.query_norm(q)
83
- k = self.key_norm(k)
84
- return q.to(v), k.to(v)
85
-
86
-
87
- class SelfAttention(nn.Module):
88
- def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
89
- super().__init__()
90
- self.num_heads = num_heads
91
- head_dim = dim // num_heads
92
-
93
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
94
- self.norm = QKNorm(head_dim)
95
- self.proj = nn.Linear(dim, dim)
96
-
97
- def forward(self, x: Tensor, pe: Tensor) -> Tensor:
98
- qkv = self.qkv(x)
99
- q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
100
- q, k = self.norm(q, k, v)
101
- x = attention(q, k, v, pe=pe)
102
- x = self.proj(x)
103
- return x
104
-
105
-
106
- @dataclass
107
- class ModulationOut:
108
- shift: Tensor
109
- scale: Tensor
110
- gate: Tensor
111
-
112
-
113
- class Modulation(nn.Module):
114
- def __init__(self, dim: int, double: bool):
115
- super().__init__()
116
- self.is_double = double
117
- self.multiplier = 6 if double else 3
118
- self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
119
-
120
- def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
121
- out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
122
-
123
- return (
124
- ModulationOut(*out[:3]),
125
- ModulationOut(*out[3:]) if self.is_double else None,
126
- )
127
-
128
-
129
- class DoubleStreamBlock(nn.Module):
130
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
131
- super().__init__()
132
-
133
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
134
- self.num_heads = num_heads
135
- self.hidden_size = hidden_size
136
- self.img_mod = Modulation(hidden_size, double=True)
137
- self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
138
- self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
139
-
140
- self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
141
- self.img_mlp = nn.Sequential(
142
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
143
- nn.GELU(approximate="tanh"),
144
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
145
- )
146
-
147
- self.txt_mod = Modulation(hidden_size, double=True)
148
- self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
149
- self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
150
-
151
- self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
152
- self.txt_mlp = nn.Sequential(
153
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
154
- nn.GELU(approximate="tanh"),
155
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
156
- )
157
-
158
- def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
159
- img_mod1, img_mod2 = self.img_mod(vec)
160
- txt_mod1, txt_mod2 = self.txt_mod(vec)
161
-
162
- # prepare image for attention
163
- img_modulated = self.img_norm1(img)
164
- img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
165
- img_qkv = self.img_attn.qkv(img_modulated)
166
- img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
167
- img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
168
-
169
- # prepare txt for attention
170
- txt_modulated = self.txt_norm1(txt)
171
- txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
172
- txt_qkv = self.txt_attn.qkv(txt_modulated)
173
- txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
174
- txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
175
-
176
- # run actual attention
177
- q = torch.cat((txt_q, img_q), dim=2)
178
- k = torch.cat((txt_k, img_k), dim=2)
179
- v = torch.cat((txt_v, img_v), dim=2)
180
-
181
- attn = attention(q, k, v, pe=pe)
182
- txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
183
-
184
- # calculate the img bloks
185
- img = img + img_mod1.gate * self.img_attn.proj(img_attn)
186
- img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
187
-
188
- # calculate the txt bloks
189
- txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
190
- txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
191
- return img, txt
192
-
193
-
194
- class SingleStreamBlock(nn.Module):
195
- """
196
- A DiT block with parallel linear layers as described in
197
- https://arxiv.org/abs/2302.05442 and adapted modulation interface.
198
- """
199
-
200
- def __init__(
201
- self,
202
- hidden_size: int,
203
- num_heads: int,
204
- mlp_ratio: float = 4.0,
205
- qk_scale: float | None = None,
206
- ):
207
- super().__init__()
208
- self.hidden_dim = hidden_size
209
- self.num_heads = num_heads
210
- head_dim = hidden_size // num_heads
211
- self.scale = qk_scale or head_dim**-0.5
212
-
213
- self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
214
- # qkv and mlp_in
215
- self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
216
- # proj and mlp_out
217
- self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
218
-
219
- self.norm = QKNorm(head_dim)
220
-
221
- self.hidden_size = hidden_size
222
- self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
223
-
224
- self.mlp_act = nn.GELU(approximate="tanh")
225
- self.modulation = Modulation(hidden_size, double=False)
226
-
227
- def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
228
- mod, _ = self.modulation(vec)
229
- x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
230
- qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
231
-
232
- q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
233
- q, k = self.norm(q, k, v)
234
-
235
- # compute attention
236
- attn = attention(q, k, v, pe=pe)
237
- # compute activation in mlp stream, cat again and run second linear layer
238
- output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
239
- return x + mod.gate * output
240
-
241
-
242
- class LastLayer(nn.Module):
243
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
244
- super().__init__()
245
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
247
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
248
-
249
- def forward(self, x: Tensor, vec: Tensor) -> Tensor:
250
- shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
251
- x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
252
- x = self.linear(x)
253
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/modules/lora.py DELETED
@@ -1,94 +0,0 @@
1
- import torch
2
- from torch import nn
3
-
4
-
5
- def replace_linear_with_lora(
6
- module: nn.Module,
7
- max_rank: int,
8
- scale: float = 1.0,
9
- ) -> None:
10
- for name, child in module.named_children():
11
- if isinstance(child, nn.Linear):
12
- new_lora = LinearLora(
13
- in_features=child.in_features,
14
- out_features=child.out_features,
15
- bias=child.bias,
16
- rank=max_rank,
17
- scale=scale,
18
- dtype=child.weight.dtype,
19
- device=child.weight.device,
20
- )
21
-
22
- new_lora.weight = child.weight
23
- new_lora.bias = child.bias if child.bias is not None else None
24
-
25
- setattr(module, name, new_lora)
26
- else:
27
- replace_linear_with_lora(
28
- module=child,
29
- max_rank=max_rank,
30
- scale=scale,
31
- )
32
-
33
-
34
- class LinearLora(nn.Linear):
35
- def __init__(
36
- self,
37
- in_features: int,
38
- out_features: int,
39
- bias: bool,
40
- rank: int,
41
- dtype: torch.dtype,
42
- device: torch.device,
43
- lora_bias: bool = True,
44
- scale: float = 1.0,
45
- *args,
46
- **kwargs,
47
- ) -> None:
48
- super().__init__(
49
- in_features=in_features,
50
- out_features=out_features,
51
- bias=bias is not None,
52
- device=device,
53
- dtype=dtype,
54
- *args,
55
- **kwargs,
56
- )
57
-
58
- assert isinstance(scale, float), "scale must be a float"
59
-
60
- self.scale = scale
61
- self.rank = rank
62
- self.lora_bias = lora_bias
63
- self.dtype = dtype
64
- self.device = device
65
-
66
- if rank > (new_rank := min(self.out_features, self.in_features)):
67
- self.rank = new_rank
68
-
69
- self.lora_A = nn.Linear(
70
- in_features=in_features,
71
- out_features=self.rank,
72
- bias=False,
73
- dtype=dtype,
74
- device=device,
75
- )
76
- self.lora_B = nn.Linear(
77
- in_features=self.rank,
78
- out_features=out_features,
79
- bias=self.lora_bias,
80
- dtype=dtype,
81
- device=device,
82
- )
83
-
84
- def set_scale(self, scale: float) -> None:
85
- assert isinstance(scale, float), "scalar value must be a float"
86
- self.scale = scale
87
-
88
- def forward(self, input: torch.Tensor) -> torch.Tensor:
89
- base_out = super().forward(input)
90
-
91
- _lora_out_B = self.lora_B(self.lora_A(input))
92
- lora_update = _lora_out_B * self.scale
93
-
94
- return base_out + lora_update
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/sampling.py DELETED
@@ -1,282 +0,0 @@
1
- import math
2
- from typing import Callable
3
-
4
- import numpy as np
5
- import torch
6
- from einops import rearrange, repeat
7
- from PIL import Image
8
- from torch import Tensor
9
-
10
- from .model import Flux
11
- from .modules.autoencoder import AutoEncoder
12
- from .modules.conditioner import HFEmbedder
13
- from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder
14
-
15
-
16
- def get_noise(
17
- num_samples: int,
18
- height: int,
19
- width: int,
20
- device: torch.device,
21
- dtype: torch.dtype,
22
- seed: int,
23
- ):
24
- return torch.randn(
25
- num_samples,
26
- 16,
27
- # allow for packing
28
- 2 * math.ceil(height / 16),
29
- 2 * math.ceil(width / 16),
30
- device=device,
31
- dtype=dtype,
32
- generator=torch.Generator(device=device).manual_seed(seed),
33
- )
34
-
35
-
36
- def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]:
37
- bs, c, h, w = img.shape
38
- if bs == 1 and not isinstance(prompt, str):
39
- bs = len(prompt)
40
-
41
- img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
42
- if img.shape[0] == 1 and bs > 1:
43
- img = repeat(img, "1 ... -> bs ...", bs=bs)
44
-
45
- img_ids = torch.zeros(h // 2, w // 2, 3)
46
- img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
47
- img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
48
- img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
49
-
50
- if isinstance(prompt, str):
51
- prompt = [prompt]
52
- txt = t5(prompt)
53
- if txt.shape[0] == 1 and bs > 1:
54
- txt = repeat(txt, "1 ... -> bs ...", bs=bs)
55
- txt_ids = torch.zeros(bs, txt.shape[1], 3)
56
-
57
- vec = clip(prompt)
58
- if vec.shape[0] == 1 and bs > 1:
59
- vec = repeat(vec, "1 ... -> bs ...", bs=bs)
60
-
61
- return {
62
- "img": img,
63
- "img_ids": img_ids.to(img.device),
64
- "txt": txt.to(img.device),
65
- "txt_ids": txt_ids.to(img.device),
66
- "vec": vec.to(img.device),
67
- }
68
-
69
-
70
- def prepare_control(
71
- t5: HFEmbedder,
72
- clip: HFEmbedder,
73
- img: Tensor,
74
- prompt: str | list[str],
75
- ae: AutoEncoder,
76
- encoder: DepthImageEncoder | CannyImageEncoder,
77
- img_cond_path: str,
78
- ) -> dict[str, Tensor]:
79
- # load and encode the conditioning image
80
- bs, _, h, w = img.shape
81
- if bs == 1 and not isinstance(prompt, str):
82
- bs = len(prompt)
83
-
84
- img_cond = Image.open(img_cond_path).convert("RGB")
85
-
86
- width = w * 8
87
- height = h * 8
88
- img_cond = img_cond.resize((width, height), Image.LANCZOS)
89
- img_cond = np.array(img_cond)
90
- img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
91
- img_cond = rearrange(img_cond, "h w c -> 1 c h w")
92
-
93
- with torch.no_grad():
94
- img_cond = encoder(img_cond)
95
- img_cond = ae.encode(img_cond)
96
-
97
- img_cond = img_cond.to(torch.bfloat16)
98
- img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
99
- if img_cond.shape[0] == 1 and bs > 1:
100
- img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
101
-
102
- return_dict = prepare(t5, clip, img, prompt)
103
- return_dict["img_cond"] = img_cond
104
- return return_dict
105
-
106
-
107
- def prepare_fill(
108
- t5: HFEmbedder,
109
- clip: HFEmbedder,
110
- img: Tensor,
111
- prompt: str | list[str],
112
- ae: AutoEncoder,
113
- img_cond_path: str,
114
- mask_path: str,
115
- ) -> dict[str, Tensor]:
116
- # load and encode the conditioning image and the mask
117
- bs, _, _, _ = img.shape
118
- if bs == 1 and not isinstance(prompt, str):
119
- bs = len(prompt)
120
-
121
- img_cond = Image.open(img_cond_path).convert("RGB")
122
- img_cond = np.array(img_cond)
123
- img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0
124
- img_cond = rearrange(img_cond, "h w c -> 1 c h w")
125
-
126
- mask = Image.open(mask_path).convert("L")
127
- mask = np.array(mask)
128
- mask = torch.from_numpy(mask).float() / 255.0
129
- mask = rearrange(mask, "h w -> 1 1 h w")
130
-
131
- with torch.no_grad():
132
- img_cond = img_cond.to(img.device)
133
- mask = mask.to(img.device)
134
- img_cond = img_cond * (1 - mask)
135
- img_cond = ae.encode(img_cond)
136
- mask = mask[:, 0, :, :]
137
- mask = mask.to(torch.bfloat16)
138
- mask = rearrange(
139
- mask,
140
- "b (h ph) (w pw) -> b (ph pw) h w",
141
- ph=8,
142
- pw=8,
143
- )
144
- mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
145
- if mask.shape[0] == 1 and bs > 1:
146
- mask = repeat(mask, "1 ... -> bs ...", bs=bs)
147
-
148
- img_cond = img_cond.to(torch.bfloat16)
149
- img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
150
- if img_cond.shape[0] == 1 and bs > 1:
151
- img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
152
-
153
- img_cond = torch.cat((img_cond, mask), dim=-1)
154
-
155
- return_dict = prepare(t5, clip, img, prompt)
156
- return_dict["img_cond"] = img_cond.to(img.device)
157
- return return_dict
158
-
159
-
160
- def prepare_redux(
161
- t5: HFEmbedder,
162
- clip: HFEmbedder,
163
- img: Tensor,
164
- prompt: str | list[str],
165
- encoder: ReduxImageEncoder,
166
- img_cond_path: str,
167
- ) -> dict[str, Tensor]:
168
- bs, _, h, w = img.shape
169
- if bs == 1 and not isinstance(prompt, str):
170
- bs = len(prompt)
171
-
172
- img_cond = Image.open(img_cond_path).convert("RGB")
173
- with torch.no_grad():
174
- img_cond = encoder(img_cond)
175
-
176
- img_cond = img_cond.to(torch.bfloat16)
177
- if img_cond.shape[0] == 1 and bs > 1:
178
- img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs)
179
-
180
- img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
181
- if img.shape[0] == 1 and bs > 1:
182
- img = repeat(img, "1 ... -> bs ...", bs=bs)
183
-
184
- img_ids = torch.zeros(h // 2, w // 2, 3)
185
- img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
186
- img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
187
- img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
188
-
189
- if isinstance(prompt, str):
190
- prompt = [prompt]
191
- txt = t5(prompt)
192
- txt = torch.cat((txt, img_cond.to(txt)), dim=-2)
193
- if txt.shape[0] == 1 and bs > 1:
194
- txt = repeat(txt, "1 ... -> bs ...", bs=bs)
195
- txt_ids = torch.zeros(bs, txt.shape[1], 3)
196
-
197
- vec = clip(prompt)
198
- if vec.shape[0] == 1 and bs > 1:
199
- vec = repeat(vec, "1 ... -> bs ...", bs=bs)
200
-
201
- return {
202
- "img": img,
203
- "img_ids": img_ids.to(img.device),
204
- "txt": txt.to(img.device),
205
- "txt_ids": txt_ids.to(img.device),
206
- "vec": vec.to(img.device),
207
- }
208
-
209
-
210
- def time_shift(mu: float, sigma: float, t: Tensor):
211
- return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
212
-
213
-
214
- def get_lin_function(
215
- x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
216
- ) -> Callable[[float], float]:
217
- m = (y2 - y1) / (x2 - x1)
218
- b = y1 - m * x1
219
- return lambda x: m * x + b
220
-
221
-
222
- def get_schedule(
223
- num_steps: int,
224
- image_seq_len: int,
225
- base_shift: float = 0.5,
226
- max_shift: float = 1.15,
227
- shift: bool = True,
228
- ) -> list[float]:
229
- # extra step for zero
230
- timesteps = torch.linspace(1, 0, num_steps + 1)
231
-
232
- # shifting the schedule to favor high timesteps for higher signal images
233
- if shift:
234
- # estimate mu based on linear estimation between two points
235
- mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
236
- timesteps = time_shift(mu, 1.0, timesteps)
237
-
238
- return timesteps.tolist()
239
-
240
-
241
- def denoise(
242
- model: Flux,
243
- # model input
244
- img: Tensor,
245
- img_ids: Tensor,
246
- txt: Tensor,
247
- txt_ids: Tensor,
248
- vec: Tensor,
249
- # sampling parameters
250
- timesteps: list[float],
251
- guidance: float = 4.0,
252
- # extra img tokens
253
- img_cond: Tensor | None = None,
254
- ):
255
- # this is ignored for schnell
256
- guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
257
- for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
258
- t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
259
- pred = model(
260
- img=torch.cat((img, img_cond), dim=-1) if img_cond is not None else img,
261
- img_ids=img_ids,
262
- txt=txt,
263
- txt_ids=txt_ids,
264
- y=vec,
265
- timesteps=t_vec,
266
- guidance=guidance_vec,
267
- )
268
-
269
- img = img + (t_prev - t_curr) * pred
270
-
271
- return img
272
-
273
-
274
- def unpack(x: Tensor, height: int, width: int) -> Tensor:
275
- return rearrange(
276
- x,
277
- "b (h w) (c ph pw) -> b c (h ph) (w pw)",
278
- h=math.ceil(height / 16),
279
- w=math.ceil(width / 16),
280
- ph=2,
281
- pw=2,
282
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flux/util.py DELETED
@@ -1,447 +0,0 @@
1
- import os
2
- from dataclasses import dataclass
3
-
4
- import torch
5
- from einops import rearrange
6
- from huggingface_hub import hf_hub_download
7
- from imwatermark import WatermarkEncoder
8
- from PIL import ExifTags, Image
9
- from safetensors.torch import load_file as load_sft
10
-
11
- from flux.model import Flux, FluxLoraWrapper, FluxParams
12
- from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
13
- from flux.modules.conditioner import HFEmbedder
14
-
15
-
16
- def save_image(
17
- nsfw_classifier,
18
- name: str,
19
- output_name: str,
20
- idx: int,
21
- x: torch.Tensor,
22
- add_sampling_metadata: bool,
23
- prompt: str,
24
- nsfw_threshold: float = 0.85,
25
- ) -> int:
26
- fn = output_name.format(idx=idx)
27
- print(f"Saving {fn}")
28
- # bring into PIL format and save
29
- x = x.clamp(-1, 1)
30
- x = embed_watermark(x.float())
31
- x = rearrange(x[0], "c h w -> h w c")
32
-
33
- img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
34
- nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
35
-
36
- if nsfw_score < nsfw_threshold:
37
- exif_data = Image.Exif()
38
- exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
39
- exif_data[ExifTags.Base.Make] = "Black Forest Labs"
40
- exif_data[ExifTags.Base.Model] = name
41
- if add_sampling_metadata:
42
- exif_data[ExifTags.Base.ImageDescription] = prompt
43
- img.save(fn, exif=exif_data, quality=95, subsampling=0)
44
- idx += 1
45
- else:
46
- print("Your generated image may contain NSFW content.")
47
-
48
- return idx
49
-
50
-
51
- @dataclass
52
- class ModelSpec:
53
- params: FluxParams
54
- ae_params: AutoEncoderParams
55
- ckpt_path: str | None
56
- lora_path: str | None
57
- ae_path: str | None
58
- repo_id: str | None
59
- repo_flow: str | None
60
- repo_ae: str | None
61
-
62
-
63
- configs = {
64
- "flux-dev": ModelSpec(
65
- repo_id="black-forest-labs/FLUX.1-dev",
66
- repo_flow="flux1-dev.safetensors",
67
- repo_ae="ae.safetensors",
68
- ckpt_path=os.getenv("FLUX_DEV"),
69
- lora_path=None,
70
- params=FluxParams(
71
- in_channels=64,
72
- out_channels=64,
73
- vec_in_dim=768,
74
- context_in_dim=4096,
75
- hidden_size=3072,
76
- mlp_ratio=4.0,
77
- num_heads=24,
78
- depth=19,
79
- depth_single_blocks=38,
80
- axes_dim=[16, 56, 56],
81
- theta=10_000,
82
- qkv_bias=True,
83
- guidance_embed=True,
84
- ),
85
- ae_path=os.getenv("AE"),
86
- ae_params=AutoEncoderParams(
87
- resolution=256,
88
- in_channels=3,
89
- ch=128,
90
- out_ch=3,
91
- ch_mult=[1, 2, 4, 4],
92
- num_res_blocks=2,
93
- z_channels=16,
94
- scale_factor=0.3611,
95
- shift_factor=0.1159,
96
- ),
97
- ),
98
- "flux-schnell": ModelSpec(
99
- repo_id="black-forest-labs/FLUX.1-schnell",
100
- repo_flow="flux1-schnell.safetensors",
101
- repo_ae="ae.safetensors",
102
- ckpt_path=os.getenv("FLUX_SCHNELL"),
103
- lora_path=None,
104
- params=FluxParams(
105
- in_channels=64,
106
- out_channels=64,
107
- vec_in_dim=768,
108
- context_in_dim=4096,
109
- hidden_size=3072,
110
- mlp_ratio=4.0,
111
- num_heads=24,
112
- depth=19,
113
- depth_single_blocks=38,
114
- axes_dim=[16, 56, 56],
115
- theta=10_000,
116
- qkv_bias=True,
117
- guidance_embed=False,
118
- ),
119
- ae_path=os.getenv("AE"),
120
- ae_params=AutoEncoderParams(
121
- resolution=256,
122
- in_channels=3,
123
- ch=128,
124
- out_ch=3,
125
- ch_mult=[1, 2, 4, 4],
126
- num_res_blocks=2,
127
- z_channels=16,
128
- scale_factor=0.3611,
129
- shift_factor=0.1159,
130
- ),
131
- ),
132
- "flux-dev-canny": ModelSpec(
133
- repo_id="black-forest-labs/FLUX.1-Canny-dev",
134
- repo_flow="flux1-canny-dev.safetensors",
135
- repo_ae="ae.safetensors",
136
- ckpt_path=os.getenv("FLUX_DEV_CANNY"),
137
- lora_path=None,
138
- params=FluxParams(
139
- in_channels=128,
140
- out_channels=64,
141
- vec_in_dim=768,
142
- context_in_dim=4096,
143
- hidden_size=3072,
144
- mlp_ratio=4.0,
145
- num_heads=24,
146
- depth=19,
147
- depth_single_blocks=38,
148
- axes_dim=[16, 56, 56],
149
- theta=10_000,
150
- qkv_bias=True,
151
- guidance_embed=True,
152
- ),
153
- ae_path=os.getenv("AE"),
154
- ae_params=AutoEncoderParams(
155
- resolution=256,
156
- in_channels=3,
157
- ch=128,
158
- out_ch=3,
159
- ch_mult=[1, 2, 4, 4],
160
- num_res_blocks=2,
161
- z_channels=16,
162
- scale_factor=0.3611,
163
- shift_factor=0.1159,
164
- ),
165
- ),
166
- "flux-dev-canny-lora": ModelSpec(
167
- repo_id="black-forest-labs/FLUX.1-dev",
168
- repo_flow="flux1-dev.safetensors",
169
- repo_ae="ae.safetensors",
170
- ckpt_path=os.getenv("FLUX_DEV"),
171
- lora_path=os.getenv("FLUX_DEV_CANNY_LORA"),
172
- params=FluxParams(
173
- in_channels=128,
174
- out_channels=64,
175
- vec_in_dim=768,
176
- context_in_dim=4096,
177
- hidden_size=3072,
178
- mlp_ratio=4.0,
179
- num_heads=24,
180
- depth=19,
181
- depth_single_blocks=38,
182
- axes_dim=[16, 56, 56],
183
- theta=10_000,
184
- qkv_bias=True,
185
- guidance_embed=True,
186
- ),
187
- ae_path=os.getenv("AE"),
188
- ae_params=AutoEncoderParams(
189
- resolution=256,
190
- in_channels=3,
191
- ch=128,
192
- out_ch=3,
193
- ch_mult=[1, 2, 4, 4],
194
- num_res_blocks=2,
195
- z_channels=16,
196
- scale_factor=0.3611,
197
- shift_factor=0.1159,
198
- ),
199
- ),
200
- "flux-dev-depth": ModelSpec(
201
- repo_id="black-forest-labs/FLUX.1-Depth-dev",
202
- repo_flow="flux1-depth-dev.safetensors",
203
- repo_ae="ae.safetensors",
204
- ckpt_path=os.getenv("FLUX_DEV_DEPTH"),
205
- lora_path=None,
206
- params=FluxParams(
207
- in_channels=128,
208
- out_channels=64,
209
- vec_in_dim=768,
210
- context_in_dim=4096,
211
- hidden_size=3072,
212
- mlp_ratio=4.0,
213
- num_heads=24,
214
- depth=19,
215
- depth_single_blocks=38,
216
- axes_dim=[16, 56, 56],
217
- theta=10_000,
218
- qkv_bias=True,
219
- guidance_embed=True,
220
- ),
221
- ae_path=os.getenv("AE"),
222
- ae_params=AutoEncoderParams(
223
- resolution=256,
224
- in_channels=3,
225
- ch=128,
226
- out_ch=3,
227
- ch_mult=[1, 2, 4, 4],
228
- num_res_blocks=2,
229
- z_channels=16,
230
- scale_factor=0.3611,
231
- shift_factor=0.1159,
232
- ),
233
- ),
234
- "flux-dev-depth-lora": ModelSpec(
235
- repo_id="black-forest-labs/FLUX.1-dev",
236
- repo_flow="flux1-dev.safetensors",
237
- repo_ae="ae.safetensors",
238
- ckpt_path=os.getenv("FLUX_DEV"),
239
- lora_path=os.getenv("FLUX_DEV_DEPTH_LORA"),
240
- params=FluxParams(
241
- in_channels=128,
242
- out_channels=64,
243
- vec_in_dim=768,
244
- context_in_dim=4096,
245
- hidden_size=3072,
246
- mlp_ratio=4.0,
247
- num_heads=24,
248
- depth=19,
249
- depth_single_blocks=38,
250
- axes_dim=[16, 56, 56],
251
- theta=10_000,
252
- qkv_bias=True,
253
- guidance_embed=True,
254
- ),
255
- ae_path=os.getenv("AE"),
256
- ae_params=AutoEncoderParams(
257
- resolution=256,
258
- in_channels=3,
259
- ch=128,
260
- out_ch=3,
261
- ch_mult=[1, 2, 4, 4],
262
- num_res_blocks=2,
263
- z_channels=16,
264
- scale_factor=0.3611,
265
- shift_factor=0.1159,
266
- ),
267
- ),
268
- "flux-dev-fill": ModelSpec(
269
- repo_id="black-forest-labs/FLUX.1-Fill-dev",
270
- repo_flow="flux1-fill-dev.safetensors",
271
- repo_ae="ae.safetensors",
272
- ckpt_path=os.getenv("FLUX_DEV_FILL"),
273
- lora_path=None,
274
- params=FluxParams(
275
- in_channels=384,
276
- out_channels=64,
277
- vec_in_dim=768,
278
- context_in_dim=4096,
279
- hidden_size=3072,
280
- mlp_ratio=4.0,
281
- num_heads=24,
282
- depth=19,
283
- depth_single_blocks=38,
284
- axes_dim=[16, 56, 56],
285
- theta=10_000,
286
- qkv_bias=True,
287
- guidance_embed=True,
288
- ),
289
- ae_path=os.getenv("AE"),
290
- ae_params=AutoEncoderParams(
291
- resolution=256,
292
- in_channels=3,
293
- ch=128,
294
- out_ch=3,
295
- ch_mult=[1, 2, 4, 4],
296
- num_res_blocks=2,
297
- z_channels=16,
298
- scale_factor=0.3611,
299
- shift_factor=0.1159,
300
- ),
301
- ),
302
- }
303
-
304
-
305
- def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
306
- if len(missing) > 0 and len(unexpected) > 0:
307
- print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
308
- print("\n" + "-" * 79 + "\n")
309
- print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
310
- elif len(missing) > 0:
311
- print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
312
- elif len(unexpected) > 0:
313
- print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
314
-
315
-
316
- def load_flow_model(
317
- name: str, device: str | torch.device = "cuda", hf_download: bool = True, verbose: bool = False
318
- ) -> Flux:
319
- # Loading Flux
320
- print("Init model")
321
- ckpt_path = configs[name].ckpt_path
322
- lora_path = configs[name].lora_path
323
- if (
324
- ckpt_path is None
325
- and configs[name].repo_id is not None
326
- and configs[name].repo_flow is not None
327
- and hf_download
328
- ):
329
- ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
330
-
331
- with torch.device("meta" if ckpt_path is not None else device):
332
- if lora_path is not None:
333
- model = FluxLoraWrapper(params=configs[name].params).to(torch.bfloat16)
334
- else:
335
- model = Flux(configs[name].params).to(torch.bfloat16)
336
-
337
- if ckpt_path is not None:
338
- print("Loading checkpoint")
339
- # load_sft doesn't support torch.device
340
- sd = load_sft(ckpt_path, device=str(device))
341
- sd = optionally_expand_state_dict(model, sd)
342
- missing, unexpected = model.load_state_dict(sd, strict=False, assign=True)
343
- if verbose:
344
- print_load_warning(missing, unexpected)
345
-
346
- if configs[name].lora_path is not None:
347
- print("Loading LoRA")
348
- lora_sd = load_sft(configs[name].lora_path, device=str(device))
349
- # loading the lora params + overwriting scale values in the norms
350
- missing, unexpected = model.load_state_dict(lora_sd, strict=False, assign=True)
351
- if verbose:
352
- print_load_warning(missing, unexpected)
353
- return model
354
-
355
-
356
- def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
357
- # max length 64, 128, 256 and 512 should work (if your sequence is short enough)
358
- return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
359
-
360
-
361
- def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
362
- return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
363
-
364
-
365
- def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
366
- ckpt_path = configs[name].ae_path
367
- if (
368
- ckpt_path is None
369
- and configs[name].repo_id is not None
370
- and configs[name].repo_ae is not None
371
- and hf_download
372
- ):
373
- ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
374
-
375
- # Loading the autoencoder
376
- print("Init AE")
377
- with torch.device("meta" if ckpt_path is not None else device):
378
- ae = AutoEncoder(configs[name].ae_params)
379
-
380
- if ckpt_path is not None:
381
- sd = load_sft(ckpt_path, device=str(device))
382
- missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
383
- print_load_warning(missing, unexpected)
384
- return ae
385
-
386
-
387
- def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict:
388
- """
389
- Optionally expand the state dict to match the model's parameters shapes.
390
- """
391
- for name, param in model.named_parameters():
392
- if name in state_dict:
393
- if state_dict[name].shape != param.shape:
394
- print(
395
- f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}."
396
- )
397
- # expand with zeros:
398
- expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device)
399
- slices = tuple(slice(0, dim) for dim in state_dict[name].shape)
400
- expanded_state_dict_weight[slices] = state_dict[name]
401
- state_dict[name] = expanded_state_dict_weight
402
-
403
- return state_dict
404
-
405
-
406
- class WatermarkEmbedder:
407
- def __init__(self, watermark):
408
- self.watermark = watermark
409
- self.num_bits = len(WATERMARK_BITS)
410
- self.encoder = WatermarkEncoder()
411
- self.encoder.set_watermark("bits", self.watermark)
412
-
413
- def __call__(self, image: torch.Tensor) -> torch.Tensor:
414
- """
415
- Adds a predefined watermark to the input image
416
-
417
- Args:
418
- image: ([N,] B, RGB, H, W) in range [-1, 1]
419
-
420
- Returns:
421
- same as input but watermarked
422
- """
423
- image = 0.5 * image + 0.5
424
- squeeze = len(image.shape) == 4
425
- if squeeze:
426
- image = image[None, ...]
427
- n = image.shape[0]
428
- image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1]
429
- # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
430
- # watermarking libary expects input as cv2 BGR format
431
- for k in range(image_np.shape[0]):
432
- image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
433
- image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to(
434
- image.device
435
- )
436
- image = torch.clamp(image / 255, min=0.0, max=1.0)
437
- if squeeze:
438
- image = image[0]
439
- image = 2 * image - 1
440
- return image
441
-
442
-
443
- # A fixed 48-bit message that was chosen at random
444
- WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110
445
- # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
446
- WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
447
- embed_watermark = WatermarkEmbedder(WATERMARK_BITS)