Haoxin Chen commited on
Commit
62b6d65
·
1 Parent(s): 3ab49e3

update change_lora func

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. lvdm/models/modules/lora.py +77 -0
  3. videocrafter_test.py +5 -4
app.py CHANGED
@@ -12,7 +12,7 @@ def videocrafter_demo(result_dir='./tmp/'):
12
  with gr.Tab(label="VideoCrafter"):
13
  input_text = gr.Text()
14
  model_choices=['origin','vangogh','frozen','yourname', 'coco']
15
- trigger_word_list=['','Loving Vincent style', 'frozenmovie style', 'MakotoShinkaiYourName style', 'coco style']
16
 
17
  with gr.Row():
18
  model_index = gr.Dropdown(label='Models', elem_id=f"model", choices=model_choices, value=model_choices[0], type="index",interactive=True)
 
12
  with gr.Tab(label="VideoCrafter"):
13
  input_text = gr.Text()
14
  model_choices=['origin','vangogh','frozen','yourname', 'coco']
15
+ trigger_word_list=[' ','Loving Vincent style', 'frozenmovie style', 'MakotoShinkaiYourName style', 'coco style']
16
 
17
  with gr.Row():
18
  model_index = gr.Dropdown(label='Models', elem_id=f"model", choices=model_choices, value=model_choices[0], type="index",interactive=True)
lvdm/models/modules/lora.py CHANGED
@@ -680,6 +680,83 @@ def change_lora(model, inject_lora=False, lora_scale=1.0, lora_path='', last_tim
680
  net_load_lora(model, lora_path, alpha=lora_scale)
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
 
684
  def load_safeloras(path, device="cpu"):
685
  safeloras = safe_open(path, framework="pt", device=device)
 
680
  net_load_lora(model, lora_path, alpha=lora_scale)
681
 
682
 
683
+ def net_load_lora_v2(net, checkpoint_path, alpha=1.0, remove=False, origin_weight=None):
684
+ visited=[]
685
+ state_dict = torch.load(checkpoint_path)
686
+ for k, v in state_dict.items():
687
+ state_dict[k] = v.to(net.device)
688
+ # import pdb;pdb.set_trace()
689
+ for key in state_dict:
690
+ if ".alpha" in key or key in visited:
691
+ continue
692
+ layer_infos = key.split(".")[:-2] # remove lora_up and down weight
693
+ curr_layer = net
694
+ # find the target layer
695
+ temp_name = layer_infos.pop(0)
696
+ while len(layer_infos) > -1:
697
+ curr_layer = curr_layer.__getattr__(temp_name)
698
+ if len(layer_infos) > 0:
699
+ temp_name = layer_infos.pop(0)
700
+ elif len(layer_infos) == 0:
701
+ break
702
+ if curr_layer.__class__ not in [nn.Linear, nn.Conv2d]:
703
+ print('missing param at:', key)
704
+ continue
705
+ pair_keys = []
706
+ if "lora_down" in key:
707
+ pair_keys.append(key.replace("lora_down", "lora_up"))
708
+ pair_keys.append(key)
709
+ else:
710
+ pair_keys.append(key)
711
+ pair_keys.append(key.replace("lora_up", "lora_down"))
712
+
713
+ # storage weight
714
+ if origin_weight is None:
715
+ origin_weight = dict()
716
+ storage_key = key.replace("lora_down", "lora").replace("lora_up", "lora")
717
+ origin_weight[storage_key] = curr_layer.weight.data.clone()
718
+ else:
719
+ storage_key = key.replace("lora_down", "lora").replace("lora_up", "lora")
720
+ if storage_key not in origin_weight.keys():
721
+ origin_weight[storage_key] = curr_layer.weight.data.clone()
722
+
723
+
724
+ # update
725
+ if len(state_dict[pair_keys[0]].shape) == 4:
726
+ # for conv
727
+ if remove:
728
+ curr_layer.weight.data = origin_weight[storage_key].clone()
729
+ else:
730
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
731
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
732
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
733
+ else:
734
+ # for linear
735
+ if remove:
736
+ curr_layer.weight.data = origin_weight[storage_key].clone()
737
+ else:
738
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
739
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
740
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
741
+
742
+ # update visited list
743
+ for item in pair_keys:
744
+ visited.append(item)
745
+ print('load_weight_num:',len(visited))
746
+ return origin_weight
747
+
748
+ def change_lora_v2(model, inject_lora=False, lora_scale=1.0, lora_path='', last_time_lora='', last_time_lora_scale=1.0, origin_weight=None):
749
+ # remove lora
750
+ if last_time_lora != '':
751
+ origin_weight = net_load_lora_v2(model, last_time_lora, alpha=last_time_lora_scale, remove=True, origin_weight=origin_weight)
752
+ # add new lora
753
+ if inject_lora:
754
+ origin_weight = net_load_lora_v2(model, lora_path, alpha=lora_scale, origin_weight=origin_weight)
755
+ return origin_weight
756
+
757
+
758
+
759
+
760
 
761
  def load_safeloras(path, device="cpu"):
762
  safeloras = safe_open(path, framework="pt", device=device)
videocrafter_test.py CHANGED
@@ -13,7 +13,7 @@ from lvdm.utils.common_utils import str2bool
13
  from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
14
  from scripts.sample_text2video import sample_text2video
15
  from scripts.sample_utils import load_model, get_conditions, make_model_input_shape, torch_to_np
16
- from lvdm.models.modules.lora import change_lora
17
 
18
  from huggingface_hub import hf_hub_download
19
 
@@ -50,13 +50,14 @@ class Text2Video():
50
  self.result_dir = result_dir
51
  self.save_fps = 8
52
  self.ddim_sampler = DDIMSampler(model)
 
53
 
54
  def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0, trigger_word=''):
55
- if trigger_word !='':
56
  input_text = input_text + ', ' + trigger_word
57
  inject_lora = model_index > 0
58
- change_lora(self.model, inject_lora=inject_lora, lora_scale=lora_scale, lora_path=self.lora_path_list[model_index],
59
- last_time_lora=self.last_time_lora, last_time_lora_scale=self.last_time_lora_scale)
60
 
61
  all_videos = sample_text2video(self.model, input_text, n_samples=1, batch_size=1,
62
  sample_type='ddim', sampler=self.ddim_sampler,
 
13
  from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
14
  from scripts.sample_text2video import sample_text2video
15
  from scripts.sample_utils import load_model, get_conditions, make_model_input_shape, torch_to_np
16
+ from lvdm.models.modules.lora import change_lora, change_lora_v2
17
 
18
  from huggingface_hub import hf_hub_download
19
 
 
50
  self.result_dir = result_dir
51
  self.save_fps = 8
52
  self.ddim_sampler = DDIMSampler(model)
53
+ self.origin_weight = None
54
 
55
  def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0, trigger_word=''):
56
+ if trigger_word !=' ':
57
  input_text = input_text + ', ' + trigger_word
58
  inject_lora = model_index > 0
59
+ self.origin_weight = change_lora_v2(self.model, inject_lora=inject_lora, lora_scale=lora_scale, lora_path=self.lora_path_list[model_index],
60
+ last_time_lora=self.last_time_lora, last_time_lora_scale=self.last_time_lora_scale, origin_weight=self.origin_weight)
61
 
62
  all_videos = sample_text2video(self.model, input_text, n_samples=1, batch_size=1,
63
  sample_type='ddim', sampler=self.ddim_sampler,