File size: 474 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
from torch.distributed.distributed_c10d import is_initialized
from torch.utils.data import Dataset, DistributedSampler
def get_ddp_sampler(dataset: Dataset, epoch: int):
"""
This function will create a DistributedSampler if DDP is initialized,
and will just return None if DDP is not initialized.
"""
if is_initialized():
sampler = DistributedSampler(dataset)
sampler.set_epoch(epoch)
else:
sampler = None
return sampler
|