Spaces:
Running
Running
import math | |
import warnings | |
from functools import total_ordering | |
from typing import Callable, Dict, Tuple, Type | |
import torch | |
from torch import inf | |
from .bernoulli import Bernoulli | |
from .beta import Beta | |
from .binomial import Binomial | |
from .categorical import Categorical | |
from .cauchy import Cauchy | |
from .continuous_bernoulli import ContinuousBernoulli | |
from .dirichlet import Dirichlet | |
from .distribution import Distribution | |
from .exp_family import ExponentialFamily | |
from .exponential import Exponential | |
from .gamma import Gamma | |
from .geometric import Geometric | |
from .gumbel import Gumbel | |
from .half_normal import HalfNormal | |
from .independent import Independent | |
from .laplace import Laplace | |
from .lowrank_multivariate_normal import ( | |
_batch_lowrank_logdet, | |
_batch_lowrank_mahalanobis, | |
LowRankMultivariateNormal, | |
) | |
from .multivariate_normal import _batch_mahalanobis, MultivariateNormal | |
from .normal import Normal | |
from .one_hot_categorical import OneHotCategorical | |
from .pareto import Pareto | |
from .poisson import Poisson | |
from .transformed_distribution import TransformedDistribution | |
from .uniform import Uniform | |
from .utils import _sum_rightmost, euler_constant as _euler_gamma | |
_KL_REGISTRY: Dict[ | |
Tuple[Type, Type], Callable | |
] = {} # Source of truth mapping a few general (type, type) pairs to functions. | |
_KL_MEMOIZE: Dict[ | |
Tuple[Type, Type], Callable | |
] = {} # Memoized version mapping many specific (type, type) pairs to functions. | |
__all__ = ["register_kl", "kl_divergence"] | |
def register_kl(type_p, type_q): | |
""" | |
Decorator to register a pairwise function with :meth:`kl_divergence`. | |
Usage:: | |
@register_kl(Normal, Normal) | |
def kl_normal_normal(p, q): | |
# insert implementation here | |
Lookup returns the most specific (type,type) match ordered by subclass. If | |
the match is ambiguous, a `RuntimeWarning` is raised. For example to | |
resolve the ambiguous situation:: | |
@register_kl(BaseP, DerivedQ) | |
def kl_version1(p, q): ... | |
@register_kl(DerivedP, BaseQ) | |
def kl_version2(p, q): ... | |
you should register a third most-specific implementation, e.g.:: | |
register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie. | |
Args: | |
type_p (type): A subclass of :class:`~torch.distributions.Distribution`. | |
type_q (type): A subclass of :class:`~torch.distributions.Distribution`. | |
""" | |
if not isinstance(type_p, type) and issubclass(type_p, Distribution): | |
raise TypeError( | |
f"Expected type_p to be a Distribution subclass but got {type_p}" | |
) | |
if not isinstance(type_q, type) and issubclass(type_q, Distribution): | |
raise TypeError( | |
f"Expected type_q to be a Distribution subclass but got {type_q}" | |
) | |
def decorator(fun): | |
_KL_REGISTRY[type_p, type_q] = fun | |
_KL_MEMOIZE.clear() # reset since lookup order may have changed | |
return fun | |
return decorator | |
class _Match: | |
__slots__ = ["types"] | |
def __init__(self, *types): | |
self.types = types | |
def __eq__(self, other): | |
return self.types == other.types | |
def __le__(self, other): | |
for x, y in zip(self.types, other.types): | |
if not issubclass(x, y): | |
return False | |
if x is not y: | |
break | |
return True | |
def _dispatch_kl(type_p, type_q): | |
""" | |
Find the most specific approximate match, assuming single inheritance. | |
""" | |
matches = [ | |
(super_p, super_q) | |
for super_p, super_q in _KL_REGISTRY | |
if issubclass(type_p, super_p) and issubclass(type_q, super_q) | |
] | |
if not matches: | |
return NotImplemented | |
# Check that the left- and right- lexicographic orders agree. | |
# mypy isn't smart enough to know that _Match implements __lt__ | |
# see: https://github.com/python/typing/issues/760#issuecomment-710670503 | |
left_p, left_q = min(_Match(*m) for m in matches).types # type: ignore[type-var] | |
right_q, right_p = min(_Match(*reversed(m)) for m in matches).types # type: ignore[type-var] | |
left_fun = _KL_REGISTRY[left_p, left_q] | |
right_fun = _KL_REGISTRY[right_p, right_q] | |
if left_fun is not right_fun: | |
warnings.warn( | |
"Ambiguous kl_divergence({}, {}). Please register_kl({}, {})".format( | |
type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__ | |
), | |
RuntimeWarning, | |
) | |
return left_fun | |
def _infinite_like(tensor): | |
""" | |
Helper function for obtaining infinite KL Divergence throughout | |
""" | |
return torch.full_like(tensor, inf) | |
def _x_log_x(tensor): | |
""" | |
Utility function for calculating x log x | |
""" | |
return tensor * tensor.log() | |
def _batch_trace_XXT(bmat): | |
""" | |
Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions | |
""" | |
n = bmat.size(-1) | |
m = bmat.size(-2) | |
flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1) | |
return flat_trace.reshape(bmat.shape[:-2]) | |
def kl_divergence(p: Distribution, q: Distribution) -> torch.Tensor: | |
r""" | |
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions. | |
.. math:: | |
KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx | |
Args: | |
p (Distribution): A :class:`~torch.distributions.Distribution` object. | |
q (Distribution): A :class:`~torch.distributions.Distribution` object. | |
Returns: | |
Tensor: A batch of KL divergences of shape `batch_shape`. | |
Raises: | |
NotImplementedError: If the distribution types have not been registered via | |
:meth:`register_kl`. | |
""" | |
try: | |
fun = _KL_MEMOIZE[type(p), type(q)] | |
except KeyError: | |
fun = _dispatch_kl(type(p), type(q)) | |
_KL_MEMOIZE[type(p), type(q)] = fun | |
if fun is NotImplemented: | |
raise NotImplementedError( | |
f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}" | |
) | |
return fun(p, q) | |
################################################################################ | |
# KL Divergence Implementations | |
################################################################################ | |
# Same distributions | |
def _kl_bernoulli_bernoulli(p, q): | |
t1 = p.probs * ( | |
torch.nn.functional.softplus(-q.logits) | |
- torch.nn.functional.softplus(-p.logits) | |
) | |
t1[q.probs == 0] = inf | |
t1[p.probs == 0] = 0 | |
t2 = (1 - p.probs) * ( | |
torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits) | |
) | |
t2[q.probs == 1] = inf | |
t2[p.probs == 1] = 0 | |
return t1 + t2 | |
def _kl_beta_beta(p, q): | |
sum_params_p = p.concentration1 + p.concentration0 | |
sum_params_q = q.concentration1 + q.concentration0 | |
t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma() | |
t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma() | |
t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1) | |
t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0) | |
t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p) | |
return t1 - t2 + t3 + t4 + t5 | |
def _kl_binomial_binomial(p, q): | |
# from https://math.stackexchange.com/questions/2214993/ | |
# kullback-leibler-divergence-for-binomial-distributions-p-and-q | |
if (p.total_count < q.total_count).any(): | |
raise NotImplementedError( | |
"KL between Binomials where q.total_count > p.total_count is not implemented" | |
) | |
kl = p.total_count * ( | |
p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p() | |
) | |
inf_idxs = p.total_count > q.total_count | |
kl[inf_idxs] = _infinite_like(kl[inf_idxs]) | |
return kl | |
def _kl_categorical_categorical(p, q): | |
t = p.probs * (p.logits - q.logits) | |
t[(q.probs == 0).expand_as(t)] = inf | |
t[(p.probs == 0).expand_as(t)] = 0 | |
return t.sum(-1) | |
def _kl_continuous_bernoulli_continuous_bernoulli(p, q): | |
t1 = p.mean * (p.logits - q.logits) | |
t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs) | |
t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs) | |
return t1 + t2 + t3 | |
def _kl_dirichlet_dirichlet(p, q): | |
# From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/ | |
sum_p_concentration = p.concentration.sum(-1) | |
sum_q_concentration = q.concentration.sum(-1) | |
t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma() | |
t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1) | |
t3 = p.concentration - q.concentration | |
t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1) | |
return t1 - t2 + (t3 * t4).sum(-1) | |
def _kl_exponential_exponential(p, q): | |
rate_ratio = q.rate / p.rate | |
t1 = -rate_ratio.log() | |
return t1 + rate_ratio - 1 | |
def _kl_expfamily_expfamily(p, q): | |
if not type(p) == type(q): | |
raise NotImplementedError( | |
"The cross KL-divergence between different exponential families cannot \ | |
be computed using Bregman divergences" | |
) | |
p_nparams = [np.detach().requires_grad_() for np in p._natural_params] | |
q_nparams = q._natural_params | |
lg_normal = p._log_normalizer(*p_nparams) | |
gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True) | |
result = q._log_normalizer(*q_nparams) - lg_normal | |
for pnp, qnp, g in zip(p_nparams, q_nparams, gradients): | |
term = (qnp - pnp) * g | |
result -= _sum_rightmost(term, len(q.event_shape)) | |
return result | |
def _kl_gamma_gamma(p, q): | |
t1 = q.concentration * (p.rate / q.rate).log() | |
t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration) | |
t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration) | |
t4 = (q.rate - p.rate) * (p.concentration / p.rate) | |
return t1 + t2 + t3 + t4 | |
def _kl_gumbel_gumbel(p, q): | |
ct1 = p.scale / q.scale | |
ct2 = q.loc / q.scale | |
ct3 = p.loc / q.scale | |
t1 = -ct1.log() - ct2 + ct3 | |
t2 = ct1 * _euler_gamma | |
t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3) | |
return t1 + t2 + t3 - (1 + _euler_gamma) | |
def _kl_geometric_geometric(p, q): | |
return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits | |
def _kl_halfnormal_halfnormal(p, q): | |
return _kl_normal_normal(p.base_dist, q.base_dist) | |
def _kl_laplace_laplace(p, q): | |
# From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf | |
scale_ratio = p.scale / q.scale | |
loc_abs_diff = (p.loc - q.loc).abs() | |
t1 = -scale_ratio.log() | |
t2 = loc_abs_diff / q.scale | |
t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale) | |
return t1 + t2 + t3 - 1 | |
def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q): | |
if p.event_shape != q.event_shape: | |
raise ValueError( | |
"KL-divergence between two Low Rank Multivariate Normals with\ | |
different event shapes cannot be computed" | |
) | |
term1 = _batch_lowrank_logdet( | |
q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril | |
) - _batch_lowrank_logdet( | |
p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril | |
) | |
term3 = _batch_lowrank_mahalanobis( | |
q._unbroadcasted_cov_factor, | |
q._unbroadcasted_cov_diag, | |
q.loc - p.loc, | |
q._capacitance_tril, | |
) | |
# Expands term2 according to | |
# inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD) | |
# = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T) | |
qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2) | |
A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False) | |
term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1) | |
term22 = _batch_trace_XXT( | |
p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1) | |
) | |
term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2)) | |
term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor)) | |
term2 = term21 + term22 - term23 - term24 | |
return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) | |
def _kl_multivariatenormal_lowrankmultivariatenormal(p, q): | |
if p.event_shape != q.event_shape: | |
raise ValueError( | |
"KL-divergence between two (Low Rank) Multivariate Normals with\ | |
different event shapes cannot be computed" | |
) | |
term1 = _batch_lowrank_logdet( | |
q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril | |
) - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) | |
term3 = _batch_lowrank_mahalanobis( | |
q._unbroadcasted_cov_factor, | |
q._unbroadcasted_cov_diag, | |
q.loc - p.loc, | |
q._capacitance_tril, | |
) | |
# Expands term2 according to | |
# inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T | |
# = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T | |
qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2) | |
A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False) | |
term21 = _batch_trace_XXT( | |
p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1) | |
) | |
term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril)) | |
term2 = term21 - term22 | |
return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) | |
def _kl_lowrankmultivariatenormal_multivariatenormal(p, q): | |
if p.event_shape != q.event_shape: | |
raise ValueError( | |
"KL-divergence between two (Low Rank) Multivariate Normals with\ | |
different event shapes cannot be computed" | |
) | |
term1 = 2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum( | |
-1 | |
) - _batch_lowrank_logdet( | |
p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril | |
) | |
term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) | |
# Expands term2 according to | |
# inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD) | |
combined_batch_shape = torch._C._infer_size( | |
q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_cov_factor.shape[:-2] | |
) | |
n = p.event_shape[0] | |
q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) | |
p_cov_factor = p._unbroadcasted_cov_factor.expand( | |
combined_batch_shape + (n, p.cov_factor.size(-1)) | |
) | |
p_cov_diag = torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand( | |
combined_batch_shape + (n, n) | |
) | |
term21 = _batch_trace_XXT( | |
torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False) | |
) | |
term22 = _batch_trace_XXT( | |
torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False) | |
) | |
term2 = term21 + term22 | |
return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) | |
def _kl_multivariatenormal_multivariatenormal(p, q): | |
# From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence | |
if p.event_shape != q.event_shape: | |
raise ValueError( | |
"KL-divergence between two Multivariate Normals with\ | |
different event shapes cannot be computed" | |
) | |
half_term1 = q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum( | |
-1 | |
) - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) | |
combined_batch_shape = torch._C._infer_size( | |
q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_scale_tril.shape[:-2] | |
) | |
n = p.event_shape[0] | |
q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) | |
p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) | |
term2 = _batch_trace_XXT( | |
torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False) | |
) | |
term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) | |
return half_term1 + 0.5 * (term2 + term3 - n) | |
def _kl_normal_normal(p, q): | |
var_ratio = (p.scale / q.scale).pow(2) | |
t1 = ((p.loc - q.loc) / q.scale).pow(2) | |
return 0.5 * (var_ratio + t1 - 1 - var_ratio.log()) | |
def _kl_onehotcategorical_onehotcategorical(p, q): | |
return _kl_categorical_categorical(p._categorical, q._categorical) | |
def _kl_pareto_pareto(p, q): | |
# From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf | |
scale_ratio = p.scale / q.scale | |
alpha_ratio = q.alpha / p.alpha | |
t1 = q.alpha * scale_ratio.log() | |
t2 = -alpha_ratio.log() | |
result = t1 + t2 + alpha_ratio - 1 | |
result[p.support.lower_bound < q.support.lower_bound] = inf | |
return result | |
def _kl_poisson_poisson(p, q): | |
return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate) | |
def _kl_transformed_transformed(p, q): | |
if p.transforms != q.transforms: | |
raise NotImplementedError | |
if p.event_shape != q.event_shape: | |
raise NotImplementedError | |
return kl_divergence(p.base_dist, q.base_dist) | |
def _kl_uniform_uniform(p, q): | |
result = ((q.high - q.low) / (p.high - p.low)).log() | |
result[(q.low > p.low) | (q.high < p.high)] = inf | |
return result | |
# Different distributions | |
def _kl_bernoulli_poisson(p, q): | |
return -p.entropy() - (p.probs * q.rate.log() - q.rate) | |
def _kl_beta_continuous_bernoulli(p, q): | |
return ( | |
-p.entropy() | |
- p.mean * q.logits | |
- torch.log1p(-q.probs) | |
- q._cont_bern_log_norm() | |
) | |
def _kl_beta_infinity(p, q): | |
return _infinite_like(p.concentration1) | |
def _kl_beta_exponential(p, q): | |
return ( | |
-p.entropy() | |
- q.rate.log() | |
+ q.rate * (p.concentration1 / (p.concentration1 + p.concentration0)) | |
) | |
def _kl_beta_gamma(p, q): | |
t1 = -p.entropy() | |
t2 = q.concentration.lgamma() - q.concentration * q.rate.log() | |
t3 = (q.concentration - 1) * ( | |
p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma() | |
) | |
t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0) | |
return t1 + t2 - t3 + t4 | |
# TODO: Add Beta-Laplace KL Divergence | |
def _kl_beta_normal(p, q): | |
E_beta = p.concentration1 / (p.concentration1 + p.concentration0) | |
var_normal = q.scale.pow(2) | |
t1 = -p.entropy() | |
t2 = 0.5 * (var_normal * 2 * math.pi).log() | |
t3 = ( | |
E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1) | |
+ E_beta.pow(2) | |
) * 0.5 | |
t4 = q.loc * E_beta | |
t5 = q.loc.pow(2) * 0.5 | |
return t1 + t2 + (t3 - t4 + t5) / var_normal | |
def _kl_beta_uniform(p, q): | |
result = -p.entropy() + (q.high - q.low).log() | |
result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf | |
return result | |
# Note that the KL between a ContinuousBernoulli and Beta has no closed form | |
def _kl_continuous_bernoulli_infinity(p, q): | |
return _infinite_like(p.probs) | |
def _kl_continuous_bernoulli_exponential(p, q): | |
return -p.entropy() - torch.log(q.rate) + q.rate * p.mean | |
# Note that the KL between a ContinuousBernoulli and Gamma has no closed form | |
# TODO: Add ContinuousBernoulli-Laplace KL Divergence | |
def _kl_continuous_bernoulli_normal(p, q): | |
t1 = -p.entropy() | |
t2 = 0.5 * (math.log(2.0 * math.pi) + torch.square(q.loc / q.scale)) + torch.log( | |
q.scale | |
) | |
t3 = (p.variance + torch.square(p.mean) - 2.0 * q.loc * p.mean) / ( | |
2.0 * torch.square(q.scale) | |
) | |
return t1 + t2 + t3 | |
def _kl_continuous_bernoulli_uniform(p, q): | |
result = -p.entropy() + (q.high - q.low).log() | |
return torch.where( | |
torch.max( | |
torch.ge(q.low, p.support.lower_bound), | |
torch.le(q.high, p.support.upper_bound), | |
), | |
torch.ones_like(result) * inf, | |
result, | |
) | |
def _kl_exponential_infinity(p, q): | |
return _infinite_like(p.rate) | |
def _kl_exponential_gamma(p, q): | |
ratio = q.rate / p.rate | |
t1 = -q.concentration * torch.log(ratio) | |
return ( | |
t1 | |
+ ratio | |
+ q.concentration.lgamma() | |
+ q.concentration * _euler_gamma | |
- (1 + _euler_gamma) | |
) | |
def _kl_exponential_gumbel(p, q): | |
scale_rate_prod = p.rate * q.scale | |
loc_scale_ratio = q.loc / q.scale | |
t1 = scale_rate_prod.log() - 1 | |
t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1) | |
t3 = scale_rate_prod.reciprocal() | |
return t1 - loc_scale_ratio + t2 + t3 | |
# TODO: Add Exponential-Laplace KL Divergence | |
def _kl_exponential_normal(p, q): | |
var_normal = q.scale.pow(2) | |
rate_sqr = p.rate.pow(2) | |
t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi) | |
t2 = rate_sqr.reciprocal() | |
t3 = q.loc / p.rate | |
t4 = q.loc.pow(2) * 0.5 | |
return t1 - 1 + (t2 - t3 + t4) / var_normal | |
def _kl_gamma_infinity(p, q): | |
return _infinite_like(p.concentration) | |
def _kl_gamma_exponential(p, q): | |
return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate | |
def _kl_gamma_gumbel(p, q): | |
beta_scale_prod = p.rate * q.scale | |
loc_scale_ratio = q.loc / q.scale | |
t1 = ( | |
(p.concentration - 1) * p.concentration.digamma() | |
- p.concentration.lgamma() | |
- p.concentration | |
) | |
t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod | |
t3 = ( | |
torch.exp(loc_scale_ratio) | |
* (1 + beta_scale_prod.reciprocal()).pow(-p.concentration) | |
- loc_scale_ratio | |
) | |
return t1 + t2 + t3 | |
# TODO: Add Gamma-Laplace KL Divergence | |
def _kl_gamma_normal(p, q): | |
var_normal = q.scale.pow(2) | |
beta_sqr = p.rate.pow(2) | |
t1 = ( | |
0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi) | |
- p.concentration | |
- p.concentration.lgamma() | |
) | |
t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr | |
t3 = q.loc * p.concentration / p.rate | |
t4 = 0.5 * q.loc.pow(2) | |
return ( | |
t1 | |
+ (p.concentration - 1) * p.concentration.digamma() | |
+ (t2 - t3 + t4) / var_normal | |
) | |
def _kl_gumbel_infinity(p, q): | |
return _infinite_like(p.loc) | |
# TODO: Add Gumbel-Laplace KL Divergence | |
def _kl_gumbel_normal(p, q): | |
param_ratio = p.scale / q.scale | |
t1 = (param_ratio / math.sqrt(2 * math.pi)).log() | |
t2 = (math.pi * param_ratio * 0.5).pow(2) / 3 | |
t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5 | |
return -t1 + t2 + t3 - (_euler_gamma + 1) | |
def _kl_laplace_infinity(p, q): | |
return _infinite_like(p.loc) | |
def _kl_laplace_normal(p, q): | |
var_normal = q.scale.pow(2) | |
scale_sqr_var_ratio = p.scale.pow(2) / var_normal | |
t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi) | |
t2 = 0.5 * p.loc.pow(2) | |
t3 = p.loc * q.loc | |
t4 = 0.5 * q.loc.pow(2) | |
return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1 | |
def _kl_normal_infinity(p, q): | |
return _infinite_like(p.loc) | |
def _kl_normal_gumbel(p, q): | |
mean_scale_ratio = p.loc / q.scale | |
var_scale_sqr_ratio = (p.scale / q.scale).pow(2) | |
loc_scale_ratio = q.loc / q.scale | |
t1 = var_scale_sqr_ratio.log() * 0.5 | |
t2 = mean_scale_ratio - loc_scale_ratio | |
t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio) | |
return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi))) | |
def _kl_normal_laplace(p, q): | |
loc_diff = p.loc - q.loc | |
scale_ratio = p.scale / q.scale | |
loc_diff_scale_ratio = loc_diff / p.scale | |
t1 = torch.log(scale_ratio) | |
t2 = ( | |
math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2)) | |
) | |
t3 = loc_diff * torch.erf(math.sqrt(0.5) * loc_diff_scale_ratio) | |
return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi))) | |
def _kl_pareto_infinity(p, q): | |
return _infinite_like(p.scale) | |
def _kl_pareto_exponential(p, q): | |
scale_rate_prod = p.scale * q.rate | |
t1 = (p.alpha / scale_rate_prod).log() | |
t2 = p.alpha.reciprocal() | |
t3 = p.alpha * scale_rate_prod / (p.alpha - 1) | |
result = t1 - t2 + t3 - 1 | |
result[p.alpha <= 1] = inf | |
return result | |
def _kl_pareto_gamma(p, q): | |
common_term = p.scale.log() + p.alpha.reciprocal() | |
t1 = p.alpha.log() - common_term | |
t2 = q.concentration.lgamma() - q.concentration * q.rate.log() | |
t3 = (1 - q.concentration) * common_term | |
t4 = q.rate * p.alpha * p.scale / (p.alpha - 1) | |
result = t1 + t2 + t3 + t4 - 1 | |
result[p.alpha <= 1] = inf | |
return result | |
# TODO: Add Pareto-Laplace KL Divergence | |
def _kl_pareto_normal(p, q): | |
var_normal = 2 * q.scale.pow(2) | |
common_term = p.scale / (p.alpha - 1) | |
t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log() | |
t2 = p.alpha.reciprocal() | |
t3 = p.alpha * common_term.pow(2) / (p.alpha - 2) | |
t4 = (p.alpha * common_term - q.loc).pow(2) | |
result = t1 - t2 + (t3 + t4) / var_normal - 1 | |
result[p.alpha <= 2] = inf | |
return result | |
def _kl_poisson_infinity(p, q): | |
return _infinite_like(p.rate) | |
def _kl_uniform_beta(p, q): | |
common_term = p.high - p.low | |
t1 = torch.log(common_term) | |
t2 = ( | |
(q.concentration1 - 1) | |
* (_x_log_x(p.high) - _x_log_x(p.low) - common_term) | |
/ common_term | |
) | |
t3 = ( | |
(q.concentration0 - 1) | |
* (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term) | |
/ common_term | |
) | |
t4 = ( | |
q.concentration1.lgamma() | |
+ q.concentration0.lgamma() | |
- (q.concentration1 + q.concentration0).lgamma() | |
) | |
result = t3 + t4 - t1 - t2 | |
result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf | |
return result | |
def _kl_uniform_continuous_bernoulli(p, q): | |
result = ( | |
-p.entropy() | |
- p.mean * q.logits | |
- torch.log1p(-q.probs) | |
- q._cont_bern_log_norm() | |
) | |
return torch.where( | |
torch.max( | |
torch.ge(p.high, q.support.upper_bound), | |
torch.le(p.low, q.support.lower_bound), | |
), | |
torch.ones_like(result) * inf, | |
result, | |
) | |
def _kl_uniform_exponetial(p, q): | |
result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log() | |
result[p.low < q.support.lower_bound] = inf | |
return result | |
def _kl_uniform_gamma(p, q): | |
common_term = p.high - p.low | |
t1 = common_term.log() | |
t2 = q.concentration.lgamma() - q.concentration * q.rate.log() | |
t3 = ( | |
(1 - q.concentration) | |
* (_x_log_x(p.high) - _x_log_x(p.low) - common_term) | |
/ common_term | |
) | |
t4 = q.rate * (p.high + p.low) / 2 | |
result = -t1 + t2 + t3 + t4 | |
result[p.low < q.support.lower_bound] = inf | |
return result | |
def _kl_uniform_gumbel(p, q): | |
common_term = q.scale / (p.high - p.low) | |
high_loc_diff = (p.high - q.loc) / q.scale | |
low_loc_diff = (p.low - q.loc) / q.scale | |
t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff) | |
t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff)) | |
return t1 - t2 | |
# TODO: Uniform-Laplace KL Divergence | |
def _kl_uniform_normal(p, q): | |
common_term = p.high - p.low | |
t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log() | |
t2 = (common_term).pow(2) / 12 | |
t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2) | |
return t1 + 0.5 * (t2 + t3) / q.scale.pow(2) | |
def _kl_uniform_pareto(p, q): | |
support_uniform = p.high - p.low | |
t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log() | |
t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform | |
result = t2 * (q.alpha + 1) - t1 | |
result[p.low < q.support.lower_bound] = inf | |
return result | |
def _kl_independent_independent(p, q): | |
if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims: | |
raise NotImplementedError | |
result = kl_divergence(p.base_dist, q.base_dist) | |
return _sum_rightmost(result, p.reinterpreted_batch_ndims) | |
def _kl_cauchy_cauchy(p, q): | |
# From https://arxiv.org/abs/1905.10965 | |
t1 = ((p.scale + q.scale).pow(2) + (p.loc - q.loc).pow(2)).log() | |
t2 = (4 * p.scale * q.scale).log() | |
return t1 - t2 | |
def _add_kl_info(): | |
"""Appends a list of implemented KL functions to the doc for kl_divergence.""" | |
rows = [ | |
"KL divergence is currently implemented for the following distribution pairs:" | |
] | |
for p, q in sorted( | |
_KL_REGISTRY, key=lambda p_q: (p_q[0].__name__, p_q[1].__name__) | |
): | |
rows.append( | |
f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`" | |
) | |
kl_info = "\n\t".join(rows) | |
if kl_divergence.__doc__: | |
kl_divergence.__doc__ += kl_info # type: ignore[operator] | |