File size: 6,715 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
from collections import defaultdict
from enum import Enum
from typing import cast, Iterable, Tuple, Union

import torch
from captum._utils.common import _format_tensor_into_tuples, _register_backward_hook
from torch import Tensor
from torch.nn import Module


def _reset_sample_grads(module: Module):
    module.weight.sample_grad = 0  # type: ignore
    if module.bias is not None:
        module.bias.sample_grad = 0  # type: ignore


def linear_param_grads(
    module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False
) -> None:
    r"""
    Computes parameter gradients per sample for nn.Linear module, given module
    input activations and output gradients.

    Gradients are accumulated in the sample_grad attribute of each parameter
    (weight and bias). If reset = True, any current sample_grad values are reset,
    otherwise computed gradients are accumulated and added to the existing
    stored gradients.

    Inputs with more than 2 dimensions are only supported with torch 1.8 or later
    """
    if reset:
        _reset_sample_grads(module)

    module.weight.sample_grad += torch.einsum(  # type: ignore
        "n...i,n...j->nij", gradient_out, activation
    )
    if module.bias is not None:
        module.bias.sample_grad += torch.einsum(  # type: ignore
            "n...i->ni", gradient_out
        )


def conv2d_param_grads(
    module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False
) -> None:
    r"""
    Computes parameter gradients per sample for nn.Conv2d module, given module
    input activations and output gradients.

    nn.Conv2d modules with padding set to a string option ('same' or 'valid') are
    currently unsupported.

    Gradients are accumulated in the sample_grad attribute of each parameter
    (weight and bias). If reset = True, any current sample_grad values are reset,
    otherwise computed gradients are accumulated and added to the existing
    stored gradients.
    """
    if reset:
        _reset_sample_grads(module)

    batch_size = cast(int, activation.shape[0])
    unfolded_act = torch.nn.functional.unfold(
        activation,
        cast(Union[int, Tuple[int, ...]], module.kernel_size),
        dilation=cast(Union[int, Tuple[int, ...]], module.dilation),
        padding=cast(Union[int, Tuple[int, ...]], module.padding),
        stride=cast(Union[int, Tuple[int, ...]], module.stride),
    )
    reshaped_grad = gradient_out.reshape(batch_size, -1, unfolded_act.shape[-1])
    grad1 = torch.einsum("ijk,ilk->ijl", reshaped_grad, unfolded_act)
    shape = [batch_size] + list(cast(Iterable[int], module.weight.shape))
    module.weight.sample_grad += grad1.reshape(shape)  # type: ignore
    if module.bias is not None:
        module.bias.sample_grad += torch.sum(reshaped_grad, dim=2)  # type: ignore


SUPPORTED_MODULES = {
    torch.nn.Conv2d: conv2d_param_grads,
    torch.nn.Linear: linear_param_grads,
}


class LossMode(Enum):
    SUM = 0
    MEAN = 1


class SampleGradientWrapper:
    r"""
    Wrapper which allows computing sample-wise gradients in a single backward pass.

    This is accomplished by adding hooks to capture activations and output
    gradients for supported modules, and using these activations and gradients
    to compute the parameter gradients per-sample.

    Currently, only nn.Linear and nn.Conv2d modules are supported.

    Similar reference implementations of sample-based gradients include:
    - https://github.com/cybertronai/autograd-hacks
    - https://github.com/pytorch/opacus/tree/main/opacus/grad_sample
    """

    def __init__(self, model):
        self.model = model
        self.hooks_added = False
        self.activation_dict = defaultdict(list)
        self.gradient_dict = defaultdict(list)
        self.forward_hooks = []
        self.backward_hooks = []

    def add_hooks(self):
        self.hooks_added = True
        self.model.apply(self._register_module_hooks)

    def _register_module_hooks(self, module: torch.nn.Module):
        if isinstance(module, tuple(SUPPORTED_MODULES.keys())):
            self.forward_hooks.append(
                module.register_forward_hook(self._forward_hook_fn)
            )
            self.backward_hooks.append(
                _register_backward_hook(module, self._backward_hook_fn, None)
            )

    def _forward_hook_fn(
        self,
        module: Module,
        module_input: Union[Tensor, Tuple[Tensor, ...]],
        module_output: Union[Tensor, Tuple[Tensor, ...]],
    ):
        inp_tuple = _format_tensor_into_tuples(module_input)
        self.activation_dict[module].append(inp_tuple[0].clone().detach())

    def _backward_hook_fn(
        self,
        module: Module,
        grad_input: Union[Tensor, Tuple[Tensor, ...]],
        grad_output: Union[Tensor, Tuple[Tensor, ...]],
    ):
        grad_output_tuple = _format_tensor_into_tuples(grad_output)
        self.gradient_dict[module].append(grad_output_tuple[0].clone().detach())

    def remove_hooks(self):
        self.hooks_added = False

        for hook in self.forward_hooks:
            hook.remove()

        for hook in self.backward_hooks:
            hook.remove()

        self.forward_hooks = []
        self.backward_hooks = []

    def _reset(self):
        self.activation_dict = defaultdict(list)
        self.gradient_dict = defaultdict(list)

    def compute_param_sample_gradients(self, loss_blob, loss_mode="mean"):
        assert (
            loss_mode.upper() in LossMode.__members__
        ), f"Provided loss mode {loss_mode} is not valid"
        mode = LossMode[loss_mode.upper()]

        self.model.zero_grad()
        loss_blob.backward(gradient=torch.ones_like(loss_blob))

        for module in self.gradient_dict:
            sample_grad_fn = SUPPORTED_MODULES[type(module)]
            activations = self.activation_dict[module]
            gradients = self.gradient_dict[module]
            assert len(activations) == len(gradients), (
                "Number of saved activations do not match number of saved gradients."
                " This may occur if multiple forward passes are run without calling"
                " reset or computing param gradients."
            )
            # Reversing grads since when a module is used multiple times,
            # the activations will be aligned with the reverse order of the gradients,
            # since the order is reversed in backprop.
            for i, (act, grad) in enumerate(
                zip(activations, list(reversed(gradients)))
            ):
                mult = 1 if mode is LossMode.SUM else act.shape[0]
                sample_grad_fn(module, act, grad * mult, reset=(i == 0))
        self._reset()