File size: 2,039 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python3

from abc import ABC, abstractmethod
from typing import Dict, Optional, Union

from captum._utils.typing import TensorOrTupleOfTensorsGeneric
from torch import Tensor
from torch.utils.data import DataLoader


class Model(ABC):
    r"""
    Abstract Class to describe the interface of a trainable model to be used
    within the algorithms of captum.

    Please note that this is an experimental feature.
    """

    @abstractmethod
    def fit(
        self, train_data: DataLoader, **kwargs
    ) -> Optional[Dict[str, Union[int, float, Tensor]]]:
        r"""
        Override this method to actually train your model.

        The specification of the dataloader will be supplied by the algorithm
        you are using within captum. This will likely be a supervised learning
        task, thus you should expect batched (x, y) pairs or (x, y, w) triples.

        Args:
            train_data (DataLoader):
                The data to train on

        Returns:
            Optional statistics about training, e.g.  iterations it took to
            train, training loss, etc.
        """
        pass

    @abstractmethod
    def representation(self) -> Tensor:
        r"""
        Returns the underlying representation of the interpretable model. For a
        linear model this is simply a tensor (the concatenation of weights
        and bias). For something slightly more complicated, such as a decision
        tree, this could be the nodes of a decision tree.

        Returns:
            A Tensor describing the representation of the model.
        """
        pass

    @abstractmethod
    def __call__(
        self, x: TensorOrTupleOfTensorsGeneric
    ) -> TensorOrTupleOfTensorsGeneric:
        r"""
        Predicts with the interpretable model.

        Args:
            x (TensorOrTupleOfTensorsGeneric)
                A batched input of tensor(s) to the model to predict
        Returns:
            The prediction of the input as a TensorOrTupleOfTensorsGeneric.
        """
        pass