Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
import typing | |
from enum import Enum | |
from functools import reduce | |
from inspect import signature | |
from typing import Any, Callable, cast, Dict, List, overload, Tuple, Union | |
import numpy as np | |
import torch | |
from captum._utils.typing import ( | |
BaselineType, | |
Literal, | |
TargetType, | |
TensorOrTupleOfTensorsGeneric, | |
TupleOrTensorOrBoolGeneric, | |
) | |
from torch import device, Tensor | |
from torch.nn import Module | |
class ExpansionTypes(Enum): | |
repeat = 1 | |
repeat_interleave = 2 | |
def safe_div( | |
numerator: Tensor, | |
denom: Union[Tensor, int, float], | |
default_denom: Union[Tensor, int, float] = 1.0, | |
) -> Tensor: | |
r""" | |
A simple utility function to perform `numerator / denom` | |
if the statement is undefined => result will be `numerator / default_denorm` | |
""" | |
if isinstance(denom, (int, float)): | |
return numerator / (denom if denom != 0 else default_denom) | |
# convert default_denom to tensor if it is float | |
if not torch.is_tensor(default_denom): | |
default_denom = torch.tensor( | |
default_denom, dtype=denom.dtype, device=denom.device | |
) | |
return numerator / torch.where(denom != 0, denom, default_denom) | |
def _is_tuple(inputs: Tensor) -> Literal[False]: | |
... | |
def _is_tuple(inputs: Tuple[Tensor, ...]) -> Literal[True]: | |
... | |
def _is_tuple(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> bool: | |
return isinstance(inputs, tuple) | |
def _validate_target(num_samples: int, target: TargetType) -> None: | |
if isinstance(target, list) or ( | |
isinstance(target, torch.Tensor) and torch.numel(target) > 1 | |
): | |
assert num_samples == len(target), ( | |
"The number of samples provied in the" | |
"input {} does not match with the number of targets. {}".format( | |
num_samples, len(target) | |
) | |
) | |
def _validate_input( | |
inputs: Tuple[Tensor, ...], | |
baselines: Tuple[Union[Tensor, int, float], ...], | |
draw_baseline_from_distrib: bool = False, | |
) -> None: | |
assert len(inputs) == len(baselines), ( | |
"Input and baseline must have the same " | |
"dimensions, baseline has {} features whereas input has {}.".format( | |
len(baselines), len(inputs) | |
) | |
) | |
for input, baseline in zip(inputs, baselines): | |
if draw_baseline_from_distrib: | |
assert ( | |
isinstance(baseline, (int, float)) | |
or input.shape[1:] == baseline.shape[1:] | |
), ( | |
"The samples in input and baseline batches must have" | |
" the same shape or the baseline corresponding to the" | |
" input tensor must be a scalar." | |
" Found baseline: {} and input: {} ".format(baseline, input) | |
) | |
else: | |
assert ( | |
isinstance(baseline, (int, float)) | |
or input.shape == baseline.shape | |
or baseline.shape[0] == 1 | |
), ( | |
"Baseline can be provided as a tensor for just one input and" | |
" broadcasted to the batch or input and baseline must have the" | |
" same shape or the baseline corresponding to each input tensor" | |
" must be a scalar. Found baseline: {} and input: {}".format( | |
baseline, input | |
) | |
) | |
def _zeros(inputs: Tuple[Tensor, ...]) -> Tuple[int, ...]: | |
r""" | |
Takes a tuple of tensors as input and returns a tuple that has the same | |
length as `inputs` with each element as the integer 0. | |
""" | |
return tuple(0 if input.dtype is not torch.bool else False for input in inputs) | |
def _format_baseline( | |
baselines: BaselineType, inputs: Tuple[Tensor, ...] | |
) -> Tuple[Union[Tensor, int, float], ...]: | |
if baselines is None: | |
return _zeros(inputs) | |
if not isinstance(baselines, tuple): | |
baselines = (baselines,) | |
for baseline in baselines: | |
assert isinstance( | |
baseline, (torch.Tensor, int, float) | |
), "baseline input argument must be either a torch.Tensor or a number \ | |
however {} detected".format( | |
type(baseline) | |
) | |
return baselines | |
def _format_tensor_into_tuples(inputs: None) -> None: | |
... | |
def _format_tensor_into_tuples( | |
inputs: Union[Tensor, Tuple[Tensor, ...]] | |
) -> Tuple[Tensor, ...]: | |
... | |
def _format_tensor_into_tuples( | |
inputs: Union[None, Tensor, Tuple[Tensor, ...]] | |
) -> Union[None, Tuple[Tensor, ...]]: | |
if inputs is None: | |
return None | |
if not isinstance(inputs, tuple): | |
assert isinstance( | |
inputs, torch.Tensor | |
), "`inputs` must have type " "torch.Tensor but {} found: ".format(type(inputs)) | |
inputs = (inputs,) | |
return inputs | |
def _format_inputs(inputs: Any, unpack_inputs: bool = True) -> Any: | |
return ( | |
inputs | |
if (isinstance(inputs, tuple) or isinstance(inputs, list)) and unpack_inputs | |
else (inputs,) | |
) | |
def _format_float_or_tensor_into_tuples( | |
inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]] | |
) -> Tuple[Union[float, Tensor], ...]: | |
if not isinstance(inputs, tuple): | |
assert isinstance( | |
inputs, (torch.Tensor, float) | |
), "`inputs` must have type float or torch.Tensor but {} found: ".format( | |
type(inputs) | |
) | |
inputs = (inputs,) | |
return inputs | |
def _format_additional_forward_args(additional_forward_args: None) -> None: | |
... | |
def _format_additional_forward_args( | |
additional_forward_args: Union[Tensor, Tuple] | |
) -> Tuple: | |
... | |
def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]: | |
... | |
def _format_additional_forward_args(additional_forward_args: Any) -> Union[None, Tuple]: | |
if additional_forward_args is not None and not isinstance( | |
additional_forward_args, tuple | |
): | |
additional_forward_args = (additional_forward_args,) | |
return additional_forward_args | |
def _expand_additional_forward_args( | |
additional_forward_args: Any, | |
n_steps: int, | |
expansion_type: ExpansionTypes = ExpansionTypes.repeat, | |
) -> Union[None, Tuple]: | |
def _expand_tensor_forward_arg( | |
additional_forward_arg: Tensor, | |
n_steps: int, | |
expansion_type: ExpansionTypes = ExpansionTypes.repeat, | |
) -> Tensor: | |
if len(additional_forward_arg.size()) == 0: | |
return additional_forward_arg | |
if expansion_type == ExpansionTypes.repeat: | |
return torch.cat([additional_forward_arg] * n_steps, dim=0) | |
elif expansion_type == ExpansionTypes.repeat_interleave: | |
return additional_forward_arg.repeat_interleave(n_steps, dim=0) | |
else: | |
raise NotImplementedError( | |
"Currently only `repeat` and `repeat_interleave`" | |
" expansion_types are supported" | |
) | |
if additional_forward_args is None: | |
return None | |
return tuple( | |
_expand_tensor_forward_arg(additional_forward_arg, n_steps, expansion_type) | |
if isinstance(additional_forward_arg, torch.Tensor) | |
else additional_forward_arg | |
for additional_forward_arg in additional_forward_args | |
) | |
def _expand_target( | |
target: TargetType, | |
n_steps: int, | |
expansion_type: ExpansionTypes = ExpansionTypes.repeat, | |
) -> TargetType: | |
if isinstance(target, list): | |
if expansion_type == ExpansionTypes.repeat: | |
return target * n_steps | |
elif expansion_type == ExpansionTypes.repeat_interleave: | |
expanded_target = [] | |
for i in target: | |
expanded_target.extend([i] * n_steps) | |
return cast(Union[List[Tuple[int, ...]], List[int]], expanded_target) | |
else: | |
raise NotImplementedError( | |
"Currently only `repeat` and `repeat_interleave`" | |
" expansion_types are supported" | |
) | |
elif isinstance(target, torch.Tensor) and torch.numel(target) > 1: | |
if expansion_type == ExpansionTypes.repeat: | |
return torch.cat([target] * n_steps, dim=0) | |
elif expansion_type == ExpansionTypes.repeat_interleave: | |
return target.repeat_interleave(n_steps, dim=0) | |
else: | |
raise NotImplementedError( | |
"Currently only `repeat` and `repeat_interleave`" | |
" expansion_types are supported" | |
) | |
return target | |
def _expand_feature_mask( | |
feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int | |
): | |
is_feature_mask_tuple = _is_tuple(feature_mask) | |
feature_mask = _format_tensor_into_tuples(feature_mask) | |
feature_mask_new = tuple( | |
feature_mask_elem.repeat_interleave(n_samples, dim=0) | |
if feature_mask_elem.size(0) > 1 | |
else feature_mask_elem | |
for feature_mask_elem in feature_mask | |
) | |
return _format_output(is_feature_mask_tuple, feature_mask_new) | |
def _expand_and_update_baselines( | |
inputs: Tuple[Tensor, ...], | |
n_samples: int, | |
kwargs: dict, | |
draw_baseline_from_distrib: bool = False, | |
): | |
def get_random_baseline_indices(bsz, baseline): | |
num_ref_samples = baseline.shape[0] | |
return np.random.choice(num_ref_samples, n_samples * bsz).tolist() | |
# expand baselines to match the sizes of input | |
if "baselines" not in kwargs: | |
return | |
baselines = kwargs["baselines"] | |
baselines = _format_baseline(baselines, inputs) | |
_validate_input( | |
inputs, baselines, draw_baseline_from_distrib=draw_baseline_from_distrib | |
) | |
if draw_baseline_from_distrib: | |
bsz = inputs[0].shape[0] | |
baselines = tuple( | |
baseline[get_random_baseline_indices(bsz, baseline)] | |
if isinstance(baseline, torch.Tensor) | |
else baseline | |
for baseline in baselines | |
) | |
else: | |
baselines = tuple( | |
baseline.repeat_interleave(n_samples, dim=0) | |
if isinstance(baseline, torch.Tensor) | |
and baseline.shape[0] == input.shape[0] | |
and baseline.shape[0] > 1 | |
else baseline | |
for input, baseline in zip(inputs, baselines) | |
) | |
# update kwargs with expanded baseline | |
kwargs["baselines"] = baselines | |
def _expand_and_update_additional_forward_args(n_samples: int, kwargs: dict): | |
if "additional_forward_args" not in kwargs: | |
return | |
additional_forward_args = kwargs["additional_forward_args"] | |
additional_forward_args = _format_additional_forward_args(additional_forward_args) | |
if additional_forward_args is None: | |
return | |
additional_forward_args = _expand_additional_forward_args( | |
additional_forward_args, | |
n_samples, | |
expansion_type=ExpansionTypes.repeat_interleave, | |
) | |
# update kwargs with expanded baseline | |
kwargs["additional_forward_args"] = additional_forward_args | |
def _expand_and_update_target(n_samples: int, kwargs: dict): | |
if "target" not in kwargs: | |
return | |
target = kwargs["target"] | |
target = _expand_target( | |
target, n_samples, expansion_type=ExpansionTypes.repeat_interleave | |
) | |
# update kwargs with expanded baseline | |
kwargs["target"] = target | |
def _expand_and_update_feature_mask(n_samples: int, kwargs: dict): | |
if "feature_mask" not in kwargs: | |
return | |
feature_mask = kwargs["feature_mask"] | |
if feature_mask is None: | |
return | |
feature_mask = _expand_feature_mask(feature_mask, n_samples) | |
kwargs["feature_mask"] = feature_mask | |
def _format_output( | |
is_inputs_tuple: Literal[True], output: Tuple[Tensor, ...] | |
) -> Tuple[Tensor, ...]: | |
... | |
def _format_output( | |
is_inputs_tuple: Literal[False], output: Tuple[Tensor, ...] | |
) -> Tensor: | |
... | |
def _format_output( | |
is_inputs_tuple: bool, output: Tuple[Tensor, ...] | |
) -> Union[Tensor, Tuple[Tensor, ...]]: | |
... | |
def _format_output( | |
is_inputs_tuple: bool, output: Tuple[Tensor, ...] | |
) -> Union[Tensor, Tuple[Tensor, ...]]: | |
r""" | |
In case input is a tensor and the output is returned in form of a | |
tuple we take the first element of the output's tuple to match the | |
same shape signatues of the inputs | |
""" | |
assert isinstance(output, tuple), "Output must be in shape of a tuple" | |
assert is_inputs_tuple or len(output) == 1, ( | |
"The input is a single tensor however the output isn't." | |
"The number of output tensors is: {}".format(len(output)) | |
) | |
return output if is_inputs_tuple else output[0] | |
def _format_outputs( | |
is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]] | |
) -> Union[Tensor, Tuple[Tensor, ...]]: | |
... | |
def _format_outputs( | |
is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]] | |
) -> List[Union[Tensor, Tuple[Tensor, ...]]]: | |
... | |
def _format_outputs( | |
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]] | |
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: | |
... | |
def _format_outputs( | |
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]] | |
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]: | |
assert isinstance(outputs, list), "Outputs must be a list" | |
assert is_multiple_inputs or len(outputs) == 1, ( | |
"outputs should contain multiple inputs or have a single output" | |
f"however the number of outputs is: {len(outputs)}" | |
) | |
return ( | |
[_format_output(len(output) > 1, output) for output in outputs] | |
if is_multiple_inputs | |
else _format_output(len(outputs[0]) > 1, outputs[0]) | |
) | |
def _run_forward( | |
forward_func: Callable, | |
inputs: Any, | |
target: TargetType = None, | |
additional_forward_args: Any = None, | |
) -> Tensor: | |
forward_func_args = signature(forward_func).parameters | |
if len(forward_func_args) == 0: | |
output = forward_func() | |
return output if target is None else _select_targets(output, target) | |
# make everything a tuple so that it is easy to unpack without | |
# using if-statements | |
inputs = _format_inputs(inputs) | |
additional_forward_args = _format_additional_forward_args(additional_forward_args) | |
output = forward_func( | |
*(*inputs, *additional_forward_args) | |
if additional_forward_args is not None | |
else inputs | |
) | |
return _select_targets(output, target) | |
def _select_targets(output: Tensor, target: TargetType) -> Tensor: | |
if target is None: | |
return output | |
num_examples = output.shape[0] | |
dims = len(output.shape) | |
device = output.device | |
if isinstance(target, (int, tuple)): | |
return _verify_select_column(output, target) | |
elif isinstance(target, torch.Tensor): | |
if torch.numel(target) == 1 and isinstance(target.item(), int): | |
return _verify_select_column(output, cast(int, target.item())) | |
elif len(target.shape) == 1 and torch.numel(target) == num_examples: | |
assert dims == 2, "Output must be 2D to select tensor of targets." | |
return torch.gather(output, 1, target.reshape(len(output), 1)) | |
else: | |
raise AssertionError( | |
"Tensor target dimension %r is not valid. %r" | |
% (target.shape, output.shape) | |
) | |
elif isinstance(target, list): | |
assert len(target) == num_examples, "Target list length does not match output!" | |
if isinstance(target[0], int): | |
assert dims == 2, "Output must be 2D to select tensor of targets." | |
return torch.gather( | |
output, 1, torch.tensor(target, device=device).reshape(len(output), 1) | |
) | |
elif isinstance(target[0], tuple): | |
return torch.stack( | |
[ | |
output[(i,) + cast(Tuple, targ_elem)] | |
for i, targ_elem in enumerate(target) | |
] | |
) | |
else: | |
raise AssertionError("Target element type in list is not valid.") | |
else: | |
raise AssertionError("Target type %r is not valid." % target) | |
def _contains_slice(target: Union[int, Tuple[Union[int, slice], ...]]) -> bool: | |
if isinstance(target, tuple): | |
for index in target: | |
if isinstance(index, slice): | |
return True | |
return False | |
return isinstance(target, slice) | |
def _verify_select_column( | |
output: Tensor, target: Union[int, Tuple[Union[int, slice], ...]] | |
) -> Tensor: | |
target = (target,) if isinstance(target, int) else target | |
assert ( | |
len(target) <= len(output.shape) - 1 | |
), "Cannot choose target column with output shape %r." % (output.shape,) | |
return output[(slice(None), *target)] | |
def _verify_select_neuron( | |
layer_output: Tuple[Tensor, ...], | |
selector: Union[int, Tuple[Union[int, slice], ...], Callable], | |
) -> Tensor: | |
if callable(selector): | |
return selector(layer_output if len(layer_output) > 1 else layer_output[0]) | |
assert len(layer_output) == 1, ( | |
"Cannot select neuron index from layer with multiple tensors," | |
"consider providing a neuron selector function instead." | |
) | |
selected_neurons = _verify_select_column(layer_output[0], selector) | |
if _contains_slice(selector): | |
return selected_neurons.reshape(selected_neurons.shape[0], -1).sum(1) | |
return selected_neurons | |
def _extract_device( | |
module: Module, | |
hook_inputs: Union[None, Tensor, Tuple[Tensor, ...]], | |
hook_outputs: Union[None, Tensor, Tuple[Tensor, ...]], | |
) -> device: | |
params = list(module.parameters()) | |
if ( | |
(hook_inputs is None or len(hook_inputs) == 0) | |
and (hook_outputs is None or len(hook_outputs) == 0) | |
and len(params) == 0 | |
): | |
raise RuntimeError( | |
"""Unable to extract device information for the module | |
{}. Both inputs and outputs to the forward hook and | |
`module.parameters()` are empty. | |
The reason that the inputs to the forward hook are empty | |
could be due to the fact that the arguments to that | |
module {} are all named and are passed as named | |
variables to its forward function. | |
""".format( | |
module, module | |
) | |
) | |
if hook_inputs is not None and len(hook_inputs) > 0: | |
return hook_inputs[0].device | |
if hook_outputs is not None and len(hook_outputs) > 0: | |
return hook_outputs[0].device | |
return params[0].device | |
def _reduce_list( | |
val_list: List[TupleOrTensorOrBoolGeneric], | |
red_func: Callable[[List], Any] = torch.cat, | |
) -> TupleOrTensorOrBoolGeneric: | |
""" | |
Applies reduction function to given list. If each element in the list is | |
a Tensor, applies reduction function to all elements of the list, and returns | |
the output Tensor / value. If each element is a boolean, apply any method (or). | |
If each element is a tuple, applies reduction | |
function to corresponding elements of each tuple in the list, and returns | |
tuple of reduction function outputs with length matching the length of tuple | |
val_list[0]. It is assumed that all tuples in the list have the same length | |
and red_func can be applied to all elements in each corresponding position. | |
""" | |
assert len(val_list) > 0, "Cannot reduce empty list!" | |
if isinstance(val_list[0], torch.Tensor): | |
first_device = val_list[0].device | |
return red_func([elem.to(first_device) for elem in val_list]) | |
elif isinstance(val_list[0], bool): | |
return any(val_list) | |
elif isinstance(val_list[0], tuple): | |
final_out = [] | |
for i in range(len(val_list[0])): | |
final_out.append( | |
_reduce_list([val_elem[i] for val_elem in val_list], red_func) | |
) | |
else: | |
raise AssertionError( | |
"Elements to be reduced can only be" | |
"either Tensors or tuples containing Tensors." | |
) | |
return tuple(final_out) | |
def _sort_key_list( | |
keys: List[device], device_ids: Union[None, List[int]] = None | |
) -> List[device]: | |
""" | |
Sorts list of torch devices (keys) by given index list, device_ids. If keys | |
contains only one device, then the list is returned unchanged. If keys | |
contains a device for which the id is not contained in device_ids, then | |
an error is returned. This method is used to identify the order of DataParallel | |
batched devices, given the device ID ordering. | |
""" | |
if len(keys) == 1: | |
return keys | |
id_dict: Dict[int, device] = {} | |
assert device_ids is not None, "Device IDs must be provided with multiple devices." | |
for key in keys: | |
if key.index in id_dict: | |
raise AssertionError("Duplicate CUDA Device ID identified in device list.") | |
id_dict[key.index] = key | |
out_list = [ | |
id_dict[device_id] | |
for device_id in filter(lambda device_id: device_id in id_dict, device_ids) | |
] | |
assert len(out_list) == len(keys), "Given Device ID List does not match" | |
"devices with computed tensors." | |
return out_list | |
def _flatten_tensor_or_tuple(inp: TensorOrTupleOfTensorsGeneric) -> Tensor: | |
if isinstance(inp, Tensor): | |
return inp.flatten() | |
return torch.cat([single_inp.flatten() for single_inp in inp]) | |
def _get_module_from_name(model: Module, layer_name: str) -> Any: | |
r""" | |
Returns the module (layer) object, given its (string) name | |
in the model. | |
Args: | |
name (str): Module or nested modules name string in self.model | |
Returns: | |
The module (layer) in self.model. | |
""" | |
return reduce(getattr, layer_name.split("."), model) | |
def _register_backward_hook( | |
module: Module, hook: Callable, attr_obj: Any | |
) -> torch.utils.hooks.RemovableHandle: | |
# Special case for supporting output attributions for neuron methods | |
# This can be removed after deprecation of neuron output attributions | |
# for NeuronDeepLift, NeuronDeconvolution, and NeuronGuidedBackprop | |
# in v0.6.0 | |
if ( | |
hasattr(attr_obj, "skip_new_hook_layer") | |
and attr_obj.skip_new_hook_layer == module | |
): | |
return module.register_backward_hook(hook) | |
if torch.__version__ >= "1.9": | |
# Only supported for torch >= 1.9 | |
return module.register_full_backward_hook(hook) | |
else: | |
# Fallback for previous versions of PyTorch | |
return module.register_backward_hook(hook) | |