Commit
·
a21bd06
1
Parent(s):
455f7b8
Update train.py forward simplification
Browse files
train.py
CHANGED
@@ -265,18 +265,12 @@ def train(hyp, opt, device, tb_writer=None):
|
|
265 |
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
|
266 |
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
|
267 |
|
268 |
-
#
|
269 |
with amp.autocast(enabled=cuda):
|
270 |
-
#
|
271 |
-
|
272 |
-
|
273 |
-
# Loss
|
274 |
-
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
|
275 |
if rank != -1:
|
276 |
loss *= opt.world_size # gradient averaged between devices in DDP mode
|
277 |
-
# if not torch.isfinite(loss):
|
278 |
-
# logger.info('WARNING: non-finite loss, ending training ', loss_items)
|
279 |
-
# return results
|
280 |
|
281 |
# Backward
|
282 |
scaler.scale(loss).backward()
|
|
|
265 |
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
|
266 |
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
|
267 |
|
268 |
+
# Forward
|
269 |
with amp.autocast(enabled=cuda):
|
270 |
+
pred = model(imgs) # forward
|
271 |
+
loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
|
|
|
|
|
|
|
272 |
if rank != -1:
|
273 |
loss *= opt.world_size # gradient averaged between devices in DDP mode
|
|
|
|
|
|
|
274 |
|
275 |
# Backward
|
276 |
scaler.scale(loss).backward()
|