Spaces:
Build error
Build error
File size: 13,328 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 |
#!/usr/bin/env python3
from typing import Callable, Optional, Tuple, Union, Any, List
import torch
import torch.nn as nn
from captum._utils.progress import progress
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
def _tensor_batch_dot(t1: Tensor, t2: Tensor) -> Tensor:
r"""
Computes pairwise dot product between two tensors
Args:
Tensors t1 and t2 are feature vectors with dimension (batch_size_1, *) and
(batch_size_2, *). The * dimensions must match in total number of elements.
Returns:
Tensor with shape (batch_size_1, batch_size_2) containing the pairwise dot
products. For example, Tensor[i][j] would be the dot product between
t1[i] and t2[j].
"""
msg = (
"Please ensure each batch member has the same feature dimension. "
f"First input has {torch.numel(t1) / t1.shape[0]} features, and "
f"second input has {torch.numel(t2) / t2.shape[0]} features."
)
assert torch.numel(t1) / t1.shape[0] == torch.numel(t2) / t2.shape[0], msg
return torch.mm(
t1.view(t1.shape[0], -1),
t2.view(t2.shape[0], -1).T,
)
def _gradient_dot_product(
input_grads: Tuple[Tensor], src_grads: Tuple[Tensor]
) -> Tensor:
r"""
Computes the dot product between the gradient vector for a model on an input batch
and src batch, for each pairwise batch member. Gradients are passed in as a tuple
corresponding to the trainable parameters returned by model.parameters(). Output
corresponds to a tensor of size (inputs_batch_size, src_batch_size) with all
pairwise dot products.
"""
assert len(input_grads) == len(src_grads), "Mismatching gradient parameters."
iterator = zip(input_grads, src_grads)
total = _tensor_batch_dot(*next(iterator))
for input_grad, src_grad in iterator:
total += _tensor_batch_dot(input_grad, src_grad)
total = torch.Tensor(total)
return total
def _jacobian_loss_wrt_inputs(
loss_fn: Union[Module, Callable],
out: Tensor,
targets: Tensor,
vectorize: bool,
reduction_type: str,
) -> Tensor:
r"""
Often, we have a loss function that computes a per-sample loss given a 1D tensor
input, and we want to calculate the jacobian of the loss w.r.t. that input. For
example, the input could be a length K tensor specifying the probability a given
sample belongs to each of K possible classes, and the loss function could be
cross-entropy loss. This function performs that calculation, but does so for a
*batch* of inputs. We create this helper function for two reasons: 1) to handle
differences between Pytorch versiosn for vectorized jacobian calculations, and
2) this function does not accept the aforementioned per-sample loss function.
Instead, it accepts a "reduction" loss function that *reduces* the per-sample loss
for a batch into a single loss. Using a "reduction" loss improves speed.
We will allow this reduction to either be the mean or sum of the per-sample losses,
and this function provides an uniform way to handle different possible reductions,
and also check if the reduction used is valid. Regardless of the reduction used,
this function returns the jacobian for the per-sample loss (for each sample in the
batch).
Args:
loss_fn (torch.nn.Module or Callable or None): The loss function. If a library
defined loss function is provided, it would be expected to be a
torch.nn.Module. If a custom loss is provided, it can be either type,
but must behave as a library loss function would if `reduction='sum'`
or `reduction='mean'`.
out (tensor): This is a tensor that represents the batch of inputs to
`loss_fn`. In practice, this will be the output of a model; this is
why this argument is named `out`. `out` is a 2D tensor of shape
(batch size, model output dimensionality). We will call `loss_fn` via
`loss_fn(out, targets)`.
targets (tensor): The labels for the batch of inputs.
vectorize (bool): Flag to use experimental vectorize functionality for
`torch.autograd.functional.jacobian`.
reduction_type (str): The type of reduction used by `loss_fn`. If `loss_fn`
has the "reduction" attribute, we will check that they match. Can
only be "mean" or "sum".
Returns:
jacobians (tensor): Returns the jacobian of the per-sample loss (implicitly
defined by `loss_fn` and `reduction_type`) w.r.t each sample
in the batch represented by `out`. This is a 2D tensor, where the
first dimension is the batch dimension.
"""
# TODO: allow loss_fn to be Callable
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
msg0 = "Please ensure that loss_fn.reduction is set to `sum` or `mean`"
assert loss_fn.reduction != "none", msg0
msg1 = (
f"loss_fn.reduction ({loss_fn.reduction}) does not match"
f"reduction type ({reduction_type}). Please ensure they are"
" matching."
)
assert loss_fn.reduction == reduction_type, msg1
if reduction_type != "sum" and reduction_type != "mean":
raise ValueError(
f"{reduction_type} is not a valid value for reduction_type. "
"Must be either 'sum' or 'mean'."
)
if torch.__version__ >= "1.8":
input_jacobians = torch.autograd.functional.jacobian(
lambda out: loss_fn(out, targets), out, vectorize=vectorize
)
else:
input_jacobians = torch.autograd.functional.jacobian(
lambda out: loss_fn(out, targets), out
)
if reduction_type == "mean":
input_jacobians = input_jacobians * len(input_jacobians)
return input_jacobians
def _load_flexible_state_dict(
model: Module, path: str, device_ids: str = "cpu", keyname: Optional[str] = None
) -> int:
r"""
Helper to load pytorch models. This function attempts to find compatibility for
loading models that were trained on different devices / with DataParallel but are
being loaded in a different environment.
Assumes that the model has been saved as a state_dict in some capacity. This can
either be a single state dict, or a nesting dictionary which contains the model
state_dict and other information.
Args:
model: The model for which to load a checkpoint
path: The filepath to the checkpoint
keyname: The key under which the model state_dict is stored, if any.
The module state_dict is modified in-place, and the learning rate is returned.
"""
device = device_ids
checkpoint = torch.load(path, map_location=device)
learning_rate = checkpoint.get("learning_rate", 1)
# can get learning rate from optimizer state_dict?
if keyname is not None:
checkpoint = checkpoint[keyname]
if "module." in next(iter(checkpoint)):
if isinstance(model, nn.DataParallel):
model.load_state_dict(checkpoint)
else:
model = nn.DataParallel(model)
model.load_state_dict(checkpoint)
model = model.module
else:
if isinstance(model, nn.DataParallel):
model = model.module
model.load_state_dict(checkpoint)
model = nn.DataParallel(model)
else:
model.load_state_dict(checkpoint)
return learning_rate
def _get_k_most_influential_helper(
influence_src_dataloader: DataLoader,
influence_batch_fn: Callable,
inputs: Tuple[Any, ...],
targets: Optional[Tensor],
k: int = 5,
proponents: bool = True,
show_progress: bool = False,
desc: Optional[str] = None,
) -> Tuple[Tensor, Tensor]:
r"""
Helper function that computes the quantities returned by
`TracInCPBase._get_k_most_influential`, using a specific implementation that is
constant memory.
Args:
influence_src_dataloader (DataLoader): The DataLoader, representing training
data, for which we want to compute proponents / opponents.
influence_batch_fn (Callable): A callable that will be called via
`influence_batch_fn(inputs, targets, batch)`, where `batch` is a batch
in the `influence_src_dataloader` argument.
inputs (Tuple of Any): A batch of examples. Does not represent labels,
which are passed as `targets`.
targets (Tensor, optional): If computing TracIn scores on a loss function,
these are the labels corresponding to the batch `inputs`.
Default: None
k (int, optional): The number of proponents or opponents to return per test
instance.
Default: 5
proponents (bool, optional): Whether seeking proponents (`proponents=True`)
or opponents (`proponents=False`)
Default: True
show_progress (bool, optional): To compute the proponents (or opponents)
for the batch of examples, we perform computation for each batch in
training dataset `influence_src_dataloader`, If `show_progress`is
true, the progress of this computation will be displayed. In
particular, the number of batches for which the computation has
been performed will be displayed. It will try to use tqdm if
available for advanced features (e.g. time estimation). Otherwise,
it will fallback to a simple output of progress.
Default: False
desc (str, optional): If `show_progress` is true, this is the description to
show when displaying progress. If `desc` is none, no description is
shown.
Default: None
Returns:
(indices, influence_scores): `indices` is a torch.long Tensor that contains the
indices of the proponents (or opponents) for each test example. Its
dimension is `(inputs_batch_size, k)`, where `inputs_batch_size` is the
number of examples in `inputs`. For example, if `proponents==True`,
`indices[i][j]` is the index of the example in training dataset
`influence_src_dataloader` with the k-th highest influence score for
the j-th example in `inputs`. `indices` is a `torch.long` tensor so that
it can directly be used to index other tensors. Each row of
`influence_scores` contains the influence scores for a different test
example, in sorted order. In particular, `influence_scores[i][j]` is
the influence score of example `indices[i][j]` in training dataset
`influence_src_dataloader` on example `i` in the test batch represented
by `inputs` and `targets`.
"""
# For each test instance, maintain the best indices and corresponding distances
# initially, these will be empty
topk_indices = torch.Tensor().long()
topk_tracin_scores = torch.Tensor()
multiplier = 1.0 if proponents else -1.0
# needed to map from relative index in a batch fo index within entire `dataloader`
num_instances_processed = 0
# if show_progress, create progress bar
total: Optional[int] = None
if show_progress:
try:
total = len(influence_src_dataloader)
except AttributeError:
pass
influence_src_dataloader = progress(
influence_src_dataloader,
desc=desc,
total=total,
)
for batch in influence_src_dataloader:
# calculate tracin_scores for the batch
batch_tracin_scores = influence_batch_fn(inputs, targets, batch)
batch_tracin_scores *= multiplier
# get the top-k indices and tracin_scores for the batch
batch_size = batch_tracin_scores.shape[1]
batch_topk_tracin_scores, batch_topk_indices = torch.topk(
batch_tracin_scores, min(batch_size, k), dim=1
)
batch_topk_indices = batch_topk_indices + num_instances_processed
num_instances_processed += batch_size
# combine the top-k for the batch with those for previously seen batches
topk_indices = torch.cat([topk_indices, batch_topk_indices], dim=1)
topk_tracin_scores = torch.cat(
[topk_tracin_scores, batch_topk_tracin_scores], dim=1
)
# retain only the top-k in terms of tracin_scores
topk_tracin_scores, topk_argsort = torch.topk(
topk_tracin_scores, min(k, topk_indices.shape[1]), dim=1
)
topk_indices = torch.gather(topk_indices, dim=1, index=topk_argsort)
# if seeking opponents, we were actually keeping track of negative tracin_scores
topk_tracin_scores *= multiplier
return topk_indices, topk_tracin_scores
class _DatasetFromList(Dataset):
def __init__(self, _l: List[Any]):
self._l = _l
def __getitem__(self, i: int) -> Any:
return self._l[i]
def __len__(self) -> int:
return len(self._l)
|