strexp / captum /influence /_utils /nearest_neighbors.py
markytools's picture
added strexp
d61b9c7
raw
history blame
10.1 kB
from abc import ABC, abstractmethod
from typing import Tuple
import torch
from torch import Tensor
class NearestNeighbors(ABC):
r"""
An abstract class to define a nearest neighbors data structure. Classes
implementing this interface are intended for computing proponents / opponents in
certain implementations of `TracInCPBase`. In particular, it is for use in
implementations which compute proponents / opponents of a test instance by
1) storing representations of training instances within a nearest neighbors data
structure, and 2) finding within that structure the nearest neighbor of the
representation of a test instance. The assumption is that the data structure
stores the tensors passed to the `setup` method, which we refer to as the "stored
tensors". If this class is used to find proponents / opponents, the nearest
neighbors of a tensor should be the stored tensors that have the largest
dot-product with the query.
"""
@abstractmethod
def get_nearest_neighbors(
self, query: torch.Tensor, k: int
) -> Tuple[Tensor, Tensor]:
r"""
Given a `query`, a tensor of shape (N, *), returns the nearest neighbors in the
"stored tensors" (see above). `query` represents a batch of N tensors, each
of common but arbitrary shape *. We always assume the 0-th dimension indexes
the batch. In use cases of this class for computing proponents / opponents,
the nearest neighbors of a tensor should be the stored tensors with the largest
dot-product with the tensor, and the tensors in `query` will all be 1D,
so that `query` is 2D.
Args:
query (tensor): tensor representing the batch of tensors for which k-nearest
neighbors are desired. `query` is of shape (N, *), where N is the
size of the batch, i.e. the 0-th dimension of `query` indexes the
batch. * denotes an arbitrary shape, so that each tensor in the
batch can be of a common, but arbitrary shape.
k (int): The number of nearest neighbors to return.
Returns:
results (tuple): A tuple of `(indices, distances)` is returned. `indices`
is a 2D tensor where `indices[i,j]` is the index (within the
"stored tensors" passed to the `setup` method) of the `j`-th
nearest neighbor of the `i`-th instance in query, and
`distances[i,j]` is the corresponding distance. `indices` should
be of dtype `torch.long` so that it can be used to index torch
tensors.
"""
pass
@abstractmethod
def setup(self, data: torch.Tensor) -> None:
r"""
`data` denotes the "stored tensors". These are the tensors within which we
want to find the nearest neighbors to each tensor in a batch of tensors, via a
call to the`get_nearest_neighbors` method. Before we can call it, however,
we need to first store the stored tensors, by doing processing that indexes
the stored tensors in a form that enables nearest-neighbors computation.
This method does that preprocessing, and is assumed to be called before any
call to `get_nearest_neighbors`. For example, this method might put the
stored tensors in a K-d tree. The tensors in the "stored tensors" can be of a
common, but arbitrary shape, denoted *, so that `data` is of shape (N, *),
where N is the number of tensors in the stored tensors. Therefore, the 0-th
dimension indexes the tensors in the stored tensors.
Args:
data (tensor): A tensor of shape (N, *) representing the stored tensors.
The 0-th dimension indexes the tensors in the stored tensors,
so that `data[i]` is the tensor with index `i`. The nearest
neighbors of a query will be referred to by their index.
"""
pass
class AnnoyNearestNeighbors(NearestNeighbors):
"""
This is an implementation of `NearestNeighbors` that uses the Annoy module. At a
high level, Annoy finds nearest neighbors by constructing binary trees in which
vectors reside at leaf nodes. Vectors near each other will tend to be in the same
leaf node. See https://tinyurl.com/2p89sb2h and https://github.com/spotify/annoy
for more details. Annoy has 1 key parameter: the number of trees to construct.
Increasing the number of trees leads to more accurate results, but longer time to
create the trees and memory usage. As mentioned in the `NearestNeighbors`
documentation, for the use case of computing proponents / opponents, the nearest
neighbors returned should be those with the largest dot product with the query
vector. The term "vector" is used here because Annoy stores 1D vectors. However
in our wrapper around Annoy, we will allow the stored tensors to be of a common
but arbitrary shape *, and flatten them before storing in the Annoy data structure.
"""
def __init__(self, num_trees: int = 10):
"""
Args:
num_trees (int): The number of trees to use. Increasing this number gives
more accurate computation of nearest neighbors, but requires longer
setup time to create the trees, as well as memory.
"""
try:
import annoy # noqa
except ImportError:
raise ValueError(
(
"Using `AnnoyNearestNeighbors` requires installing the annoy "
"module. If pip is installed, this can be done with "
"`pip install --user annoy`."
)
)
self.num_trees = num_trees
def setup(self, data: torch.Tensor) -> None:
"""
`data` denotes the "stored tensors". These are the tensors within which we
want to find the nearest neighbors to a query tensor, via a call to the
`get_nearest_neighbors` method. Before we can call `get_nearest_neighbors`,
we need to first store the stored tensors, by doing processing that indexes
the stored tensors in a form that enables nearest-neighbors computation.
This method does that preprocessing, and is assumed to be called before any
call to `get_nearest_neighbors`. In particular, it creates the trees used to
index the stored tensors. This index is built to enable computation of
vectors that have the largest dot-product with the query tensors. The tensors
in the "stored tensors" can be of a common, but arbitrary shape, denoted *, so
that `data` is of shape (N, *), where N is the number of tensors in the stored
tensors. Therefore, the 0-th dimension indexes the tensors in the stored
tensors.
Args:
data (tensor): A tensor of shape (N, *) representing the stored tensors.
The 0-th dimension indexes the tensors in the stored tensors,
so that `data[i]` is the tensor with index `i`. The nearest
neighbors of a query will be referred to by their index.
"""
import annoy
data = data.view((len(data), -1))
projection_dim = data.shape[1]
self.knn_index = annoy.AnnoyIndex(projection_dim, "dot")
for (i, projection) in enumerate(data):
self.knn_index.add_item(i, projection)
self.knn_index.build(self.num_trees)
def get_nearest_neighbors(
self, query: torch.Tensor, k: int
) -> Tuple[Tensor, Tensor]:
r"""
Given a `query`, a tensor of shape (N, *), returns the nearest neighbors in the
"stored tensors" (see above). `query` represents a batch of N tensors, each
of common but arbitrary shape *. We always assume the 0-th dimension indexes
the batch. In use cases of this class for computing proponents / opponents,
the nearest neighbors of a tensor should be the stored tensors with the largest
dot-product with the tensor, and the tensors in `query` will all be 1D,
so that `query` is 2D. This implementation returns the stored tensors
that have the largest dot-product with the query tensor, and does not constrain
the tensors in `query` or in the stored tensors to be 1D. If tensors are of
dimension greater than 1D, their dot-product will be defined to be the
dot-product of the flattened version of tensors.
Args:
query (tensor): tensor representing the batch of tensors for which k-nearest
neighbors are desired. `query` is of shape (N, *), where N is the
size of the batch, i.e. the 0-th dimension of `query` indexes the
batch. * denotes an arbitrary shape, so that each tensor in the
batch can be of a common, but arbitrary shape.
k (int): The number of nearest neighbors to return.
Returns:
results (tuple): A tuple of `(indices, distances)` is returned. `indices`
is a 2D tensor where `indices[i,j]` is the index (within the
"stored tensors" passed to the `setup` method) of the `j`-th
nearest neighbor of the `i`-th instance in query, and
`distances[i,j]` is the corresponding distance. `indices` should
be of dtype `torch.long` so that it can be used to index torch
tensors.
"""
query = query.view((len(query), -1))
indices_and_distances = [
self.knn_index.get_nns_by_vector(instance, k, include_distances=True)
for instance in query
]
indices, distances = zip(*indices_and_distances)
indices = torch.Tensor(indices).type(torch.long)
distances = torch.Tensor(distances)
return indices, distances