File size: 3,206 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python3

from typing import Callable, Union

import torch
from torch.nn import Module


class Concept:

    r"""
    Concepts are human-friendly abstract representations that can be
    numerically encoded into torch tensors. They can be illustrated as
    images, text or any other form of representation. In case of images,
    for example, "stripes" concept can be represented through a number
    of example images resembling "stripes" in various different
    contexts. In case of Natural Language Processing, the concept of
    "happy", for instance, can be illustrated through a number of
    adjectives and words that convey happiness.
    """

    def __init__(
        self, id: int, name: str, data_iter: Union[None, torch.utils.data.DataLoader]
    ) -> None:

        r"""
        Args:
            id (int):   The unique identifier of the concept.
            name (str): A unique name of the concept.
            data_iter (DataLoader): A pytorch DataLoader object that combines a dataset
                        and a sampler, and provides an iterable over a given
                        dataset. Only the input batches are provided by `data_iter`.
                        Concept ids can be used as labels if necessary.
                        For more information, please check:
                        https://pytorch.org/docs/stable/data.html

        Example::
            >>> # Creates a Concept object named "striped", with a data_iter
            >>> # object to iterate over all files in "./concepts/striped"
            >>> concept_name = "striped"
            >>> concept_path = os.path.join("./concepts", concept_name) + "/"
            >>> concept_iter = dataset_to_dataloader(
            >>> get_tensor_from_filename, concepts_path=concept_path)
            >>> concept_object = Concept(
                    id=0, name=concept_name, data_iter=concept_iter)
        """

        self.id = id
        self.name = name
        self.data_iter = data_iter

    @property
    def identifier(self) -> str:
        return "%s-%s" % (self.name, self.id)

    def __repr__(self) -> str:
        return "Concept(%r, %r)" % (self.id, self.name)


class ConceptInterpreter:
    r"""
    An abstract class that exposes an abstract interpret method
    that has to be implemented by a specific algorithm for
    concept-based model interpretability.
    """

    def __init__(self, model: Module) -> None:
        r"""
        Args:
            model (torch.nn.Module): An instance of pytorch model.
        """
        self.model = model

    interpret: Callable
    r"""
    An abstract interpret method that performs concept-based model interpretability
    and returns the interpretation results in form of tensors, dictionaries or other
    data structures.

    Args:

        inputs (tensor or tuple of tensors):  Inputs for which concept-based
                    interpretation scores are computed. It can be provided as
                    a single tensor or a tuple of multiple tensors. If multiple
                    input tensors are provided, the batch size (the first
                    dimension of the tensors) must be aligned across all tensors.
    """