Update app.py
Browse files
app.py
CHANGED
@@ -231,7 +231,7 @@ def remove_lora_2(selected_indices, loras_state):
|
|
231 |
|
232 |
def remove_lora_3(selected_indices, loras_state):
|
233 |
if len(selected_indices) >= 3:
|
234 |
-
selected_indices.pop(
|
235 |
selected_info_1 = "Select a Celebrity as LoRA 1"
|
236 |
selected_info_2 = "Select a LoRA 2"
|
237 |
selected_info_3 = "Select a LoRA 3"
|
@@ -267,7 +267,7 @@ def remove_lora_3(selected_indices, loras_state):
|
|
267 |
|
268 |
def remove_lora_4(selected_indices, loras_state):
|
269 |
if len(selected_indices) >= 4:
|
270 |
-
selected_indices.pop(
|
271 |
selected_info_1 = "Select a Celebrity as LoRA 1"
|
272 |
selected_info_2 = "Select a LoRA 2"
|
273 |
selected_info_3 = "Select a LoRA 3"
|
@@ -451,12 +451,27 @@ def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
|
|
451 |
yield img, seed, f"Generated image {img} with seed {seed}"
|
452 |
return img
|
453 |
|
454 |
-
@spaces.GPU(duration=
|
455 |
def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
|
456 |
if not selected_indices:
|
457 |
raise gr.Error("You must select at least one LoRA before proceeding.")
|
458 |
|
459 |
selected_loras = [loras_state[idx] for idx in selected_indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
|
461 |
# Build the prompt with trigger words
|
462 |
prepends = []
|
@@ -470,7 +485,7 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
470 |
appends.append(trigger_word)
|
471 |
prompt_mash = " ".join(prepends + [prompt] + appends)
|
472 |
print("Prompt Mash: ", prompt_mash)
|
473 |
-
print("
|
474 |
|
475 |
# Unload previous LoRA weights
|
476 |
with calculateDuration("Unloading LoRA"):
|
@@ -484,20 +499,24 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
484 |
with calculateDuration("Loading LoRA weights"):
|
485 |
for idx, lora in enumerate(selected_loras):
|
486 |
lora_name = f"lora_{idx}"
|
487 |
-
|
488 |
-
print(f"
|
489 |
-
|
490 |
pipe.load_lora_weights(
|
491 |
lora['repo'],
|
492 |
-
weight_name=
|
493 |
low_cpu_mem_usage=True,
|
494 |
adapter_name=lora_name,
|
495 |
-
|
496 |
-
|
497 |
-
print("Loaded LoRAs:", selected_indices)
|
498 |
-
print("Adapter weights:", lora_weights)
|
499 |
|
|
|
|
|
500 |
pipe.set_adapters(lora_names, adapter_weights=lora_weights)
|
|
|
|
|
|
|
|
|
501 |
|
502 |
# Set random seed if required
|
503 |
if randomize_seed:
|
@@ -512,7 +531,7 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
512 |
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
|
513 |
yield image, seed, gr.update(value=progress_bar, visible=True)
|
514 |
|
515 |
-
run_lora.zerogpu =
|
516 |
|
517 |
def get_huggingface_safetensors(link):
|
518 |
split_link = link.split("/")
|
@@ -703,7 +722,8 @@ with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(128, 256)
|
|
703 |
gallery.select(
|
704 |
update_selection,
|
705 |
inputs=[selected_indices, loras_state, width, height],
|
706 |
-
outputs=[prompt, selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, width, height, lora_image_1, lora_image_2, lora_image_3, lora_image_4]
|
|
|
707 |
remove_button_1.click(
|
708 |
remove_lora_1,
|
709 |
inputs=[selected_indices, loras_state],
|
|
|
231 |
|
232 |
def remove_lora_3(selected_indices, loras_state):
|
233 |
if len(selected_indices) >= 3:
|
234 |
+
selected_indices.pop(2)
|
235 |
selected_info_1 = "Select a Celebrity as LoRA 1"
|
236 |
selected_info_2 = "Select a LoRA 2"
|
237 |
selected_info_3 = "Select a LoRA 3"
|
|
|
267 |
|
268 |
def remove_lora_4(selected_indices, loras_state):
|
269 |
if len(selected_indices) >= 4:
|
270 |
+
selected_indices.pop(3)
|
271 |
selected_info_1 = "Select a Celebrity as LoRA 1"
|
272 |
selected_info_2 = "Select a LoRA 2"
|
273 |
selected_info_3 = "Select a LoRA 3"
|
|
|
451 |
yield img, seed, f"Generated image {img} with seed {seed}"
|
452 |
return img
|
453 |
|
454 |
+
@spaces.GPU(duration=85)
|
455 |
def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
|
456 |
if not selected_indices:
|
457 |
raise gr.Error("You must select at least one LoRA before proceeding.")
|
458 |
|
459 |
selected_loras = [loras_state[idx] for idx in selected_indices]
|
460 |
+
|
461 |
+
# Debugging snippet: Inspect LoRAs before loading
|
462 |
+
for idx, lora in enumerate(selected_loras):
|
463 |
+
print(f"Inspecting LoRA {idx + 1}: {lora['title']}")
|
464 |
+
try:
|
465 |
+
lora_weights = torch.load(lora['repo']) # Load the LoRA weights
|
466 |
+
print(f"LoRA Parameter Keys for {lora['title']}: {list(lora_weights.keys())}")
|
467 |
+
except Exception as e:
|
468 |
+
print(f"Error loading LoRA weights for {lora['title']} from {lora['repo']}: {e}")
|
469 |
+
raise gr.Error(f"Failed to load LoRA weights for {lora['title']}.")
|
470 |
+
|
471 |
+
# Print the selected LoRAs
|
472 |
+
print("Running with the following LoRAs:")
|
473 |
+
for lora in selected_loras:
|
474 |
+
print(f"- {lora['title']} from {lora['repo']} with scale {lora_scale_1 if selected_loras.index(lora) == 0 else lora_scale_2}")
|
475 |
|
476 |
# Build the prompt with trigger words
|
477 |
prepends = []
|
|
|
485 |
appends.append(trigger_word)
|
486 |
prompt_mash = " ".join(prepends + [prompt] + appends)
|
487 |
print("Prompt Mash: ", prompt_mash)
|
488 |
+
print(":--Seed--:", seed)
|
489 |
|
490 |
# Unload previous LoRA weights
|
491 |
with calculateDuration("Unloading LoRA"):
|
|
|
499 |
with calculateDuration("Loading LoRA weights"):
|
500 |
for idx, lora in enumerate(selected_loras):
|
501 |
lora_name = f"lora_{idx}"
|
502 |
+
lora_weights_path = lora.get("weights")
|
503 |
+
print(f"Loading LoRA: {lora['title']} from {lora['repo']} with adapter name: {lora_name}")
|
504 |
+
|
505 |
pipe.load_lora_weights(
|
506 |
lora['repo'],
|
507 |
+
weight_name=lora_weights_path,
|
508 |
low_cpu_mem_usage=True,
|
509 |
adapter_name=lora_name,
|
510 |
+
merge_and_unload=True, # Explicitly merge weights to avoid runtime conflicts
|
511 |
+
)
|
|
|
|
|
512 |
|
513 |
+
print("Adapter weights:", lora_weights)
|
514 |
+
try:
|
515 |
pipe.set_adapters(lora_names, adapter_weights=lora_weights)
|
516 |
+
except Exception as e:
|
517 |
+
print(f"Error while setting adapters: {e}")
|
518 |
+
raise
|
519 |
+
#pipe.set_adapters(lora_names, adapter_weights=lora_weights)
|
520 |
|
521 |
# Set random seed if required
|
522 |
if randomize_seed:
|
|
|
531 |
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
|
532 |
yield image, seed, gr.update(value=progress_bar, visible=True)
|
533 |
|
534 |
+
run_lora.zerogpu = False
|
535 |
|
536 |
def get_huggingface_safetensors(link):
|
537 |
split_link = link.split("/")
|
|
|
722 |
gallery.select(
|
723 |
update_selection,
|
724 |
inputs=[selected_indices, loras_state, width, height],
|
725 |
+
outputs=[prompt, selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, width, height, lora_image_1, lora_image_2, lora_image_3, lora_image_4]
|
726 |
+
)
|
727 |
remove_button_1.click(
|
728 |
remove_lora_1,
|
729 |
inputs=[selected_indices, loras_state],
|