markytools's picture
added strexp
d61b9c7
raw
history blame
13 kB
#!/usr/bin/env python3
import typing
from inspect import signature
from typing import Any, Callable, List, Tuple, TYPE_CHECKING, Union
import torch
from captum._utils.common import (
_format_baseline,
_format_output,
_format_tensor_into_tuples,
_validate_input as _validate_input_basic,
)
from captum._utils.typing import (
BaselineType,
Literal,
TargetType,
TensorOrTupleOfTensorsGeneric,
)
from captum.attr._utils.approximation_methods import SUPPORTED_METHODS
from torch import Tensor
if TYPE_CHECKING:
from captum.attr._utils.attribution import GradientAttribution
def _sum_rows(input: Tensor) -> Tensor:
return input.reshape(input.shape[0], -1).sum(1)
def _validate_target(num_samples: int, target: TargetType) -> None:
if isinstance(target, list) or (
isinstance(target, torch.Tensor) and torch.numel(target) > 1
):
assert num_samples == len(target), (
"The number of samples provied in the"
"input {} does not match with the number of targets. {}".format(
num_samples, len(target)
)
)
def _validate_input(
inputs: Tuple[Tensor, ...],
baselines: Tuple[Union[Tensor, int, float], ...],
n_steps: int = 50,
method: str = "riemann_trapezoid",
draw_baseline_from_distrib: bool = False,
) -> None:
_validate_input_basic(inputs, baselines, draw_baseline_from_distrib)
assert (
n_steps >= 0
), "The number of steps must be a positive integer. " "Given: {}".format(n_steps)
assert (
method in SUPPORTED_METHODS
), "Approximation method must be one for the following {}. " "Given {}".format(
SUPPORTED_METHODS, method
)
def _validate_noise_tunnel_type(
nt_type: str, supported_noise_tunnel_types: List[str]
) -> None:
assert nt_type in supported_noise_tunnel_types, (
"Noise types must be either `smoothgrad`, `smoothgrad_sq` or `vargrad`. "
"Given {}".format(nt_type)
)
@typing.overload
def _format_input_baseline(
inputs: Union[Tensor, Tuple[Tensor, ...]],
baselines: Union[Tensor, Tuple[Tensor, ...]],
) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:
...
@typing.overload
def _format_input_baseline(
inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType
) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]:
...
def _format_input_baseline(
inputs: Union[Tensor, Tuple[Tensor, ...]], baselines: BaselineType
) -> Tuple[Tuple[Tensor, ...], Tuple[Union[Tensor, int, float], ...]]:
inputs = _format_tensor_into_tuples(inputs)
baselines = _format_baseline(baselines, inputs)
return inputs, baselines
# This function can potentially be merged with the `format_baseline` function
# however, since currently not all algorithms support baselines of type
# callable this will be kept in a separate function.
@typing.overload
def _format_callable_baseline(
baselines: Union[
None,
Callable[..., Union[Tensor, Tuple[Tensor, ...]]],
Tensor,
Tuple[Tensor, ...],
],
inputs: Union[Tensor, Tuple[Tensor, ...]],
) -> Tuple[Tensor, ...]:
...
@typing.overload
def _format_callable_baseline(
baselines: Union[
None,
Callable[..., Union[Tensor, Tuple[Tensor, ...]]],
Tensor,
int,
float,
Tuple[Union[Tensor, int, float], ...],
],
inputs: Union[Tensor, Tuple[Tensor, ...]],
) -> Tuple[Union[Tensor, int, float], ...]:
...
def _format_callable_baseline(
baselines: Union[
None,
Callable[..., Union[Tensor, Tuple[Tensor, ...]]],
Tensor,
int,
float,
Tuple[Union[Tensor, int, float], ...],
],
inputs: Union[Tensor, Tuple[Tensor, ...]],
) -> Tuple[Union[Tensor, int, float], ...]:
if callable(baselines):
# Note: this assumes that if baselines is a function and if it takes
# arguments, then the first argument is the `inputs`.
# This can be expanded in the future with better type checks
baseline_parameters = signature(baselines).parameters
if len(baseline_parameters) == 0:
baselines = baselines()
else:
baselines = baselines(inputs)
return _format_baseline(baselines, _format_tensor_into_tuples(inputs))
def _format_and_verify_strides(
strides: Union[None, int, Tuple[int, ...], Tuple[Union[int, Tuple[int, ...]], ...]],
inputs: Tuple[Tensor, ...],
) -> Tuple[Union[int, Tuple[int, ...]], ...]:
# Formats strides, which are necessary for occlusion
# Assumes inputs are already formatted (in tuple)
if strides is None:
strides = tuple(1 for input in inputs)
if len(inputs) == 1 and not (isinstance(strides, tuple) and len(strides) == 1):
strides = (strides,) # type: ignore
assert isinstance(strides, tuple) and len(strides) == len(
inputs
), "Strides must be provided for each input tensor."
for i in range(len(inputs)):
assert isinstance(strides[i], int) or (
isinstance(strides[i], tuple)
and len(strides[i]) == len(inputs[i].shape) - 1 # type: ignore
), (
"Stride for input index {} is {}, which is invalid for input with "
"shape {}. It must be either an int or a tuple with length equal to "
"len(input_shape) - 1."
).format(
i, strides[i], inputs[i].shape
)
return strides
def _format_and_verify_sliding_window_shapes(
sliding_window_shapes: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...]],
inputs: Tuple[Tensor, ...],
) -> Tuple[Tuple[int, ...], ...]:
# Formats shapes of sliding windows, which is necessary for occlusion
# Assumes inputs is already formatted (in tuple)
if isinstance(sliding_window_shapes[0], int):
sliding_window_shapes = (sliding_window_shapes,) # type: ignore
sliding_window_shapes: Tuple[Tuple[int, ...], ...]
assert len(sliding_window_shapes) == len(
inputs
), "Must provide sliding window dimensions for each input tensor."
for i in range(len(inputs)):
assert (
isinstance(sliding_window_shapes[i], tuple)
and len(sliding_window_shapes[i]) == len(inputs[i].shape) - 1
), (
"Occlusion shape for input index {} is {} but should be a tuple with "
"{} dimensions."
).format(
i, sliding_window_shapes[i], len(inputs[i].shape) - 1
)
return sliding_window_shapes
@typing.overload
def _compute_conv_delta_and_format_attrs(
attr_algo: "GradientAttribution",
return_convergence_delta: bool,
attributions: Tuple[Tensor, ...],
start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]],
end_point: Union[Tensor, Tuple[Tensor, ...]],
additional_forward_args: Any,
target: TargetType,
is_inputs_tuple: Literal[False] = False,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
...
@typing.overload
def _compute_conv_delta_and_format_attrs(
attr_algo: "GradientAttribution",
return_convergence_delta: bool,
attributions: Tuple[Tensor, ...],
start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]],
end_point: Union[Tensor, Tuple[Tensor, ...]],
additional_forward_args: Any,
target: TargetType,
is_inputs_tuple: Literal[True],
) -> Union[Tuple[Tensor, ...], Tuple[Tuple[Tensor, ...], Tensor]]:
...
# FIXME: GradientAttribution is provided as a string due to a circular import.
# This should be fixed when common is refactored into separate files.
def _compute_conv_delta_and_format_attrs(
attr_algo: "GradientAttribution",
return_convergence_delta: bool,
attributions: Tuple[Tensor, ...],
start_point: Union[int, float, Tensor, Tuple[Union[int, float, Tensor], ...]],
end_point: Union[Tensor, Tuple[Tensor, ...]],
additional_forward_args: Any,
target: TargetType,
is_inputs_tuple: bool = False,
) -> Union[
Tensor, Tuple[Tensor, ...], Tuple[Union[Tensor, Tuple[Tensor, ...]], Tensor]
]:
if return_convergence_delta:
# computes convergence error
delta = attr_algo.compute_convergence_delta(
attributions,
start_point,
end_point,
additional_forward_args=additional_forward_args,
target=target,
)
return _format_output(is_inputs_tuple, attributions), delta
else:
return _format_output(is_inputs_tuple, attributions)
def _tensorize_baseline(
inputs: Tuple[Tensor, ...], baselines: Tuple[Union[int, float, Tensor], ...]
) -> Tuple[Tensor, ...]:
def _tensorize_single_baseline(baseline, input):
if isinstance(baseline, (int, float)):
return torch.full_like(input, baseline)
if input.shape[0] > baseline.shape[0] and baseline.shape[0] == 1:
return torch.cat([baseline] * input.shape[0])
return baseline
assert isinstance(inputs, tuple) and isinstance(baselines, tuple), (
"inputs and baselines must"
"have tuple type but found baselines: {} and inputs: {}".format(
type(baselines), type(inputs)
)
)
return tuple(
_tensorize_single_baseline(baseline, input)
for baseline, input in zip(baselines, inputs)
)
def _reshape_and_sum(
tensor_input: Tensor, num_steps: int, num_examples: int, layer_size: Tuple[int, ...]
) -> Tensor:
# Used for attribution methods which perform integration
# Sums across integration steps by reshaping tensor to
# (num_steps, num_examples, (layer_size)) and summing over
# dimension 0. Returns a tensor of size (num_examples, (layer_size))
return torch.sum(
tensor_input.reshape((num_steps, num_examples) + layer_size), dim=0
)
def _call_custom_attribution_func(
custom_attribution_func: Callable[..., Tuple[Tensor, ...]],
multipliers: Tuple[Tensor, ...],
inputs: Tuple[Tensor, ...],
baselines: Tuple[Tensor, ...],
) -> Tuple[Tensor, ...]:
assert callable(custom_attribution_func), (
"`custom_attribution_func`"
" must be a callable function but {} provided".format(
type(custom_attribution_func)
)
)
custom_attr_func_params = signature(custom_attribution_func).parameters
if len(custom_attr_func_params) == 1:
return custom_attribution_func(multipliers)
elif len(custom_attr_func_params) == 2:
return custom_attribution_func(multipliers, inputs)
elif len(custom_attr_func_params) == 3:
return custom_attribution_func(multipliers, inputs, baselines)
else:
raise AssertionError(
"`custom_attribution_func` must take at least one and at most 3 arguments."
)
def _find_output_mode_and_verify(
initial_eval: Union[int, float, Tensor],
num_examples: int,
perturbations_per_eval: int,
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric],
) -> bool:
"""
This method identifies whether the model outputs a single output for a batch
(agg_output_mode = True) or whether it outputs a single output per example
(agg_output_mode = False) and returns agg_output_mode. The method also
verifies that perturbations_per_eval is 1 in the case that agg_output_mode is True
and also verifies that the first dimension of each feature mask if the model
returns a single output for a batch.
"""
if isinstance(initial_eval, (int, float)) or (
isinstance(initial_eval, torch.Tensor)
and (
len(initial_eval.shape) == 0
or (num_examples > 1 and initial_eval.numel() == 1)
)
):
agg_output_mode = True
assert (
perturbations_per_eval == 1
), "Cannot have perturbations_per_eval > 1 when function returns scalar."
if feature_mask is not None:
for single_mask in feature_mask:
assert single_mask.shape[0] == 1, (
"Cannot provide different masks for each example when function "
"returns a scalar."
)
else:
agg_output_mode = False
assert (
isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1
), "Target should identify a single element in the model output."
return agg_output_mode
def _construct_default_feature_mask(
inputs: Tuple[Tensor, ...]
) -> Tuple[Tuple[Tensor, ...], int]:
feature_mask = []
current_num_features = 0
for i in range(len(inputs)):
num_features = torch.numel(inputs[i][0])
feature_mask.append(
current_num_features
+ torch.reshape(
torch.arange(num_features, device=inputs[i].device),
inputs[i][0:1].shape,
)
)
current_num_features += num_features
total_features = current_num_features
feature_mask = tuple(feature_mask)
return feature_mask, total_features