import os, sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch from torch import nn from torch.nn import functional as F # from .types_ import * class VanillaVAE(nn.Module): def __init__(self,args, in_channels: int, latent_dim: int, hidden_dims = None, **kwargs) -> None: super(VanillaVAE, self).__init__() self.latent_dim = latent_dim modules = [] if hidden_dims is None: hidden_dims = [32, 64, 128, 256, 512] if latent_dim is None: latent_dim = 512 # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input): """ Encodes the input by passing through the encoder network and returns the latent codes. :param input: (Tensor) Input tensor to encoder [N x C x H x W] :return: (Tensor) List of latent codes """ result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) # log_var = self.fc_var(result) return mu def decode(self, z): """ Maps the given latent codes onto the image space. :param z: (Tensor) [B x D] :return: (Tensor) [B x C x H x W] """ result = self.decoder_input(z) result = result.view(-1, 512, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result # def reparameterize(self, mu, logvar): # """ # Reparameterization trick to sample from N(mu, var) from # N(0,1). # :param mu: (Tensor) Mean of the latent Gaussian [B x D] # :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] # :return: (Tensor) [B x D] # """ # std = torch.exp(0.5 * logvar) # eps = torch.randn_like(std) # return eps * std + mu def forward(self, input, **kwargs): mu = self.encode(input) # z = self.reparameterize(mu, log_var) return self.decode(mu) def loss_function(self, *args, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] # mu = args[2] # log_var = args[3] # kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss =F.mse_loss(recons, input) # kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) loss = recons_loss return loss # {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':recons_loss.detach()} def generate(self, x, **kwargs): """ Given an input image x, returns the reconstructed image :param x: (Tensor) [B x C x H x W] :return: (Tensor) [B x C x H x W] """ return self.forward(x)[0]