Zhouyan248 commited on
Commit
f5c9414
1 Parent(s): 13a8aa5

Update base/models/unet.py

Browse files
Files changed (1) hide show
  1. base/models/unet.py +15 -15
base/models/unet.py CHANGED
@@ -569,21 +569,21 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
569
 
570
  model = cls.from_config(config)
571
  model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
572
- if not os.path.isfile(model_file):
573
- raise RuntimeError(f"{model_file} does not exist")
574
- state_dict = torch.load(model_file, map_location="cpu")
575
- for k, v in model.state_dict().items():
576
- # print(k)
577
- if '_temp' in k:
578
- state_dict.update({k: v})
579
- if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
580
- k = k.replace('attn_fcross', 'attn1')
581
- state_dict.update({k: state_dict[k]})
582
- if 'norm_fcross' in k:
583
- k = k.replace('norm_fcross', 'norm1')
584
- state_dict.update({k: state_dict[k]})
585
-
586
- model.load_state_dict(state_dict)
587
 
588
  return model
589
 
 
569
 
570
  model = cls.from_config(config)
571
  model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
572
+ # if not os.path.isfile(model_file):
573
+ # raise RuntimeError(f"{model_file} does not exist")
574
+ # state_dict = torch.load(model_file, map_location="cpu")
575
+ # for k, v in model.state_dict().items():
576
+ # # print(k)
577
+ # if '_temp' in k:
578
+ # state_dict.update({k: v})
579
+ # if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
580
+ # k = k.replace('attn_fcross', 'attn1')
581
+ # state_dict.update({k: state_dict[k]})
582
+ # if 'norm_fcross' in k:
583
+ # k = k.replace('norm_fcross', 'norm1')
584
+ # state_dict.update({k: state_dict[k]})
585
+
586
+ # model.load_state_dict(state_dict)
587
 
588
  return model
589