soumickmj's picture
Upload cceVAE
c3bcb92 verified
import numpy as np
import torch
import torch.nn as nn
import torch.distributions as dist
from .ae_bases import BasicEncoder, BasicGenerator
class VAE(torch.nn.Module):
def __init__(
self,
d,
input_size,
z_dim=256,
fmap_sizes=(16, 64, 256, 1024),
to_1x1=True,
conv_params=None,
tconv_params=None,
normalization_op=None,
normalization_params=None,
activation_op="leakyrelu",
activation_params=None,
block_op=None,
block_params=None,
*args,
**kwargs
):
"""Basic VAE build up of a symetric BasicEncoder (Encoder) and BasicGenerator (Decoder)
Args:
input_size ((int, int, int): Size of the input in format CxHxW):
z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). Defaults to 256
fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each
int defines the number of feature maps in the layer]. Defaults to (16, 64, 256, 1024).
to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected)
or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ].
Defaults to True.
conv_op ([torch.nn.Module], optional): [Convolutioon operation used in the encoder to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d.
conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).
tconv_op ([torch.nn.Module], optional): [Upsampling/ Transposed Conv operation used in the decoder to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d.
tconv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).
normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.
normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.
activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp.
block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None.
"""
super(VAE, self).__init__()
if d == 2:
conv_op = nn.Conv2d
tconv_op = nn.ConvTranspose2d
else:
conv_op = nn.Conv3d
tconv_op = nn.ConvTranspose3d
match (activation_op):
case "relu":
activation_op = nn.ReLU
case "prelu":
activation_op = nn.PReLU
case "leakyrelu":
activation_op = nn.LeakyReLU
case _:
raise ValueError(f"Activation function {activation_op} not supported")
input_size_enc = list(input_size)
input_size_dec = list(input_size)
self.enc = BasicEncoder(
input_size=input_size_enc,
fmap_sizes=fmap_sizes,
z_dim=z_dim * 2,
conv_op=conv_op,
conv_params=conv_params,
normalization_op=normalization_op,
normalization_params=normalization_params,
activation_op=activation_op,
activation_params=activation_params,
block_op=block_op,
block_params=block_params,
to_1x1=to_1x1,
)
self.dec = BasicGenerator(
input_size=input_size_dec,
fmap_sizes=fmap_sizes[::-1],
z_dim=z_dim,
upsample_op=tconv_op,
conv_params=tconv_params,
normalization_op=normalization_op,
normalization_params=normalization_params,
activation_op=activation_op,
activation_params=activation_params,
block_op=block_op,
block_params=block_params,
to_1x1=to_1x1,
)
self.hidden_size = self.enc.output_size
def forward(self, inpt, sample=True, no_dist=False, **kwargs):
y1 = self.enc(inpt, **kwargs)
mu, log_std = torch.chunk(y1, 2, dim=1)
std = torch.exp(log_std)
z_dist = dist.Normal(mu, std)
if sample:
z_sample = z_dist.rsample()
else:
z_sample = mu
x_rec = self.dec(z_sample)
if no_dist:
return x_rec
else:
return x_rec, z_dist
def encode(self, inpt, **kwargs):
"""Encodes a sample and returns the paramters for the approx inference dist. (Normal)
Args:
inpt ([tensor]): The input to encode
Returns:
mu : The mean used to parameterized a Normal distribution
std: The standard deviation used to parameterized a Normal distribution
"""
enc = self.enc(inpt, **kwargs)
mu, log_std = torch.chunk(enc, 2, dim=1)
std = torch.exp(log_std)
return mu, std
def decode(self, inpt, **kwargs):
"""Decodes a latent space sample, used the generative model (decode = mu_{gen}(z) as used in p(x|z) = N(x | mu_{gen}(z), 1) ).
Args:
inpt ([type]): A sample from the latent space to decode
Returns:
[type]: [description]
"""
x_rec = self.dec(inpt, **kwargs)
return x_rec
class AE(torch.nn.Module):
def __init__(
self,
input_size,
z_dim=1024,
fmap_sizes=(16, 64, 256, 1024),
to_1x1=True,
conv_op=torch.nn.Conv2d,
conv_params=None,
tconv_op=torch.nn.ConvTranspose2d,
tconv_params=None,
normalization_op=None,
normalization_params=None,
activation_op=torch.nn.LeakyReLU,
activation_params=None,
block_op=None,
block_params=None,
*args,
**kwargs
):
"""Basic AE build up of a symetric BasicEncoder (Encoder) and BasicGenerator (Decoder)
Args:
input_size ((int, int, int): Size of the input in format CxHxW):
z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). Defaults to 256
fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each
int defines the number of feature maps in the layer]. Defaults to (16, 64, 256, 1024).
to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected)
or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ].
Defaults to True.
conv_op ([torch.nn.Module], optional): [Convolutioon operation used in the encoder to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d.
conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).
tconv_op ([torch.nn.Module], optional): [Upsampling/ Transposed Conv operation used in the decoder to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d.
tconv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).
normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.
normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.
activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp.
block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None.
"""
super(AE, self).__init__()
input_size_enc = list(input_size)
input_size_dec = list(input_size)
self.enc = BasicEncoder(
input_size=input_size_enc,
fmap_sizes=fmap_sizes,
z_dim=z_dim,
conv_op=conv_op,
conv_params=conv_params,
normalization_op=normalization_op,
normalization_params=normalization_params,
activation_op=activation_op,
activation_params=activation_params,
block_op=block_op,
block_params=block_params,
to_1x1=to_1x1,
)
self.dec = BasicGenerator(
input_size=input_size_dec,
fmap_sizes=fmap_sizes[::-1],
z_dim=z_dim,
upsample_op=tconv_op,
conv_params=tconv_params,
normalization_op=normalization_op,
normalization_params=normalization_params,
activation_op=activation_op,
activation_params=activation_params,
block_op=block_op,
block_params=block_params,
to_1x1=to_1x1,
)
self.hidden_size = self.enc.output_size
def forward(self, inpt, **kwargs):
y1 = self.enc(inpt, **kwargs)
x_rec = self.dec(y1)
return x_rec
def encode(self, inpt, **kwargs):
"""Encodes a input sample to a latent space sample
Args:
inpt ([tensor]): Input sample
Returns:
enc: Encoded input sample in the latent space
"""
enc = self.enc(inpt, **kwargs)
return enc
def decode(self, inpt, **kwargs):
"""Decodes a latent space sample back to the input space
Args:
inpt ([tensor]): [Latent space sample]
Returns:
[rec]: [Encoded latent sample back in the input space]
"""
rec = self.dec(inpt, **kwargs)
return rec