File size: 804 Bytes
ad947b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import List
from torch import distributed


def barrier():
    if distributed.is_initialized():
        distributed.barrier()
    else:
        pass


def broadcast(data, src):
    if distributed.is_initialized():
        distributed.broadcast(data, src)
    else:
        pass


def all_gather(data: List, src):
    if distributed.is_initialized():
        distributed.all_gather(data, src)
    else:
        data[0] = src


def get_rank():
    if distributed.is_initialized():
        return distributed.get_rank()
    else:
        return 0


def get_world_size():
    if distributed.is_initialized():
        return distributed.get_world_size()
    else:
        return 1


def chunk_size(size, rank, world_size):
    extra = rank < size % world_size
    return size // world_size + extra