Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
import random | |
import warnings | |
from abc import ABC, abstractmethod | |
from typing import Any, Dict, List, Tuple, Union | |
import torch | |
from captum._utils.models.linear_model import model | |
from torch import Tensor | |
from torch.utils.data import DataLoader, TensorDataset | |
class Classifier(ABC): | |
r""" | |
An abstract class definition of any classifier that allows to train a model | |
and access trained weights of that model. | |
More specifically the classifier can, for instance, be trained on the | |
activations of a particular layer. Below we can see an example a sklearn | |
linear classifier wrapped by the `CustomClassifier` which extends `Classifier` | |
abstract class. | |
Example:: | |
>>> from sklearn import linear_model | |
>>> | |
>>> class CustomClassifier(Classifier): | |
>>> | |
>>> def __init__(self): | |
>>> | |
>>> self.lm = linear_model.SGDClassifier(alpha=0.01, max_iter=1000, | |
>>> tol=1e-3) | |
>>> | |
>>> def train_and_eval(self, dataloader): | |
>>> | |
>>> x_train, x_test, y_train, y_test = train_test_split(inputs, labels) | |
>>> self.lm.fit(x_train.detach().numpy(), y_train.detach().numpy()) | |
>>> | |
>>> preds = torch.tensor(self.lm.predict(x_test.detach().numpy())) | |
>>> return {'accs': (preds == y_test).float().mean()} | |
>>> | |
>>> | |
>>> def weights(self): | |
>>> | |
>>> if len(self.lm.coef_) == 1: | |
>>> # if there are two concepts, there is only one label. | |
>>> # We split it in two. | |
>>> return torch.tensor([-1 * self.lm.coef_[0], self.lm.coef_[0]]) | |
>>> else: | |
>>> return torch.tensor(self.lm.coef_) | |
>>> | |
>>> | |
>>> def classes(self): | |
>>> return self.lm.classes_ | |
>>> | |
>>> | |
""" | |
def __init__(self) -> None: | |
pass | |
def train_and_eval( | |
self, dataloader: DataLoader, **kwargs: Any | |
) -> Union[Dict, None]: | |
r""" | |
This method is responsible for training a classifier using the data | |
provided through `dataloader` input arguments. Based on the specific | |
implementation, it may or may not return a statistics about model | |
training and evaluation. | |
Args: | |
dataloader (dataloader): A dataloader that enables batch-wise access to | |
the inputs and corresponding labels. Dataloader allows us to | |
iterate over the dataset by loading the batches in lazy manner. | |
kwargs (dict): Named arguments that are used for training and evaluating | |
concept classifier. | |
Default: None | |
Returns: | |
stats (dict): a dictionary of statistics about the performance of the model. | |
For example the accuracy of the model on the test and/or | |
train dataset(s). The user may decide to return None or an | |
empty dictionary if she/he decides to not return any performance | |
statistics. | |
""" | |
pass | |
def weights(self) -> Tensor: | |
r""" | |
This function returns a C x F tensor weights, where | |
C is the number of classes and F is the number of features. | |
Returns: | |
weights (tensor): A torch Tensor with the weights resulting from | |
the model training. | |
""" | |
pass | |
def classes(self) -> List[int]: | |
r""" | |
This function returns the list of all classes that are used by the | |
classifier to train the model in the `train_and_eval` method. | |
The order of returned classes has to match the same order used in | |
the weights matrix returned by the `weights` method. | |
Returns: | |
classes (list): The list of classes used by the classifier to train | |
the model in the `train_and_eval` method. | |
""" | |
pass | |
class DefaultClassifier(Classifier): | |
r""" | |
A default Linear Classifier based on sklearn's SGDClassifier for | |
learning decision boundaries between concepts. | |
Note that default implementation slices input dataset into train and test | |
splits and keeps them in memory. | |
In case concept datasets are large, this can lead to out of memory and we | |
recommend to provide a custom Classier that extends `Classifier` abstract | |
class and handles large concept datasets accordingly. | |
""" | |
def __init__(self): | |
warnings.warn( | |
"Using default classifier for TCAV which keeps input" | |
" both train and test datasets in the memory. Consider defining" | |
" your own classifier that doesn't rely heavily on memory, for" | |
" large number of concepts, by extending" | |
" `Classifer` abstract class" | |
) | |
self.lm = model.SkLearnSGDClassifier(alpha=0.01, max_iter=1000, tol=1e-3) | |
def train_and_eval( | |
self, dataloader: DataLoader, test_split_ratio: float = 0.33, **kwargs: Any | |
) -> Union[Dict, None]: | |
r""" | |
Implements Classifier::train_and_eval abstract method for small concept | |
datsets provided by `dataloader`. | |
It is assumed that when iterating over `dataloader` we can still | |
retain the entire dataset in the memory. | |
This method shuffles all examples randomly provided, splits them | |
into train and test partitions and trains an SGDClassifier using sklearn | |
library. Ultimately, it measures and returns model accuracy using test | |
split of the dataset. | |
Args: | |
dataloader (dataloader): A dataloader that enables batch-wise access to | |
the inputs and corresponding labels. Dataloader allows us to | |
iterate over the dataset by loading the batches in lazy manner. | |
test_split_ratio (float): The ratio of test split in the entire dataset | |
served by input data loader `dataloader`. | |
Default: 0.33 | |
Returns: | |
stats (dict): a dictionary of statistics about the performance of the model. | |
In this case stats represents a dictionary of model accuracy | |
measured on the test split of the dataset. | |
""" | |
inputs = [] | |
labels = [] | |
for input, label in dataloader: | |
inputs.append(input) | |
labels.append(label) | |
device = "cpu" if input is None else input.device | |
x_train, x_test, y_train, y_test = _train_test_split( | |
torch.cat(inputs), torch.cat(labels), test_split=test_split_ratio | |
) | |
self.lm.device = device | |
self.lm.fit(DataLoader(TensorDataset(x_train, y_train))) | |
predict = self.lm(x_test) | |
predict = self.lm.classes()[torch.argmax(predict, dim=1)] | |
score = predict.long() == y_test.long().cpu() | |
accs = score.float().mean() | |
return {"accs": accs} | |
def weights(self) -> Tensor: | |
r""" | |
This function returns a C x F tensor weights, where | |
C is the number of classes and F is the number of features. | |
In case of binary classification, C = 2 othewise it is > 2. | |
Returns: | |
weights (tensor): A torch Tensor with the weights resulting from | |
the model training. | |
""" | |
assert self.lm.linear is not None, ( | |
"The weights cannot be obtained because no model was trained." | |
"In order to train the model call `train_and_eval` method first." | |
) | |
weights = self.lm.representation() | |
if weights.shape[0] == 1: | |
# if there are two concepts, there is only one label. We split it in two. | |
return torch.stack([-1 * weights[0], weights[0]]) | |
else: | |
return weights | |
def classes(self) -> List[int]: | |
r""" | |
This function returns the list of all classes that are used by the | |
classifier to train the model in the `train_and_eval` method. | |
The order of returned classes has to match the same order used in | |
the weights matrix returned by the `weights` method. | |
Returns: | |
classes (list): The list of classes used by the classifier to train | |
the model in the `train_and_eval` method. | |
""" | |
return self.lm.classes().detach().numpy() | |
def _train_test_split( | |
x_list: Tensor, y_list: Tensor, test_split: float = 0.33 | |
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
# Shuffle | |
z_list = list(zip(x_list, y_list)) | |
random.shuffle(z_list) | |
# Split | |
test_size = int(test_split * len(z_list)) | |
z_test, z_train = z_list[:test_size], z_list[test_size:] | |
x_test, y_test = zip(*z_test) | |
x_train, y_train = zip(*z_train) | |
return ( | |
torch.stack(x_train), | |
torch.stack(x_test), | |
torch.stack(y_train), | |
torch.stack(y_test), | |
) | |