strexp / captum /attr /_utils /lrp_rules.py
markytools's picture
added strexp
d61b9c7
raw
history blame
6.29 kB
#!/usr/bin/env python3
from abc import ABC, abstractmethod
import torch
from ..._utils.common import _format_tensor_into_tuples
class PropagationRule(ABC):
"""
Base class for all propagation rule classes, also called Z-Rule.
STABILITY_FACTOR is used to assure that no zero divison occurs.
"""
STABILITY_FACTOR = 1e-9
def forward_hook(self, module, inputs, outputs):
"""Register backward hooks on input and output
tensors of linear layers in the model."""
inputs = _format_tensor_into_tuples(inputs)
self._has_single_input = len(inputs) == 1
self._handle_input_hooks = []
for input in inputs:
if not hasattr(input, "hook_registered"):
input_hook = self._create_backward_hook_input(input.data)
self._handle_input_hooks.append(input.register_hook(input_hook))
input.hook_registered = True
output_hook = self._create_backward_hook_output(outputs.data)
self._handle_output_hook = outputs.register_hook(output_hook)
return outputs.clone()
@staticmethod
def backward_hook_activation(module, grad_input, grad_output):
"""Backward hook to propagate relevance over non-linear activations."""
if (
isinstance(grad_input, tuple)
and isinstance(grad_output, tuple)
and len(grad_input) > len(grad_output)
):
# Adds any additional elements of grad_input if applicable
# This occurs when registering a backward hook on nn.Dropout
# modules, which has an additional element of None in
# grad_input
return grad_output + grad_input[len(grad_output) :]
return grad_output
def _create_backward_hook_input(self, inputs):
def _backward_hook_input(grad):
relevance = grad * inputs
device = grad.device
if self._has_single_input:
self.relevance_input[device] = relevance.data
else:
self.relevance_input[device].append(relevance.data)
return relevance
return _backward_hook_input
def _create_backward_hook_output(self, outputs):
def _backward_hook_output(grad):
sign = torch.sign(outputs)
sign[sign == 0] = 1
relevance = grad / (outputs + sign * self.STABILITY_FACTOR)
self.relevance_output[grad.device] = grad.data
return relevance
return _backward_hook_output
def forward_hook_weights(self, module, inputs, outputs):
"""Save initial activations a_j before modules are changed"""
device = inputs[0].device if isinstance(inputs, tuple) else inputs.device
if hasattr(module, "activations") and device in module.activations:
raise RuntimeError(
"Module {} is being used more than once in the network, which "
"is not supported by LRP. "
"Please ensure that module is being used only once in the "
"network.".format(module)
)
module.activations[device] = tuple(input.data for input in inputs)
self._manipulate_weights(module, inputs, outputs)
@abstractmethod
def _manipulate_weights(self, module, inputs, outputs):
raise NotImplementedError
def forward_pre_hook_activations(self, module, inputs):
"""Pass initial activations to graph generation pass"""
device = inputs[0].device if isinstance(inputs, tuple) else inputs.device
for input, activation in zip(inputs, module.activations[device]):
input.data = activation
return inputs
class EpsilonRule(PropagationRule):
"""
Rule for relevance propagation using a small value of epsilon
to avoid numerical instabilities and remove noise.
Use for middle layers.
Args:
epsilon (integer, float): Value by which is added to the
discriminator during propagation.
"""
def __init__(self, epsilon=1e-9) -> None:
self.STABILITY_FACTOR = epsilon
def _manipulate_weights(self, module, inputs, outputs):
pass
class GammaRule(PropagationRule):
"""
Gamma rule for relevance propagation, gives more importance to
positive relevance.
Use for lower layers.
Args:
gamma (float): The gamma parameter determines by how much
the positive relevance is increased.
"""
def __init__(self, gamma=0.25, set_bias_to_zero=False) -> None:
self.gamma = gamma
self.set_bias_to_zero = set_bias_to_zero
def _manipulate_weights(self, module, inputs, outputs):
if hasattr(module, "weight"):
module.weight.data = (
module.weight.data + self.gamma * module.weight.data.clamp(min=0)
)
if self.set_bias_to_zero and hasattr(module, "bias"):
if module.bias is not None:
module.bias.data = torch.zeros_like(module.bias.data)
class Alpha1_Beta0_Rule(PropagationRule):
"""
Alpha1_Beta0 rule for relevance backpropagation, also known
as Deep-Taylor. Only positive relevance is propagated, resulting
in stable results, therefore recommended as the initial choice.
Warning: Does not work for BatchNorm modules because weight and bias
are defined differently.
Use for lower layers.
"""
def __init__(self, set_bias_to_zero=False) -> None:
self.set_bias_to_zero = set_bias_to_zero
def _manipulate_weights(self, module, inputs, outputs):
if hasattr(module, "weight"):
module.weight.data = module.weight.data.clamp(min=0)
if self.set_bias_to_zero and hasattr(module, "bias"):
if module.bias is not None:
module.bias.data = torch.zeros_like(module.bias.data)
class IdentityRule(EpsilonRule):
"""
Identity rule for skipping layer manipulation and propagating the
relevance over a layer. Only valid for modules with same dimensions for
inputs and outputs.
Can be used for BatchNorm2D.
"""
def _create_backward_hook_input(self, inputs):
def _backward_hook_input(grad):
return self.relevance_output[grad.device]
return _backward_hook_input