Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
from abc import ABC, abstractmethod | |
from typing import Dict, Optional, Union | |
from captum._utils.typing import TensorOrTupleOfTensorsGeneric | |
from torch import Tensor | |
from torch.utils.data import DataLoader | |
class Model(ABC): | |
r""" | |
Abstract Class to describe the interface of a trainable model to be used | |
within the algorithms of captum. | |
Please note that this is an experimental feature. | |
""" | |
def fit( | |
self, train_data: DataLoader, **kwargs | |
) -> Optional[Dict[str, Union[int, float, Tensor]]]: | |
r""" | |
Override this method to actually train your model. | |
The specification of the dataloader will be supplied by the algorithm | |
you are using within captum. This will likely be a supervised learning | |
task, thus you should expect batched (x, y) pairs or (x, y, w) triples. | |
Args: | |
train_data (DataLoader): | |
The data to train on | |
Returns: | |
Optional statistics about training, e.g. iterations it took to | |
train, training loss, etc. | |
""" | |
pass | |
def representation(self) -> Tensor: | |
r""" | |
Returns the underlying representation of the interpretable model. For a | |
linear model this is simply a tensor (the concatenation of weights | |
and bias). For something slightly more complicated, such as a decision | |
tree, this could be the nodes of a decision tree. | |
Returns: | |
A Tensor describing the representation of the model. | |
""" | |
pass | |
def __call__( | |
self, x: TensorOrTupleOfTensorsGeneric | |
) -> TensorOrTupleOfTensorsGeneric: | |
r""" | |
Predicts with the interpretable model. | |
Args: | |
x (TensorOrTupleOfTensorsGeneric) | |
A batched input of tensor(s) to the model to predict | |
Returns: | |
The prediction of the input as a TensorOrTupleOfTensorsGeneric. | |
""" | |
pass | |