Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
import argparse | |
import os | |
import os.path as osp | |
import torch | |
import torch.distributed as dist | |
import sys | |
ROOT = os.getcwd() | |
if str(ROOT) not in sys.path: | |
sys.path.append(str(ROOT)) | |
from yolov6.core.engine import Trainer | |
from yolov6.utils.config import Config | |
from yolov6.utils.events import LOGGER, save_yaml | |
from yolov6.utils.envs import get_envs, select_device, set_random_seed | |
def get_args_parser(add_help=True): | |
parser = argparse.ArgumentParser(description='YOLOv6 PyTorch Training', add_help=add_help) | |
parser.add_argument('--data-path', default='./data/coco.yaml', type=str, help='dataset path') | |
parser.add_argument('--conf-file', default='./configs/yolov6s.py', type=str, help='experiment description file') | |
parser.add_argument('--img-size', type=int, default=640, help='train, val image size (pixels)') | |
parser.add_argument('--batch-size', default=32, type=int, help='total batch size for all GPUs') | |
parser.add_argument('--epochs', default=400, type=int, help='number of total epochs to run') | |
parser.add_argument('--workers', default=8, type=int, help='number of data loading workers (default: 8)') | |
parser.add_argument('--device', default='0', type=str, help='cuda device, i.e. 0 or 0,1,2,3 or cpu') | |
parser.add_argument('--noval', action='store_true', help='only evaluate in final epoch') | |
parser.add_argument('--check-images', action='store_true', help='check images when initializing datasets') | |
parser.add_argument('--check-labels', action='store_true', help='check label files when initializing datasets') | |
parser.add_argument('--output-dir', default='./runs/train', type=str, help='path to save outputs') | |
parser.add_argument('--name', default='exp', type=str, help='experiment name, save to output_dir/name') | |
parser.add_argument('--dist_url', type=str, default="tcp://127.0.0.1:8888") | |
parser.add_argument('--gpu_count', type=int, default=0) | |
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify') | |
return parser | |
def check_and_init(args): | |
'''check config files and device, and initialize ''' | |
# check files | |
args.save_dir = osp.join(args.output_dir, args.name) | |
os.makedirs(args.save_dir, exist_ok=True) | |
cfg = Config.fromfile(args.conf_file) | |
# check device | |
device = select_device(args.device) | |
# set random seed | |
set_random_seed(1+args.rank, deterministic=(args.rank == -1)) | |
# save args | |
save_yaml(vars(args), osp.join(args.save_dir, 'args.yaml')) | |
return cfg, device | |
def main(args): | |
'''main function of training''' | |
# Setup | |
args.rank, args.local_rank, args.world_size = get_envs() | |
LOGGER.info(f'training args are: {args}\n') | |
cfg, device = check_and_init(args) | |
if args.local_rank != -1: # if DDP mode | |
torch.cuda.set_device(args.local_rank) | |
device = torch.device('cuda', args.local_rank) | |
LOGGER.info('Initializing process group... ') | |
dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo", \ | |
init_method=args.dist_url, rank=args.local_rank, world_size=args.world_size) | |
# Start | |
trainer = Trainer(args, cfg, device) | |
trainer.train() | |
# End | |
if args.world_size > 1 and args.rank == 0: | |
LOGGER.info('Destroying process group... ') | |
dist.destroy_process_group() | |
if __name__ == '__main__': | |
args = get_args_parser().parse_args() | |
main(args) | |