|
from torch.nn.parallel import DistributedDataParallel |
|
from torch.nn.parallel.distributed import _find_tensors |
|
import torch.optim |
|
import torch.utils.data |
|
import torch |
|
from packaging import version |
|
|
|
def get_torch_version(): |
|
torch_version = torch.__version__ |
|
torch_version = torch_version.split("dev")[0] |
|
torch_version = torch_version.split("cu")[0] |
|
if torch_version[-1] == '.': |
|
torch_version = torch_version[:-1] |
|
torch_version = torch_version.replace("+","") |
|
return torch_version |
|
|
|
|
|
class DDP(DistributedDataParallel): |
|
""" |
|
Override the forward call in lightning so it goes to training and validation step respectively |
|
""" |
|
|
|
def forward(self, *inputs, **kwargs): |
|
torch_version = get_torch_version() |
|
if version.parse(torch_version) < version.parse("1.11"): |
|
self._sync_params() |
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
|
assert len(self.device_ids) == 1 |
|
if self.module.training: |
|
output = self.module.training_step(*inputs[0], **kwargs[0]) |
|
elif self.module.testing: |
|
output = self.module.test_step(*inputs[0], **kwargs[0]) |
|
else: |
|
output = self.module.validation_step(*inputs[0], **kwargs[0]) |
|
if torch.is_grad_enabled(): |
|
|
|
|
|
|
|
|
|
|
|
if self.find_unused_parameters: |
|
self.reducer.prepare_for_backward(list(_find_tensors(output))) |
|
else: |
|
self.reducer.prepare_for_backward([]) |
|
elif version.parse(torch_version) < version.parse("2.1"): |
|
from torch.nn.parallel.distributed import \ |
|
logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref |
|
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): |
|
if torch.is_grad_enabled() and self.require_backward_grad_sync: |
|
self.logger.set_runtime_stats_and_log() |
|
self.num_iterations += 1 |
|
self.reducer.prepare_for_forward() |
|
|
|
|
|
|
|
work = Join.notify_join_context(self) |
|
if work: |
|
self.reducer._set_forward_pass_work_handle( |
|
work, self._divide_by_initial_world_size |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): |
|
logging.info("Reducer buckets have been rebuilt in this iteration.") |
|
self._has_rebuilt_buckets = True |
|
|
|
|
|
|
|
buffer_hook_registered = hasattr(self, 'buffer_hook') |
|
if self._check_sync_bufs_pre_fwd(): |
|
self._sync_buffers() |
|
|
|
if self._join_config.enable: |
|
|
|
self._check_global_requires_backward_grad_sync(is_joined_rank=False) |
|
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
|
if self.module.training: |
|
output = self.module.training_step(*inputs[0], **kwargs[0]) |
|
elif self.module.testing: |
|
output = self.module.test_step(*inputs[0], **kwargs[0]) |
|
else: |
|
output = self.module.validation_step(*inputs[0], **kwargs[0]) |
|
|
|
|
|
|
|
if self._check_sync_bufs_post_fwd(): |
|
self._sync_buffers() |
|
|
|
if torch.is_grad_enabled() and self.require_backward_grad_sync: |
|
self.require_forward_param_sync = True |
|
|
|
|
|
|
|
|
|
|
|
if self.find_unused_parameters and not self.static_graph: |
|
|
|
self.reducer.prepare_for_backward(list(_find_tensors(output))) |
|
else: |
|
self.reducer.prepare_for_backward([]) |
|
else: |
|
self.require_forward_param_sync = False |
|
|
|
|
|
|
|
if (self.find_unused_parameters and not self.static_graph) or ( |
|
self.static_graph and self.num_iterations == 1 |
|
): |
|
state_dict = { |
|
'static_graph': self.static_graph, |
|
'num_iterations': self.num_iterations, |
|
} |
|
|
|
output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref( |
|
output |
|
) |
|
output_placeholders = [None for _ in range(len(output_tensor_list))] |
|
|
|
|
|
for i, output in enumerate(output_tensor_list): |
|
if torch.is_tensor(output) and output.grad_fn is None: |
|
output_placeholders[i] = output |
|
|
|
|
|
|
|
|
|
|
|
|
|
passthrough_tensor_list = _DDPSink.apply( |
|
self.reducer, |
|
state_dict, |
|
*output_tensor_list, |
|
) |
|
for i in range(len(output_placeholders)): |
|
if output_placeholders[i] is None: |
|
output_placeholders[i] = passthrough_tensor_list[i] |
|
|
|
|
|
output = _tree_unflatten_with_rref( |
|
output_placeholders, treespec, output_is_rref |
|
) |
|
else: |
|
output = super().forward(*inputs, **kwargs) |
|
return output |
|
|
|
def _run_ddp_forward(self, *inputs, **kwargs): |
|
torch_version = get_torch_version() |
|
if version.parse(torch_version) < version.parse("2.1"): |
|
return super()._run_ddp_forward(*inputs, **kwargs) |
|
with self._inside_ddp_forward(): |
|
if self.module.training: |
|
output = self.module.training_step(*inputs, **kwargs) |
|
elif self.module.testing: |
|
output = self.module.test_step(*inputs, **kwargs) |
|
else: |
|
output = self.module.validation_step(*inputs, **kwargs) |
|
return output |
|
|