|
|
|
|
|
|
|
|
|
|
|
""" |
|
A modified version of the legacy DistributedDataParallel module that uses c10d |
|
communication primitives. This version is simpler than the latest PyTorch |
|
version and is useful for debugging. Notably it does not overlap gradient |
|
communication with the backward pass, which makes it slower but more robust |
|
than the PyTorch version. |
|
|
|
This version also supports the *no_sync* context manager, which allows faster |
|
training with `--update-freq`. |
|
""" |
|
|
|
from collections import OrderedDict |
|
from contextlib import contextmanager |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from fairseq.distributed import utils |
|
|
|
|
|
class LegacyDistributedDataParallel(nn.Module): |
|
"""Implements distributed data parallelism at the module level. |
|
|
|
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`. |
|
This version uses a c10d process group for communication and does not |
|
broadcast buffers. |
|
|
|
Args: |
|
module (~torch.nn.Module): module to be parallelized |
|
process_group: the c10d process group to be used for distributed data |
|
parallel all-reduction. |
|
buffer_size (int, optional): number of elements to buffer before |
|
performing all-reduce (default: 256M). |
|
""" |
|
|
|
def __init__(self, module, process_group, buffer_size=2**28): |
|
super().__init__() |
|
|
|
self.module = module |
|
self.process_group = process_group |
|
self.world_size = utils.get_world_size(self.process_group) |
|
|
|
|
|
self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters())) |
|
self.buffer = None |
|
|
|
|
|
|
|
self.accumulate_grads = False |
|
|
|
|
|
paramlists = OrderedDict() |
|
for param in self.module.parameters(): |
|
device = param.device |
|
if paramlists.get(device) is None: |
|
paramlists[device] = [] |
|
paramlists[device] += [param] |
|
self.per_device_params = list(paramlists.values()) |
|
|
|
@contextmanager |
|
def no_sync(self): |
|
"""A context manager to disable gradient synchronization.""" |
|
old_accumulate_grads = self.accumulate_grads |
|
self.accumulate_grads = True |
|
yield |
|
self.accumulate_grads = old_accumulate_grads |
|
|
|
def forward(self, *inputs, **kwargs): |
|
return self.module(*inputs, **kwargs) |
|
|
|
def all_reduce_grads(self): |
|
""" |
|
This function must be called explicitly after backward to reduce |
|
gradients. There is no automatic hook like c10d. |
|
""" |
|
|
|
def all_reduce_params(params): |
|
buffer = self.buffer |
|
nonzero_buffer = False |
|
if len(params) > 1: |
|
offset = 0 |
|
for p in params: |
|
sz = p.numel() |
|
if p.grad is not None: |
|
buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) |
|
nonzero_buffer = True |
|
else: |
|
buffer[offset : offset + sz].zero_() |
|
offset += sz |
|
else: |
|
|
|
p = params[0] |
|
if p.grad is not None: |
|
buffer = p.grad.data |
|
nonzero_buffer = True |
|
elif p.numel() <= self.buffer.numel(): |
|
buffer = buffer[: p.numel()] |
|
buffer.zero_() |
|
else: |
|
buffer = torch.zeros_like(p) |
|
|
|
if nonzero_buffer: |
|
buffer.div_(self.world_size) |
|
|
|
utils.all_reduce(buffer, self.process_group) |
|
|
|
|
|
offset = 0 |
|
for p in params: |
|
sz = p.numel() |
|
if p.grad is not None: |
|
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) |
|
else: |
|
p.grad = buffer[offset : offset + sz].view_as(p).clone() |
|
offset += sz |
|
|
|
def reduction_fn(): |
|
|
|
if self.accumulate_grads: |
|
return |
|
|
|
if self.buffer is None: |
|
self.buffer = next(self.module.parameters()).new(self.buffer_size) |
|
|
|
for params in self.per_device_params: |
|
|
|
offset = 0 |
|
buffered_params = [] |
|
for param in params: |
|
if not param.requires_grad: |
|
continue |
|
if param.grad is None: |
|
param.grad = torch.zeros_like(param) |
|
|
|
if hasattr(param, "expert"): |
|
|
|
continue |
|
|
|
if param.grad.requires_grad: |
|
raise RuntimeError( |
|
"DistributedDataParallel only works " |
|
"with gradients that don't require " |
|
"grad" |
|
) |
|
sz = param.numel() |
|
if sz > self.buffer.numel(): |
|
|
|
all_reduce_params([param]) |
|
else: |
|
if offset + sz > self.buffer.numel(): |
|
all_reduce_params(buffered_params) |
|
offset = 0 |
|
buffered_params.clear() |
|
buffered_params.append(param) |
|
offset += sz |
|
|
|
if len(buffered_params) > 0: |
|
all_reduce_params(buffered_params) |
|
|
|
reduction_fn() |
|
|