|
""" Pytorch Distributed utils |
|
This piece of code was heavily inspired by the equivalent of Fairseq-py |
|
https://github.com/pytorch/fairseq |
|
""" |
|
import os |
|
import signal |
|
import math |
|
import pickle |
|
|
|
import torch.distributed |
|
|
|
from onmt.utils.misc import set_random_seed |
|
from onmt.utils.logging import init_logger, logger |
|
|
|
|
|
def is_master(opt, device_id): |
|
return opt.gpu_ranks[device_id] == 0 |
|
|
|
|
|
def multi_init(opt, device_id): |
|
dist_init_method = 'tcp://{master_ip}:{master_port}'.format( |
|
master_ip=opt.master_ip, |
|
master_port=opt.master_port) |
|
dist_world_size = opt.world_size |
|
torch.distributed.init_process_group( |
|
backend=opt.gpu_backend, init_method=dist_init_method, |
|
world_size=dist_world_size, rank=opt.gpu_ranks[device_id]) |
|
gpu_rank = torch.distributed.get_rank() |
|
if not is_master(opt, device_id): |
|
logger.disabled = True |
|
|
|
return gpu_rank |
|
|
|
|
|
def all_reduce_and_rescale_tensors(tensors, rescale_denom, |
|
buffer_size=10485760): |
|
"""All-reduce and rescale tensors in chunks of the specified size. |
|
|
|
Args: |
|
tensors: list of Tensors to all-reduce |
|
rescale_denom: denominator for rescaling summed Tensors |
|
buffer_size: all-reduce chunk size in bytes |
|
""" |
|
|
|
buffer_t = tensors[0].new( |
|
math.ceil(buffer_size / tensors[0].element_size())).zero_() |
|
buffer = [] |
|
|
|
def all_reduce_buffer(): |
|
|
|
offset = 0 |
|
for t in buffer: |
|
numel = t.numel() |
|
buffer_t[offset:offset+numel].copy_(t.view(-1)) |
|
offset += numel |
|
|
|
|
|
torch.distributed.all_reduce(buffer_t[:offset]) |
|
buffer_t.div_(rescale_denom) |
|
|
|
|
|
offset = 0 |
|
for t in buffer: |
|
numel = t.numel() |
|
t.view(-1).copy_(buffer_t[offset:offset+numel]) |
|
offset += numel |
|
|
|
filled = 0 |
|
for t in tensors: |
|
sz = t.numel() * t.element_size() |
|
if sz > buffer_size: |
|
|
|
torch.distributed.all_reduce(t) |
|
t.div_(rescale_denom) |
|
elif filled + sz > buffer_size: |
|
|
|
all_reduce_buffer() |
|
buffer = [t] |
|
filled = sz |
|
else: |
|
|
|
buffer.append(t) |
|
filled += sz |
|
|
|
if len(buffer) > 0: |
|
all_reduce_buffer() |
|
|
|
|
|
def all_gather_list(data, max_size=4096): |
|
"""Gathers arbitrary data from all nodes into a list.""" |
|
world_size = torch.distributed.get_world_size() |
|
if not hasattr(all_gather_list, '_in_buffer') or \ |
|
max_size != all_gather_list._in_buffer.size(): |
|
all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) |
|
all_gather_list._out_buffers = [ |
|
torch.cuda.ByteTensor(max_size) |
|
for i in range(world_size) |
|
] |
|
in_buffer = all_gather_list._in_buffer |
|
out_buffers = all_gather_list._out_buffers |
|
|
|
enc = pickle.dumps(data) |
|
enc_size = len(enc) |
|
if enc_size + 2 > max_size: |
|
raise ValueError( |
|
'encoded data exceeds max_size: {}'.format(enc_size + 2)) |
|
assert max_size < 255*256 |
|
in_buffer[0] = enc_size // 255 |
|
in_buffer[1] = enc_size % 255 |
|
in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) |
|
|
|
torch.distributed.all_gather(out_buffers, in_buffer.cuda()) |
|
|
|
results = [] |
|
for i in range(world_size): |
|
out_buffer = out_buffers[i] |
|
size = (255 * out_buffer[0].item()) + out_buffer[1].item() |
|
|
|
bytes_list = bytes(out_buffer[2:size+2].tolist()) |
|
result = pickle.loads(bytes_list) |
|
results.append(result) |
|
return results |
|
|
|
|
|
class ErrorHandler(object): |
|
"""A class that listens for exceptions in children processes and propagates |
|
the tracebacks to the parent process.""" |
|
|
|
def __init__(self, error_queue): |
|
""" init error handler """ |
|
import signal |
|
import threading |
|
self.error_queue = error_queue |
|
self.children_pids = [] |
|
self.error_thread = threading.Thread( |
|
target=self.error_listener, daemon=True) |
|
self.error_thread.start() |
|
signal.signal(signal.SIGUSR1, self.signal_handler) |
|
|
|
def add_child(self, pid): |
|
""" error handler """ |
|
self.children_pids.append(pid) |
|
|
|
def error_listener(self): |
|
""" error listener """ |
|
(rank, original_trace) = self.error_queue.get() |
|
self.error_queue.put((rank, original_trace)) |
|
os.kill(os.getpid(), signal.SIGUSR1) |
|
|
|
def signal_handler(self, signalnum, stackframe): |
|
""" signal handler """ |
|
for pid in self.children_pids: |
|
os.kill(pid, signal.SIGINT) |
|
(rank, original_trace) = self.error_queue.get() |
|
msg = """\n\n-- Tracebacks above this line can probably |
|
be ignored --\n\n""" |
|
msg += original_trace |
|
raise Exception(msg) |
|
|
|
|
|
def batch_producer(generator_to_serve, queue, semaphore, opt, device_id): |
|
"""Produce batches to `queues` from `generator_to_serve`.""" |
|
log_level = "INFO" if opt.verbose or device_id == 0 else "WARNING" |
|
init_logger(opt.log_file, log_level=log_level) |
|
set_random_seed(opt.seed, False) |
|
|
|
def pred(x): |
|
""" |
|
Filters batches that belong only |
|
to gpu_ranks of current node |
|
""" |
|
for rank in opt.gpu_ranks: |
|
if x[0] % opt.world_size == rank: |
|
return True |
|
|
|
generator_to_serve = filter( |
|
pred, enumerate(generator_to_serve)) |
|
|
|
def next_batch(): |
|
|
|
|
|
new_batch = next(generator_to_serve) |
|
semaphore.acquire() |
|
return new_batch[1] |
|
|
|
b = next_batch() |
|
|
|
while True: |
|
b.dataset = None |
|
|
|
|
|
|
|
b.fields = list(b.fields) |
|
queue.put(b) |
|
b = next_batch() |
|
|
|
|
|
def consumer(process_fn, opt, device_id, error_queue, batch_queue, semaphore): |
|
"""Run `process_fn` on `device_id` with data from `batch_queue`.""" |
|
try: |
|
gpu_rank = multi_init(opt, device_id) |
|
if gpu_rank != opt.gpu_ranks[device_id]: |
|
raise AssertionError("An error occurred in \ |
|
Distributed initialization") |
|
process_fn(opt, device_id=device_id, |
|
batch_queue=batch_queue, semaphore=semaphore) |
|
except KeyboardInterrupt: |
|
pass |
|
except Exception: |
|
|
|
import traceback |
|
error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc())) |
|
|