#!/usr/bin/env python3 from copy import deepcopy from inspect import signature from typing import Any, Callable, cast, Tuple, Union import torch from captum._utils.common import ( _expand_and_update_additional_forward_args, _expand_and_update_baselines, _expand_and_update_target, _format_baseline, _format_tensor_into_tuples, ) from captum._utils.typing import TensorOrTupleOfTensorsGeneric from captum.log import log_usage from captum.metrics._utils.batching import _divide_and_aggregate_metrics from torch import Tensor def default_perturb_func( inputs: TensorOrTupleOfTensorsGeneric, perturb_radius: float = 0.02 ) -> Tuple[Tensor, ...]: r"""A default function for generating perturbations of `inputs` within perturbation radius of `perturb_radius`. This function samples uniformly random from the L_Infinity ball with `perturb_radius` radius. The users can override this function if they prefer to use a different perturbation function. Args: inputs (tensor or a tuple of tensors): The input tensors that we'd like to perturb by adding a random noise sampled unifromly random from an L_infinity ball with a radius `perturb_radius`. radius (float): A radius used for sampling from an L_infinity ball. Returns: perturbed_input (tuple(tensor)): A list of perturbed inputs that are createed by adding noise sampled uniformly random from L_infiniy ball with a radius `perturb_radius` to the original inputs. """ inputs = _format_tensor_into_tuples(inputs) perturbed_input = tuple( input + torch.FloatTensor(input.size()) # type: ignore .uniform_(-perturb_radius, perturb_radius) .to(input.device) for input in inputs ) return perturbed_input @log_usage() def sensitivity_max( explanation_func: Callable, inputs: TensorOrTupleOfTensorsGeneric, perturb_func: Callable = default_perturb_func, perturb_radius: float = 0.02, n_perturb_samples: int = 10, norm_ord: str = "fro", max_examples_per_batch: int = None, **kwargs: Any, ) -> Tensor: r""" Explanation sensitivity measures the extent of explanation change when the input is slightly perturbed. It has been shown that the models that have high explanation sensitivity are prone to adversarial attacks: `Interpretation of Neural Networks is Fragile` https://www.aaai.org/ojs/index.php/AAAI/article/view/4252 `sensitivity_max` metric measures maximum sensitivity of an explanation using Monte Carlo sampling-based approximation. By default in order to do so it samples multiple data points from a sub-space of an L-Infinity ball that has a `perturb_radius` radius using `default_perturb_func` default perturbation function. In a general case users can use any L_p ball or any other custom sampling technique that they prefer by providing a custom `perturb_func`. Note that max sensitivity is similar to Lipschitz Continuity metric however it is more robust and easier to estimate. Since the explanation, for instance an attribution function, may not always be continuous, can lead to unbounded Lipschitz continuity. Therefore the latter isn't always appropriate. More about the Lipschitz Continuity Metric can also be found here `On the Robustness of Interpretability Methods` https://arxiv.org/pdf/1806.08049.pdf and `Towards Robust Interpretability with Self-Explaining Neural Networks` https://papers.nips.cc/paper\ 8003-towards-robust-interpretability- with-self-explaining-neural-networks.pdf More details about sensitivity max can be found here: `On the (In)fidelity and Sensitivity of Explanations` https://arxiv.org/pdf/1901.09392.pdf Args: explanation_func (callable): This function can be the `attribute` method of an attribution algorithm or any other explanation method that returns the explanations. inputs (tensor or tuple of tensors): Input for which explanations are computed. If `explanation_func` takes a single tensor as input, a single input tensor should be provided. If `explanation_func` takes multiple tensors as input, a tuple of the input tensors should be provided. It is assumed that for all given input tensors, dimension 0 corresponds to the number of examples (aka batch size), and if multiple input tensors are provided, the examples must be aligned appropriately. perturb_func (callable): The perturbation function of model inputs. This function takes model inputs and optionally `perturb_radius` if the function takes more than one argument and returns perturbed inputs. If there are more than one inputs passed to sensitivity function those will be passed to `perturb_func` as tuples in the same order as they are passed to sensitivity function. It is important to note that for performance reasons `perturb_func` isn't called for each example individually but on a batch of input examples that are repeated `max_examples_per_batch / batch_size` times within the batch. Default: default_perturb_func perturb_radius (float, optional): The epsilon radius used for sampling. In the `default_perturb_func` it is used as the radius of the L-Infinity ball. In a general case it can serve as a radius of any L_p nom. This argument is passed to `perturb_func` if it takes more than one argument. Default: 0.02 n_perturb_samples (int, optional): The number of times input tensors are perturbed. Each input example in the inputs tensor is expanded `n_perturb_samples` times before calling `perturb_func` function. Default: 10 norm_ord (int, float, inf, -inf, 'fro', 'nuc', optional): The type of norm that is used to compute the norm of the sensitivity matrix which is defined as the difference between the explanation function at its input and perturbed input. Default: 'fro' max_examples_per_batch (int, optional): The number of maximum input examples that are processed together. In case the number of examples (`input batch size * n_perturb_samples`) exceeds `max_examples_per_batch`, they will be sliced into batches of `max_examples_per_batch` examples and processed in a sequential order. If `max_examples_per_batch` is None, all examples are processed together. `max_examples_per_batch` should at least be equal `input batch size` and at most `input batch size * n_perturb_samples`. Default: None **kwargs (Any, optional): Contains a list of arguments that are passed to `explanation_func` explanation function which in some cases could be the `attribute` function of an attribution algorithm. Any additional arguments that need be passed to the explanation function should be included here. For instance, such arguments include: `additional_forward_args`, `baselines` and `target`. Returns: sensitivities (tensor): A tensor of scalar sensitivity scores per input example. The first dimension is equal to the number of examples in the input batch and the second dimension is one. Returned sensitivities are normalized by the magnitudes of the input explanations. Examples:: >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, >>> # and returns an Nx10 tensor of class probabilities. >>> net = ImageClassifier() >>> saliency = Saliency(net) >>> input = torch.randn(2, 3, 32, 32, requires_grad=True) >>> # Computes sensitivity score for saliency maps of class 3 >>> sens = sensitivity_max(saliency.attribute, input, target = 3) """ def _generate_perturbations( current_n_perturb_samples: int, ) -> TensorOrTupleOfTensorsGeneric: r""" The perturbations are generated for each example `current_n_perturb_samples` times. For perfomance reasons we are not calling `perturb_func` on each example but on a batch that contains `current_n_perturb_samples` repeated instances per example. """ inputs_expanded: Union[Tensor, Tuple[Tensor, ...]] = tuple( torch.repeat_interleave(input, current_n_perturb_samples, dim=0) for input in inputs ) if len(inputs_expanded) == 1: inputs_expanded = inputs_expanded[0] return ( perturb_func(inputs_expanded, perturb_radius) if len(signature(perturb_func).parameters) > 1 else perturb_func(inputs_expanded) ) def max_values(input_tnsr: Tensor) -> Tensor: return torch.max(input_tnsr, dim=1).values # type: ignore kwarg_expanded_for = None kwargs_copy: Any = None def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor: inputs_perturbed = _generate_perturbations(current_n_perturb_samples) # copy kwargs and update some of the arguments that need to be expanded nonlocal kwarg_expanded_for nonlocal kwargs_copy if ( kwarg_expanded_for is None or kwarg_expanded_for != current_n_perturb_samples ): kwarg_expanded_for = current_n_perturb_samples kwargs_copy = deepcopy(kwargs) _expand_and_update_additional_forward_args( current_n_perturb_samples, kwargs_copy ) _expand_and_update_target(current_n_perturb_samples, kwargs_copy) if "baselines" in kwargs: baselines = kwargs["baselines"] baselines = _format_baseline( baselines, cast(Tuple[Tensor, ...], inputs) ) if ( isinstance(baselines[0], Tensor) and baselines[0].shape == inputs[0].shape ): _expand_and_update_baselines( cast(Tuple[Tensor, ...], inputs), current_n_perturb_samples, kwargs_copy, ) expl_perturbed_inputs = explanation_func(inputs_perturbed, **kwargs_copy) # tuplize `expl_perturbed_inputs` in case it is not expl_perturbed_inputs = _format_tensor_into_tuples(expl_perturbed_inputs) expl_inputs_expanded = tuple( expl_input.repeat_interleave(current_n_perturb_samples, dim=0) for expl_input in expl_inputs ) sensitivities = torch.cat( [ (expl_input - expl_perturbed).view(expl_perturbed.size(0), -1) for expl_perturbed, expl_input in zip( expl_perturbed_inputs, expl_inputs_expanded ) ], dim=1, ) # compute the norm of original input explanations expl_inputs_norm_expanded = torch.norm( torch.cat( [expl_input.view(expl_input.size(0), -1) for expl_input in expl_inputs], dim=1, ), p=norm_ord, dim=1, keepdim=True, ).repeat_interleave(current_n_perturb_samples, dim=0) expl_inputs_norm_expanded = torch.where( expl_inputs_norm_expanded == 0.0, torch.tensor( 1.0, device=expl_inputs_norm_expanded.device, dtype=expl_inputs_norm_expanded.dtype, ), expl_inputs_norm_expanded, ) # compute the norm for each input noisy example sensitivities_norm = ( torch.norm(sensitivities, p=norm_ord, dim=1, keepdim=True) / expl_inputs_norm_expanded ) return max_values(sensitivities_norm.view(bsz, -1)) inputs = _format_tensor_into_tuples(inputs) # type: ignore bsz = inputs[0].size(0) with torch.no_grad(): expl_inputs = explanation_func(inputs, **kwargs) metrics_max = _divide_and_aggregate_metrics( cast(Tuple[Tensor, ...], inputs), n_perturb_samples, _next_sensitivity_max, max_examples_per_batch=max_examples_per_batch, agg_func=torch.max, ) return metrics_max