File size: 5,908 Bytes
ffbcf9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""

import os
import platform
import warnings
import subprocess

import cv2

import torch
import torch.utils.data.distributed
from torch import multiprocessing as mp
from torch import distributed as dist


def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True


def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()


def barrier():
    if not is_dist_avail_and_initialized():
        return
    dist.barrier()


def is_main_process():
    return get_rank() == 0


def is_rank_zero(args):
    return args.rank == 0


def get_dist_info():
    if dist.is_available() and dist.is_initialized():
        rank = dist.get_rank()
        world_size = dist.get_world_size()
    else:
        rank = 0
        world_size = 1
    return rank, world_size


def setup_multi_processes(cfg):
    """Setup multi-processing environment variables."""
    # set multi-process start method as `fork` to speed up the training
    if platform.system() != "Windows":
        mp_start_method = cfg.get("mp_start_method", "fork")
        current_method = mp.get_start_method(allow_none=True)
        if current_method is not None and current_method != mp_start_method:
            warnings.warn(
                f"Multi-processing start method `{mp_start_method}` is "
                f"different from the previous setting `{current_method}`."
                f"It will be force set to `{mp_start_method}`. You can change "
                f"this behavior by changing `mp_start_method` in your config."
            )
        mp.set_start_method(mp_start_method, force=True)

    # disable opencv multithreading to avoid system being overloaded
    opencv_num_threads = cfg.get("opencv_num_threads", 0)
    cv2.setNumThreads(opencv_num_threads)

    # setup OMP threads
    # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py  # noqa
    workers_per_gpu = cfg.get("workers_per_gpu", 4)

    if "OMP_NUM_THREADS" not in os.environ and workers_per_gpu > 1:
        omp_num_threads = 1
        warnings.warn(
            f"Setting OMP_NUM_THREADS environment variable for each process "
            f"to be {omp_num_threads} in default, to avoid your system being "
            f"overloaded, please further tune the variable for optimal "
            f"performance in your application as needed."
        )
        os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)

    # setup MKL threads
    if "MKL_NUM_THREADS" not in os.environ and workers_per_gpu > 1:
        mkl_num_threads = os.environ.get("OMP_NUM_THREADS", 1)
        warnings.warn(
            f"Setting MKL_NUM_THREADS environment variable for each process "
            f"to be {mkl_num_threads} in default, to avoid your system being "
            f"overloaded, please further tune the variable for optimal "
            f"performance in your application as needed."
        )
        os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads)


def setup_slurm(backend: str, port: str) -> None:
    """Initialize slurm distributed training environment.
    If argument ``port`` is not specified, then the master port will be system
    environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
    environment variable, then a default port ``29500`` will be used.
    Args:
        backend (str): Backend of torch.distributed.
        port (int, optional): Master port. Defaults to None.
    """
    proc_id = int(os.environ["SLURM_PROCID"])
    ntasks = int(os.environ["SLURM_NTASKS"])
    node_list = os.environ["SLURM_NODELIST"]

    num_gpus = torch.cuda.device_count()

    torch.cuda.set_device(proc_id % num_gpus)
    addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
    os.environ["MASTER_PORT"] = str(port)
    os.environ["MASTER_ADDR"] = addr
    os.environ["WORLD_SIZE"] = str(ntasks)
    os.environ["LOCAL_RANK"] = str(proc_id % num_gpus)
    os.environ["RANK"] = str(proc_id)
    print(
        proc_id,
        ntasks,
        num_gpus,
        proc_id % num_gpus,
        node_list,
        addr,
        os.environ["MASTER_PORT"],
        os.system("nvidia-smi -L"),
    )
    dist.init_process_group(backend, rank=proc_id, world_size=ntasks)


def sync_tensor_across_gpus(t, dim=0, cat=True):
    if t is None or not (dist.is_available() and dist.is_initialized()):
        return t
    t = torch.atleast_1d(t)
    group = dist.group.WORLD
    group_size = torch.distributed.get_world_size(group)

    local_size = torch.tensor(t.size(dim), device=t.device)
    all_sizes = [torch.zeros_like(local_size) for _ in range(group_size)]
    dist.all_gather(all_sizes, local_size)
    max_size = max(all_sizes)
    size_diff = max_size.item() - local_size.item()
    if size_diff:
        padding = torch.zeros(size_diff, device=t.device, dtype=t.dtype)
        t = torch.cat((t, padding))

    gather_t_tensor = [torch.zeros_like(t) for _ in range(group_size)]
    dist.all_gather(gather_t_tensor, t)
    all_ts = []
    for t, size in zip(gather_t_tensor, all_sizes):
        all_ts.append(t[:size])
    if cat:
        return torch.cat(all_ts, dim=0)
    return all_ts


import pickle


def sync_string_across_gpus(keys: list[str], device, dim=0):
    keys_serialized = pickle.dumps(keys, protocol=pickle.HIGHEST_PROTOCOL)
    keys_serialized_tensor = torch.frombuffer(keys_serialized, dtype=torch.uint8).to(
        device
    )
    keys_serialized_tensor = sync_tensor_across_gpus(
        keys_serialized_tensor, dim=0, cat=False
    )
    keys = [
        key
        for keys in keys_serialized_tensor
        for key in pickle.loads(bytes(keys.cpu().tolist()))
    ]
    return keys