File size: 4,873 Bytes
11e6f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Helpers for distributed training.
"""

import datetime
import io
import os
import socket

import blobfile as bf
from pdb import set_trace as st
# from mpi4py import MPI
import torch as th
import torch.distributed as dist

# Change this to reflect your cluster layout.
# The GPU for a given rank is (rank % GPUS_PER_NODE).
GPUS_PER_NODE = 8
SETUP_RETRY_COUNT = 3


def get_rank():
    if not dist.is_available():
        return 0

    if not dist.is_initialized():
        return 0

    return dist.get_rank()


def synchronize():
    if not dist.is_available():
        return

    if not dist.is_initialized():
        return

    world_size = dist.get_world_size()

    if world_size == 1:
        return

    dist.barrier()


def get_world_size():
    if not dist.is_available():
        return 1

    if not dist.is_initialized():
        return 1

    return dist.get_world_size()


def setup_dist(args):
    """
    Setup a distributed process group.
    """
    if dist.is_initialized():
        return

    # print(f"{os.environ['MASTER_ADDR']=} {args.master_port=}")

    # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=th.cuda.device_count(),  timeout=datetime.timedelta(seconds=5400))
    # st() no mark
    dist.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(seconds=54000))
    print(f"{args.local_rank=} init complete")

    # synchronize() # extra memory on rank 0, why?

    th.cuda.empty_cache()

def cleanup():
    dist.destroy_process_group()

def dev():
    """
    Get the device to use for torch.distributed.
    """
    if th.cuda.is_available():

        if get_world_size() > 1:
            return th.device(f"cuda:{get_rank() % GPUS_PER_NODE}")
        return th.device(f"cuda")

    return th.device("cpu")


# def load_state_dict(path, submodule_name='', **kwargs):
def load_state_dict(path, **kwargs):
    """
    Load a PyTorch file without redundant fetches across MPI ranks.
    """
    # chunk_size = 2 ** 30  # MPI has a relatively small size limit
    # if get_rank() == 0:
    #     with bf.BlobFile(path, "rb") as f:
    #         data = f.read()
    #     num_chunks = len(data) // chunk_size
    #     if len(data) % chunk_size:
    #         num_chunks += 1
    #     MPI.COMM_WORLD.bcast(num_chunks)
    #     for i in range(0, len(data), chunk_size):
    #         MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
    # else:
    #     num_chunks = MPI.COMM_WORLD.bcast(None)
    #     data = bytes()
    #     for _ in range(num_chunks):
    #         data += MPI.COMM_WORLD.bcast(None)

    # return th.load(io.BytesIO(data), **kwargs)
    # with open(path) as f:
    ckpt = th.load(path, **kwargs)
    # if submodule_name != '':
    #     assert submodule_name in ckpt
    #     return ckpt[submodule_name]
    # else:
    return ckpt


def sync_params(params):
    """
    Synchronize a sequence of Tensors across ranks from rank 0.
    """
    # for k, p in params:
    for p in params:
        with th.no_grad():
            try:
                dist.broadcast(p, 0)
            except Exception as e:
                print(k, e)
                # print(e)


def _find_free_port():
    try:
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.bind(("", 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return s.getsockname()[1]
    finally:
        s.close()


_num_moments    = 3             # [num_scalars, sum_of_scalars, sum_of_squares]
_reduce_dtype   = th.float32 # Data type to use for initial per-tensor reduction.
_counter_dtype  = th.float64 # Data type to use for the internal counters.
_rank           = 0             # Rank of the current process.
_sync_device    = None          # Device to use for multiprocess communication. None = single-process.
_sync_called    = False         # Has _sync() been called yet?
_counters       = dict()        # Running counters on each device, updated by report(): name => device => torch.Tensor
_cumulative     = dict()        # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor

def init_multiprocessing(rank, sync_device):
    r"""Initializes `utils.torch_utils.training_stats` for collecting statistics
    across multiple processes.
    This function must be called after
    `torch.distributed.init_process_group()` and before `Collector.update()`.
    The call is not necessary if multi-process collection is not needed.
    Args:
        rank:           Rank of the current process.
        sync_device:    PyTorch device to use for inter-process
                        communication, or None to disable multi-process
                        collection. Typically `torch.device('cuda', rank)`.
    """
    global _rank, _sync_device
    assert not _sync_called
    _rank = rank
    _sync_device = sync_device