Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
from typing import Any, Callable, List, Optional, TYPE_CHECKING | |
import torch | |
from torch import Tensor | |
if TYPE_CHECKING: | |
from captum.attr._utils.summarizer import SummarizerSingleTensor | |
class Stat: | |
""" | |
The Stat class represents a statistic that can be updated and retrieved | |
at any point in time. | |
The basic functionality this class provides is: | |
1. A update/get method to actually compute the statistic | |
2. A statistic store/cache to retrieve dependent information | |
(e.g. other stat values that are required for computation) | |
3. The name of the statistic that is used for the user to refer to | |
""" | |
def __init__(self, name: Optional[str] = None, **kwargs: Any) -> None: | |
""" | |
Args: | |
name (str, optional): | |
The name of the statistic. If not provided, | |
the class name will be used alongside it's parameters | |
kwargs (Any): | |
Additional arguments used to construct the statistic | |
""" | |
self.params = kwargs | |
self._name = name | |
self._other_stats: Optional[SummarizerSingleTensor] = None | |
def init(self): | |
pass | |
def _get_stat(self, stat: "Stat") -> Optional["Stat"]: | |
assert self._other_stats is not None | |
return self._other_stats.get(stat) | |
def update(self, x: Tensor): | |
raise NotImplementedError() | |
def get(self) -> Optional[Tensor]: | |
raise NotImplementedError() | |
def __hash__(self): | |
return hash((self.__class__, frozenset(self.params.items()))) | |
def __eq__(self, other: object) -> bool: | |
if isinstance(other, Stat): | |
return self.__class__ == other.__class__ and frozenset( | |
self.params.items() | |
) == frozenset(other.params.items()) | |
else: | |
return False | |
def __ne__(self, other: object) -> bool: | |
return not self.__eq__(other) | |
def name(self): | |
""" | |
The name of the statistic. i.e. it is the key in a .summary | |
This will be the class name or a custom name if provided. | |
See Summarizer or SummarizerSingleTensor | |
""" | |
default_name = self.__class__.__name__.lower() | |
if len(self.params) > 0: | |
default_name += f"({self.params})" | |
return default_name if self._name is None else self._name | |
class Count(Stat): | |
""" | |
Counts the number of elements, i.e. the | |
number of `update`'s called | |
""" | |
def __init__(self, name: Optional[str] = None) -> None: | |
super().__init__(name=name) | |
self.n = None | |
def get(self): | |
return self.n | |
def update(self, x): | |
if self.n is None: | |
self.n = 0 | |
self.n += 1 | |
class Mean(Stat): | |
""" | |
Calculates the average of a tensor | |
""" | |
def __init__(self, name: Optional[str] = None) -> None: | |
super().__init__(name=name) | |
self.rolling_mean: Optional[Tensor] = None | |
self.n: Optional[Count] = None | |
def get(self) -> Optional[Tensor]: | |
return self.rolling_mean | |
def init(self): | |
self.n = self._get_stat(Count()) | |
def update(self, x): | |
n = self.n.get() | |
if self.rolling_mean is None: | |
# Ensures rolling_mean is a float tensor | |
self.rolling_mean = x.clone() if x.is_floating_point() else x.double() | |
else: | |
delta = x - self.rolling_mean | |
self.rolling_mean += delta / n | |
class MSE(Stat): | |
""" | |
Calculates the mean squared error of a tensor | |
""" | |
def __init__(self, name: Optional[str] = None) -> None: | |
super().__init__(name=name) | |
self.prev_mean = None | |
self.mse = None | |
def init(self): | |
self.mean = self._get_stat(Mean()) | |
def get(self) -> Optional[Tensor]: | |
if self.mse is None and self.prev_mean is not None: | |
return torch.zeros_like(self.prev_mean) | |
return self.mse | |
def update(self, x: Tensor): | |
mean = self.mean.get() | |
if mean is not None and self.prev_mean is not None: | |
rhs = (x - self.prev_mean) * (x - mean) | |
if self.mse is None: | |
self.mse = rhs | |
else: | |
self.mse += rhs | |
# do not not clone | |
self.prev_mean = mean.clone() | |
class Var(Stat): | |
""" | |
Calculates the variance of a tensor, with an order. e.g. | |
if `order = 1` then it will calculate sample variance. | |
This is equal to mse / (n - order) | |
""" | |
def __init__(self, name: Optional[str] = None, order: int = 0) -> None: | |
if name is None: | |
if order == 0: | |
name = "variance" | |
elif order == 1: | |
name = "sample_variance" | |
else: | |
name = f"variance({order})" | |
super().__init__(name=name, order=order) | |
self.order = order | |
def init(self): | |
self.mse = self._get_stat(MSE()) | |
self.n = self._get_stat(Count()) | |
def update(self, x: Tensor): | |
pass | |
def get(self) -> Optional[Tensor]: | |
mse = self.mse.get() | |
n = self.n.get() | |
if mse is None: | |
return None | |
if n <= self.order: | |
return torch.zeros_like(mse) | |
# NOTE: The following ensures mse is a float tensor. | |
# torch.true_divide is available in PyTorch 1.5 and later. | |
# This is for compatibility with 1.4. | |
return mse.to(torch.float64) / (n - self.order) | |
class StdDev(Stat): | |
""" | |
The standard deviation, with an associated order. | |
""" | |
def __init__(self, name: Optional[str] = None, order: int = 0) -> None: | |
if name is None: | |
if order == 0: | |
name = "std_dev" | |
elif order == 1: | |
name = "sample_std_dev" | |
else: | |
name = f"std_dev{order})" | |
super().__init__(name=name, order=order) | |
self.order = order | |
def init(self): | |
self.var = self._get_stat(Var(order=self.order)) | |
def update(self, x: Tensor): | |
pass | |
def get(self) -> Optional[Tensor]: | |
var = self.var.get() | |
return var ** 0.5 if var is not None else None | |
class GeneralAccumFn(Stat): | |
""" | |
Performs update(x): result = fn(result, x) | |
where fn is a custom function | |
""" | |
def __init__(self, fn: Callable, name: Optional[str] = None) -> None: | |
super().__init__(name=name) | |
self.result = None | |
self.fn = fn | |
def get(self) -> Optional[Tensor]: | |
return self.result | |
def update(self, x): | |
if self.result is None: | |
self.result = x | |
else: | |
self.result = self.fn(self.result, x) | |
class Min(GeneralAccumFn): | |
def __init__( | |
self, name: Optional[str] = None, min_fn: Callable = torch.min | |
) -> None: | |
super().__init__(name=name, fn=min_fn) | |
class Max(GeneralAccumFn): | |
def __init__( | |
self, name: Optional[str] = None, max_fn: Callable = torch.max | |
) -> None: | |
super().__init__(name=name, fn=max_fn) | |
class Sum(GeneralAccumFn): | |
def __init__( | |
self, name: Optional[str] = None, add_fn: Callable = torch.add | |
) -> None: | |
super().__init__(name=name, fn=add_fn) | |
def CommonStats() -> List[Stat]: | |
r""" | |
Returns common summary statistics, specifically: | |
Mean, Sample Variance, Sample Std Dev, Min, Max | |
""" | |
return [Mean(), Var(order=1), StdDev(order=1), Min(), Max()] | |