Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
import warnings | |
from functools import partial | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import captum._utils.common as common | |
import torch | |
from captum._utils.av import AV | |
from captum.attr import LayerActivation | |
from captum.influence._core.influence import DataInfluence | |
from torch import Tensor | |
from torch.nn import Module | |
from torch.utils.data import DataLoader, Dataset | |
r""" | |
Additional helper functions to calculate similarity metrics. | |
""" | |
def euclidean_distance(test, train) -> Tensor: | |
r""" | |
Calculates the pairwise euclidean distance for batches of feature vectors. | |
Tensors test and train have shape (batch_size_1, *), and (batch_size_2, *). | |
Returns pairwise euclidean distance Tensor of shape (batch_size_1, batch_size_2). | |
""" | |
similarity = torch.cdist( | |
test.view(test.shape[0], -1).unsqueeze(0), | |
train.view(train.shape[0], -1).unsqueeze(0), | |
).squeeze(0) | |
return similarity | |
def cosine_similarity(test, train, replace_nan=0) -> Tensor: | |
r""" | |
Calculates the pairwise cosine similarity for batches of feature vectors. | |
Tensors test and train have shape (batch_size_1, *), and (batch_size_2, *). | |
Returns pairwise cosine similarity Tensor of shape (batch_size_1, batch_size_2). | |
""" | |
test = test.view(test.shape[0], -1) | |
train = train.view(train.shape[0], -1) | |
if torch.__version__ <= "1.6.0": | |
test_norm = torch.norm(test, p=None, dim=1, keepdim=True) | |
train_norm = torch.norm(train, p=None, dim=1, keepdim=True) | |
else: | |
test_norm = torch.linalg.norm(test, ord=2, dim=1, keepdim=True) | |
train_norm = torch.linalg.norm(train, ord=2, dim=1, keepdim=True) | |
test = torch.where(test_norm != 0.0, test / test_norm, Tensor([replace_nan])) | |
train = torch.where(train_norm != 0.0, train / train_norm, Tensor([replace_nan])).T | |
similarity = torch.mm(test, train) | |
return similarity | |
r""" | |
Implements abstract DataInfluence class and provides implementation details for | |
similarity metric-based influence computation. Similarity metrics can be used to compare | |
intermediate or final activation vectors of a model for different sets of input. Then, | |
these can be used to draw conclusions about influential instances. | |
Some standard similarity metrics such as dot product similarity or euclidean distance | |
are provided, but the user can provide any custom similarity metric as well. | |
""" | |
class SimilarityInfluence(DataInfluence): | |
def __init__( | |
self, | |
module: Module, | |
layers: Union[str, List[str]], | |
influence_src_dataset: Dataset, | |
activation_dir: str, | |
model_id: str = "", | |
similarity_metric: Callable = cosine_similarity, | |
similarity_direction: str = "max", | |
batch_size: int = 1, | |
**kwargs: Any, | |
): | |
r""" | |
Args: | |
module (torch.nn.Module): An instance of pytorch model. This model should | |
define all of its layers as attributes of the model. | |
layers (str or List of str): The fully qualified layer(s) for which the | |
activation vectors are computed. | |
influence_src_dataset (torch.utils.data.Dataset): PyTorch Dataset that is | |
used to create a PyTorch Dataloader to iterate over the dataset and | |
its labels. This is the dataset for which we will be seeking for | |
influential instances. In most cases this is the training dataset. | |
activation_dir (str): The directory of the path to store | |
and retrieve activation computations. Best practice would be to use | |
an absolute path. | |
model_id (str): The name/version of the model for which layer | |
activations are being computed. Activations will be stored and | |
loaded under the subdirectory with this name if provided. | |
similarity_metric (Callable): This is a callable function that computes a | |
similarity metric between two representations. For example, the | |
representations pair could be from the training and test sets. | |
This function must adhere to certain standards. The inputs should be | |
torch Tensors with shape (batch_size_i/j, feature dimensions). The | |
output Tensor should have shape (batch_size_i, batch_size_j) with | |
scalar values corresponding to the similarity metric used for each | |
pairwise combination from the two batches. | |
For example, suppose we use `batch_size_1 = 16` for iterating | |
through `influence_src_dataset`, and for the `inputs` argument | |
we pass in a Tensor with 3 examples, i.e. batch_size_2 = 3. Also, | |
suppose that our inputs and intermediate activations throughout the | |
model will have dimension (N, C, H, W). Then, the feature dimensions | |
should be flattened within this function. For example:: | |
>>> av_test.shape | |
torch.Size([3, N, C, H, W]) | |
>>> av_src.shape | |
torch.Size([16, N, C, H, W]) | |
>>> av_test = torch.view(av_test.shape[0], -1) | |
>>> av_test.shape | |
torch.Size([3, N x C x H x W]) | |
and similarly for av_src. The similarity_metric should then use | |
these flattened tensors to return the pairwise similarity matrix. | |
For example, `similarity_metric(av_test, av_src)` should return a | |
tensor of shape (3, 16). | |
batch_size (int): Batch size for iterating through `influence_src_dataset`. | |
**kwargs: Additional key-value arguments that are necessary for specific | |
implementation of `DataInfluence` abstract class. | |
""" | |
self.module = module | |
self.layers = [layers] if isinstance(layers, str) else layers | |
self.influence_src_dataset = influence_src_dataset | |
self.activation_dir = activation_dir | |
self.model_id = model_id | |
self.batch_size = batch_size | |
if similarity_direction == "max" or similarity_direction == "min": | |
self.similarity_direction = similarity_direction | |
else: | |
raise ValueError( | |
f"{similarity_direction} is not a valid value. " | |
"Must be either 'max' or 'min'" | |
) | |
if similarity_metric is cosine_similarity: | |
if "replace_nan" in kwargs: | |
self.replace_nan = kwargs["replace_nan"] | |
else: | |
self.replace_nan = -2 if self.similarity_direction == "max" else 2 | |
similarity_metric = partial(cosine_similarity, replace_nan=self.replace_nan) | |
self.similarity_metric = similarity_metric | |
self.influence_src_dataloader = DataLoader( | |
influence_src_dataset, batch_size, shuffle=False | |
) | |
def influence( # type: ignore[override] | |
self, | |
inputs: Union[Tensor, Tuple[Tensor, ...]], | |
top_k: int = 1, | |
additional_forward_args: Optional[Any] = None, | |
load_src_from_disk: bool = True, | |
**kwargs: Any, | |
) -> Dict: | |
r""" | |
Args: | |
inputs (tensor or tuple of tensors): Batch of examples for which influential | |
instances are computed. They are passed to the forward_func. The | |
first dimension in `inputs` tensor or tuple of tensors corresponds | |
to the batch size. A tuple of tensors is only passed in if this | |
is the input form that `module` accepts. | |
top_k (int): The number of top-matching activations to return | |
additional_forward_args (optional): Additional arguments that will be | |
passed to forward_func after inputs. | |
load_src_from_disk (bool): Loads activations for `influence_src_dataset` | |
where possible. Setting to False would force regeneration of | |
activations. | |
load_input_from_disk (bool): Regenerates activations for inputs by default | |
and removes previous `inputs` activations that are flagged with | |
`inputs_id`. Setting to True will load prior matching inputs | |
activations. Note that this could lead to unexpected behavior if | |
`inputs_id` is not configured properly and activations are loaded | |
for a different, prior `inputs`. | |
inputs_id (str): Used to identify inputs for loading activations. | |
**kwargs: Additional key-value arguments that are necessary for specific | |
implementation of `DataInfluence` abstract class. | |
Returns: | |
influences (dict): Returns the influential instances retrieved from | |
`influence_src_dataset` for each test example represented through a | |
tensor or a tuple of tensor in `inputs`. Returned influential | |
examples are represented as dict, with keys corresponding to | |
the layer names passed in `layers`. Each value in the dict is a | |
tuple containing the indices and values for the top k similarities | |
from `influence_src_dataset` by the chosen metric. The first value | |
in the tuple corresponds to the indices corresponding to the top k | |
most similar examples, and the second value is the similarity score. | |
The batch dimension corresponds to the batch dimension of `inputs`. | |
If inputs.shape[0] == 5, then dict[`layer_name`][0].shape[0] == 5. | |
These tensors will be of shape (inputs.shape[0], top_k). | |
""" | |
inputs_batch_size = ( | |
inputs[0].shape[0] if isinstance(inputs, tuple) else inputs.shape[0] | |
) | |
influences: Dict[str, Any] = {} | |
layer_AVDatasets = AV.generate_dataset_activations( | |
self.activation_dir, | |
self.module, | |
self.model_id, | |
self.layers, | |
DataLoader(self.influence_src_dataset, self.batch_size, shuffle=False), | |
identifier="src", | |
load_from_disk=load_src_from_disk, | |
return_activations=True, | |
) | |
assert layer_AVDatasets is not None and not isinstance( | |
layer_AVDatasets, AV.AVDataset | |
) | |
layer_modules = [ | |
common._get_module_from_name(self.module, layer) for layer in self.layers | |
] | |
test_activations = LayerActivation(self.module, layer_modules).attribute( | |
inputs, additional_forward_args | |
) | |
minmax = self.similarity_direction == "max" | |
# av_inputs shape: (inputs_batch_size, *) e.g. (inputs_batch_size, N, C, H, W) | |
# av_src shape: (self.batch_size, *) e.g. (self.batch_size, N, C, H, W) | |
test_activations = ( | |
test_activations if len(self.layers) > 1 else [test_activations] | |
) | |
for i, (layer, layer_AVDataset) in enumerate( | |
zip(self.layers, layer_AVDatasets) | |
): | |
topk_val, topk_idx = torch.Tensor(), torch.Tensor().long() | |
zero_acts = torch.Tensor().long() | |
av_inputs = test_activations[i] | |
src_loader = DataLoader(layer_AVDataset) | |
for j, av_src in enumerate(src_loader): | |
av_src = av_src.squeeze(0) | |
similarity = self.similarity_metric(av_inputs, av_src) | |
msg = ( | |
"Output of custom similarity does not meet required dimensions. " | |
f"Your output has shape {similarity.shape}.\nPlease ensure the " | |
"output shape matches (inputs_batch_size, src_dataset_batch_size), " | |
f"which should be {(inputs_batch_size, self.batch_size)}." | |
) | |
assert similarity.shape == (inputs_batch_size, av_src.shape[0]), msg | |
if hasattr(self, "replace_nan"): | |
idx = (similarity == self.replace_nan).nonzero() | |
zero_acts = torch.cat((zero_acts, idx)) | |
r""" | |
TODO: For models that can have tuples as activations, we should | |
allow similarity metrics to accept tuples, support topk selection. | |
""" | |
topk_batch = min(top_k, self.batch_size) | |
values, indices = torch.topk( | |
similarity, topk_batch, dim=1, largest=minmax | |
) | |
indices += int(j * self.batch_size) | |
topk_val = torch.cat((topk_val, values), dim=1) | |
topk_idx = torch.cat((topk_idx, indices), dim=1) | |
# can modify how often to sort for efficiency? minor | |
sort_idx = torch.argsort(topk_val, dim=1, descending=minmax) | |
topk_val = torch.gather(topk_val, 1, sort_idx[:, :top_k]) | |
topk_idx = torch.gather(topk_idx, 1, sort_idx[:, :top_k]) | |
influences[layer] = (topk_idx, topk_val) | |
if torch.numel(zero_acts != 0): | |
zero_warning = ( | |
f"Layer {layer} has zero-vector activations for some inputs. This " | |
"may cause undefined behavior for cosine similarity. The indices " | |
"for the offending inputs will be included under the key " | |
f"'zero_acts-{layer}' in the output dictionary. Indices are " | |
"returned as a tensor with [inputs_idx, src_dataset_idx] pairs " | |
"which may have corrupted similarity scores." | |
) | |
warnings.warn(zero_warning, RuntimeWarning) | |
key = "-".join(["zero_acts", layer]) | |
influences[key] = zero_acts | |
return influences | |