Spaces:
Configuration error
Configuration error
import os, sys | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
# from .types_ import * | |
class VanillaVAE(nn.Module): | |
def __init__(self,args, | |
in_channels: int, | |
latent_dim: int, | |
hidden_dims = None, | |
**kwargs) -> None: | |
super(VanillaVAE, self).__init__() | |
self.latent_dim = latent_dim | |
modules = [] | |
if hidden_dims is None: | |
hidden_dims = [32, 64, 128, 256, 512] | |
if latent_dim is None: | |
latent_dim = 512 | |
# Build Encoder | |
for h_dim in hidden_dims: | |
modules.append( | |
nn.Sequential( | |
nn.Conv2d(in_channels, out_channels=h_dim, | |
kernel_size= 3, stride= 2, padding = 1), | |
nn.BatchNorm2d(h_dim), | |
nn.LeakyReLU()) | |
) | |
in_channels = h_dim | |
self.encoder = nn.Sequential(*modules) | |
self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim) | |
self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim) | |
# Build Decoder | |
modules = [] | |
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) | |
hidden_dims.reverse() | |
for i in range(len(hidden_dims) - 1): | |
modules.append( | |
nn.Sequential( | |
nn.ConvTranspose2d(hidden_dims[i], | |
hidden_dims[i + 1], | |
kernel_size=3, | |
stride = 2, | |
padding=1, | |
output_padding=1), | |
nn.BatchNorm2d(hidden_dims[i + 1]), | |
nn.LeakyReLU()) | |
) | |
self.decoder = nn.Sequential(*modules) | |
self.final_layer = nn.Sequential( | |
nn.ConvTranspose2d(hidden_dims[-1], | |
hidden_dims[-1], | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
output_padding=1), | |
nn.BatchNorm2d(hidden_dims[-1]), | |
nn.LeakyReLU(), | |
nn.Conv2d(hidden_dims[-1], out_channels= 3, | |
kernel_size= 3, padding= 1), | |
nn.Tanh()) | |
def encode(self, input): | |
""" | |
Encodes the input by passing through the encoder network | |
and returns the latent codes. | |
:param input: (Tensor) Input tensor to encoder [N x C x H x W] | |
:return: (Tensor) List of latent codes | |
""" | |
result = self.encoder(input) | |
result = torch.flatten(result, start_dim=1) | |
# Split the result into mu and var components | |
# of the latent Gaussian distribution | |
mu = self.fc_mu(result) | |
# log_var = self.fc_var(result) | |
return mu | |
def decode(self, z): | |
""" | |
Maps the given latent codes | |
onto the image space. | |
:param z: (Tensor) [B x D] | |
:return: (Tensor) [B x C x H x W] | |
""" | |
result = self.decoder_input(z) | |
result = result.view(-1, 512, 2, 2) | |
result = self.decoder(result) | |
result = self.final_layer(result) | |
return result | |
# def reparameterize(self, mu, logvar): | |
# """ | |
# Reparameterization trick to sample from N(mu, var) from | |
# N(0,1). | |
# :param mu: (Tensor) Mean of the latent Gaussian [B x D] | |
# :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] | |
# :return: (Tensor) [B x D] | |
# """ | |
# std = torch.exp(0.5 * logvar) | |
# eps = torch.randn_like(std) | |
# return eps * std + mu | |
def forward(self, input, **kwargs): | |
mu = self.encode(input) | |
# z = self.reparameterize(mu, log_var) | |
return self.decode(mu) | |
def loss_function(self, | |
*args, | |
**kwargs) -> dict: | |
""" | |
Computes the VAE loss function. | |
KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} | |
:param args: | |
:param kwargs: | |
:return: | |
""" | |
recons = args[0] | |
input = args[1] | |
# mu = args[2] | |
# log_var = args[3] | |
# kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset | |
recons_loss =F.mse_loss(recons, input) | |
# kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) | |
loss = recons_loss | |
return loss | |
# {'loss': loss, 'Reconstruction_Loss':recons_loss.detach(), 'KLD':recons_loss.detach()} | |
def generate(self, x, **kwargs): | |
""" | |
Given an input image x, returns the reconstructed image | |
:param x: (Tensor) [B x C x H x W] | |
:return: (Tensor) [B x C x H x W] | |
""" | |
return self.forward(x)[0] |