File size: 13,328 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
#!/usr/bin/env python3

from typing import Callable, Optional, Tuple, Union, Any, List

import torch
import torch.nn as nn
from captum._utils.progress import progress
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset


def _tensor_batch_dot(t1: Tensor, t2: Tensor) -> Tensor:
    r"""
    Computes pairwise dot product between two tensors

    Args:
        Tensors t1 and t2 are feature vectors with dimension (batch_size_1, *) and
        (batch_size_2,  *). The * dimensions must match in total number of elements.

    Returns:
        Tensor with shape (batch_size_1, batch_size_2) containing the pairwise dot
        products. For example, Tensor[i][j] would be the dot product between
        t1[i] and t2[j].
    """

    msg = (
        "Please ensure each batch member has the same feature dimension. "
        f"First input has {torch.numel(t1) / t1.shape[0]} features, and "
        f"second input has {torch.numel(t2) / t2.shape[0]} features."
    )
    assert torch.numel(t1) / t1.shape[0] == torch.numel(t2) / t2.shape[0], msg

    return torch.mm(
        t1.view(t1.shape[0], -1),
        t2.view(t2.shape[0], -1).T,
    )


def _gradient_dot_product(
    input_grads: Tuple[Tensor], src_grads: Tuple[Tensor]
) -> Tensor:
    r"""
    Computes the dot product between the gradient vector for a model on an input batch
    and src batch, for each pairwise batch member. Gradients are passed in as a tuple
    corresponding to the trainable parameters returned by model.parameters(). Output
    corresponds to a tensor of size (inputs_batch_size, src_batch_size) with all
    pairwise dot products.
    """

    assert len(input_grads) == len(src_grads), "Mismatching gradient parameters."

    iterator = zip(input_grads, src_grads)
    total = _tensor_batch_dot(*next(iterator))
    for input_grad, src_grad in iterator:
        total += _tensor_batch_dot(input_grad, src_grad)
    total = torch.Tensor(total)

    return total


def _jacobian_loss_wrt_inputs(
    loss_fn: Union[Module, Callable],
    out: Tensor,
    targets: Tensor,
    vectorize: bool,
    reduction_type: str,
) -> Tensor:
    r"""
    Often, we have a loss function that computes a per-sample loss given a 1D tensor
    input, and we want to calculate the jacobian of the loss w.r.t. that input.  For
    example, the input could be a length K tensor specifying the probability a given
    sample belongs to each of K possible classes, and the loss function could be
    cross-entropy loss. This function performs that calculation, but does so for a
    *batch* of inputs. We create this helper function for two reasons: 1) to handle
    differences between Pytorch versiosn for vectorized jacobian calculations, and
    2) this function does not accept the aforementioned per-sample loss function.
    Instead, it accepts a "reduction" loss function that *reduces* the per-sample loss
    for a batch into a single loss. Using a "reduction" loss improves speed.
    We will allow this reduction to either be the mean or sum of the per-sample losses,
    and this function provides an uniform way to handle different possible reductions,
    and also check if the reduction used is valid. Regardless of the reduction used,
    this function returns the jacobian for the per-sample loss (for each sample in the
    batch).

    Args:
        loss_fn (torch.nn.Module or Callable or None): The loss function. If a library
                defined loss function is provided, it would be expected to be a
                torch.nn.Module. If a custom loss is provided, it can be either type,
                but must behave as a library loss function would if `reduction='sum'`
                or `reduction='mean'`.
        out (tensor): This is a tensor that represents the batch of inputs to
                `loss_fn`. In practice, this will be the output of a model; this is
                why this argument is named `out`. `out` is a 2D tensor of shape
                (batch size, model output dimensionality). We will call `loss_fn` via
                `loss_fn(out, targets)`.
        targets (tensor): The labels for the batch of inputs.
        vectorize (bool): Flag to use experimental vectorize functionality for
                `torch.autograd.functional.jacobian`.
        reduction_type (str): The type of reduction used by `loss_fn`. If `loss_fn`
                has the "reduction" attribute, we will check that they match. Can
                only be "mean" or "sum".

    Returns:
        jacobians (tensor): Returns the jacobian of the per-sample loss (implicitly
                defined by `loss_fn` and `reduction_type`) w.r.t each sample
                in the batch represented by `out`. This is a 2D tensor, where the
                first dimension is the batch dimension.
    """
    # TODO: allow loss_fn to be Callable
    if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
        msg0 = "Please ensure that loss_fn.reduction is set to `sum` or `mean`"

        assert loss_fn.reduction != "none", msg0
        msg1 = (
            f"loss_fn.reduction ({loss_fn.reduction}) does not match"
            f"reduction type ({reduction_type}). Please ensure they are"
            " matching."
        )
        assert loss_fn.reduction == reduction_type, msg1

    if reduction_type != "sum" and reduction_type != "mean":
        raise ValueError(
            f"{reduction_type} is not a valid value for reduction_type. "
            "Must be either 'sum' or 'mean'."
        )

    if torch.__version__ >= "1.8":
        input_jacobians = torch.autograd.functional.jacobian(
            lambda out: loss_fn(out, targets), out, vectorize=vectorize
        )
    else:
        input_jacobians = torch.autograd.functional.jacobian(
            lambda out: loss_fn(out, targets), out
        )

    if reduction_type == "mean":
        input_jacobians = input_jacobians * len(input_jacobians)

    return input_jacobians


def _load_flexible_state_dict(
    model: Module, path: str, device_ids: str = "cpu", keyname: Optional[str] = None
) -> int:
    r"""
    Helper to load pytorch models. This function attempts to find compatibility for
    loading models that were trained on different devices / with DataParallel but are
    being loaded in a different environment.

    Assumes that the model has been saved as a state_dict in some capacity. This can
    either be a single state dict, or a nesting dictionary which contains the model
    state_dict and other information.

    Args:
        model: The model for which to load a checkpoint
        path: The filepath to the checkpoint
        keyname: The key under which the model state_dict is stored, if any.

    The module state_dict is modified in-place, and the learning rate is returned.
    """

    device = device_ids

    checkpoint = torch.load(path, map_location=device)

    learning_rate = checkpoint.get("learning_rate", 1)
    # can get learning rate from optimizer state_dict?

    if keyname is not None:
        checkpoint = checkpoint[keyname]

    if "module." in next(iter(checkpoint)):
        if isinstance(model, nn.DataParallel):
            model.load_state_dict(checkpoint)
        else:
            model = nn.DataParallel(model)
            model.load_state_dict(checkpoint)
            model = model.module
    else:
        if isinstance(model, nn.DataParallel):
            model = model.module
            model.load_state_dict(checkpoint)
            model = nn.DataParallel(model)
        else:
            model.load_state_dict(checkpoint)

    return learning_rate


def _get_k_most_influential_helper(
    influence_src_dataloader: DataLoader,
    influence_batch_fn: Callable,
    inputs: Tuple[Any, ...],
    targets: Optional[Tensor],
    k: int = 5,
    proponents: bool = True,
    show_progress: bool = False,
    desc: Optional[str] = None,
) -> Tuple[Tensor, Tensor]:
    r"""
    Helper function that computes the quantities returned by
    `TracInCPBase._get_k_most_influential`, using a specific implementation that is
    constant memory.

    Args:
        influence_src_dataloader (DataLoader): The DataLoader, representing training
                data, for which we want to compute proponents / opponents.
        influence_batch_fn (Callable): A callable that will be called via
                `influence_batch_fn(inputs, targets, batch)`, where `batch` is a batch
                in the `influence_src_dataloader` argument.
        inputs (Tuple of Any): A batch of examples. Does not represent labels,
                which are passed as `targets`.
        targets (Tensor, optional): If computing TracIn scores on a loss function,
                these are the labels corresponding to the batch `inputs`.
                Default: None
        k (int, optional): The number of proponents or opponents to return per test
                instance.
                Default: 5
        proponents (bool, optional): Whether seeking proponents (`proponents=True`)
                or opponents (`proponents=False`)
                Default: True
        show_progress (bool, optional): To compute the proponents (or opponents)
                for the batch of examples, we perform computation for each batch in
                training dataset `influence_src_dataloader`, If `show_progress`is
                true, the progress of this computation will be displayed. In
                particular, the number of batches for which the computation has
                been performed will be displayed. It will try to use tqdm if
                available for advanced features (e.g. time estimation). Otherwise,
                it will fallback to a simple output of progress.
                Default: False
        desc (str, optional): If `show_progress` is true, this is the description to
                show when displaying progress. If `desc` is none, no description is
                shown.
                Default: None

    Returns:
        (indices, influence_scores): `indices` is a torch.long Tensor that contains the
                indices of the proponents (or opponents) for each test example. Its
                dimension is `(inputs_batch_size, k)`, where `inputs_batch_size` is the
                number of examples in `inputs`. For example, if `proponents==True`,
                `indices[i][j]` is the index of the example in training dataset
                `influence_src_dataloader` with the k-th highest influence score for
                the j-th example in `inputs`. `indices` is a `torch.long` tensor so that
                it can directly be used to index other tensors. Each row of
                `influence_scores` contains the influence scores for a different test
                example, in sorted order. In particular, `influence_scores[i][j]` is
                the influence score of example `indices[i][j]` in training dataset
                `influence_src_dataloader` on example `i` in the test batch represented
                by `inputs` and `targets`.
    """
    # For each test instance, maintain the best indices and corresponding distances
    # initially, these will be empty
    topk_indices = torch.Tensor().long()
    topk_tracin_scores = torch.Tensor()

    multiplier = 1.0 if proponents else -1.0

    # needed to map from relative index in a batch fo index within entire `dataloader`
    num_instances_processed = 0

    # if show_progress, create progress bar
    total: Optional[int] = None
    if show_progress:
        try:
            total = len(influence_src_dataloader)
        except AttributeError:
            pass
        influence_src_dataloader = progress(
            influence_src_dataloader,
            desc=desc,
            total=total,
        )

    for batch in influence_src_dataloader:

        # calculate tracin_scores for the batch
        batch_tracin_scores = influence_batch_fn(inputs, targets, batch)
        batch_tracin_scores *= multiplier

        # get the top-k indices and tracin_scores for the batch
        batch_size = batch_tracin_scores.shape[1]
        batch_topk_tracin_scores, batch_topk_indices = torch.topk(
            batch_tracin_scores, min(batch_size, k), dim=1
        )
        batch_topk_indices = batch_topk_indices + num_instances_processed
        num_instances_processed += batch_size

        # combine the top-k for the batch with those for previously seen batches
        topk_indices = torch.cat([topk_indices, batch_topk_indices], dim=1)
        topk_tracin_scores = torch.cat(
            [topk_tracin_scores, batch_topk_tracin_scores], dim=1
        )

        # retain only the top-k in terms of tracin_scores
        topk_tracin_scores, topk_argsort = torch.topk(
            topk_tracin_scores, min(k, topk_indices.shape[1]), dim=1
        )
        topk_indices = torch.gather(topk_indices, dim=1, index=topk_argsort)

    # if seeking opponents, we were actually keeping track of negative tracin_scores
    topk_tracin_scores *= multiplier

    return topk_indices, topk_tracin_scores


class _DatasetFromList(Dataset):
    def __init__(self, _l: List[Any]):
        self._l = _l

    def __getitem__(self, i: int) -> Any:
        return self._l[i]

    def __len__(self) -> int:
        return len(self._l)