File size: 11,329 Bytes
c3bcb92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 |
import numpy as np
import torch
import torch.nn as nn
import torch.distributions as dist
from .ae_bases import BasicEncoder, BasicGenerator
class VAE(torch.nn.Module):
def __init__(
self,
d,
input_size,
z_dim=256,
fmap_sizes=(16, 64, 256, 1024),
to_1x1=True,
conv_params=None,
tconv_params=None,
normalization_op=None,
normalization_params=None,
activation_op="leakyrelu",
activation_params=None,
block_op=None,
block_params=None,
*args,
**kwargs
):
"""Basic VAE build up of a symetric BasicEncoder (Encoder) and BasicGenerator (Decoder)
Args:
input_size ((int, int, int): Size of the input in format CxHxW):
z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). Defaults to 256
fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each
int defines the number of feature maps in the layer]. Defaults to (16, 64, 256, 1024).
to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected)
or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ].
Defaults to True.
conv_op ([torch.nn.Module], optional): [Convolutioon operation used in the encoder to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d.
conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).
tconv_op ([torch.nn.Module], optional): [Upsampling/ Transposed Conv operation used in the decoder to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d.
tconv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).
normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.
normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.
activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp.
block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None.
"""
super(VAE, self).__init__()
if d == 2:
conv_op = nn.Conv2d
tconv_op = nn.ConvTranspose2d
else:
conv_op = nn.Conv3d
tconv_op = nn.ConvTranspose3d
match (activation_op):
case "relu":
activation_op = nn.ReLU
case "prelu":
activation_op = nn.PReLU
case "leakyrelu":
activation_op = nn.LeakyReLU
case _:
raise ValueError(f"Activation function {activation_op} not supported")
input_size_enc = list(input_size)
input_size_dec = list(input_size)
self.enc = BasicEncoder(
input_size=input_size_enc,
fmap_sizes=fmap_sizes,
z_dim=z_dim * 2,
conv_op=conv_op,
conv_params=conv_params,
normalization_op=normalization_op,
normalization_params=normalization_params,
activation_op=activation_op,
activation_params=activation_params,
block_op=block_op,
block_params=block_params,
to_1x1=to_1x1,
)
self.dec = BasicGenerator(
input_size=input_size_dec,
fmap_sizes=fmap_sizes[::-1],
z_dim=z_dim,
upsample_op=tconv_op,
conv_params=tconv_params,
normalization_op=normalization_op,
normalization_params=normalization_params,
activation_op=activation_op,
activation_params=activation_params,
block_op=block_op,
block_params=block_params,
to_1x1=to_1x1,
)
self.hidden_size = self.enc.output_size
def forward(self, inpt, sample=True, no_dist=False, **kwargs):
y1 = self.enc(inpt, **kwargs)
mu, log_std = torch.chunk(y1, 2, dim=1)
std = torch.exp(log_std)
z_dist = dist.Normal(mu, std)
if sample:
z_sample = z_dist.rsample()
else:
z_sample = mu
x_rec = self.dec(z_sample)
if no_dist:
return x_rec
else:
return x_rec, z_dist
def encode(self, inpt, **kwargs):
"""Encodes a sample and returns the paramters for the approx inference dist. (Normal)
Args:
inpt ([tensor]): The input to encode
Returns:
mu : The mean used to parameterized a Normal distribution
std: The standard deviation used to parameterized a Normal distribution
"""
enc = self.enc(inpt, **kwargs)
mu, log_std = torch.chunk(enc, 2, dim=1)
std = torch.exp(log_std)
return mu, std
def decode(self, inpt, **kwargs):
"""Decodes a latent space sample, used the generative model (decode = mu_{gen}(z) as used in p(x|z) = N(x | mu_{gen}(z), 1) ).
Args:
inpt ([type]): A sample from the latent space to decode
Returns:
[type]: [description]
"""
x_rec = self.dec(inpt, **kwargs)
return x_rec
class AE(torch.nn.Module):
def __init__(
self,
input_size,
z_dim=1024,
fmap_sizes=(16, 64, 256, 1024),
to_1x1=True,
conv_op=torch.nn.Conv2d,
conv_params=None,
tconv_op=torch.nn.ConvTranspose2d,
tconv_params=None,
normalization_op=None,
normalization_params=None,
activation_op=torch.nn.LeakyReLU,
activation_params=None,
block_op=None,
block_params=None,
*args,
**kwargs
):
"""Basic AE build up of a symetric BasicEncoder (Encoder) and BasicGenerator (Decoder)
Args:
input_size ((int, int, int): Size of the input in format CxHxW):
z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). Defaults to 256
fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each
int defines the number of feature maps in the layer]. Defaults to (16, 64, 256, 1024).
to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected)
or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ].
Defaults to True.
conv_op ([torch.nn.Module], optional): [Convolutioon operation used in the encoder to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d.
conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).
tconv_op ([torch.nn.Module], optional): [Upsampling/ Transposed Conv operation used in the decoder to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d.
tconv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).
normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.
normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.
activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp.
block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None.
"""
super(AE, self).__init__()
input_size_enc = list(input_size)
input_size_dec = list(input_size)
self.enc = BasicEncoder(
input_size=input_size_enc,
fmap_sizes=fmap_sizes,
z_dim=z_dim,
conv_op=conv_op,
conv_params=conv_params,
normalization_op=normalization_op,
normalization_params=normalization_params,
activation_op=activation_op,
activation_params=activation_params,
block_op=block_op,
block_params=block_params,
to_1x1=to_1x1,
)
self.dec = BasicGenerator(
input_size=input_size_dec,
fmap_sizes=fmap_sizes[::-1],
z_dim=z_dim,
upsample_op=tconv_op,
conv_params=tconv_params,
normalization_op=normalization_op,
normalization_params=normalization_params,
activation_op=activation_op,
activation_params=activation_params,
block_op=block_op,
block_params=block_params,
to_1x1=to_1x1,
)
self.hidden_size = self.enc.output_size
def forward(self, inpt, **kwargs):
y1 = self.enc(inpt, **kwargs)
x_rec = self.dec(y1)
return x_rec
def encode(self, inpt, **kwargs):
"""Encodes a input sample to a latent space sample
Args:
inpt ([tensor]): Input sample
Returns:
enc: Encoded input sample in the latent space
"""
enc = self.enc(inpt, **kwargs)
return enc
def decode(self, inpt, **kwargs):
"""Decodes a latent space sample back to the input space
Args:
inpt ([tensor]): [Latent space sample]
Returns:
[rec]: [Encoded latent sample back in the input space]
"""
rec = self.dec(inpt, **kwargs)
return rec
|