Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
import typing | |
from inspect import signature | |
from typing import Any, Callable, List, Tuple, TYPE_CHECKING, Union | |
import torch | |
from captum._utils.common import ( | |
_format_baseline, | |
_format_output, | |
_format_tensor_into_tuples, | |
_validate_input as _validate_input_basic, | |
) | |
from captum._utils.typing import ( | |
BaselineType, | |
Literal, | |
TargetType, | |
TensorOrTupleOfTensorsGeneric, | |
) | |
from captum.attr._utils.approximation_methods import SUPPORTED_METHODS | |
from torch import Tensor | |
if TYPE_CHECKING: | |
from captum.attr._utils.attribution import GradientAttribution | |
def _sum_rows(input: Tensor) -> Tensor: | |
return input.reshape(input.shape[0], -1).sum(1) | |
def _validate_target(num_samples: int, target: TargetType) -> None: | |
if isinstance(target, list) or ( | |
isinstance(target, torch.Tensor) and torch.numel(target) > 1 | |
): | |
assert num_samples == len(target), ( | |
"The number of samples provied in the" | |
"input {} does not match with the number of targets. {}".format( | |
num_samples, len(target) | |
) | |
) | |
def _validate_input( | |
inputs: Tuple[Tensor, ...], | |
baselines: Tuple[Union[Tensor, int, float], ...], | |
n_steps: int = 50, | |
method: str = "riemann_trapezoid", | |
draw_baseline_from_distrib: bool = False, | |
) -> None: | |
_validate_input_basic(inputs, baselines, draw_baseline_from_distrib) | |
assert ( | |
n_steps >= 0 | |
), "The number of steps must be a positive integer. " "Given: {}".format(n_steps) | |
assert ( | |
method in SUPPORTED_METHODS | |
), "Approximation method must be one for the following {}. " "Given {}".format( | |
SUPPORTED_METHODS, method | |
) | |
def _validate_noise_tunnel_type( | |
nt_type: str, supported_noise_tunnel_types: List[str] | |
) -> None: | |
assert nt_type in supported_noise_tunnel_types, ( | |
"Noise types must be either `smoothgrad`, `smoothgrad_sq` or `vargrad`. " | |
"Given {}".format(nt_type) | |
) | |
def _format_input_baseline( | |
inputs: Union[Tensor, Tuple[Tensor, ...]], | |
baselines: Union[Tensor, Tuple[Tensor, ...]], | |
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]: | |
... | |
def _format_input_baseline( | |
inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType | |
) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: | |
... | |
def _format_input_baseline( | |
inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType | |
) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]: | |
inputs = _format_tensor_into_tuples(inputs) | |
baselines = _format_baseline(baselines, inputs) | |
return inputs, baselines | |
# This function can potentially be merged with the `format_baseline` function | |
# however, since currently not all algorithms support baselines of type | |
# callable this will be kept in a separate function. | |
def _format_callable_baseline( | |
baselines: Union[ | |
None, | |
Callable[..., Union[Tensor, Tuple[Tensor, ...]]], | |
Tensor, | |
Tuple[Tensor, ...], | |
], | |
inputs: Union[Tensor, Tuple[Tensor, ...]], | |
) -> Tuple[Tensor, ...]: | |
... | |
def _format_callable_baseline( | |
baselines: Union[ | |
None, | |
Callable[..., Union[Tensor, Tuple[Tensor, ...]]], | |
Tensor, | |
int, | |
float, | |
Tuple[Union[Tensor, int, float], ...], | |
], | |
inputs: Union[Tensor, Tuple[Tensor, ...]], | |
) -> Tuple[Union[Tensor, int, float], ...]: | |
... | |
def _format_callable_baseline( | |
baselines: Union[ | |
None, | |
Callable[..., Union[Tensor, Tuple[Tensor, ...]]], | |
Tensor, | |
int, | |
float, | |
Tuple[Union[Tensor, int, float], ...], | |
], | |
inputs: Union[Tensor, Tuple[Tensor, ...]], | |
) -> Tuple[Union[Tensor, int, float], ...]: | |
if callable(baselines): | |
# Note: this assumes that if baselines is a function and if it takes | |
# arguments, then the first argument is the `inputs`. | |
# This can be expanded in the future with better type checks | |
baseline_parameters = signature(baselines).parameters | |
if len(baseline_parameters) == 0: | |
baselines = baselines() | |
else: | |
baselines = baselines(inputs) | |
return _format_baseline(baselines, _format_tensor_into_tuples(inputs)) | |
def _format_and_verify_strides( | |
strides: Union[None, int, Tuple[int, ...], Tuple[Union[int, Tuple[int, ...]], ...]], | |
inputs: Tuple[Tensor, ...], | |
) -> Tuple[Union[int, Tuple[int, ...]], ...]: | |
# Formats strides, which are necessary for occlusion | |
# Assumes inputs are already formatted (in tuple) | |
if strides is None: | |
strides = tuple(1 for input in inputs) | |
if len(inputs) == 1 and not (isinstance(strides, tuple) and len(strides) == 1): | |
strides = (strides,) # type: ignore | |
assert isinstance(strides, tuple) and len(strides) == len( | |
inputs | |
), "Strides must be provided for each input tensor." | |
for i in range(len(inputs)): | |
assert isinstance(strides[i], int) or ( | |
isinstance(strides[i], tuple) | |
and len(strides[i]) == len(inputs[i].shape) - 1 # type: ignore | |
), ( | |
"Stride for input index {} is {}, which is invalid for input with " | |
"shape {}. It must be either an int or a tuple with length equal to " | |
"len(input_shape) - 1." | |
).format( | |
i, strides[i], inputs[i].shape | |
) | |
return strides | |
def _format_and_verify_sliding_window_shapes( | |
sliding_window_shapes: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...]], | |
inputs: Tuple[Tensor, ...], | |
) -> Tuple[Tuple[int, ...], ...]: | |
# Formats shapes of sliding windows, which is necessary for occlusion | |
# Assumes inputs is already formatted (in tuple) | |
if isinstance(sliding_window_shapes[0], int): | |
sliding_window_shapes = (sliding_window_shapes,) # type: ignore | |
sliding_window_shapes: Tuple[Tuple[int, ...], ...] | |
assert len(sliding_window_shapes) == len( | |
inputs | |
), "Must provide sliding window dimensions for each input tensor." | |
for i in range(len(inputs)): | |
assert ( | |
isinstance(sliding_window_shapes[i], tuple) | |
and len(sliding_window_shapes[i]) == len(inputs[i].shape) - 1 | |
), ( | |
"Occlusion shape for input index {} is {} but should be a tuple with " | |
"{} dimensions." | |
).format( | |
i, sliding_window_shapes[i], len(inputs[i].shape) - 1 | |
) | |
return sliding_window_shapes | |
def _compute_conv_delta_and_format_attrs( | |
attr_algo: "GradientAttribution", | |
return_convergence_delta: bool, | |
attributions: Tuple[Tensor, ...], | |
start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], | |
end_point: Union[Tensor, Tuple[Tensor, ...]], | |
additional_forward_args: Any, | |
target: TargetType, | |
is_inputs_tuple: Literal[False] = False, | |
) -> Union[Tensor, Tuple[Tensor, Tensor]]: | |
... | |
def _compute_conv_delta_and_format_attrs( | |
attr_algo: "GradientAttribution", | |
return_convergence_delta: bool, | |
attributions: Tuple[Tensor, ...], | |
start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], | |
end_point: Union[Tensor, Tuple[Tensor, ...]], | |
additional_forward_args: Any, | |
target: TargetType, | |
is_inputs_tuple: Literal[True], | |
) -> Union[Tuple[Tensor, ...], Tuple[Tuple[Tensor, ...], Tensor]]: | |
... | |
# FIXME: GradientAttribution is provided as a string due to a circular import. | |
# This should be fixed when common is refactored into separate files. | |
def _compute_conv_delta_and_format_attrs( | |
attr_algo: "GradientAttribution", | |
return_convergence_delta: bool, | |
attributions: Tuple[Tensor, ...], | |
start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]], | |
end_point: Union[Tensor, Tuple[Tensor, ...]], | |
additional_forward_args: Any, | |
target: TargetType, | |
is_inputs_tuple: bool = False, | |
) -> Union[ | |
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor] | |
]: | |
if return_convergence_delta: | |
# computes convergence error | |
delta = attr_algo.compute_convergence_delta( | |
attributions, | |
start_point, | |
end_point, | |
additional_forward_args=additional_forward_args, | |
target=target, | |
) | |
return _format_output(is_inputs_tuple, attributions), delta | |
else: | |
return _format_output(is_inputs_tuple, attributions) | |
def _tensorize_baseline( | |
inputs: Tuple[Tensor, ...], baselines: Tuple[Union[int, float, Tensor], ...] | |
) -> Tuple[Tensor, ...]: | |
def _tensorize_single_baseline(baseline, input): | |
if isinstance(baseline, (int, float)): | |
return torch.full_like(input, baseline) | |
if input.shape[0] > baseline.shape[0] and baseline.shape[0] == 1: | |
return torch.cat([baseline] * input.shape[0]) | |
return baseline | |
assert isinstance(inputs, tuple) and isinstance(baselines, tuple), ( | |
"inputs and baselines must" | |
"have tuple type but found baselines: {} and inputs: {}".format( | |
type(baselines), type(inputs) | |
) | |
) | |
return tuple( | |
_tensorize_single_baseline(baseline, input) | |
for baseline, input in zip(baselines, inputs) | |
) | |
def _reshape_and_sum( | |
tensor_input: Tensor, num_steps: int, num_examples: int, layer_size: Tuple[int, ...] | |
) -> Tensor: | |
# Used for attribution methods which perform integration | |
# Sums across integration steps by reshaping tensor to | |
# (num_steps, num_examples, (layer_size)) and summing over | |
# dimension 0. Returns a tensor of size (num_examples, (layer_size)) | |
return torch.sum( | |
tensor_input.reshape((num_steps, num_examples) + layer_size), dim=0 | |
) | |
def _call_custom_attribution_func( | |
custom_attribution_func: Callable[..., Tuple[Tensor, ...]], | |
multipliers: Tuple[Tensor, ...], | |
inputs: Tuple[Tensor, ...], | |
baselines: Tuple[Tensor, ...], | |
) -> Tuple[Tensor, ...]: | |
assert callable(custom_attribution_func), ( | |
"`custom_attribution_func`" | |
" must be a callable function but {} provided".format( | |
type(custom_attribution_func) | |
) | |
) | |
custom_attr_func_params = signature(custom_attribution_func).parameters | |
if len(custom_attr_func_params) == 1: | |
return custom_attribution_func(multipliers) | |
elif len(custom_attr_func_params) == 2: | |
return custom_attribution_func(multipliers, inputs) | |
elif len(custom_attr_func_params) == 3: | |
return custom_attribution_func(multipliers, inputs, baselines) | |
else: | |
raise AssertionError( | |
"`custom_attribution_func` must take at least one and at most 3 arguments." | |
) | |
def _find_output_mode_and_verify( | |
initial_eval: Union[int, float, Tensor], | |
num_examples: int, | |
perturbations_per_eval: int, | |
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric], | |
) -> bool: | |
""" | |
This method identifies whether the model outputs a single output for a batch | |
(agg_output_mode = True) or whether it outputs a single output per example | |
(agg_output_mode = False) and returns agg_output_mode. The method also | |
verifies that perturbations_per_eval is 1 in the case that agg_output_mode is True | |
and also verifies that the first dimension of each feature mask if the model | |
returns a single output for a batch. | |
""" | |
if isinstance(initial_eval, (int, float)) or ( | |
isinstance(initial_eval, torch.Tensor) | |
and ( | |
len(initial_eval.shape) == 0 | |
or (num_examples > 1 and initial_eval.numel() == 1) | |
) | |
): | |
agg_output_mode = True | |
assert ( | |
perturbations_per_eval == 1 | |
), "Cannot have perturbations_per_eval > 1 when function returns scalar." | |
if feature_mask is not None: | |
for single_mask in feature_mask: | |
assert single_mask.shape[0] == 1, ( | |
"Cannot provide different masks for each example when function " | |
"returns a scalar." | |
) | |
else: | |
agg_output_mode = False | |
assert ( | |
isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1 | |
), "Target should identify a single element in the model output." | |
return agg_output_mode | |
def _construct_default_feature_mask( | |
inputs: Tuple[Tensor, ...] | |
) -> Tuple[Tuple[Tensor, ...], int]: | |
feature_mask = [] | |
current_num_features = 0 | |
for i in range(len(inputs)): | |
num_features = torch.numel(inputs[i][0]) | |
feature_mask.append( | |
current_num_features | |
+ torch.reshape( | |
torch.arange(num_features, device=inputs[i].device), | |
inputs[i][0:1].shape, | |
) | |
) | |
current_num_features += num_features | |
total_features = current_num_features | |
feature_mask = tuple(feature_mask) | |
return feature_mask, total_features | |