glenn-jocher commited on
Commit
65857ad
·
1 Parent(s): 7b0c9be

update train.py ckpt loading

Browse files
Files changed (1) hide show
  1. train.py +1 -2
train.py CHANGED
@@ -119,8 +119,7 @@ def train(hyp):
119
 
120
  # load model
121
  try:
122
- ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
123
- if model.state_dict()[k].shape == v.shape} # to FP32, filter
124
  model.load_state_dict(ckpt['model'], strict=False)
125
  except KeyError as e:
126
  s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
 
119
 
120
  # load model
121
  try:
122
+ ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items() if k in model.state_dict()}
 
123
  model.load_state_dict(ckpt['model'], strict=False)
124
  except KeyError as e:
125
  s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \