gaur3009 commited on
Commit
960b661
·
verified ·
1 Parent(s): 7335794

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +27 -4
networks.py CHANGED
@@ -12,7 +12,7 @@ class Options:
12
  # Default values
13
  self.fine_height = 256
14
  self.fine_width = 192
15
- self.grid_size = 3
16
  self.use_dropout = False
17
 
18
  def weights_init_normal(m):
@@ -499,7 +499,30 @@ def save_checkpoint(model, save_path):
499
  os.makedirs(os.path.dirname(save_path))
500
  torch.save(model.state_dict(), save_path)
501
 
502
- def load_checkpoint(model, checkpoint_path):
503
  if not os.path.exists(checkpoint_path):
504
- return
505
- model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Default values
13
  self.fine_height = 256
14
  self.fine_width = 192
15
+ self.grid_size = 5
16
  self.use_dropout = False
17
 
18
  def weights_init_normal(m):
 
499
  os.makedirs(os.path.dirname(save_path))
500
  torch.save(model.state_dict(), save_path)
501
 
502
+ def load_checkpoint(model, checkpoint_path, strict=True):
503
  if not os.path.exists(checkpoint_path):
504
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
505
+
506
+ # Load checkpoint with strict=False to ignore size mismatches
507
+ state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
508
+
509
+ # Filter out size-mismatched keys
510
+ model_state_dict = model.state_dict()
511
+ filtered_state_dict = {k: v for k, v in state_dict.items()
512
+ if k in model_state_dict and v.size() == model_state_dict[k].size()}
513
+
514
+ # Load the filtered state dict
515
+ model.load_state_dict(filtered_state_dict, strict=strict)
516
+
517
+ # Print warnings for mismatched keys
518
+ missing_keys = [k for k in model_state_dict.keys() if k not in state_dict]
519
+ unexpected_keys = [k for k in state_dict.keys() if k not in model_state_dict]
520
+ size_mismatch_keys = [k for k in state_dict.keys()
521
+ if k in model_state_dict and state_dict[k].size() != model_state_dict[k].size()]
522
+
523
+ if missing_keys:
524
+ print(f"Missing keys in checkpoint: {missing_keys}")
525
+ if unexpected_keys:
526
+ print(f"Unexpected keys in checkpoint: {unexpected_keys}")
527
+ if size_mismatch_keys:
528
+ print(f"Size mismatch for keys: {size_mismatch_keys}")