from itertools import chain import math import torch import torch.nn as nn from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence from .gaussian_encoder import GaussianEncoderBase from ..utils import log_sum_exp class GaussianLSTMEncoder(GaussianEncoderBase): """Gaussian LSTM Encoder with constant-length input""" def __init__(self, args, vocab_size, model_init, emb_init): super(GaussianLSTMEncoder, self).__init__() self.ni = args.ni self.nh = args.enc_nh self.nz = args.nz self.args = args self.embed = nn.Embedding(vocab_size, args.ni) self.lstm = nn.LSTM(input_size=args.ni, hidden_size=args.enc_nh, num_layers=1, batch_first=True, dropout=0) self.linear = nn.Linear(args.enc_nh, 2 * args.nz, bias=False) self.reset_parameters(model_init, emb_init) def reset_parameters(self, model_init, emb_init): # for name, param in self.lstm.named_parameters(): # # self.initializer(param) # if 'bias' in name: # nn.init.constant_(param, 0.0) # # model_init(param) # elif 'weight' in name: # model_init(param) # model_init(self.linear.weight) # emb_init(self.embed.weight) for param in self.parameters(): model_init(param) emb_init(self.embed.weight) def forward(self, input): """ Args: x: (batch_size, seq_len) Returns: Tensor1, Tensor2 Tensor1: the mean tensor, shape (batch, nz) Tensor2: the logvar tensor, shape (batch, nz) """ # (batch_size, seq_len-1, args.ni) word_embed = self.embed(input) _, (last_state, last_cell) = self.lstm(word_embed) mean, logvar = self.linear(last_state).chunk(2, -1) # fix variance as a pre-defined value if self.args.fix_var > 0: logvar = mean.new_tensor([[[math.log(self.args.fix_var)]]]).expand_as(mean) return mean.squeeze(0), logvar.squeeze(0) # def eval_inference_mode(self, x): # """compute the mode points in the inference distribution # (in Gaussian case) # Returns: Tensor # Tensor: the posterior mode points with shape (*, nz) # """ # # (batch_size, nz) # mu, logvar = self.forward(x) class VarLSTMEncoder(GaussianLSTMEncoder): """Gaussian LSTM Encoder with variable-length input""" def __init__(self, args, vocab_size, model_init, emb_init): super(VarLSTMEncoder, self).__init__(args, vocab_size, model_init, emb_init) def forward(self, input): """ Args: input: tuple which contains x and sents_len x: (batch_size, seq_len) sents_len: long tensor of sentence lengths Returns: Tensor1, Tensor2 Tensor1: the mean tensor, shape (batch, nz) Tensor2: the logvar tensor, shape (batch, nz) """ input, sents_len = input # (batch_size, seq_len, args.ni) word_embed = self.embed(input) packed_embed = pack_padded_sequence(word_embed, sents_len.tolist(), batch_first=True) _, (last_state, last_cell) = self.lstm(packed_embed) mean, logvar = self.linear(last_state).chunk(2, -1) return mean.squeeze(0), logvar.squeeze(0) def encode(self, input, nsamples): """perform the encoding and compute the KL term Args: input: tuple which contains x and sents_len Returns: Tensor1, Tensor2 Tensor1: the tensor latent z with shape [batch, nsamples, nz] Tensor2: the tenor of KL for each x with shape [batch] """ # (batch_size, nz) mu, logvar = self.forward(input) # (batch, nsamples, nz) z = self.reparameterize(mu, logvar, nsamples) KL = 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) return z, KL