Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
|
|
|
2 |
import os
|
3 |
import json
|
4 |
import torch
|
5 |
import random
|
|
|
6 |
|
7 |
import gradio as gr
|
8 |
from glob import glob
|
@@ -147,8 +149,8 @@ class AnimateController:
|
|
147 |
raise gr.Error(f"Please select a pretrained model path.")
|
148 |
if motion_module_dropdown == "":
|
149 |
raise gr.Error(f"Please select a motion module.")
|
150 |
-
if base_model_dropdown == "":
|
151 |
-
|
152 |
|
153 |
if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
|
154 |
|
@@ -158,11 +160,13 @@ class AnimateController:
|
|
158 |
).to("cuda")
|
159 |
|
160 |
if self.lora_model_state_dict != {}:
|
161 |
-
|
|
|
162 |
|
163 |
pipeline.to("cuda")
|
164 |
|
165 |
-
|
|
|
166 |
else: torch.seed()
|
167 |
seed = torch.initial_seed()
|
168 |
|
@@ -259,7 +263,7 @@ def ui():
|
|
259 |
)
|
260 |
lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
|
261 |
|
262 |
-
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.
|
263 |
|
264 |
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
265 |
def update_personalized_model():
|
|
|
1 |
|
2 |
+
|
3 |
import os
|
4 |
import json
|
5 |
import torch
|
6 |
import random
|
7 |
+
import copy
|
8 |
|
9 |
import gradio as gr
|
10 |
from glob import glob
|
|
|
149 |
raise gr.Error(f"Please select a pretrained model path.")
|
150 |
if motion_module_dropdown == "":
|
151 |
raise gr.Error(f"Please select a motion module.")
|
152 |
+
# if base_model_dropdown == "":
|
153 |
+
# raise gr.Error(f"Please select a base DreamBooth model.")
|
154 |
|
155 |
if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
|
156 |
|
|
|
160 |
).to("cuda")
|
161 |
|
162 |
if self.lora_model_state_dict != {}:
|
163 |
+
print(f"Lora alpha: {lora_alpha_slider}")
|
164 |
+
pipeline = convert_lora(copy.deepcopy(pipeline), self.lora_model_state_dict, alpha=lora_alpha_slider)
|
165 |
|
166 |
pipeline.to("cuda")
|
167 |
|
168 |
+
seed_textbox = int(seed_textbox)
|
169 |
+
if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(seed_textbox)
|
170 |
else: torch.seed()
|
171 |
seed = torch.initial_seed()
|
172 |
|
|
|
263 |
)
|
264 |
lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
|
265 |
|
266 |
+
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.7, minimum=0, maximum=2, interactive=True)
|
267 |
|
268 |
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
269 |
def update_personalized_model():
|