mueller-franzes's picture
init
f85e212
raw
history blame
248 Bytes
import torch
def kl_gaussians(mean1, logvar1, mean2, logvar2):
""" Compute the KL divergence between two gaussians."""
return 0.5 * (logvar2-logvar1 + torch.exp(logvar1 - logvar2) + torch.pow(mean1 - mean2, 2) * torch.exp(-logvar2)-1.0)