|
import random |
|
import warnings |
|
|
|
import numpy as np |
|
import torch |
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel |
|
from mmcv.runner import ( |
|
DistSamplerSeedHook, |
|
Fp16OptimizerHook, |
|
OptimizerHook, |
|
GradientCumulativeFp16OptimizerHook, |
|
GradientCumulativeOptimizerHook, |
|
build_runner) |
|
|
|
from mogen.core.distributed_wrapper import DistributedDataParallelWrapper |
|
from mogen.core.evaluation import DistEvalHook, EvalHook |
|
from mogen.core.optimizer import build_optimizers |
|
from mogen.datasets import build_dataloader, build_dataset |
|
from mogen.utils import get_root_logger |
|
|
|
|
|
def set_random_seed(seed, deterministic=False): |
|
"""Set random seed. |
|
Args: |
|
seed (int): Seed to be used. |
|
deterministic (bool): Whether to set the deterministic option for |
|
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` |
|
to True and `torch.backends.cudnn.benchmark` to False. |
|
Default: False. |
|
""" |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
if deterministic: |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
def train_model(model, |
|
dataset, |
|
cfg, |
|
distributed=False, |
|
validate=False, |
|
timestamp=None, |
|
device='cuda', |
|
meta=None): |
|
"""Main api for training model.""" |
|
logger = get_root_logger(cfg.log_level) |
|
|
|
|
|
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] |
|
|
|
data_loaders = [ |
|
build_dataloader( |
|
ds, |
|
cfg.data.samples_per_gpu, |
|
cfg.data.workers_per_gpu, |
|
|
|
num_gpus=len(cfg.gpu_ids), |
|
dist=distributed, |
|
round_up=True, |
|
sampler_cfg=cfg.data.sampler_cfg, |
|
batch_sampler_cfg=cfg.data.batch_sampler_cfg, |
|
seed=cfg.seed) for ds in dataset |
|
] |
|
|
|
|
|
use_adversarial_train = cfg.get('use_adversarial_train', False) |
|
|
|
|
|
if distributed: |
|
find_unused_parameters = cfg.get('find_unused_parameters', True) |
|
|
|
|
|
if use_adversarial_train: |
|
|
|
model = DistributedDataParallelWrapper( |
|
model, |
|
device_ids=[torch.cuda.current_device()], |
|
broadcast_buffers=False, |
|
find_unused_parameters=find_unused_parameters) |
|
else: |
|
model = MMDistributedDataParallel( |
|
model.cuda(), |
|
device_ids=[torch.cuda.current_device()], |
|
broadcast_buffers=False, |
|
find_unused_parameters=find_unused_parameters) |
|
else: |
|
if device == 'cuda': |
|
model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), |
|
device_ids=cfg.gpu_ids) |
|
elif device == 'cpu': |
|
model = model.cpu() |
|
else: |
|
raise ValueError(F'unsupported device name {device}.') |
|
|
|
|
|
optimizer = build_optimizers(model, cfg.optimizer) |
|
|
|
if cfg.get('runner') is None: |
|
cfg.runner = { |
|
'type': 'EpochBasedRunner', |
|
'max_epochs': cfg.total_epochs |
|
} |
|
warnings.warn( |
|
'config is now expected to have a `runner` section, ' |
|
'please set `runner` in your config.', UserWarning) |
|
|
|
runner = build_runner(cfg.runner, |
|
default_args=dict(model=model, |
|
batch_processor=None, |
|
optimizer=optimizer, |
|
work_dir=cfg.work_dir, |
|
logger=logger, |
|
meta=meta)) |
|
|
|
|
|
runner.timestamp = timestamp |
|
|
|
if use_adversarial_train: |
|
|
|
|
|
optimizer_config = None |
|
else: |
|
if distributed and 'type' not in cfg.optimizer_config: |
|
optimizer_config = OptimizerHook(**cfg.optimizer_config) |
|
else: |
|
optimizer_config = cfg.optimizer_config |
|
|
|
|
|
runner.register_training_hooks(cfg.lr_config, |
|
optimizer_config, |
|
cfg.checkpoint_config, |
|
cfg.log_config, |
|
cfg.get('momentum_config', None), |
|
custom_hooks_config=cfg.get( |
|
'custom_hooks', None)) |
|
if distributed: |
|
runner.register_hook(DistSamplerSeedHook()) |
|
|
|
|
|
if validate: |
|
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) |
|
val_dataloader = build_dataloader( |
|
val_dataset, |
|
samples_per_gpu=cfg.data.samples_per_gpu, |
|
workers_per_gpu=cfg.data.workers_per_gpu, |
|
dist=distributed, |
|
shuffle=False, |
|
round_up=True) |
|
eval_cfg = cfg.get('evaluation', {}) |
|
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' |
|
eval_hook = DistEvalHook if distributed else EvalHook |
|
runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) |
|
|
|
if cfg.resume_from: |
|
runner.resume(cfg.resume_from) |
|
elif cfg.load_from: |
|
runner.load_checkpoint(cfg.load_from) |
|
runner.run(data_loaders, cfg.workflow) |
|
|