Spaces:
Runtime error
Runtime error
File size: 4,232 Bytes
c5ca37a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import math
import torch
import torch.nn as nn
from .encoder import EncoderBase
from ..utils import log_sum_exp
class GaussianEncoderBase(EncoderBase):
"""docstring for EncoderBase"""
def __init__(self):
super(GaussianEncoderBase, self).__init__()
def freeze(self):
for param in self.parameters():
param.requires_grad = False
def forward(self, x):
"""
Args:
x: (batch_size, *)
Returns: Tensor1, Tensor2
Tensor1: the mean tensor, shape (batch, nz)
Tensor2: the logvar tensor, shape (batch, nz)
"""
raise NotImplementedError
def encode_stats(self, x):
return self.forward(x)
def sample(self, input, nsamples):
"""sampling from the encoder
Returns: Tensor1
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
"""
# (batch_size, nz)
mu, logvar = self.forward(input)
# (batch, nsamples, nz)
z = self.reparameterize(mu, logvar, nsamples)
return z, (mu, logvar)
def encode(self, input, nsamples):
"""perform the encoding and compute the KL term
Returns: Tensor1, Tensor2
Tensor1: the tensor latent z with shape [batch, nsamples, nz]
Tensor2: the tenor of KL for each x with shape [batch]
"""
# (batch_size, nz)
mu, logvar = self.forward(input)
# (batch, nsamples, nz)
z = self.reparameterize(mu, logvar, nsamples)
KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1)
return z, KL
def reparameterize(self, mu, logvar, nsamples=1):
"""sample from posterior Gaussian family
Args:
mu: Tensor
Mean of gaussian distribution with shape (batch, nz)
logvar: Tensor
logvar of gaussian distibution with shape (batch, nz)
Returns: Tensor
Sampled z with shape (batch, nsamples, nz)
"""
batch_size, nz = mu.size()
std = logvar.mul(0.5).exp()
mu_expd = mu.unsqueeze(1).expand(batch_size, nsamples, nz)
std_expd = std.unsqueeze(1).expand(batch_size, nsamples, nz)
eps = torch.zeros_like(std_expd).normal_()
return mu_expd + torch.mul(eps, std_expd)
def eval_inference_dist(self, x, z, param=None):
"""this function computes log q(z | x)
Args:
z: tensor
different z points that will be evaluated, with
shape [batch, nsamples, nz]
Returns: Tensor1
Tensor1: log q(z|x) with shape [batch, nsamples]
"""
nz = z.size(2)
if not param:
mu, logvar = self.forward(x)
else:
mu, logvar = param
# (batch_size, 1, nz)
mu, logvar = mu.unsqueeze(1), logvar.unsqueeze(1)
var = logvar.exp()
# (batch_size, nsamples, nz)
dev = z - mu
# (batch_size, nsamples)
log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
return log_density
def calc_mi(self, x):
"""Approximate the mutual information between x and z
I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z))
Returns: Float
"""
# [x_batch, nz]
mu, logvar = self.forward(x)
x_batch, nz = mu.size()
# E_{q(z|x)}log(q(z|x)) = -0.5*nz*log(2*\pi) - 0.5*(1+logvar).sum(-1)
neg_entropy = (-0.5 * nz * math.log(2 * math.pi)- 0.5 * (1 + logvar).sum(-1)).mean()
# [z_batch, 1, nz]
z_samples = self.reparameterize(mu, logvar, 1)
# [1, x_batch, nz]
mu, logvar = mu.unsqueeze(0), logvar.unsqueeze(0)
var = logvar.exp()
# (z_batch, x_batch, nz)
dev = z_samples - mu
# (z_batch, x_batch)
log_density = -0.5 * ((dev ** 2) / var).sum(dim=-1) - \
0.5 * (nz * math.log(2 * math.pi) + logvar.sum(-1))
# log q(z): aggregate posterior
# [z_batch]
log_qz = log_sum_exp(log_density, dim=1) - math.log(x_batch)
return (neg_entropy - log_qz.mean(-1)).item() |