Update train.py (#3667)
Browse files
train.py
CHANGED
@@ -22,7 +22,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|
22 |
from torch.utils.tensorboard import SummaryWriter
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
-
import test #
|
26 |
from models.experimental import attempt_load
|
27 |
from models.yolo import Model
|
28 |
from utils.autoanchor import check_anchors
|
@@ -39,7 +39,11 @@ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
|
|
39 |
logger = logging.getLogger(__name__)
|
40 |
|
41 |
|
42 |
-
def train(hyp,
|
|
|
|
|
|
|
|
|
43 |
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
44 |
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
|
45 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
|
@@ -341,7 +345,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
341 |
save_dir.glob('train*.jpg') if x.exists()]})
|
342 |
|
343 |
# end batch ------------------------------------------------------------------------------------------------
|
344 |
-
|
345 |
# Scheduler
|
346 |
lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
|
347 |
scheduler.step()
|
@@ -404,12 +408,11 @@ def train(hyp, opt, device, tb_writer=None):
|
|
404 |
torch.save(ckpt, best)
|
405 |
if wandb_logger.wandb:
|
406 |
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
|
407 |
-
wandb_logger.log_model(
|
408 |
-
last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
409 |
del ckpt
|
410 |
|
411 |
# end epoch ----------------------------------------------------------------------------------------------------
|
412 |
-
# end training
|
413 |
if rank in [-1, 0]:
|
414 |
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
|
415 |
if plots:
|
|
|
22 |
from torch.utils.tensorboard import SummaryWriter
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
+
import test # for end-of-epoch mAP
|
26 |
from models.experimental import attempt_load
|
27 |
from models.yolo import Model
|
28 |
from utils.autoanchor import check_anchors
|
|
|
39 |
logger = logging.getLogger(__name__)
|
40 |
|
41 |
|
42 |
+
def train(hyp,
|
43 |
+
opt,
|
44 |
+
device,
|
45 |
+
tb_writer=None
|
46 |
+
):
|
47 |
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
48 |
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
|
49 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
|
|
|
345 |
save_dir.glob('train*.jpg') if x.exists()]})
|
346 |
|
347 |
# end batch ------------------------------------------------------------------------------------------------
|
348 |
+
|
349 |
# Scheduler
|
350 |
lr = [x['lr'] for x in optimizer.param_groups] # for tensorboard
|
351 |
scheduler.step()
|
|
|
408 |
torch.save(ckpt, best)
|
409 |
if wandb_logger.wandb:
|
410 |
if ((epoch + 1) % opt.save_period == 0 and not final_epoch) and opt.save_period != -1:
|
411 |
+
wandb_logger.log_model(last.parent, opt, epoch, fi, best_model=best_fitness == fi)
|
|
|
412 |
del ckpt
|
413 |
|
414 |
# end epoch ----------------------------------------------------------------------------------------------------
|
415 |
+
# end training -----------------------------------------------------------------------------------------------------
|
416 |
if rank in [-1, 0]:
|
417 |
logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
|
418 |
if plots:
|