ReMoDiffuse / mogen /core /distributed_wrapper.py
mingyuan's picture
initial commit
a0d91d3
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel
from mmcv.parallel.scatter_gather import scatter_kwargs
from torch.cuda._utils import _get_device_index
@MODULE_WRAPPERS.register_module()
class DistributedDataParallelWrapper(nn.Module):
"""A DistributedDataParallel wrapper for models in 3D mesh estimation task.
In 3D mesh estimation task, there is a need to wrap different modules in
the models with separate DistributedDataParallel. Otherwise, it will cause
errors for GAN training.
More specific, the GAN model, usually has two sub-modules:
generator and discriminator. If we wrap both of them in one
standard DistributedDataParallel, it will cause errors during training,
because when we update the parameters of the generator (or discriminator),
the parameters of the discriminator (or generator) is not updated, which is
not allowed for DistributedDataParallel.
So we design this wrapper to separately wrap DistributedDataParallel
for generator and discriminator.
In this wrapper, we perform two operations:
1. Wrap the modules in the models with separate MMDistributedDataParallel.
Note that only modules with parameters will be wrapped.
2. Do scatter operation for 'forward', 'train_step' and 'val_step'.
Note that the arguments of this wrapper is the same as those in
`torch.nn.parallel.distributed.DistributedDataParallel`.
Args:
module (nn.Module): Module that needs to be wrapped.
device_ids (list[int | `torch.device`]): Same as that in
`torch.nn.parallel.distributed.DistributedDataParallel`.
dim (int, optional): Same as that in the official scatter function in
pytorch. Defaults to 0.
broadcast_buffers (bool): Same as that in
`torch.nn.parallel.distributed.DistributedDataParallel`.
Defaults to False.
find_unused_parameters (bool, optional): Same as that in
`torch.nn.parallel.distributed.DistributedDataParallel`.
Traverse the autograd graph of all tensors contained in returned
value of the wrapped module’s forward function. Defaults to False.
kwargs (dict): Other arguments used in
`torch.nn.parallel.distributed.DistributedDataParallel`.
"""
def __init__(self,
module,
device_ids,
dim=0,
broadcast_buffers=False,
find_unused_parameters=False,
**kwargs):
super().__init__()
assert len(device_ids) == 1, (
'Currently, DistributedDataParallelWrapper only supports one'
'single CUDA device for each process.'
f'The length of device_ids must be 1, but got {len(device_ids)}.')
self.module = module
self.dim = dim
self.to_ddp(
device_ids=device_ids,
dim=dim,
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
**kwargs)
self.output_device = _get_device_index(device_ids[0], True)
def to_ddp(self, device_ids, dim, broadcast_buffers,
find_unused_parameters, **kwargs):
"""Wrap models with separate MMDistributedDataParallel.
It only wraps the modules with parameters.
"""
for name, module in self.module._modules.items():
if next(module.parameters(), None) is None:
module = module.cuda()
elif all(not p.requires_grad for p in module.parameters()):
module = module.cuda()
else:
module = MMDistributedDataParallel(
module.cuda(),
device_ids=device_ids,
dim=dim,
broadcast_buffers=broadcast_buffers,
find_unused_parameters=find_unused_parameters,
**kwargs)
self.module._modules[name] = module
def scatter(self, inputs, kwargs, device_ids):
"""Scatter function.
Args:
inputs (Tensor): Input Tensor.
kwargs (dict): Args for
``mmcv.parallel.scatter_gather.scatter_kwargs``.
device_ids (int): Device id.
"""
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def forward(self, *inputs, **kwargs):
"""Forward function.
Args:
inputs (tuple): Input data.
kwargs (dict): Args for
``mmcv.parallel.scatter_gather.scatter_kwargs``.
"""
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
return self.module(*inputs[0], **kwargs[0])
def train_step(self, *inputs, **kwargs):
"""Train step function.
Args:
inputs (Tensor): Input Tensor.
kwargs (dict): Args for
``mmcv.parallel.scatter_gather.scatter_kwargs``.
"""
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.train_step(*inputs[0], **kwargs[0])
return output
def val_step(self, *inputs, **kwargs):
"""Validation step function.
Args:
inputs (tuple): Input data.
kwargs (dict): Args for ``scatter_kwargs``.
"""
inputs, kwargs = self.scatter(inputs, kwargs,
[torch.cuda.current_device()])
output = self.module.val_step(*inputs[0], **kwargs[0])
return output