sakharamg's picture
Uploading all files
158b61b
""" Pytorch Distributed utils
This piece of code was heavily inspired by the equivalent of Fairseq-py
https://github.com/pytorch/fairseq
"""
import os
import signal
import math
import pickle
import torch.distributed
from onmt.utils.misc import set_random_seed
from onmt.utils.logging import init_logger, logger
def is_master(opt, device_id):
return opt.gpu_ranks[device_id] == 0
def multi_init(opt, device_id):
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip=opt.master_ip,
master_port=opt.master_port)
dist_world_size = opt.world_size
torch.distributed.init_process_group(
backend=opt.gpu_backend, init_method=dist_init_method,
world_size=dist_world_size, rank=opt.gpu_ranks[device_id])
gpu_rank = torch.distributed.get_rank()
if not is_master(opt, device_id):
logger.disabled = True
return gpu_rank
def all_reduce_and_rescale_tensors(tensors, rescale_denom,
buffer_size=10485760):
"""All-reduce and rescale tensors in chunks of the specified size.
Args:
tensors: list of Tensors to all-reduce
rescale_denom: denominator for rescaling summed Tensors
buffer_size: all-reduce chunk size in bytes
"""
# buffer size in bytes, determine equiv. # of elements based on data type
buffer_t = tensors[0].new(
math.ceil(buffer_size / tensors[0].element_size())).zero_()
buffer = []
def all_reduce_buffer():
# copy tensors into buffer_t
offset = 0
for t in buffer:
numel = t.numel()
buffer_t[offset:offset+numel].copy_(t.view(-1))
offset += numel
# all-reduce and rescale
torch.distributed.all_reduce(buffer_t[:offset])
buffer_t.div_(rescale_denom)
# copy all-reduced buffer back into tensors
offset = 0
for t in buffer:
numel = t.numel()
t.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
filled = 0
for t in tensors:
sz = t.numel() * t.element_size()
if sz > buffer_size:
# tensor is bigger than buffer, all-reduce and rescale directly
torch.distributed.all_reduce(t)
t.div_(rescale_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer()
buffer = [t]
filled = sz
else:
# add tensor to buffer
buffer.append(t)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def all_gather_list(data, max_size=4096):
"""Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size()
if not hasattr(all_gather_list, '_in_buffer') or \
max_size != all_gather_list._in_buffer.size():
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size)
for i in range(world_size)
]
in_buffer = all_gather_list._in_buffer
out_buffers = all_gather_list._out_buffers
enc = pickle.dumps(data)
enc_size = len(enc)
if enc_size + 2 > max_size:
raise ValueError(
'encoded data exceeds max_size: {}'.format(enc_size + 2))
assert max_size < 255*256
in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k
in_buffer[1] = enc_size % 255
in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc))
torch.distributed.all_gather(out_buffers, in_buffer.cuda())
results = []
for i in range(world_size):
out_buffer = out_buffers[i]
size = (255 * out_buffer[0].item()) + out_buffer[1].item()
bytes_list = bytes(out_buffer[2:size+2].tolist())
result = pickle.loads(bytes_list)
results.append(result)
return results
class ErrorHandler(object):
"""A class that listens for exceptions in children processes and propagates
the tracebacks to the parent process."""
def __init__(self, error_queue):
""" init error handler """
import signal
import threading
self.error_queue = error_queue
self.children_pids = []
self.error_thread = threading.Thread(
target=self.error_listener, daemon=True)
self.error_thread.start()
signal.signal(signal.SIGUSR1, self.signal_handler)
def add_child(self, pid):
""" error handler """
self.children_pids.append(pid)
def error_listener(self):
""" error listener """
(rank, original_trace) = self.error_queue.get()
self.error_queue.put((rank, original_trace))
os.kill(os.getpid(), signal.SIGUSR1)
def signal_handler(self, signalnum, stackframe):
""" signal handler """
for pid in self.children_pids:
os.kill(pid, signal.SIGINT) # kill children processes
(rank, original_trace) = self.error_queue.get()
msg = """\n\n-- Tracebacks above this line can probably
be ignored --\n\n"""
msg += original_trace
raise Exception(msg)
def batch_producer(generator_to_serve, queue, semaphore, opt, device_id):
"""Produce batches to `queues` from `generator_to_serve`."""
log_level = "INFO" if opt.verbose or device_id == 0 else "WARNING"
init_logger(opt.log_file, log_level=log_level)
set_random_seed(opt.seed, False)
def pred(x):
"""
Filters batches that belong only
to gpu_ranks of current node
"""
for rank in opt.gpu_ranks:
if x[0] % opt.world_size == rank:
return True
generator_to_serve = filter(
pred, enumerate(generator_to_serve))
def next_batch():
# NOTE: stride (if needed) is handled at the
# generator (train_iter) level
new_batch = next(generator_to_serve)
semaphore.acquire()
return new_batch[1]
b = next_batch()
while True:
b.dataset = None
# Move batch to correspond device_id when consumer iterate
# hack to dodge unpicklable `dict_keys`
b.fields = list(b.fields)
queue.put(b)
b = next_batch()
def consumer(process_fn, opt, device_id, error_queue, batch_queue, semaphore): # noqa: E501
"""Run `process_fn` on `device_id` with data from `batch_queue`."""
try:
gpu_rank = multi_init(opt, device_id)
if gpu_rank != opt.gpu_ranks[device_id]:
raise AssertionError("An error occurred in \
Distributed initialization")
process_fn(opt, device_id=device_id,
batch_queue=batch_queue, semaphore=semaphore)
except KeyboardInterrupt:
pass # killed by parent, do nothing
except Exception:
# propagate exception to parent process, keeping original traceback
import traceback
error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc()))