Resume with custom anchors fix (#2361)
Browse files* Resume with custom anchors fix
* Update train.py
train.py
CHANGED
@@ -75,10 +75,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
75 |
with torch_distributed_zero_first(rank):
|
76 |
attempt_download(weights) # download if not found locally
|
77 |
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
78 |
-
|
79 |
-
|
80 |
-
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc).to(device) # create
|
81 |
-
exclude = ['anchor'] if opt.cfg or hyp.get('anchors') else [] # exclude keys
|
82 |
state_dict = ckpt['model'].float().state_dict() # to FP32
|
83 |
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
|
84 |
model.load_state_dict(state_dict, strict=False) # load
|
@@ -216,6 +214,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
216 |
# Anchors
|
217 |
if not opt.noautoanchor:
|
218 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
|
|
219 |
|
220 |
# Model parameters
|
221 |
hyp['box'] *= 3. / nl # scale to layers
|
|
|
75 |
with torch_distributed_zero_first(rank):
|
76 |
attempt_download(weights) # download if not found locally
|
77 |
ckpt = torch.load(weights, map_location=device) # load checkpoint
|
78 |
+
model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
|
79 |
+
exclude = ['anchor'] if (opt.cfg or hyp.get('anchors')) and not opt.resume else [] # exclude keys
|
|
|
|
|
80 |
state_dict = ckpt['model'].float().state_dict() # to FP32
|
81 |
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=exclude) # intersect
|
82 |
model.load_state_dict(state_dict, strict=False) # load
|
|
|
214 |
# Anchors
|
215 |
if not opt.noautoanchor:
|
216 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
217 |
+
model.half().float() # pre-reduce anchor precision
|
218 |
|
219 |
# Model parameters
|
220 |
hyp['box'] *= 3. / nl # scale to layers
|