Commit
·
65857ad
1
Parent(s):
7b0c9be
update train.py ckpt loading
Browse files
train.py
CHANGED
@@ -119,8 +119,7 @@ def train(hyp):
|
|
119 |
|
120 |
# load model
|
121 |
try:
|
122 |
-
ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
|
123 |
-
if model.state_dict()[k].shape == v.shape} # to FP32, filter
|
124 |
model.load_state_dict(ckpt['model'], strict=False)
|
125 |
except KeyError as e:
|
126 |
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
|
|
|
119 |
|
120 |
# load model
|
121 |
try:
|
122 |
+
ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items() if k in model.state_dict()}
|
|
|
123 |
model.load_state_dict(ckpt['model'], strict=False)
|
124 |
except KeyError as e:
|
125 |
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
|