Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
from typing import Callable, Optional, Tuple, Union, Any, List | |
import torch | |
import torch.nn as nn | |
from captum._utils.progress import progress | |
from torch import Tensor | |
from torch.nn import Module | |
from torch.utils.data import DataLoader, Dataset | |
def _tensor_batch_dot(t1: Tensor, t2: Tensor) -> Tensor: | |
r""" | |
Computes pairwise dot product between two tensors | |
Args: | |
Tensors t1 and t2 are feature vectors with dimension (batch_size_1, *) and | |
(batch_size_2, *). The * dimensions must match in total number of elements. | |
Returns: | |
Tensor with shape (batch_size_1, batch_size_2) containing the pairwise dot | |
products. For example, Tensor[i][j] would be the dot product between | |
t1[i] and t2[j]. | |
""" | |
msg = ( | |
"Please ensure each batch member has the same feature dimension. " | |
f"First input has {torch.numel(t1) / t1.shape[0]} features, and " | |
f"second input has {torch.numel(t2) / t2.shape[0]} features." | |
) | |
assert torch.numel(t1) / t1.shape[0] == torch.numel(t2) / t2.shape[0], msg | |
return torch.mm( | |
t1.view(t1.shape[0], -1), | |
t2.view(t2.shape[0], -1).T, | |
) | |
def _gradient_dot_product( | |
input_grads: Tuple[Tensor], src_grads: Tuple[Tensor] | |
) -> Tensor: | |
r""" | |
Computes the dot product between the gradient vector for a model on an input batch | |
and src batch, for each pairwise batch member. Gradients are passed in as a tuple | |
corresponding to the trainable parameters returned by model.parameters(). Output | |
corresponds to a tensor of size (inputs_batch_size, src_batch_size) with all | |
pairwise dot products. | |
""" | |
assert len(input_grads) == len(src_grads), "Mismatching gradient parameters." | |
iterator = zip(input_grads, src_grads) | |
total = _tensor_batch_dot(*next(iterator)) | |
for input_grad, src_grad in iterator: | |
total += _tensor_batch_dot(input_grad, src_grad) | |
total = torch.Tensor(total) | |
return total | |
def _jacobian_loss_wrt_inputs( | |
loss_fn: Union[Module, Callable], | |
out: Tensor, | |
targets: Tensor, | |
vectorize: bool, | |
reduction_type: str, | |
) -> Tensor: | |
r""" | |
Often, we have a loss function that computes a per-sample loss given a 1D tensor | |
input, and we want to calculate the jacobian of the loss w.r.t. that input. For | |
example, the input could be a length K tensor specifying the probability a given | |
sample belongs to each of K possible classes, and the loss function could be | |
cross-entropy loss. This function performs that calculation, but does so for a | |
*batch* of inputs. We create this helper function for two reasons: 1) to handle | |
differences between Pytorch versiosn for vectorized jacobian calculations, and | |
2) this function does not accept the aforementioned per-sample loss function. | |
Instead, it accepts a "reduction" loss function that *reduces* the per-sample loss | |
for a batch into a single loss. Using a "reduction" loss improves speed. | |
We will allow this reduction to either be the mean or sum of the per-sample losses, | |
and this function provides an uniform way to handle different possible reductions, | |
and also check if the reduction used is valid. Regardless of the reduction used, | |
this function returns the jacobian for the per-sample loss (for each sample in the | |
batch). | |
Args: | |
loss_fn (torch.nn.Module or Callable or None): The loss function. If a library | |
defined loss function is provided, it would be expected to be a | |
torch.nn.Module. If a custom loss is provided, it can be either type, | |
but must behave as a library loss function would if `reduction='sum'` | |
or `reduction='mean'`. | |
out (tensor): This is a tensor that represents the batch of inputs to | |
`loss_fn`. In practice, this will be the output of a model; this is | |
why this argument is named `out`. `out` is a 2D tensor of shape | |
(batch size, model output dimensionality). We will call `loss_fn` via | |
`loss_fn(out, targets)`. | |
targets (tensor): The labels for the batch of inputs. | |
vectorize (bool): Flag to use experimental vectorize functionality for | |
`torch.autograd.functional.jacobian`. | |
reduction_type (str): The type of reduction used by `loss_fn`. If `loss_fn` | |
has the "reduction" attribute, we will check that they match. Can | |
only be "mean" or "sum". | |
Returns: | |
jacobians (tensor): Returns the jacobian of the per-sample loss (implicitly | |
defined by `loss_fn` and `reduction_type`) w.r.t each sample | |
in the batch represented by `out`. This is a 2D tensor, where the | |
first dimension is the batch dimension. | |
""" | |
# TODO: allow loss_fn to be Callable | |
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"): | |
msg0 = "Please ensure that loss_fn.reduction is set to `sum` or `mean`" | |
assert loss_fn.reduction != "none", msg0 | |
msg1 = ( | |
f"loss_fn.reduction ({loss_fn.reduction}) does not match" | |
f"reduction type ({reduction_type}). Please ensure they are" | |
" matching." | |
) | |
assert loss_fn.reduction == reduction_type, msg1 | |
if reduction_type != "sum" and reduction_type != "mean": | |
raise ValueError( | |
f"{reduction_type} is not a valid value for reduction_type. " | |
"Must be either 'sum' or 'mean'." | |
) | |
if torch.__version__ >= "1.8": | |
input_jacobians = torch.autograd.functional.jacobian( | |
lambda out: loss_fn(out, targets), out, vectorize=vectorize | |
) | |
else: | |
input_jacobians = torch.autograd.functional.jacobian( | |
lambda out: loss_fn(out, targets), out | |
) | |
if reduction_type == "mean": | |
input_jacobians = input_jacobians * len(input_jacobians) | |
return input_jacobians | |
def _load_flexible_state_dict( | |
model: Module, path: str, device_ids: str = "cpu", keyname: Optional[str] = None | |
) -> int: | |
r""" | |
Helper to load pytorch models. This function attempts to find compatibility for | |
loading models that were trained on different devices / with DataParallel but are | |
being loaded in a different environment. | |
Assumes that the model has been saved as a state_dict in some capacity. This can | |
either be a single state dict, or a nesting dictionary which contains the model | |
state_dict and other information. | |
Args: | |
model: The model for which to load a checkpoint | |
path: The filepath to the checkpoint | |
keyname: The key under which the model state_dict is stored, if any. | |
The module state_dict is modified in-place, and the learning rate is returned. | |
""" | |
device = device_ids | |
checkpoint = torch.load(path, map_location=device) | |
learning_rate = checkpoint.get("learning_rate", 1) | |
# can get learning rate from optimizer state_dict? | |
if keyname is not None: | |
checkpoint = checkpoint[keyname] | |
if "module." in next(iter(checkpoint)): | |
if isinstance(model, nn.DataParallel): | |
model.load_state_dict(checkpoint) | |
else: | |
model = nn.DataParallel(model) | |
model.load_state_dict(checkpoint) | |
model = model.module | |
else: | |
if isinstance(model, nn.DataParallel): | |
model = model.module | |
model.load_state_dict(checkpoint) | |
model = nn.DataParallel(model) | |
else: | |
model.load_state_dict(checkpoint) | |
return learning_rate | |
def _get_k_most_influential_helper( | |
influence_src_dataloader: DataLoader, | |
influence_batch_fn: Callable, | |
inputs: Tuple[Any, ...], | |
targets: Optional[Tensor], | |
k: int = 5, | |
proponents: bool = True, | |
show_progress: bool = False, | |
desc: Optional[str] = None, | |
) -> Tuple[Tensor, Tensor]: | |
r""" | |
Helper function that computes the quantities returned by | |
`TracInCPBase._get_k_most_influential`, using a specific implementation that is | |
constant memory. | |
Args: | |
influence_src_dataloader (DataLoader): The DataLoader, representing training | |
data, for which we want to compute proponents / opponents. | |
influence_batch_fn (Callable): A callable that will be called via | |
`influence_batch_fn(inputs, targets, batch)`, where `batch` is a batch | |
in the `influence_src_dataloader` argument. | |
inputs (Tuple of Any): A batch of examples. Does not represent labels, | |
which are passed as `targets`. | |
targets (Tensor, optional): If computing TracIn scores on a loss function, | |
these are the labels corresponding to the batch `inputs`. | |
Default: None | |
k (int, optional): The number of proponents or opponents to return per test | |
instance. | |
Default: 5 | |
proponents (bool, optional): Whether seeking proponents (`proponents=True`) | |
or opponents (`proponents=False`) | |
Default: True | |
show_progress (bool, optional): To compute the proponents (or opponents) | |
for the batch of examples, we perform computation for each batch in | |
training dataset `influence_src_dataloader`, If `show_progress`is | |
true, the progress of this computation will be displayed. In | |
particular, the number of batches for which the computation has | |
been performed will be displayed. It will try to use tqdm if | |
available for advanced features (e.g. time estimation). Otherwise, | |
it will fallback to a simple output of progress. | |
Default: False | |
desc (str, optional): If `show_progress` is true, this is the description to | |
show when displaying progress. If `desc` is none, no description is | |
shown. | |
Default: None | |
Returns: | |
(indices, influence_scores): `indices` is a torch.long Tensor that contains the | |
indices of the proponents (or opponents) for each test example. Its | |
dimension is `(inputs_batch_size, k)`, where `inputs_batch_size` is the | |
number of examples in `inputs`. For example, if `proponents==True`, | |
`indices[i][j]` is the index of the example in training dataset | |
`influence_src_dataloader` with the k-th highest influence score for | |
the j-th example in `inputs`. `indices` is a `torch.long` tensor so that | |
it can directly be used to index other tensors. Each row of | |
`influence_scores` contains the influence scores for a different test | |
example, in sorted order. In particular, `influence_scores[i][j]` is | |
the influence score of example `indices[i][j]` in training dataset | |
`influence_src_dataloader` on example `i` in the test batch represented | |
by `inputs` and `targets`. | |
""" | |
# For each test instance, maintain the best indices and corresponding distances | |
# initially, these will be empty | |
topk_indices = torch.Tensor().long() | |
topk_tracin_scores = torch.Tensor() | |
multiplier = 1.0 if proponents else -1.0 | |
# needed to map from relative index in a batch fo index within entire `dataloader` | |
num_instances_processed = 0 | |
# if show_progress, create progress bar | |
total: Optional[int] = None | |
if show_progress: | |
try: | |
total = len(influence_src_dataloader) | |
except AttributeError: | |
pass | |
influence_src_dataloader = progress( | |
influence_src_dataloader, | |
desc=desc, | |
total=total, | |
) | |
for batch in influence_src_dataloader: | |
# calculate tracin_scores for the batch | |
batch_tracin_scores = influence_batch_fn(inputs, targets, batch) | |
batch_tracin_scores *= multiplier | |
# get the top-k indices and tracin_scores for the batch | |
batch_size = batch_tracin_scores.shape[1] | |
batch_topk_tracin_scores, batch_topk_indices = torch.topk( | |
batch_tracin_scores, min(batch_size, k), dim=1 | |
) | |
batch_topk_indices = batch_topk_indices + num_instances_processed | |
num_instances_processed += batch_size | |
# combine the top-k for the batch with those for previously seen batches | |
topk_indices = torch.cat([topk_indices, batch_topk_indices], dim=1) | |
topk_tracin_scores = torch.cat( | |
[topk_tracin_scores, batch_topk_tracin_scores], dim=1 | |
) | |
# retain only the top-k in terms of tracin_scores | |
topk_tracin_scores, topk_argsort = torch.topk( | |
topk_tracin_scores, min(k, topk_indices.shape[1]), dim=1 | |
) | |
topk_indices = torch.gather(topk_indices, dim=1, index=topk_argsort) | |
# if seeking opponents, we were actually keeping track of negative tracin_scores | |
topk_tracin_scores *= multiplier | |
return topk_indices, topk_tracin_scores | |
class _DatasetFromList(Dataset): | |
def __init__(self, _l: List[Any]): | |
self._l = _l | |
def __getitem__(self, i: int) -> Any: | |
return self._l[i] | |
def __len__(self) -> int: | |
return len(self._l) | |