Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
from abc import ABC, abstractmethod | |
import torch | |
from ..._utils.common import _format_tensor_into_tuples | |
class PropagationRule(ABC): | |
""" | |
Base class for all propagation rule classes, also called Z-Rule. | |
STABILITY_FACTOR is used to assure that no zero divison occurs. | |
""" | |
STABILITY_FACTOR = 1e-9 | |
def forward_hook(self, module, inputs, outputs): | |
"""Register backward hooks on input and output | |
tensors of linear layers in the model.""" | |
inputs = _format_tensor_into_tuples(inputs) | |
self._has_single_input = len(inputs) == 1 | |
self._handle_input_hooks = [] | |
for input in inputs: | |
if not hasattr(input, "hook_registered"): | |
input_hook = self._create_backward_hook_input(input.data) | |
self._handle_input_hooks.append(input.register_hook(input_hook)) | |
input.hook_registered = True | |
output_hook = self._create_backward_hook_output(outputs.data) | |
self._handle_output_hook = outputs.register_hook(output_hook) | |
return outputs.clone() | |
def backward_hook_activation(module, grad_input, grad_output): | |
"""Backward hook to propagate relevance over non-linear activations.""" | |
if ( | |
isinstance(grad_input, tuple) | |
and isinstance(grad_output, tuple) | |
and len(grad_input) > len(grad_output) | |
): | |
# Adds any additional elements of grad_input if applicable | |
# This occurs when registering a backward hook on nn.Dropout | |
# modules, which has an additional element of None in | |
# grad_input | |
return grad_output + grad_input[len(grad_output) :] | |
return grad_output | |
def _create_backward_hook_input(self, inputs): | |
def _backward_hook_input(grad): | |
relevance = grad * inputs | |
device = grad.device | |
if self._has_single_input: | |
self.relevance_input[device] = relevance.data | |
else: | |
self.relevance_input[device].append(relevance.data) | |
return relevance | |
return _backward_hook_input | |
def _create_backward_hook_output(self, outputs): | |
def _backward_hook_output(grad): | |
sign = torch.sign(outputs) | |
sign[sign == 0] = 1 | |
relevance = grad / (outputs + sign * self.STABILITY_FACTOR) | |
self.relevance_output[grad.device] = grad.data | |
return relevance | |
return _backward_hook_output | |
def forward_hook_weights(self, module, inputs, outputs): | |
"""Save initial activations a_j before modules are changed""" | |
device = inputs[0].device if isinstance(inputs, tuple) else inputs.device | |
if hasattr(module, "activations") and device in module.activations: | |
raise RuntimeError( | |
"Module {} is being used more than once in the network, which " | |
"is not supported by LRP. " | |
"Please ensure that module is being used only once in the " | |
"network.".format(module) | |
) | |
module.activations[device] = tuple(input.data for input in inputs) | |
self._manipulate_weights(module, inputs, outputs) | |
def _manipulate_weights(self, module, inputs, outputs): | |
raise NotImplementedError | |
def forward_pre_hook_activations(self, module, inputs): | |
"""Pass initial activations to graph generation pass""" | |
device = inputs[0].device if isinstance(inputs, tuple) else inputs.device | |
for input, activation in zip(inputs, module.activations[device]): | |
input.data = activation | |
return inputs | |
class EpsilonRule(PropagationRule): | |
""" | |
Rule for relevance propagation using a small value of epsilon | |
to avoid numerical instabilities and remove noise. | |
Use for middle layers. | |
Args: | |
epsilon (integer, float): Value by which is added to the | |
discriminator during propagation. | |
""" | |
def __init__(self, epsilon=1e-9) -> None: | |
self.STABILITY_FACTOR = epsilon | |
def _manipulate_weights(self, module, inputs, outputs): | |
pass | |
class GammaRule(PropagationRule): | |
""" | |
Gamma rule for relevance propagation, gives more importance to | |
positive relevance. | |
Use for lower layers. | |
Args: | |
gamma (float): The gamma parameter determines by how much | |
the positive relevance is increased. | |
""" | |
def __init__(self, gamma=0.25, set_bias_to_zero=False) -> None: | |
self.gamma = gamma | |
self.set_bias_to_zero = set_bias_to_zero | |
def _manipulate_weights(self, module, inputs, outputs): | |
if hasattr(module, "weight"): | |
module.weight.data = ( | |
module.weight.data + self.gamma * module.weight.data.clamp(min=0) | |
) | |
if self.set_bias_to_zero and hasattr(module, "bias"): | |
if module.bias is not None: | |
module.bias.data = torch.zeros_like(module.bias.data) | |
class Alpha1_Beta0_Rule(PropagationRule): | |
""" | |
Alpha1_Beta0 rule for relevance backpropagation, also known | |
as Deep-Taylor. Only positive relevance is propagated, resulting | |
in stable results, therefore recommended as the initial choice. | |
Warning: Does not work for BatchNorm modules because weight and bias | |
are defined differently. | |
Use for lower layers. | |
""" | |
def __init__(self, set_bias_to_zero=False) -> None: | |
self.set_bias_to_zero = set_bias_to_zero | |
def _manipulate_weights(self, module, inputs, outputs): | |
if hasattr(module, "weight"): | |
module.weight.data = module.weight.data.clamp(min=0) | |
if self.set_bias_to_zero and hasattr(module, "bias"): | |
if module.bias is not None: | |
module.bias.data = torch.zeros_like(module.bias.data) | |
class IdentityRule(EpsilonRule): | |
""" | |
Identity rule for skipping layer manipulation and propagating the | |
relevance over a layer. Only valid for modules with same dimensions for | |
inputs and outputs. | |
Can be used for BatchNorm2D. | |
""" | |
def _create_backward_hook_input(self, inputs): | |
def _backward_hook_input(grad): | |
return self.relevance_output[grad.device] | |
return _backward_hook_input | |