Spaces:
Build error
Build error
File size: 8,032 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
#!/usr/bin/env python3
import typing
import warnings
from typing import Any, Callable, Iterator, Tuple, Union
import torch
from captum._utils.common import (
_format_additional_forward_args,
_format_output,
_format_tensor_into_tuples,
_reduce_list,
)
from captum._utils.typing import (
TargetType,
TensorOrTupleOfTensorsGeneric,
TupleOrTensorOrBoolGeneric,
)
from captum.attr._utils.approximation_methods import approximation_parameters
from torch import Tensor
def _batch_attribution(
attr_method,
num_examples,
internal_batch_size,
n_steps,
include_endpoint=False,
**kwargs,
):
"""
This method applies internal batching to given attribution method, dividing
the total steps into batches and running each independently and sequentially,
adding each result to compute the total attribution.
Step sizes and alphas are spliced for each batch and passed explicitly for each
call to _attribute.
kwargs include all argument necessary to pass to each attribute call, except
for n_steps, which is computed based on the number of steps for the batch.
include_endpoint ensures that one step overlaps between each batch, which
is necessary for some methods, particularly LayerConductance.
"""
if internal_batch_size < num_examples:
warnings.warn(
"Internal batch size cannot be less than the number of input examples. "
"Defaulting to internal batch size of %d equal to the number of examples."
% num_examples
)
# Number of steps for each batch
step_count = max(1, internal_batch_size // num_examples)
if include_endpoint:
if step_count < 2:
step_count = 2
warnings.warn(
"This method computes finite differences between evaluations at "
"consecutive steps, so internal batch size must be at least twice "
"the number of examples. Defaulting to internal batch size of %d"
" equal to twice the number of examples." % (2 * num_examples)
)
total_attr = None
cumulative_steps = 0
step_sizes_func, alphas_func = approximation_parameters(kwargs["method"])
full_step_sizes = step_sizes_func(n_steps)
full_alphas = alphas_func(n_steps)
while cumulative_steps < n_steps:
start_step = cumulative_steps
end_step = min(start_step + step_count, n_steps)
batch_steps = end_step - start_step
if include_endpoint:
batch_steps -= 1
step_sizes = full_step_sizes[start_step:end_step]
alphas = full_alphas[start_step:end_step]
current_attr = attr_method._attribute(
**kwargs, n_steps=batch_steps, step_sizes_and_alphas=(step_sizes, alphas)
)
if total_attr is None:
total_attr = current_attr
else:
if isinstance(total_attr, Tensor):
total_attr = total_attr + current_attr.detach()
else:
total_attr = tuple(
current.detach() + prev_total
for current, prev_total in zip(current_attr, total_attr)
)
if include_endpoint and end_step < n_steps:
cumulative_steps = end_step - 1
else:
cumulative_steps = end_step
return total_attr
@typing.overload
def _tuple_splice_range(inputs: None, start: int, end: int) -> None:
...
@typing.overload
def _tuple_splice_range(inputs: Tuple, start: int, end: int) -> Tuple:
...
def _tuple_splice_range(
inputs: Union[None, Tuple], start: int, end: int
) -> Union[None, Tuple]:
"""
Splices each tensor element of given tuple (inputs) from range start
(inclusive) to end (non-inclusive) on its first dimension. If element
is not a Tensor, it is left unchanged. It is assumed that all tensor elements
have the same first dimension (corresponding to number of examples).
The returned value is a tuple with the same length as inputs, with Tensors
spliced appropriately.
"""
assert start < end, "Start point must precede end point for batch splicing."
if inputs is None:
return None
return tuple(
inp[start:end] if isinstance(inp, torch.Tensor) else inp for inp in inputs
)
def _batched_generator(
inputs: TensorOrTupleOfTensorsGeneric,
additional_forward_args: Any = None,
target_ind: TargetType = None,
internal_batch_size: Union[None, int] = None,
) -> Iterator[Tuple[Tuple[Tensor, ...], Any, TargetType]]:
"""
Returns a generator which returns corresponding chunks of size internal_batch_size
for both inputs and additional_forward_args. If batch size is None,
generator only includes original inputs and additional args.
"""
assert internal_batch_size is None or (
isinstance(internal_batch_size, int) and internal_batch_size > 0
), "Batch size must be greater than 0."
inputs = _format_tensor_into_tuples(inputs)
additional_forward_args = _format_additional_forward_args(additional_forward_args)
num_examples = inputs[0].shape[0]
# TODO Reconsider this check if _batched_generator is used for non gradient-based
# attribution algorithms
if not (inputs[0] * 1).requires_grad:
warnings.warn(
"""It looks like that the attribution for a gradient-based method is
computed in a `torch.no_grad` block or perhaps the inputs have no
requires_grad."""
)
if internal_batch_size is None:
yield inputs, additional_forward_args, target_ind
else:
for current_total in range(0, num_examples, internal_batch_size):
with torch.autograd.set_grad_enabled(True):
inputs_splice = _tuple_splice_range(
inputs, current_total, current_total + internal_batch_size
)
yield inputs_splice, _tuple_splice_range(
additional_forward_args,
current_total,
current_total + internal_batch_size,
), target_ind[
current_total : current_total + internal_batch_size
] if isinstance(
target_ind, list
) or (
isinstance(target_ind, torch.Tensor) and target_ind.numel() > 1
) else target_ind
def _batched_operator(
operator: Callable[..., TupleOrTensorOrBoolGeneric],
inputs: TensorOrTupleOfTensorsGeneric,
additional_forward_args: Any = None,
target_ind: TargetType = None,
internal_batch_size: Union[None, int] = None,
**kwargs: Any,
) -> TupleOrTensorOrBoolGeneric:
"""
Batches the operation of the given operator, applying the given batch size
to inputs and additional forward arguments, and returning the concatenation
of the results of each batch.
"""
all_outputs = [
operator(
inputs=input,
additional_forward_args=additional,
target_ind=target,
**kwargs,
)
for input, additional, target in _batched_generator(
inputs, additional_forward_args, target_ind, internal_batch_size
)
]
return _reduce_list(all_outputs)
def _select_example(curr_arg: Any, index: int, bsz: int) -> Any:
if curr_arg is None:
return None
is_tuple = isinstance(curr_arg, tuple)
if not is_tuple:
curr_arg = (curr_arg,)
selected_arg = []
for i in range(len(curr_arg)):
if isinstance(curr_arg[i], (Tensor, list)) and len(curr_arg[i]) == bsz:
selected_arg.append(curr_arg[i][index : index + 1])
else:
selected_arg.append(curr_arg[i])
return _format_output(is_tuple, tuple(selected_arg))
def _batch_example_iterator(bsz: int, *args) -> Iterator:
"""
Batches the provided argument.
"""
for i in range(bsz):
curr_args = [_select_example(args[j], i, bsz) for j in range(len(args))]
yield tuple(curr_args)
|