|
|
|
|
|
|
|
|
|
|
|
import functools |
|
from typing import Any, Dict, List, Tuple, Union |
|
|
|
import torch |
|
import torch.utils.checkpoint as checkpoint |
|
from fairseq import utils |
|
|
|
|
|
def checkpoint_wrapper(m, offload_to_cpu=False): |
|
""" |
|
A friendlier wrapper for performing activation checkpointing. |
|
|
|
Compared to the PyTorch version, this version: |
|
- wraps an nn.Module, so that all subsequent calls will use checkpointing |
|
- handles keyword arguments in the forward |
|
- handles non-Tensor outputs from the forward |
|
|
|
Usage:: |
|
|
|
checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) |
|
a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) |
|
""" |
|
|
|
assert not hasattr( |
|
m, "precheckpoint_forward" |
|
), "checkpoint function has already been applied?" |
|
m.precheckpoint_forward = m.forward |
|
m.forward = functools.partial( |
|
_checkpointed_forward, |
|
m.precheckpoint_forward, |
|
offload_to_cpu, |
|
) |
|
return m |
|
|
|
|
|
def unwrap_checkpoint(m: torch.nn.Module): |
|
""" |
|
unwrap a module and its children from checkpoint_wrapper |
|
""" |
|
for module in m.modules(): |
|
if hasattr(module, "precheckpoint_forward"): |
|
module.forward = module.precheckpoint_forward |
|
del module.precheckpoint_forward |
|
return m |
|
|
|
|
|
def _checkpointed_forward(original_forward, offload_to_cpu, *args, **kwargs): |
|
|
|
|
|
|
|
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) |
|
parent_ctx_dict = {"offload": offload_to_cpu} |
|
output = CheckpointFunction.apply( |
|
original_forward, parent_ctx_dict, kwarg_keys, *flat_args |
|
) |
|
if isinstance(output, torch.Tensor): |
|
return output |
|
else: |
|
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] |
|
if packed_non_tensor_outputs: |
|
output = unpack_non_tensors(output, packed_non_tensor_outputs) |
|
return output |
|
|
|
|
|
def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]: |
|
""" |
|
Usage:: |
|
|
|
kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4) |
|
args, kwargs = unpack_kwargs(kwarg_keys, flat_args) |
|
assert args == [1, 2] |
|
assert kwargs == {"a": 3, "b": 4} |
|
""" |
|
kwarg_keys = [] |
|
flat_args = list(args) |
|
for k, v in kwargs.items(): |
|
kwarg_keys.append(k) |
|
flat_args.append(v) |
|
return kwarg_keys, flat_args |
|
|
|
|
|
def unpack_kwargs( |
|
kwarg_keys: List[str], flat_args: List[Any] |
|
) -> Tuple[List[Any], Dict[str, Any]]: |
|
if len(kwarg_keys) == 0: |
|
return flat_args, {} |
|
args = flat_args[: -len(kwarg_keys)] |
|
kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])} |
|
return args, kwargs |
|
|
|
|
|
def split_non_tensors( |
|
mixed: Union[torch.Tensor, Tuple[Any]] |
|
) -> Tuple[Tuple[torch.Tensor], Dict[str, List[Any]]]: |
|
""" |
|
Usage:: |
|
|
|
x = torch.Tensor([1]) |
|
y = torch.Tensor([2]) |
|
tensors, packed_non_tensors = split_non_tensors((x, y, None, 3)) |
|
recon = unpack_non_tensors(tensors, packed_non_tensors) |
|
assert recon == (x, y, None, 3) |
|
""" |
|
if isinstance(mixed, torch.Tensor): |
|
return (mixed,), None |
|
tensors = [] |
|
packed_non_tensors = {"is_tensor": [], "objects": []} |
|
for o in mixed: |
|
if isinstance(o, torch.Tensor): |
|
packed_non_tensors["is_tensor"].append(True) |
|
tensors.append(o) |
|
else: |
|
packed_non_tensors["is_tensor"].append(False) |
|
packed_non_tensors["objects"].append(o) |
|
return tuple(tensors), packed_non_tensors |
|
|
|
|
|
def unpack_non_tensors( |
|
tensors: Tuple[torch.Tensor], |
|
packed_non_tensors: Dict[str, List[Any]], |
|
) -> Tuple[Any]: |
|
if packed_non_tensors is None: |
|
return tensors |
|
assert isinstance(packed_non_tensors, dict) |
|
mixed = [] |
|
is_tensor_list = packed_non_tensors["is_tensor"] |
|
objects = packed_non_tensors["objects"] |
|
assert len(tensors) + len(objects) == len(is_tensor_list) |
|
obj_i = tnsr_i = 0 |
|
for is_tensor in is_tensor_list: |
|
if is_tensor: |
|
mixed.append(tensors[tnsr_i]) |
|
tnsr_i += 1 |
|
else: |
|
mixed.append(objects[obj_i]) |
|
obj_i += 1 |
|
return tuple(mixed) |
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function): |
|
"""Similar to the torch version, but support non-Tensor outputs. |
|
|
|
The caller is expected to provide a dict (*parent_ctx_dict*) that will hold |
|
the non-Tensor outputs. These should be combined with the Tensor *outputs* |
|
by calling ``unpack_non_tensors``. |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args): |
|
if torch.is_grad_enabled(): |
|
checkpoint.check_backward_validity(args) |
|
|
|
ctx.run_function = run_function |
|
ctx.kwarg_keys = kwarg_keys |
|
ctx.fwd_rng_state = utils.get_rng_state() |
|
|
|
tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args) |
|
if parent_ctx_dict["offload"]: |
|
ctx.fwd_device = tuple(x.device for x in tensor_inputs) |
|
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs) |
|
tensor_inputs = tuple(x.cpu() for x in tensor_inputs) |
|
|
|
else: |
|
ctx.fwd_device, ctx.grad_requirements = None, None |
|
|
|
ctx.save_for_backward(*tensor_inputs) |
|
ctx.packed_non_tensor_inputs = packed_non_tensor_inputs |
|
|
|
with torch.no_grad(): |
|
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) |
|
outputs = run_function(*unpacked_args, **unpacked_kwargs) |
|
|
|
if isinstance(outputs, torch.Tensor): |
|
return outputs |
|
else: |
|
|
|
|
|
|
|
outputs, packed_non_tensor_outputs = split_non_tensors(outputs) |
|
parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs |
|
return outputs |
|
|
|
@staticmethod |
|
def backward(ctx, *args): |
|
if not torch.autograd._is_checkpoint_valid(): |
|
raise RuntimeError( |
|
"Checkpointing is not compatible with .grad(), please use .backward() if possible" |
|
) |
|
|
|
tensor_inputs: Tuple = ctx.saved_tensors |
|
tensor_inputs = checkpoint.detach_variable(tensor_inputs) |
|
if ctx.fwd_device is not None: |
|
tensor_inputs = [ |
|
t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs) |
|
] |
|
for i, need_grad in enumerate(ctx.grad_requirements): |
|
tensor_inputs[i].requires_grad = need_grad |
|
inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) |
|
|
|
|
|
bwd_rng_state = utils.get_rng_state() |
|
|
|
|
|
utils.set_rng_state(ctx.fwd_rng_state) |
|
|
|
with torch.enable_grad(): |
|
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs) |
|
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs) |
|
tensor_outputs, _ = split_non_tensors(outputs) |
|
|
|
utils.set_rng_state(bwd_rng_state) |
|
|
|
|
|
outputs_with_grad = [] |
|
args_with_grad = [] |
|
for i in range(len(tensor_outputs)): |
|
if tensor_outputs[i].requires_grad: |
|
outputs_with_grad.append(tensor_outputs[i]) |
|
args_with_grad.append(args[i]) |
|
if len(outputs_with_grad) == 0: |
|
raise RuntimeError( |
|
"None of the outputs have requires_grad=True, " |
|
"this checkpoint() is not necessary" |
|
) |
|
|
|
torch.autograd.backward(outputs_with_grad, args_with_grad) |
|
|
|
grads = tuple( |
|
inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs |
|
) |
|
return (None, None, None) + grads |
|
|