Update app.py
Browse files
app.py
CHANGED
@@ -92,17 +92,56 @@ def download_file(url, directory=None):
|
|
92 |
file.write(response.content)
|
93 |
|
94 |
return filepath
|
95 |
-
|
96 |
-
def get_trigger_word(base_model, lora_models):
|
97 |
-
trigger_words = [] # Initialize an empty list to hold trigger words
|
98 |
-
for lora in lora_models:
|
99 |
-
trigger_word = lora.get('trigger_word', '')
|
100 |
-
if trigger_word:
|
101 |
-
trigger_words.append(trigger_word)
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
def randomize_loras(selected_indices, loras_state):
|
108 |
if len(loras_state) < 2:
|
@@ -127,42 +166,6 @@ def randomize_loras(selected_indices, loras_state):
|
|
127 |
|
128 |
return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, random_prompt
|
129 |
|
130 |
-
def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
|
131 |
-
selected_index = evt.index
|
132 |
-
selected_indices = selected_indices or []
|
133 |
-
|
134 |
-
if selected_index in selected_indices:
|
135 |
-
selected_indices.remove(selected_index)
|
136 |
-
else:
|
137 |
-
if len(selected_indices) < 4:
|
138 |
-
selected_indices.append(selected_index)
|
139 |
-
else:
|
140 |
-
gr.Warning("You can select up to 4 LoRAs, remove one to select a new one.")
|
141 |
-
return gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), width, height, gr.update(), gr.update(), gr.update()
|
142 |
-
|
143 |
-
selected_info = ["Select a LoRA {}".format(i) for i in range(1, 5)]
|
144 |
-
lora_images = [None] * 4
|
145 |
-
lora_scales = [0.5] * 4
|
146 |
-
|
147 |
-
for i in range(len(selected_indices)):
|
148 |
-
lora = loras_state[selected_indices[i]]
|
149 |
-
trigger_word = lora.get('trigger_word', '')
|
150 |
-
selected_info[i] = f"### LoRA {i + 1} Selected: [{lora['title']}](https://huggingface.co/{lora['repo']}) ✨ {trigger_word}"
|
151 |
-
lora_images[i] = lora['image']
|
152 |
-
|
153 |
-
if selected_indices:
|
154 |
-
last_selected_lora = loras_state[selected_indices[-1]]
|
155 |
-
new_placeholder = f"Type a prompt for {last_selected_lora['title']}"
|
156 |
-
else:
|
157 |
-
new_placeholder = "Type a prompt after selecting a LoRA"
|
158 |
-
|
159 |
-
return (gr.update(placeholder=new_placeholder),
|
160 |
-
*selected_info, selected_indices,
|
161 |
-
*lora_scales,
|
162 |
-
width, height,
|
163 |
-
*lora_images,
|
164 |
-
gr.update()
|
165 |
-
)
|
166 |
def remove_lora_1(selected_indices, loras_state):
|
167 |
if len(selected_indices) >= 1:
|
168 |
selected_indices.pop(0)
|
@@ -310,24 +313,7 @@ def remove_lora_4(selected_indices, loras_state):
|
|
310 |
selected_info_4 = f"### LoRA 4 Selected: [{lora4['title']}]({lora4['repo']}) ✨ {trigger_word}"
|
311 |
lora_image_4 = lora4['image']
|
312 |
return selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, lora_image_1, lora_image_2, lora_image_3, lora_image_4, gr.update()
|
313 |
-
|
314 |
-
def remove_lora(selected_indices, loras_state):
|
315 |
-
# Remove the LoRA based on the index
|
316 |
-
if selected_indices:
|
317 |
-
selected_indices.pop() # Remove the last selected LoRA
|
318 |
-
|
319 |
-
selected_info = ["Select a LoRA"] * 4
|
320 |
-
lora_images = [None] * 4
|
321 |
-
lora_scales = [0.5] * 4
|
322 |
-
|
323 |
-
for i in range(min(len(selected_indices), 4)):
|
324 |
-
lora = loras_state[selected_indices[i]]
|
325 |
-
trigger_word = lora.get('trigger_word', '')
|
326 |
-
selected_info[i] = f"### LoRA {i + 1} Selected: [{lora['title']}]({lora['repo']}) ✨ {trigger_word}"
|
327 |
-
lora_images[i] = lora['image']
|
328 |
-
|
329 |
-
return selected_info, selected_indices, lora_scales, lora_images, gr.update()
|
330 |
-
|
331 |
def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
|
332 |
if custom_lora:
|
333 |
try:
|
@@ -395,9 +381,9 @@ def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
|
|
395 |
except Exception as e:
|
396 |
print(e)
|
397 |
gr.Warning(str(e))
|
398 |
-
return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
|
399 |
else:
|
400 |
-
return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update()
|
401 |
|
402 |
def remove_custom_lora(selected_indices, current_loras, gallery):
|
403 |
if current_loras:
|
@@ -465,13 +451,13 @@ def generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
|
|
465 |
output_type="pil",
|
466 |
good_vae=good_vae,
|
467 |
):
|
468 |
-
print("Image generated successfully.") # Debugging statement
|
469 |
yield img
|
470 |
final_image = img # Update final_image with the current image
|
471 |
return final_image
|
472 |
|
473 |
@spaces.GPU(duration=75)
|
474 |
-
def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, randomize_seed, seed, width, height, loras_state, image_input=None, progress=gr.Progress(track_tqdm=True)):
|
475 |
print("run_lora function called.") # Debugging statement
|
476 |
print(f"Inputs received - Prompt: {prompt}, CFG Scale: {cfg_scale}, Steps: {steps}, Seed: {seed}, Width: {width}, Height: {height}") # Debugging statement
|
477 |
|
@@ -487,7 +473,7 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
487 |
for lora in selected_loras:
|
488 |
trigger_word = lora.get('trigger_word', '')
|
489 |
if trigger_word:
|
490 |
-
if lora.get("trigger_position") == "
|
491 |
prepends.append(trigger_word)
|
492 |
else:
|
493 |
appends.append(trigger_word)
|
@@ -540,7 +526,7 @@ def run_lora(prompt, cfg_scale, steps, selected_indices, lora_scale_1, lora_scal
|
|
540 |
try:
|
541 |
if image_input is not None:
|
542 |
final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
|
543 |
-
yield final_image, seed, gr.update(visible=
|
544 |
else:
|
545 |
image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
|
546 |
# Consume the generator to get the final image
|
@@ -563,7 +549,7 @@ run_lora.zerogpu = True
|
|
563 |
|
564 |
def get_huggingface_safetensors(link):
|
565 |
split_link = link.split("/")
|
566 |
-
if len(split_link) ==
|
567 |
model_card = ModelCard.load(link)
|
568 |
base_model = model_card.data.get("base_model")
|
569 |
print(f"Base model: {base_model}")
|
|
|
92 |
file.write(response.content)
|
93 |
|
94 |
return filepath
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
+
def update_selection(evt: gr.SelectData, selected_indices, loras_state, width, height):
|
97 |
+
selected_index = evt.index
|
98 |
+
selected_indices = selected_indices or []
|
99 |
+
if selected_index in selected_indices:
|
100 |
+
selected_indices.remove(selected_index)
|
101 |
+
else:
|
102 |
+
if len(selected_indices) < 4:
|
103 |
+
selected_indices.append(selected_index)
|
104 |
+
else:
|
105 |
+
gr.Warning("You can select up to 4 LoRAs, remove one to select a new one.")
|
106 |
+
return gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), width, height, gr.update(), gr.update(), gr.update(), gr.update()
|
107 |
+
|
108 |
+
selected_info_1 = "Select a LoRA 1"
|
109 |
+
selected_info_2 = "Select a LoRA 2"
|
110 |
+
selected_info_3 = "Select a LoRA 3"
|
111 |
+
selected_info_4 = "Select a LoRA 4"
|
112 |
+
lora_scale_1 = 0.5
|
113 |
+
lora_scale_2 = 0.5
|
114 |
+
lora_scale_3 = 0.5
|
115 |
+
lora_scale_4 = 0.5
|
116 |
+
lora_image_1 = None
|
117 |
+
lora_image_2 = None
|
118 |
+
lora_image_3 = None
|
119 |
+
lora_image_4 = None
|
120 |
+
if len(selected_indices) >= 1:
|
121 |
+
lora1 = loras_state[selected_indices[0]]
|
122 |
+
selected_info_1 = f"### LoRA 1 Selected: [{lora1['title']}](https://huggingface.co/{lora1['repo']}) ✨"
|
123 |
+
lora_image_1 = lora1['image']
|
124 |
+
if len(selected_indices) >= 2:
|
125 |
+
lora2 = loras_state[selected_indices[1]]
|
126 |
+
selected_info_2 = f"### LoRA 2 Selected: [{lora2['title']}](https://huggingface.co/{lora2['repo']}) ✨"
|
127 |
+
lora_image_2 = lora2['image']
|
128 |
+
if len(selected_indices) >= 3:
|
129 |
+
lora3 = loras_state[selected_indices[2]]
|
130 |
+
selected_info_3 = f"### LoRA 3 Selected: [{lora2['title']}](https://huggingface.co/{lora3['repo']}) ✨"
|
131 |
+
lora_image_3 = lora3['image']
|
132 |
+
if len(selected_indices) >= 4:
|
133 |
+
lora4 = loras_state[selected_indices[3]]
|
134 |
+
selected_info_4 = f"### LoRA 4 Selected: [{lora4['title']}](https://huggingface.co/{lora4['repo']}) ✨"
|
135 |
+
lora_image_4 = lora4['image']
|
136 |
+
|
137 |
+
if selected_indices:
|
138 |
+
last_selected_lora = loras_state[selected_indices[-1]]
|
139 |
+
new_placeholder = f"Type a prompt for {last_selected_lora['title']}"
|
140 |
+
else:
|
141 |
+
new_placeholder = "Type a prompt after selecting a LoRA"
|
142 |
+
|
143 |
+
return gr.update(placeholder=new_placeholder), selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, width, height, lora_image_1, lora_image_2, lora_image_3, lora_image_4
|
144 |
+
|
145 |
|
146 |
def randomize_loras(selected_indices, loras_state):
|
147 |
if len(loras_state) < 2:
|
|
|
166 |
|
167 |
return selected_info_1, selected_info_2, selected_indices, lora_scale_1, lora_scale_2, lora_image_1, lora_image_2, random_prompt
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
def remove_lora_1(selected_indices, loras_state):
|
170 |
if len(selected_indices) >= 1:
|
171 |
selected_indices.pop(0)
|
|
|
313 |
selected_info_4 = f"### LoRA 4 Selected: [{lora4['title']}]({lora4['repo']}) ✨ {trigger_word}"
|
314 |
lora_image_4 = lora4['image']
|
315 |
return selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, lora_image_1, lora_image_2, lora_image_3, lora_image_4, gr.update()
|
316 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
def add_custom_lora(custom_lora, selected_indices, current_loras, gallery):
|
318 |
if custom_lora:
|
319 |
try:
|
|
|
381 |
except Exception as e:
|
382 |
print(e)
|
383 |
gr.Warning(str(e))
|
384 |
+
return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
385 |
else:
|
386 |
+
return current_loras, gr.update(), gr.update(), gr.update(), selected_indices, gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
|
387 |
|
388 |
def remove_custom_lora(selected_indices, current_loras, gallery):
|
389 |
if current_loras:
|
|
|
451 |
output_type="pil",
|
452 |
good_vae=good_vae,
|
453 |
):
|
454 |
+
#print("Image generated successfully.") # Debugging statement
|
455 |
yield img
|
456 |
final_image = img # Update final_image with the current image
|
457 |
return final_image
|
458 |
|
459 |
@spaces.GPU(duration=75)
|
460 |
+
def run_lora(prompt, cfg_scale, steps, selected_info_1, selected_info_2, selected_info_3, selected_info_4, selected_indices, lora_scale_1, lora_scale_2, lora_scale_3, lora_scale_4, randomize_seed, seed, width, height, loras_state, image_input=None, progress=gr.Progress(track_tqdm=True)):
|
461 |
print("run_lora function called.") # Debugging statement
|
462 |
print(f"Inputs received - Prompt: {prompt}, CFG Scale: {cfg_scale}, Steps: {steps}, Seed: {seed}, Width: {width}, Height: {height}") # Debugging statement
|
463 |
|
|
|
473 |
for lora in selected_loras:
|
474 |
trigger_word = lora.get('trigger_word', '')
|
475 |
if trigger_word:
|
476 |
+
if lora.get("trigger_position") == "append":
|
477 |
prepends.append(trigger_word)
|
478 |
else:
|
479 |
appends.append(trigger_word)
|
|
|
526 |
try:
|
527 |
if image_input is not None:
|
528 |
final_image = generate_image_to_image(prompt_mash, image_input, image_strength, steps, cfg_scale, width, height, seed)
|
529 |
+
yield final_image, seed, gr.update(visible=False)
|
530 |
else:
|
531 |
image_generator = generate_image(prompt_mash, steps, seed, cfg_scale, width, height, progress)
|
532 |
# Consume the generator to get the final image
|
|
|
549 |
|
550 |
def get_huggingface_safetensors(link):
|
551 |
split_link = link.split("/")
|
552 |
+
if len(split_link) == 4:
|
553 |
model_card = ModelCard.load(link)
|
554 |
base_model = model_card.data.get("base_model")
|
555 |
print(f"Base model: {base_model}")
|