Spaces:
Build error
Build error
File size: 6,289 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
#!/usr/bin/env python3
from abc import ABC, abstractmethod
import torch
from ..._utils.common import _format_tensor_into_tuples
class PropagationRule(ABC):
"""
Base class for all propagation rule classes, also called Z-Rule.
STABILITY_FACTOR is used to assure that no zero divison occurs.
"""
STABILITY_FACTOR = 1e-9
def forward_hook(self, module, inputs, outputs):
"""Register backward hooks on input and output
tensors of linear layers in the model."""
inputs = _format_tensor_into_tuples(inputs)
self._has_single_input = len(inputs) == 1
self._handle_input_hooks = []
for input in inputs:
if not hasattr(input, "hook_registered"):
input_hook = self._create_backward_hook_input(input.data)
self._handle_input_hooks.append(input.register_hook(input_hook))
input.hook_registered = True
output_hook = self._create_backward_hook_output(outputs.data)
self._handle_output_hook = outputs.register_hook(output_hook)
return outputs.clone()
@staticmethod
def backward_hook_activation(module, grad_input, grad_output):
"""Backward hook to propagate relevance over non-linear activations."""
if (
isinstance(grad_input, tuple)
and isinstance(grad_output, tuple)
and len(grad_input) > len(grad_output)
):
# Adds any additional elements of grad_input if applicable
# This occurs when registering a backward hook on nn.Dropout
# modules, which has an additional element of None in
# grad_input
return grad_output + grad_input[len(grad_output) :]
return grad_output
def _create_backward_hook_input(self, inputs):
def _backward_hook_input(grad):
relevance = grad * inputs
device = grad.device
if self._has_single_input:
self.relevance_input[device] = relevance.data
else:
self.relevance_input[device].append(relevance.data)
return relevance
return _backward_hook_input
def _create_backward_hook_output(self, outputs):
def _backward_hook_output(grad):
sign = torch.sign(outputs)
sign[sign == 0] = 1
relevance = grad / (outputs + sign * self.STABILITY_FACTOR)
self.relevance_output[grad.device] = grad.data
return relevance
return _backward_hook_output
def forward_hook_weights(self, module, inputs, outputs):
"""Save initial activations a_j before modules are changed"""
device = inputs[0].device if isinstance(inputs, tuple) else inputs.device
if hasattr(module, "activations") and device in module.activations:
raise RuntimeError(
"Module {} is being used more than once in the network, which "
"is not supported by LRP. "
"Please ensure that module is being used only once in the "
"network.".format(module)
)
module.activations[device] = tuple(input.data for input in inputs)
self._manipulate_weights(module, inputs, outputs)
@abstractmethod
def _manipulate_weights(self, module, inputs, outputs):
raise NotImplementedError
def forward_pre_hook_activations(self, module, inputs):
"""Pass initial activations to graph generation pass"""
device = inputs[0].device if isinstance(inputs, tuple) else inputs.device
for input, activation in zip(inputs, module.activations[device]):
input.data = activation
return inputs
class EpsilonRule(PropagationRule):
"""
Rule for relevance propagation using a small value of epsilon
to avoid numerical instabilities and remove noise.
Use for middle layers.
Args:
epsilon (integer, float): Value by which is added to the
discriminator during propagation.
"""
def __init__(self, epsilon=1e-9) -> None:
self.STABILITY_FACTOR = epsilon
def _manipulate_weights(self, module, inputs, outputs):
pass
class GammaRule(PropagationRule):
"""
Gamma rule for relevance propagation, gives more importance to
positive relevance.
Use for lower layers.
Args:
gamma (float): The gamma parameter determines by how much
the positive relevance is increased.
"""
def __init__(self, gamma=0.25, set_bias_to_zero=False) -> None:
self.gamma = gamma
self.set_bias_to_zero = set_bias_to_zero
def _manipulate_weights(self, module, inputs, outputs):
if hasattr(module, "weight"):
module.weight.data = (
module.weight.data + self.gamma * module.weight.data.clamp(min=0)
)
if self.set_bias_to_zero and hasattr(module, "bias"):
if module.bias is not None:
module.bias.data = torch.zeros_like(module.bias.data)
class Alpha1_Beta0_Rule(PropagationRule):
"""
Alpha1_Beta0 rule for relevance backpropagation, also known
as Deep-Taylor. Only positive relevance is propagated, resulting
in stable results, therefore recommended as the initial choice.
Warning: Does not work for BatchNorm modules because weight and bias
are defined differently.
Use for lower layers.
"""
def __init__(self, set_bias_to_zero=False) -> None:
self.set_bias_to_zero = set_bias_to_zero
def _manipulate_weights(self, module, inputs, outputs):
if hasattr(module, "weight"):
module.weight.data = module.weight.data.clamp(min=0)
if self.set_bias_to_zero and hasattr(module, "bias"):
if module.bias is not None:
module.bias.data = torch.zeros_like(module.bias.data)
class IdentityRule(EpsilonRule):
"""
Identity rule for skipping layer manipulation and propagating the
relevance over a layer. Only valid for modules with same dimensions for
inputs and outputs.
Can be used for BatchNorm2D.
"""
def _create_backward_hook_input(self, inputs):
def _backward_hook_input(grad):
return self.relevance_output[grad.device]
return _backward_hook_input
|