Upload app.py
Browse files
app.py
CHANGED
@@ -105,18 +105,6 @@ def download_file(url, directory=None):
|
|
105 |
|
106 |
return filepath
|
107 |
|
108 |
-
def get_lora_weights(lora_repo, weight_name=None):
|
109 |
-
try:
|
110 |
-
# Download the weights from Hugging Face Hub
|
111 |
-
file_path = hf_hub_download(
|
112 |
-
repo_id=lora_repo,
|
113 |
-
filename=weight_name if weight_name else "pytorch_model.bin"
|
114 |
-
)
|
115 |
-
return file_path
|
116 |
-
except Exception as e:
|
117 |
-
print(f"Failed to fetch weights for {lora_repo}: {e}")
|
118 |
-
raise
|
119 |
-
|
120 |
def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
|
121 |
selected_index = evt.index
|
122 |
selected_indices = selected_indices or []
|
@@ -243,7 +231,7 @@ def remove_lora_2(selected_indices, loras_state):
|
|
243 |
|
244 |
def remove_lora_3(selected_indices, loras_state):
|
245 |
if len(selected_indices) >= 3:
|
246 |
-
selected_indices.pop(
|
247 |
selected_info_1 = "Select a Celebrity as LoRA 1"
|
248 |
selected_info_2 = "Select a LoRA 2"
|
249 |
selected_info_3 = "Select a LoRA 3"
|
@@ -279,7 +267,7 @@ def remove_lora_3(selected_indices, loras_state):
|
|
279 |
|
280 |
def remove_lora_4(selected_indices, loras_state):
|
281 |
if len(selected_indices) >= 4:
|
282 |
-
selected_indices.pop(
|
283 |
selected_info_1 = "Select a Celebrity as LoRA 1"
|
284 |
selected_info_2 = "Select a LoRA 2"
|
285 |
selected_info_3 = "Select a LoRA 3"
|
@@ -463,29 +451,12 @@ def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
|
|
463 |
yield img, seed, f"Generated image {img} with seed {seed}"
|
464 |
return img
|
465 |
|
466 |
-
@spaces.GPU(duration=
|
467 |
def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
|
468 |
if not selected_indices:
|
469 |
raise gr.Error("You must select at least one LoRA before proceeding.")
|
470 |
|
471 |
selected_loras = [loras_state[idx] for idx in selected_indices]
|
472 |
-
|
473 |
-
# Debugging snippet: Inspect LoRAs before loading
|
474 |
-
for idx, lora in enumerate(selected_loras):
|
475 |
-
print(f"Inspecting LoRA {idx + 1}: {lora['title']}")
|
476 |
-
try:
|
477 |
-
lora_weights_path = get_lora_weights(lora['repo'], lora.get("weights"))
|
478 |
-
print(f"LoRA weights fetched from: {lora_weights_path}")
|
479 |
-
lora_weights = torch.load(lora_weights_path, weights_only=True) #lora_weights = torch.load(lora_weights_path)
|
480 |
-
print(f"LoRA Parameter Keys for {lora['title']}: {list(lora_weights.keys())}")
|
481 |
-
except Exception as e:
|
482 |
-
print(f"Error loading LoRA weights for {lora['title']} from {lora['repo']}: {e}")
|
483 |
-
raise gr.Error(f"Failed to load LoRA weights for {lora['title']}.")
|
484 |
-
|
485 |
-
# Print the selected LoRAs
|
486 |
-
print("Running with the following LoRAs:")
|
487 |
-
for lora in selected_loras:
|
488 |
-
print(f"- {lora['title']} from {lora['repo']} with scale {lora_scale_1 if selected_loras.index(lora) == 0 else lora_scale_2}")
|
489 |
|
490 |
# Build the prompt with trigger words
|
491 |
prepends = []
|
@@ -499,7 +470,7 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
499 |
appends.append(trigger_word)
|
500 |
prompt_mash = " ".join(prepends + [prompt] + appends)
|
501 |
print("Prompt Mash: ", prompt_mash)
|
502 |
-
print("
|
503 |
|
504 |
# Unload previous LoRA weights
|
505 |
with calculateDuration("Unloading LoRA"):
|
@@ -513,22 +484,20 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
513 |
with calculateDuration("Loading LoRA weights"):
|
514 |
for idx, lora in enumerate(selected_loras):
|
515 |
lora_name = f"lora_{idx}"
|
516 |
-
|
517 |
-
|
|
|
518 |
pipe.load_lora_weights(
|
519 |
-
|
|
|
520 |
low_cpu_mem_usage=True,
|
521 |
adapter_name=lora_name,
|
522 |
-
merge_and_unload=True,
|
523 |
)
|
524 |
-
|
|
|
525 |
print("Adapter weights:", lora_weights)
|
526 |
-
|
527 |
pipe.set_adapters(lora_names, adapter_weights=lora_weights)
|
528 |
-
except Exception as e:
|
529 |
-
print(f"Error while setting adapters: {e}")
|
530 |
-
raise
|
531 |
-
#pipe.set_adapters(lora_names, adapter_weights=lora_weights)
|
532 |
|
533 |
# Set random seed if required
|
534 |
if randomize_seed:
|
@@ -543,7 +512,7 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
543 |
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
|
544 |
yield image, seed, gr.update(value=progress_bar, visible=True)
|
545 |
|
546 |
-
run_lora.zerogpu =
|
547 |
|
548 |
def get_huggingface_safetensors(link):
|
549 |
split_link = link.split("/")
|
@@ -734,8 +703,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(128, 256)
|
|
734 |
gallery.select(
|
735 |
update_selection,
|
736 |
inputs=[selected_indices, loras_state, width, height],
|
737 |
-
outputs=[prompt, selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, width, height, lora_image_1, lora_image_2, lora_image_3, lora_image_4]
|
738 |
-
)
|
739 |
remove_button_1.click(
|
740 |
remove_lora_1,
|
741 |
inputs=[selected_indices, loras_state],
|
|
|
105 |
|
106 |
return filepath
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
|
109 |
selected_index = evt.index
|
110 |
selected_indices = selected_indices or []
|
|
|
231 |
|
232 |
def remove_lora_3(selected_indices, loras_state):
|
233 |
if len(selected_indices) >= 3:
|
234 |
+
selected_indices.pop(1)
|
235 |
selected_info_1 = "Select a Celebrity as LoRA 1"
|
236 |
selected_info_2 = "Select a LoRA 2"
|
237 |
selected_info_3 = "Select a LoRA 3"
|
|
|
267 |
|
268 |
def remove_lora_4(selected_indices, loras_state):
|
269 |
if len(selected_indices) >= 4:
|
270 |
+
selected_indices.pop(1)
|
271 |
selected_info_1 = "Select a Celebrity as LoRA 1"
|
272 |
selected_info_2 = "Select a LoRA 2"
|
273 |
selected_info_3 = "Select a LoRA 3"
|
|
|
451 |
yield img, seed, f"Generated image {img} with seed {seed}"
|
452 |
return img
|
453 |
|
454 |
+
@spaces.GPU(duration=75)
|
455 |
def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
|
456 |
if not selected_indices:
|
457 |
raise gr.Error("You must select at least one LoRA before proceeding.")
|
458 |
|
459 |
selected_loras = [loras_state[idx] for idx in selected_indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
|
461 |
# Build the prompt with trigger words
|
462 |
prepends = []
|
|
|
470 |
appends.append(trigger_word)
|
471 |
prompt_mash = " ".join(prepends + [prompt] + appends)
|
472 |
print("Prompt Mash: ", prompt_mash)
|
473 |
+
print("--Seed--:", seed)
|
474 |
|
475 |
# Unload previous LoRA weights
|
476 |
with calculateDuration("Unloading LoRA"):
|
|
|
484 |
with calculateDuration("Loading LoRA weights"):
|
485 |
for idx, lora in enumerate(selected_loras):
|
486 |
lora_name = f"lora_{idx}"
|
487 |
+
lora_names.append(lora_name)
|
488 |
+
print(f"Lora Name: {lora_name}")
|
489 |
+
lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
|
490 |
pipe.load_lora_weights(
|
491 |
+
lora['repo'],
|
492 |
+
weight_name=lora.get("weights"),
|
493 |
low_cpu_mem_usage=True,
|
494 |
adapter_name=lora_name,
|
|
|
495 |
)
|
496 |
+
print("Base Model:", base_model)
|
497 |
+
print("Loaded LoRAs:", selected_indices)
|
498 |
print("Adapter weights:", lora_weights)
|
499 |
+
|
500 |
pipe.set_adapters(lora_names, adapter_weights=lora_weights)
|
|
|
|
|
|
|
|
|
501 |
|
502 |
# Set random seed if required
|
503 |
if randomize_seed:
|
|
|
512 |
progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
|
513 |
yield image, seed, gr.update(value=progress_bar, visible=True)
|
514 |
|
515 |
+
run_lora.zerogpu = True
|
516 |
|
517 |
def get_huggingface_safetensors(link):
|
518 |
split_link = link.split("/")
|
|
|
703 |
gallery.select(
|
704 |
update_selection,
|
705 |
inputs=[selected_indices, loras_state, width, height],
|
706 |
+
outputs=[prompt, selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, width, height, lora_image_1, lora_image_2, lora_image_3, lora_image_4])
|
|
|
707 |
remove_button_1.click(
|
708 |
remove_lora_1,
|
709 |
inputs=[selected_indices, loras_state],
|