Spaces:
Build error
Build error
File size: 6,715 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
from collections import defaultdict
from enum import Enum
from typing import cast, Iterable, Tuple, Union
import torch
from captum._utils.common import _format_tensor_into_tuples, _register_backward_hook
from torch import Tensor
from torch.nn import Module
def _reset_sample_grads(module: Module):
module.weight.sample_grad = 0 # type: ignore
if module.bias is not None:
module.bias.sample_grad = 0 # type: ignore
def linear_param_grads(
module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False
) -> None:
r"""
Computes parameter gradients per sample for nn.Linear module, given module
input activations and output gradients.
Gradients are accumulated in the sample_grad attribute of each parameter
(weight and bias). If reset = True, any current sample_grad values are reset,
otherwise computed gradients are accumulated and added to the existing
stored gradients.
Inputs with more than 2 dimensions are only supported with torch 1.8 or later
"""
if reset:
_reset_sample_grads(module)
module.weight.sample_grad += torch.einsum( # type: ignore
"n...i,n...j->nij", gradient_out, activation
)
if module.bias is not None:
module.bias.sample_grad += torch.einsum( # type: ignore
"n...i->ni", gradient_out
)
def conv2d_param_grads(
module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False
) -> None:
r"""
Computes parameter gradients per sample for nn.Conv2d module, given module
input activations and output gradients.
nn.Conv2d modules with padding set to a string option ('same' or 'valid') are
currently unsupported.
Gradients are accumulated in the sample_grad attribute of each parameter
(weight and bias). If reset = True, any current sample_grad values are reset,
otherwise computed gradients are accumulated and added to the existing
stored gradients.
"""
if reset:
_reset_sample_grads(module)
batch_size = cast(int, activation.shape[0])
unfolded_act = torch.nn.functional.unfold(
activation,
cast(Union[int, Tuple[int, ...]], module.kernel_size),
dilation=cast(Union[int, Tuple[int, ...]], module.dilation),
padding=cast(Union[int, Tuple[int, ...]], module.padding),
stride=cast(Union[int, Tuple[int, ...]], module.stride),
)
reshaped_grad = gradient_out.reshape(batch_size, -1, unfolded_act.shape[-1])
grad1 = torch.einsum("ijk,ilk->ijl", reshaped_grad, unfolded_act)
shape = [batch_size] + list(cast(Iterable[int], module.weight.shape))
module.weight.sample_grad += grad1.reshape(shape) # type: ignore
if module.bias is not None:
module.bias.sample_grad += torch.sum(reshaped_grad, dim=2) # type: ignore
SUPPORTED_MODULES = {
torch.nn.Conv2d: conv2d_param_grads,
torch.nn.Linear: linear_param_grads,
}
class LossMode(Enum):
SUM = 0
MEAN = 1
class SampleGradientWrapper:
r"""
Wrapper which allows computing sample-wise gradients in a single backward pass.
This is accomplished by adding hooks to capture activations and output
gradients for supported modules, and using these activations and gradients
to compute the parameter gradients per-sample.
Currently, only nn.Linear and nn.Conv2d modules are supported.
Similar reference implementations of sample-based gradients include:
- https://github.com/cybertronai/autograd-hacks
- https://github.com/pytorch/opacus/tree/main/opacus/grad_sample
"""
def __init__(self, model):
self.model = model
self.hooks_added = False
self.activation_dict = defaultdict(list)
self.gradient_dict = defaultdict(list)
self.forward_hooks = []
self.backward_hooks = []
def add_hooks(self):
self.hooks_added = True
self.model.apply(self._register_module_hooks)
def _register_module_hooks(self, module: torch.nn.Module):
if isinstance(module, tuple(SUPPORTED_MODULES.keys())):
self.forward_hooks.append(
module.register_forward_hook(self._forward_hook_fn)
)
self.backward_hooks.append(
_register_backward_hook(module, self._backward_hook_fn, None)
)
def _forward_hook_fn(
self,
module: Module,
module_input: Union[Tensor, Tuple[Tensor, ...]],
module_output: Union[Tensor, Tuple[Tensor, ...]],
):
inp_tuple = _format_tensor_into_tuples(module_input)
self.activation_dict[module].append(inp_tuple[0].clone().detach())
def _backward_hook_fn(
self,
module: Module,
grad_input: Union[Tensor, Tuple[Tensor, ...]],
grad_output: Union[Tensor, Tuple[Tensor, ...]],
):
grad_output_tuple = _format_tensor_into_tuples(grad_output)
self.gradient_dict[module].append(grad_output_tuple[0].clone().detach())
def remove_hooks(self):
self.hooks_added = False
for hook in self.forward_hooks:
hook.remove()
for hook in self.backward_hooks:
hook.remove()
self.forward_hooks = []
self.backward_hooks = []
def _reset(self):
self.activation_dict = defaultdict(list)
self.gradient_dict = defaultdict(list)
def compute_param_sample_gradients(self, loss_blob, loss_mode="mean"):
assert (
loss_mode.upper() in LossMode.__members__
), f"Provided loss mode {loss_mode} is not valid"
mode = LossMode[loss_mode.upper()]
self.model.zero_grad()
loss_blob.backward(gradient=torch.ones_like(loss_blob))
for module in self.gradient_dict:
sample_grad_fn = SUPPORTED_MODULES[type(module)]
activations = self.activation_dict[module]
gradients = self.gradient_dict[module]
assert len(activations) == len(gradients), (
"Number of saved activations do not match number of saved gradients."
" This may occur if multiple forward passes are run without calling"
" reset or computing param gradients."
)
# Reversing grads since when a module is used multiple times,
# the activations will be aligned with the reverse order of the gradients,
# since the order is reversed in backprop.
for i, (act, grad) in enumerate(
zip(activations, list(reversed(gradients)))
):
mult = 1 if mode is LossMode.SUM else act.shape[0]
sample_grad_fn(module, act, grad * mult, reset=(i == 0))
self._reset()
|