PatchFusion / tools /train.py
Zhyever
refactor
1f418ff
import os
import os.path as osp
import argparse
import torch
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader
from mmengine.utils import mkdir_or_exist
from mmengine.config import Config, DictAction
from mmengine.logging import MMLogger
from estimator.utils import RunnerInfo, setup_env, log_env, fix_random_seed
from estimator.models.builder import build_model
from estimator.datasets.builder import build_dataset
from estimator.trainer import Trainer
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('config', help='train config file path')
parser.add_argument('--work-dir', help='the dir to save logs and models')
parser.add_argument(
'--resume',
action='store_true',
default=False,
help='resume from the latest checkpoint in the work_dir automatically')
parser.add_argument(
'--debug',
action='store_true',
default=False,
help='debug mode')
parser.add_argument(
'--log-name',
type=str, default='',
help='log_name for wandb')
parser.add_argument(
'--tags',
type=str, default='',
help='tags for wandb')
parser.add_argument(
'--amp',
action='store_true',
default=False,
help='enable automatic-mixed-precision training')
parser.add_argument(
'--seed',
type=int, default=621,
help='for debug')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
# When using PyTorch version >= 2.0.0, the `torch.distributed.launch`
# will pass the `--local-rank` parameter to `tools/train.py` instead
# of `--local_rank`.
parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def main():
args = parse_args()
# if args.debug:
# torch.autograd.set_detect_anomaly(True) # for debug
# load config
cfg = Config.fromfile(args.config)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
# work_dir is determined in this priority: CLI > segment in file > filename
cfg.work_dir = args.work_dir
cfg.work_dir = osp.join(cfg.work_dir, args.log_name)
mkdir_or_exist(cfg.work_dir)
cfg.debug = args.debug
cfg.log_name = args.log_name
tags = args.tags
if ',' in tags:
tag_list = tags.split(',')
else:
tag_list = [tags]
cfg.tags = tag_list
# fix seed
seed = args.seed
fix_random_seed(seed)
# start dist training
if cfg.launcher == 'none':
distributed = False
else:
distributed = True
env_cfg = cfg.get('env_cfg', dict(dist_cfg=dict(backend='nccl')))
rank, world_size, timestamp = setup_env(env_cfg, distributed, cfg.launcher)
# prepare basic text logger
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
log_cfg = dict(log_level='INFO', log_file=log_file)
log_cfg.setdefault('name', timestamp)
log_cfg.setdefault('logger_name', 'patchstitcher')
# `torch.compile` in PyTorch 2.0 could close all user defined handlers
# unexpectedly. Using file mode 'a' can help prevent abnormal
# termination of the FileHandler and ensure that the log file could
# be continuously updated during the lifespan of the runner.
log_cfg.setdefault('file_mode', 'a')
logger = MMLogger.get_instance(**log_cfg)
# save some information useful during the training
runner_info = RunnerInfo()
runner_info.config = cfg # ideally, cfg should not be changed during process. information should be temp saved in runner_info
runner_info.logger = logger # easier way: use print_log("infos", logger='current')
runner_info.rank = rank
runner_info.distributed = distributed
runner_info.launcher = cfg.launcher
runner_info.seed = seed
runner_info.world_size = world_size
runner_info.work_dir = cfg.work_dir
runner_info.timestamp = timestamp
# start wandb
if runner_info.rank == 0 and cfg.debug == False:
wandb.init(
project=cfg.project,
name=cfg.log_name+"_"+runner_info.timestamp,
tags=cfg.tags,
dir=runner_info.work_dir,
config=cfg, # have a test
settings=wandb.Settings(start_method="fork"))
wandb.define_metric("Val/step")
wandb.define_metric("Val/*", step_metric="Val/step")
wandb.define_metric("Train/step")
wandb.define_metric("Train/*", step_metric="Train/step")
log_env(cfg, env_cfg, runner_info, logger)
# resume training (future)
cfg.resume = args.resume
# build model
model = build_model(cfg.model)
if runner_info.distributed:
torch.cuda.set_device(runner_info.rank)
if cfg.get('convert_syncbn', False):
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = model.cuda(runner_info.rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[runner_info.rank], output_device=runner_info.rank,
find_unused_parameters=cfg.get('find_unused_parameters', False))
logger.info(model)
else:
model = model.cuda(runner_info.rank)
logger.info(model)
# build dataloader
dataset = build_dataset(cfg.train_dataloader.dataset)
if runner_info.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
train_sampler = None
train_dataloader = DataLoader(
dataset,
batch_size=cfg.train_dataloader.batch_size,
shuffle=(train_sampler is None),
num_workers=cfg.train_dataloader.num_workers,
pin_memory=True,
persistent_workers=True,
sampler=train_sampler)
dataset = build_dataset(cfg.val_dataloader.dataset)
if runner_info.distributed:
val_sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False)
else:
val_sampler = None
val_dataloader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=cfg.val_dataloader.num_workers,
pin_memory=True,
persistent_workers=True,
sampler=val_sampler)
# everything is ready, start training. But before that, save your config!
cfg.dump(osp.join(cfg.work_dir, 'config.py'))
# build trainer
trainer = Trainer(
config=cfg,
runner_info=runner_info,
train_sampler=train_sampler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
model=model)
trainer.run()
wandb.finish()
if __name__ == '__main__':
main()