{ "device": "$torch.device(f'cuda:{dist.get_rank()}')", "network": { "_target_": "torch.nn.parallel.DistributedDataParallel", "module": "$@network_def.to(@device)", "device_ids": [ "@device" ] }, "train#sampler": { "_target_": "DistributedSampler", "dataset": "@train#dataset", "even_divisible": true, "shuffle": true }, "train#dataloader#sampler": "@train#sampler", "train#dataloader#shuffle": false, "train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]", "validate#sampler": { "_target_": "DistributedSampler", "dataset": "@validate#dataset", "even_divisible": false, "shuffle": false }, "validate#dataloader#sampler": "@validate#sampler", "validate#evaluator#val_handlers": "$None if dist.get_rank() > 0 else @validate#handlers", "training": [ "$import torch.distributed as dist", "$dist.init_process_group(backend='nccl')", "$torch.cuda.set_device(@device)", "$monai.utils.set_determinism(seed=123)", "$setattr(torch.backends.cudnn, 'benchmark', True)", "$import logging", "$@train#trainer.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)", "$@validate#evaluator.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)", "$@train#trainer.run()", "$dist.destroy_process_group()" ] }