Spaces:
Build error
Build error
File size: 7,454 Bytes
d61b9c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 |
#!/usr/bin/env python3
from typing import Any, Callable, List, Optional, TYPE_CHECKING
import torch
from torch import Tensor
if TYPE_CHECKING:
from captum.attr._utils.summarizer import SummarizerSingleTensor
class Stat:
"""
The Stat class represents a statistic that can be updated and retrieved
at any point in time.
The basic functionality this class provides is:
1. A update/get method to actually compute the statistic
2. A statistic store/cache to retrieve dependent information
(e.g. other stat values that are required for computation)
3. The name of the statistic that is used for the user to refer to
"""
def __init__(self, name: Optional[str] = None, **kwargs: Any) -> None:
"""
Args:
name (str, optional):
The name of the statistic. If not provided,
the class name will be used alongside it's parameters
kwargs (Any):
Additional arguments used to construct the statistic
"""
self.params = kwargs
self._name = name
self._other_stats: Optional[SummarizerSingleTensor] = None
def init(self):
pass
def _get_stat(self, stat: "Stat") -> Optional["Stat"]:
assert self._other_stats is not None
return self._other_stats.get(stat)
def update(self, x: Tensor):
raise NotImplementedError()
def get(self) -> Optional[Tensor]:
raise NotImplementedError()
def __hash__(self):
return hash((self.__class__, frozenset(self.params.items())))
def __eq__(self, other: object) -> bool:
if isinstance(other, Stat):
return self.__class__ == other.__class__ and frozenset(
self.params.items()
) == frozenset(other.params.items())
else:
return False
def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
@property
def name(self):
"""
The name of the statistic. i.e. it is the key in a .summary
This will be the class name or a custom name if provided.
See Summarizer or SummarizerSingleTensor
"""
default_name = self.__class__.__name__.lower()
if len(self.params) > 0:
default_name += f"({self.params})"
return default_name if self._name is None else self._name
class Count(Stat):
"""
Counts the number of elements, i.e. the
number of `update`'s called
"""
def __init__(self, name: Optional[str] = None) -> None:
super().__init__(name=name)
self.n = None
def get(self):
return self.n
def update(self, x):
if self.n is None:
self.n = 0
self.n += 1
class Mean(Stat):
"""
Calculates the average of a tensor
"""
def __init__(self, name: Optional[str] = None) -> None:
super().__init__(name=name)
self.rolling_mean: Optional[Tensor] = None
self.n: Optional[Count] = None
def get(self) -> Optional[Tensor]:
return self.rolling_mean
def init(self):
self.n = self._get_stat(Count())
def update(self, x):
n = self.n.get()
if self.rolling_mean is None:
# Ensures rolling_mean is a float tensor
self.rolling_mean = x.clone() if x.is_floating_point() else x.double()
else:
delta = x - self.rolling_mean
self.rolling_mean += delta / n
class MSE(Stat):
"""
Calculates the mean squared error of a tensor
"""
def __init__(self, name: Optional[str] = None) -> None:
super().__init__(name=name)
self.prev_mean = None
self.mse = None
def init(self):
self.mean = self._get_stat(Mean())
def get(self) -> Optional[Tensor]:
if self.mse is None and self.prev_mean is not None:
return torch.zeros_like(self.prev_mean)
return self.mse
def update(self, x: Tensor):
mean = self.mean.get()
if mean is not None and self.prev_mean is not None:
rhs = (x - self.prev_mean) * (x - mean)
if self.mse is None:
self.mse = rhs
else:
self.mse += rhs
# do not not clone
self.prev_mean = mean.clone()
class Var(Stat):
"""
Calculates the variance of a tensor, with an order. e.g.
if `order = 1` then it will calculate sample variance.
This is equal to mse / (n - order)
"""
def __init__(self, name: Optional[str] = None, order: int = 0) -> None:
if name is None:
if order == 0:
name = "variance"
elif order == 1:
name = "sample_variance"
else:
name = f"variance({order})"
super().__init__(name=name, order=order)
self.order = order
def init(self):
self.mse = self._get_stat(MSE())
self.n = self._get_stat(Count())
def update(self, x: Tensor):
pass
def get(self) -> Optional[Tensor]:
mse = self.mse.get()
n = self.n.get()
if mse is None:
return None
if n <= self.order:
return torch.zeros_like(mse)
# NOTE: The following ensures mse is a float tensor.
# torch.true_divide is available in PyTorch 1.5 and later.
# This is for compatibility with 1.4.
return mse.to(torch.float64) / (n - self.order)
class StdDev(Stat):
"""
The standard deviation, with an associated order.
"""
def __init__(self, name: Optional[str] = None, order: int = 0) -> None:
if name is None:
if order == 0:
name = "std_dev"
elif order == 1:
name = "sample_std_dev"
else:
name = f"std_dev{order})"
super().__init__(name=name, order=order)
self.order = order
def init(self):
self.var = self._get_stat(Var(order=self.order))
def update(self, x: Tensor):
pass
def get(self) -> Optional[Tensor]:
var = self.var.get()
return var ** 0.5 if var is not None else None
class GeneralAccumFn(Stat):
"""
Performs update(x): result = fn(result, x)
where fn is a custom function
"""
def __init__(self, fn: Callable, name: Optional[str] = None) -> None:
super().__init__(name=name)
self.result = None
self.fn = fn
def get(self) -> Optional[Tensor]:
return self.result
def update(self, x):
if self.result is None:
self.result = x
else:
self.result = self.fn(self.result, x)
class Min(GeneralAccumFn):
def __init__(
self, name: Optional[str] = None, min_fn: Callable = torch.min
) -> None:
super().__init__(name=name, fn=min_fn)
class Max(GeneralAccumFn):
def __init__(
self, name: Optional[str] = None, max_fn: Callable = torch.max
) -> None:
super().__init__(name=name, fn=max_fn)
class Sum(GeneralAccumFn):
def __init__(
self, name: Optional[str] = None, add_fn: Callable = torch.add
) -> None:
super().__init__(name=name, fn=add_fn)
def CommonStats() -> List[Stat]:
r"""
Returns common summary statistics, specifically:
Mean, Sample Variance, Sample Std Dev, Min, Max
"""
return [Mean(), Var(order=1), StdDev(order=1), Min(), Max()]
|