Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.parallel import MODULE_WRAPPERS, MMDistributedDataParallel | |
from mmcv.parallel.scatter_gather import scatter_kwargs | |
from torch.cuda._utils import _get_device_index | |
class DistributedDataParallelWrapper(nn.Module): | |
"""A DistributedDataParallel wrapper for models in 3D mesh estimation task. | |
In 3D mesh estimation task, there is a need to wrap different modules in | |
the models with separate DistributedDataParallel. Otherwise, it will cause | |
errors for GAN training. | |
More specific, the GAN model, usually has two sub-modules: | |
generator and discriminator. If we wrap both of them in one | |
standard DistributedDataParallel, it will cause errors during training, | |
because when we update the parameters of the generator (or discriminator), | |
the parameters of the discriminator (or generator) is not updated, which is | |
not allowed for DistributedDataParallel. | |
So we design this wrapper to separately wrap DistributedDataParallel | |
for generator and discriminator. | |
In this wrapper, we perform two operations: | |
1. Wrap the modules in the models with separate MMDistributedDataParallel. | |
Note that only modules with parameters will be wrapped. | |
2. Do scatter operation for 'forward', 'train_step' and 'val_step'. | |
Note that the arguments of this wrapper is the same as those in | |
`torch.nn.parallel.distributed.DistributedDataParallel`. | |
Args: | |
module (nn.Module): Module that needs to be wrapped. | |
device_ids (list[int | `torch.device`]): Same as that in | |
`torch.nn.parallel.distributed.DistributedDataParallel`. | |
dim (int, optional): Same as that in the official scatter function in | |
pytorch. Defaults to 0. | |
broadcast_buffers (bool): Same as that in | |
`torch.nn.parallel.distributed.DistributedDataParallel`. | |
Defaults to False. | |
find_unused_parameters (bool, optional): Same as that in | |
`torch.nn.parallel.distributed.DistributedDataParallel`. | |
Traverse the autograd graph of all tensors contained in returned | |
value of the wrapped module’s forward function. Defaults to False. | |
kwargs (dict): Other arguments used in | |
`torch.nn.parallel.distributed.DistributedDataParallel`. | |
""" | |
def __init__(self, | |
module, | |
device_ids, | |
dim=0, | |
broadcast_buffers=False, | |
find_unused_parameters=False, | |
**kwargs): | |
super().__init__() | |
assert len(device_ids) == 1, ( | |
'Currently, DistributedDataParallelWrapper only supports one' | |
'single CUDA device for each process.' | |
f'The length of device_ids must be 1, but got {len(device_ids)}.') | |
self.module = module | |
self.dim = dim | |
self.to_ddp( | |
device_ids=device_ids, | |
dim=dim, | |
broadcast_buffers=broadcast_buffers, | |
find_unused_parameters=find_unused_parameters, | |
**kwargs) | |
self.output_device = _get_device_index(device_ids[0], True) | |
def to_ddp(self, device_ids, dim, broadcast_buffers, | |
find_unused_parameters, **kwargs): | |
"""Wrap models with separate MMDistributedDataParallel. | |
It only wraps the modules with parameters. | |
""" | |
for name, module in self.module._modules.items(): | |
if next(module.parameters(), None) is None: | |
module = module.cuda() | |
elif all(not p.requires_grad for p in module.parameters()): | |
module = module.cuda() | |
else: | |
module = MMDistributedDataParallel( | |
module.cuda(), | |
device_ids=device_ids, | |
dim=dim, | |
broadcast_buffers=broadcast_buffers, | |
find_unused_parameters=find_unused_parameters, | |
**kwargs) | |
self.module._modules[name] = module | |
def scatter(self, inputs, kwargs, device_ids): | |
"""Scatter function. | |
Args: | |
inputs (Tensor): Input Tensor. | |
kwargs (dict): Args for | |
``mmcv.parallel.scatter_gather.scatter_kwargs``. | |
device_ids (int): Device id. | |
""" | |
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) | |
def forward(self, *inputs, **kwargs): | |
"""Forward function. | |
Args: | |
inputs (tuple): Input data. | |
kwargs (dict): Args for | |
``mmcv.parallel.scatter_gather.scatter_kwargs``. | |
""" | |
inputs, kwargs = self.scatter(inputs, kwargs, | |
[torch.cuda.current_device()]) | |
return self.module(*inputs[0], **kwargs[0]) | |
def train_step(self, *inputs, **kwargs): | |
"""Train step function. | |
Args: | |
inputs (Tensor): Input Tensor. | |
kwargs (dict): Args for | |
``mmcv.parallel.scatter_gather.scatter_kwargs``. | |
""" | |
inputs, kwargs = self.scatter(inputs, kwargs, | |
[torch.cuda.current_device()]) | |
output = self.module.train_step(*inputs[0], **kwargs[0]) | |
return output | |
def val_step(self, *inputs, **kwargs): | |
"""Validation step function. | |
Args: | |
inputs (tuple): Input data. | |
kwargs (dict): Args for ``scatter_kwargs``. | |
""" | |
inputs, kwargs = self.scatter(inputs, kwargs, | |
[torch.cuda.current_device()]) | |
output = self.module.val_step(*inputs[0], **kwargs[0]) | |
return output |