File size: 9,108 Bytes
fe781a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import logging
import random
import subprocess
from datetime import datetime
import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel.distributed import _find_tensors
import torch.optim
from packaging import version
from omegaconf import OmegaConf
def set_random_seed(seed):
def is_logging_process():
return not dist.is_initialized() or dist.get_rank() == 0
def get_logger(cfg, name=None):
# log_file_path is used when unit testing
if is_logging_process():
OmegaConf.to_container(cfg.job_logging_config, resolve=True)
return logging.getLogger(name)
# from
class SyncFunction(torch.autograd.Function):
# @torch.no_grad()
def forward(ctx, tensor):
ctx.batch_size = tensor.shape[0]
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_tensor, tensor)
gathered_tensor =, 0)
return gathered_tensor
def backward(ctx, grad_output):
grad_input = grad_output.clone()
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
idx_from = torch.distributed.get_rank() * ctx.batch_size
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
return grad_input[idx_from:idx_to]
def get_timestamp():
def get_commit_hash():
message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
return message.strip().decode("utf-8")
class DDP(DistributedDataParallel):
Override the forward call in lightning so it goes to training and validation step respectively
def forward(self, *inputs, **kwargs): # pragma: no cover
if version.parse(torch.__version__[:6]) < version.parse("1.11"):
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
assert len(self.device_ids) == 1
output = self.module.training_step(*inputs[0], **kwargs[0])
elif self.module.testing:
output = self.module.test_step(*inputs[0], **kwargs[0])
output = self.module.validation_step(*inputs[0], **kwargs[0])
if torch.is_grad_enabled():
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters:
from torch.nn.parallel.distributed import \
logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.num_iterations += 1
# Notify the join context that this process has not joined, if
# needed
work = Join.notify_join_context(self)
if work:
work, self._divide_by_initial_world_size
# Calling _rebuild_buckets before forward compuation,
# It may allocate new buckets before deallocating old buckets
# inside _rebuild_buckets. To save peak memory usage,
# call _rebuild_buckets before the peak memory usage increases
# during forward computation.
# This should be called only once during whole training period.
if torch.is_grad_enabled() and self.reducer._rebuild_buckets():"Reducer buckets have been rebuilt in this iteration.")
self._has_rebuilt_buckets = True
# sync params according to location (before/after forward) user
# specified as part of hook, if hook was specified.
buffer_hook_registered = hasattr(self, 'buffer_hook')
if self._check_sync_bufs_pre_fwd():
if self._join_config.enable:
# Notify joined ranks whether they should sync in backwards pass or not.
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
output = self.module.training_step(*inputs[0], **kwargs[0])
elif self.module.testing:
output = self.module.test_step(*inputs[0], **kwargs[0])
output = self.module.validation_step(*inputs[0], **kwargs[0])
# sync params according to location (before/after forward) user
# specified as part of hook, if hook was specified.
if self._check_sync_bufs_post_fwd():
if torch.is_grad_enabled() and self.require_backward_grad_sync:
self.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
# object. We need to find any tensors in this object, though,
# because we need to figure out which parameters were used during
# this forward pass, to ensure we short circuit reduction for any
# unused parameters. Only if `find_unused_parameters` is set.
if self.find_unused_parameters and not self.static_graph:
# Do not need to populate this for static graph.
self.require_forward_param_sync = False
# TODO: DDPSink is currently enabled for unused parameter detection and
# static graph training for first iteration.
if (self.find_unused_parameters and not self.static_graph) or (
self.static_graph and self.num_iterations == 1
state_dict = {
'static_graph': self.static_graph,
'num_iterations': self.num_iterations,
output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(
output_placeholders = [None for _ in range(len(output_tensor_list))]
# Do not touch tensors that have no grad_fn, which can cause issues
# such as
for i, output in enumerate(output_tensor_list):
if torch.is_tensor(output) and output.grad_fn is None:
output_placeholders[i] = output
# When find_unused_parameters=True, makes tensors which require grad
# run through the DDPSink backward pass. When not all outputs are
# used in loss, this makes those corresponding tensors receive
# undefined gradient which the reducer then handles to ensure
# param.grad field is not touched and we don't error out.
passthrough_tensor_list = _DDPSink.apply(
for i in range(len(output_placeholders)):
if output_placeholders[i] is None:
output_placeholders[i] = passthrough_tensor_list[i]
# Reconstruct output data structure.
output = _tree_unflatten_with_rref(
output_placeholders, treespec, output_is_rref
return output