extrasneo / modules /voice_conversion /fairseq /distributed /legacy_distributed_data_parallel.py
Reiner4's picture
Upload 539 files
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
A modified version of the legacy DistributedDataParallel module that uses c10d
communication primitives. This version is simpler than the latest PyTorch
version and is useful for debugging. Notably it does not overlap gradient
communication with the backward pass, which makes it slower but more robust
than the PyTorch version.
This version also supports the *no_sync* context manager, which allows faster
training with `--update-freq`.
from collections import OrderedDict
from contextlib import contextmanager
import torch
from torch import nn
from fairseq.distributed import utils
class LegacyDistributedDataParallel(nn.Module):
"""Implements distributed data parallelism at the module level.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and does not
broadcast buffers.
module (~torch.nn.Module): module to be parallelized
process_group: the c10d process group to be used for distributed data
parallel all-reduction.
buffer_size (int, optional): number of elements to buffer before
performing all-reduce (default: 256M).
def __init__(self, module, process_group, buffer_size=2**28):
self.module = module
self.process_group = process_group
self.world_size = utils.get_world_size(self.process_group)
# Never use a bigger buffer than the number of model params
self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters()))
self.buffer = None
# We can also forcibly accumulate grads locally and only do the
# all-reduce at some later time
self.accumulate_grads = False
# make per-device lists of parameters
paramlists = OrderedDict()
for param in self.module.parameters():
device = param.device
if paramlists.get(device) is None:
paramlists[device] = []
paramlists[device] += [param]
self.per_device_params = list(paramlists.values())
def no_sync(self):
"""A context manager to disable gradient synchronization."""
old_accumulate_grads = self.accumulate_grads
self.accumulate_grads = True
self.accumulate_grads = old_accumulate_grads
def forward(self, *inputs, **kwargs):
return self.module(*inputs, **kwargs)
def all_reduce_grads(self):
This function must be called explicitly after backward to reduce
gradients. There is no automatic hook like c10d.
def all_reduce_params(params):
buffer = self.buffer
nonzero_buffer = False
if len(params) > 1:
offset = 0
for p in params:
sz = p.numel()
if p.grad is not None:
buffer[offset : offset + sz].copy_(p.grad.data.view(-1))
nonzero_buffer = True
buffer[offset : offset + sz].zero_()
offset += sz
# we only have a single grad to all-reduce
p = params[0]
if p.grad is not None:
buffer = p.grad.data
nonzero_buffer = True
elif p.numel() <= self.buffer.numel():
buffer = buffer[: p.numel()]
buffer = torch.zeros_like(p)
if nonzero_buffer:
utils.all_reduce(buffer, self.process_group)
# copy all-reduced grads back into their original place
offset = 0
for p in params:
sz = p.numel()
if p.grad is not None:
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p))
p.grad = buffer[offset : offset + sz].view_as(p).clone()
offset += sz
def reduction_fn():
# This function only needs to be called once
if self.accumulate_grads:
if self.buffer is None:
self.buffer = next(self.module.parameters()).new(self.buffer_size)
for params in self.per_device_params:
# All-reduce the gradients in buckets
offset = 0
buffered_params = []
for param in params:
if not param.requires_grad:
if param.grad is None:
param.grad = torch.zeros_like(param)
if hasattr(param, "expert"):
# Skip gradient sync for unshared parameters
if param.grad.requires_grad:
raise RuntimeError(
"DistributedDataParallel only works "
"with gradients that don't require "
sz = param.numel()
if sz > self.buffer.numel():
# all-reduce big params directly
if offset + sz > self.buffer.numel():
offset = 0
offset += sz
if len(buffered_params) > 0: