Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
from copy import deepcopy | |
from inspect import signature | |
from typing import Any, Callable, cast, Tuple, Union | |
import torch | |
from captum._utils.common import ( | |
_expand_and_update_additional_forward_args, | |
_expand_and_update_baselines, | |
_expand_and_update_target, | |
_format_baseline, | |
_format_tensor_into_tuples, | |
) | |
from captum._utils.typing import TensorOrTupleOfTensorsGeneric | |
from captum.log import log_usage | |
from captum.metrics._utils.batching import _divide_and_aggregate_metrics | |
from torch import Tensor | |
def default_perturb_func( | |
inputs: TensorOrTupleOfTensorsGeneric, perturb_radius: float = 0.02 | |
) -> Tuple[Tensor, ...]: | |
r"""A default function for generating perturbations of `inputs` | |
within perturbation radius of `perturb_radius`. | |
This function samples uniformly random from the L_Infinity ball | |
with `perturb_radius` radius. | |
The users can override this function if they prefer to use a | |
different perturbation function. | |
Args: | |
inputs (tensor or a tuple of tensors): The input tensors that we'd | |
like to perturb by adding a random noise sampled unifromly | |
random from an L_infinity ball with a radius `perturb_radius`. | |
radius (float): A radius used for sampling from | |
an L_infinity ball. | |
Returns: | |
perturbed_input (tuple(tensor)): A list of perturbed inputs that | |
are createed by adding noise sampled uniformly random | |
from L_infiniy ball with a radius `perturb_radius` to the | |
original inputs. | |
""" | |
inputs = _format_tensor_into_tuples(inputs) | |
perturbed_input = tuple( | |
input | |
+ torch.FloatTensor(input.size()) # type: ignore | |
.uniform_(-perturb_radius, perturb_radius) | |
.to(input.device) | |
for input in inputs | |
) | |
return perturbed_input | |
def sensitivity_max( | |
explanation_func: Callable, | |
inputs: TensorOrTupleOfTensorsGeneric, | |
perturb_func: Callable = default_perturb_func, | |
perturb_radius: float = 0.02, | |
n_perturb_samples: int = 10, | |
norm_ord: str = "fro", | |
max_examples_per_batch: int = None, | |
**kwargs: Any, | |
) -> Tensor: | |
r""" | |
Explanation sensitivity measures the extent of explanation change when | |
the input is slightly perturbed. It has been shown that the models that | |
have high explanation sensitivity are prone to adversarial attacks: | |
`Interpretation of Neural Networks is Fragile` | |
https://www.aaai.org/ojs/index.php/AAAI/article/view/4252 | |
`sensitivity_max` metric measures maximum sensitivity of an explanation | |
using Monte Carlo sampling-based approximation. By default in order to | |
do so it samples multiple data points from a sub-space of an L-Infinity | |
ball that has a `perturb_radius` radius using `default_perturb_func` | |
default perturbation function. In a general case users can | |
use any L_p ball or any other custom sampling technique that they | |
prefer by providing a custom `perturb_func`. | |
Note that max sensitivity is similar to Lipschitz Continuity metric | |
however it is more robust and easier to estimate. | |
Since the explanation, for instance an attribution function, | |
may not always be continuous, can lead to unbounded | |
Lipschitz continuity. Therefore the latter isn't always appropriate. | |
More about the Lipschitz Continuity Metric can also be found here | |
`On the Robustness of Interpretability Methods` | |
https://arxiv.org/pdf/1806.08049.pdf | |
and | |
`Towards Robust Interpretability with Self-Explaining Neural Networks` | |
https://papers.nips.cc/paper\ | |
8003-towards-robust-interpretability- | |
with-self-explaining-neural-networks.pdf | |
More details about sensitivity max can be found here: | |
`On the (In)fidelity and Sensitivity of Explanations` | |
https://arxiv.org/pdf/1901.09392.pdf | |
Args: | |
explanation_func (callable): | |
This function can be the `attribute` method of an | |
attribution algorithm or any other explanation method | |
that returns the explanations. | |
inputs (tensor or tuple of tensors): Input for which | |
explanations are computed. If `explanation_func` takes a | |
single tensor as input, a single input tensor should | |
be provided. | |
If `explanation_func` takes multiple tensors as input, a tuple | |
of the input tensors should be provided. It is assumed | |
that for all given input tensors, dimension 0 corresponds | |
to the number of examples (aka batch size), and if | |
multiple input tensors are provided, the examples must | |
be aligned appropriately. | |
perturb_func (callable): | |
The perturbation function of model inputs. This function takes | |
model inputs and optionally `perturb_radius` if | |
the function takes more than one argument and returns | |
perturbed inputs. | |
If there are more than one inputs passed to sensitivity function those | |
will be passed to `perturb_func` as tuples in the same order as they | |
are passed to sensitivity function. | |
It is important to note that for performance reasons `perturb_func` | |
isn't called for each example individually but on a batch of | |
input examples that are repeated `max_examples_per_batch / batch_size` | |
times within the batch. | |
Default: default_perturb_func | |
perturb_radius (float, optional): The epsilon radius used for sampling. | |
In the `default_perturb_func` it is used as the radius of | |
the L-Infinity ball. In a general case it can serve as a radius of | |
any L_p nom. | |
This argument is passed to `perturb_func` if it takes more than | |
one argument. | |
Default: 0.02 | |
n_perturb_samples (int, optional): The number of times input tensors | |
are perturbed. Each input example in the inputs tensor is | |
expanded `n_perturb_samples` times before calling | |
`perturb_func` function. | |
Default: 10 | |
norm_ord (int, float, inf, -inf, 'fro', 'nuc', optional): The type of norm | |
that is used to compute the | |
norm of the sensitivity matrix which is defined as the difference | |
between the explanation function at its input and perturbed input. | |
Default: 'fro' | |
max_examples_per_batch (int, optional): The number of maximum input | |
examples that are processed together. In case the number of | |
examples (`input batch size * n_perturb_samples`) exceeds | |
`max_examples_per_batch`, they will be sliced | |
into batches of `max_examples_per_batch` examples and processed | |
in a sequential order. If `max_examples_per_batch` is None, all | |
examples are processed together. `max_examples_per_batch` should | |
at least be equal `input batch size` and at most | |
`input batch size * n_perturb_samples`. | |
Default: None | |
**kwargs (Any, optional): Contains a list of arguments that are passed | |
to `explanation_func` explanation function which in some cases | |
could be the `attribute` function of an attribution algorithm. | |
Any additional arguments that need be passed to the explanation | |
function should be included here. | |
For instance, such arguments include: | |
`additional_forward_args`, `baselines` and `target`. | |
Returns: | |
sensitivities (tensor): A tensor of scalar sensitivity scores per | |
input example. The first dimension is equal to the | |
number of examples in the input batch and the second | |
dimension is one. Returned sensitivities are normalized by | |
the magnitudes of the input explanations. | |
Examples:: | |
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32, | |
>>> # and returns an Nx10 tensor of class probabilities. | |
>>> net = ImageClassifier() | |
>>> saliency = Saliency(net) | |
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True) | |
>>> # Computes sensitivity score for saliency maps of class 3 | |
>>> sens = sensitivity_max(saliency.attribute, input, target = 3) | |
""" | |
def _generate_perturbations( | |
current_n_perturb_samples: int, | |
) -> TensorOrTupleOfTensorsGeneric: | |
r""" | |
The perturbations are generated for each example | |
`current_n_perturb_samples` times. | |
For perfomance reasons we are not calling `perturb_func` on each example but | |
on a batch that contains `current_n_perturb_samples` repeated instances | |
per example. | |
""" | |
inputs_expanded: Union[Tensor, Tuple[Tensor, ...]] = tuple( | |
torch.repeat_interleave(input, current_n_perturb_samples, dim=0) | |
for input in inputs | |
) | |
if len(inputs_expanded) == 1: | |
inputs_expanded = inputs_expanded[0] | |
return ( | |
perturb_func(inputs_expanded, perturb_radius) | |
if len(signature(perturb_func).parameters) > 1 | |
else perturb_func(inputs_expanded) | |
) | |
def max_values(input_tnsr: Tensor) -> Tensor: | |
return torch.max(input_tnsr, dim=1).values # type: ignore | |
kwarg_expanded_for = None | |
kwargs_copy: Any = None | |
def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor: | |
inputs_perturbed = _generate_perturbations(current_n_perturb_samples) | |
# copy kwargs and update some of the arguments that need to be expanded | |
nonlocal kwarg_expanded_for | |
nonlocal kwargs_copy | |
if ( | |
kwarg_expanded_for is None | |
or kwarg_expanded_for != current_n_perturb_samples | |
): | |
kwarg_expanded_for = current_n_perturb_samples | |
kwargs_copy = deepcopy(kwargs) | |
_expand_and_update_additional_forward_args( | |
current_n_perturb_samples, kwargs_copy | |
) | |
_expand_and_update_target(current_n_perturb_samples, kwargs_copy) | |
if "baselines" in kwargs: | |
baselines = kwargs["baselines"] | |
baselines = _format_baseline( | |
baselines, cast(Tuple[Tensor, ...], inputs) | |
) | |
if ( | |
isinstance(baselines[0], Tensor) | |
and baselines[0].shape == inputs[0].shape | |
): | |
_expand_and_update_baselines( | |
cast(Tuple[Tensor, ...], inputs), | |
current_n_perturb_samples, | |
kwargs_copy, | |
) | |
expl_perturbed_inputs = explanation_func(inputs_perturbed, **kwargs_copy) | |
# tuplize `expl_perturbed_inputs` in case it is not | |
expl_perturbed_inputs = _format_tensor_into_tuples(expl_perturbed_inputs) | |
expl_inputs_expanded = tuple( | |
expl_input.repeat_interleave(current_n_perturb_samples, dim=0) | |
for expl_input in expl_inputs | |
) | |
sensitivities = torch.cat( | |
[ | |
(expl_input - expl_perturbed).view(expl_perturbed.size(0), -1) | |
for expl_perturbed, expl_input in zip( | |
expl_perturbed_inputs, expl_inputs_expanded | |
) | |
], | |
dim=1, | |
) | |
# compute the norm of original input explanations | |
expl_inputs_norm_expanded = torch.norm( | |
torch.cat( | |
[expl_input.view(expl_input.size(0), -1) for expl_input in expl_inputs], | |
dim=1, | |
), | |
p=norm_ord, | |
dim=1, | |
keepdim=True, | |
).repeat_interleave(current_n_perturb_samples, dim=0) | |
expl_inputs_norm_expanded = torch.where( | |
expl_inputs_norm_expanded == 0.0, | |
torch.tensor( | |
1.0, | |
device=expl_inputs_norm_expanded.device, | |
dtype=expl_inputs_norm_expanded.dtype, | |
), | |
expl_inputs_norm_expanded, | |
) | |
# compute the norm for each input noisy example | |
sensitivities_norm = ( | |
torch.norm(sensitivities, p=norm_ord, dim=1, keepdim=True) | |
/ expl_inputs_norm_expanded | |
) | |
return max_values(sensitivities_norm.view(bsz, -1)) | |
inputs = _format_tensor_into_tuples(inputs) # type: ignore | |
bsz = inputs[0].size(0) | |
with torch.no_grad(): | |
expl_inputs = explanation_func(inputs, **kwargs) | |
metrics_max = _divide_and_aggregate_metrics( | |
cast(Tuple[Tensor, ...], inputs), | |
n_perturb_samples, | |
_next_sensitivity_max, | |
max_examples_per_batch=max_examples_per_batch, | |
agg_func=torch.max, | |
) | |
return metrics_max | |