File size: 804 Bytes
ad947b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from typing import List
from torch import distributed
def barrier():
if distributed.is_initialized():
distributed.barrier()
else:
pass
def broadcast(data, src):
if distributed.is_initialized():
distributed.broadcast(data, src)
else:
pass
def all_gather(data: List, src):
if distributed.is_initialized():
distributed.all_gather(data, src)
else:
data[0] = src
def get_rank():
if distributed.is_initialized():
return distributed.get_rank()
else:
return 0
def get_world_size():
if distributed.is_initialized():
return distributed.get_world_size()
else:
return 1
def chunk_size(size, rank, world_size):
extra = rank < size % world_size
return size // world_size + extra |