Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Megvii, Inc. and its affiliates. | |
import os | |
import sys | |
import random | |
import time | |
import warnings | |
from loguru import logger | |
import torch | |
import torch.backends.cudnn as cudnn | |
from yolox.exp import Exp, get_exp | |
from yolox.core import Trainer | |
from yolox.utils import configure_module, configure_omp | |
from yolox.tools.train import make_parser | |
class AssignVisualizer(Trainer): | |
def __init__(self, exp: Exp, args): | |
super().__init__(exp, args) | |
self.batch_cnt = 0 | |
self.vis_dir = os.path.join(self.file_name, "vis") | |
os.makedirs(self.vis_dir, exist_ok=True) | |
def train_one_iter(self): | |
iter_start_time = time.time() | |
inps, targets = self.prefetcher.next() | |
inps = inps.to(self.data_type) | |
targets = targets.to(self.data_type) | |
targets.requires_grad = False | |
inps, targets = self.exp.preprocess(inps, targets, self.input_size) | |
data_end_time = time.time() | |
with torch.cuda.amp.autocast(enabled=self.amp_training): | |
path_prefix = os.path.join(self.vis_dir, f"assign_vis_{self.batch_cnt}_") | |
self.model.visualize(inps, targets, path_prefix) | |
if self.use_model_ema: | |
self.ema_model.update(self.model) | |
iter_end_time = time.time() | |
self.meter.update( | |
iter_time=iter_end_time - iter_start_time, | |
data_time=data_end_time - iter_start_time, | |
) | |
self.batch_cnt += 1 | |
if self.batch_cnt >= self.args.max_batch: | |
sys.exit(0) | |
def after_train(self): | |
logger.info("Finish visualize assignment, exit...") | |
def assign_vis_parser(): | |
parser = make_parser() | |
parser.add_argument("--max-batch", type=int, default=1, help="max batch of images to visualize") | |
return parser | |
def main(exp: Exp, args): | |
if exp.seed is not None: | |
random.seed(exp.seed) | |
torch.manual_seed(exp.seed) | |
cudnn.deterministic = True | |
warnings.warn( | |
"You have chosen to seed training. This will turn on the CUDNN deterministic setting, " | |
"which can slow down your training considerably! You may see unexpected behavior " | |
"when restarting from checkpoints." | |
) | |
# set environment variables for distributed training | |
configure_omp() | |
cudnn.benchmark = True | |
visualizer = AssignVisualizer(exp, args) | |
visualizer.train() | |
if __name__ == "__main__": | |
configure_module() | |
args = assign_vis_parser().parse_args() | |
exp = get_exp(args.exp_file, args.name) | |
exp.merge(args.opts) | |
if not args.experiment_name: | |
args.experiment_name = exp.exp_name | |
main(exp, args) | |