Spaces:
Runtime error
Runtime error
File size: 1,618 Bytes
2cdd41c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import torch
from torch import distributed as dist
from torch.utils import data
def get_rank():
if not dist.is_available() or not dist.is_initialized():
return 0
return dist.get_rank()
def synchronize():
if not dist.is_available() or not dist.is_initialized() or dist.get_world_size() == 1:
return
dist.barrier()
def get_world_size():
if not dist.is_available() or not dist.is_initialized():
return 1
return dist.get_world_size()
def reduce_loss_dict(loss_dict):
world_size = get_world_size()
if world_size < 2:
return loss_dict
with torch.no_grad():
keys = []
losses = []
for k in loss_dict.keys():
keys.append(k)
losses.append(loss_dict[k])
losses = torch.stack(losses, 0)
dist.reduce(losses, dst=0)
if dist.get_rank() == 0:
losses /= world_size
reduced_losses = {k: v for k, v in zip(keys, losses)}
return reduced_losses
def get_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
def get_dp_wrapper(distributed):
class DPWrapper(torch.nn.parallel.DistributedDataParallel if distributed else torch.nn.DataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
return DPWrapper
|