Keltezaa commited on
Commit
de6b25f
·
verified ·
1 Parent(s): f7099c9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -46
app.py CHANGED
@@ -105,18 +105,6 @@ def download_file(url, directory=None):
105
 
106
  return filepath
107
 
108
- def get_lora_weights(lora_repo, weight_name=None):
109
- try:
110
- # Download the weights from Hugging Face Hub
111
- file_path = hf_hub_download(
112
- repo_id=lora_repo,
113
- filename=weight_name if weight_name else "pytorch_model.bin"
114
- )
115
- return file_path
116
- except Exception as e:
117
- print(f"Failed to fetch weights for {lora_repo}: {e}")
118
- raise
119
-
120
  def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
121
  selected_index = evt.index
122
  selected_indices = selected_indices or []
@@ -243,7 +231,7 @@ def remove_lora_2(selected_indices, loras_state):
243
 
244
  def remove_lora_3(selected_indices, loras_state):
245
  if len(selected_indices) >= 3:
246
- selected_indices.pop(2)
247
  selected_info_1 = "Select a Celebrity as LoRA 1"
248
  selected_info_2 = "Select a LoRA 2"
249
  selected_info_3 = "Select a LoRA 3"
@@ -279,7 +267,7 @@ def remove_lora_3(selected_indices, loras_state):
279
 
280
  def remove_lora_4(selected_indices, loras_state):
281
  if len(selected_indices) >= 4:
282
- selected_indices.pop(3)
283
  selected_info_1 = "Select a Celebrity as LoRA 1"
284
  selected_info_2 = "Select a LoRA 2"
285
  selected_info_3 = "Select a LoRA 3"
@@ -463,29 +451,12 @@ def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
463
  yield img, seed, f"Generated image {img} with seed {seed}"
464
  return img
465
 
466
- @spaces.GPU(duration=85)
467
  def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
468
  if not selected_indices:
469
  raise gr.Error("You must select at least one LoRA before proceeding.")
470
 
471
  selected_loras = [loras_state[idx] for idx in selected_indices]
472
-
473
- # Debugging snippet: Inspect LoRAs before loading
474
- for idx, lora in enumerate(selected_loras):
475
- print(f"Inspecting LoRA {idx + 1}: {lora['title']}")
476
- try:
477
- lora_weights_path = get_lora_weights(lora['repo'], lora.get("weights"))
478
- print(f"LoRA weights fetched from: {lora_weights_path}")
479
- lora_weights = torch.load(lora_weights_path, weights_only=True) #lora_weights = torch.load(lora_weights_path)
480
- print(f"LoRA Parameter Keys for {lora['title']}: {list(lora_weights.keys())}")
481
- except Exception as e:
482
- print(f"Error loading LoRA weights for {lora['title']} from {lora['repo']}: {e}")
483
- raise gr.Error(f"Failed to load LoRA weights for {lora['title']}.")
484
-
485
- # Print the selected LoRAs
486
- print("Running with the following LoRAs:")
487
- for lora in selected_loras:
488
- print(f"- {lora['title']} from {lora['repo']} with scale {lora_scale_1 if selected_loras.index(lora) == 0 else lora_scale_2}")
489
 
490
  # Build the prompt with trigger words
491
  prepends = []
@@ -499,7 +470,7 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
499
  appends.append(trigger_word)
500
  prompt_mash = " ".join(prepends + [prompt] + appends)
501
  print("Prompt Mash: ", prompt_mash)
502
- print(":--Seed--:", seed)
503
 
504
  # Unload previous LoRA weights
505
  with calculateDuration("Unloading LoRA"):
@@ -513,22 +484,20 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
513
  with calculateDuration("Loading LoRA weights"):
514
  for idx, lora in enumerate(selected_loras):
515
  lora_name = f"lora_{idx}"
516
- print(f"Loading LoRA: {lora['title']} with adapter name: {lora_name}")
517
- lora_weights_path = get_lora_weights(lora['repo'], lora.get("weights"))
 
518
  pipe.load_lora_weights(
519
- lora_weights_path,
 
520
  low_cpu_mem_usage=True,
521
  adapter_name=lora_name,
522
- merge_and_unload=True,
523
  )
524
-
 
525
  print("Adapter weights:", lora_weights)
526
- try:
527
  pipe.set_adapters(lora_names, adapter_weights=lora_weights)
528
- except Exception as e:
529
- print(f"Error while setting adapters: {e}")
530
- raise
531
- #pipe.set_adapters(lora_names, adapter_weights=lora_weights)
532
 
533
  # Set random seed if required
534
  if randomize_seed:
@@ -543,7 +512,7 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
543
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
544
  yield image, seed, gr.update(value=progress_bar, visible=True)
545
 
546
- run_lora.zerogpu = False
547
 
548
  def get_huggingface_safetensors(link):
549
  split_link = link.split("/")
@@ -734,8 +703,7 @@ with gr.Blocks(theme=gr.themes.Soft(font=font), css=css, delete_cache=(128, 256)
734
  gallery.select(
735
  update_selection,
736
  inputs=[selected_indices, loras_state, width, height],
737
- outputs=[prompt, selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, width, height, lora_image_1, lora_image_2, lora_image_3, lora_image_4]
738
- )
739
  remove_button_1.click(
740
  remove_lora_1,
741
  inputs=[selected_indices, loras_state],
 
105
 
106
  return filepath
107
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
109
  selected_index = evt.index
110
  selected_indices = selected_indices or []
 
231
 
232
  def remove_lora_3(selected_indices, loras_state):
233
  if len(selected_indices) >= 3:
234
+ selected_indices.pop(1)
235
  selected_info_1 = "Select a Celebrity as LoRA 1"
236
  selected_info_2 = "Select a LoRA 2"
237
  selected_info_3 = "Select a LoRA 3"
 
267
 
268
  def remove_lora_4(selected_indices, loras_state):
269
  if len(selected_indices) >= 4:
270
+ selected_indices.pop(1)
271
  selected_info_1 = "Select a Celebrity as LoRA 1"
272
  selected_info_2 = "Select a LoRA 2"
273
  selected_info_3 = "Select a LoRA 3"
 
451
  yield img, seed, f"Generated image {img} with seed {seed}"
452
  return img
453
 
454
+ @spaces.GPU(duration=75)
455
  def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, randomize_seed, seed, width, height, loras_state, progress=gr.Progress(track_tqdm=True)):
456
  if not selected_indices:
457
  raise gr.Error("You must select at least one LoRA before proceeding.")
458
 
459
  selected_loras = [loras_state[idx] for idx in selected_indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
  # Build the prompt with trigger words
462
  prepends = []
 
470
  appends.append(trigger_word)
471
  prompt_mash = " ".join(prepends + [prompt] + appends)
472
  print("Prompt Mash: ", prompt_mash)
473
+ print("--Seed--:", seed)
474
 
475
  # Unload previous LoRA weights
476
  with calculateDuration("Unloading LoRA"):
 
484
  with calculateDuration("Loading LoRA weights"):
485
  for idx, lora in enumerate(selected_loras):
486
  lora_name = f"lora_{idx}"
487
+ lora_names.append(lora_name)
488
+ print(f"Lora Name: {lora_name}")
489
+ lora_weights.append(lora_scale_1 if idx == 0 else lora_scale_2)
490
  pipe.load_lora_weights(
491
+ lora['repo'],
492
+ weight_name=lora.get("weights"),
493
  low_cpu_mem_usage=True,
494
  adapter_name=lora_name,
 
495
  )
496
+ print("Base Model:", base_model)
497
+ print("Loaded LoRAs:", selected_indices)
498
  print("Adapter weights:", lora_weights)
499
+
500
  pipe.set_adapters(lora_names, adapter_weights=lora_weights)
 
 
 
 
501
 
502
  # Set random seed if required
503
  if randomize_seed:
 
512
  progress_bar = f'<div class="progress-container"><div class="progress-bar" style="--current: {step_counter}; --total: {steps};"></div></div>'
513
  yield image, seed, gr.update(value=progress_bar, visible=True)
514
 
515
+ run_lora.zerogpu = True
516
 
517
  def get_huggingface_safetensors(link):
518
  split_link = link.split("/")
 
703
  gallery.select(
704
  update_selection,
705
  inputs=[selected_indices, loras_state, width, height],
706
+ outputs=[prompt, selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, width, height, lora_image_1, lora_image_2, lora_image_3, lora_image_4])
 
707
  remove_button_1.click(
708
  remove_lora_1,
709
  inputs=[selected_indices, loras_state],