Spaces:
Running
Running
import math | |
import warnings | |
from numbers import Number | |
from typing import Optional, Union | |
import torch | |
from torch import nan | |
from torch.distributions import constraints | |
from torch.distributions.exp_family import ExponentialFamily | |
from torch.distributions.multivariate_normal import _precision_to_scale_tril | |
from torch.distributions.utils import lazy_property | |
__all__ = ["Wishart"] | |
_log_2 = math.log(2) | |
def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor: | |
assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function." | |
return torch.digamma( | |
x.unsqueeze(-1) | |
- torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,)) | |
).sum(-1) | |
def _clamp_above_eps(x: torch.Tensor) -> torch.Tensor: | |
# We assume positive input for this function | |
return x.clamp(min=torch.finfo(x.dtype).eps) | |
class Wishart(ExponentialFamily): | |
r""" | |
Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`, | |
or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top` | |
Example: | |
>>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional") | |
>>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2)) | |
>>> m.sample() # Wishart distributed with mean=`df * I` and | |
>>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j | |
Args: | |
df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1 | |
covariance_matrix (Tensor): positive-definite covariance matrix | |
precision_matrix (Tensor): positive-definite precision matrix | |
scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal | |
Note: | |
Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or | |
:attr:`scale_tril` can be specified. | |
Using :attr:`scale_tril` will be more efficient: all computations internally | |
are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or | |
:attr:`precision_matrix` is passed instead, it is only used to compute | |
the corresponding lower triangular matrices using a Cholesky decomposition. | |
'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1] | |
**References** | |
[1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`. | |
[2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`. | |
[3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`. | |
[4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203. | |
[5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`. | |
""" | |
arg_constraints = { | |
"covariance_matrix": constraints.positive_definite, | |
"precision_matrix": constraints.positive_definite, | |
"scale_tril": constraints.lower_cholesky, | |
"df": constraints.greater_than(0), | |
} | |
support = constraints.positive_definite | |
has_rsample = True | |
_mean_carrier_measure = 0 | |
def __init__( | |
self, | |
df: Union[torch.Tensor, Number], | |
covariance_matrix: Optional[torch.Tensor] = None, | |
precision_matrix: Optional[torch.Tensor] = None, | |
scale_tril: Optional[torch.Tensor] = None, | |
validate_args=None, | |
): | |
assert (covariance_matrix is not None) + (scale_tril is not None) + ( | |
precision_matrix is not None | |
) == 1, "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." | |
param = next( | |
p | |
for p in (covariance_matrix, precision_matrix, scale_tril) | |
if p is not None | |
) | |
if param.dim() < 2: | |
raise ValueError( | |
"scale_tril must be at least two-dimensional, with optional leading batch dimensions" | |
) | |
if isinstance(df, Number): | |
batch_shape = torch.Size(param.shape[:-2]) | |
self.df = torch.tensor(df, dtype=param.dtype, device=param.device) | |
else: | |
batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape) | |
self.df = df.expand(batch_shape) | |
event_shape = param.shape[-2:] | |
if self.df.le(event_shape[-1] - 1).any(): | |
raise ValueError( | |
f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}." | |
) | |
if scale_tril is not None: | |
self.scale_tril = param.expand(batch_shape + (-1, -1)) | |
elif covariance_matrix is not None: | |
self.covariance_matrix = param.expand(batch_shape + (-1, -1)) | |
elif precision_matrix is not None: | |
self.precision_matrix = param.expand(batch_shape + (-1, -1)) | |
self.arg_constraints["df"] = constraints.greater_than(event_shape[-1] - 1) | |
if self.df.lt(event_shape[-1]).any(): | |
warnings.warn( | |
"Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim." | |
) | |
super().__init__(batch_shape, event_shape, validate_args=validate_args) | |
self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))] | |
if scale_tril is not None: | |
self._unbroadcasted_scale_tril = scale_tril | |
elif covariance_matrix is not None: | |
self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix) | |
else: # precision_matrix is not None | |
self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix) | |
# Chi2 distribution is needed for Bartlett decomposition sampling | |
self._dist_chi2 = torch.distributions.chi2.Chi2( | |
df=( | |
self.df.unsqueeze(-1) | |
- torch.arange( | |
self._event_shape[-1], | |
dtype=self._unbroadcasted_scale_tril.dtype, | |
device=self._unbroadcasted_scale_tril.device, | |
).expand(batch_shape + (-1,)) | |
) | |
) | |
def expand(self, batch_shape, _instance=None): | |
new = self._get_checked_instance(Wishart, _instance) | |
batch_shape = torch.Size(batch_shape) | |
cov_shape = batch_shape + self.event_shape | |
new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape) | |
new.df = self.df.expand(batch_shape) | |
new._batch_dims = [-(x + 1) for x in range(len(batch_shape))] | |
if "covariance_matrix" in self.__dict__: | |
new.covariance_matrix = self.covariance_matrix.expand(cov_shape) | |
if "scale_tril" in self.__dict__: | |
new.scale_tril = self.scale_tril.expand(cov_shape) | |
if "precision_matrix" in self.__dict__: | |
new.precision_matrix = self.precision_matrix.expand(cov_shape) | |
# Chi2 distribution is needed for Bartlett decomposition sampling | |
new._dist_chi2 = torch.distributions.chi2.Chi2( | |
df=( | |
new.df.unsqueeze(-1) | |
- torch.arange( | |
self.event_shape[-1], | |
dtype=new._unbroadcasted_scale_tril.dtype, | |
device=new._unbroadcasted_scale_tril.device, | |
).expand(batch_shape + (-1,)) | |
) | |
) | |
super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False) | |
new._validate_args = self._validate_args | |
return new | |
def scale_tril(self): | |
return self._unbroadcasted_scale_tril.expand( | |
self._batch_shape + self._event_shape | |
) | |
def covariance_matrix(self): | |
return ( | |
self._unbroadcasted_scale_tril | |
).expand(self._batch_shape + self._event_shape) | |
def precision_matrix(self): | |
identity = torch.eye( | |
self._event_shape[-1], | |
device=self._unbroadcasted_scale_tril.device, | |
dtype=self._unbroadcasted_scale_tril.dtype, | |
) | |
return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand( | |
self._batch_shape + self._event_shape | |
) | |
def mean(self): | |
return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix | |
def mode(self): | |
factor = self.df - self.covariance_matrix.shape[-1] - 1 | |
factor[factor <= 0] = nan | |
return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix | |
def variance(self): | |
V = self.covariance_matrix # has shape (batch_shape x event_shape) | |
diag_V = V.diagonal(dim1=-2, dim2=-1) | |
return self.df.view(self._batch_shape + (1, 1)) * ( | |
V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V) | |
) | |
def _bartlett_sampling(self, sample_shape=torch.Size()): | |
p = self._event_shape[-1] # has singleton shape | |
# Implemented Sampling using Bartlett decomposition | |
noise = _clamp_above_eps( | |
self._dist_chi2.rsample(sample_shape).sqrt() | |
).diag_embed(dim1=-2, dim2=-1) | |
i, j = torch.tril_indices(p, p, offset=-1) | |
noise[..., i, j] = torch.randn( | |
torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),), | |
dtype=noise.dtype, | |
device=noise.device, | |
) | |
chol = self._unbroadcasted_scale_tril @ noise | |
return chol @ chol.transpose(-2, -1) | |
def rsample(self, sample_shape=torch.Size(), max_try_correction=None): | |
r""" | |
.. warning:: | |
In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples. | |
Several tries to correct singular samples are performed by default, but it may end up returning | |
singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`. | |
In those cases, the user should validate the samples and either fix the value of `df` | |
or adjust `max_try_correction` value for argument in `.rsample` accordingly. | |
""" | |
if max_try_correction is None: | |
max_try_correction = 3 if torch._C._get_tracing_state() else 10 | |
sample_shape = torch.Size(sample_shape) | |
sample = self._bartlett_sampling(sample_shape) | |
# Below part is to improve numerical stability temporally and should be removed in the future | |
is_singular = self.support.check(sample) | |
if self._batch_shape: | |
is_singular = is_singular.amax(self._batch_dims) | |
if torch._C._get_tracing_state(): | |
# Less optimized version for JIT | |
for _ in range(max_try_correction): | |
sample_new = self._bartlett_sampling(sample_shape) | |
sample = torch.where(is_singular, sample_new, sample) | |
is_singular = ~self.support.check(sample) | |
if self._batch_shape: | |
is_singular = is_singular.amax(self._batch_dims) | |
else: | |
# More optimized version with data-dependent control flow. | |
if is_singular.any(): | |
warnings.warn("Singular sample detected.") | |
for _ in range(max_try_correction): | |
sample_new = self._bartlett_sampling(is_singular[is_singular].shape) | |
sample[is_singular] = sample_new | |
is_singular_new = ~self.support.check(sample_new) | |
if self._batch_shape: | |
is_singular_new = is_singular_new.amax(self._batch_dims) | |
is_singular[is_singular.clone()] = is_singular_new | |
if not is_singular.any(): | |
break | |
return sample | |
def log_prob(self, value): | |
if self._validate_args: | |
self._validate_sample(value) | |
nu = self.df # has shape (batch_shape) | |
p = self._event_shape[-1] # has singleton shape | |
return ( | |
-nu | |
* ( | |
p * _log_2 / 2 | |
+ self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1) | |
.log() | |
.sum(-1) | |
) | |
- torch.mvlgamma(nu / 2, p=p) | |
+ (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet | |
- torch.cholesky_solve(value, self._unbroadcasted_scale_tril) | |
.diagonal(dim1=-2, dim2=-1) | |
.sum(dim=-1) | |
/ 2 | |
) | |
def entropy(self): | |
nu = self.df # has shape (batch_shape) | |
p = self._event_shape[-1] # has singleton shape | |
V = self.covariance_matrix # has shape (batch_shape x event_shape) | |
return ( | |
(p + 1) | |
* ( | |
p * _log_2 / 2 | |
+ self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1) | |
.log() | |
.sum(-1) | |
) | |
+ torch.mvlgamma(nu / 2, p=p) | |
- (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p) | |
+ nu * p / 2 | |
) | |
def _natural_params(self): | |
nu = self.df # has shape (batch_shape) | |
p = self._event_shape[-1] # has singleton shape | |
return -self.precision_matrix / 2, (nu - p - 1) / 2 | |
def _log_normalizer(self, x, y): | |
p = self._event_shape[-1] | |
return (y + (p + 1) / 2) * ( | |
-torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p | |
) + torch.mvlgamma(y + (p + 1) / 2, p=p) | |