#!/usr/bin/env python3 from typing import Dict, List, Optional, Tuple, Type, Union import torch from captum.attr._utils.stat import Count, Max, Mean, Min, MSE, Stat, StdDev, Sum, Var from captum.log import log_usage from torch import Tensor class Summarizer: r""" This class simply wraps over a given a set of SummarizerSingleTensor's in order to summarise multiple input tensors. Basic usage: >>>from captum.attr.aggregator import Summarizer >>>from captum.attr._utils.stats import Mean, StdDev >>> >>>attrib = torch.tensor([1, 2, 3, 4, 5]) >>> >>>summ = Summarizer([Mean(), StdDev(0]) >>>summ.update(attrib) >>> >>>print(summ.summary['mean']) """ @log_usage() def __init__(self, stats: List[Stat]) -> None: r""" Args: stats (List[Stat]): The list of statistics you wish to track """ self._summarizers: List[SummarizerSingleTensor] = [] self._is_inputs_tuple: Optional[bool] = None self._stats, self._summary_stats_indicies = _reorder_stats(stats) def _copy_stats(self): import copy return copy.deepcopy(self._stats) def update(self, x: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]): r""" Calls `update` on each `Stat` object within the summarizer Args: x (Tensor or Tuple[Tensor, ...]): The input(s) you wish to summarize """ if self._is_inputs_tuple is None: self._is_inputs_tuple = isinstance(x, tuple) else: # we want input to be consistently a single input or a tuple assert not (self._is_inputs_tuple ^ isinstance(x, tuple)) from captum._utils.common import _format_float_or_tensor_into_tuples x = _format_float_or_tensor_into_tuples(x) for i, inp in enumerate(x): if i >= len(self._summarizers): # _summarizers[i] is a new SummarizerSingleTensor, which # aims to summarize input i (i.e. x[i]) # # Thus, we must copy our stats, as otherwise # in the best case the statistics for each input will be mangled # and in the worst case we will run into an error due to different # dimensionality in the input tensors tensors (i.e. # x[i].shape != x[j].shape for some pair i, j) stats = self._copy_stats() self._summarizers.append( SummarizerSingleTensor( stats=stats, summary_stats_indices=self._summary_stats_indicies ) ) if not isinstance(inp, torch.Tensor): inp = torch.tensor(inp, dtype=torch.float) self._summarizers[i].update(inp) @property def summary( self, ) -> Optional[ Union[Dict[str, Optional[Tensor]], List[Dict[str, Optional[Tensor]]]] ]: r""" Effectively calls `get` on each `Stat` object within this object for each input Returns: A dict or list of dict: mapping from the Stat object's `name` to the associated value of `get` """ if len(self._summarizers) == 0: return None temp = [summ.summary for summ in self._summarizers] return temp if self._is_inputs_tuple else temp[0] def _reorder_stats(stats: List[Stat]) -> Tuple[List[Stat], List[int]]: # We want to want to store two things: # 1. A mapping from a Stat to Stat object (self._stat_to_stat): # This is to retrieve an existing Stat object for dependency # resolution, e.g. Mean needs the Count stat - we want to # retrieve it in O(1) # # 2. All of the necessary stats, in the correct order, # to perform an update for each Stat (self.stats) trivially # As a reference, the dependency graph for our stats is as follows: # StdDev(x) -> Var(x) -> MSE -> Mean -> Count, for all valid x # # Step 1: # Ensure we have all the necessary stats # i.e. ensure we have the dependencies # Step 2: # Figure out the order to update them dep_order = [StdDev, Var, MSE, Mean, Count] # remove dupe stats stats = set(stats) summary_stats = set(stats) from collections import defaultdict stats_by_module: Dict[Type, List[Stat]] = defaultdict(list) for stat in stats: stats_by_module[stat.__class__].append(stat) # StdDev is an odd case since it is parameterized, thus # for each StdDev(order) we must ensure there is an associated Var(order) for std_dev in stats_by_module[StdDev]: stat_to_add = Var(order=std_dev.order) # type: ignore stats.add(stat_to_add) stats_by_module[stat_to_add.__class__].append(stat_to_add) # For the other modules (deps[1:n-1]): if i exists => # we want to ensure i...n-1 exists for i, dep in enumerate(dep_order[1:]): if dep in stats_by_module: stats.update([mod() for mod in dep_order[i + 1 :]]) break # Step 2: get the correct order # NOTE: we are sorting via a given topological order sort_order = {mod: i for i, mod in enumerate(dep_order)} sort_order[Min] = -1 sort_order[Max] = -1 sort_order[Sum] = -1 stats = list(stats) stats.sort(key=lambda x: sort_order[x.__class__], reverse=True) # get the summary stat indices summary_stat_indexs = [] for i, stat in enumerate(stats): if stat in summary_stats: summary_stat_indexs.append(i) return stats, summary_stat_indexs class SummarizerSingleTensor: r""" A simple class that summarizes a single tensor. The basic functionality of this class is two operations .update and .summary If possible use `Summarizer` instead. """ def __init__(self, stats: List[Stat], summary_stats_indices: List[int]) -> None: r""" Args: stats (list of Stat): A list of all the Stat objects that need to be updated. This must be in the appropriate order for updates (see `_reorder_stats`) summary_stats (list of int): A list of indicies, referencing `stats`, which are the stats you want to show in the .summary property. This does not require any specific order. """ self._stats = stats self._stat_to_stat = {stat: stat for stat in self._stats} self._summary_stats = [stats[i] for i in summary_stats_indices] for stat in stats: stat._other_stats = self stat.init() def update(self, x: Tensor): r""" Updates the summary of a given tensor `x` Args: x (Tensor): The tensor to summarize """ for stat in self._stats: stat.update(x) def get(self, stat: Stat) -> Optional[Stat]: r""" Retrieves `stat` from cache if this summarizer contains it. Note that `Stat` has it's hash/equality method overridden, such that an object with the same class and parameters will have the same hash. Thus, if you call `get` with a `Stat`, an associated `Stat` with the same class and parameters belonging to this object will be retrieved if it exists. If no such object is retrieved then `None` is returned. Args: stat (Stat): The stat to retrieve Returns: Stat The cached stat object or `None` """ if stat not in self._stat_to_stat: return None return self._stat_to_stat[stat] @property def summary(self) -> Dict[str, Optional[Tensor]]: """ Returns: Optional[Dict[str, Optional[Tensor]]] The cached stat object """ return {stat.name: stat.get() for stat in self._summary_stats}