File size: 8,568 Bytes
d5175d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
from typing import Any, Dict, List, Tuple, Union
import torch
import torch.utils.checkpoint as checkpoint
from fairseq import utils
def checkpoint_wrapper(m, offload_to_cpu=False):
"""
A friendlier wrapper for performing activation checkpointing.
Compared to the PyTorch version, this version:
- wraps an nn.Module, so that all subsequent calls will use checkpointing
- handles keyword arguments in the forward
- handles non-Tensor outputs from the forward
Usage::
checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True)
a, b = checkpointed_module(x, y=3, z=torch.Tensor([1]))
"""
# should I check whether original_forward has already been set?
assert not hasattr(
m, "precheckpoint_forward"
), "checkpoint function has already been applied?"
m.precheckpoint_forward = m.forward
m.forward = functools.partial(
_checkpointed_forward,
m.precheckpoint_forward, # original_forward
offload_to_cpu,
)
return m
def unwrap_checkpoint(m: torch.nn.Module):
"""
unwrap a module and its children from checkpoint_wrapper
"""
for module in m.modules():
if hasattr(module, "precheckpoint_forward"):
module.forward = module.precheckpoint_forward
del module.precheckpoint_forward
return m
def _checkpointed_forward(original_forward, offload_to_cpu, *args, **kwargs):
# Autograd Functions in PyTorch work best with positional args, since
# the backward must return gradients (or None) for every input argument.
# We can flatten keyword arguments to make this easier.
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
parent_ctx_dict = {"offload": offload_to_cpu}
output = CheckpointFunction.apply(
original_forward, parent_ctx_dict, kwarg_keys, *flat_args
)
if isinstance(output, torch.Tensor):
return output
else:
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
if packed_non_tensor_outputs:
output = unpack_non_tensors(output, packed_non_tensor_outputs)
return output
def pack_kwargs(*args, **kwargs) -> Tuple[List[str], List[Any]]:
"""
Usage::
kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
assert args == [1, 2]
assert kwargs == {"a": 3, "b": 4}
"""
kwarg_keys = []
flat_args = list(args)
for k, v in kwargs.items():
kwarg_keys.append(k)
flat_args.append(v)
return kwarg_keys, flat_args
def unpack_kwargs(
kwarg_keys: List[str], flat_args: List[Any]
) -> Tuple[List[Any], Dict[str, Any]]:
if len(kwarg_keys) == 0:
return flat_args, {}
args = flat_args[: -len(kwarg_keys)]
kwargs = {k: v for k, v in zip(kwarg_keys, flat_args[-len(kwarg_keys) :])}
return args, kwargs
def split_non_tensors(
mixed: Union[torch.Tensor, Tuple[Any]]
) -> Tuple[Tuple[torch.Tensor], Dict[str, List[Any]]]:
"""
Usage::
x = torch.Tensor([1])
y = torch.Tensor([2])
tensors, packed_non_tensors = split_non_tensors((x, y, None, 3))
recon = unpack_non_tensors(tensors, packed_non_tensors)
assert recon == (x, y, None, 3)
"""
if isinstance(mixed, torch.Tensor):
return (mixed,), None
tensors = []
packed_non_tensors = {"is_tensor": [], "objects": []}
for o in mixed:
if isinstance(o, torch.Tensor):
packed_non_tensors["is_tensor"].append(True)
tensors.append(o)
else:
packed_non_tensors["is_tensor"].append(False)
packed_non_tensors["objects"].append(o)
return tuple(tensors), packed_non_tensors
def unpack_non_tensors(
tensors: Tuple[torch.Tensor],
packed_non_tensors: Dict[str, List[Any]],
) -> Tuple[Any]:
if packed_non_tensors is None:
return tensors
assert isinstance(packed_non_tensors, dict)
mixed = []
is_tensor_list = packed_non_tensors["is_tensor"]
objects = packed_non_tensors["objects"]
assert len(tensors) + len(objects) == len(is_tensor_list)
obj_i = tnsr_i = 0
for is_tensor in is_tensor_list:
if is_tensor:
mixed.append(tensors[tnsr_i])
tnsr_i += 1
else:
mixed.append(objects[obj_i])
obj_i += 1
return tuple(mixed)
class CheckpointFunction(torch.autograd.Function):
"""Similar to the torch version, but support non-Tensor outputs.
The caller is expected to provide a dict (*parent_ctx_dict*) that will hold
the non-Tensor outputs. These should be combined with the Tensor *outputs*
by calling ``unpack_non_tensors``.
"""
@staticmethod
def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args):
if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation
checkpoint.check_backward_validity(args)
ctx.run_function = run_function
ctx.kwarg_keys = kwarg_keys
ctx.fwd_rng_state = utils.get_rng_state()
tensor_inputs, packed_non_tensor_inputs = split_non_tensors(args)
if parent_ctx_dict["offload"]:
ctx.fwd_device = tuple(x.device for x in tensor_inputs)
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
tensor_inputs = tuple(x.cpu() for x in tensor_inputs)
else:
ctx.fwd_device, ctx.grad_requirements = None, None
ctx.save_for_backward(*tensor_inputs)
ctx.packed_non_tensor_inputs = packed_non_tensor_inputs
with torch.no_grad():
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
outputs = run_function(*unpacked_args, **unpacked_kwargs)
if isinstance(outputs, torch.Tensor):
return outputs
else:
# Autograd Functions don't like non-Tensor outputs. We can split the
# non-Tensor and Tensor outputs, returning the former by reference
# through *parent_ctx_dict* and returning the latter directly.
outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs
return outputs
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), please use .backward() if possible"
)
tensor_inputs: Tuple = ctx.saved_tensors
tensor_inputs = checkpoint.detach_variable(tensor_inputs)
if ctx.fwd_device is not None:
tensor_inputs = [
t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs)
]
for i, need_grad in enumerate(ctx.grad_requirements):
tensor_inputs[i].requires_grad = need_grad
inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs)
# Store the current states.
bwd_rng_state = utils.get_rng_state()
# Set the states to what it used to be before the forward pass.
utils.set_rng_state(ctx.fwd_rng_state)
with torch.enable_grad():
unpacked_args, unpacked_kwargs = unpack_kwargs(ctx.kwarg_keys, inputs)
outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
tensor_outputs, _ = split_non_tensors(outputs)
# Set the states back to what it was at the start of this function.
utils.set_rng_state(bwd_rng_state)
# Run backward() with only Tensors that require grad
outputs_with_grad = []
args_with_grad = []
for i in range(len(tensor_outputs)):
if tensor_outputs[i].requires_grad:
outputs_with_grad.append(tensor_outputs[i])
args_with_grad.append(args[i])
if len(outputs_with_grad) == 0:
raise RuntimeError(
"None of the outputs have requires_grad=True, "
"this checkpoint() is not necessary"
)
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None for inp in inputs
)
return (None, None, None) + grads
|