glenn-jocher commited on
Commit
6d6e2ca
·
unverified ·
1 Parent(s): ac34834

Update train.py (#3667)

Browse files
Files changed (1) hide show
  1. train.py +9 -6
train.py CHANGED
@@ -22,7 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
22
  from torch.utils.tensorboard import SummaryWriter
23
  from tqdm import tqdm
24
 
25
- import test # import test.py to get mAP after each epoch
26
  from models.experimental import attempt_load
27
  from models.yolo import Model
28
  from utils.autoanchor import check_anchors
@@ -39,7 +39,11 @@ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
39
  logger = logging.getLogger(__name__)
40
 
41
 
42
- def train(hyp, opt, device, tb_writer=None):
 
 
 
 
43
  logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
44
  save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
45
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
@@ -341,7 +345,7 @@ def train(hyp, opt, device, tb_writer=None):
341
  save_dir.glob('train*.jpg') if x.exists()]})
342
 
343
  # end batch ------------------------------------------------------------------------------------------------
344
-
345
  # Scheduler
346
  lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
347
  scheduler.step()
@@ -404,12 +408,11 @@ def train(hyp, opt, device, tb_writer=None):
404
  torch.save(ckpt, best)
405
  if wandb_logger.wandb:
406
  if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
407
- wandb_logger.log_model(
408
- last.parent, opt, epoch, fi, best_model=best_fitness == fi)
409
  del ckpt
410
 
411
  # end epoch ----------------------------------------------------------------------------------------------------
412
- # end training
413
  if rank in [-1, 0]:
414
  logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
415
  if plots:
 
22
  from torch.utils.tensorboard import SummaryWriter
23
  from tqdm import tqdm
24
 
25
+ import test # for end-of-epoch mAP
26
  from models.experimental import attempt_load
27
  from models.yolo import Model
28
  from utils.autoanchor import check_anchors
 
39
  logger = logging.getLogger(__name__)
40
 
41
 
42
+ def train(hyp,
43
+ opt,
44
+ device,
45
+ tb_writer=None
46
+ ):
47
  logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
48
  save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
49
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
 
345
  save_dir.glob('train*.jpg') if x.exists()]})
346
 
347
  # end batch ------------------------------------------------------------------------------------------------
348
+
349
  # Scheduler
350
  lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
351
  scheduler.step()
 
408
  torch.save(ckpt, best)
409
  if wandb_logger.wandb:
410
  if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
411
+ wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
 
412
  del ckpt
413
 
414
  # end epoch ----------------------------------------------------------------------------------------------------
415
+ # end training -----------------------------------------------------------------------------------------------------
416
  if rank in [-1, 0]:
417
  logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
418
  if plots: