from typing import Dict import torch from torch.distributions import constraints from torch.distributions.distribution import Distribution from torch.distributions.independent import Independent from torch.distributions.transforms import ComposeTransform, Transform from torch.distributions.utils import _sum_rightmost __all__ = ["TransformedDistribution"] class TransformedDistribution(Distribution): r""" Extension of the Distribution class, which applies a sequence of Transforms to a base distribution. Let f be the composition of transforms applied:: X ~ BaseDistribution Y = f(X) ~ TransformedDistribution(BaseDistribution, f) log p(Y) = log p(X) + log |det (dX/dY)| Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the maximum shape of its base distribution and its transforms, since transforms can introduce correlations among events. An example for the usage of :class:`TransformedDistribution` would be:: # Building a Logistic Distribution # X ~ Uniform(0, 1) # f = a + b * logit(X) # Y ~ f(X) ~ Logistic(a, b) base_distribution = Uniform(0, 1) transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] logistic = TransformedDistribution(base_distribution, transforms) For more examples, please look at the implementations of :class:`~torch.distributions.gumbel.Gumbel`, :class:`~torch.distributions.half_cauchy.HalfCauchy`, :class:`~torch.distributions.half_normal.HalfNormal`, :class:`~torch.distributions.log_normal.LogNormal`, :class:`~torch.distributions.pareto.Pareto`, :class:`~torch.distributions.weibull.Weibull`, :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` """ arg_constraints: Dict[str, constraints.Constraint] = {} def __init__(self, base_distribution, transforms, validate_args=None): if isinstance(transforms, Transform): self.transforms = [ transforms, ] elif isinstance(transforms, list): if not all(isinstance(t, Transform) for t in transforms): raise ValueError( "transforms must be a Transform or a list of Transforms" ) self.transforms = transforms else: raise ValueError( f"transforms must be a Transform or list, but was {transforms}" ) # Reshape base_distribution according to transforms. base_shape = base_distribution.batch_shape + base_distribution.event_shape base_event_dim = len(base_distribution.event_shape) transform = ComposeTransform(self.transforms) if len(base_shape) < transform.domain.event_dim: raise ValueError( "base_distribution needs to have shape with size at least {}, but got {}.".format( transform.domain.event_dim, base_shape ) ) forward_shape = transform.forward_shape(base_shape) expanded_base_shape = transform.inverse_shape(forward_shape) if base_shape != expanded_base_shape: base_batch_shape = expanded_base_shape[ : len(expanded_base_shape) - base_event_dim ] base_distribution = base_distribution.expand(base_batch_shape) reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim if reinterpreted_batch_ndims > 0: base_distribution = Independent( base_distribution, reinterpreted_batch_ndims ) self.base_dist = base_distribution # Compute shapes. transform_change_in_event_dim = ( transform.codomain.event_dim - transform.domain.event_dim ) event_dim = max( transform.codomain.event_dim, # the transform is coupled base_event_dim + transform_change_in_event_dim, # the base dist is coupled ) assert len(forward_shape) >= event_dim cut = len(forward_shape) - event_dim batch_shape = forward_shape[:cut] event_shape = forward_shape[cut:] super().__init__(batch_shape, event_shape, validate_args=validate_args) def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(TransformedDistribution, _instance) batch_shape = torch.Size(batch_shape) shape = batch_shape + self.event_shape for t in reversed(self.transforms): shape = t.inverse_shape(shape) base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)] new.base_dist = self.base_dist.expand(base_batch_shape) new.transforms = self.transforms super(TransformedDistribution, new).__init__( batch_shape, self.event_shape, validate_args=False ) new._validate_args = self._validate_args return new @constraints.dependent_property(is_discrete=False) def support(self): if not self.transforms: return self.base_dist.support support = self.transforms[-1].codomain if len(self.event_shape) > support.event_dim: support = constraints.independent( support, len(self.event_shape) - support.event_dim ) return support @property def has_rsample(self): return self.base_dist.has_rsample def sample(self, sample_shape=torch.Size()): """ Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. Samples first from base distribution and applies `transform()` for every transform in the list. """ with torch.no_grad(): x = self.base_dist.sample(sample_shape) for transform in self.transforms: x = transform(x) return x def rsample(self, sample_shape=torch.Size()): """ Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. Samples first from base distribution and applies `transform()` for every transform in the list. """ x = self.base_dist.rsample(sample_shape) for transform in self.transforms: x = transform(x) return x def log_prob(self, value): """ Scores the sample by inverting the transform(s) and computing the score using the score of the base distribution and the log abs det jacobian. """ if self._validate_args: self._validate_sample(value) event_dim = len(self.event_shape) log_prob = 0.0 y = value for transform in reversed(self.transforms): x = transform.inv(y) event_dim += transform.domain.event_dim - transform.codomain.event_dim log_prob = log_prob - _sum_rightmost( transform.log_abs_det_jacobian(x, y), event_dim - transform.domain.event_dim, ) y = x log_prob = log_prob + _sum_rightmost( self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape) ) return log_prob def _monotonize_cdf(self, value): """ This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is monotone increasing. """ sign = 1 for transform in self.transforms: sign = sign * transform.sign if isinstance(sign, int) and sign == 1: return value return sign * (value - 0.5) + 0.5 def cdf(self, value): """ Computes the cumulative distribution function by inverting the transform(s) and computing the score of the base distribution. """ for transform in self.transforms[::-1]: value = transform.inv(value) if self._validate_args: self.base_dist._validate_sample(value) value = self.base_dist.cdf(value) value = self._monotonize_cdf(value) return value def icdf(self, value): """ Computes the inverse cumulative distribution function using transform(s) and computing the score of the base distribution. """ value = self._monotonize_cdf(value) value = self.base_dist.icdf(value) for transform in self.transforms: value = transform(value) return value