File size: 6,733 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/usr/bin/env python3
import inspect
from collections import namedtuple
from typing import (
    Callable,
    cast,
    Dict,
    Iterable,
    List,
    Optional,
    Sequence,
    Tuple,
    Union,
)

import torch
from captum._utils.common import _run_forward, safe_div
from captum.insights.attr_vis.config import (
    ATTRIBUTION_METHOD_CONFIG,
    ATTRIBUTION_NAMES_TO_METHODS,
)
from captum.insights.attr_vis.features import BaseFeature
from torch import Tensor
from torch.nn import Module

OutputScore = namedtuple("OutputScore", "score index label")


class AttributionCalculation:
    def __init__(
        self,
        models: Sequence[Module],
        classes: Sequence[str],
        features: List[BaseFeature],
        score_func: Optional[Callable] = None,
        use_label_for_attr: bool = True,
    ) -> None:
        self.models = models
        self.classes = classes
        self.features = features
        self.score_func = score_func
        self.use_label_for_attr = use_label_for_attr
        self.baseline_cache: dict = {}
        self.transformed_input_cache: dict = {}

    def calculate_predicted_scores(
        self, inputs, additional_forward_args, model
    ) -> Tuple[
        List[OutputScore], Optional[List[Tuple[Tensor, ...]]], Tuple[Tensor, ...]
    ]:
        # Check if inputs have cached baselines and transformed inputs
        hashable_inputs = tuple(inputs)
        if hashable_inputs in self.baseline_cache:
            baselines_group = self.baseline_cache[hashable_inputs]
            transformed_inputs = self.transformed_input_cache[hashable_inputs]
        else:
            # Initialize baselines
            baseline_transforms_len = 1  # todo support multiple baselines
            baselines: List[List[Optional[Tensor]]] = [
                [None] * len(self.features) for _ in range(baseline_transforms_len)
            ]
            transformed_inputs = list(inputs)
            for feature_i, feature in enumerate(self.features):
                transformed_inputs[feature_i] = self._transform(
                    feature.input_transforms, transformed_inputs[feature_i], True
                )
                for baseline_i in range(baseline_transforms_len):
                    if baseline_i > len(feature.baseline_transforms) - 1:
                        baselines[baseline_i][feature_i] = torch.zeros_like(
                            transformed_inputs[feature_i]
                        )
                    else:
                        baselines[baseline_i][feature_i] = self._transform(
                            [feature.baseline_transforms[baseline_i]],
                            transformed_inputs[feature_i],
                            True,
                        )

            baselines = cast(List[List[Optional[Tensor]]], baselines)
            baselines_group = [tuple(b) for b in baselines]
            self.baseline_cache[hashable_inputs] = baselines_group
            self.transformed_input_cache[hashable_inputs] = transformed_inputs

        outputs = _run_forward(
            model,
            tuple(transformed_inputs),
            additional_forward_args=additional_forward_args,
        )

        if self.score_func is not None:
            outputs = self.score_func(outputs)

        if outputs.nelement() == 1:
            scores = outputs
            predicted = scores.round().to(torch.int)
        else:
            scores, predicted = outputs.topk(min(4, outputs.shape[-1]))

        scores = scores.cpu().squeeze(0)
        predicted = predicted.cpu().squeeze(0)

        predicted_scores = self._get_labels_from_scores(scores, predicted)

        return predicted_scores, baselines_group, tuple(transformed_inputs)

    def calculate_attribution(
        self,
        baselines: Optional[Sequence[Tuple[Tensor, ...]]],
        data: Tuple[Tensor, ...],
        additional_forward_args: Optional[Tuple[Tensor, ...]],
        label: Optional[Union[Tensor]],
        attribution_method_name: str,
        attribution_arguments: Dict,
        model: Module,
    ) -> Tuple[Tensor, ...]:
        attribution_cls = ATTRIBUTION_NAMES_TO_METHODS[attribution_method_name]
        attribution_method = attribution_cls(model)
        if attribution_method_name in ATTRIBUTION_METHOD_CONFIG:
            param_config = ATTRIBUTION_METHOD_CONFIG[attribution_method_name]
            if param_config.post_process:
                for k, v in attribution_arguments.items():
                    if k in param_config.post_process:
                        attribution_arguments[k] = param_config.post_process[k](v)

        # TODO support multiple baselines
        baseline = baselines[0] if baselines and len(baselines) > 0 else None
        label = (
            None
            if not self.use_label_for_attr or label is None or label.nelement() == 0
            else label
        )
        if "baselines" in inspect.signature(attribution_method.attribute).parameters:
            attribution_arguments["baselines"] = baseline
        attr = attribution_method.attribute.__wrapped__(
            attribution_method,  # self
            data,
            additional_forward_args=additional_forward_args,
            target=label,
            **attribution_arguments,
        )

        return attr

    def calculate_net_contrib(
        self, attrs_per_input_feature: Tuple[Tensor, ...]
    ) -> List[float]:
        # get the net contribution per feature (input)
        net_contrib = torch.stack(
            [attrib.flatten().sum() for attrib in attrs_per_input_feature]
        )

        # normalise the contribution, s.t. sum(abs(x_i)) = 1
        norm = torch.norm(net_contrib, p=1)
        # if norm is 0, all net_contrib elements are 0
        net_contrib = safe_div(net_contrib, norm)

        return net_contrib.tolist()

    def _transform(
        self, transforms: Iterable[Callable], inputs: Tensor, batch: bool = False
    ) -> Tensor:
        transformed_inputs = inputs
        # TODO support batch size > 1
        if batch:
            transformed_inputs = inputs.squeeze(0)

        for t in transforms:
            transformed_inputs = t(transformed_inputs)

        if batch:
            transformed_inputs = transformed_inputs.unsqueeze(0)

        return transformed_inputs

    def _get_labels_from_scores(
        self, scores: Tensor, indices: Tensor
    ) -> List[OutputScore]:
        pred_scores: List[OutputScore] = []
        if indices.nelement() < 2:
            return pred_scores
        for i in range(len(indices)):
            score = scores[i]
            pred_scores.append(
                OutputScore(score, indices[i], self.classes[int(indices[i])])
            )
        return pred_scores