#!/usr/bin/env python3 from typing import Callable, Optional, Tuple, Union, Any, List import torch import torch.nn as nn from captum._utils.progress import progress from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader, Dataset def _tensor_batch_dot(t1: Tensor, t2: Tensor) -> Tensor: r""" Computes pairwise dot product between two tensors Args: Tensors t1 and t2 are feature vectors with dimension (batch_size_1, *) and (batch_size_2, *). The * dimensions must match in total number of elements. Returns: Tensor with shape (batch_size_1, batch_size_2) containing the pairwise dot products. For example, Tensor[i][j] would be the dot product between t1[i] and t2[j]. """ msg = ( "Please ensure each batch member has the same feature dimension. " f"First input has {torch.numel(t1) / t1.shape[0]} features, and " f"second input has {torch.numel(t2) / t2.shape[0]} features." ) assert torch.numel(t1) / t1.shape[0] == torch.numel(t2) / t2.shape[0], msg return torch.mm( t1.view(t1.shape[0], -1), t2.view(t2.shape[0], -1).T, ) def _gradient_dot_product( input_grads: Tuple[Tensor], src_grads: Tuple[Tensor] ) -> Tensor: r""" Computes the dot product between the gradient vector for a model on an input batch and src batch, for each pairwise batch member. Gradients are passed in as a tuple corresponding to the trainable parameters returned by model.parameters(). Output corresponds to a tensor of size (inputs_batch_size, src_batch_size) with all pairwise dot products. """ assert len(input_grads) == len(src_grads), "Mismatching gradient parameters." iterator = zip(input_grads, src_grads) total = _tensor_batch_dot(*next(iterator)) for input_grad, src_grad in iterator: total += _tensor_batch_dot(input_grad, src_grad) total = torch.Tensor(total) return total def _jacobian_loss_wrt_inputs( loss_fn: Union[Module, Callable], out: Tensor, targets: Tensor, vectorize: bool, reduction_type: str, ) -> Tensor: r""" Often, we have a loss function that computes a per-sample loss given a 1D tensor input, and we want to calculate the jacobian of the loss w.r.t. that input. For example, the input could be a length K tensor specifying the probability a given sample belongs to each of K possible classes, and the loss function could be cross-entropy loss. This function performs that calculation, but does so for a *batch* of inputs. We create this helper function for two reasons: 1) to handle differences between Pytorch versiosn for vectorized jacobian calculations, and 2) this function does not accept the aforementioned per-sample loss function. Instead, it accepts a "reduction" loss function that *reduces* the per-sample loss for a batch into a single loss. Using a "reduction" loss improves speed. We will allow this reduction to either be the mean or sum of the per-sample losses, and this function provides an uniform way to handle different possible reductions, and also check if the reduction used is valid. Regardless of the reduction used, this function returns the jacobian for the per-sample loss (for each sample in the batch). Args: loss_fn (torch.nn.Module or Callable or None): The loss function. If a library defined loss function is provided, it would be expected to be a torch.nn.Module. If a custom loss is provided, it can be either type, but must behave as a library loss function would if `reduction='sum'` or `reduction='mean'`. out (tensor): This is a tensor that represents the batch of inputs to `loss_fn`. In practice, this will be the output of a model; this is why this argument is named `out`. `out` is a 2D tensor of shape (batch size, model output dimensionality). We will call `loss_fn` via `loss_fn(out, targets)`. targets (tensor): The labels for the batch of inputs. vectorize (bool): Flag to use experimental vectorize functionality for `torch.autograd.functional.jacobian`. reduction_type (str): The type of reduction used by `loss_fn`. If `loss_fn` has the "reduction" attribute, we will check that they match. Can only be "mean" or "sum". Returns: jacobians (tensor): Returns the jacobian of the per-sample loss (implicitly defined by `loss_fn` and `reduction_type`) w.r.t each sample in the batch represented by `out`. This is a 2D tensor, where the first dimension is the batch dimension. """ # TODO: allow loss_fn to be Callable if isinstance(loss_fn, Module) and hasattr(loss_fn, "reduction"): msg0 = "Please ensure that loss_fn.reduction is set to `sum` or `mean`" assert loss_fn.reduction != "none", msg0 msg1 = ( f"loss_fn.reduction ({loss_fn.reduction}) does not match" f"reduction type ({reduction_type}). Please ensure they are" " matching." ) assert loss_fn.reduction == reduction_type, msg1 if reduction_type != "sum" and reduction_type != "mean": raise ValueError( f"{reduction_type} is not a valid value for reduction_type. " "Must be either 'sum' or 'mean'." ) if torch.__version__ >= "1.8": input_jacobians = torch.autograd.functional.jacobian( lambda out: loss_fn(out, targets), out, vectorize=vectorize ) else: input_jacobians = torch.autograd.functional.jacobian( lambda out: loss_fn(out, targets), out ) if reduction_type == "mean": input_jacobians = input_jacobians * len(input_jacobians) return input_jacobians def _load_flexible_state_dict( model: Module, path: str, device_ids: str = "cpu", keyname: Optional[str] = None ) -> int: r""" Helper to load pytorch models. This function attempts to find compatibility for loading models that were trained on different devices / with DataParallel but are being loaded in a different environment. Assumes that the model has been saved as a state_dict in some capacity. This can either be a single state dict, or a nesting dictionary which contains the model state_dict and other information. Args: model: The model for which to load a checkpoint path: The filepath to the checkpoint keyname: The key under which the model state_dict is stored, if any. The module state_dict is modified in-place, and the learning rate is returned. """ device = device_ids checkpoint = torch.load(path, map_location=device) learning_rate = checkpoint.get("learning_rate", 1) # can get learning rate from optimizer state_dict? if keyname is not None: checkpoint = checkpoint[keyname] if "module." in next(iter(checkpoint)): if isinstance(model, nn.DataParallel): model.load_state_dict(checkpoint) else: model = nn.DataParallel(model) model.load_state_dict(checkpoint) model = model.module else: if isinstance(model, nn.DataParallel): model = model.module model.load_state_dict(checkpoint) model = nn.DataParallel(model) else: model.load_state_dict(checkpoint) return learning_rate def _get_k_most_influential_helper( influence_src_dataloader: DataLoader, influence_batch_fn: Callable, inputs: Tuple[Any, ...], targets: Optional[Tensor], k: int = 5, proponents: bool = True, show_progress: bool = False, desc: Optional[str] = None, ) -> Tuple[Tensor, Tensor]: r""" Helper function that computes the quantities returned by `TracInCPBase._get_k_most_influential`, using a specific implementation that is constant memory. Args: influence_src_dataloader (DataLoader): The DataLoader, representing training data, for which we want to compute proponents / opponents. influence_batch_fn (Callable): A callable that will be called via `influence_batch_fn(inputs, targets, batch)`, where `batch` is a batch in the `influence_src_dataloader` argument. inputs (Tuple of Any): A batch of examples. Does not represent labels, which are passed as `targets`. targets (Tensor, optional): If computing TracIn scores on a loss function, these are the labels corresponding to the batch `inputs`. Default: None k (int, optional): The number of proponents or opponents to return per test instance. Default: 5 proponents (bool, optional): Whether seeking proponents (`proponents=True`) or opponents (`proponents=False`) Default: True show_progress (bool, optional): To compute the proponents (or opponents) for the batch of examples, we perform computation for each batch in training dataset `influence_src_dataloader`, If `show_progress`is true, the progress of this computation will be displayed. In particular, the number of batches for which the computation has been performed will be displayed. It will try to use tqdm if available for advanced features (e.g. time estimation). Otherwise, it will fallback to a simple output of progress. Default: False desc (str, optional): If `show_progress` is true, this is the description to show when displaying progress. If `desc` is none, no description is shown. Default: None Returns: (indices, influence_scores): `indices` is a torch.long Tensor that contains the indices of the proponents (or opponents) for each test example. Its dimension is `(inputs_batch_size, k)`, where `inputs_batch_size` is the number of examples in `inputs`. For example, if `proponents==True`, `indices[i][j]` is the index of the example in training dataset `influence_src_dataloader` with the k-th highest influence score for the j-th example in `inputs`. `indices` is a `torch.long` tensor so that it can directly be used to index other tensors. Each row of `influence_scores` contains the influence scores for a different test example, in sorted order. In particular, `influence_scores[i][j]` is the influence score of example `indices[i][j]` in training dataset `influence_src_dataloader` on example `i` in the test batch represented by `inputs` and `targets`. """ # For each test instance, maintain the best indices and corresponding distances # initially, these will be empty topk_indices = torch.Tensor().long() topk_tracin_scores = torch.Tensor() multiplier = 1.0 if proponents else -1.0 # needed to map from relative index in a batch fo index within entire `dataloader` num_instances_processed = 0 # if show_progress, create progress bar total: Optional[int] = None if show_progress: try: total = len(influence_src_dataloader) except AttributeError: pass influence_src_dataloader = progress( influence_src_dataloader, desc=desc, total=total, ) for batch in influence_src_dataloader: # calculate tracin_scores for the batch batch_tracin_scores = influence_batch_fn(inputs, targets, batch) batch_tracin_scores *= multiplier # get the top-k indices and tracin_scores for the batch batch_size = batch_tracin_scores.shape[1] batch_topk_tracin_scores, batch_topk_indices = torch.topk( batch_tracin_scores, min(batch_size, k), dim=1 ) batch_topk_indices = batch_topk_indices + num_instances_processed num_instances_processed += batch_size # combine the top-k for the batch with those for previously seen batches topk_indices = torch.cat([topk_indices, batch_topk_indices], dim=1) topk_tracin_scores = torch.cat( [topk_tracin_scores, batch_topk_tracin_scores], dim=1 ) # retain only the top-k in terms of tracin_scores topk_tracin_scores, topk_argsort = torch.topk( topk_tracin_scores, min(k, topk_indices.shape[1]), dim=1 ) topk_indices = torch.gather(topk_indices, dim=1, index=topk_argsort) # if seeking opponents, we were actually keeping track of negative tracin_scores topk_tracin_scores *= multiplier return topk_indices, topk_tracin_scores class _DatasetFromList(Dataset): def __init__(self, _l: List[Any]): self._l = _l def __getitem__(self, i: int) -> Any: return self._l[i] def __len__(self) -> int: return len(self._l)