Update app.py
Browse files
app.py
CHANGED
@@ -105,6 +105,18 @@ def download_file(url, directory=None):
|
|
105 |
|
106 |
return filepath
|
107 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
|
109 |
selected_index = evt.index
|
110 |
selected_indices = selected_indices or []
|
@@ -462,7 +474,9 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
462 |
for idx, lora in enumerate(selected_loras):
|
463 |
print(f"Inspecting LoRA {idx + 1}: {lora['title']}")
|
464 |
try:
|
465 |
-
|
|
|
|
|
466 |
print(f"LoRA Parameter Keys for {lora['title']}: {list(lora_weights.keys())}")
|
467 |
except Exception as e:
|
468 |
print(f"Error loading LoRA weights for {lora['title']} from {lora['repo']}: {e}")
|
@@ -499,16 +513,14 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
499 |
with calculateDuration("Loading LoRA weights"):
|
500 |
for idx, lora in enumerate(selected_loras):
|
501 |
lora_name = f"lora_{idx}"
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
pipe.load_lora_weights(
|
506 |
-
|
507 |
-
weight_name=lora_weights_path,
|
508 |
low_cpu_mem_usage=True,
|
509 |
adapter_name=lora_name,
|
510 |
-
merge_and_unload=True,
|
511 |
-
|
512 |
|
513 |
print("Adapter weights:", lora_weights)
|
514 |
try:
|
|
|
105 |
|
106 |
return filepath
|
107 |
|
108 |
+
def get_lora_weights(lora_repo, weight_name=None):
|
109 |
+
try:
|
110 |
+
# Download the weights from Hugging Face Hub
|
111 |
+
file_path = hf_hub_download(
|
112 |
+
repo_id=lora_repo,
|
113 |
+
filename=weight_name if weight_name else "pytorch_model.bin"
|
114 |
+
)
|
115 |
+
return file_path
|
116 |
+
except Exception as e:
|
117 |
+
print(f"Failed to fetch weights for {lora_repo}: {e}")
|
118 |
+
raise
|
119 |
+
|
120 |
def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
|
121 |
selected_index = evt.index
|
122 |
selected_indices = selected_indices or []
|
|
|
474 |
for idx, lora in enumerate(selected_loras):
|
475 |
print(f"Inspecting LoRA {idx + 1}: {lora['title']}")
|
476 |
try:
|
477 |
+
lora_weights_path = get_lora_weights(lora['repo'], lora.get("weights"))
|
478 |
+
print(f"LoRA weights fetched from: {lora_weights_path}")
|
479 |
+
lora_weights = torch.load(lora_weights_path, weights_only=True) #lora_weights = torch.load(lora_weights_path)
|
480 |
print(f"LoRA Parameter Keys for {lora['title']}: {list(lora_weights.keys())}")
|
481 |
except Exception as e:
|
482 |
print(f"Error loading LoRA weights for {lora['title']} from {lora['repo']}: {e}")
|
|
|
513 |
with calculateDuration("Loading LoRA weights"):
|
514 |
for idx, lora in enumerate(selected_loras):
|
515 |
lora_name = f"lora_{idx}"
|
516 |
+
print(f"Loading LoRA: {lora['title']} with adapter name: {lora_name}")
|
517 |
+
lora_weights_path = get_lora_weights(lora['repo'], lora.get("weights"))
|
|
|
518 |
pipe.load_lora_weights(
|
519 |
+
lora_weights_path,
|
|
|
520 |
low_cpu_mem_usage=True,
|
521 |
adapter_name=lora_name,
|
522 |
+
merge_and_unload=True,
|
523 |
+
)
|
524 |
|
525 |
print("Adapter weights:", lora_weights)
|
526 |
try:
|