import math import torch import torch.nn as nn from ..utils import log_sum_exp class EncoderBase(nn.Module): """docstring for EncoderBase""" def __init__(self): super(EncoderBase, self).__init__() def forward(self, x): """ Args: x: (batch_size, *) Returns: the tensors required to parameterize a distribution. E.g. for Gaussian encoder it returns the mean and variance tensors """ raise NotImplementedError def sample(self, input, nsamples): """sampling from the encoder Returns: Tensor1 Tensor1: the tensor latent z with shape [batch, nsamples, nz] """ raise NotImplementedError def encode(self, input, nsamples): """perform the encoding and compute the KL term Returns: Tensor1, Tensor2 Tensor1: the tensor latent z with shape [batch, nsamples, nz] Tensor2: the tenor of KL for each x with shape [batch] """ raise NotImplementedError def eval_inference_dist(self, x, z, param=None): """this function computes log q(z | x) Args: z: tensor different z points that will be evaluated, with shape [batch, nsamples, nz] Returns: Tensor1 Tensor1: log q(z|x) with shape [batch, nsamples] """ raise NotImplementedError def calc_mi(self, x): """Approximate the mutual information between x and z I(x, z) = E_xE_{q(z|x)}log(q(z|x)) - E_xE_{q(z|x)}log(q(z)) Returns: Float """ raise NotImplementedError