Spaces:
Sleeping
Sleeping
Update networks.py
Browse files- networks.py +27 -4
networks.py
CHANGED
@@ -12,7 +12,7 @@ class Options:
|
|
12 |
# Default values
|
13 |
self.fine_height = 256
|
14 |
self.fine_width = 192
|
15 |
-
self.grid_size =
|
16 |
self.use_dropout = False
|
17 |
|
18 |
def weights_init_normal(m):
|
@@ -499,7 +499,30 @@ def save_checkpoint(model, save_path):
|
|
499 |
os.makedirs(os.path.dirname(save_path))
|
500 |
torch.save(model.state_dict(), save_path)
|
501 |
|
502 |
-
def load_checkpoint(model, checkpoint_path):
|
503 |
if not os.path.exists(checkpoint_path):
|
504 |
-
|
505 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
# Default values
|
13 |
self.fine_height = 256
|
14 |
self.fine_width = 192
|
15 |
+
self.grid_size = 5
|
16 |
self.use_dropout = False
|
17 |
|
18 |
def weights_init_normal(m):
|
|
|
499 |
os.makedirs(os.path.dirname(save_path))
|
500 |
torch.save(model.state_dict(), save_path)
|
501 |
|
502 |
+
def load_checkpoint(model, checkpoint_path, strict=True):
|
503 |
if not os.path.exists(checkpoint_path):
|
504 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
|
505 |
+
|
506 |
+
# Load checkpoint with strict=False to ignore size mismatches
|
507 |
+
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
508 |
+
|
509 |
+
# Filter out size-mismatched keys
|
510 |
+
model_state_dict = model.state_dict()
|
511 |
+
filtered_state_dict = {k: v for k, v in state_dict.items()
|
512 |
+
if k in model_state_dict and v.size() == model_state_dict[k].size()}
|
513 |
+
|
514 |
+
# Load the filtered state dict
|
515 |
+
model.load_state_dict(filtered_state_dict, strict=strict)
|
516 |
+
|
517 |
+
# Print warnings for mismatched keys
|
518 |
+
missing_keys = [k for k in model_state_dict.keys() if k not in state_dict]
|
519 |
+
unexpected_keys = [k for k in state_dict.keys() if k not in model_state_dict]
|
520 |
+
size_mismatch_keys = [k for k in state_dict.keys()
|
521 |
+
if k in model_state_dict and state_dict[k].size() != model_state_dict[k].size()]
|
522 |
+
|
523 |
+
if missing_keys:
|
524 |
+
print(f"Missing keys in checkpoint: {missing_keys}")
|
525 |
+
if unexpected_keys:
|
526 |
+
print(f"Unexpected keys in checkpoint: {unexpected_keys}")
|
527 |
+
if size_mismatch_keys:
|
528 |
+
print(f"Size mismatch for keys: {size_mismatch_keys}")
|