|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
|
|
import torch |
|
from torch import _C |
|
from torch.cuda import _lazy_call, device as device_ctx_manager |
|
from torch.utils.checkpoint import detach_variable |
|
|
|
from megatron.memory import allocate_mem_buff |
|
|
|
from .initialize import get_data_parallel_rank |
|
from .initialize import get_tensor_model_parallel_group |
|
from .initialize import get_tensor_model_parallel_rank |
|
from .initialize import get_tensor_model_parallel_world_size |
|
|
|
|
|
|
|
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' |
|
|
|
|
|
def _set_cuda_rng_state(new_state, device=-1): |
|
"""Sets the random number generator state of the current GPU. |
|
|
|
Argumentss: |
|
new_state (torch.ByteTensor): The desired state |
|
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) |
|
with a single change: the input state is not cloned. Cloning caused |
|
major performance issues for +4 GPU cases. |
|
""" |
|
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): |
|
|
|
def cb(): |
|
with device_ctx_manager(device): |
|
_C._cuda_setRNGState(new_state) |
|
else: |
|
|
|
if device == -1: |
|
device = torch.device('cuda') |
|
elif isinstance(device, str): |
|
device = torch.device(device) |
|
elif isinstance(device, int): |
|
device = torch.device('cuda', device) |
|
|
|
def cb(): |
|
idx = device.index |
|
if idx is None: |
|
idx = torch.cuda.current_device() |
|
default_generator = torch.cuda.default_generators[idx] |
|
default_generator.set_state(new_state) |
|
|
|
_lazy_call(cb) |
|
|
|
|
|
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): |
|
"""Break a tensor into equal 1D chunks.""" |
|
partition_size = torch.numel(tensor) // \ |
|
get_tensor_model_parallel_world_size() |
|
start_index = partition_size * get_tensor_model_parallel_rank() |
|
end_index = start_index + partition_size |
|
if new_buffer: |
|
data = torch.empty(partition_size, dtype=tensor.dtype, |
|
device=torch.cuda.current_device(), |
|
requires_grad=False) |
|
data.copy_(tensor.view(-1)[start_index:end_index]) |
|
else: |
|
data = tensor.view(-1)[start_index:end_index] |
|
return data |
|
|
|
|
|
def gather_split_1d_tensor(tensor): |
|
"""Opposite of above function, gather values from model parallel ranks.""" |
|
numel_gathered = torch.numel(tensor) * \ |
|
get_tensor_model_parallel_world_size() |
|
gathered = torch.empty(numel_gathered, dtype=tensor.dtype, |
|
device=torch.cuda.current_device(), |
|
requires_grad=False) |
|
|
|
|
|
|
|
|
|
|
|
torch.distributed._all_gather_base(gathered, tensor, |
|
group=get_tensor_model_parallel_group()) |
|
return gathered |
|
|
|
|
|
def _kernel_make_viewless_tensor(inp, requires_grad): |
|
'''Make a viewless tensor. |
|
|
|
View tensors have the undesirable side-affect of retaining a reference |
|
to the originally-viewed tensor, even after manually setting the '.data' |
|
field. This method creates a new tensor that links to the old tensor's |
|
data, without linking the viewed tensor, referenced via the '._base' |
|
field. |
|
''' |
|
out = torch.empty( |
|
(1,), |
|
dtype = inp.dtype, |
|
device = inp.device, |
|
requires_grad = requires_grad, |
|
) |
|
out.data = inp.data |
|
return out |
|
|
|
class MakeViewlessTensor(torch.autograd.Function): |
|
''' |
|
Autograd function to make a viewless tensor. |
|
|
|
This function should be used in cases where the computation graph needs |
|
to be propagated, but we only want a viewless tensor (e.g., |
|
ParallelTransformer's hidden_states). Call this function by passing |
|
'keep_graph = True' to 'make_viewless_tensor()'. |
|
''' |
|
@staticmethod |
|
def forward(ctx, inp, requires_grad): |
|
return _kernel_make_viewless_tensor(inp, requires_grad) |
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return grad_output, None |
|
|
|
def make_viewless_tensor(inp, requires_grad, keep_graph): |
|
''' |
|
Entry-point for creating viewless tensors. |
|
|
|
This method should be used, rather than calling 'MakeViewlessTensor' |
|
or '_kernel_make_viewless_tensor' directly. This method acts as a |
|
switch for determining if an autograd function or a regular method |
|
should be used to create the tensor. |
|
''' |
|
|
|
|
|
if inp._base is None: |
|
return inp |
|
|
|
|
|
if keep_graph: |
|
return MakeViewlessTensor.apply(inp, requires_grad) |
|
else: |
|
return _kernel_make_viewless_tensor(inp, requires_grad) |
|
|
|
def assert_viewless_tensor(tensor, extra_msg = None): |
|
'''Assert that a tensor is not a view (i.e., its '._base' field is |
|
not set).''' |
|
if isinstance(tensor, list): |
|
[ assert_viewless_tensor(t) for t in tensor ] |
|
return tensor |
|
if not isinstance(tensor, torch.Tensor): |
|
return tensor |
|
assert tensor._base is None, ( |
|
"Ensure tensor._base is None before setting tensor.data or storing " |
|
"tensor to memory buffer. Otherwise, a memory leak will occur (and " |
|
"likely accumulate over iterations). %s" |
|
) % extra_msg |
|
return tensor |
|
|
|
def safely_set_viewless_tensor_data(tensor, new_data_tensor): |
|
'''Safely set tensor's '.data' field. |
|
|
|
Check first that the tensor is viewless (i.e., '._base' not set). If not, |
|
raise an exception. |
|
''' |
|
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape)) |
|
tensor.data = new_data_tensor |
|
|
|
|
|
class CudaRNGStatesTracker: |
|
"""Tracker for the cuda RNG states. |
|
|
|
Using the `add` method, a cuda rng state is initialized based on |
|
the input `seed` and is assigned to `name`. Later, by forking the |
|
rng state, we can perform operations and return to our starting |
|
cuda state. |
|
""" |
|
|
|
def __init__(self): |
|
|
|
self.states_ = {} |
|
|
|
self.seeds_ = set() |
|
|
|
def reset(self): |
|
"""Set to the initial state (no tracker).""" |
|
self.states_ = {} |
|
self.seeds_ = set() |
|
|
|
def get_states(self): |
|
"""Get rng states. Copy the dictionary so we have direct |
|
pointers to the states, not just a pointer to the dictionary.""" |
|
states = {} |
|
for name in self.states_: |
|
states[name] = self.states_[name] |
|
return states |
|
|
|
def set_states(self, states): |
|
"""Set the rng states. For efficiency purposes, we do not check |
|
the size of seed for compatibility.""" |
|
self.states_ = states |
|
|
|
def add(self, name, seed): |
|
"""Track the rng state.""" |
|
|
|
if seed in self.seeds_: |
|
raise Exception('seed {} already exists'.format(seed)) |
|
self.seeds_.add(seed) |
|
|
|
if name in self.states_: |
|
raise Exception('cuda rng state {} already exists'.format(name)) |
|
|
|
orig_rng_state = torch.cuda.get_rng_state() |
|
|
|
torch.cuda.manual_seed(seed) |
|
self.states_[name] = torch.cuda.get_rng_state() |
|
|
|
_set_cuda_rng_state(orig_rng_state) |
|
|
|
@contextlib.contextmanager |
|
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): |
|
"""Fork the cuda rng state, perform operations, and exit with |
|
the original state.""" |
|
|
|
if name not in self.states_: |
|
raise Exception('cuda rng state {} is not added'.format(name)) |
|
|
|
orig_cuda_rng_state = torch.cuda.get_rng_state() |
|
|
|
_set_cuda_rng_state(self.states_[name]) |
|
|
|
try: |
|
yield |
|
finally: |
|
|
|
self.states_[name] = torch.cuda.get_rng_state() |
|
|
|
_set_cuda_rng_state(orig_cuda_rng_state) |
|
|
|
|
|
|
|
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() |
|
|
|
|
|
def get_cuda_rng_tracker(): |
|
"""Get cuda rng tracker.""" |
|
return _CUDA_RNG_STATE_TRACKER |
|
|
|
|
|
def model_parallel_cuda_manual_seed(seed): |
|
"""Initialize model parallel cuda seed. |
|
|
|
This function should be called after the model parallel is |
|
initialized. Also, no torch.cuda.manual_seed should be called |
|
after this function. Basically, this is replacement for that |
|
function. |
|
Two set of RNG states are tracked: |
|
default state: This is for data parallelism and is the same among a |
|
set of model parallel GPUs but different across |
|
different model paralle groups. This is used for |
|
example for dropout in the non-tensor-model-parallel regions. |
|
tensor-model-parallel state: This state is different among a set of model |
|
parallel GPUs, but the same across data parallel |
|
groups. This is used for example for dropout in |
|
model parallel regions. |
|
""" |
|
|
|
offset = seed + 2718 |
|
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank() |
|
|
|
data_parallel_seed = seed |
|
|
|
if torch.distributed.get_rank() == 0: |
|
print('> initializing model parallel cuda seeds on global rank {}, ' |
|
'model parallel rank {}, and data parallel rank {} with ' |
|
'model parallel seed: {} and data parallel seed: {}'.format( |
|
torch.distributed.get_rank(), get_tensor_model_parallel_rank(), |
|
get_data_parallel_rank(), tensor_model_parallel_seed, |
|
data_parallel_seed), flush=True) |
|
_CUDA_RNG_STATE_TRACKER.reset() |
|
|
|
torch.cuda.manual_seed(data_parallel_seed) |
|
|
|
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, |
|
tensor_model_parallel_seed) |
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function): |
|
"""This function is adapted from torch.utils.checkpoint with |
|
two main changes: |
|
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` |
|
2) the states in the model parallel tracker are also properly |
|
tracked/set/reset. |
|
""" |
|
@staticmethod |
|
def forward(ctx, run_function, distribute_saved_activations, *args): |
|
ctx.run_function = run_function |
|
ctx.distribute_saved_activations \ |
|
= distribute_saved_activations |
|
|
|
|
|
ctx.fwd_cpu_rng_state = torch.get_rng_state() |
|
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() |
|
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() |
|
|
|
with torch.no_grad(): |
|
outputs = run_function(*args) |
|
|
|
|
|
|
|
if distribute_saved_activations: |
|
ctx.input_0_shape = args[0].data.shape |
|
safely_set_viewless_tensor_data( |
|
args[0], |
|
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True)) |
|
|
|
|
|
ctx.save_for_backward(*args) |
|
|
|
return outputs |
|
|
|
@staticmethod |
|
def backward(ctx, *args): |
|
if not torch.autograd._is_checkpoint_valid(): |
|
raise RuntimeError("Checkpointing is not compatible with .grad(), " |
|
"please use .backward() if possible") |
|
inputs = ctx.saved_tensors |
|
if ctx.distribute_saved_activations: |
|
safely_set_viewless_tensor_data( |
|
inputs[0], |
|
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape)) |
|
|
|
|
|
bwd_cpu_rng_state = torch.get_rng_state() |
|
bwd_cuda_rng_state = torch.cuda.get_rng_state() |
|
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() |
|
|
|
|
|
torch.set_rng_state(ctx.fwd_cpu_rng_state) |
|
_set_cuda_rng_state(ctx.fwd_cuda_rng_state) |
|
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) |
|
|
|
|
|
detached_inputs = detach_variable(inputs) |
|
with torch.enable_grad(): |
|
outputs = ctx.run_function(*detached_inputs) |
|
|
|
|
|
torch.set_rng_state(bwd_cpu_rng_state) |
|
_set_cuda_rng_state(bwd_cuda_rng_state) |
|
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) |
|
|
|
if isinstance(outputs, torch.Tensor): |
|
outputs = (outputs,) |
|
torch.autograd.backward(outputs, args) |
|
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp |
|
for inp in detached_inputs) |
|
return (None, None) + grads |
|
|
|
|
|
def checkpoint(function, distribute_saved_activations, *args): |
|
"""Checkpoint a model or part of the model. |
|
This has been directly copied from torch.utils.checkpoint.""" |
|
return CheckpointFunction.apply(function, |
|
distribute_saved_activations, *args) |
|
|