Spaces:
Running
Running
import torch | |
from torch.distributions.distribution import Distribution | |
__all__ = ["ExponentialFamily"] | |
class ExponentialFamily(Distribution): | |
r""" | |
ExponentialFamily is the abstract base class for probability distributions belonging to an | |
exponential family, whose probability mass/density function has the form is defined below | |
.. math:: | |
p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x)) | |
where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic, | |
:math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier | |
measure. | |
Note: | |
This class is an intermediary between the `Distribution` class and distributions which belong | |
to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL | |
divergence methods. We use this class to compute the entropy and KL divergence using the AD | |
framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and | |
Cross-entropies of Exponential Families). | |
""" | |
def _natural_params(self): | |
""" | |
Abstract method for natural parameters. Returns a tuple of Tensors based | |
on the distribution | |
""" | |
raise NotImplementedError | |
def _log_normalizer(self, *natural_params): | |
""" | |
Abstract method for log normalizer function. Returns a log normalizer based on | |
the distribution and input | |
""" | |
raise NotImplementedError | |
def _mean_carrier_measure(self): | |
""" | |
Abstract method for expected carrier measure, which is required for computing | |
entropy. | |
""" | |
raise NotImplementedError | |
def entropy(self): | |
""" | |
Method to compute the entropy using Bregman divergence of the log normalizer. | |
""" | |
result = -self._mean_carrier_measure | |
nparams = [p.detach().requires_grad_() for p in self._natural_params] | |
lg_normal = self._log_normalizer(*nparams) | |
gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True) | |
result += lg_normal | |
for np, g in zip(nparams, gradients): | |
result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1) | |
return result | |