File size: 8,025 Bytes
d61b9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#!/usr/bin/env python3

from typing import Dict, List, Optional, Tuple, Type, Union

import torch
from captum.attr._utils.stat import Count, Max, Mean, Min, MSE, Stat, StdDev, Sum, Var
from captum.log import log_usage
from torch import Tensor


class Summarizer:
    r"""
    This class simply wraps over a given a set of SummarizerSingleTensor's in order
    to summarise multiple input tensors.

    Basic usage:

    >>>from captum.attr.aggregator import Summarizer
    >>>from captum.attr._utils.stats import Mean, StdDev
    >>>
    >>>attrib = torch.tensor([1, 2, 3, 4, 5])
    >>>
    >>>summ = Summarizer([Mean(), StdDev(0])
    >>>summ.update(attrib)
    >>>
    >>>print(summ.summary['mean'])
    """

    @log_usage()
    def __init__(self, stats: List[Stat]) -> None:
        r"""
        Args:
            stats (List[Stat]):
                The list of statistics you wish to track
        """
        self._summarizers: List[SummarizerSingleTensor] = []
        self._is_inputs_tuple: Optional[bool] = None
        self._stats, self._summary_stats_indicies = _reorder_stats(stats)

    def _copy_stats(self):
        import copy

        return copy.deepcopy(self._stats)

    def update(self, x: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]):
        r"""
        Calls `update` on each `Stat` object within the summarizer

        Args:
            x (Tensor or Tuple[Tensor, ...]):
                The input(s) you wish to summarize
        """
        if self._is_inputs_tuple is None:
            self._is_inputs_tuple = isinstance(x, tuple)
        else:
            # we want input to be consistently a single input or a tuple
            assert not (self._is_inputs_tuple ^ isinstance(x, tuple))

        from captum._utils.common import _format_float_or_tensor_into_tuples

        x = _format_float_or_tensor_into_tuples(x)

        for i, inp in enumerate(x):
            if i >= len(self._summarizers):
                # _summarizers[i] is a new SummarizerSingleTensor, which
                # aims to summarize input i (i.e. x[i])
                #
                # Thus, we must copy our stats, as otherwise
                # in the best case the statistics for each input will be mangled
                # and in the worst case we will run into an error due to different
                # dimensionality in the input tensors tensors (i.e.
                # x[i].shape != x[j].shape for some pair i, j)
                stats = self._copy_stats()
                self._summarizers.append(
                    SummarizerSingleTensor(
                        stats=stats, summary_stats_indices=self._summary_stats_indicies
                    )
                )
            if not isinstance(inp, torch.Tensor):
                inp = torch.tensor(inp, dtype=torch.float)
            self._summarizers[i].update(inp)

    @property
    def summary(
        self,
    ) -> Optional[
        Union[Dict[str, Optional[Tensor]], List[Dict[str, Optional[Tensor]]]]
    ]:
        r"""
        Effectively calls `get` on each `Stat` object within this object for each input

        Returns:
            A dict or list of dict: mapping from the Stat
            object's `name` to the associated value of `get`
        """
        if len(self._summarizers) == 0:
            return None

        temp = [summ.summary for summ in self._summarizers]
        return temp if self._is_inputs_tuple else temp[0]


def _reorder_stats(stats: List[Stat]) -> Tuple[List[Stat], List[int]]:
    # We want to want to store two things:
    # 1. A mapping from a Stat to Stat object (self._stat_to_stat):
    #    This is to retrieve an existing Stat object for dependency
    #    resolution, e.g.  Mean needs the Count stat - we want to
    #    retrieve it in O(1)
    #
    # 2. All of the necessary stats, in the correct order,
    #    to perform an update for each Stat (self.stats) trivially

    # As a reference, the dependency graph for our stats is as follows:
    # StdDev(x) -> Var(x) -> MSE -> Mean -> Count, for all valid x
    #
    # Step 1:
    #    Ensure we have all the necessary stats
    #    i.e. ensure we have the dependencies
    # Step 2:
    #    Figure out the order to update them
    dep_order = [StdDev, Var, MSE, Mean, Count]

    # remove dupe stats
    stats = set(stats)
    summary_stats = set(stats)

    from collections import defaultdict

    stats_by_module: Dict[Type, List[Stat]] = defaultdict(list)
    for stat in stats:
        stats_by_module[stat.__class__].append(stat)

    # StdDev is an odd case since it is parameterized, thus
    # for each StdDev(order) we must ensure there is an associated Var(order)
    for std_dev in stats_by_module[StdDev]:
        stat_to_add = Var(order=std_dev.order)  # type: ignore
        stats.add(stat_to_add)
        stats_by_module[stat_to_add.__class__].append(stat_to_add)

    # For the other modules (deps[1:n-1]): if i exists =>
    # we want to ensure i...n-1 exists
    for i, dep in enumerate(dep_order[1:]):
        if dep in stats_by_module:
            stats.update([mod() for mod in dep_order[i + 1 :]])
            break

    # Step 2: get the correct order
    # NOTE: we are sorting via a given topological order
    sort_order = {mod: i for i, mod in enumerate(dep_order)}
    sort_order[Min] = -1
    sort_order[Max] = -1
    sort_order[Sum] = -1

    stats = list(stats)
    stats.sort(key=lambda x: sort_order[x.__class__], reverse=True)

    # get the summary stat indices
    summary_stat_indexs = []
    for i, stat in enumerate(stats):
        if stat in summary_stats:
            summary_stat_indexs.append(i)
    return stats, summary_stat_indexs


class SummarizerSingleTensor:
    r"""
    A simple class that summarizes a single tensor. The basic functionality
    of this class is two operations .update and .summary

    If possible use `Summarizer` instead.
    """

    def __init__(self, stats: List[Stat], summary_stats_indices: List[int]) -> None:
        r"""
        Args:
            stats (list of Stat): A list of all the Stat objects that
                need to be updated. This must be in the appropriate order for
                updates (see `_reorder_stats`)
            summary_stats (list of int): A list of indicies, referencing `stats`,
                which are the stats you want to show in the .summary property. This
                does not require any specific order.
        """
        self._stats = stats
        self._stat_to_stat = {stat: stat for stat in self._stats}
        self._summary_stats = [stats[i] for i in summary_stats_indices]

        for stat in stats:
            stat._other_stats = self
            stat.init()

    def update(self, x: Tensor):
        r"""
        Updates the summary of a given tensor `x`

        Args:
            x (Tensor):
                The tensor to summarize
        """
        for stat in self._stats:
            stat.update(x)

    def get(self, stat: Stat) -> Optional[Stat]:
        r"""
        Retrieves `stat` from cache if this summarizer contains it.

        Note that `Stat` has it's hash/equality method overridden, such
        that an object with the same class and parameters will have the
        same hash. Thus, if you call `get` with a `Stat`, an associated
        `Stat` with the same class and parameters belonging to this object
        will be retrieved if it exists.

        If no such object is retrieved then `None` is returned.

        Args:
            stat (Stat):
                The stat to retrieve
        Returns:
            Stat
                The cached stat object or `None`
        """
        if stat not in self._stat_to_stat:
            return None

        return self._stat_to_stat[stat]

    @property
    def summary(self) -> Dict[str, Optional[Tensor]]:
        """
        Returns:
            Optional[Dict[str, Optional[Tensor]]]
                The cached stat object
        """
        return {stat.name: stat.get() for stat in self._summary_stats}