markytools's picture
added strexp
d61b9c7
raw
history blame
13.3 kB
#!/usr/bin/env python3
from typing import Callable, Optional, Tuple, Union, Any, List
import torch
import torch.nn as nn
from captum._utils.progress import progress
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
def _tensor_batch_dot(t1: Tensor, t2: Tensor) -> Tensor:
r"""
Computes pairwise dot product between two tensors
Args:
Tensors t1 and t2 are feature vectors with dimension (batch_size_1, *) and
(batch_size_2, *). The * dimensions must match in total number of elements.
Returns:
Tensor with shape (batch_size_1, batch_size_2) containing the pairwise dot
products. For example, Tensor[i][j] would be the dot product between
t1[i] and t2[j].
"""
msg = (
"Please ensure each batch member has the same feature dimension. "
f"First input has {torch.numel(t1) / t1.shape[0]} features, and "
f"second input has {torch.numel(t2) / t2.shape[0]} features."
)
assert torch.numel(t1) / t1.shape[0] == torch.numel(t2) / t2.shape[0], msg
return torch.mm(
t1.view(t1.shape[0], -1),
t2.view(t2.shape[0], -1).T,
)
def _gradient_dot_product(
input_grads: Tuple[Tensor], src_grads: Tuple[Tensor]
) -> Tensor:
r"""
Computes the dot product between the gradient vector for a model on an input batch
and src batch, for each pairwise batch member. Gradients are passed in as a tuple
corresponding to the trainable parameters returned by model.parameters(). Output
corresponds to a tensor of size (inputs_batch_size, src_batch_size) with all
pairwise dot products.
"""
assert len(input_grads) == len(src_grads), "Mismatching gradient parameters."
iterator = zip(input_grads, src_grads)
total = _tensor_batch_dot(*next(iterator))
for input_grad, src_grad in iterator:
total += _tensor_batch_dot(input_grad, src_grad)
total = torch.Tensor(total)
return total
def _jacobian_loss_wrt_inputs(
loss_fn: Union[Module, Callable],
out: Tensor,
targets: Tensor,
vectorize: bool,
reduction_type: str,
) -> Tensor:
r"""
Often, we have a loss function that computes a per-sample loss given a 1D tensor
input, and we want to calculate the jacobian of the loss w.r.t. that input. For
example, the input could be a length K tensor specifying the probability a given
sample belongs to each of K possible classes, and the loss function could be
cross-entropy loss. This function performs that calculation, but does so for a
*batch* of inputs. We create this helper function for two reasons: 1) to handle
differences between Pytorch versiosn for vectorized jacobian calculations, and
2) this function does not accept the aforementioned per-sample loss function.
Instead, it accepts a "reduction" loss function that *reduces* the per-sample loss
for a batch into a single loss. Using a "reduction" loss improves speed.
We will allow this reduction to either be the mean or sum of the per-sample losses,
and this function provides an uniform way to handle different possible reductions,
and also check if the reduction used is valid. Regardless of the reduction used,
this function returns the jacobian for the per-sample loss (for each sample in the
batch).
Args:
loss_fn (torch.nn.Module or Callable or None): The loss function. If a library
defined loss function is provided, it would be expected to be a
torch.nn.Module. If a custom loss is provided, it can be either type,
but must behave as a library loss function would if `reduction='sum'`
or `reduction='mean'`.
out (tensor): This is a tensor that represents the batch of inputs to
`loss_fn`. In practice, this will be the output of a model; this is
why this argument is named `out`. `out` is a 2D tensor of shape
(batch size, model output dimensionality). We will call `loss_fn` via
`loss_fn(out, targets)`.
targets (tensor): The labels for the batch of inputs.
vectorize (bool): Flag to use experimental vectorize functionality for
`torch.autograd.functional.jacobian`.
reduction_type (str): The type of reduction used by `loss_fn`. If `loss_fn`
has the "reduction" attribute, we will check that they match. Can
only be "mean" or "sum".
Returns:
jacobians (tensor): Returns the jacobian of the per-sample loss (implicitly
defined by `loss_fn` and `reduction_type`) w.r.t each sample
in the batch represented by `out`. This is a 2D tensor, where the
first dimension is the batch dimension.
"""
# TODO: allow loss_fn to be Callable
if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"):
msg0 = "Please ensure that loss_fn.reduction is set to `sum` or `mean`"
assert loss_fn.reduction != "none", msg0
msg1 = (
f"loss_fn.reduction ({loss_fn.reduction}) does not match"
f"reduction type ({reduction_type}). Please ensure they are"
" matching."
)
assert loss_fn.reduction == reduction_type, msg1
if reduction_type != "sum" and reduction_type != "mean":
raise ValueError(
f"{reduction_type} is not a valid value for reduction_type. "
"Must be either 'sum' or 'mean'."
)
if torch.__version__ >= "1.8":
input_jacobians = torch.autograd.functional.jacobian(
lambda out: loss_fn(out, targets), out, vectorize=vectorize
)
else:
input_jacobians = torch.autograd.functional.jacobian(
lambda out: loss_fn(out, targets), out
)
if reduction_type == "mean":
input_jacobians = input_jacobians * len(input_jacobians)
return input_jacobians
def _load_flexible_state_dict(
model: Module, path: str, device_ids: str = "cpu", keyname: Optional[str] = None
) -> int:
r"""
Helper to load pytorch models. This function attempts to find compatibility for
loading models that were trained on different devices / with DataParallel but are
being loaded in a different environment.
Assumes that the model has been saved as a state_dict in some capacity. This can
either be a single state dict, or a nesting dictionary which contains the model
state_dict and other information.
Args:
model: The model for which to load a checkpoint
path: The filepath to the checkpoint
keyname: The key under which the model state_dict is stored, if any.
The module state_dict is modified in-place, and the learning rate is returned.
"""
device = device_ids
checkpoint = torch.load(path, map_location=device)
learning_rate = checkpoint.get("learning_rate", 1)
# can get learning rate from optimizer state_dict?
if keyname is not None:
checkpoint = checkpoint[keyname]
if "module." in next(iter(checkpoint)):
if isinstance(model, nn.DataParallel):
model.load_state_dict(checkpoint)
else:
model = nn.DataParallel(model)
model.load_state_dict(checkpoint)
model = model.module
else:
if isinstance(model, nn.DataParallel):
model = model.module
model.load_state_dict(checkpoint)
model = nn.DataParallel(model)
else:
model.load_state_dict(checkpoint)
return learning_rate
def _get_k_most_influential_helper(
influence_src_dataloader: DataLoader,
influence_batch_fn: Callable,
inputs: Tuple[Any, ...],
targets: Optional[Tensor],
k: int = 5,
proponents: bool = True,
show_progress: bool = False,
desc: Optional[str] = None,
) -> Tuple[Tensor, Tensor]:
r"""
Helper function that computes the quantities returned by
`TracInCPBase._get_k_most_influential`, using a specific implementation that is
constant memory.
Args:
influence_src_dataloader (DataLoader): The DataLoader, representing training
data, for which we want to compute proponents / opponents.
influence_batch_fn (Callable): A callable that will be called via
`influence_batch_fn(inputs, targets, batch)`, where `batch` is a batch
in the `influence_src_dataloader` argument.
inputs (Tuple of Any): A batch of examples. Does not represent labels,
which are passed as `targets`.
targets (Tensor, optional): If computing TracIn scores on a loss function,
these are the labels corresponding to the batch `inputs`.
Default: None
k (int, optional): The number of proponents or opponents to return per test
instance.
Default: 5
proponents (bool, optional): Whether seeking proponents (`proponents=True`)
or opponents (`proponents=False`)
Default: True
show_progress (bool, optional): To compute the proponents (or opponents)
for the batch of examples, we perform computation for each batch in
training dataset `influence_src_dataloader`, If `show_progress`is
true, the progress of this computation will be displayed. In
particular, the number of batches for which the computation has
been performed will be displayed. It will try to use tqdm if
available for advanced features (e.g. time estimation). Otherwise,
it will fallback to a simple output of progress.
Default: False
desc (str, optional): If `show_progress` is true, this is the description to
show when displaying progress. If `desc` is none, no description is
shown.
Default: None
Returns:
(indices, influence_scores): `indices` is a torch.long Tensor that contains the
indices of the proponents (or opponents) for each test example. Its
dimension is `(inputs_batch_size, k)`, where `inputs_batch_size` is the
number of examples in `inputs`. For example, if `proponents==True`,
`indices[i][j]` is the index of the example in training dataset
`influence_src_dataloader` with the k-th highest influence score for
the j-th example in `inputs`. `indices` is a `torch.long` tensor so that
it can directly be used to index other tensors. Each row of
`influence_scores` contains the influence scores for a different test
example, in sorted order. In particular, `influence_scores[i][j]` is
the influence score of example `indices[i][j]` in training dataset
`influence_src_dataloader` on example `i` in the test batch represented
by `inputs` and `targets`.
"""
# For each test instance, maintain the best indices and corresponding distances
# initially, these will be empty
topk_indices = torch.Tensor().long()
topk_tracin_scores = torch.Tensor()
multiplier = 1.0 if proponents else -1.0
# needed to map from relative index in a batch fo index within entire `dataloader`
num_instances_processed = 0
# if show_progress, create progress bar
total: Optional[int] = None
if show_progress:
try:
total = len(influence_src_dataloader)
except AttributeError:
pass
influence_src_dataloader = progress(
influence_src_dataloader,
desc=desc,
total=total,
)
for batch in influence_src_dataloader:
# calculate tracin_scores for the batch
batch_tracin_scores = influence_batch_fn(inputs, targets, batch)
batch_tracin_scores *= multiplier
# get the top-k indices and tracin_scores for the batch
batch_size = batch_tracin_scores.shape[1]
batch_topk_tracin_scores, batch_topk_indices = torch.topk(
batch_tracin_scores, min(batch_size, k), dim=1
)
batch_topk_indices = batch_topk_indices + num_instances_processed
num_instances_processed += batch_size
# combine the top-k for the batch with those for previously seen batches
topk_indices = torch.cat([topk_indices, batch_topk_indices], dim=1)
topk_tracin_scores = torch.cat(
[topk_tracin_scores, batch_topk_tracin_scores], dim=1
)
# retain only the top-k in terms of tracin_scores
topk_tracin_scores, topk_argsort = torch.topk(
topk_tracin_scores, min(k, topk_indices.shape[1]), dim=1
)
topk_indices = torch.gather(topk_indices, dim=1, index=topk_argsort)
# if seeking opponents, we were actually keeping track of negative tracin_scores
topk_tracin_scores *= multiplier
return topk_indices, topk_tracin_scores
class _DatasetFromList(Dataset):
def __init__(self, _l: List[Any]):
self._l = _l
def __getitem__(self, i: int) -> Any:
return self._l[i]
def __len__(self) -> int:
return len(self._l)