Spaces:
Running
Running
import functools | |
import math | |
import numbers | |
import operator | |
import weakref | |
from typing import List | |
import torch | |
import torch.nn.functional as F | |
from torch.distributions import constraints | |
from torch.distributions.utils import ( | |
_sum_rightmost, | |
broadcast_all, | |
lazy_property, | |
tril_matrix_to_vec, | |
vec_to_tril_matrix, | |
) | |
from torch.nn.functional import pad, softplus | |
__all__ = [ | |
"AbsTransform", | |
"AffineTransform", | |
"CatTransform", | |
"ComposeTransform", | |
"CorrCholeskyTransform", | |
"CumulativeDistributionTransform", | |
"ExpTransform", | |
"IndependentTransform", | |
"LowerCholeskyTransform", | |
"PositiveDefiniteTransform", | |
"PowerTransform", | |
"ReshapeTransform", | |
"SigmoidTransform", | |
"SoftplusTransform", | |
"TanhTransform", | |
"SoftmaxTransform", | |
"StackTransform", | |
"StickBreakingTransform", | |
"Transform", | |
"identity_transform", | |
] | |
class Transform: | |
""" | |
Abstract class for invertable transformations with computable log | |
det jacobians. They are primarily used in | |
:class:`torch.distributions.TransformedDistribution`. | |
Caching is useful for transforms whose inverses are either expensive or | |
numerically unstable. Note that care must be taken with memoized values | |
since the autograd graph may be reversed. For example while the following | |
works with or without caching:: | |
y = t(x) | |
t.log_abs_det_jacobian(x, y).backward() # x will receive gradients. | |
However the following will error when caching due to dependency reversal:: | |
y = t(x) | |
z = t.inv(y) | |
grad(z.sum(), [y]) # error because z is x | |
Derived classes should implement one or both of :meth:`_call` or | |
:meth:`_inverse`. Derived classes that set `bijective=True` should also | |
implement :meth:`log_abs_det_jacobian`. | |
Args: | |
cache_size (int): Size of cache. If zero, no caching is done. If one, | |
the latest single value is cached. Only 0 and 1 are supported. | |
Attributes: | |
domain (:class:`~torch.distributions.constraints.Constraint`): | |
The constraint representing valid inputs to this transform. | |
codomain (:class:`~torch.distributions.constraints.Constraint`): | |
The constraint representing valid outputs to this transform | |
which are inputs to the inverse transform. | |
bijective (bool): Whether this transform is bijective. A transform | |
``t`` is bijective iff ``t.inv(t(x)) == x`` and | |
``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in | |
the codomain. Transforms that are not bijective should at least | |
maintain the weaker pseudoinverse properties | |
``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``. | |
sign (int or Tensor): For bijective univariate transforms, this | |
should be +1 or -1 depending on whether transform is monotone | |
increasing or decreasing. | |
""" | |
bijective = False | |
domain: constraints.Constraint | |
codomain: constraints.Constraint | |
def __init__(self, cache_size=0): | |
self._cache_size = cache_size | |
self._inv = None | |
if cache_size == 0: | |
pass # default behavior | |
elif cache_size == 1: | |
self._cached_x_y = None, None | |
else: | |
raise ValueError("cache_size must be 0 or 1") | |
super().__init__() | |
def __getstate__(self): | |
state = self.__dict__.copy() | |
state["_inv"] = None | |
return state | |
def event_dim(self): | |
if self.domain.event_dim == self.codomain.event_dim: | |
return self.domain.event_dim | |
raise ValueError("Please use either .domain.event_dim or .codomain.event_dim") | |
def inv(self): | |
""" | |
Returns the inverse :class:`Transform` of this transform. | |
This should satisfy ``t.inv.inv is t``. | |
""" | |
inv = None | |
if self._inv is not None: | |
inv = self._inv() | |
if inv is None: | |
inv = _InverseTransform(self) | |
self._inv = weakref.ref(inv) | |
return inv | |
def sign(self): | |
""" | |
Returns the sign of the determinant of the Jacobian, if applicable. | |
In general this only makes sense for bijective transforms. | |
""" | |
raise NotImplementedError | |
def with_cache(self, cache_size=1): | |
if self._cache_size == cache_size: | |
return self | |
if type(self).__init__ is Transform.__init__: | |
return type(self)(cache_size=cache_size) | |
raise NotImplementedError(f"{type(self)}.with_cache is not implemented") | |
def __eq__(self, other): | |
return self is other | |
def __ne__(self, other): | |
# Necessary for Python2 | |
return not self.__eq__(other) | |
def __call__(self, x): | |
""" | |
Computes the transform `x => y`. | |
""" | |
if self._cache_size == 0: | |
return self._call(x) | |
x_old, y_old = self._cached_x_y | |
if x is x_old: | |
return y_old | |
y = self._call(x) | |
self._cached_x_y = x, y | |
return y | |
def _inv_call(self, y): | |
""" | |
Inverts the transform `y => x`. | |
""" | |
if self._cache_size == 0: | |
return self._inverse(y) | |
x_old, y_old = self._cached_x_y | |
if y is y_old: | |
return x_old | |
x = self._inverse(y) | |
self._cached_x_y = x, y | |
return x | |
def _call(self, x): | |
""" | |
Abstract method to compute forward transformation. | |
""" | |
raise NotImplementedError | |
def _inverse(self, y): | |
""" | |
Abstract method to compute inverse transformation. | |
""" | |
raise NotImplementedError | |
def log_abs_det_jacobian(self, x, y): | |
""" | |
Computes the log det jacobian `log |dy/dx|` given input and output. | |
""" | |
raise NotImplementedError | |
def __repr__(self): | |
return self.__class__.__name__ + "()" | |
def forward_shape(self, shape): | |
""" | |
Infers the shape of the forward computation, given the input shape. | |
Defaults to preserving shape. | |
""" | |
return shape | |
def inverse_shape(self, shape): | |
""" | |
Infers the shapes of the inverse computation, given the output shape. | |
Defaults to preserving shape. | |
""" | |
return shape | |
class _InverseTransform(Transform): | |
""" | |
Inverts a single :class:`Transform`. | |
This class is private; please instead use the ``Transform.inv`` property. | |
""" | |
def __init__(self, transform: Transform): | |
super().__init__(cache_size=transform._cache_size) | |
self._inv: Transform = transform | |
def domain(self): | |
assert self._inv is not None | |
return self._inv.codomain | |
def codomain(self): | |
assert self._inv is not None | |
return self._inv.domain | |
def bijective(self): | |
assert self._inv is not None | |
return self._inv.bijective | |
def sign(self): | |
assert self._inv is not None | |
return self._inv.sign | |
def inv(self): | |
return self._inv | |
def with_cache(self, cache_size=1): | |
assert self._inv is not None | |
return self.inv.with_cache(cache_size).inv | |
def __eq__(self, other): | |
if not isinstance(other, _InverseTransform): | |
return False | |
assert self._inv is not None | |
return self._inv == other._inv | |
def __repr__(self): | |
return f"{self.__class__.__name__}({repr(self._inv)})" | |
def __call__(self, x): | |
assert self._inv is not None | |
return self._inv._inv_call(x) | |
def log_abs_det_jacobian(self, x, y): | |
assert self._inv is not None | |
return -self._inv.log_abs_det_jacobian(y, x) | |
def forward_shape(self, shape): | |
return self._inv.inverse_shape(shape) | |
def inverse_shape(self, shape): | |
return self._inv.forward_shape(shape) | |
class ComposeTransform(Transform): | |
""" | |
Composes multiple transforms in a chain. | |
The transforms being composed are responsible for caching. | |
Args: | |
parts (list of :class:`Transform`): A list of transforms to compose. | |
cache_size (int): Size of cache. If zero, no caching is done. If one, | |
the latest single value is cached. Only 0 and 1 are supported. | |
""" | |
def __init__(self, parts: List[Transform], cache_size=0): | |
if cache_size: | |
parts = [part.with_cache(cache_size) for part in parts] | |
super().__init__(cache_size=cache_size) | |
self.parts = parts | |
def __eq__(self, other): | |
if not isinstance(other, ComposeTransform): | |
return False | |
return self.parts == other.parts | |
def domain(self): | |
if not self.parts: | |
return constraints.real | |
domain = self.parts[0].domain | |
# Adjust event_dim to be maximum among all parts. | |
event_dim = self.parts[-1].codomain.event_dim | |
for part in reversed(self.parts): | |
event_dim += part.domain.event_dim - part.codomain.event_dim | |
event_dim = max(event_dim, part.domain.event_dim) | |
assert event_dim >= domain.event_dim | |
if event_dim > domain.event_dim: | |
domain = constraints.independent(domain, event_dim - domain.event_dim) | |
return domain | |
def codomain(self): | |
if not self.parts: | |
return constraints.real | |
codomain = self.parts[-1].codomain | |
# Adjust event_dim to be maximum among all parts. | |
event_dim = self.parts[0].domain.event_dim | |
for part in self.parts: | |
event_dim += part.codomain.event_dim - part.domain.event_dim | |
event_dim = max(event_dim, part.codomain.event_dim) | |
assert event_dim >= codomain.event_dim | |
if event_dim > codomain.event_dim: | |
codomain = constraints.independent(codomain, event_dim - codomain.event_dim) | |
return codomain | |
def bijective(self): | |
return all(p.bijective for p in self.parts) | |
def sign(self): | |
sign = 1 | |
for p in self.parts: | |
sign = sign * p.sign | |
return sign | |
def inv(self): | |
inv = None | |
if self._inv is not None: | |
inv = self._inv() | |
if inv is None: | |
inv = ComposeTransform([p.inv for p in reversed(self.parts)]) | |
self._inv = weakref.ref(inv) | |
inv._inv = weakref.ref(self) | |
return inv | |
def with_cache(self, cache_size=1): | |
if self._cache_size == cache_size: | |
return self | |
return ComposeTransform(self.parts, cache_size=cache_size) | |
def __call__(self, x): | |
for part in self.parts: | |
x = part(x) | |
return x | |
def log_abs_det_jacobian(self, x, y): | |
if not self.parts: | |
return torch.zeros_like(x) | |
# Compute intermediates. This will be free if parts[:-1] are all cached. | |
xs = [x] | |
for part in self.parts[:-1]: | |
xs.append(part(xs[-1])) | |
xs.append(y) | |
terms = [] | |
event_dim = self.domain.event_dim | |
for part, x, y in zip(self.parts, xs[:-1], xs[1:]): | |
terms.append( | |
_sum_rightmost( | |
part.log_abs_det_jacobian(x, y), event_dim - part.domain.event_dim | |
) | |
) | |
event_dim += part.codomain.event_dim - part.domain.event_dim | |
return functools.reduce(operator.add, terms) | |
def forward_shape(self, shape): | |
for part in self.parts: | |
shape = part.forward_shape(shape) | |
return shape | |
def inverse_shape(self, shape): | |
for part in reversed(self.parts): | |
shape = part.inverse_shape(shape) | |
return shape | |
def __repr__(self): | |
fmt_string = self.__class__.__name__ + "(\n " | |
fmt_string += ",\n ".join([p.__repr__() for p in self.parts]) | |
fmt_string += "\n)" | |
return fmt_string | |
identity_transform = ComposeTransform([]) | |
class IndependentTransform(Transform): | |
""" | |
Wrapper around another transform to treat | |
``reinterpreted_batch_ndims``-many extra of the right most dimensions as | |
dependent. This has no effect on the forward or backward transforms, but | |
does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions | |
in :meth:`log_abs_det_jacobian`. | |
Args: | |
base_transform (:class:`Transform`): A base transform. | |
reinterpreted_batch_ndims (int): The number of extra rightmost | |
dimensions to treat as dependent. | |
""" | |
def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0): | |
super().__init__(cache_size=cache_size) | |
self.base_transform = base_transform.with_cache(cache_size) | |
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims | |
def with_cache(self, cache_size=1): | |
if self._cache_size == cache_size: | |
return self | |
return IndependentTransform( | |
self.base_transform, self.reinterpreted_batch_ndims, cache_size=cache_size | |
) | |
def domain(self): | |
return constraints.independent( | |
self.base_transform.domain, self.reinterpreted_batch_ndims | |
) | |
def codomain(self): | |
return constraints.independent( | |
self.base_transform.codomain, self.reinterpreted_batch_ndims | |
) | |
def bijective(self): | |
return self.base_transform.bijective | |
def sign(self): | |
return self.base_transform.sign | |
def _call(self, x): | |
if x.dim() < self.domain.event_dim: | |
raise ValueError("Too few dimensions on input") | |
return self.base_transform(x) | |
def _inverse(self, y): | |
if y.dim() < self.codomain.event_dim: | |
raise ValueError("Too few dimensions on input") | |
return self.base_transform.inv(y) | |
def log_abs_det_jacobian(self, x, y): | |
result = self.base_transform.log_abs_det_jacobian(x, y) | |
result = _sum_rightmost(result, self.reinterpreted_batch_ndims) | |
return result | |
def __repr__(self): | |
return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})" | |
def forward_shape(self, shape): | |
return self.base_transform.forward_shape(shape) | |
def inverse_shape(self, shape): | |
return self.base_transform.inverse_shape(shape) | |
class ReshapeTransform(Transform): | |
""" | |
Unit Jacobian transform to reshape the rightmost part of a tensor. | |
Note that ``in_shape`` and ``out_shape`` must have the same number of | |
elements, just as for :meth:`torch.Tensor.reshape`. | |
Arguments: | |
in_shape (torch.Size): The input event shape. | |
out_shape (torch.Size): The output event shape. | |
""" | |
bijective = True | |
def __init__(self, in_shape, out_shape, cache_size=0): | |
self.in_shape = torch.Size(in_shape) | |
self.out_shape = torch.Size(out_shape) | |
if self.in_shape.numel() != self.out_shape.numel(): | |
raise ValueError("in_shape, out_shape have different numbers of elements") | |
super().__init__(cache_size=cache_size) | |
def domain(self): | |
return constraints.independent(constraints.real, len(self.in_shape)) | |
def codomain(self): | |
return constraints.independent(constraints.real, len(self.out_shape)) | |
def with_cache(self, cache_size=1): | |
if self._cache_size == cache_size: | |
return self | |
return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size) | |
def _call(self, x): | |
batch_shape = x.shape[: x.dim() - len(self.in_shape)] | |
return x.reshape(batch_shape + self.out_shape) | |
def _inverse(self, y): | |
batch_shape = y.shape[: y.dim() - len(self.out_shape)] | |
return y.reshape(batch_shape + self.in_shape) | |
def log_abs_det_jacobian(self, x, y): | |
batch_shape = x.shape[: x.dim() - len(self.in_shape)] | |
return x.new_zeros(batch_shape) | |
def forward_shape(self, shape): | |
if len(shape) < len(self.in_shape): | |
raise ValueError("Too few dimensions on input") | |
cut = len(shape) - len(self.in_shape) | |
if shape[cut:] != self.in_shape: | |
raise ValueError( | |
f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}" | |
) | |
return shape[:cut] + self.out_shape | |
def inverse_shape(self, shape): | |
if len(shape) < len(self.out_shape): | |
raise ValueError("Too few dimensions on input") | |
cut = len(shape) - len(self.out_shape) | |
if shape[cut:] != self.out_shape: | |
raise ValueError( | |
f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}" | |
) | |
return shape[:cut] + self.in_shape | |
class ExpTransform(Transform): | |
r""" | |
Transform via the mapping :math:`y = \exp(x)`. | |
""" | |
domain = constraints.real | |
codomain = constraints.positive | |
bijective = True | |
sign = +1 | |
def __eq__(self, other): | |
return isinstance(other, ExpTransform) | |
def _call(self, x): | |
return x.exp() | |
def _inverse(self, y): | |
return y.log() | |
def log_abs_det_jacobian(self, x, y): | |
return x | |
class PowerTransform(Transform): | |
r""" | |
Transform via the mapping :math:`y = x^{\text{exponent}}`. | |
""" | |
domain = constraints.positive | |
codomain = constraints.positive | |
bijective = True | |
def __init__(self, exponent, cache_size=0): | |
super().__init__(cache_size=cache_size) | |
(self.exponent,) = broadcast_all(exponent) | |
def with_cache(self, cache_size=1): | |
if self._cache_size == cache_size: | |
return self | |
return PowerTransform(self.exponent, cache_size=cache_size) | |
def sign(self): | |
return self.exponent.sign() | |
def __eq__(self, other): | |
if not isinstance(other, PowerTransform): | |
return False | |
return self.exponent.eq(other.exponent).all().item() | |
def _call(self, x): | |
return x.pow(self.exponent) | |
def _inverse(self, y): | |
return y.pow(1 / self.exponent) | |
def log_abs_det_jacobian(self, x, y): | |
return (self.exponent * y / x).abs().log() | |
def forward_shape(self, shape): | |
return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) | |
def inverse_shape(self, shape): | |
return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) | |
def _clipped_sigmoid(x): | |
finfo = torch.finfo(x.dtype) | |
return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1.0 - finfo.eps) | |
class SigmoidTransform(Transform): | |
r""" | |
Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`. | |
""" | |
domain = constraints.real | |
codomain = constraints.unit_interval | |
bijective = True | |
sign = +1 | |
def __eq__(self, other): | |
return isinstance(other, SigmoidTransform) | |
def _call(self, x): | |
return _clipped_sigmoid(x) | |
def _inverse(self, y): | |
finfo = torch.finfo(y.dtype) | |
y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps) | |
return y.log() - (-y).log1p() | |
def log_abs_det_jacobian(self, x, y): | |
return -F.softplus(-x) - F.softplus(x) | |
class SoftplusTransform(Transform): | |
r""" | |
Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`. | |
The implementation reverts to the linear function when :math:`x > 20`. | |
""" | |
domain = constraints.real | |
codomain = constraints.positive | |
bijective = True | |
sign = +1 | |
def __eq__(self, other): | |
return isinstance(other, SoftplusTransform) | |
def _call(self, x): | |
return softplus(x) | |
def _inverse(self, y): | |
return (-y).expm1().neg().log() + y | |
def log_abs_det_jacobian(self, x, y): | |
return -softplus(-x) | |
class TanhTransform(Transform): | |
r""" | |
Transform via the mapping :math:`y = \tanh(x)`. | |
It is equivalent to | |
``` | |
ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)]) | |
``` | |
However this might not be numerically stable, thus it is recommended to use `TanhTransform` | |
instead. | |
Note that one should use `cache_size=1` when it comes to `NaN/Inf` values. | |
""" | |
domain = constraints.real | |
codomain = constraints.interval(-1.0, 1.0) | |
bijective = True | |
sign = +1 | |
def __eq__(self, other): | |
return isinstance(other, TanhTransform) | |
def _call(self, x): | |
return x.tanh() | |
def _inverse(self, y): | |
# We do not clamp to the boundary here as it may degrade the performance of certain algorithms. | |
# one should use `cache_size=1` instead | |
return torch.atanh(y) | |
def log_abs_det_jacobian(self, x, y): | |
# We use a formula that is more numerically stable, see details in the following link | |
# https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80 | |
return 2.0 * (math.log(2.0) - x - softplus(-2.0 * x)) | |
class AbsTransform(Transform): | |
r""" | |
Transform via the mapping :math:`y = |x|`. | |
""" | |
domain = constraints.real | |
codomain = constraints.positive | |
def __eq__(self, other): | |
return isinstance(other, AbsTransform) | |
def _call(self, x): | |
return x.abs() | |
def _inverse(self, y): | |
return y | |
class AffineTransform(Transform): | |
r""" | |
Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`. | |
Args: | |
loc (Tensor or float): Location parameter. | |
scale (Tensor or float): Scale parameter. | |
event_dim (int): Optional size of `event_shape`. This should be zero | |
for univariate random variables, 1 for distributions over vectors, | |
2 for distributions over matrices, etc. | |
""" | |
bijective = True | |
def __init__(self, loc, scale, event_dim=0, cache_size=0): | |
super().__init__(cache_size=cache_size) | |
self.loc = loc | |
self.scale = scale | |
self._event_dim = event_dim | |
def event_dim(self): | |
return self._event_dim | |
def domain(self): | |
if self.event_dim == 0: | |
return constraints.real | |
return constraints.independent(constraints.real, self.event_dim) | |
def codomain(self): | |
if self.event_dim == 0: | |
return constraints.real | |
return constraints.independent(constraints.real, self.event_dim) | |
def with_cache(self, cache_size=1): | |
if self._cache_size == cache_size: | |
return self | |
return AffineTransform( | |
self.loc, self.scale, self.event_dim, cache_size=cache_size | |
) | |
def __eq__(self, other): | |
if not isinstance(other, AffineTransform): | |
return False | |
if isinstance(self.loc, numbers.Number) and isinstance( | |
other.loc, numbers.Number | |
): | |
if self.loc != other.loc: | |
return False | |
else: | |
if not (self.loc == other.loc).all().item(): | |
return False | |
if isinstance(self.scale, numbers.Number) and isinstance( | |
other.scale, numbers.Number | |
): | |
if self.scale != other.scale: | |
return False | |
else: | |
if not (self.scale == other.scale).all().item(): | |
return False | |
return True | |
def sign(self): | |
if isinstance(self.scale, numbers.Real): | |
return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0 | |
return self.scale.sign() | |
def _call(self, x): | |
return self.loc + self.scale * x | |
def _inverse(self, y): | |
return (y - self.loc) / self.scale | |
def log_abs_det_jacobian(self, x, y): | |
shape = x.shape | |
scale = self.scale | |
if isinstance(scale, numbers.Real): | |
result = torch.full_like(x, math.log(abs(scale))) | |
else: | |
result = torch.abs(scale).log() | |
if self.event_dim: | |
result_size = result.size()[: -self.event_dim] + (-1,) | |
result = result.view(result_size).sum(-1) | |
shape = shape[: -self.event_dim] | |
return result.expand(shape) | |
def forward_shape(self, shape): | |
return torch.broadcast_shapes( | |
shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) | |
) | |
def inverse_shape(self, shape): | |
return torch.broadcast_shapes( | |
shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) | |
) | |
class CorrCholeskyTransform(Transform): | |
r""" | |
Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the | |
Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower | |
triangular matrix with positive diagonals and unit Euclidean norm for each row. | |
The transform is processed as follows: | |
1. First we convert x into a lower triangular matrix in row order. | |
2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of | |
class :class:`StickBreakingTransform` to transform :math:`X_i` into a | |
unit Euclidean length vector using the following steps: | |
- Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`. | |
- Transforms into an unsigned domain: :math:`z_i = r_i^2`. | |
- Applies :math:`s_i = StickBreakingTransform(z_i)`. | |
- Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`. | |
""" | |
domain = constraints.real_vector | |
codomain = constraints.corr_cholesky | |
bijective = True | |
def _call(self, x): | |
x = torch.tanh(x) | |
eps = torch.finfo(x.dtype).eps | |
x = x.clamp(min=-1 + eps, max=1 - eps) | |
r = vec_to_tril_matrix(x, diag=-1) | |
# apply stick-breaking on the squared values | |
# Note that y = sign(r) * sqrt(z * z1m_cumprod) | |
# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) | |
z = r**2 | |
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) | |
# Diagonal elements must be 1. | |
r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device) | |
y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1) | |
return y | |
def _inverse(self, y): | |
# inverse stick-breaking | |
# See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html | |
y_cumsum = 1 - torch.cumsum(y * y, dim=-1) | |
y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1) | |
y_vec = tril_matrix_to_vec(y, diag=-1) | |
y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1) | |
t = y_vec / (y_cumsum_vec).sqrt() | |
# inverse of tanh | |
x = (t.log1p() - t.neg().log1p()) / 2 | |
return x | |
def log_abs_det_jacobian(self, x, y, intermediates=None): | |
# Because domain and codomain are two spaces with different dimensions, determinant of | |
# Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the | |
# flattened lower triangular part of `y`. | |
# See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html | |
y1m_cumsum = 1 - (y * y).cumsum(dim=-1) | |
# by taking diagonal=-2, we don't need to shift z_cumprod to the right | |
# also works for 2 x 2 matrix | |
y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2) | |
stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1) | |
tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.0)).sum(dim=-1) | |
return stick_breaking_logdet + tanh_logdet | |
def forward_shape(self, shape): | |
# Reshape from (..., N) to (..., D, D). | |
if len(shape) < 1: | |
raise ValueError("Too few dimensions on input") | |
N = shape[-1] | |
D = round((0.25 + 2 * N) ** 0.5 + 0.5) | |
if D * (D - 1) // 2 != N: | |
raise ValueError("Input is not a flattend lower-diagonal number") | |
return shape[:-1] + (D, D) | |
def inverse_shape(self, shape): | |
# Reshape from (..., D, D) to (..., N). | |
if len(shape) < 2: | |
raise ValueError("Too few dimensions on input") | |
if shape[-2] != shape[-1]: | |
raise ValueError("Input is not square") | |
D = shape[-1] | |
N = D * (D - 1) // 2 | |
return shape[:-2] + (N,) | |
class SoftmaxTransform(Transform): | |
r""" | |
Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then | |
normalizing. | |
This is not bijective and cannot be used for HMC. However this acts mostly | |
coordinate-wise (except for the final normalization), and thus is | |
appropriate for coordinate-wise optimization algorithms. | |
""" | |
domain = constraints.real_vector | |
codomain = constraints.simplex | |
def __eq__(self, other): | |
return isinstance(other, SoftmaxTransform) | |
def _call(self, x): | |
logprobs = x | |
probs = (logprobs - logprobs.max(-1, True)[0]).exp() | |
return probs / probs.sum(-1, True) | |
def _inverse(self, y): | |
probs = y | |
return probs.log() | |
def forward_shape(self, shape): | |
if len(shape) < 1: | |
raise ValueError("Too few dimensions on input") | |
return shape | |
def inverse_shape(self, shape): | |
if len(shape) < 1: | |
raise ValueError("Too few dimensions on input") | |
return shape | |
class StickBreakingTransform(Transform): | |
""" | |
Transform from unconstrained space to the simplex of one additional | |
dimension via a stick-breaking process. | |
This transform arises as an iterated sigmoid transform in a stick-breaking | |
construction of the `Dirichlet` distribution: the first logit is | |
transformed via sigmoid to the first probability and the probability of | |
everything else, and then the process recurses. | |
This is bijective and appropriate for use in HMC; however it mixes | |
coordinates together and is less appropriate for optimization. | |
""" | |
domain = constraints.real_vector | |
codomain = constraints.simplex | |
bijective = True | |
def __eq__(self, other): | |
return isinstance(other, StickBreakingTransform) | |
def _call(self, x): | |
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) | |
z = _clipped_sigmoid(x - offset.log()) | |
z_cumprod = (1 - z).cumprod(-1) | |
y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1) | |
return y | |
def _inverse(self, y): | |
y_crop = y[..., :-1] | |
offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1) | |
sf = 1 - y_crop.cumsum(-1) | |
# we clamp to make sure that sf is positive which sometimes does not | |
# happen when y[-1] ~ 0 or y[:-1].sum() ~ 1 | |
sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny) | |
x = y_crop.log() - sf.log() + offset.log() | |
return x | |
def log_abs_det_jacobian(self, x, y): | |
offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) | |
x = x - offset.log() | |
# use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x) | |
detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1) | |
return detJ | |
def forward_shape(self, shape): | |
if len(shape) < 1: | |
raise ValueError("Too few dimensions on input") | |
return shape[:-1] + (shape[-1] + 1,) | |
def inverse_shape(self, shape): | |
if len(shape) < 1: | |
raise ValueError("Too few dimensions on input") | |
return shape[:-1] + (shape[-1] - 1,) | |
class LowerCholeskyTransform(Transform): | |
""" | |
Transform from unconstrained matrices to lower-triangular matrices with | |
nonnegative diagonal entries. | |
This is useful for parameterizing positive definite matrices in terms of | |
their Cholesky factorization. | |
""" | |
domain = constraints.independent(constraints.real, 2) | |
codomain = constraints.lower_cholesky | |
def __eq__(self, other): | |
return isinstance(other, LowerCholeskyTransform) | |
def _call(self, x): | |
return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed() | |
def _inverse(self, y): | |
return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed() | |
class PositiveDefiniteTransform(Transform): | |
""" | |
Transform from unconstrained matrices to positive-definite matrices. | |
""" | |
domain = constraints.independent(constraints.real, 2) | |
codomain = constraints.positive_definite # type: ignore[assignment] | |
def __eq__(self, other): | |
return isinstance(other, PositiveDefiniteTransform) | |
def _call(self, x): | |
x = LowerCholeskyTransform()(x) | |
return x @ x.mT | |
def _inverse(self, y): | |
y = torch.linalg.cholesky(y) | |
return LowerCholeskyTransform().inv(y) | |
class CatTransform(Transform): | |
""" | |
Transform functor that applies a sequence of transforms `tseq` | |
component-wise to each submatrix at `dim`, of length `lengths[dim]`, | |
in a way compatible with :func:`torch.cat`. | |
Example:: | |
x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0) | |
x = torch.cat([x0, x0], dim=0) | |
t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10]) | |
t = CatTransform([t0, t0], dim=0, lengths=[20, 20]) | |
y = t(x) | |
""" | |
transforms: List[Transform] | |
def __init__(self, tseq, dim=0, lengths=None, cache_size=0): | |
assert all(isinstance(t, Transform) for t in tseq) | |
if cache_size: | |
tseq = [t.with_cache(cache_size) for t in tseq] | |
super().__init__(cache_size=cache_size) | |
self.transforms = list(tseq) | |
if lengths is None: | |
lengths = [1] * len(self.transforms) | |
self.lengths = list(lengths) | |
assert len(self.lengths) == len(self.transforms) | |
self.dim = dim | |
def event_dim(self): | |
return max(t.event_dim for t in self.transforms) | |
def length(self): | |
return sum(self.lengths) | |
def with_cache(self, cache_size=1): | |
if self._cache_size == cache_size: | |
return self | |
return CatTransform(self.transforms, self.dim, self.lengths, cache_size) | |
def _call(self, x): | |
assert -x.dim() <= self.dim < x.dim() | |
assert x.size(self.dim) == self.length | |
yslices = [] | |
start = 0 | |
for trans, length in zip(self.transforms, self.lengths): | |
xslice = x.narrow(self.dim, start, length) | |
yslices.append(trans(xslice)) | |
start = start + length # avoid += for jit compat | |
return torch.cat(yslices, dim=self.dim) | |
def _inverse(self, y): | |
assert -y.dim() <= self.dim < y.dim() | |
assert y.size(self.dim) == self.length | |
xslices = [] | |
start = 0 | |
for trans, length in zip(self.transforms, self.lengths): | |
yslice = y.narrow(self.dim, start, length) | |
xslices.append(trans.inv(yslice)) | |
start = start + length # avoid += for jit compat | |
return torch.cat(xslices, dim=self.dim) | |
def log_abs_det_jacobian(self, x, y): | |
assert -x.dim() <= self.dim < x.dim() | |
assert x.size(self.dim) == self.length | |
assert -y.dim() <= self.dim < y.dim() | |
assert y.size(self.dim) == self.length | |
logdetjacs = [] | |
start = 0 | |
for trans, length in zip(self.transforms, self.lengths): | |
xslice = x.narrow(self.dim, start, length) | |
yslice = y.narrow(self.dim, start, length) | |
logdetjac = trans.log_abs_det_jacobian(xslice, yslice) | |
if trans.event_dim < self.event_dim: | |
logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim) | |
logdetjacs.append(logdetjac) | |
start = start + length # avoid += for jit compat | |
# Decide whether to concatenate or sum. | |
dim = self.dim | |
if dim >= 0: | |
dim = dim - x.dim() | |
dim = dim + self.event_dim | |
if dim < 0: | |
return torch.cat(logdetjacs, dim=dim) | |
else: | |
return sum(logdetjacs) | |
def bijective(self): | |
return all(t.bijective for t in self.transforms) | |
def domain(self): | |
return constraints.cat( | |
[t.domain for t in self.transforms], self.dim, self.lengths | |
) | |
def codomain(self): | |
return constraints.cat( | |
[t.codomain for t in self.transforms], self.dim, self.lengths | |
) | |
class StackTransform(Transform): | |
""" | |
Transform functor that applies a sequence of transforms `tseq` | |
component-wise to each submatrix at `dim` | |
in a way compatible with :func:`torch.stack`. | |
Example:: | |
x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1) | |
t = StackTransform([ExpTransform(), identity_transform], dim=1) | |
y = t(x) | |
""" | |
transforms: List[Transform] | |
def __init__(self, tseq, dim=0, cache_size=0): | |
assert all(isinstance(t, Transform) for t in tseq) | |
if cache_size: | |
tseq = [t.with_cache(cache_size) for t in tseq] | |
super().__init__(cache_size=cache_size) | |
self.transforms = list(tseq) | |
self.dim = dim | |
def with_cache(self, cache_size=1): | |
if self._cache_size == cache_size: | |
return self | |
return StackTransform(self.transforms, self.dim, cache_size) | |
def _slice(self, z): | |
return [z.select(self.dim, i) for i in range(z.size(self.dim))] | |
def _call(self, x): | |
assert -x.dim() <= self.dim < x.dim() | |
assert x.size(self.dim) == len(self.transforms) | |
yslices = [] | |
for xslice, trans in zip(self._slice(x), self.transforms): | |
yslices.append(trans(xslice)) | |
return torch.stack(yslices, dim=self.dim) | |
def _inverse(self, y): | |
assert -y.dim() <= self.dim < y.dim() | |
assert y.size(self.dim) == len(self.transforms) | |
xslices = [] | |
for yslice, trans in zip(self._slice(y), self.transforms): | |
xslices.append(trans.inv(yslice)) | |
return torch.stack(xslices, dim=self.dim) | |
def log_abs_det_jacobian(self, x, y): | |
assert -x.dim() <= self.dim < x.dim() | |
assert x.size(self.dim) == len(self.transforms) | |
assert -y.dim() <= self.dim < y.dim() | |
assert y.size(self.dim) == len(self.transforms) | |
logdetjacs = [] | |
yslices = self._slice(y) | |
xslices = self._slice(x) | |
for xslice, yslice, trans in zip(xslices, yslices, self.transforms): | |
logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice)) | |
return torch.stack(logdetjacs, dim=self.dim) | |
def bijective(self): | |
return all(t.bijective for t in self.transforms) | |
def domain(self): | |
return constraints.stack([t.domain for t in self.transforms], self.dim) | |
def codomain(self): | |
return constraints.stack([t.codomain for t in self.transforms], self.dim) | |
class CumulativeDistributionTransform(Transform): | |
""" | |
Transform via the cumulative distribution function of a probability distribution. | |
Args: | |
distribution (Distribution): Distribution whose cumulative distribution function to use for | |
the transformation. | |
Example:: | |
# Construct a Gaussian copula from a multivariate normal. | |
base_dist = MultivariateNormal( | |
loc=torch.zeros(2), | |
scale_tril=LKJCholesky(2).sample(), | |
) | |
transform = CumulativeDistributionTransform(Normal(0, 1)) | |
copula = TransformedDistribution(base_dist, [transform]) | |
""" | |
bijective = True | |
codomain = constraints.unit_interval | |
sign = +1 | |
def __init__(self, distribution, cache_size=0): | |
super().__init__(cache_size=cache_size) | |
self.distribution = distribution | |
def domain(self): | |
return self.distribution.support | |
def _call(self, x): | |
return self.distribution.cdf(x) | |
def _inverse(self, y): | |
return self.distribution.icdf(y) | |
def log_abs_det_jacobian(self, x, y): | |
return self.distribution.log_prob(x) | |
def with_cache(self, cache_size=1): | |
if self._cache_size == cache_size: | |
return self | |
return CumulativeDistributionTransform(self.distribution, cache_size=cache_size) | |