Spaces:
Build error
Build error
Haoxin Chen
commited on
Commit
·
62b6d65
1
Parent(s):
3ab49e3
update change_lora func
Browse files- app.py +1 -1
- lvdm/models/modules/lora.py +77 -0
- videocrafter_test.py +5 -4
app.py
CHANGED
@@ -12,7 +12,7 @@ def videocrafter_demo(result_dir='./tmp/'):
|
|
12 |
with gr.Tab(label="VideoCrafter"):
|
13 |
input_text = gr.Text()
|
14 |
model_choices=['origin','vangogh','frozen','yourname', 'coco']
|
15 |
-
trigger_word_list=['','Loving Vincent style', 'frozenmovie style', 'MakotoShinkaiYourName style', 'coco style']
|
16 |
|
17 |
with gr.Row():
|
18 |
model_index = gr.Dropdown(label='Models', elem_id=f"model", choices=model_choices, value=model_choices[0], type="index",interactive=True)
|
|
|
12 |
with gr.Tab(label="VideoCrafter"):
|
13 |
input_text = gr.Text()
|
14 |
model_choices=['origin','vangogh','frozen','yourname', 'coco']
|
15 |
+
trigger_word_list=[' ','Loving Vincent style', 'frozenmovie style', 'MakotoShinkaiYourName style', 'coco style']
|
16 |
|
17 |
with gr.Row():
|
18 |
model_index = gr.Dropdown(label='Models', elem_id=f"model", choices=model_choices, value=model_choices[0], type="index",interactive=True)
|
lvdm/models/modules/lora.py
CHANGED
@@ -680,6 +680,83 @@ def change_lora(model, inject_lora=False, lora_scale=1.0, lora_path='', last_tim
|
|
680 |
net_load_lora(model, lora_path, alpha=lora_scale)
|
681 |
|
682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
683 |
|
684 |
def load_safeloras(path, device="cpu"):
|
685 |
safeloras = safe_open(path, framework="pt", device=device)
|
|
|
680 |
net_load_lora(model, lora_path, alpha=lora_scale)
|
681 |
|
682 |
|
683 |
+
def net_load_lora_v2(net, checkpoint_path, alpha=1.0, remove=False, origin_weight=None):
|
684 |
+
visited=[]
|
685 |
+
state_dict = torch.load(checkpoint_path)
|
686 |
+
for k, v in state_dict.items():
|
687 |
+
state_dict[k] = v.to(net.device)
|
688 |
+
# import pdb;pdb.set_trace()
|
689 |
+
for key in state_dict:
|
690 |
+
if ".alpha" in key or key in visited:
|
691 |
+
continue
|
692 |
+
layer_infos = key.split(".")[:-2] # remove lora_up and down weight
|
693 |
+
curr_layer = net
|
694 |
+
# find the target layer
|
695 |
+
temp_name = layer_infos.pop(0)
|
696 |
+
while len(layer_infos) > -1:
|
697 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
698 |
+
if len(layer_infos) > 0:
|
699 |
+
temp_name = layer_infos.pop(0)
|
700 |
+
elif len(layer_infos) == 0:
|
701 |
+
break
|
702 |
+
if curr_layer.__class__ not in [nn.Linear, nn.Conv2d]:
|
703 |
+
print('missing param at:', key)
|
704 |
+
continue
|
705 |
+
pair_keys = []
|
706 |
+
if "lora_down" in key:
|
707 |
+
pair_keys.append(key.replace("lora_down", "lora_up"))
|
708 |
+
pair_keys.append(key)
|
709 |
+
else:
|
710 |
+
pair_keys.append(key)
|
711 |
+
pair_keys.append(key.replace("lora_up", "lora_down"))
|
712 |
+
|
713 |
+
# storage weight
|
714 |
+
if origin_weight is None:
|
715 |
+
origin_weight = dict()
|
716 |
+
storage_key = key.replace("lora_down", "lora").replace("lora_up", "lora")
|
717 |
+
origin_weight[storage_key] = curr_layer.weight.data.clone()
|
718 |
+
else:
|
719 |
+
storage_key = key.replace("lora_down", "lora").replace("lora_up", "lora")
|
720 |
+
if storage_key not in origin_weight.keys():
|
721 |
+
origin_weight[storage_key] = curr_layer.weight.data.clone()
|
722 |
+
|
723 |
+
|
724 |
+
# update
|
725 |
+
if len(state_dict[pair_keys[0]].shape) == 4:
|
726 |
+
# for conv
|
727 |
+
if remove:
|
728 |
+
curr_layer.weight.data = origin_weight[storage_key].clone()
|
729 |
+
else:
|
730 |
+
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
731 |
+
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
732 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
733 |
+
else:
|
734 |
+
# for linear
|
735 |
+
if remove:
|
736 |
+
curr_layer.weight.data = origin_weight[storage_key].clone()
|
737 |
+
else:
|
738 |
+
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
739 |
+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
740 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
741 |
+
|
742 |
+
# update visited list
|
743 |
+
for item in pair_keys:
|
744 |
+
visited.append(item)
|
745 |
+
print('load_weight_num:',len(visited))
|
746 |
+
return origin_weight
|
747 |
+
|
748 |
+
def change_lora_v2(model, inject_lora=False, lora_scale=1.0, lora_path='', last_time_lora='', last_time_lora_scale=1.0, origin_weight=None):
|
749 |
+
# remove lora
|
750 |
+
if last_time_lora != '':
|
751 |
+
origin_weight = net_load_lora_v2(model, last_time_lora, alpha=last_time_lora_scale, remove=True, origin_weight=origin_weight)
|
752 |
+
# add new lora
|
753 |
+
if inject_lora:
|
754 |
+
origin_weight = net_load_lora_v2(model, lora_path, alpha=lora_scale, origin_weight=origin_weight)
|
755 |
+
return origin_weight
|
756 |
+
|
757 |
+
|
758 |
+
|
759 |
+
|
760 |
|
761 |
def load_safeloras(path, device="cpu"):
|
762 |
safeloras = safe_open(path, framework="pt", device=device)
|
videocrafter_test.py
CHANGED
@@ -13,7 +13,7 @@ from lvdm.utils.common_utils import str2bool
|
|
13 |
from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
|
14 |
from scripts.sample_text2video import sample_text2video
|
15 |
from scripts.sample_utils import load_model, get_conditions, make_model_input_shape, torch_to_np
|
16 |
-
from lvdm.models.modules.lora import change_lora
|
17 |
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
|
@@ -50,13 +50,14 @@ class Text2Video():
|
|
50 |
self.result_dir = result_dir
|
51 |
self.save_fps = 8
|
52 |
self.ddim_sampler = DDIMSampler(model)
|
|
|
53 |
|
54 |
def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0, trigger_word=''):
|
55 |
-
if trigger_word !='':
|
56 |
input_text = input_text + ', ' + trigger_word
|
57 |
inject_lora = model_index > 0
|
58 |
-
|
59 |
-
last_time_lora=self.last_time_lora, last_time_lora_scale=self.last_time_lora_scale)
|
60 |
|
61 |
all_videos = sample_text2video(self.model, input_text, n_samples=1, batch_size=1,
|
62 |
sample_type='ddim', sampler=self.ddim_sampler,
|
|
|
13 |
from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
|
14 |
from scripts.sample_text2video import sample_text2video
|
15 |
from scripts.sample_utils import load_model, get_conditions, make_model_input_shape, torch_to_np
|
16 |
+
from lvdm.models.modules.lora import change_lora, change_lora_v2
|
17 |
|
18 |
from huggingface_hub import hf_hub_download
|
19 |
|
|
|
50 |
self.result_dir = result_dir
|
51 |
self.save_fps = 8
|
52 |
self.ddim_sampler = DDIMSampler(model)
|
53 |
+
self.origin_weight = None
|
54 |
|
55 |
def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0, trigger_word=''):
|
56 |
+
if trigger_word !=' ':
|
57 |
input_text = input_text + ', ' + trigger_word
|
58 |
inject_lora = model_index > 0
|
59 |
+
self.origin_weight = change_lora_v2(self.model, inject_lora=inject_lora, lora_scale=lora_scale, lora_path=self.lora_path_list[model_index],
|
60 |
+
last_time_lora=self.last_time_lora, last_time_lora_scale=self.last_time_lora_scale, origin_weight=self.origin_weight)
|
61 |
|
62 |
all_videos = sample_text2video(self.model, input_text, n_samples=1, batch_size=1,
|
63 |
sample_type='ddim', sampler=self.ddim_sampler,
|