#!/usr/bin/env python3 from typing import Any, Callable, List, Optional, TYPE_CHECKING import torch from torch import Tensor if TYPE_CHECKING: from captum.attr._utils.summarizer import SummarizerSingleTensor class Stat: """ The Stat class represents a statistic that can be updated and retrieved at any point in time. The basic functionality this class provides is: 1. A update/get method to actually compute the statistic 2. A statistic store/cache to retrieve dependent information (e.g. other stat values that are required for computation) 3. The name of the statistic that is used for the user to refer to """ def __init__(self, name: Optional[str] = None, **kwargs: Any) -> None: """ Args: name (str, optional): The name of the statistic. If not provided, the class name will be used alongside it's parameters kwargs (Any): Additional arguments used to construct the statistic """ self.params = kwargs self._name = name self._other_stats: Optional[SummarizerSingleTensor] = None def init(self): pass def _get_stat(self, stat: "Stat") -> Optional["Stat"]: assert self._other_stats is not None return self._other_stats.get(stat) def update(self, x: Tensor): raise NotImplementedError() def get(self) -> Optional[Tensor]: raise NotImplementedError() def __hash__(self): return hash((self.__class__, frozenset(self.params.items()))) def __eq__(self, other: object) -> bool: if isinstance(other, Stat): return self.__class__ == other.__class__ and frozenset( self.params.items() ) == frozenset(other.params.items()) else: return False def __ne__(self, other: object) -> bool: return not self.__eq__(other) @property def name(self): """ The name of the statistic. i.e. it is the key in a .summary This will be the class name or a custom name if provided. See Summarizer or SummarizerSingleTensor """ default_name = self.__class__.__name__.lower() if len(self.params) > 0: default_name += f"({self.params})" return default_name if self._name is None else self._name class Count(Stat): """ Counts the number of elements, i.e. the number of `update`'s called """ def __init__(self, name: Optional[str] = None) -> None: super().__init__(name=name) self.n = None def get(self): return self.n def update(self, x): if self.n is None: self.n = 0 self.n += 1 class Mean(Stat): """ Calculates the average of a tensor """ def __init__(self, name: Optional[str] = None) -> None: super().__init__(name=name) self.rolling_mean: Optional[Tensor] = None self.n: Optional[Count] = None def get(self) -> Optional[Tensor]: return self.rolling_mean def init(self): self.n = self._get_stat(Count()) def update(self, x): n = self.n.get() if self.rolling_mean is None: # Ensures rolling_mean is a float tensor self.rolling_mean = x.clone() if x.is_floating_point() else x.double() else: delta = x - self.rolling_mean self.rolling_mean += delta / n class MSE(Stat): """ Calculates the mean squared error of a tensor """ def __init__(self, name: Optional[str] = None) -> None: super().__init__(name=name) self.prev_mean = None self.mse = None def init(self): self.mean = self._get_stat(Mean()) def get(self) -> Optional[Tensor]: if self.mse is None and self.prev_mean is not None: return torch.zeros_like(self.prev_mean) return self.mse def update(self, x: Tensor): mean = self.mean.get() if mean is not None and self.prev_mean is not None: rhs = (x - self.prev_mean) * (x - mean) if self.mse is None: self.mse = rhs else: self.mse += rhs # do not not clone self.prev_mean = mean.clone() class Var(Stat): """ Calculates the variance of a tensor, with an order. e.g. if `order = 1` then it will calculate sample variance. This is equal to mse / (n - order) """ def __init__(self, name: Optional[str] = None, order: int = 0) -> None: if name is None: if order == 0: name = "variance" elif order == 1: name = "sample_variance" else: name = f"variance({order})" super().__init__(name=name, order=order) self.order = order def init(self): self.mse = self._get_stat(MSE()) self.n = self._get_stat(Count()) def update(self, x: Tensor): pass def get(self) -> Optional[Tensor]: mse = self.mse.get() n = self.n.get() if mse is None: return None if n <= self.order: return torch.zeros_like(mse) # NOTE: The following ensures mse is a float tensor. # torch.true_divide is available in PyTorch 1.5 and later. # This is for compatibility with 1.4. return mse.to(torch.float64) / (n - self.order) class StdDev(Stat): """ The standard deviation, with an associated order. """ def __init__(self, name: Optional[str] = None, order: int = 0) -> None: if name is None: if order == 0: name = "std_dev" elif order == 1: name = "sample_std_dev" else: name = f"std_dev{order})" super().__init__(name=name, order=order) self.order = order def init(self): self.var = self._get_stat(Var(order=self.order)) def update(self, x: Tensor): pass def get(self) -> Optional[Tensor]: var = self.var.get() return var ** 0.5 if var is not None else None class GeneralAccumFn(Stat): """ Performs update(x): result = fn(result, x) where fn is a custom function """ def __init__(self, fn: Callable, name: Optional[str] = None) -> None: super().__init__(name=name) self.result = None self.fn = fn def get(self) -> Optional[Tensor]: return self.result def update(self, x): if self.result is None: self.result = x else: self.result = self.fn(self.result, x) class Min(GeneralAccumFn): def __init__( self, name: Optional[str] = None, min_fn: Callable = torch.min ) -> None: super().__init__(name=name, fn=min_fn) class Max(GeneralAccumFn): def __init__( self, name: Optional[str] = None, max_fn: Callable = torch.max ) -> None: super().__init__(name=name, fn=max_fn) class Sum(GeneralAccumFn): def __init__( self, name: Optional[str] = None, add_fn: Callable = torch.add ) -> None: super().__init__(name=name, fn=add_fn) def CommonStats() -> List[Stat]: r""" Returns common summary statistics, specifically: Mean, Sample Variance, Sample Std Dev, Min, Max """ return [Mean(), Var(order=1), StdDev(order=1), Min(), Max()]