Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
from typing import Any, Callable, cast, Tuple, Union | |
import torch | |
from captum._utils.common import ( | |
_expand_additional_forward_args, | |
_expand_target, | |
_format_additional_forward_args, | |
_format_baseline, | |
_format_tensor_into_tuples, | |
_run_forward, | |
ExpansionTypes, | |
safe_div, | |
) | |
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric | |
from captum.log import log_usage | |
from captum.metrics._utils.batching import _divide_and_aggregate_metrics | |
from torch import Tensor | |
def infidelity_perturb_func_decorator(multipy_by_inputs: bool = True) -> Callable: | |
r"""An auxiliary, decorator function that helps with computing | |
perturbations given perturbed inputs. It can be useful for cases | |
when `pertub_func` returns only perturbed inputs and we | |
internally compute the perturbations as | |
(input - perturbed_input) / (input - baseline) if | |
multipy_by_inputs is set to True and | |
(input - perturbed_input) otherwise. | |
If users decorate their `pertub_func` with | |
`@infidelity_perturb_func_decorator` function then their `pertub_func` | |
needs to only return perturbed inputs. | |
Args: | |
multipy_by_inputs (bool): Indicates whether model inputs' | |
multiplier is factored in the computation of | |
attribution scores. | |
""" | |
def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable: | |
r""" | |
Args: | |
pertub_func(callable): Input perturbation function that takes inputs | |
and optionally baselines and returns perturbed inputs | |
Returns: | |
default_perturb_func(callable): Internal default perturbation | |
function that computes the perturbations internally and returns | |
perturbations and perturbed inputs. | |
Examples:: | |
>>> @infidelity_perturb_func_decorator(True) | |
>>> def perturb_fn(inputs): | |
>>> noise = torch.tensor(np.random.normal(0, 0.003, | |
>>> inputs.shape)).float() | |
>>> return inputs - noise | |
>>> # Computes infidelity score using `perturb_fn` | |
>>> infidelity = infidelity(model, perturb_fn, input, ...) | |
""" | |
def default_perturb_func( | |
inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None | |
): | |
r""" """ | |
inputs_perturbed = ( | |
pertub_func(inputs, baselines) | |
if baselines is not None | |
else pertub_func(inputs) | |
) | |
inputs_perturbed = _format_tensor_into_tuples(inputs_perturbed) | |
inputs = _format_tensor_into_tuples(inputs) | |
baselines = _format_baseline(baselines, inputs) | |
if baselines is None: | |
perturbations = tuple( | |
safe_div( | |
input - input_perturbed, | |
input, | |
default_denom=1.0, | |
) | |
if multipy_by_inputs | |
else input - input_perturbed | |
for input, input_perturbed in zip(inputs, inputs_perturbed) | |
) | |
else: | |
perturbations = tuple( | |
safe_div( | |
input - input_perturbed, | |
input - baseline, | |
default_denom=1.0, | |
) | |
if multipy_by_inputs | |
else input - input_perturbed | |
for input, input_perturbed, baseline in zip( | |
inputs, inputs_perturbed, baselines | |
) | |
) | |
return perturbations, inputs_perturbed | |
return default_perturb_func | |
return sub_infidelity_perturb_func_decorator | |
def infidelity( | |
forward_func: Callable, | |
perturb_func: Callable, | |
inputs: TensorOrTupleOfTensorsGeneric, | |
attributions: TensorOrTupleOfTensorsGeneric, | |
baselines: BaselineType = None, | |
additional_forward_args: Any = None, | |
target: TargetType = None, | |
n_perturb_samples: int = 10, | |
max_examples_per_batch: int = None, | |
normalize: bool = False, | |
) -> Tensor: | |
r""" | |
Explanation infidelity represents the expected mean-squared error | |
between the explanation multiplied by a meaningful input perturbation | |
and the differences between the predictor function at its input | |
and perturbed input. | |
More details about the measure can be found in the following paper: | |
https://arxiv.org/pdf/1901.09392.pdf | |
It is derived from the completeness property of well-known attribution | |
algorithms and is a computationally more efficient and generalized | |
notion of Sensitivy-n. The latter measures correlations between the sum | |
of the attributions and the differences of the predictor function at | |
its input and fixed baseline. More details about the Sensitivity-n can | |
be found here: | |
https://arxiv.org/pdf/1711.06104.pdfs | |
The users can perturb the inputs any desired way by providing any | |
perturbation function that takes the inputs (and optionally baselines) | |
and returns perturbed inputs or perturbed inputs and corresponding | |
perturbations. | |
This specific implementation is primarily tested for attribution-based | |
explanation methods but the idea can be expanded to use for non | |
attribution-based interpretability methods as well. | |
Args: | |
forward_func (callable): | |
The forward function of the model or any modification of it. | |
perturb_func (callable): | |
The perturbation function of model inputs. This function takes | |
model inputs and optionally baselines as input arguments and returns | |
either a tuple of perturbations and perturbed inputs or just | |
perturbed inputs. For example: | |
>>> def my_perturb_func(inputs): | |
>>> <MY-LOGIC-HERE> | |
>>> return perturbations, perturbed_inputs | |
If we want to only return perturbed inputs and compute | |
perturbations internally then we can wrap perturb_func with | |
`infidelity_perturb_func_decorator` decorator such as: | |
>>> from captum.metrics import infidelity_perturb_func_decorator | |
>>> @infidelity_perturb_func_decorator(<multipy_by_inputs flag>) | |
>>> def my_perturb_func(inputs): | |
>>> <MY-LOGIC-HERE> | |
>>> return perturbed_inputs | |
In case `multipy_by_inputs` is False we compute perturbations by | |
`input - perturbed_input` difference and in case `multipy_by_inputs` | |
flag is True we compute it by dividing | |
(input - perturbed_input) by (input - baselines). | |
The user needs to only return perturbed inputs in `perturb_func` | |
as described above. | |
`infidelity_perturb_func_decorator` needs to be used with | |
`multipy_by_inputs` flag set to False in case infidelity | |
score is being computed for attribution maps that are local aka | |
that do not factor in inputs in the final attribution score. | |
Such attribution algorithms include Saliency, GradCam, Guided Backprop, | |
or Integrated Gradients and DeepLift attribution scores that are already | |
computed with `multipy_by_inputs=False` flag. | |
If there are more than one inputs passed to infidelity function those | |
will be passed to `perturb_func` as tuples in the same order as they | |
are passed to infidelity function. | |
If inputs | |
- is a single tensor, the function needs to return a tuple | |
of perturbations and perturbed input such as: | |
perturb, perturbed_input and only perturbed_input in case | |
`infidelity_perturb_func_decorator` is used. | |
- is a tuple of tensors, corresponding perturbations and perturbed | |
inputs must be computed and returned as tuples in the | |
following format: | |
(perturb1, perturb2, ... perturbN), (perturbed_input1, | |
perturbed_input2, ... perturbed_inputN) | |
Similar to previous case here as well we need to return only | |
perturbed inputs in case `infidelity_perturb_func_decorator` | |
decorates out `perturb_func`. | |
It is important to note that for performance reasons `perturb_func` | |
isn't called for each example individually but on a batch of | |
input examples that are repeated `max_examples_per_batch / batch_size` | |
times within the batch. | |
inputs (tensor or tuple of tensors): Input for which | |
attributions are computed. If forward_func takes a single | |
tensor as input, a single input tensor should be provided. | |
If forward_func takes multiple tensors as input, a tuple | |
of the input tensors should be provided. It is assumed | |
that for all given input tensors, dimension 0 corresponds | |
to the number of examples (aka batch size), and if | |
multiple input tensors are provided, the examples must | |
be aligned appropriately. | |
baselines (scalar, tensor, tuple of scalars or tensors, optional): | |
Baselines define reference values which sometimes represent ablated | |
values and are used to compare with the actual inputs to compute | |
importance scores in attribution algorithms. They can be represented | |
as: | |
- a single tensor, if inputs is a single tensor, with | |
exactly the same dimensions as inputs or the first | |
dimension is one and the remaining dimensions match | |
with inputs. | |
- a single scalar, if inputs is a single tensor, which will | |
be broadcasted for each input value in input tensor. | |
- a tuple of tensors or scalars, the baseline corresponding | |
to each tensor in the inputs' tuple can be: | |
- either a tensor with matching dimensions to | |
corresponding tensor in the inputs' tuple | |
or the first dimension is one and the remaining | |
dimensions match with the corresponding | |
input tensor. | |
- or a scalar, corresponding to a tensor in the | |
inputs' tuple. This scalar value is broadcasted | |
for corresponding input tensor. | |
Default: None | |
attributions (tensor or tuple of tensors): | |
Attribution scores computed based on an attribution algorithm. | |
This attribution scores can be computed using the implementations | |
provided in the `captum.attr` package. Some of those attribution | |
approaches are so called global methods, which means that | |
they factor in model inputs' multiplier, as described in: | |
https://arxiv.org/pdf/1711.06104.pdf | |
Many global attribution algorithms can be used in local modes, | |
meaning that the inputs multiplier isn't factored in the | |
attribution scores. | |
This can be done duing the definition of the attribution algorithm | |
by passing `multipy_by_inputs=False` flag. | |
For example in case of Integrated Gradients (IG) we can obtain | |
local attribution scores if we define the constructor of IG as: | |
ig = IntegratedGradients(multipy_by_inputs=False) | |
Some attribution algorithms are inherently local. | |
Examples of inherently local attribution methods include: | |
Saliency, Guided GradCam, Guided Backprop and Deconvolution. | |
For local attributions we can use real-valued perturbations | |
whereas for global attributions that perturbation is binary. | |
https://arxiv.org/pdf/1901.09392.pdf | |
If we want to compute the infidelity of global attributions we | |
can use a binary perturbation matrix that will allow us to select | |
a subset of features from `inputs` or `inputs - baselines` space. | |
This will allow us to approximate sensitivity-n for a global | |
attribution algorithm. | |
`infidelity_perturb_func_decorator` function decorator is a helper | |
function that computes perturbations under the hood if perturbed | |
inputs are provided. | |
For more details about how to use `infidelity_perturb_func_decorator`, | |
please, read the documentation about `perturb_func` | |
Attributions have the same shape and dimensionality as the inputs. | |
If inputs is a single tensor then the attributions is a single | |
tensor as well. If inputs is provided as a tuple of tensors | |
then attributions will be tuples of tensors as well. | |
additional_forward_args (any, optional): If the forward function | |
requires additional arguments other than the inputs for | |
which attributions should not be computed, this argument | |
can be provided. It must be either a single additional | |
argument of a Tensor or arbitrary (non-tuple) type or a tuple | |
containing multiple additional arguments including tensors | |
or any arbitrary python types. These arguments are provided to | |
forward_func in order, following the arguments in inputs. | |
Note that the perturbations are not computed with respect | |
to these arguments. This means that these arguments aren't | |
being passed to `perturb_func` as an input argument. | |
Default: None | |
target (int, tuple, tensor or list, optional): Indices for selecting | |
predictions from output(for classification cases, | |
this is usually the target class). | |
If the network returns a scalar value per example, no target | |
index is necessary. | |
For general 2D outputs, targets can be either: | |
- A single integer or a tensor containing a single | |
integer, which is applied to all input examples | |
- A list of integers or a 1D tensor, with length matching | |
the number of examples in inputs (dim 0). Each integer | |
is applied as the target for the corresponding example. | |
For outputs with > 2 dimensions, targets can be either: | |
- A single tuple, which contains #output_dims - 1 | |
elements. This target index is applied to all examples. | |
- A list of tuples with length equal to the number of | |
examples in inputs (dim 0), and each tuple containing | |
#output_dims - 1 elements. Each tuple is applied as the | |
target for the corresponding example. | |
Default: None | |
n_perturb_samples (int, optional): The number of times input tensors | |
are perturbed. Each input example in the inputs tensor is expanded | |
`n_perturb_samples` | |
times before calling `perturb_func` function. | |
Default: 10 | |
max_examples_per_batch (int, optional): The number of maximum input | |
examples that are processed together. In case the number of | |
examples (`input batch size * n_perturb_samples`) exceeds | |
`max_examples_per_batch`, they will be sliced | |
into batches of `max_examples_per_batch` examples and processed | |
in a sequential order. If `max_examples_per_batch` is None, all | |
examples are processed together. `max_examples_per_batch` should | |
at least be equal `input batch size` and at most | |
`input batch size * n_perturb_samples`. | |
Default: None | |
normalize (bool, optional): Normalize the dot product of the input | |
perturbation and the attribution so the infidelity value is invariant | |
to constant scaling of the attribution values. The normalization factor | |
beta is defined as the ratio of two mean values: | |
.. math:: | |
\beta = \frac{ | |
\mathbb{E}_{I \sim \mu_I} [ I^T \Phi(f, x) (f(x) - f(x - I)) ] | |
}{ | |
\mathbb{E}_{I \sim \mu_I} [ (I^T \Phi(f, x))^2 ] | |
} | |
Please refer the original paper for the meaning of the symbols. Same | |
normalization can be found in the paper's official implementation | |
https://github.com/chihkuanyeh/saliency_evaluation | |
Default: False | |
Returns: | |
infidelities (tensor): A tensor of scalar infidelity scores per | |
input example. The first dimension is equal to the | |
number of examples in the input batch and the second | |
dimension is one. | |
Examples:: | |
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32, | |
>>> # and returns an Nx10 tensor of class probabilities. | |
>>> net = ImageClassifier() | |
>>> saliency = Saliency(net) | |
>>> input = torch.randn(2, 3, 32, 32, requires_grad=True) | |
>>> # Computes saliency maps for class 3. | |
>>> attribution = saliency.attribute(input, target=3) | |
>>> # define a perturbation function for the input | |
>>> def perturb_fn(inputs): | |
>>> noise = torch.tensor(np.random.normal(0, 0.003, inputs.shape)).float() | |
>>> return noise, inputs - noise | |
>>> # Computes infidelity score for saliency maps | |
>>> infid = infidelity(net, perturb_fn, input, attribution) | |
""" | |
def _generate_perturbations( | |
current_n_perturb_samples: int, | |
) -> Tuple[TensorOrTupleOfTensorsGeneric, TensorOrTupleOfTensorsGeneric]: | |
r""" | |
The perturbations are generated for each example | |
`current_n_perturb_samples` times. | |
For performance reasons we are not calling `perturb_func` on each example but | |
on a batch that contains `current_n_perturb_samples` | |
repeated instances per example. | |
""" | |
def call_perturb_func(): | |
r""" """ | |
baselines_pert = None | |
inputs_pert: Union[Tensor, Tuple[Tensor, ...]] | |
if len(inputs_expanded) == 1: | |
inputs_pert = inputs_expanded[0] | |
if baselines_expanded is not None: | |
baselines_pert = cast(Tuple, baselines_expanded)[0] | |
else: | |
inputs_pert = inputs_expanded | |
baselines_pert = baselines_expanded | |
return ( | |
perturb_func(inputs_pert, baselines_pert) | |
if baselines_pert is not None | |
else perturb_func(inputs_pert) | |
) | |
inputs_expanded = tuple( | |
torch.repeat_interleave(input, current_n_perturb_samples, dim=0) | |
for input in inputs | |
) | |
baselines_expanded = baselines | |
if baselines is not None: | |
baselines_expanded = tuple( | |
baseline.repeat_interleave(current_n_perturb_samples, dim=0) | |
if isinstance(baseline, torch.Tensor) | |
and baseline.shape[0] == input.shape[0] | |
and baseline.shape[0] > 1 | |
else baseline | |
for input, baseline in zip(inputs, cast(Tuple, baselines)) | |
) | |
return call_perturb_func() | |
def _validate_inputs_and_perturbations( | |
inputs: Tuple[Tensor, ...], | |
inputs_perturbed: Tuple[Tensor, ...], | |
perturbations: Tuple[Tensor, ...], | |
) -> None: | |
# asserts the sizes of the perturbations and inputs | |
assert len(perturbations) == len(inputs), ( | |
"""The number of perturbed | |
inputs and corresponding perturbations must have the same number of | |
elements. Found number of inputs is: {} and perturbations: | |
{}""" | |
).format(len(perturbations), len(inputs)) | |
# asserts the shapes of the perturbations and perturbed inputs | |
for perturb, input_perturbed in zip(perturbations, inputs_perturbed): | |
assert perturb[0].shape == input_perturbed[0].shape, ( | |
"""Perturbed input | |
and corresponding perturbation must have the same shape and | |
dimensionality. Found perturbation shape is: {} and the input shape | |
is: {}""" | |
).format(perturb[0].shape, input_perturbed[0].shape) | |
def _next_infidelity_tensors( | |
current_n_perturb_samples: int, | |
) -> Union[Tuple[Tensor], Tuple[Tensor, Tensor, Tensor]]: | |
perturbations, inputs_perturbed = _generate_perturbations( | |
current_n_perturb_samples | |
) | |
perturbations = _format_tensor_into_tuples(perturbations) | |
inputs_perturbed = _format_tensor_into_tuples(inputs_perturbed) | |
_validate_inputs_and_perturbations( | |
cast(Tuple[Tensor, ...], inputs), | |
cast(Tuple[Tensor, ...], inputs_perturbed), | |
cast(Tuple[Tensor, ...], perturbations), | |
) | |
targets_expanded = _expand_target( | |
target, | |
current_n_perturb_samples, | |
expansion_type=ExpansionTypes.repeat_interleave, | |
) | |
additional_forward_args_expanded = _expand_additional_forward_args( | |
additional_forward_args, | |
current_n_perturb_samples, | |
expansion_type=ExpansionTypes.repeat_interleave, | |
) | |
inputs_perturbed_fwd = _run_forward( | |
forward_func, | |
inputs_perturbed, | |
targets_expanded, | |
additional_forward_args_expanded, | |
) | |
inputs_fwd = _run_forward(forward_func, inputs, target, additional_forward_args) | |
inputs_fwd = torch.repeat_interleave( | |
inputs_fwd, current_n_perturb_samples, dim=0 | |
) | |
perturbed_fwd_diffs = inputs_fwd - inputs_perturbed_fwd | |
attributions_expanded = tuple( | |
torch.repeat_interleave(attribution, current_n_perturb_samples, dim=0) | |
for attribution in attributions | |
) | |
attributions_times_perturb = tuple( | |
(attribution_expanded * perturbation).view(attribution_expanded.size(0), -1) | |
for attribution_expanded, perturbation in zip( | |
attributions_expanded, perturbations | |
) | |
) | |
attr_times_perturb_sums = sum( | |
torch.sum(attribution_times_perturb, dim=1) | |
for attribution_times_perturb in attributions_times_perturb | |
) | |
attr_times_perturb_sums = cast(Tensor, attr_times_perturb_sums) | |
# reshape as Tensor(bsz, current_n_perturb_samples) | |
attr_times_perturb_sums = attr_times_perturb_sums.view(bsz, -1) | |
perturbed_fwd_diffs = perturbed_fwd_diffs.view(bsz, -1) | |
if normalize: | |
# in order to normalize, we have to aggregate the following tensors | |
# to calculate MSE in its polynomial expansion: | |
# (a-b)^2 = a^2 - 2ab + b^2 | |
return ( | |
attr_times_perturb_sums.pow(2).sum(-1), | |
(attr_times_perturb_sums * perturbed_fwd_diffs).sum(-1), | |
perturbed_fwd_diffs.pow(2).sum(-1), | |
) | |
else: | |
# returns (a-b)^2 if no need to normalize | |
return ((attr_times_perturb_sums - perturbed_fwd_diffs).pow(2).sum(-1),) | |
def _sum_infidelity_tensors(agg_tensors, tensors): | |
return tuple(agg_t + t for agg_t, t in zip(agg_tensors, tensors)) | |
# perform argument formattings | |
inputs = _format_tensor_into_tuples(inputs) # type: ignore | |
if baselines is not None: | |
baselines = _format_baseline(baselines, cast(Tuple[Tensor, ...], inputs)) | |
additional_forward_args = _format_additional_forward_args(additional_forward_args) | |
attributions = _format_tensor_into_tuples(attributions) # type: ignore | |
# Make sure that inputs and corresponding attributions have matching sizes. | |
assert len(inputs) == len(attributions), ( | |
"""The number of tensors in the inputs and | |
attributions must match. Found number of tensors in the inputs is: {} and in the | |
attributions: {}""" | |
).format(len(inputs), len(attributions)) | |
for inp, attr in zip(inputs, attributions): | |
assert inp.shape == attr.shape, ( | |
"""Inputs and attributions must have | |
matching shapes. One of the input tensor's shape is {} and the | |
attribution tensor's shape is: {}""" | |
).format(inp.shape, attr.shape) | |
bsz = inputs[0].size(0) | |
with torch.no_grad(): | |
# if not normalize, directly return aggrgated MSE ((a-b)^2,) | |
# else return aggregated MSE's polynomial expansion tensors (a^2, ab, b^2) | |
agg_tensors = _divide_and_aggregate_metrics( | |
cast(Tuple[Tensor, ...], inputs), | |
n_perturb_samples, | |
_next_infidelity_tensors, | |
agg_func=_sum_infidelity_tensors, | |
max_examples_per_batch=max_examples_per_batch, | |
) | |
if normalize: | |
beta_num = agg_tensors[1] | |
beta_denorm = agg_tensors[0] | |
beta = safe_div(beta_num, beta_denorm) | |
infidelity_values = ( | |
beta ** 2 * agg_tensors[0] - 2 * beta * agg_tensors[1] + agg_tensors[2] | |
) | |
else: | |
infidelity_values = agg_tensors[0] | |
infidelity_values /= n_perturb_samples | |
return infidelity_values | |