mpatel57 commited on
Commit
a859ff0
·
verified ·
1 Parent(s): 42b3cea
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +15 -0
  2. app.py +427 -0
  3. assets/dog.webp +0 -0
  4. assets/vulcano.jpg +0 -0
  5. assets/vulcano_mask.webp +0 -0
  6. fluxcombined.py +1607 -0
  7. requirements.txt +9 -0
  8. saved_results/20241126_053639/input.png +0 -0
  9. saved_results/20241126_053639/mask.png +0 -0
  10. saved_results/20241126_053639/output.png +3 -0
  11. saved_results/20241126_053639/parameters.json +13 -0
  12. saved_results/20241126_055109/input.png +0 -0
  13. saved_results/20241126_055109/mask.png +0 -0
  14. saved_results/20241126_055109/output.png +3 -0
  15. saved_results/20241126_055109/parameters.json +13 -0
  16. saved_results/20241126_173140/input.png +0 -0
  17. saved_results/20241126_173140/mask.png +0 -0
  18. saved_results/20241126_173140/output.png +3 -0
  19. saved_results/20241126_173140/parameters.json +13 -0
  20. saved_results/20241126_181436/input.png +3 -0
  21. saved_results/20241126_181436/mask.png +0 -0
  22. saved_results/20241126_181436/output.png +0 -0
  23. saved_results/20241126_181436/parameters.json +13 -0
  24. saved_results/20241126_181633/input.png +3 -0
  25. saved_results/20241126_181633/mask.png +0 -0
  26. saved_results/20241126_181633/output.png +0 -0
  27. saved_results/20241126_181633/parameters.json +13 -0
  28. saved_results/20241126_214810/input.png +0 -0
  29. saved_results/20241126_214810/mask.png +0 -0
  30. saved_results/20241126_214810/output.png +3 -0
  31. saved_results/20241126_214810/parameters.json +13 -0
  32. saved_results/20241126_214908/input.png +0 -0
  33. saved_results/20241126_214908/mask.png +0 -0
  34. saved_results/20241126_214908/output.png +3 -0
  35. saved_results/20241126_214908/parameters.json +13 -0
  36. saved_results/20241126_215043/input.png +0 -0
  37. saved_results/20241126_215043/mask.png +0 -0
  38. saved_results/20241126_215043/output.png +3 -0
  39. saved_results/20241126_215043/parameters.json +13 -0
  40. saved_results/20241126_221300/input.png +0 -0
  41. saved_results/20241126_221300/mask.png +0 -0
  42. saved_results/20241126_221300/output.png +3 -0
  43. saved_results/20241126_221300/parameters.json +13 -0
  44. saved_results/20241126_222257/input.png +0 -0
  45. saved_results/20241126_222257/mask.png +0 -0
  46. saved_results/20241126_222257/output.png +3 -0
  47. saved_results/20241126_222257/parameters.json +13 -0
  48. saved_results/20241126_222442/input.png +0 -0
  49. saved_results/20241126_222442/mask.png +0 -0
  50. saved_results/20241126_222442/output.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ saved_results/20241126_053639/output.png filter=lfs diff=lfs merge=lfs -text
37
+ saved_results/20241126_055109/output.png filter=lfs diff=lfs merge=lfs -text
38
+ saved_results/20241126_173140/output.png filter=lfs diff=lfs merge=lfs -text
39
+ saved_results/20241126_181436/input.png filter=lfs diff=lfs merge=lfs -text
40
+ saved_results/20241126_181633/input.png filter=lfs diff=lfs merge=lfs -text
41
+ saved_results/20241126_214810/output.png filter=lfs diff=lfs merge=lfs -text
42
+ saved_results/20241126_214908/output.png filter=lfs diff=lfs merge=lfs -text
43
+ saved_results/20241126_215043/output.png filter=lfs diff=lfs merge=lfs -text
44
+ saved_results/20241126_221300/output.png filter=lfs diff=lfs merge=lfs -text
45
+ saved_results/20241126_222257/output.png filter=lfs diff=lfs merge=lfs -text
46
+ saved_results/20241126_222442/output.png filter=lfs diff=lfs merge=lfs -text
47
+ saved_results/20241126_222522/output.png filter=lfs diff=lfs merge=lfs -text
48
+ saved_results/20241126_223634/output.png filter=lfs diff=lfs merge=lfs -text
49
+ saved_results/20241126_223719/output.png filter=lfs diff=lfs merge=lfs -text
50
+ saved_results/20241127_025429/output.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import torch
4
+ from diffusers import StableDiffusionInpaintPipeline, StableDiffusionImg2ImgPipeline
5
+ from PIL import Image
6
+ import random
7
+ import numpy as np
8
+ import torch
9
+ import os
10
+ import json
11
+ from datetime import datetime
12
+
13
+ from fluxcombined import FluxPipeline
14
+ from scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
15
+
16
+ # Load the Stable Diffusion Inpainting model
17
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="scheduler")
18
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16, scheduler=scheduler)
19
+ pipe.to("cuda") # Comment this line if GPU is not available
20
+
21
+ # Function to process the image
22
+ @spaces.GPU(duration=120)
23
+ def process_image(
24
+ mode, image_layers, prompt, edit_prompt, seed, randomize_seed, num_inference_steps,
25
+ max_steps, learning_rate, max_source_steps, optimization_steps, true_cfg, mask_input
26
+ ):
27
+ image_with_mask = {
28
+ "image": image_layers["background"],
29
+ "mask": image_layers["layers"][0] if mask_input is None else mask_input
30
+ }
31
+
32
+ # Set seed
33
+ if randomize_seed or seed is None:
34
+ seed = random.randint(0, 2**32 - 1)
35
+ generator = torch.Generator("cuda").manual_seed(int(seed))
36
+
37
+ # Unpack image and mask
38
+ if image_with_mask is None:
39
+ return None, f"❌ Please upload an image and create a mask."
40
+ image = image_with_mask["image"]
41
+ mask = image_with_mask["mask"]
42
+
43
+ if image is None or mask is None:
44
+ return None, f"❌ Please ensure both image and mask are provided."
45
+
46
+ # Convert images to RGB
47
+ image = image.convert("RGB")
48
+ mask = mask.split()[-1] # Convert mask to grayscale
49
+
50
+ if mode == "Inpainting":
51
+ if not prompt:
52
+ return None, f"❌ Please provide a prompt for inpainting."
53
+ with torch.autocast("cuda"):
54
+ # Placeholder for using advanced parameters in the future
55
+ # Adjust parameters according to advanced settings if applicable
56
+ result = pipe.inpaint(
57
+ prompt=prompt,
58
+ input_image=image.resize((1024, 1024)),
59
+ mask_image=mask.resize((1024, 1024)),
60
+ num_inference_steps=num_inference_steps,
61
+ guidance_scale=0.5,
62
+ generator=generator,
63
+ save_masked_image=False,
64
+ output_path="",
65
+ learning_rate=learning_rate,
66
+ max_steps=max_steps
67
+ ).images[0]
68
+ pipe.vae = pipe.vae.to(torch.float16)
69
+ return result, f"✅ Inpainting completed with seed {seed}."
70
+ elif mode == "Editing":
71
+ if not edit_prompt:
72
+ return None, f"❌ Please provide a prompt for editing."
73
+ if not prompt:
74
+ prompt = ""
75
+ # Resize the mask to match the image
76
+ # mask = mask.resize(image.size)
77
+ with torch.autocast("cuda"):
78
+ # Placeholder for using advanced parameters in the future
79
+ # Adjust parameters according to advanced settings if applicable
80
+ result = pipe.edit2(
81
+ prompt=edit_prompt,
82
+ input_image=image.resize((1024, 1024)),
83
+ mask_image=mask.resize((1024, 1024)),
84
+ num_inference_steps=num_inference_steps,
85
+ guidance_scale=0.0,
86
+ generator=generator,
87
+ save_masked_image=False,
88
+ output_path="",
89
+ learning_rate=learning_rate,
90
+ max_steps=max_steps,
91
+ optimization_steps=optimization_steps,
92
+ true_cfg=true_cfg,
93
+ negative_prompt=prompt,
94
+ source_steps=max_source_steps,
95
+ ).images[0]
96
+ return result, f"✅ Editing completed with seed {seed}."
97
+ else:
98
+ return None, f"❌ Invalid mode selected."
99
+
100
+ # Design the Gradio interface
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown(
103
+ """
104
+ <style>
105
+ body {background-color: #f5f5f5; color: #333333;}
106
+ h1 {text-align: center; font-family: 'Helvetica', sans-serif; margin-bottom: 10px;}
107
+ h2 {text-align: center; color: #666666; font-weight: normal; margin-bottom: 30px;}
108
+ .gradio-container {max-width: 800px; margin: auto;}
109
+ .footer {text-align: center; margin-top: 20px; color: #999999; font-size: 12px;}
110
+ </style>
111
+ """
112
+ )
113
+ gr.Markdown("<h1>🍲 FlowChef 🍲</h1>")
114
+ gr.Markdown("<h2>Inversion/Gradient/Training-free Steering of Flux.1[Dev]</h2>")
115
+ gr.Markdown("<h2><p><a href='https://flowchef.github.io/'>Project Page</a> | <a href='#'>Paper</a></p> (Steering Rectified Flow Models in the Vector Field for Controlled Image Generation)</h2>")
116
+ gr.Markdown("<h3>💡 We recommend going through our <a href='#'>tutorial introduction</a> before getting started!</h3>")
117
+
118
+ # Store current state
119
+ current_input_image = None
120
+ current_mask = None
121
+ current_output_image = None
122
+ current_params = {}
123
+
124
+ # Images at the top
125
+ with gr.Row():
126
+ with gr.Column():
127
+ image_input = gr.ImageMask(
128
+ # source="upload",
129
+ # tool="sketch",
130
+ type="pil",
131
+ label="Input Image and Mask",
132
+ image_mode="RGBA",
133
+ height=512,
134
+ width=512,
135
+ )
136
+ with gr.Column():
137
+ output_image = gr.Image(label="Output Image")
138
+
139
+ # All options below
140
+ with gr.Column():
141
+ mode = gr.Radio(
142
+ choices=["Inpainting", "Editing"], label="Select Mode", value="Inpainting"
143
+ )
144
+ prompt = gr.Textbox(
145
+ label="Prompt",
146
+ placeholder="Describe what should appear in the masked area..."
147
+ )
148
+ edit_prompt = gr.Textbox(
149
+ label="Editing Prompt",
150
+ placeholder="Describe how you want to edit the image..."
151
+ )
152
+ with gr.Row():
153
+ seed = gr.Number(label="Seed (Optional)", value=None)
154
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
155
+ num_inference_steps = gr.Slider(
156
+ label="Inference Steps", minimum=1, maximum=50, value=30
157
+ )
158
+ # Advanced settings in an accordion
159
+ with gr.Accordion("Advanced Settings", open=False):
160
+ max_steps = gr.Slider(label="Max Steps", minimum=1, maximum=30, value=30)
161
+ learning_rate = gr.Slider(label="Learning Rate", minimum=0.1, maximum=1.0, value=0.5)
162
+ true_cfg = gr.Slider(label="Guidance Scale (only for editing)", minimum=1, maximum=20, value=2)
163
+ max_source_steps = gr.Slider(label="Max Source Steps (only for editing)", minimum=1, maximum=30, value=20)
164
+ optimization_steps = gr.Slider(label="Optimization Steps", minimum=1, maximum=10, value=1)
165
+ mask_input = gr.Image(
166
+ type="pil",
167
+ label="Optional Mask",
168
+ image_mode="RGBA",
169
+ )
170
+ with gr.Row():
171
+ run_button = gr.Button("Run", variant="primary")
172
+ save_button = gr.Button("Save Data", variant="secondary")
173
+
174
+ def update_visibility(selected_mode):
175
+ if selected_mode == "Inpainting":
176
+ return gr.update(visible=True), gr.update(visible=False)
177
+ else:
178
+ return gr.update(visible=True), gr.update(visible=True)
179
+
180
+ mode.change(
181
+ update_visibility,
182
+ inputs=mode,
183
+ outputs=[prompt, edit_prompt],
184
+ )
185
+
186
+ def run_and_update_status(
187
+ mode, image_with_mask, prompt, edit_prompt, seed, randomize_seed, num_inference_steps,
188
+ max_steps, learning_rate, max_source_steps, optimization_steps, true_cfg, mask_input
189
+ ):
190
+ result_image, result_status = process_image(
191
+ mode, image_with_mask, prompt, edit_prompt, seed, randomize_seed, num_inference_steps,
192
+ max_steps, learning_rate, max_source_steps, optimization_steps, true_cfg, mask_input
193
+ )
194
+
195
+ # Store current state
196
+ global current_input_image, current_mask, current_output_image, current_params
197
+
198
+ current_input_image = image_with_mask["background"] if image_with_mask else None
199
+ current_mask = mask_input if mask_input is not None else (image_with_mask["layers"][0] if image_with_mask else None)
200
+ current_output_image = result_image
201
+ current_params = {
202
+ "mode": mode,
203
+ "prompt": prompt,
204
+ "edit_prompt": edit_prompt,
205
+ "seed": seed,
206
+ "randomize_seed": randomize_seed,
207
+ "num_inference_steps": num_inference_steps,
208
+ "max_steps": max_steps,
209
+ "learning_rate": learning_rate,
210
+ "max_source_steps": max_source_steps,
211
+ "optimization_steps": optimization_steps,
212
+ "true_cfg": true_cfg
213
+ }
214
+
215
+ return result_image
216
+
217
+ def save_data():
218
+ if not os.path.exists("saved_results"):
219
+ os.makedirs("saved_results")
220
+
221
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
222
+ save_dir = os.path.join("saved_results", timestamp)
223
+ os.makedirs(save_dir)
224
+
225
+ # Save images
226
+ if current_input_image:
227
+ current_input_image.save(os.path.join(save_dir, "input.png"))
228
+ if current_mask:
229
+ current_mask.save(os.path.join(save_dir, "mask.png"))
230
+ if current_output_image:
231
+ current_output_image.save(os.path.join(save_dir, "output.png"))
232
+
233
+ # Save parameters
234
+ with open(os.path.join(save_dir, "parameters.json"), "w") as f:
235
+ json.dump(current_params, f, indent=4)
236
+
237
+ return f"✅ Data saved in {save_dir}"
238
+
239
+ run_button.click(
240
+ fn=run_and_update_status,
241
+ inputs=[
242
+ mode,
243
+ image_input,
244
+ prompt,
245
+ edit_prompt,
246
+ seed,
247
+ randomize_seed,
248
+ num_inference_steps,
249
+ max_steps,
250
+ learning_rate,
251
+ max_source_steps,
252
+ optimization_steps,
253
+ true_cfg,
254
+ mask_input
255
+ ],
256
+ outputs=output_image,
257
+ )
258
+
259
+ save_button.click(fn=save_data)
260
+
261
+ gr.Markdown(
262
+ "<div class='footer'>Developed with ❤️ using Flux and Gradio by <a href='https://maitreyapatel.com'>Maitreya Patel</a></div>"
263
+ )
264
+
265
+ def load_example_image_with_mask(image_path):
266
+ # Load the image
267
+ image = Image.open(image_path)
268
+ # Create an empty mask of the same size
269
+ mask = Image.new('L', image.size, 0)
270
+ return {"background": image, "layers": [mask], "composite": image}
271
+
272
+ examples_dir = "assets"
273
+ volcano_dict = load_example_image_with_mask(os.path.join(examples_dir, "vulcano.jpg"))
274
+ dog_dict = load_example_image_with_mask(os.path.join(examples_dir, "dog.webp"))
275
+
276
+ gr.Examples(
277
+ examples=[
278
+ [
279
+ "Inpainting", # mode
280
+ "./saved_results/20241126_053639/input.png", # image with mask
281
+ "./saved_results/20241126_053639/mask.png",
282
+ "./saved_results/20241126_053639/output.png",
283
+ "a dog", # prompt
284
+ " ", # edit_prompt
285
+ 0, # seed
286
+ True, # randomize_seed
287
+ 30, # num_inference_steps
288
+ 30, # max_steps
289
+ 1.0, # learning_rate
290
+ 20, # max_source_steps
291
+ 10, # optimization_steps
292
+ 2, # true_cfg
293
+ ],
294
+ [
295
+ "Inpainting", # mode
296
+ "./saved_results/20241126_173140/input.png", # image with mask
297
+ "./saved_results/20241126_173140/mask.png",
298
+ "./saved_results/20241126_173140/output.png",
299
+ "a cat with blue eyes", # prompt
300
+ " ", # edit_prompt
301
+ 0, # seed
302
+ True, # randomize_seed
303
+ 30, # num_inference_steps
304
+ 20, # max_steps
305
+ 1.0, # learning_rate
306
+ 20, # max_source_steps
307
+ 10, # optimization_steps
308
+ 2, # true_cfg
309
+ ],
310
+ [
311
+ "Editing", # mode
312
+ "./saved_results/20241126_181633/input.png", # image with mask
313
+ "./saved_results/20241126_181633/mask.png",
314
+ "./saved_results/20241126_181633/output.png",
315
+ " ", # prompt
316
+ "volcano eruption", # edit_prompt
317
+ 0, # seed
318
+ True, # randomize_seed
319
+ 30, # num_inference_steps
320
+ 20, # max_steps
321
+ 0.5, # learning_rate
322
+ 2, # max_source_steps
323
+ 3, # optimization_steps
324
+ 4.5, # true_cfg
325
+ ],
326
+ [
327
+ "Editing", # mode
328
+ "./saved_results/20241126_214810/input.png", # image with mask
329
+ "./saved_results/20241126_214810/mask.png",
330
+ "./saved_results/20241126_214810/output.png",
331
+ " ", # prompt
332
+ "a dog with flowers in the mouth", # edit_prompt
333
+ 0, # seed
334
+ True, # randomize_seed
335
+ 30, # num_inference_steps
336
+ 30, # max_steps
337
+ 1, # learning_rate
338
+ 5, # max_source_steps
339
+ 3, # optimization_steps
340
+ 4.5, # true_cfg
341
+ ],
342
+ [
343
+ "Inpainting", # mode
344
+ "./saved_results/20241127_025429/input.png", # image with mask
345
+ "./saved_results/20241127_025429/mask.png",
346
+ "./saved_results/20241127_025429/output.png",
347
+ "A building with \"ASU\" written on it.", # prompt
348
+ "", # edit_prompt
349
+ 52, # seed
350
+ False, # randomize_seed
351
+ 30, # num_inference_steps
352
+ 30, # max_steps
353
+ 1, # learning_rate
354
+ 20, # max_source_steps
355
+ 10, # optimization_steps
356
+ 2, # true_cfg
357
+ ],
358
+ [
359
+ "Inpainting", # mode
360
+ "./saved_results/20241126_222257/input.png", # image with mask
361
+ "./saved_results/20241126_222257/mask.png",
362
+ "./saved_results/20241126_222257/output.png",
363
+ "A cute pig with big eyes", # prompt
364
+ "", # edit_prompt
365
+ 0, # seed
366
+ True, # randomize_seed
367
+ 30, # num_inference_steps
368
+ 20, # max_steps
369
+ 1, # learning_rate
370
+ 20, # max_source_steps
371
+ 5, # optimization_steps
372
+ 2, # true_cfg
373
+ ],
374
+ [
375
+ "Editing", # mode
376
+ "./saved_results/20241126_222522/input.png", # image with mask
377
+ "./saved_results/20241126_222522/mask.png",
378
+ "./saved_results/20241126_222522/output.png",
379
+ "A cute rabbit with big eyes", # prompt
380
+ "A cute pig with big eyes", # edit_prompt
381
+ 0, # seed
382
+ True, # randomize_seed
383
+ 30, # num_inference_steps
384
+ 20, # max_steps
385
+ 0.4, # learning_rate
386
+ 5, # max_source_steps
387
+ 5, # optimization_steps
388
+ 4.5, # true_cfg
389
+ ],
390
+ [
391
+ "Editing", # mode
392
+ "./saved_results/20241126_223719/input.png", # image with mask
393
+ "./saved_results/20241126_223719/mask.png",
394
+ "./saved_results/20241126_223719/output.png",
395
+ "a cat", # prompt
396
+ "a tiger", # edit_prompt
397
+ 0, # seed
398
+ True, # randomize_seed
399
+ 30, # num_inference_steps
400
+ 30, # max_steps
401
+ 0.6, # learning_rate
402
+ 10, # max_source_steps
403
+ 5, # optimization_steps
404
+ 4.5, # true_cfg
405
+ ],
406
+ ],
407
+ inputs=[
408
+ mode,
409
+ image_input,
410
+ mask_input,
411
+ output_image,
412
+ prompt,
413
+ edit_prompt,
414
+ seed,
415
+ randomize_seed,
416
+ num_inference_steps,
417
+ max_steps,
418
+ learning_rate,
419
+ max_source_steps,
420
+ optimization_steps,
421
+ true_cfg,
422
+ ],
423
+ # outputs=[output_image],
424
+ # fn=run_and_update_status,
425
+ # cache_examples=True,
426
+ )
427
+ demo.launch()
assets/dog.webp ADDED
assets/vulcano.jpg ADDED
assets/vulcano_mask.webp ADDED
fluxcombined.py ADDED
@@ -0,0 +1,1607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21
+
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
24
+ from diffusers.models.autoencoders import AutoencoderKL
25
+ from diffusers.models.transformers import FluxTransformer2DModel
26
+ from scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
27
+ from diffusers.utils import (
28
+ USE_PEFT_BACKEND,
29
+ is_torch_xla_available,
30
+ logging,
31
+ replace_example_docstring,
32
+ scale_lora_layers,
33
+ unscale_lora_layers,
34
+ )
35
+ from diffusers.utils.torch_utils import randn_tensor
36
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
37
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
38
+
39
+ import os
40
+ import torch
41
+ import torch.nn as nn
42
+ from os.path import expanduser # pylint: disable=import-outside-toplevel
43
+ from urllib.request import urlretrieve # pylint: disable=import-outside-toplevel
44
+ from torchvision import transforms as TF
45
+
46
+ if is_torch_xla_available():
47
+ import torch_xla.core.xla_model as xm
48
+
49
+ XLA_AVAILABLE = True
50
+ else:
51
+ XLA_AVAILABLE = False
52
+
53
+
54
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
+
56
+ EXAMPLE_DOC_STRING = """
57
+ Examples:
58
+ ```py
59
+ >>> import torch
60
+ >>> from diffusers import FluxPipeline
61
+
62
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
63
+ >>> pipe.to("cuda")
64
+ >>> prompt = "A cat holding a sign that says hello world"
65
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
66
+ >>> # Refer to the pipeline documentation for more details.
67
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
68
+ >>> image.save("flux.png")
69
+ ```
70
+ """
71
+
72
+ import sys
73
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
74
+
75
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
76
+
77
+ def retrieve_latents(
78
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
79
+ ):
80
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
81
+ return encoder_output.latent_dist.sample(generator)
82
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
83
+ return encoder_output.latent_dist.mode()
84
+ elif hasattr(encoder_output, "latents"):
85
+ return encoder_output.latents
86
+ else:
87
+ raise AttributeError("Could not access latents of provided encoder_output")
88
+
89
+
90
+ def calculate_shift(
91
+ image_seq_len,
92
+ base_seq_len: int = 256,
93
+ max_seq_len: int = 4096,
94
+ base_shift: float = 0.5,
95
+ max_shift: float = 1.16,
96
+ ):
97
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
98
+ b = base_shift - m * base_seq_len
99
+ mu = image_seq_len * m + b
100
+ return mu
101
+
102
+
103
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104
+ def retrieve_timesteps(
105
+ scheduler,
106
+ num_inference_steps: Optional[int] = None,
107
+ device: Optional[Union[str, torch.device]] = None,
108
+ timesteps: Optional[List[int]] = None,
109
+ sigmas: Optional[List[float]] = None,
110
+ **kwargs,
111
+ ):
112
+ """
113
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
114
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
115
+
116
+ Args:
117
+ scheduler (`SchedulerMixin`):
118
+ The scheduler to get timesteps from.
119
+ num_inference_steps (`int`):
120
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
121
+ must be `None`.
122
+ device (`str` or `torch.device`, *optional*):
123
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
124
+ timesteps (`List[int]`, *optional*):
125
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
126
+ `num_inference_steps` and `sigmas` must be `None`.
127
+ sigmas (`List[float]`, *optional*):
128
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
129
+ `num_inference_steps` and `timesteps` must be `None`.
130
+
131
+ Returns:
132
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
133
+ second element is the number of inference steps.
134
+ """
135
+ if timesteps is not None and sigmas is not None:
136
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
137
+ if timesteps is not None:
138
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
139
+ if not accepts_timesteps:
140
+ raise ValueError(
141
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
142
+ f" timestep schedules. Please check whether you are using the correct scheduler."
143
+ )
144
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
145
+ timesteps = scheduler.timesteps
146
+ num_inference_steps = len(timesteps)
147
+ elif sigmas is not None:
148
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
149
+ if not accept_sigmas:
150
+ raise ValueError(
151
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
152
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
153
+ )
154
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
155
+ timesteps = scheduler.timesteps
156
+ num_inference_steps = len(timesteps)
157
+ else:
158
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
159
+ timesteps = scheduler.timesteps
160
+ return timesteps, num_inference_steps
161
+
162
+
163
+ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
164
+ r"""
165
+ The Flux pipeline for text-to-image generation.
166
+
167
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
168
+
169
+ Args:
170
+ transformer ([`FluxTransformer2DModel`]):
171
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
172
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
173
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
174
+ vae ([`AutoencoderKL`]):
175
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
176
+ text_encoder ([`CLIPTextModel`]):
177
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
178
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
179
+ text_encoder_2 ([`T5EncoderModel`]):
180
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
181
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
182
+ tokenizer (`CLIPTokenizer`):
183
+ Tokenizer of class
184
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
185
+ tokenizer_2 (`T5TokenizerFast`):
186
+ Second Tokenizer of class
187
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
188
+ """
189
+
190
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
191
+ _optional_components = []
192
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
193
+
194
+ def __init__(
195
+ self,
196
+ scheduler: FlowMatchEulerDiscreteScheduler,
197
+ vae: AutoencoderKL,
198
+ text_encoder: CLIPTextModel,
199
+ tokenizer: CLIPTokenizer,
200
+ text_encoder_2: T5EncoderModel,
201
+ tokenizer_2: T5TokenizerFast,
202
+ transformer: FluxTransformer2DModel,
203
+ ):
204
+ super().__init__()
205
+
206
+ self.register_modules(
207
+ vae=vae,
208
+ text_encoder=text_encoder,
209
+ text_encoder_2=text_encoder_2,
210
+ tokenizer=tokenizer,
211
+ tokenizer_2=tokenizer_2,
212
+ transformer=transformer,
213
+ scheduler=scheduler,
214
+ )
215
+ self.vae_scale_factor = (
216
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
217
+ )
218
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
219
+ self.tokenizer_max_length = (
220
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
221
+ )
222
+ self.default_sample_size = 64
223
+
224
+ def _get_t5_prompt_embeds(
225
+ self,
226
+ prompt: Union[str, List[str]] = None,
227
+ num_images_per_prompt: int = 1,
228
+ max_sequence_length: int = 512,
229
+ device: Optional[torch.device] = None,
230
+ dtype: Optional[torch.dtype] = None,
231
+ ):
232
+ device = device or self._execution_device
233
+ dtype = dtype or self.text_encoder.dtype
234
+
235
+ prompt = [prompt] if isinstance(prompt, str) else prompt
236
+ batch_size = len(prompt)
237
+
238
+ text_inputs = self.tokenizer_2(
239
+ prompt,
240
+ padding="max_length",
241
+ max_length=max_sequence_length,
242
+ truncation=True,
243
+ return_length=False,
244
+ return_overflowing_tokens=False,
245
+ return_tensors="pt",
246
+ )
247
+ text_input_ids = text_inputs.input_ids
248
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
249
+
250
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
251
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
252
+ logger.warning(
253
+ "The following part of your input was truncated because `max_sequence_length` is set to "
254
+ f" {max_sequence_length} tokens: {removed_text}"
255
+ )
256
+
257
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
258
+
259
+ dtype = self.text_encoder_2.dtype
260
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
261
+
262
+ _, seq_len, _ = prompt_embeds.shape
263
+
264
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
265
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
266
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
267
+
268
+ return prompt_embeds
269
+
270
+ def _get_clip_prompt_embeds(
271
+ self,
272
+ prompt: Union[str, List[str]],
273
+ num_images_per_prompt: int = 1,
274
+ device: Optional[torch.device] = None,
275
+ ):
276
+ device = device or self._execution_device
277
+
278
+ prompt = [prompt] if isinstance(prompt, str) else prompt
279
+ batch_size = len(prompt)
280
+
281
+ text_inputs = self.tokenizer(
282
+ prompt,
283
+ padding="max_length",
284
+ max_length=self.tokenizer_max_length,
285
+ truncation=True,
286
+ return_overflowing_tokens=False,
287
+ return_length=False,
288
+ return_tensors="pt",
289
+ )
290
+
291
+ text_input_ids = text_inputs.input_ids
292
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
293
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
294
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
295
+ logger.warning(
296
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
297
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
298
+ )
299
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
300
+
301
+ # Use pooled output of CLIPTextModel
302
+ prompt_embeds = prompt_embeds.pooler_output
303
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
304
+
305
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
306
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
307
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
308
+
309
+ return prompt_embeds
310
+
311
+ def encode_prompt(
312
+ self,
313
+ prompt: Union[str, List[str]],
314
+ prompt_2: Union[str, List[str]],
315
+ device: Optional[torch.device] = None,
316
+ num_images_per_prompt: int = 1,
317
+ prompt_embeds: Optional[torch.FloatTensor] = None,
318
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
319
+ max_sequence_length: int = 512,
320
+ lora_scale: Optional[float] = None,
321
+ ):
322
+ r"""
323
+
324
+ Args:
325
+ prompt (`str` or `List[str]`, *optional*):
326
+ prompt to be encoded
327
+ prompt_2 (`str` or `List[str]`, *optional*):
328
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
329
+ used in all text-encoders
330
+ device: (`torch.device`):
331
+ torch device
332
+ num_images_per_prompt (`int`):
333
+ number of images that should be generated per prompt
334
+ prompt_embeds (`torch.FloatTensor`, *optional*):
335
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
336
+ provided, text embeddings will be generated from `prompt` input argument.
337
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
338
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
339
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
340
+ lora_scale (`float`, *optional*):
341
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
342
+ """
343
+ device = device or self._execution_device
344
+
345
+ # set lora scale so that monkey patched LoRA
346
+ # function of text encoder can correctly access it
347
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
348
+ self._lora_scale = lora_scale
349
+
350
+ # dynamically adjust the LoRA scale
351
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
352
+ scale_lora_layers(self.text_encoder, lora_scale)
353
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
354
+ scale_lora_layers(self.text_encoder_2, lora_scale)
355
+
356
+ prompt = [prompt] if isinstance(prompt, str) else prompt
357
+ if prompt is not None:
358
+ batch_size = len(prompt)
359
+ else:
360
+ batch_size = prompt_embeds.shape[0]
361
+
362
+ if prompt_embeds is None:
363
+ prompt_2 = prompt_2 or prompt
364
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
365
+
366
+ # We only use the pooled prompt output from the CLIPTextModel
367
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
368
+ prompt=prompt,
369
+ device=device,
370
+ num_images_per_prompt=num_images_per_prompt,
371
+ )
372
+ prompt_embeds = self._get_t5_prompt_embeds(
373
+ prompt=prompt_2,
374
+ num_images_per_prompt=num_images_per_prompt,
375
+ max_sequence_length=max_sequence_length,
376
+ device=device,
377
+ )
378
+
379
+ if self.text_encoder is not None:
380
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
381
+ # Retrieve the original scale by scaling back the LoRA layers
382
+ unscale_lora_layers(self.text_encoder, lora_scale)
383
+
384
+ if self.text_encoder_2 is not None:
385
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
386
+ # Retrieve the original scale by scaling back the LoRA layers
387
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
388
+
389
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
390
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
391
+ # text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
392
+
393
+ return prompt_embeds, pooled_prompt_embeds, text_ids
394
+
395
+ def encode_prompt_edit(
396
+ self,
397
+ prompt: Union[str, List[str]],
398
+ prompt_2: Union[str, List[str]],
399
+ negative_prompt: Union[str, List[str]] = None,
400
+ negative_prompt_2: Union[str, List[str]] = None,
401
+ device: Optional[torch.device] = None,
402
+ num_images_per_prompt: int = 1,
403
+ prompt_embeds: Optional[torch.FloatTensor] = None,
404
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
405
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
406
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
407
+ max_sequence_length: int = 512,
408
+ lora_scale: Optional[float] = None,
409
+ do_true_cfg: bool = False,
410
+ ):
411
+ device = device or self._execution_device
412
+
413
+ # Set LoRA scale if applicable
414
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
415
+ self._lora_scale = lora_scale
416
+
417
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
418
+ scale_lora_layers(self.text_encoder, lora_scale)
419
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
420
+ scale_lora_layers(self.text_encoder_2, lora_scale)
421
+
422
+ prompt = [prompt] if isinstance(prompt, str) else prompt
423
+ batch_size = len(prompt)
424
+
425
+ if do_true_cfg and negative_prompt is not None:
426
+ negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
427
+ negative_batch_size = len(negative_prompt)
428
+
429
+ if negative_batch_size != batch_size:
430
+ raise ValueError(
431
+ f"Negative prompt batch size ({negative_batch_size}) does not match prompt batch size ({batch_size})"
432
+ )
433
+
434
+ # Concatenate prompts
435
+ prompts = prompt + negative_prompt
436
+ prompts_2 = (
437
+ prompt_2 + negative_prompt_2 if prompt_2 is not None and negative_prompt_2 is not None else None
438
+ )
439
+ else:
440
+ prompts = prompt
441
+ prompts_2 = prompt_2
442
+
443
+ if prompt_embeds is None:
444
+ if prompts_2 is None:
445
+ prompts_2 = prompts
446
+
447
+ # Get pooled prompt embeddings from CLIPTextModel
448
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
449
+ prompt=prompts,
450
+ device=device,
451
+ num_images_per_prompt=num_images_per_prompt,
452
+ )
453
+ prompt_embeds = self._get_t5_prompt_embeds(
454
+ prompt=prompts_2,
455
+ num_images_per_prompt=num_images_per_prompt,
456
+ max_sequence_length=max_sequence_length,
457
+ device=device,
458
+ )
459
+
460
+ if do_true_cfg and negative_prompt is not None:
461
+ # Split embeddings back into positive and negative parts
462
+ total_batch_size = batch_size * num_images_per_prompt
463
+ positive_indices = slice(0, total_batch_size)
464
+ negative_indices = slice(total_batch_size, 2 * total_batch_size)
465
+
466
+ positive_pooled_prompt_embeds = pooled_prompt_embeds[positive_indices]
467
+ negative_pooled_prompt_embeds = pooled_prompt_embeds[negative_indices]
468
+
469
+ positive_prompt_embeds = prompt_embeds[positive_indices]
470
+ negative_prompt_embeds = prompt_embeds[negative_indices]
471
+
472
+ pooled_prompt_embeds = positive_pooled_prompt_embeds
473
+ prompt_embeds = positive_prompt_embeds
474
+
475
+ # Unscale LoRA layers
476
+ if self.text_encoder is not None:
477
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
478
+ unscale_lora_layers(self.text_encoder, lora_scale)
479
+
480
+ if self.text_encoder_2 is not None:
481
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
482
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
483
+
484
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
485
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
486
+
487
+ if do_true_cfg and negative_prompt is not None:
488
+ return (
489
+ prompt_embeds,
490
+ pooled_prompt_embeds,
491
+ text_ids,
492
+ negative_prompt_embeds,
493
+ negative_pooled_prompt_embeds,
494
+ )
495
+ else:
496
+ return prompt_embeds, pooled_prompt_embeds, text_ids, None, None
497
+
498
+
499
+ def check_inputs(
500
+ self,
501
+ prompt,
502
+ prompt_2,
503
+ height,
504
+ width,
505
+ prompt_embeds=None,
506
+ pooled_prompt_embeds=None,
507
+ callback_on_step_end_tensor_inputs=None,
508
+ max_sequence_length=None,
509
+ ):
510
+ if height % 8 != 0 or width % 8 != 0:
511
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
512
+
513
+ if callback_on_step_end_tensor_inputs is not None and not all(
514
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
515
+ ):
516
+ raise ValueError(
517
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
518
+ )
519
+
520
+ if prompt is not None and prompt_embeds is not None:
521
+ raise ValueError(
522
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
523
+ " only forward one of the two."
524
+ )
525
+ elif prompt_2 is not None and prompt_embeds is not None:
526
+ raise ValueError(
527
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
528
+ " only forward one of the two."
529
+ )
530
+ elif prompt is None and prompt_embeds is None:
531
+ raise ValueError(
532
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
533
+ )
534
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
535
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
536
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
537
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
538
+
539
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
540
+ raise ValueError(
541
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
542
+ )
543
+
544
+ if max_sequence_length is not None and max_sequence_length > 512:
545
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
546
+
547
+ @staticmethod
548
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
549
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
550
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
551
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
552
+
553
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
554
+
555
+ # latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
556
+ latent_image_ids = latent_image_ids.reshape(
557
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
558
+ )
559
+
560
+ return latent_image_ids.to(device=device, dtype=dtype)
561
+
562
+ @staticmethod
563
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
564
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
565
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
566
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
567
+
568
+ return latents
569
+
570
+ @staticmethod
571
+ def _unpack_latents(latents, height, width, vae_scale_factor):
572
+ batch_size, num_patches, channels = latents.shape
573
+
574
+ height = height // vae_scale_factor
575
+ width = width // vae_scale_factor
576
+
577
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
578
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
579
+
580
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
581
+
582
+ return latents
583
+
584
+ def prepare_latents(
585
+ self,
586
+ batch_size,
587
+ num_channels_latents,
588
+ height,
589
+ width,
590
+ dtype,
591
+ device,
592
+ generator,
593
+ latents=None,
594
+ ):
595
+ height = 2 * (int(height) // self.vae_scale_factor)
596
+ width = 2 * (int(width) // self.vae_scale_factor)
597
+
598
+ shape = (batch_size, num_channels_latents, height, width)
599
+
600
+ if latents is not None:
601
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
602
+ return latents.to(device=device, dtype=dtype), latent_image_ids
603
+
604
+ if isinstance(generator, list) and len(generator) != batch_size:
605
+ raise ValueError(
606
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
607
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
608
+ )
609
+
610
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
611
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
612
+
613
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
614
+
615
+ return latents, latent_image_ids
616
+
617
+ @property
618
+ def guidance_scale(self):
619
+ return self._guidance_scale
620
+
621
+ @property
622
+ def joint_attention_kwargs(self):
623
+ return self._joint_attention_kwargs
624
+
625
+ @property
626
+ def num_timesteps(self):
627
+ return self._num_timesteps
628
+
629
+ @property
630
+ def interrupt(self):
631
+ return self._interrupt
632
+
633
+
634
+ def prepare_mask_latents(
635
+ self,
636
+ mask,
637
+ masked_image,
638
+ batch_size,
639
+ num_channels_latents,
640
+ num_images_per_prompt,
641
+ height,
642
+ width,
643
+ dtype,
644
+ device,
645
+ generator,
646
+ ):
647
+ height = 2 * (int(height) // self.vae_scale_factor)
648
+ width = 2 * (int(width) // self.vae_scale_factor)
649
+ # resize the mask to latents shape as we concatenate the mask to the latents
650
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
651
+ # and half precision
652
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
653
+ mask = mask.to(device=device, dtype=dtype)
654
+
655
+ batch_size = batch_size * num_images_per_prompt
656
+
657
+ masked_image = masked_image.to(device=device, dtype=dtype)
658
+
659
+ if masked_image.shape[1] == 16:
660
+ masked_image_latents = masked_image
661
+ else:
662
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
663
+
664
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
665
+
666
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
667
+ if mask.shape[0] < batch_size:
668
+ if not batch_size % mask.shape[0] == 0:
669
+ raise ValueError(
670
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
671
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
672
+ " of masks that you pass is divisible by the total requested batch size."
673
+ )
674
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
675
+ if masked_image_latents.shape[0] < batch_size:
676
+ if not batch_size % masked_image_latents.shape[0] == 0:
677
+ raise ValueError(
678
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
679
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
680
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
681
+ )
682
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
683
+
684
+ # aligning device to prevent device errors when concating it with the latent model input
685
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
686
+
687
+ masked_image_latents = self._pack_latents(
688
+ masked_image_latents,
689
+ batch_size,
690
+ num_channels_latents,
691
+ height,
692
+ width,
693
+ )
694
+ mask = self._pack_latents(
695
+ mask.repeat(1, num_channels_latents, 1, 1),
696
+ batch_size,
697
+ num_channels_latents,
698
+ height,
699
+ width,
700
+ )
701
+
702
+ return mask, masked_image_latents
703
+
704
+ @torch.no_grad()
705
+ def inpaint(
706
+ self,
707
+ prompt: Union[str, List[str]] = None,
708
+ prompt_2: Optional[Union[str, List[str]]] = None,
709
+ height: Optional[int] = None,
710
+ width: Optional[int] = None,
711
+ num_inference_steps: int = 28,
712
+ timesteps: List[int] = None,
713
+ guidance_scale: float = 7.0,
714
+ num_images_per_prompt: Optional[int] = 1,
715
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
716
+ latents: Optional[torch.FloatTensor] = None,
717
+ prompt_embeds: Optional[torch.FloatTensor] = None,
718
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
719
+ output_type: Optional[str] = "pil",
720
+ return_dict: bool = True,
721
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
722
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
723
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
724
+ max_sequence_length: int = 512,
725
+ optimization_steps: int = 3,
726
+ learning_rate: float = 0.8,
727
+ max_steps: int = 5,
728
+ input_image = None,
729
+ save_masked_image = False,
730
+ output_path="",
731
+ mask_image = None,
732
+ ):
733
+
734
+ height = height or self.default_sample_size * self.vae_scale_factor
735
+ width = width or self.default_sample_size * self.vae_scale_factor
736
+
737
+ # 1. Check inputs. Raise error if not correct
738
+ self.check_inputs(
739
+ prompt,
740
+ prompt_2,
741
+ height,
742
+ width,
743
+ prompt_embeds=prompt_embeds,
744
+ pooled_prompt_embeds=pooled_prompt_embeds,
745
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
746
+ max_sequence_length=max_sequence_length,
747
+ )
748
+
749
+ self._guidance_scale = guidance_scale
750
+ self._joint_attention_kwargs = joint_attention_kwargs
751
+ self._interrupt = False
752
+
753
+ # 2. Define call parameters
754
+ if prompt is not None and isinstance(prompt, str):
755
+ batch_size = 1
756
+ elif prompt is not None and isinstance(prompt, list):
757
+ batch_size = len(prompt)
758
+ else:
759
+ batch_size = prompt_embeds.shape[0]
760
+
761
+ device = self._execution_device
762
+
763
+ lora_scale = (
764
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
765
+ )
766
+ (
767
+ prompt_embeds,
768
+ pooled_prompt_embeds,
769
+ text_ids,
770
+ ) = self.encode_prompt(
771
+ prompt=prompt,
772
+ prompt_2=prompt_2,
773
+ prompt_embeds=prompt_embeds,
774
+ pooled_prompt_embeds=pooled_prompt_embeds,
775
+ device=device,
776
+ num_images_per_prompt=num_images_per_prompt,
777
+ max_sequence_length=max_sequence_length,
778
+ lora_scale=lora_scale,
779
+ )
780
+
781
+ # 4. Prepare latent variables
782
+ num_channels_latents = self.transformer.config.in_channels // 4
783
+ random_latents, latent_image_ids = self.prepare_latents(
784
+ batch_size * num_images_per_prompt,
785
+ num_channels_latents,
786
+ height,
787
+ width,
788
+ prompt_embeds.dtype,
789
+ device,
790
+ generator,
791
+ latents,
792
+ )
793
+
794
+ # 5. Prepare timesteps
795
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
796
+ image_seq_len = random_latents.shape[1]
797
+ mu = calculate_shift(
798
+ image_seq_len,
799
+ self.scheduler.config.base_image_seq_len,
800
+ self.scheduler.config.max_image_seq_len,
801
+ self.scheduler.config.base_shift,
802
+ self.scheduler.config.max_shift,
803
+ )
804
+ timesteps, num_inference_steps = retrieve_timesteps(
805
+ self.scheduler,
806
+ num_inference_steps,
807
+ device,
808
+ timesteps,
809
+ sigmas,
810
+ mu=mu,
811
+ )
812
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
813
+ self._num_timesteps = len(timesteps)
814
+
815
+ # 4. Preprocess image
816
+ # Preprocess mask image
817
+ mask_image = mask_image.convert("L")
818
+ mask = TF.ToTensor()(mask_image).to(device=device, dtype=self.transformer.dtype)
819
+ mask = TF.Resize(input_image.size, interpolation=TF.InterpolationMode.NEAREST)(mask)
820
+ mask = (mask > 0.5)
821
+ mask = ~mask
822
+
823
+ # # Convert input image to tensor and apply mask
824
+ # input_image = TF.ToTensor()(input_image).to(device=device, dtype=self.transformer.dtype)
825
+ # input_image = input_image * mask.float().expand_as(input_image)
826
+ # input_image = TF.ToPILImage()(input_image.cpu())
827
+
828
+ image = self.image_processor.preprocess(input_image)
829
+ image = image.to(device=device, dtype=self.transformer.dtype)
830
+ latents = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
831
+
832
+
833
+ h, w = latents.shape[2], latents.shape[3]
834
+ mask = TF.ToTensor()(mask_image).to(device=device, dtype=self.transformer.dtype)
835
+ mask = TF.Resize((h, w), interpolation=TF.InterpolationMode.NEAREST)(mask)
836
+
837
+ # Slightly dilate the mask to increase coverage
838
+ kernel_size = 1 # Decreased from 3 to 2
839
+ kernel = torch.ones((1, 1, kernel_size, kernel_size), device=device)
840
+ mask = torch.nn.functional.conv2d(
841
+ mask.unsqueeze(0),
842
+ kernel,
843
+ padding=0
844
+ ).squeeze(0)
845
+ mask = torch.clamp(mask, 0, 1)
846
+
847
+ mask = (mask > 0.1).float()
848
+
849
+ # Remove extra channel dimension if present
850
+ if len(mask.shape) == 3 and mask.shape[0] == 1:
851
+ mask = mask.squeeze(0)
852
+
853
+ bool_mask = mask.bool().unsqueeze(0).unsqueeze(0).expand_as(latents)
854
+ mask=~bool_mask
855
+
856
+ print(mask.shape, latents.shape)
857
+
858
+ masked_latents = (latents * mask).clone().detach() # apply the mask and get gt_latents
859
+ masked_latents = self._pack_latents(masked_latents, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
860
+
861
+ mask = self._pack_latents(mask, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
862
+
863
+ # Decode and save the masked image
864
+ if save_masked_image:
865
+ with torch.no_grad():
866
+ save_masked_latents = self._unpack_latents(masked_latents, 1024, 1024, self.vae_scale_factor)
867
+ save_masked_latents = (save_masked_latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
868
+ mask_image = self.vae.decode(save_masked_latents, return_dict=False)[0]
869
+ mask_image = self.image_processor.postprocess(mask_image, output_type="pil")
870
+ mask_image_path = output_path.replace(".png", "_masked.png")
871
+ mask_image[0].save(mask_image_path)
872
+
873
+
874
+ # initialize the random noise for denoising
875
+ latents = random_latents.clone().detach()
876
+
877
+ self.vae = self.vae.to(torch.float32)
878
+
879
+ # 9. Denoising loop
880
+ self.transformer.eval()
881
+ self.vae.eval()
882
+
883
+ # 6. Denoising loop
884
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
885
+ for i, t in enumerate(timesteps):
886
+ if self.interrupt:
887
+ continue
888
+
889
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
890
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
891
+
892
+ # handle guidance
893
+ if self.transformer.config.guidance_embeds:
894
+ guidance = torch.tensor([guidance_scale], device=device)
895
+ guidance = guidance.expand(latents.shape[0])
896
+ else:
897
+ guidance = None
898
+
899
+ noise_pred = self.transformer(
900
+ hidden_states=latents,
901
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
902
+ timestep=timestep / 1000,
903
+ guidance=guidance,
904
+ pooled_projections=pooled_prompt_embeds,
905
+ encoder_hidden_states=prompt_embeds,
906
+ txt_ids=text_ids,
907
+ img_ids=latent_image_ids,
908
+ joint_attention_kwargs=self.joint_attention_kwargs,
909
+ return_dict=False,
910
+ )[0]
911
+
912
+ # compute the previous noisy sample x_t -> x_t-1
913
+ latents_dtype = latents.dtype
914
+
915
+ # perform CG
916
+ if i < max_steps:
917
+ opt_latents = latents.detach().clone()
918
+ with torch.enable_grad():
919
+ opt_latents = opt_latents.detach().requires_grad_()
920
+ opt_latents = torch.autograd.Variable(opt_latents, requires_grad=True)
921
+ # optimizer = torch.optim.Adam([opt_latents], lr=learning_rate)
922
+
923
+ for _ in range(optimization_steps):
924
+ latents_p = self.scheduler.step_final(noise_pred, t, opt_latents, return_dict=False)[0]
925
+ loss = (1000*torch.nn.functional.mse_loss(latents_p, masked_latents, reduction='none')*mask).mean()
926
+
927
+ grad = torch.autograd.grad(loss, opt_latents)[0]
928
+ # grad = torch.clamp(grad, -0.5, 0.5)
929
+ opt_latents = opt_latents - learning_rate * grad
930
+
931
+ latents = opt_latents.detach().clone()
932
+
933
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
934
+
935
+ if latents.dtype != latents_dtype:
936
+ if torch.backends.mps.is_available():
937
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
938
+ latents = latents.to(latents_dtype)
939
+
940
+ if callback_on_step_end is not None:
941
+ callback_kwargs = {}
942
+ for k in callback_on_step_end_tensor_inputs:
943
+ callback_kwargs[k] = locals()[k]
944
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
945
+
946
+ latents = callback_outputs.pop("latents", latents)
947
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
948
+
949
+ # call the callback, if provided
950
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
951
+ progress_bar.update()
952
+
953
+ if XLA_AVAILABLE:
954
+ xm.mark_step()
955
+
956
+ if output_type == "latent":
957
+ image = latents
958
+
959
+ else:
960
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
961
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
962
+ image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0]
963
+ image = self.image_processor.postprocess(image, output_type=output_type)
964
+
965
+ # Offload all models
966
+ self.maybe_free_model_hooks()
967
+
968
+ if not return_dict:
969
+ return (image,)
970
+
971
+ return FluxPipelineOutput(images=image)
972
+
973
+ def get_diff_image(self, latents):
974
+ latents = self._unpack_latents(latents, 1024, 1024, self.vae_scale_factor)
975
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
976
+ image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0]
977
+ image = self.image_processor.postprocess(image, output_type="pt")
978
+ return image
979
+
980
+ def load_and_preprocess_image(self, image_path, custom_image_processor, device):
981
+ from diffusers.utils import load_image
982
+ img = load_image(image_path)
983
+ img = img.resize((512, 512))
984
+ return custom_image_processor(img).unsqueeze(0).to(device)
985
+
986
+
987
+ @torch.no_grad()
988
+ def edit(
989
+ self,
990
+ prompt: Union[str, List[str]] = None,
991
+ prompt_2: Optional[Union[str, List[str]]] = None,
992
+ negative_prompt: Union[str, List[str]] = None, #
993
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
994
+ true_cfg: float = 1.0, #
995
+ height: Optional[int] = None,
996
+ width: Optional[int] = None,
997
+ num_inference_steps: int = 28,
998
+ timesteps: List[int] = None,
999
+ guidance_scale: float = 3.5,
1000
+ num_images_per_prompt: Optional[int] = 1,
1001
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1002
+ latents: Optional[torch.FloatTensor] = None,
1003
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1004
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1005
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1006
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1007
+ output_type: Optional[str] = "pil",
1008
+ return_dict: bool = True,
1009
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1010
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1011
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1012
+ max_sequence_length: int = 512,
1013
+ optimization_steps: int = 3,
1014
+ learning_rate: float = 0.8,
1015
+ max_steps: int = 5,
1016
+ input_image = None,
1017
+ save_masked_image = False,
1018
+ output_path="",
1019
+ mask_image=None,
1020
+ source_steps=1,
1021
+ ):
1022
+
1023
+ height = height or self.default_sample_size * self.vae_scale_factor
1024
+ width = width or self.default_sample_size * self.vae_scale_factor
1025
+
1026
+ # 1. Check inputs. Raise error if not correct
1027
+ self.check_inputs(
1028
+ prompt,
1029
+ prompt_2,
1030
+ height,
1031
+ width,
1032
+ # negative_prompt=negative_prompt,
1033
+ # negative_prompt_2=negative_prompt_2,
1034
+ prompt_embeds=prompt_embeds,
1035
+ # negative_prompt_embeds=negative_prompt_embeds,
1036
+ pooled_prompt_embeds=pooled_prompt_embeds,
1037
+ # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1038
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1039
+ max_sequence_length=max_sequence_length,
1040
+ )
1041
+
1042
+ self._guidance_scale = guidance_scale
1043
+ self._joint_attention_kwargs = joint_attention_kwargs
1044
+ self._interrupt = False
1045
+
1046
+ # 2. Define call parameters
1047
+ if prompt is not None and isinstance(prompt, str):
1048
+ batch_size = 1
1049
+ elif prompt is not None and isinstance(prompt, list):
1050
+ batch_size = len(prompt)
1051
+ else:
1052
+ batch_size = prompt_embeds.shape[0]
1053
+
1054
+ device = self._execution_device
1055
+
1056
+ lora_scale = (
1057
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1058
+ )
1059
+ do_true_cfg = true_cfg > 1 and negative_prompt is not None
1060
+ (
1061
+ prompt_embeds,
1062
+ pooled_prompt_embeds,
1063
+ text_ids,
1064
+ negative_prompt_embeds,
1065
+ negative_pooled_prompt_embeds,
1066
+ ) = self.encode_prompt_edit(
1067
+ prompt=prompt,
1068
+ prompt_2=prompt_2,
1069
+ negative_prompt=negative_prompt,
1070
+ negative_prompt_2=negative_prompt_2,
1071
+ prompt_embeds=prompt_embeds,
1072
+ pooled_prompt_embeds=pooled_prompt_embeds,
1073
+ negative_prompt_embeds=negative_prompt_embeds,
1074
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1075
+ device=device,
1076
+ num_images_per_prompt=num_images_per_prompt,
1077
+ max_sequence_length=max_sequence_length,
1078
+ lora_scale=lora_scale,
1079
+ do_true_cfg=do_true_cfg,
1080
+ )
1081
+ # text_ids = text_ids.repeat(batch_size, 1, 1)
1082
+
1083
+ if do_true_cfg:
1084
+ # Concatenate embeddings
1085
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1086
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1087
+
1088
+ # 4. Prepare latent variables
1089
+ num_channels_latents = self.transformer.config.in_channels // 4
1090
+ random_latents, latent_image_ids = self.prepare_latents(
1091
+ batch_size * num_images_per_prompt,
1092
+ num_channels_latents,
1093
+ height,
1094
+ width,
1095
+ prompt_embeds.dtype,
1096
+ device,
1097
+ generator,
1098
+ latents,
1099
+ )
1100
+ # latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1)
1101
+
1102
+ # 5. Prepare timesteps
1103
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1104
+ image_seq_len = random_latents.shape[1]
1105
+ mu = calculate_shift(
1106
+ image_seq_len,
1107
+ self.scheduler.config.base_image_seq_len,
1108
+ self.scheduler.config.max_image_seq_len,
1109
+ self.scheduler.config.base_shift,
1110
+ self.scheduler.config.max_shift,
1111
+ )
1112
+ timesteps, num_inference_steps = retrieve_timesteps(
1113
+ self.scheduler,
1114
+ num_inference_steps,
1115
+ device,
1116
+ timesteps,
1117
+ sigmas,
1118
+ mu=mu,
1119
+ )
1120
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1121
+ self._num_timesteps = len(timesteps)
1122
+
1123
+ # 4. Preprocess image
1124
+ image = self.image_processor.preprocess(input_image)
1125
+ image = image.to(device=device, dtype=self.transformer.dtype)
1126
+ latents = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
1127
+
1128
+
1129
+ # Convert PIL image to tensor
1130
+ if mask_image:
1131
+ from torchvision import transforms as TF
1132
+
1133
+ h, w = latents.shape[2], latents.shape[3]
1134
+ mask = TF.ToTensor()(mask_image).to(device=device, dtype=self.transformer.dtype)
1135
+ mask = TF.Resize((h, w), interpolation=TF.InterpolationMode.NEAREST)(mask)
1136
+ mask = (mask > 0.5).float()
1137
+ mask = mask.squeeze(0)#.squeeze(0) # Remove the added dimensions
1138
+ else:
1139
+ mask = torch.ones_like(latents).to(device=device)
1140
+
1141
+ print(mask.shape, latents.shape)
1142
+
1143
+ bool_mask = mask.unsqueeze(0).unsqueeze(0).expand_as(latents)
1144
+ mask=(1-bool_mask*1.0).to(latents.dtype)
1145
+
1146
+ masked_latents = (latents * mask).clone().detach() # apply the mask and get gt_latents
1147
+ masked_latents = self._pack_latents(masked_latents, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
1148
+
1149
+ source_latents = (latents).clone().detach() # apply the mask and get gt_latents
1150
+ source_latents = self._pack_latents(source_latents, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
1151
+
1152
+ mask = self._pack_latents(mask, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
1153
+
1154
+ # initialize the random noise for denoising
1155
+ latents = random_latents.clone().detach()
1156
+
1157
+ self.vae = self.vae.to(torch.float32)
1158
+
1159
+ # 9. Denoising loop
1160
+ self.transformer.eval()
1161
+ self.vae.eval()
1162
+
1163
+ # 6. Denoising loop
1164
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1165
+ for i, t in enumerate(timesteps):
1166
+ if self.interrupt:
1167
+ continue
1168
+
1169
+ latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents
1170
+
1171
+ # handle guidance
1172
+ if self.transformer.config.guidance_embeds:
1173
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1174
+ guidance = guidance.expand(latent_model_input.shape[0])
1175
+ else:
1176
+ guidance = None
1177
+
1178
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1179
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
1180
+
1181
+ noise_pred = self.transformer(
1182
+ hidden_states=latent_model_input,
1183
+ timestep=timestep / 1000,
1184
+ guidance=guidance,
1185
+ pooled_projections=pooled_prompt_embeds,
1186
+ encoder_hidden_states=prompt_embeds,
1187
+ txt_ids=text_ids,
1188
+ img_ids=latent_image_ids,
1189
+ joint_attention_kwargs=self.joint_attention_kwargs,
1190
+ return_dict=False,
1191
+ )[0]
1192
+
1193
+ if do_true_cfg:
1194
+ neg_noise_pred, noise_pred = noise_pred.chunk(2)
1195
+ # noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)
1196
+ noise_pred = noise_pred + (1-mask)*(noise_pred - neg_noise_pred) * true_cfg
1197
+ # else:
1198
+ # neg_noise_pred, noise_pred = noise_pred.chunk(2)
1199
+
1200
+ # perform CG
1201
+ if i < max_steps:
1202
+ opt_latents = latents.detach().clone()
1203
+ with torch.enable_grad():
1204
+ opt_latents = opt_latents.detach().requires_grad_()
1205
+ opt_latents = torch.autograd.Variable(opt_latents, requires_grad=True)
1206
+ # optimizer = torch.optim.Adam([opt_latents], lr=learning_rate)
1207
+
1208
+ for _ in range(optimization_steps):
1209
+ latents_p = self.scheduler.step_final(noise_pred, t, opt_latents, return_dict=False)[0]
1210
+ if i < source_steps:
1211
+ loss = (1000*torch.nn.functional.mse_loss(latents_p, source_latents, reduction='none')).mean()
1212
+ else:
1213
+ loss = (1000*torch.nn.functional.mse_loss(latents_p, masked_latents, reduction='none')*mask).mean()
1214
+
1215
+ grad = torch.autograd.grad(loss, opt_latents)[0]
1216
+ # grad = torch.clamp(grad, -0.5, 0.5)
1217
+ opt_latents = opt_latents - learning_rate * grad
1218
+
1219
+ latents = opt_latents.detach().clone()
1220
+
1221
+
1222
+ # compute the previous noisy sample x_t -> x_t-1
1223
+ latents_dtype = latents.dtype
1224
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1225
+
1226
+ if latents.dtype != latents_dtype:
1227
+ if torch.backends.mps.is_available():
1228
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1229
+ latents = latents.to(latents_dtype)
1230
+
1231
+ if callback_on_step_end is not None:
1232
+ callback_kwargs = {}
1233
+ for k in callback_on_step_end_tensor_inputs:
1234
+ callback_kwargs[k] = locals()[k]
1235
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1236
+
1237
+ latents = callback_outputs.pop("latents", latents)
1238
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1239
+
1240
+ # call the callback, if provided
1241
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1242
+ progress_bar.update()
1243
+
1244
+ if XLA_AVAILABLE:
1245
+ xm.mark_step()
1246
+
1247
+ if output_type == "latent":
1248
+ image = latents
1249
+
1250
+ else:
1251
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1252
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1253
+ image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0]
1254
+ image = self.image_processor.postprocess(image, output_type=output_type)
1255
+
1256
+ # Offload all models
1257
+ self.maybe_free_model_hooks()
1258
+
1259
+ if not return_dict:
1260
+ return (image,)
1261
+
1262
+ return FluxPipelineOutput(images=image)
1263
+
1264
+ @torch.no_grad()
1265
+ def edit2(
1266
+ self,
1267
+ prompt: Union[str, List[str]] = None,
1268
+ prompt_2: Optional[Union[str, List[str]]] = None,
1269
+ negative_prompt: Union[str, List[str]] = None, #
1270
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1271
+ true_cfg: float = 1.0, #
1272
+ height: Optional[int] = None,
1273
+ width: Optional[int] = None,
1274
+ num_inference_steps: int = 28,
1275
+ timesteps: List[int] = None,
1276
+ guidance_scale: float = 3.5,
1277
+ num_images_per_prompt: Optional[int] = 1,
1278
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1279
+ latents: Optional[torch.FloatTensor] = None,
1280
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1281
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1282
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1283
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1284
+ output_type: Optional[str] = "pil",
1285
+ return_dict: bool = True,
1286
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
1287
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1288
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1289
+ max_sequence_length: int = 512,
1290
+ optimization_steps: int = 3,
1291
+ learning_rate: float = 0.8,
1292
+ max_steps: int = 5,
1293
+ input_image = None,
1294
+ save_masked_image = False,
1295
+ output_path="",
1296
+ mask_image=None,
1297
+ source_steps=1,
1298
+ ):
1299
+ r"""
1300
+ Function invoked when calling the pipeline for generation.
1301
+
1302
+ Args:
1303
+ prompt (`str` or `List[str]`, *optional*):
1304
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1305
+ instead.
1306
+ prompt_2 (`str` or `List[str]`, *optional*):
1307
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1308
+ will be used instead
1309
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1310
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1311
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1312
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1313
+ num_inference_steps (`int`, *optional*, defaults to 50):
1314
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1315
+ expense of slower inference.
1316
+ timesteps (`List[int]`, *optional*):
1317
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1318
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1319
+ passed will be used. Must be in descending order.
1320
+ guidance_scale (`float`, *optional*, defaults to 7.0):
1321
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1322
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1323
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1324
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1325
+ usually at the expense of lower image quality.
1326
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1327
+ The number of images to generate per prompt.
1328
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1329
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1330
+ to make generation deterministic.
1331
+ latents (`torch.FloatTensor`, *optional*):
1332
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1333
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1334
+ tensor will ge generated by sampling using the supplied random `generator`.
1335
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1336
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1337
+ provided, text embeddings will be generated from `prompt` input argument.
1338
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1339
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1340
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1341
+ output_type (`str`, *optional*, defaults to `"pil"`):
1342
+ The output format of the generate image. Choose between
1343
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1344
+ return_dict (`bool`, *optional*, defaults to `True`):
1345
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
1346
+ joint_attention_kwargs (`dict`, *optional*):
1347
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1348
+ `self.processor` in
1349
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1350
+ callback_on_step_end (`Callable`, *optional*):
1351
+ A function that calls at the end of each denoising steps during the inference. The function is called
1352
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1353
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1354
+ `callback_on_step_end_tensor_inputs`.
1355
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1356
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1357
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1358
+ `._callback_tensor_inputs` attribute of your pipeline class.
1359
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
1360
+
1361
+ Examples:
1362
+
1363
+ Returns:
1364
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
1365
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
1366
+ images.
1367
+ """
1368
+
1369
+ height = height or self.default_sample_size * self.vae_scale_factor
1370
+ width = width or self.default_sample_size * self.vae_scale_factor
1371
+
1372
+ # 1. Check inputs. Raise error if not correct
1373
+ self.check_inputs(
1374
+ prompt,
1375
+ prompt_2,
1376
+ height,
1377
+ width,
1378
+ # negative_prompt=negative_prompt,
1379
+ # negative_prompt_2=negative_prompt_2,
1380
+ prompt_embeds=prompt_embeds,
1381
+ # negative_prompt_embeds=negative_prompt_embeds,
1382
+ pooled_prompt_embeds=pooled_prompt_embeds,
1383
+ # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1384
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1385
+ max_sequence_length=max_sequence_length,
1386
+ )
1387
+
1388
+ self._guidance_scale = guidance_scale
1389
+ self._joint_attention_kwargs = joint_attention_kwargs
1390
+ self._interrupt = False
1391
+
1392
+ # 2. Define call parameters
1393
+ if prompt is not None and isinstance(prompt, str):
1394
+ batch_size = 1
1395
+ elif prompt is not None and isinstance(prompt, list):
1396
+ batch_size = len(prompt)
1397
+ else:
1398
+ batch_size = prompt_embeds.shape[0]
1399
+
1400
+ device = self._execution_device
1401
+
1402
+ lora_scale = (
1403
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1404
+ )
1405
+ do_true_cfg = true_cfg > 1 and negative_prompt is not None
1406
+ (
1407
+ prompt_embeds,
1408
+ pooled_prompt_embeds,
1409
+ text_ids,
1410
+ negative_prompt_embeds,
1411
+ negative_pooled_prompt_embeds,
1412
+ ) = self.encode_prompt_edit(
1413
+ prompt=prompt,
1414
+ prompt_2=prompt_2,
1415
+ negative_prompt=negative_prompt,
1416
+ negative_prompt_2=negative_prompt_2,
1417
+ prompt_embeds=prompt_embeds,
1418
+ pooled_prompt_embeds=pooled_prompt_embeds,
1419
+ negative_prompt_embeds=negative_prompt_embeds,
1420
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1421
+ device=device,
1422
+ num_images_per_prompt=num_images_per_prompt,
1423
+ max_sequence_length=max_sequence_length,
1424
+ lora_scale=lora_scale,
1425
+ do_true_cfg=do_true_cfg,
1426
+ )
1427
+ # text_ids = text_ids.repeat(batch_size, 1, 1)
1428
+
1429
+ if do_true_cfg:
1430
+ # Concatenate embeddings
1431
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1432
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1433
+
1434
+ # 4. Prepare latent variables
1435
+ num_channels_latents = self.transformer.config.in_channels // 4
1436
+ random_latents, latent_image_ids = self.prepare_latents(
1437
+ batch_size * num_images_per_prompt,
1438
+ num_channels_latents,
1439
+ height,
1440
+ width,
1441
+ prompt_embeds.dtype,
1442
+ device,
1443
+ generator,
1444
+ latents,
1445
+ )
1446
+ # latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1)
1447
+
1448
+ # 5. Prepare timesteps
1449
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
1450
+ image_seq_len = random_latents.shape[1]
1451
+ mu = calculate_shift(
1452
+ image_seq_len,
1453
+ self.scheduler.config.base_image_seq_len,
1454
+ self.scheduler.config.max_image_seq_len,
1455
+ self.scheduler.config.base_shift,
1456
+ self.scheduler.config.max_shift,
1457
+ )
1458
+ timesteps, num_inference_steps = retrieve_timesteps(
1459
+ self.scheduler,
1460
+ num_inference_steps,
1461
+ device,
1462
+ timesteps,
1463
+ sigmas,
1464
+ mu=mu,
1465
+ )
1466
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1467
+ self._num_timesteps = len(timesteps)
1468
+
1469
+ # 4. Preprocess image
1470
+ image = self.image_processor.preprocess(input_image)
1471
+ image = image.to(device=device, dtype=self.transformer.dtype)
1472
+ latents = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
1473
+
1474
+
1475
+ # Convert PIL image to tensor
1476
+ if mask_image:
1477
+ from torchvision import transforms as TF
1478
+
1479
+ h, w = latents.shape[2], latents.shape[3]
1480
+ mask = TF.ToTensor()(mask_image).to(device=device, dtype=self.transformer.dtype)
1481
+ mask = TF.Resize((h, w), interpolation=TF.InterpolationMode.NEAREST)(mask)
1482
+ mask = (mask > 0.1).float()
1483
+ mask = mask.squeeze(0)#.squeeze(0) # Remove the added dimensions
1484
+ else:
1485
+ mask = torch.ones_like(latents).to(device=device)
1486
+
1487
+ bool_mask = mask.unsqueeze(0).unsqueeze(0).expand_as(latents)
1488
+ mask=(1-bool_mask*1.0).to(latents.dtype)
1489
+
1490
+ masked_latents = (latents * mask).clone().detach() # apply the mask and get gt_latents
1491
+ masked_latents = self._pack_latents(masked_latents, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
1492
+
1493
+ source_latents = (latents).clone().detach() # apply the mask and get gt_latents
1494
+ source_latents = self._pack_latents(source_latents, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
1495
+
1496
+ mask = self._pack_latents(mask, batch_size, num_channels_latents, 2 * (int(height) // self.vae_scale_factor), 2 * (int(width) // self.vae_scale_factor))
1497
+
1498
+ # initialize the random noise for denoising
1499
+ latents = random_latents.clone().detach()
1500
+
1501
+ self.vae = self.vae.to(torch.float32)
1502
+
1503
+ # 9. Denoising loop
1504
+ self.transformer.eval()
1505
+ self.vae.eval()
1506
+
1507
+ # 6. Denoising loop
1508
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1509
+ for i, t in enumerate(timesteps):
1510
+ if self.interrupt:
1511
+ continue
1512
+
1513
+ latent_model_input = torch.cat([latents] * 2) if do_true_cfg else latents
1514
+
1515
+ # handle guidance
1516
+ if self.transformer.config.guidance_embeds:
1517
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1518
+ guidance = guidance.expand(latent_model_input.shape[0])
1519
+ else:
1520
+ guidance = None
1521
+
1522
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1523
+ timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
1524
+
1525
+ noise_pred = self.transformer(
1526
+ hidden_states=latent_model_input,
1527
+ timestep=timestep / 1000,
1528
+ guidance=guidance,
1529
+ pooled_projections=pooled_prompt_embeds,
1530
+ encoder_hidden_states=prompt_embeds,
1531
+ txt_ids=text_ids,
1532
+ img_ids=latent_image_ids,
1533
+ joint_attention_kwargs=self.joint_attention_kwargs,
1534
+ return_dict=False,
1535
+ )[0]
1536
+
1537
+ if do_true_cfg and i < max_steps:
1538
+ neg_noise_pred, noise_pred = noise_pred.chunk(2)
1539
+ # noise_pred = neg_noise_pred + true_cfg * (noise_pred - neg_noise_pred)
1540
+ noise_pred = noise_pred + (1-mask)*(noise_pred - neg_noise_pred) * true_cfg
1541
+ else:
1542
+ neg_noise_pred, noise_pred = noise_pred.chunk(2)
1543
+
1544
+ # perform CG
1545
+ if i < max_steps:
1546
+ opt_latents = latents.detach().clone()
1547
+ with torch.enable_grad():
1548
+ opt_latents = opt_latents.detach().requires_grad_()
1549
+ opt_latents = torch.autograd.Variable(opt_latents, requires_grad=True)
1550
+ # optimizer = torch.optim.Adam([opt_latents], lr=learning_rate)
1551
+
1552
+ for _ in range(optimization_steps):
1553
+ latents_p = self.scheduler.step_final(noise_pred, t, opt_latents, return_dict=False)[0]
1554
+ if i < source_steps:
1555
+ loss = (1000*torch.nn.functional.mse_loss(latents_p, source_latents, reduction='none')).mean()
1556
+ else:
1557
+ loss = (1000*torch.nn.functional.mse_loss(latents_p, masked_latents, reduction='none')*mask).mean()
1558
+
1559
+ grad = torch.autograd.grad(loss, opt_latents)[0]
1560
+ # grad = torch.clamp(grad, -0.5, 0.5)
1561
+ opt_latents = opt_latents - learning_rate * grad
1562
+
1563
+ latents = opt_latents.detach().clone()
1564
+
1565
+
1566
+ # compute the previous noisy sample x_t -> x_t-1
1567
+ latents_dtype = latents.dtype
1568
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1569
+
1570
+ if latents.dtype != latents_dtype:
1571
+ if torch.backends.mps.is_available():
1572
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1573
+ latents = latents.to(latents_dtype)
1574
+
1575
+ if callback_on_step_end is not None:
1576
+ callback_kwargs = {}
1577
+ for k in callback_on_step_end_tensor_inputs:
1578
+ callback_kwargs[k] = locals()[k]
1579
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1580
+
1581
+ latents = callback_outputs.pop("latents", latents)
1582
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1583
+
1584
+ # call the callback, if provided
1585
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1586
+ progress_bar.update()
1587
+
1588
+ if XLA_AVAILABLE:
1589
+ xm.mark_step()
1590
+
1591
+ if output_type == "latent":
1592
+ image = latents
1593
+
1594
+ else:
1595
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1596
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1597
+ image = self.vae.decode(latents.to(torch.float32), return_dict=False)[0]
1598
+ image = self.image_processor.postprocess(image, output_type=output_type)
1599
+
1600
+ # Offload all models
1601
+ self.maybe_free_model_hooks()
1602
+
1603
+ if not return_dict:
1604
+ return (image,)
1605
+
1606
+ return FluxPipelineOutput(images=image)
1607
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ spaces
2
+ diffusers==0.31.0
3
+ gradio==5.6.0
4
+ numpy==2.1.3
5
+ Pillow==11.0.0
6
+ torch==2.1.2
7
+ torch_xla==2.5.1
8
+ torchvision==0.16.2
9
+ transformers==4.45.2
saved_results/20241126_053639/input.png ADDED
saved_results/20241126_053639/mask.png ADDED
saved_results/20241126_053639/output.png ADDED

Git LFS Details

  • SHA256: 5f1fdeb3a98da1b0cc536e4f97f4d36cfc912b9645225607402b87b7221047ef
  • Pointer size: 132 Bytes
  • Size of remote file: 1.08 MB
saved_results/20241126_053639/parameters.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mode": "Inpainting",
3
+ "prompt": "a dog",
4
+ "edit_prompt": "",
5
+ "seed": 0,
6
+ "randomize_seed": true,
7
+ "num_inference_steps": 30,
8
+ "max_steps": 30,
9
+ "learning_rate": 1,
10
+ "max_source_steps": 20,
11
+ "optimization_steps": 10,
12
+ "true_cfg": 2
13
+ }
saved_results/20241126_055109/input.png ADDED
saved_results/20241126_055109/mask.png ADDED
saved_results/20241126_055109/output.png ADDED

Git LFS Details

  • SHA256: b8111afa5b56049ddf13a1a35c0a8c3be4ea9ab777476c7669e1719b6a4343d1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
saved_results/20241126_055109/parameters.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mode": "Inpainting",
3
+ "prompt": "a dog",
4
+ "edit_prompt": "",
5
+ "seed": 0,
6
+ "randomize_seed": true,
7
+ "num_inference_steps": 30,
8
+ "max_steps": 30,
9
+ "learning_rate": 1,
10
+ "max_source_steps": 20,
11
+ "optimization_steps": 10,
12
+ "true_cfg": 2
13
+ }
saved_results/20241126_173140/input.png ADDED
saved_results/20241126_173140/mask.png ADDED
saved_results/20241126_173140/output.png ADDED

Git LFS Details

  • SHA256: c484065f2bb53cb427858a9137c4f5e2b32426609a9a1bb401883368f2a6f6da
  • Pointer size: 132 Bytes
  • Size of remote file: 1.37 MB
saved_results/20241126_173140/parameters.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mode": "Inpainting",
3
+ "prompt": "a cat with blue eyes",
4
+ "edit_prompt": "",
5
+ "seed": 0,
6
+ "randomize_seed": true,
7
+ "num_inference_steps": 30,
8
+ "max_steps": 20,
9
+ "learning_rate": 1,
10
+ "max_source_steps": 20,
11
+ "optimization_steps": 10,
12
+ "true_cfg": 2
13
+ }
saved_results/20241126_181436/input.png ADDED

Git LFS Details

  • SHA256: 4e6a77c0cb907998fe78e28658c7f65c674f1be6bf8d6984279f2fdcbc30482e
  • Pointer size: 132 Bytes
  • Size of remote file: 4.84 MB
saved_results/20241126_181436/mask.png ADDED
saved_results/20241126_181436/output.png ADDED
saved_results/20241126_181436/parameters.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mode": "Editing",
3
+ "prompt": " ",
4
+ "edit_prompt": "volcano eruption",
5
+ "seed": 0,
6
+ "randomize_seed": true,
7
+ "num_inference_steps": 30,
8
+ "max_steps": 20,
9
+ "learning_rate": 0.5,
10
+ "max_source_steps": 2,
11
+ "optimization_steps": 3,
12
+ "true_cfg": 4.5
13
+ }
saved_results/20241126_181633/input.png ADDED

Git LFS Details

  • SHA256: 4e6a77c0cb907998fe78e28658c7f65c674f1be6bf8d6984279f2fdcbc30482e
  • Pointer size: 132 Bytes
  • Size of remote file: 4.84 MB
saved_results/20241126_181633/mask.png ADDED
saved_results/20241126_181633/output.png ADDED
saved_results/20241126_181633/parameters.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mode": "Editing",
3
+ "prompt": " ",
4
+ "edit_prompt": "volcano eruption",
5
+ "seed": 0,
6
+ "randomize_seed": true,
7
+ "num_inference_steps": 30,
8
+ "max_steps": 20,
9
+ "learning_rate": 0.5,
10
+ "max_source_steps": 2,
11
+ "optimization_steps": 3,
12
+ "true_cfg": 4.5
13
+ }
saved_results/20241126_214810/input.png ADDED
saved_results/20241126_214810/mask.png ADDED
saved_results/20241126_214810/output.png ADDED

Git LFS Details

  • SHA256: c89f9f31166f6418be1533d462e74aefadb2b4521dd8015640f1d9403673a6f8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
saved_results/20241126_214810/parameters.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mode": "Editing",
3
+ "prompt": " ",
4
+ "edit_prompt": "a dog with flowers in the mouth",
5
+ "seed": 0,
6
+ "randomize_seed": true,
7
+ "num_inference_steps": 30,
8
+ "max_steps": 30,
9
+ "learning_rate": 1,
10
+ "max_source_steps": 5,
11
+ "optimization_steps": 3,
12
+ "true_cfg": 4.5
13
+ }
saved_results/20241126_214908/input.png ADDED
saved_results/20241126_214908/mask.png ADDED
saved_results/20241126_214908/output.png ADDED

Git LFS Details

  • SHA256: 4e8b13890624af98097b57ca314f2af4ec98e73f04c26be10ec93466edf7cabf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.08 MB
saved_results/20241126_214908/parameters.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mode": "Editing",
3
+ "prompt": " ",
4
+ "edit_prompt": "a dog with flowers in the mouth",
5
+ "seed": 0,
6
+ "randomize_seed": true,
7
+ "num_inference_steps": 30,
8
+ "max_steps": 20,
9
+ "learning_rate": 1,
10
+ "max_source_steps": 5,
11
+ "optimization_steps": 3,
12
+ "true_cfg": 4.5
13
+ }
saved_results/20241126_215043/input.png ADDED
saved_results/20241126_215043/mask.png ADDED
saved_results/20241126_215043/output.png ADDED

Git LFS Details

  • SHA256: d29bd6e39a70b4a9273467af120c0c1a19ae4019ba975a6b88bdefc4e1620271
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
saved_results/20241126_215043/parameters.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mode": "Editing",
3
+ "prompt": " ",
4
+ "edit_prompt": "a dog with flowers in the mouth",
5
+ "seed": 52,
6
+ "randomize_seed": false,
7
+ "num_inference_steps": 30,
8
+ "max_steps": 20,
9
+ "learning_rate": 1,
10
+ "max_source_steps": 5,
11
+ "optimization_steps": 3,
12
+ "true_cfg": 4.5
13
+ }
saved_results/20241126_221300/input.png ADDED
saved_results/20241126_221300/mask.png ADDED
saved_results/20241126_221300/output.png ADDED

Git LFS Details

  • SHA256: 9b09790e41c4f7502684c6682c623029934c9a8774ed5b45ceba5c1be474638b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
saved_results/20241126_221300/parameters.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mode": "Inpainting",
3
+ "prompt": "A building with \"ASU\" written on it.",
4
+ "edit_prompt": "",
5
+ "seed": 0,
6
+ "randomize_seed": true,
7
+ "num_inference_steps": 30,
8
+ "max_steps": 30,
9
+ "learning_rate": 1,
10
+ "max_source_steps": 20,
11
+ "optimization_steps": 5,
12
+ "true_cfg": 2
13
+ }
saved_results/20241126_222257/input.png ADDED
saved_results/20241126_222257/mask.png ADDED
saved_results/20241126_222257/output.png ADDED

Git LFS Details

  • SHA256: 825b90e103c5eb976f47c58cd31cd502f0334714edd26675d76b9e7e1fcad7b9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
saved_results/20241126_222257/parameters.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mode": "Inpainting",
3
+ "prompt": "A cute pig with big eyes",
4
+ "edit_prompt": "",
5
+ "seed": 0,
6
+ "randomize_seed": true,
7
+ "num_inference_steps": 30,
8
+ "max_steps": 19.8,
9
+ "learning_rate": 1,
10
+ "max_source_steps": 20,
11
+ "optimization_steps": 5,
12
+ "true_cfg": 2
13
+ }
saved_results/20241126_222442/input.png ADDED
saved_results/20241126_222442/mask.png ADDED
saved_results/20241126_222442/output.png ADDED

Git LFS Details

  • SHA256: f8850c5c54035ebb6d7ab27f91779c942936b6736292b2e488583b06610d6e4e
  • Pointer size: 132 Bytes
  • Size of remote file: 1 MB