File size: 8,001 Bytes
1ea89dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import os
import pickle
import platform
import subprocess
import warnings

import cv2
import torch
import torch.utils.data.distributed
from torch import distributed as dist
from torch import multiprocessing as mp

_LOCAL_PROCESS_GROUP = None


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def get_local_rank() -> int:
    """
    Returns:
        The rank of the current process within the local (per-machine) process group.
    """
    if not is_dist_avail_and_initialized():
        return 0
    assert _LOCAL_PROCESS_GROUP is not None
    return dist.get_rank(group=_LOCAL_PROCESS_GROUP)


def get_local_size() -> int:
    """
    Returns:
        The size of the per-machine process group,
        i.e. the number of processes per machine.
    """
    if not is_dist_avail_and_initialized():
        return 1
    assert _LOCAL_PROCESS_GROUP is not None
    return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)


def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()


def barrier():
    if not is_dist_avail_and_initialized():
        return
    dist.barrier()


def is_main_process():
    return get_rank() == 0


def is_rank_zero(args):
    return args.rank == 0


def get_dist_info():
    if dist.is_available() and dist.is_initialized():
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1
    return rank, world_size


def setup_multi_processes(cfg):
    """Setup multi-processing environment variables."""
    # set multi-process start method as `fork` to speed up the training
    if platform.system() != "Windows":
        mp_start_method = cfg.get("mp_start_method", "fork")
        current_method = mp.get_start_method(allow_none=True)
        if current_method is not None and current_method != mp_start_method:
            warnings.warn(
                f"Multi-processing start method `{mp_start_method}` is "
                f"different from the previous setting `{current_method}`."
                f"It will be force set to `{mp_start_method}`. You can change "
                f"this behavior by changing `mp_start_method` in your config."
            )
        mp.set_start_method(mp_start_method, force=True)

    # disable opencv multithreading to avoid system being overloaded
    # opencv_num_threads = cfg.get('opencv_num_threads', 0)
    # cv2.setNumThreads(opencv_num_threads)

    # setup OMP threads
    # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py  # noqa
    # workers_per_gpu = cfg.get('workers_per_gpu', 4)

    # if 'OMP_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
    #     omp_num_threads = 1
    #     warnings.warn(
    #         f'Setting OMP_NUM_THREADS environment variable for each process '
    #         f'to be {omp_num_threads} in default, to avoid your system being '
    #         f'overloaded, please further tune the variable for optimal '
    #         f'performance in your application as needed.')
    #     os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)

    # setup MKL threads
    # if 'MKL_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
    #     mkl_num_threads = os.environ.get('OMP_NUM_THREADS', 1)
    #     warnings.warn(
    #         f'Setting MKL_NUM_THREADS environment variable for each process '
    #         f'to be {mkl_num_threads} in default, to avoid your system being '
    #         f'overloaded, please further tune the variable for optimal '
    #         f'performance in your application as needed.')
    #     os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)


def setup_slurm(backend: str, port: str) -> None:
    proc_id = int(os.environ["SLURM_PROCID"])
    ntasks = int(os.environ["SLURM_NTASKS"])
    node_list = os.environ["SLURM_NODELIST"]

    num_gpus = torch.cuda.device_count()

    torch.cuda.set_device(proc_id % num_gpus)
    if "MASTER_ADDR" not in os.environ:
        addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
        os.environ["MASTER_PORT"] = str(port)
        os.environ["MASTER_ADDR"] = addr
    else:
        addr = os.environ["MASTER_ADDR"]
    os.environ["WORLD_SIZE"] = str(ntasks)
    os.environ["LOCAL_RANK"] = str(proc_id % num_gpus)
    os.environ["RANK"] = str(proc_id)
    print(
        proc_id,
        ntasks,
        num_gpus,
        proc_id % num_gpus,
        node_list,
        addr,
        os.environ["MASTER_PORT"],
        os.system("nvidia-smi -L"),
    )
    dist.init_process_group(backend, rank=proc_id, world_size=ntasks)


def sync_tensor_across_gpus(t, dim=0, cat=True):
    if t is None or not (dist.is_available() and dist.is_initialized()):
        return t
    t = torch.atleast_1d(t)
    group = dist.group.WORLD
    group_size = torch.distributed.get_world_size(group)

    local_size = torch.tensor(t.size(dim), device=t.device)
    all_sizes = [torch.zeros_like(local_size) for _ in range(group_size)]
    dist.all_gather(all_sizes, local_size)
    max_size = max(all_sizes)
    size_diff = max_size.item() - local_size.item()
    if size_diff:
        padding = torch.zeros(size_diff, device=t.device, dtype=t.dtype)
        t = torch.cat((t, padding))

    gather_t_tensor = [torch.zeros_like(t) for _ in range(group_size)]
    dist.all_gather(gather_t_tensor, t)
    all_ts = []
    for t, size in zip(gather_t_tensor, all_sizes):
        all_ts.append(t[:size])
    if cat:
        return torch.cat(all_ts, dim=0)
    return all_ts


def sync_string_across_gpus(keys: list[str], device, dim=0):
    keys_serialized = pickle.dumps(keys, protocol=pickle.HIGHEST_PROTOCOL)
    keys_serialized_tensor = (
        torch.frombuffer(keys_serialized, dtype=torch.uint8).clone().to(device)
    )
    keys_serialized_tensor = sync_tensor_across_gpus(
        keys_serialized_tensor, dim=0, cat=False
    )
    keys = [
        key
        for keys in keys_serialized_tensor
        for key in pickle.loads(bytes(keys.cpu().tolist()))
    ]
    return keys


def create_local_process_group() -> None:
    num_workers_per_machine = torch.cuda.device_count()
    global _LOCAL_PROCESS_GROUP
    assert _LOCAL_PROCESS_GROUP is None
    assert get_world_size() % num_workers_per_machine == 0
    num_machines = get_world_size() // num_workers_per_machine
    machine_rank = get_rank() // num_workers_per_machine
    for i in range(num_machines):
        ranks_on_i = list(
            range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine)
        )
        pg = dist.new_group(ranks_on_i)
        if i == machine_rank:
            _LOCAL_PROCESS_GROUP = pg


def _get_global_gloo_group():
    if dist.get_backend() == "nccl":
        return dist.new_group(backend="gloo")
    else:
        return dist.group.WORLD


def all_gather(data, group=None):
    if get_world_size() == 1:
        return [data]
    if group is None:
        group = (
            _get_global_gloo_group()
        )  # use CPU group by default, to reduce GPU RAM usage.
    world_size = dist.get_world_size(group)
    if world_size == 1:
        return [data]

    output = [None for _ in range(world_size)]
    dist.all_gather_object(output, data, group=group)
    return output


def local_broadcast_process_authkey():
    if get_local_size() == 1:
        return
    local_rank = get_local_rank()
    authkey = bytes(mp.current_process().authkey)
    all_keys = all_gather(authkey)
    local_leader_key = all_keys[get_rank() - local_rank]
    if authkey != local_leader_key:
        # print("Process authkey is different from the key of local leader! workers are launched independently ??")
        # print("Overwriting local authkey ...")
        mp.current_process().authkey = local_leader_key