Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
import math | |
from typing import Any, Callable, cast, Tuple, Union | |
import torch | |
from captum._utils.common import ( | |
_expand_additional_forward_args, | |
_expand_target, | |
_format_additional_forward_args, | |
_format_output, | |
_format_tensor_into_tuples, | |
_is_tuple, | |
_run_forward, | |
) | |
from captum._utils.progress import progress | |
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric | |
from captum.attr._utils.attribution import PerturbationAttribution | |
from captum.attr._utils.common import _format_input_baseline | |
from captum.log import log_usage | |
from torch import dtype, Tensor | |
class FeatureAblation(PerturbationAttribution): | |
r""" | |
A perturbation based approach to computing attribution, involving | |
replacing each input feature with a given baseline / reference, and | |
computing the difference in output. By default, each scalar value within | |
each input tensor is taken as a feature and replaced independently. Passing | |
a feature mask, allows grouping features to be ablated together. This can | |
be used in cases such as images, where an entire segment or region | |
can be ablated, measuring the importance of the segment (feature group). | |
Each input scalar in the group will be given the same attribution value | |
equal to the change in target as a result of ablating the entire feature | |
group. | |
The forward function can either return a scalar per example or a tensor | |
of a fixed sized tensor (or scalar value) for the full batch, i.e. the | |
output does not grow as the batch size increase. If the output is fixed | |
we consider this model to be an "aggregation" of the inputs. In the fixed | |
sized output mode we require `perturbations_per_eval == 1` and the | |
`feature_mask` to be either `None` or for all of them to have 1 as their | |
first dimension (i.e. a feature mask requires to be applied to all inputs). | |
""" | |
def __init__(self, forward_func: Callable) -> None: | |
r""" | |
Args: | |
forward_func (callable): The forward function of the model or | |
any modification of it | |
""" | |
PerturbationAttribution.__init__(self, forward_func) | |
self.use_weights = False | |
def attribute( | |
self, | |
inputs: TensorOrTupleOfTensorsGeneric, | |
baselines: BaselineType = None, | |
target: TargetType = None, | |
additional_forward_args: Any = None, | |
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None, | |
perturbations_per_eval: int = 1, | |
show_progress: bool = False, | |
**kwargs: Any, | |
) -> TensorOrTupleOfTensorsGeneric: | |
r""" | |
Args: | |
inputs (tensor or tuple of tensors): Input for which ablation | |
attributions are computed. If forward_func takes a single | |
tensor as input, a single input tensor should be provided. | |
If forward_func takes multiple tensors as input, a tuple | |
of the input tensors should be provided. It is assumed | |
that for all given input tensors, dimension 0 corresponds | |
to the number of examples (aka batch size), and if | |
multiple input tensors are provided, the examples must | |
be aligned appropriately. | |
baselines (scalar, tensor, tuple of scalars or tensors, optional): | |
Baselines define reference value which replaces each | |
feature when ablated. | |
Baselines can be provided as: | |
- a single tensor, if inputs is a single tensor, with | |
exactly the same dimensions as inputs or | |
broadcastable to match the dimensions of inputs | |
- a single scalar, if inputs is a single tensor, which will | |
be broadcasted for each input value in input tensor. | |
- a tuple of tensors or scalars, the baseline corresponding | |
to each tensor in the inputs' tuple can be: | |
- either a tensor with matching dimensions to | |
corresponding tensor in the inputs' tuple | |
or the first dimension is one and the remaining | |
dimensions match with the corresponding | |
input tensor. | |
- or a scalar, corresponding to a tensor in the | |
inputs' tuple. This scalar value is broadcasted | |
for corresponding input tensor. | |
In the cases when `baselines` is not provided, we internally | |
use zero scalar corresponding to each input tensor. | |
Default: None | |
target (int, tuple, tensor or list, optional): Output indices for | |
which gradients are computed (for classification cases, | |
this is usually the target class). | |
If the network returns a scalar value per example, | |
no target index is necessary. | |
For general 2D outputs, targets can be either: | |
- a single integer or a tensor containing a single | |
integer, which is applied to all input examples | |
- a list of integers or a 1D tensor, with length matching | |
the number of examples in inputs (dim 0). Each integer | |
is applied as the target for the corresponding example. | |
For outputs with > 2 dimensions, targets can be either: | |
- A single tuple, which contains #output_dims - 1 | |
elements. This target index is applied to all examples. | |
- A list of tuples with length equal to the number of | |
examples in inputs (dim 0), and each tuple containing | |
#output_dims - 1 elements. Each tuple is applied as the | |
target for the corresponding example. | |
Default: None | |
additional_forward_args (any, optional): If the forward function | |
requires additional arguments other than the inputs for | |
which attributions should not be computed, this argument | |
can be provided. It must be either a single additional | |
argument of a Tensor or arbitrary (non-tuple) type or a | |
tuple containing multiple additional arguments including | |
tensors or any arbitrary python types. These arguments | |
are provided to forward_func in order following the | |
arguments in inputs. | |
For a tensor, the first dimension of the tensor must | |
correspond to the number of examples. For all other types, | |
the given argument is used for all forward evaluations. | |
Note that attributions are not computed with respect | |
to these arguments. | |
Default: None | |
feature_mask (tensor or tuple of tensors, optional): | |
feature_mask defines a mask for the input, grouping | |
features which should be ablated together. feature_mask | |
should contain the same number of tensors as inputs. | |
Each tensor should | |
be the same size as the corresponding input or | |
broadcastable to match the input tensor. Each tensor | |
should contain integers in the range 0 to num_features | |
- 1, and indices corresponding to the same feature should | |
have the same value. | |
Note that features within each input tensor are ablated | |
independently (not across tensors). | |
If the forward function returns a single scalar per batch, | |
we enforce that the first dimension of each mask must be 1, | |
since attributions are returned batch-wise rather than per | |
example, so the attributions must correspond to the | |
same features (indices) in each input example. | |
If None, then a feature mask is constructed which assigns | |
each scalar within a tensor as a separate feature, which | |
is ablated independently. | |
Default: None | |
perturbations_per_eval (int, optional): Allows ablation of multiple | |
features to be processed simultaneously in one call to | |
forward_fn. | |
Each forward pass will contain a maximum of | |
perturbations_per_eval * #examples samples. | |
For DataParallel models, each batch is split among the | |
available devices, so evaluations on each available | |
device contain at most | |
(perturbations_per_eval * #examples) / num_devices | |
samples. | |
If the forward function's number of outputs does not | |
change as the batch size grows (e.g. if it outputs a | |
scalar value), you must set perturbations_per_eval to 1 | |
and use a single feature mask to describe the features | |
for all examples in the batch. | |
Default: 1 | |
show_progress (bool, optional): Displays the progress of computation. | |
It will try to use tqdm if available for advanced features | |
(e.g. time estimation). Otherwise, it will fallback to | |
a simple output of progress. | |
Default: False | |
**kwargs (Any, optional): Any additional arguments used by child | |
classes of FeatureAblation (such as Occlusion) to construct | |
ablations. These arguments are ignored when using | |
FeatureAblation directly. | |
Default: None | |
Returns: | |
*tensor* or tuple of *tensors* of **attributions**: | |
- **attributions** (*tensor* or tuple of *tensors*): | |
The attributions with respect to each input feature. | |
If the forward function returns | |
a scalar value per example, attributions will be | |
the same size as the provided inputs, with each value | |
providing the attribution of the corresponding input index. | |
If the forward function returns a scalar per batch, then | |
attribution tensor(s) will have first dimension 1 and | |
the remaining dimensions will match the input. | |
If a single tensor is provided as inputs, a single tensor is | |
returned. If a tuple of tensors is provided for inputs, a | |
tuple of corresponding sized tensors is returned. | |
Examples:: | |
>>> # SimpleClassifier takes a single input tensor of size Nx4x4, | |
>>> # and returns an Nx3 tensor of class probabilities. | |
>>> net = SimpleClassifier() | |
>>> # Generating random input with size 2 x 4 x 4 | |
>>> input = torch.randn(2, 4, 4) | |
>>> # Defining FeatureAblation interpreter | |
>>> ablator = FeatureAblation(net) | |
>>> # Computes ablation attribution, ablating each of the 16 | |
>>> # scalar input independently. | |
>>> attr = ablator.attribute(input, target=1) | |
>>> # Alternatively, we may want to ablate features in groups, e.g. | |
>>> # grouping each 2x2 square of the inputs and ablating them together. | |
>>> # This can be done by creating a feature mask as follows, which | |
>>> # defines the feature groups, e.g.: | |
>>> # +---+---+---+---+ | |
>>> # | 0 | 0 | 1 | 1 | | |
>>> # +---+---+---+---+ | |
>>> # | 0 | 0 | 1 | 1 | | |
>>> # +---+---+---+---+ | |
>>> # | 2 | 2 | 3 | 3 | | |
>>> # +---+---+---+---+ | |
>>> # | 2 | 2 | 3 | 3 | | |
>>> # +---+---+---+---+ | |
>>> # With this mask, all inputs with the same value are ablated | |
>>> # simultaneously, and the attribution for each input in the same | |
>>> # group (0, 1, 2, and 3) per example are the same. | |
>>> # The attributions can be calculated as follows: | |
>>> # feature mask has dimensions 1 x 4 x 4 | |
>>> feature_mask = torch.tensor([[[0,0,1,1],[0,0,1,1], | |
>>> [2,2,3,3],[2,2,3,3]]]) | |
>>> attr = ablator.attribute(input, target=1, feature_mask=feature_mask) | |
""" | |
# Keeps track whether original input is a tuple or not before | |
# converting it into a tuple. | |
is_inputs_tuple = _is_tuple(inputs) | |
inputs, baselines = _format_input_baseline(inputs, baselines) | |
additional_forward_args = _format_additional_forward_args( | |
additional_forward_args | |
) | |
num_examples = inputs[0].shape[0] | |
feature_mask = ( | |
_format_tensor_into_tuples(feature_mask) | |
if feature_mask is not None | |
else None | |
) | |
assert ( | |
isinstance(perturbations_per_eval, int) and perturbations_per_eval >= 1 | |
), "Perturbations per evaluation must be an integer and at least 1." | |
with torch.no_grad(): | |
if show_progress: | |
feature_counts = self._get_feature_counts( | |
inputs, feature_mask, **kwargs | |
) | |
total_forwards = ( | |
sum( | |
math.ceil(count / perturbations_per_eval) | |
for count in feature_counts | |
) | |
+ 1 | |
) # add 1 for the initial eval | |
attr_progress = progress( | |
desc=f"{self.get_name()} attribution", total=total_forwards | |
) | |
attr_progress.update(0) | |
# Computes initial evaluation with all features, which is compared | |
# to each ablated result. | |
initial_eval = _run_forward( | |
self.forward_func, inputs, target, additional_forward_args | |
) | |
if show_progress: | |
attr_progress.update() | |
agg_output_mode = FeatureAblation._find_output_mode( | |
perturbations_per_eval, feature_mask | |
) | |
# get as a 2D tensor (if it is not a scalar) | |
if isinstance(initial_eval, torch.Tensor): | |
initial_eval = initial_eval.reshape(1, -1) | |
num_outputs = initial_eval.shape[1] | |
else: | |
num_outputs = 1 | |
if not agg_output_mode: | |
assert ( | |
isinstance(initial_eval, torch.Tensor) | |
and num_outputs == num_examples | |
), ( | |
"expected output of `forward_func` to have " | |
+ "`batch_size` elements for perturbations_per_eval > 1 " | |
+ "and all feature_mask.shape[0] > 1" | |
) | |
# Initialize attribution totals and counts | |
attrib_type = cast( | |
dtype, | |
initial_eval.dtype | |
if isinstance(initial_eval, Tensor) | |
else type(initial_eval), | |
) | |
total_attrib = [ | |
torch.zeros( | |
(num_outputs,) + input.shape[1:], | |
dtype=attrib_type, | |
device=input.device, | |
) | |
for input in inputs | |
] | |
# Weights are used in cases where ablations may be overlapping. | |
if self.use_weights: | |
weights = [ | |
torch.zeros( | |
(num_outputs,) + input.shape[1:], device=input.device | |
).float() | |
for input in inputs | |
] | |
# Iterate through each feature tensor for ablation | |
for i in range(len(inputs)): | |
# Skip any empty input tensors | |
if torch.numel(inputs[i]) == 0: | |
continue | |
for ( | |
current_inputs, | |
current_add_args, | |
current_target, | |
current_mask, | |
) in self._ith_input_ablation_generator( | |
i, | |
inputs, | |
additional_forward_args, | |
target, | |
baselines, | |
feature_mask, | |
perturbations_per_eval, | |
**kwargs, | |
): | |
# modified_eval dimensions: 1D tensor with length | |
# equal to #num_examples * #features in batch | |
modified_eval = _run_forward( | |
self.forward_func, | |
current_inputs, | |
current_target, | |
current_add_args, | |
) | |
if show_progress: | |
attr_progress.update() | |
# (contains 1 more dimension than inputs). This adds extra | |
# dimensions of 1 to make the tensor broadcastable with the inputs | |
# tensor. | |
if not isinstance(modified_eval, torch.Tensor): | |
eval_diff = initial_eval - modified_eval | |
else: | |
if not agg_output_mode: | |
assert ( | |
modified_eval.numel() == current_inputs[0].shape[0] | |
), """expected output of forward_func to grow with | |
batch_size. If this is not the case for your model | |
please set perturbations_per_eval = 1""" | |
eval_diff = ( | |
initial_eval - modified_eval.reshape((-1, num_outputs)) | |
).reshape((-1, num_outputs) + (len(inputs[i].shape) - 1) * (1,)) | |
eval_diff = eval_diff.to(total_attrib[i].device) | |
if self.use_weights: | |
weights[i] += current_mask.float().sum(dim=0) | |
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum( | |
dim=0 | |
) | |
if show_progress: | |
attr_progress.close() | |
# Divide total attributions by counts and return formatted attributions | |
if self.use_weights: | |
attrib = tuple( | |
single_attrib.float() / weight | |
for single_attrib, weight in zip(total_attrib, weights) | |
) | |
else: | |
attrib = tuple(total_attrib) | |
_result = _format_output(is_inputs_tuple, attrib) | |
return _result | |
def _ith_input_ablation_generator( | |
self, | |
i, | |
inputs, | |
additional_args, | |
target, | |
baselines, | |
input_mask, | |
perturbations_per_eval, | |
**kwargs, | |
): | |
""" | |
This method return an generator of ablation perturbations of the i-th input | |
Returns: | |
ablation_iter (generator): yields each perturbation to be evaluated | |
as a tuple (inputs, additional_forward_args, targets, mask). | |
""" | |
extra_args = {} | |
for key, value in kwargs.items(): | |
# For any tuple argument in kwargs, we choose index i of the tuple. | |
if isinstance(value, tuple): | |
extra_args[key] = value[i] | |
else: | |
extra_args[key] = value | |
input_mask = input_mask[i] if input_mask is not None else None | |
min_feature, num_features, input_mask = self._get_feature_range_and_mask( | |
inputs[i], input_mask, **extra_args | |
) | |
num_examples = inputs[0].shape[0] | |
perturbations_per_eval = min(perturbations_per_eval, num_features) | |
baseline = baselines[i] if isinstance(baselines, tuple) else baselines | |
if isinstance(baseline, torch.Tensor): | |
baseline = baseline.reshape((1,) + baseline.shape) | |
if perturbations_per_eval > 1: | |
# Repeat features and additional args for batch size. | |
all_features_repeated = [ | |
torch.cat([inputs[j]] * perturbations_per_eval, dim=0) | |
for j in range(len(inputs)) | |
] | |
additional_args_repeated = ( | |
_expand_additional_forward_args(additional_args, perturbations_per_eval) | |
if additional_args is not None | |
else None | |
) | |
target_repeated = _expand_target(target, perturbations_per_eval) | |
else: | |
all_features_repeated = list(inputs) | |
additional_args_repeated = additional_args | |
target_repeated = target | |
num_features_processed = min_feature | |
while num_features_processed < num_features: | |
current_num_ablated_features = min( | |
perturbations_per_eval, num_features - num_features_processed | |
) | |
# Store appropriate inputs and additional args based on batch size. | |
if current_num_ablated_features != perturbations_per_eval: | |
current_features = [ | |
feature_repeated[0 : current_num_ablated_features * num_examples] | |
for feature_repeated in all_features_repeated | |
] | |
current_additional_args = ( | |
_expand_additional_forward_args( | |
additional_args, current_num_ablated_features | |
) | |
if additional_args is not None | |
else None | |
) | |
current_target = _expand_target(target, current_num_ablated_features) | |
else: | |
current_features = all_features_repeated | |
current_additional_args = additional_args_repeated | |
current_target = target_repeated | |
# Store existing tensor before modifying | |
original_tensor = current_features[i] | |
# Construct ablated batch for features in range num_features_processed | |
# to num_features_processed + current_num_ablated_features and return | |
# mask with same size as ablated batch. ablated_features has dimension | |
# (current_num_ablated_features, num_examples, inputs[i].shape[1:]) | |
# Note that in the case of sparse tensors, the second dimension | |
# may not necessarilly be num_examples and will match the first | |
# dimension of this tensor. | |
current_reshaped = current_features[i].reshape( | |
(current_num_ablated_features, -1) + current_features[i].shape[1:] | |
) | |
ablated_features, current_mask = self._construct_ablated_input( | |
current_reshaped, | |
input_mask, | |
baseline, | |
num_features_processed, | |
num_features_processed + current_num_ablated_features, | |
**extra_args, | |
) | |
# current_features[i] has dimension | |
# (current_num_ablated_features * num_examples, inputs[i].shape[1:]), | |
# which can be provided to the model as input. | |
current_features[i] = ablated_features.reshape( | |
(-1,) + ablated_features.shape[2:] | |
) | |
yield tuple( | |
current_features | |
), current_additional_args, current_target, current_mask | |
# Replace existing tensor at index i. | |
current_features[i] = original_tensor | |
num_features_processed += current_num_ablated_features | |
def _construct_ablated_input( | |
self, expanded_input, input_mask, baseline, start_feature, end_feature, **kwargs | |
): | |
r""" | |
Ablates given expanded_input tensor with given feature mask, feature range, | |
and baselines. expanded_input shape is (`num_features`, `num_examples`, ...) | |
with remaining dimensions corresponding to remaining original tensor | |
dimensions and `num_features` = `end_feature` - `start_feature`. | |
input_mask has same number of dimensions as original input tensor (one less | |
than `expanded_input`), and can have first dimension either 1, applying same | |
feature mask to all examples, or `num_examples`. baseline is expected to | |
be broadcastable to match `expanded_input`. | |
This method returns the ablated input tensor, which has the same | |
dimensionality as `expanded_input` as well as the corresponding mask with | |
either the same dimensionality as `expanded_input` or second dimension | |
being 1. This mask contains 1s in locations which have been ablated (and | |
thus counted towards ablations for that feature) and 0s otherwise. | |
""" | |
current_mask = torch.stack( | |
[input_mask == j for j in range(start_feature, end_feature)], dim=0 | |
).long() | |
ablated_tensor = ( | |
expanded_input * (1 - current_mask).to(expanded_input.dtype) | |
) + (baseline * current_mask.to(expanded_input.dtype)) | |
return ablated_tensor, current_mask | |
def _get_feature_range_and_mask(self, input, input_mask, **kwargs): | |
if input_mask is None: | |
# Obtain feature mask for selected input tensor, matches size of | |
# 1 input example, (1 x inputs[i].shape[1:]) | |
input_mask = torch.reshape( | |
torch.arange(torch.numel(input[0]), device=input.device), | |
input[0:1].shape, | |
).long() | |
return ( | |
torch.min(input_mask).item(), | |
torch.max(input_mask).item() + 1, | |
input_mask, | |
) | |
def _get_feature_counts(self, inputs, feature_mask, **kwargs): | |
"""return the numbers of input features""" | |
if not feature_mask: | |
return tuple(inp[0].numel() if inp.numel() else 0 for inp in inputs) | |
return tuple( | |
(mask.max() - mask.min()).item() + 1 | |
if mask is not None | |
else (inp[0].numel() if inp.numel() else 0) | |
for inp, mask in zip(inputs, feature_mask) | |
) | |
def _find_output_mode( | |
perturbations_per_eval: int, | |
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric], | |
) -> bool: | |
""" | |
Returns True if the output mode is "aggregation output mode" | |
Aggregation output mode is defined as: when there is no 1:1 correspondence | |
with the `num_examples` (`batch_size`) and the amount of outputs your model | |
produces, i.e. the model output does not grow in size as the input becomes | |
larger. | |
We assume this is the case if `perturbations_per_eval == 1` | |
and your feature mask is None or is associated to all | |
examples in a batch (fm.shape[0] == 1 for all fm in feature_mask). | |
""" | |
return perturbations_per_eval == 1 and ( | |
feature_mask is None | |
or all(len(sm.shape) == 0 or sm.shape[0] == 1 for sm in feature_mask) | |
) | |