Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
from typing import Dict, List, Optional, Tuple, Type, Union | |
import torch | |
from captum.attr._utils.stat import Count, Max, Mean, Min, MSE, Stat, StdDev, Sum, Var | |
from captum.log import log_usage | |
from torch import Tensor | |
class Summarizer: | |
r""" | |
This class simply wraps over a given a set of SummarizerSingleTensor's in order | |
to summarise multiple input tensors. | |
Basic usage: | |
>>>from captum.attr.aggregator import Summarizer | |
>>>from captum.attr._utils.stats import Mean, StdDev | |
>>> | |
>>>attrib = torch.tensor([1, 2, 3, 4, 5]) | |
>>> | |
>>>summ = Summarizer([Mean(), StdDev(0]) | |
>>>summ.update(attrib) | |
>>> | |
>>>print(summ.summary['mean']) | |
""" | |
def __init__(self, stats: List[Stat]) -> None: | |
r""" | |
Args: | |
stats (List[Stat]): | |
The list of statistics you wish to track | |
""" | |
self._summarizers: List[SummarizerSingleTensor] = [] | |
self._is_inputs_tuple: Optional[bool] = None | |
self._stats, self._summary_stats_indicies = _reorder_stats(stats) | |
def _copy_stats(self): | |
import copy | |
return copy.deepcopy(self._stats) | |
def update(self, x: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]): | |
r""" | |
Calls `update` on each `Stat` object within the summarizer | |
Args: | |
x (Tensor or Tuple[Tensor, ...]): | |
The input(s) you wish to summarize | |
""" | |
if self._is_inputs_tuple is None: | |
self._is_inputs_tuple = isinstance(x, tuple) | |
else: | |
# we want input to be consistently a single input or a tuple | |
assert not (self._is_inputs_tuple ^ isinstance(x, tuple)) | |
from captum._utils.common import _format_float_or_tensor_into_tuples | |
x = _format_float_or_tensor_into_tuples(x) | |
for i, inp in enumerate(x): | |
if i >= len(self._summarizers): | |
# _summarizers[i] is a new SummarizerSingleTensor, which | |
# aims to summarize input i (i.e. x[i]) | |
# | |
# Thus, we must copy our stats, as otherwise | |
# in the best case the statistics for each input will be mangled | |
# and in the worst case we will run into an error due to different | |
# dimensionality in the input tensors tensors (i.e. | |
# x[i].shape != x[j].shape for some pair i, j) | |
stats = self._copy_stats() | |
self._summarizers.append( | |
SummarizerSingleTensor( | |
stats=stats, summary_stats_indices=self._summary_stats_indicies | |
) | |
) | |
if not isinstance(inp, torch.Tensor): | |
inp = torch.tensor(inp, dtype=torch.float) | |
self._summarizers[i].update(inp) | |
def summary( | |
self, | |
) -> Optional[ | |
Union[Dict[str, Optional[Tensor]], List[Dict[str, Optional[Tensor]]]] | |
]: | |
r""" | |
Effectively calls `get` on each `Stat` object within this object for each input | |
Returns: | |
A dict or list of dict: mapping from the Stat | |
object's `name` to the associated value of `get` | |
""" | |
if len(self._summarizers) == 0: | |
return None | |
temp = [summ.summary for summ in self._summarizers] | |
return temp if self._is_inputs_tuple else temp[0] | |
def _reorder_stats(stats: List[Stat]) -> Tuple[List[Stat], List[int]]: | |
# We want to want to store two things: | |
# 1. A mapping from a Stat to Stat object (self._stat_to_stat): | |
# This is to retrieve an existing Stat object for dependency | |
# resolution, e.g. Mean needs the Count stat - we want to | |
# retrieve it in O(1) | |
# | |
# 2. All of the necessary stats, in the correct order, | |
# to perform an update for each Stat (self.stats) trivially | |
# As a reference, the dependency graph for our stats is as follows: | |
# StdDev(x) -> Var(x) -> MSE -> Mean -> Count, for all valid x | |
# | |
# Step 1: | |
# Ensure we have all the necessary stats | |
# i.e. ensure we have the dependencies | |
# Step 2: | |
# Figure out the order to update them | |
dep_order = [StdDev, Var, MSE, Mean, Count] | |
# remove dupe stats | |
stats = set(stats) | |
summary_stats = set(stats) | |
from collections import defaultdict | |
stats_by_module: Dict[Type, List[Stat]] = defaultdict(list) | |
for stat in stats: | |
stats_by_module[stat.__class__].append(stat) | |
# StdDev is an odd case since it is parameterized, thus | |
# for each StdDev(order) we must ensure there is an associated Var(order) | |
for std_dev in stats_by_module[StdDev]: | |
stat_to_add = Var(order=std_dev.order) # type: ignore | |
stats.add(stat_to_add) | |
stats_by_module[stat_to_add.__class__].append(stat_to_add) | |
# For the other modules (deps[1:n-1]): if i exists => | |
# we want to ensure i...n-1 exists | |
for i, dep in enumerate(dep_order[1:]): | |
if dep in stats_by_module: | |
stats.update([mod() for mod in dep_order[i + 1 :]]) | |
break | |
# Step 2: get the correct order | |
# NOTE: we are sorting via a given topological order | |
sort_order = {mod: i for i, mod in enumerate(dep_order)} | |
sort_order[Min] = -1 | |
sort_order[Max] = -1 | |
sort_order[Sum] = -1 | |
stats = list(stats) | |
stats.sort(key=lambda x: sort_order[x.__class__], reverse=True) | |
# get the summary stat indices | |
summary_stat_indexs = [] | |
for i, stat in enumerate(stats): | |
if stat in summary_stats: | |
summary_stat_indexs.append(i) | |
return stats, summary_stat_indexs | |
class SummarizerSingleTensor: | |
r""" | |
A simple class that summarizes a single tensor. The basic functionality | |
of this class is two operations .update and .summary | |
If possible use `Summarizer` instead. | |
""" | |
def __init__(self, stats: List[Stat], summary_stats_indices: List[int]) -> None: | |
r""" | |
Args: | |
stats (list of Stat): A list of all the Stat objects that | |
need to be updated. This must be in the appropriate order for | |
updates (see `_reorder_stats`) | |
summary_stats (list of int): A list of indicies, referencing `stats`, | |
which are the stats you want to show in the .summary property. This | |
does not require any specific order. | |
""" | |
self._stats = stats | |
self._stat_to_stat = {stat: stat for stat in self._stats} | |
self._summary_stats = [stats[i] for i in summary_stats_indices] | |
for stat in stats: | |
stat._other_stats = self | |
stat.init() | |
def update(self, x: Tensor): | |
r""" | |
Updates the summary of a given tensor `x` | |
Args: | |
x (Tensor): | |
The tensor to summarize | |
""" | |
for stat in self._stats: | |
stat.update(x) | |
def get(self, stat: Stat) -> Optional[Stat]: | |
r""" | |
Retrieves `stat` from cache if this summarizer contains it. | |
Note that `Stat` has it's hash/equality method overridden, such | |
that an object with the same class and parameters will have the | |
same hash. Thus, if you call `get` with a `Stat`, an associated | |
`Stat` with the same class and parameters belonging to this object | |
will be retrieved if it exists. | |
If no such object is retrieved then `None` is returned. | |
Args: | |
stat (Stat): | |
The stat to retrieve | |
Returns: | |
Stat | |
The cached stat object or `None` | |
""" | |
if stat not in self._stat_to_stat: | |
return None | |
return self._stat_to_stat[stat] | |
def summary(self) -> Dict[str, Optional[Tensor]]: | |
""" | |
Returns: | |
Optional[Dict[str, Optional[Tensor]]] | |
The cached stat object | |
""" | |
return {stat.name: stat.get() for stat in self._summary_stats} | |