Update train.py (#462)
Browse files
train.py
CHANGED
@@ -123,9 +123,12 @@ def train(hyp, tb_writer, opt, device):
|
|
123 |
|
124 |
# load model
|
125 |
try:
|
|
|
126 |
ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
|
127 |
-
if k in model.state_dict() and
|
|
|
128 |
model.load_state_dict(ckpt['model'], strict=False)
|
|
|
129 |
except KeyError as e:
|
130 |
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
|
131 |
"Please delete or update %s and try again, or use --weights '' to train from scratch." \
|
|
|
123 |
|
124 |
# load model
|
125 |
try:
|
126 |
+
exclude = ['anchor'] # exclude keys
|
127 |
ckpt['model'] = {k: v for k, v in ckpt['model'].float().state_dict().items()
|
128 |
+
if k in model.state_dict() and not any(x in k for x in exclude)
|
129 |
+
and model.state_dict()[k].shape == v.shape}
|
130 |
model.load_state_dict(ckpt['model'], strict=False)
|
131 |
+
print('Transferred %g/%g items from %s' % (len(ckpt['model']), len(model.state_dict()), weights))
|
132 |
except KeyError as e:
|
133 |
s = "%s is not compatible with %s. This may be due to model differences or %s may be out of date. " \
|
134 |
"Please delete or update %s and try again, or use --weights '' to train from scratch." \
|