Spaces:
Runtime error
Runtime error
Zhouyan248
commited on
Commit
•
f5c9414
1
Parent(s):
13a8aa5
Update base/models/unet.py
Browse files- base/models/unet.py +15 -15
base/models/unet.py
CHANGED
@@ -569,21 +569,21 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|
569 |
|
570 |
model = cls.from_config(config)
|
571 |
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
572 |
-
if not os.path.isfile(model_file):
|
573 |
-
|
574 |
-
state_dict = torch.load(model_file, map_location="cpu")
|
575 |
-
for k, v in model.state_dict().items():
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
model.load_state_dict(state_dict)
|
587 |
|
588 |
return model
|
589 |
|
|
|
569 |
|
570 |
model = cls.from_config(config)
|
571 |
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
572 |
+
# if not os.path.isfile(model_file):
|
573 |
+
# raise RuntimeError(f"{model_file} does not exist")
|
574 |
+
# state_dict = torch.load(model_file, map_location="cpu")
|
575 |
+
# for k, v in model.state_dict().items():
|
576 |
+
# # print(k)
|
577 |
+
# if '_temp' in k:
|
578 |
+
# state_dict.update({k: v})
|
579 |
+
# if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
|
580 |
+
# k = k.replace('attn_fcross', 'attn1')
|
581 |
+
# state_dict.update({k: state_dict[k]})
|
582 |
+
# if 'norm_fcross' in k:
|
583 |
+
# k = k.replace('norm_fcross', 'norm1')
|
584 |
+
# state_dict.update({k: state_dict[k]})
|
585 |
+
|
586 |
+
# model.load_state_dict(state_dict)
|
587 |
|
588 |
return model
|
589 |
|