Keltezaa commited on
Commit
f7099c9
·
verified ·
1 Parent(s): e91d986

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -105,6 +105,18 @@ def download_file(url, directory=None):
105
 
106
  return filepath
107
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
109
  selected_index = evt.index
110
  selected_indices = selected_indices or []
@@ -462,7 +474,9 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
462
  for idx, lora in enumerate(selected_loras):
463
  print(f"Inspecting LoRA {idx + 1}: {lora['title']}")
464
  try:
465
- lora_weights = torch.load(lora['repo']) # Load the LoRA weights
 
 
466
  print(f"LoRA Parameter Keys for {lora['title']}: {list(lora_weights.keys())}")
467
  except Exception as e:
468
  print(f"Error loading LoRA weights for {lora['title']} from {lora['repo']}: {e}")
@@ -499,16 +513,14 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
499
  with calculateDuration("Loading LoRA weights"):
500
  for idx, lora in enumerate(selected_loras):
501
  lora_name = f"lora_{idx}"
502
- lora_weights_path = lora.get("weights")
503
- print(f"Loading LoRA: {lora['title']} from {lora['repo']} with adapter name: {lora_name}")
504
-
505
  pipe.load_lora_weights(
506
- lora['repo'],
507
- weight_name=lora_weights_path,
508
  low_cpu_mem_usage=True,
509
  adapter_name=lora_name,
510
- merge_and_unload=True, # Explicitly merge weights to avoid runtime conflicts
511
- )
512
 
513
  print("Adapter weights:", lora_weights)
514
  try:
 
105
 
106
  return filepath
107
 
108
+ def get_lora_weights(lora_repo, weight_name=None):
109
+ try:
110
+ # Download the weights from Hugging Face Hub
111
+ file_path = hf_hub_download(
112
+ repo_id=lora_repo,
113
+ filename=weight_name if weight_name else "pytorch_model.bin"
114
+ )
115
+ return file_path
116
+ except Exception as e:
117
+ print(f"Failed to fetch weights for {lora_repo}: {e}")
118
+ raise
119
+
120
  def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
121
  selected_index = evt.index
122
  selected_indices = selected_indices or []
 
474
  for idx, lora in enumerate(selected_loras):
475
  print(f"Inspecting LoRA {idx + 1}: {lora['title']}")
476
  try:
477
+ lora_weights_path = get_lora_weights(lora['repo'], lora.get("weights"))
478
+ print(f"LoRA weights fetched from: {lora_weights_path}")
479
+ lora_weights = torch.load(lora_weights_path, weights_only=True) #lora_weights = torch.load(lora_weights_path)
480
  print(f"LoRA Parameter Keys for {lora['title']}: {list(lora_weights.keys())}")
481
  except Exception as e:
482
  print(f"Error loading LoRA weights for {lora['title']} from {lora['repo']}: {e}")
 
513
  with calculateDuration("Loading LoRA weights"):
514
  for idx, lora in enumerate(selected_loras):
515
  lora_name = f"lora_{idx}"
516
+ print(f"Loading LoRA: {lora['title']} with adapter name: {lora_name}")
517
+ lora_weights_path = get_lora_weights(lora['repo'], lora.get("weights"))
 
518
  pipe.load_lora_weights(
519
+ lora_weights_path,
 
520
  low_cpu_mem_usage=True,
521
  adapter_name=lora_name,
522
+ merge_and_unload=True,
523
+ )
524
 
525
  print("Adapter weights:", lora_weights)
526
  try: