# This file contains the changes to implement DDP training with the train.yaml config. | |
is_dist: '$dist.is_initialized()' | |
rank: '$dist.get_rank() if @is_dist else 0' | |
device: '$torch.device(f"cuda:{@rank}" if torch.cuda.is_available() else "cpu")' # assumes GPU # matches rank # | |
# wrap the network in a DistributedDataParallel instance, moving it to the chosen device for this process | |
network: | |
_target_: torch.nn.parallel.DistributedDataParallel | |
module: $@network_def.to(@device) | |
device_ids: ['@device'] | |
find_unused_parameters: true | |
train_sampler: | |
_target_: DistributedSampler | |
dataset: '@train_dataset' | |
even_divisible: true | |
shuffle: true | |
train_dataloader#sampler: '@train_sampler' | |
train_dataloader#shuffle: false | |
val_sampler: | |
_target_: DistributedSampler | |
dataset: '@val_dataset' | |
even_divisible: false | |
shuffle: false | |
val_dataloader#sampler: '@val_sampler' | |
run: | |
- $import torch.distributed as dist | |
- $dist.init_process_group(backend='nccl') | |
- $torch.cuda.set_device(@device) | |
- $monai.utils.set_determinism(seed=123) # may want to choose a different seed or not do this here | |
- [email protected]() | |
- $dist.destroy_process_group() | |