File size: 11,895 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
#!/usr/bin/env python3
import base64
import warnings
from collections import namedtuple
from io import BytesIO
from typing import Callable, List, Optional, Union

from captum._utils.common import safe_div
from captum.attr._utils import visualization as viz
from captum.insights.attr_vis._utils.transforms import format_transforms

FeatureOutput = namedtuple("FeatureOutput", "name base modified type contribution")


def _convert_figure_base64(fig):
    buff = BytesIO()
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        fig.tight_layout()  # removes padding
    fig.savefig(buff, format="png")
    base64img = base64.b64encode(buff.getvalue()).decode("utf-8")
    return base64img


class BaseFeature:
    r"""
    All Feature classes extend this class to implement custom visualizations in
    Insights.

    It enforces child classes to implement ``visualization_type`` and ``visualize``
    methods.
    """

    def __init__(
        self,
        name: str,
        baseline_transforms: Optional[Union[Callable, List[Callable]]],
        input_transforms: Optional[Union[Callable, List[Callable]]],
        visualization_transform: Optional[Callable],
    ) -> None:
        r"""
        Args:

            name (str): The label of the specific feature. For example, an
                        ImageFeature's name can be "Photo".
            baseline_transforms (list, callable, optional): Optional list of
                        callables (e.g. functions) to be called on the input tensor
                        to construct multiple baselines. Currently only one baseline
                        is supported. See
                        :py:class:`.IntegratedGradients` for more
                        information about baselines.
            input_transforms (list, callable, optional): Optional list of callables
                        (e.g. functions) called on the input tensor sequentially to
                        convert it into the format expected by the model.
            visualization_transform (callable, optional): Optional callable (e.g.
                        function) applied as a postprocessing step of the original
                        input data (before ``input_transforms``) to convert it to a
                        format to be understood by the frontend visualizer as
                        specified in ``captum/captum/insights/frontend/App.js``.
        """
        self.name = name
        self.baseline_transforms = format_transforms(baseline_transforms)
        self.input_transforms = format_transforms(input_transforms)
        self.visualization_transform = visualization_transform

    @staticmethod
    def visualization_type() -> str:
        raise NotImplementedError

    def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
        raise NotImplementedError


class ImageFeature(BaseFeature):
    r"""
    ImageFeature is used to visualize image features in Insights. It expects an image in
    NCHW format. If C has a dimension of 1, its assumed to be a greyscale image.
    If it has a dimension of 3, its expected to be in RGB format.
    """

    def __init__(
        self,
        name: str,
        baseline_transforms: Union[Callable, List[Callable]],
        input_transforms: Union[Callable, List[Callable]],
        visualization_transform: Optional[Callable] = None,
    ) -> None:
        r"""
        Args:
            name (str): The label of the specific feature. For example, an
                        ImageFeature's name can be "Photo".
            baseline_transforms (list, callable, optional): Optional list of
                        callables (e.g. functions) to be called on the input tensor
                        to construct multiple baselines. Currently only one baseline
                        is supported. See
                        :py:class:`.IntegratedGradients` for more
                        information about baselines.
            input_transforms (list, callable, optional): A list of transforms
                        or transform to be applied to the input. For images,
                        normalization is often applied here.
            visualization_transform (callable, optional): Optional callable (e.g.
                        function) applied as a postprocessing step of the original
                        input data (before input_transforms) to convert it to a
                        format to be visualized.
        """
        super().__init__(
            name,
            baseline_transforms=baseline_transforms,
            input_transforms=input_transforms,
            visualization_transform=visualization_transform,
        )

    @staticmethod
    def visualization_type() -> str:
        return "image"

    def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
        if self.visualization_transform:
            data = self.visualization_transform(data)

        data_t, attribution_t = [
            t.detach().squeeze().permute((1, 2, 0)).cpu().numpy()
            for t in (data, attribution)
        ]

        orig_fig, _ = viz.visualize_image_attr(
            attribution_t, data_t, method="original_image", use_pyplot=False
        )
        attr_fig, _ = viz.visualize_image_attr(
            attribution_t,
            data_t,
            method="heat_map",
            sign="absolute_value",
            use_pyplot=False,
        )

        img_64 = _convert_figure_base64(orig_fig)
        attr_img_64 = _convert_figure_base64(attr_fig)

        return FeatureOutput(
            name=self.name,
            base=img_64,
            modified=attr_img_64,
            type=self.visualization_type(),
            contribution=contribution_frac,
        )


class TextFeature(BaseFeature):
    r"""
    TextFeature is used to visualize text (e.g. sentences) in Insights.
    It expects the visualization transform to convert the input data (e.g. index to
    string) to the raw text.
    """

    def __init__(
        self,
        name: str,
        baseline_transforms: Union[Callable, List[Callable]],
        input_transforms: Union[Callable, List[Callable]],
        visualization_transform: Callable,
    ) -> None:
        r"""
        Args:
            name (str): The label of the specific feature. For example, an
                        ImageFeature's name can be "Photo".
            baseline_transforms (list, callable, optional): Optional list of
                        callables (e.g. functions) to be called on the input tensor
                        to construct multiple baselines. Currently only one baseline
                        is supported. See
                        :py:class:`.IntegratedGradients` for more
                        information about baselines.
                        For text features, a common baseline is a tensor of indices
                        corresponding to PAD with the same size as the input
                        tensor. See :py:class:`.TokenReferenceBase` for more
                        information.
            input_transforms (list, callable, optional): A list of transforms
                        or transform to be applied to the input. For text, a common
                        transform is to convert the tokenized input tensor into an
                        interpretable embedding. See
                        :py:class:`.InterpretableEmbeddingBase`
                        and
                        :py:func:`~.configure_interpretable_embedding_layer`
                        for more information.
            visualization_transform (callable, optional): Optional callable (e.g.
                        function) applied as a postprocessing step of the original
                        input data (before ``input_transforms``) to convert it to a
                        suitable format for visualization. For text features,
                        a common function is to convert the token indices to their
                        corresponding (sub)words.
        """
        super().__init__(
            name,
            baseline_transforms=baseline_transforms,
            input_transforms=input_transforms,
            visualization_transform=visualization_transform,
        )

    @staticmethod
    def visualization_type() -> str:
        return "text"

    def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
        if self.visualization_transform:
            text = self.visualization_transform(data)
        else:
            text = data

        attribution = attribution.squeeze(0)
        data = data.squeeze(0)
        if len(attribution.shape) > 1:
            attribution = attribution.sum(dim=1)

        # L-Infinity norm, if norm is 0, all attr elements are 0
        attr_max = attribution.abs().max()
        normalized_attribution = safe_div(attribution, attr_max)

        modified = [x * 100 for x in normalized_attribution.tolist()]
        return FeatureOutput(
            name=self.name,
            base=text,
            modified=modified,
            type=self.visualization_type(),
            contribution=contribution_frac,
        )


class GeneralFeature(BaseFeature):
    r"""
    GeneralFeature is used for non-specified feature visualization in Insights.
    It can be used for dense or sparse features.

    Currently general features are only supported for 2-d tensors, in the format (N, C)
    where N is the number of samples and C is the number of categories.
    """

    def __init__(self, name: str, categories: List[str]) -> None:
        r"""
        Args:
            name (str): The label of the specific feature. For example, an
                        ImageFeature's name can be "Photo".
            categories (list[str]): Category labels for the general feature. The
                        order and size should match the second dimension of the
                        ``data`` tensor parameter in ``visualize``.
        """
        super().__init__(
            name,
            baseline_transforms=None,
            input_transforms=None,
            visualization_transform=None,
        )
        self.categories = categories

    @staticmethod
    def visualization_type() -> str:
        return "general"

    def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
        attribution = attribution.squeeze(0)
        data = data.squeeze(0)

        # L-2 norm, if norm is 0, all attr elements are 0
        l2_norm = attribution.norm()
        normalized_attribution = safe_div(attribution, l2_norm)

        modified = [x * 100 for x in normalized_attribution.tolist()]

        base = [f"{c}: {d:.2f}" for c, d in zip(self.categories, data.tolist())]
        return FeatureOutput(
            name=self.name,
            base=base,
            modified=modified,
            type=self.visualization_type(),
            contribution=contribution_frac,
        )


class EmptyFeature(BaseFeature):
    def __init__(
        self,
        name: str = "empty",
        baseline_transforms: Optional[Union[Callable, List[Callable]]] = None,
        input_transforms: Optional[Union[Callable, List[Callable]]] = None,
        visualization_transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(
            name,
            baseline_transforms=baseline_transforms,
            input_transforms=input_transforms,
            visualization_transform=visualization_transform,
        )

    @staticmethod
    def visualization_type() -> str:
        return "empty"

    def visualize(self, _attribution, _data, contribution_frac) -> FeatureOutput:
        return FeatureOutput(
            name=self.name,
            base=None,
            modified=None,
            type=self.visualization_type(),
            contribution=contribution_frac,
        )