from collections import defaultdict from enum import Enum from typing import cast, Iterable, Tuple, Union import torch from captum._utils.common import _format_tensor_into_tuples, _register_backward_hook from torch import Tensor from torch.nn import Module def _reset_sample_grads(module: Module): module.weight.sample_grad = 0 # type: ignore if module.bias is not None: module.bias.sample_grad = 0 # type: ignore def linear_param_grads( module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False ) -> None: r""" Computes parameter gradients per sample for nn.Linear module, given module input activations and output gradients. Gradients are accumulated in the sample_grad attribute of each parameter (weight and bias). If reset = True, any current sample_grad values are reset, otherwise computed gradients are accumulated and added to the existing stored gradients. Inputs with more than 2 dimensions are only supported with torch 1.8 or later """ if reset: _reset_sample_grads(module) module.weight.sample_grad += torch.einsum( # type: ignore "n...i,n...j->nij", gradient_out, activation ) if module.bias is not None: module.bias.sample_grad += torch.einsum( # type: ignore "n...i->ni", gradient_out ) def conv2d_param_grads( module: Module, activation: Tensor, gradient_out: Tensor, reset: bool = False ) -> None: r""" Computes parameter gradients per sample for nn.Conv2d module, given module input activations and output gradients. nn.Conv2d modules with padding set to a string option ('same' or 'valid') are currently unsupported. Gradients are accumulated in the sample_grad attribute of each parameter (weight and bias). If reset = True, any current sample_grad values are reset, otherwise computed gradients are accumulated and added to the existing stored gradients. """ if reset: _reset_sample_grads(module) batch_size = cast(int, activation.shape[0]) unfolded_act = torch.nn.functional.unfold( activation, cast(Union[int, Tuple[int, ...]], module.kernel_size), dilation=cast(Union[int, Tuple[int, ...]], module.dilation), padding=cast(Union[int, Tuple[int, ...]], module.padding), stride=cast(Union[int, Tuple[int, ...]], module.stride), ) reshaped_grad = gradient_out.reshape(batch_size, -1, unfolded_act.shape[-1]) grad1 = torch.einsum("ijk,ilk->ijl", reshaped_grad, unfolded_act) shape = [batch_size] + list(cast(Iterable[int], module.weight.shape)) module.weight.sample_grad += grad1.reshape(shape) # type: ignore if module.bias is not None: module.bias.sample_grad += torch.sum(reshaped_grad, dim=2) # type: ignore SUPPORTED_MODULES = { torch.nn.Conv2d: conv2d_param_grads, torch.nn.Linear: linear_param_grads, } class LossMode(Enum): SUM = 0 MEAN = 1 class SampleGradientWrapper: r""" Wrapper which allows computing sample-wise gradients in a single backward pass. This is accomplished by adding hooks to capture activations and output gradients for supported modules, and using these activations and gradients to compute the parameter gradients per-sample. Currently, only nn.Linear and nn.Conv2d modules are supported. Similar reference implementations of sample-based gradients include: - - """ def __init__(self, model): self.model = model self.hooks_added = False self.activation_dict = defaultdict(list) self.gradient_dict = defaultdict(list) self.forward_hooks = [] self.backward_hooks = [] def add_hooks(self): self.hooks_added = True self.model.apply(self._register_module_hooks) def _register_module_hooks(self, module: torch.nn.Module): if isinstance(module, tuple(SUPPORTED_MODULES.keys())): self.forward_hooks.append( module.register_forward_hook(self._forward_hook_fn) ) self.backward_hooks.append( _register_backward_hook(module, self._backward_hook_fn, None) ) def _forward_hook_fn( self, module: Module, module_input: Union[Tensor, Tuple[Tensor, ...]], module_output: Union[Tensor, Tuple[Tensor, ...]], ): inp_tuple = _format_tensor_into_tuples(module_input) self.activation_dict[module].append(inp_tuple[0].clone().detach()) def _backward_hook_fn( self, module: Module, grad_input: Union[Tensor, Tuple[Tensor, ...]], grad_output: Union[Tensor, Tuple[Tensor, ...]], ): grad_output_tuple = _format_tensor_into_tuples(grad_output) self.gradient_dict[module].append(grad_output_tuple[0].clone().detach()) def remove_hooks(self): self.hooks_added = False for hook in self.forward_hooks: hook.remove() for hook in self.backward_hooks: hook.remove() self.forward_hooks = [] self.backward_hooks = [] def _reset(self): self.activation_dict = defaultdict(list) self.gradient_dict = defaultdict(list) def compute_param_sample_gradients(self, loss_blob, loss_mode="mean"): assert ( loss_mode.upper() in LossMode.__members__ ), f"Provided loss mode {loss_mode} is not valid" mode = LossMode[loss_mode.upper()] self.model.zero_grad() loss_blob.backward(gradient=torch.ones_like(loss_blob)) for module in self.gradient_dict: sample_grad_fn = SUPPORTED_MODULES[type(module)] activations = self.activation_dict[module] gradients = self.gradient_dict[module] assert len(activations) == len(gradients), ( "Number of saved activations do not match number of saved gradients." " This may occur if multiple forward passes are run without calling" " reset or computing param gradients." ) # Reversing grads since when a module is used multiple times, # the activations will be aligned with the reverse order of the gradients, # since the order is reversed in backprop. for i, (act, grad) in enumerate( zip(activations, list(reversed(gradients))) ): mult = 1 if mode is LossMode.SUM else act.shape[0] sample_grad_fn(module, act, grad * mult, reset=(i == 0)) self._reset()