Spaces:
Sleeping
Sleeping
Commit
·
f30e9ce
1
Parent(s):
eb27870
Updates
Browse files
app.py
CHANGED
@@ -27,7 +27,7 @@ base_model = "black-forest-labs/FLUX.1-dev"
|
|
27 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
28 |
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
29 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
30 |
-
|
31 |
base_model,
|
32 |
vae=good_vae,
|
33 |
transformer=pipe.transformer,
|
@@ -121,7 +121,7 @@ def generate_image_to_image(prompt_mash, image_input_path, image_strength, steps
|
|
121 |
).images[0]
|
122 |
return final_image
|
123 |
|
124 |
-
def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
|
125 |
global selected_lora_index
|
126 |
if selected_lora_index is None:
|
127 |
raise gr.Error("You must select a LoRA before proceeding.")
|
@@ -143,20 +143,13 @@ def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora
|
|
143 |
|
144 |
with calculateDuration("Unloading LoRA"):
|
145 |
pipe.unload_lora_weights()
|
146 |
-
pipe_i2i.unload_lora_weights()
|
147 |
|
148 |
# Load LoRA weights
|
149 |
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
|
150 |
-
if
|
151 |
-
|
152 |
-
pipe_i2i.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
|
153 |
-
else:
|
154 |
-
pipe_i2i.load_lora_weights(lora_path)
|
155 |
else:
|
156 |
-
|
157 |
-
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
|
158 |
-
else:
|
159 |
-
pipe.load_lora_weights(lora_path)
|
160 |
|
161 |
# Set random seed for reproducibility
|
162 |
with calculateDuration("Randomizing seed"):
|
@@ -168,16 +161,16 @@ def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora
|
|
168 |
yield final_image, seed, gr.update(visible=False)
|
169 |
else:
|
170 |
image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
|
182 |
# ...
|
183 |
|
|
|
27 |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
28 |
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
29 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
30 |
+
pipe_i2i = FluxImg2ImgPipeline.from_pretrained(
|
31 |
base_model,
|
32 |
vae=good_vae,
|
33 |
transformer=pipe.transformer,
|
|
|
121 |
).images[0]
|
122 |
return final_image
|
123 |
|
124 |
+
def run_lora(prompt, cfg_scale, steps, selected_index, randomize_seed, seed, width, height, lora_scale, progress=gr.Progress(track_tqdm=True)):
|
125 |
global selected_lora_index
|
126 |
if selected_lora_index is None:
|
127 |
raise gr.Error("You must select a LoRA before proceeding.")
|
|
|
143 |
|
144 |
with calculateDuration("Unloading LoRA"):
|
145 |
pipe.unload_lora_weights()
|
|
|
146 |
|
147 |
# Load LoRA weights
|
148 |
with calculateDuration(f"Loading LoRA weights for {selected_lora['title']}"):
|
149 |
+
if "weights" in selected_lora:
|
150 |
+
pipe.load_lora_weights(lora_path, weight_name=selected_lora["weights"])
|
|
|
|
|
|
|
151 |
else:
|
152 |
+
pipe.load_lora_weights(lora_path)
|
|
|
|
|
|
|
153 |
|
154 |
# Set random seed for reproducibility
|
155 |
with calculateDuration("Randomizing seed"):
|
|
|
161 |
yield final_image, seed, gr.update(visible=False)
|
162 |
else:
|
163 |
image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, lora_scale, progress)
|
164 |
+
# Consume the generator to get the final image
|
165 |
+
final_image = None
|
166 |
+
step_counter = 0
|
167 |
+
for image in image_generator:
|
168 |
+
step_counter += 1
|
169 |
+
final_image = image
|
170 |
+
progress_bar = f'Generating image... Step {step_counter}/{steps}'
|
171 |
+
yield image, seed, gr.update(visible=True, value=progress_bar)
|
172 |
+
|
173 |
+
yield final_image, seed, gr.update(visible=False)
|
174 |
|
175 |
# ...
|
176 |
|