curt-park's picture
Refactor code
1615d09
import torch
from torch import distributed as dist
from torch.utils import data
def get_rank():
if not dist.is_available() or not dist.is_initialized():
return 0
return dist.get_rank()
def synchronize():
if (
not dist.is_available()
or not dist.is_initialized()
or dist.get_world_size() == 1
):
return
dist.barrier()
def get_world_size():
if not dist.is_available() or not dist.is_initialized():
return 1
return dist.get_world_size()
def reduce_loss_dict(loss_dict):
world_size = get_world_size()
if world_size < 2:
return loss_dict
with torch.no_grad():
keys = []
losses = []
for k in loss_dict.keys():
keys.append(k)
losses.append(loss_dict[k])
losses = torch.stack(losses, 0)
dist.reduce(losses, dst=0)
if dist.get_rank() == 0:
losses /= world_size
reduced_losses = {k: v for k, v in zip(keys, losses)}
return reduced_losses
def get_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
def get_dp_wrapper(distributed):
class DPWrapper(
torch.nn.parallel.DistributedDataParallel
if distributed
else torch.nn.DataParallel
):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
return DPWrapper