File size: 8,032 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
#!/usr/bin/env python3
import typing
import warnings
from typing import Any, Callable, Iterator, Tuple, Union

import torch
from captum._utils.common import (
    _format_additional_forward_args,
    _format_output,
    _format_tensor_into_tuples,
    _reduce_list,
)
from captum._utils.typing import (
    TargetType,
    TensorOrTupleOfTensorsGeneric,
    TupleOrTensorOrBoolGeneric,
)
from captum.attr._utils.approximation_methods import approximation_parameters
from torch import Tensor


def _batch_attribution(
    attr_method,
    num_examples,
    internal_batch_size,
    n_steps,
    include_endpoint=False,
    **kwargs,
):
    """
    This method applies internal batching to given attribution method, dividing
    the total steps into batches and running each independently and sequentially,
    adding each result to compute the total attribution.

    Step sizes and alphas are spliced for each batch and passed explicitly for each
    call to _attribute.

    kwargs include all argument necessary to pass to each attribute call, except
    for n_steps, which is computed based on the number of steps for the batch.

    include_endpoint ensures that one step overlaps between each batch, which
    is necessary for some methods, particularly LayerConductance.
    """
    if internal_batch_size < num_examples:
        warnings.warn(
            "Internal batch size cannot be less than the number of input examples. "
            "Defaulting to internal batch size of %d equal to the number of examples."
            % num_examples
        )
    # Number of steps for each batch
    step_count = max(1, internal_batch_size // num_examples)
    if include_endpoint:
        if step_count < 2:
            step_count = 2
            warnings.warn(
                "This method computes finite differences between evaluations at "
                "consecutive steps, so internal batch size must be at least twice "
                "the number of examples. Defaulting to internal batch size of %d"
                " equal to twice the number of examples." % (2 * num_examples)
            )

    total_attr = None
    cumulative_steps = 0
    step_sizes_func, alphas_func = approximation_parameters(kwargs["method"])
    full_step_sizes = step_sizes_func(n_steps)
    full_alphas = alphas_func(n_steps)

    while cumulative_steps < n_steps:
        start_step = cumulative_steps
        end_step = min(start_step + step_count, n_steps)
        batch_steps = end_step - start_step

        if include_endpoint:
            batch_steps -= 1

        step_sizes = full_step_sizes[start_step:end_step]
        alphas = full_alphas[start_step:end_step]
        current_attr = attr_method._attribute(
            **kwargs, n_steps=batch_steps, step_sizes_and_alphas=(step_sizes, alphas)
        )

        if total_attr is None:
            total_attr = current_attr
        else:
            if isinstance(total_attr, Tensor):
                total_attr = total_attr + current_attr.detach()
            else:
                total_attr = tuple(
                    current.detach() + prev_total
                    for current, prev_total in zip(current_attr, total_attr)
                )
        if include_endpoint and end_step < n_steps:
            cumulative_steps = end_step - 1
        else:
            cumulative_steps = end_step
    return total_attr


@typing.overload
def _tuple_splice_range(inputs: None, start: int, end: int) -> None:
    ...


@typing.overload
def _tuple_splice_range(inputs: Tuple, start: int, end: int) -> Tuple:
    ...


def _tuple_splice_range(
    inputs: Union[None, Tuple], start: int, end: int
) -> Union[None, Tuple]:
    """
    Splices each tensor element of given tuple (inputs) from range start
    (inclusive) to end (non-inclusive) on its first dimension. If element
    is not a Tensor, it is left unchanged. It is assumed that all tensor elements
    have the same first dimension (corresponding to number of examples).
    The returned value is a tuple with the same length as inputs, with Tensors
    spliced appropriately.
    """
    assert start < end, "Start point must precede end point for batch splicing."
    if inputs is None:
        return None
    return tuple(
        inp[start:end] if isinstance(inp, torch.Tensor) else inp for inp in inputs
    )


def _batched_generator(
    inputs: TensorOrTupleOfTensorsGeneric,
    additional_forward_args: Any = None,
    target_ind: TargetType = None,
    internal_batch_size: Union[None, int] = None,
) -> Iterator[Tuple[Tuple[Tensor, ...], Any, TargetType]]:
    """
    Returns a generator which returns corresponding chunks of size internal_batch_size
    for both inputs and additional_forward_args. If batch size is None,
    generator only includes original inputs and additional args.
    """
    assert internal_batch_size is None or (
        isinstance(internal_batch_size, int) and internal_batch_size > 0
    ), "Batch size must be greater than 0."
    inputs = _format_tensor_into_tuples(inputs)
    additional_forward_args = _format_additional_forward_args(additional_forward_args)
    num_examples = inputs[0].shape[0]
    # TODO Reconsider this check if _batched_generator is used for non gradient-based
    # attribution algorithms
    if not (inputs[0] * 1).requires_grad:
        warnings.warn(
            """It looks like that the attribution for a gradient-based method is
            computed in a `torch.no_grad` block or perhaps the inputs have no
            requires_grad."""
        )
    if internal_batch_size is None:
        yield inputs, additional_forward_args, target_ind
    else:
        for current_total in range(0, num_examples, internal_batch_size):
            with torch.autograd.set_grad_enabled(True):
                inputs_splice = _tuple_splice_range(
                    inputs, current_total, current_total + internal_batch_size
                )
            yield inputs_splice, _tuple_splice_range(
                additional_forward_args,
                current_total,
                current_total + internal_batch_size,
            ), target_ind[
                current_total : current_total + internal_batch_size
            ] if isinstance(
                target_ind, list
            ) or (
                isinstance(target_ind, torch.Tensor) and target_ind.numel() > 1
            ) else target_ind


def _batched_operator(
    operator: Callable[..., TupleOrTensorOrBoolGeneric],
    inputs: TensorOrTupleOfTensorsGeneric,
    additional_forward_args: Any = None,
    target_ind: TargetType = None,
    internal_batch_size: Union[None, int] = None,
    **kwargs: Any,
) -> TupleOrTensorOrBoolGeneric:
    """
    Batches the operation of the given operator, applying the given batch size
    to inputs and additional forward arguments, and returning the concatenation
    of the results of each batch.
    """
    all_outputs = [
        operator(
            inputs=input,
            additional_forward_args=additional,
            target_ind=target,
            **kwargs,
        )
        for input, additional, target in _batched_generator(
            inputs, additional_forward_args, target_ind, internal_batch_size
        )
    ]
    return _reduce_list(all_outputs)


def _select_example(curr_arg: Any, index: int, bsz: int) -> Any:
    if curr_arg is None:
        return None
    is_tuple = isinstance(curr_arg, tuple)
    if not is_tuple:
        curr_arg = (curr_arg,)
    selected_arg = []
    for i in range(len(curr_arg)):
        if isinstance(curr_arg[i], (Tensor, list)) and len(curr_arg[i]) == bsz:
            selected_arg.append(curr_arg[i][index : index + 1])
        else:
            selected_arg.append(curr_arg[i])
    return _format_output(is_tuple, tuple(selected_arg))


def _batch_example_iterator(bsz: int, *args) -> Iterator:
    """
    Batches the provided argument.
    """
    for i in range(bsz):
        curr_args = [_select_example(args[j], i, bsz) for j in range(len(args))]
        yield tuple(curr_args)