Spaces:
Runtime error
Runtime error
import torch | |
from torch import distributed as dist | |
from torch.utils import data | |
def get_rank(): | |
if not dist.is_available() or not dist.is_initialized(): | |
return 0 | |
return dist.get_rank() | |
def synchronize(): | |
if ( | |
not dist.is_available() | |
or not dist.is_initialized() | |
or dist.get_world_size() == 1 | |
): | |
return | |
dist.barrier() | |
def get_world_size(): | |
if not dist.is_available() or not dist.is_initialized(): | |
return 1 | |
return dist.get_world_size() | |
def reduce_loss_dict(loss_dict): | |
world_size = get_world_size() | |
if world_size < 2: | |
return loss_dict | |
with torch.no_grad(): | |
keys = [] | |
losses = [] | |
for k in loss_dict.keys(): | |
keys.append(k) | |
losses.append(loss_dict[k]) | |
losses = torch.stack(losses, 0) | |
dist.reduce(losses, dst=0) | |
if dist.get_rank() == 0: | |
losses /= world_size | |
reduced_losses = {k: v for k, v in zip(keys, losses)} | |
return reduced_losses | |
def get_sampler(dataset, shuffle, distributed): | |
if distributed: | |
return data.distributed.DistributedSampler(dataset, shuffle=shuffle) | |
if shuffle: | |
return data.RandomSampler(dataset) | |
else: | |
return data.SequentialSampler(dataset) | |
def get_dp_wrapper(distributed): | |
class DPWrapper( | |
torch.nn.parallel.DistributedDataParallel | |
if distributed | |
else torch.nn.DataParallel | |
): | |
def __getattr__(self, name): | |
try: | |
return super().__getattr__(name) | |
except AttributeError: | |
return getattr(self.module, name) | |
return DPWrapper | |