|
import warnings
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class NoOp(nn.Module):
|
|
def __init__(self, *args, **kwargs):
|
|
"""NoOp Pytorch Module.
|
|
Forwards the given input as is.
|
|
"""
|
|
super(NoOp, self).__init__()
|
|
|
|
def forward(self, x, *args, **kwargs):
|
|
return x
|
|
|
|
|
|
class ConvModule(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
conv_op=nn.Conv2d,
|
|
conv_params=None,
|
|
normalization_op=None,
|
|
normalization_params=None,
|
|
activation_op=nn.LeakyReLU,
|
|
activation_params=None,
|
|
):
|
|
"""Basic Conv Pytorch Conv Module
|
|
Has can have a Conv Op, a Normlization Op and a Non Linearity:
|
|
x = conv(x)
|
|
x = some_norm(x)
|
|
x = nonlin(x)
|
|
|
|
Args:
|
|
in_channels ([int]): [Number on input channels/ feature maps]
|
|
out_channels ([int]): [Number of ouput channels/ feature maps]
|
|
conv_op ([torch.nn.Module], optional): [Conv operation]. Defaults to nn.Conv2d.
|
|
conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None.
|
|
normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...)]. Defaults to None.
|
|
normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
|
|
activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...)]. Defaults to nn.LeakyReLU.
|
|
activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
|
|
"""
|
|
|
|
super(ConvModule, self).__init__()
|
|
|
|
self.conv_params = conv_params
|
|
if self.conv_params is None:
|
|
self.conv_params = {}
|
|
self.activation_params = activation_params
|
|
if self.activation_params is None:
|
|
self.activation_params = {}
|
|
self.normalization_params = normalization_params
|
|
if self.normalization_params is None:
|
|
self.normalization_params = {}
|
|
|
|
self.conv = None
|
|
if conv_op is not None and not isinstance(conv_op, str):
|
|
self.conv = conv_op(in_channels, out_channels, **self.conv_params)
|
|
|
|
self.normalization = None
|
|
if normalization_op is not None and not isinstance(normalization_op, str):
|
|
self.normalization = normalization_op(out_channels, **self.normalization_params)
|
|
|
|
self.activation = None
|
|
if activation_op is not None and not isinstance(activation_op, str):
|
|
self.activation = activation_op(**self.activation_params)
|
|
|
|
def forward(self, input, conv_add_input=None, normalization_add_input=None, activation_add_input=None):
|
|
|
|
x = input
|
|
|
|
if self.conv is not None:
|
|
if conv_add_input is None:
|
|
x = self.conv(x)
|
|
else:
|
|
x = self.conv(x, **conv_add_input)
|
|
|
|
if self.normalization is not None:
|
|
if normalization_add_input is None:
|
|
x = self.normalization(x)
|
|
else:
|
|
x = self.normalization(x, **normalization_add_input)
|
|
|
|
if self.activation is not None:
|
|
if activation_add_input is None:
|
|
x = self.activation(x)
|
|
else:
|
|
x = self.activation(x, **activation_add_input)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
class ConvBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_convs: int,
|
|
n_featmaps: int,
|
|
conv_op=nn.Conv2d,
|
|
conv_params=None,
|
|
normalization_op=nn.BatchNorm2d,
|
|
normalization_params=None,
|
|
activation_op=nn.LeakyReLU,
|
|
activation_params=None,
|
|
):
|
|
"""Basic Conv block with repeated conv, build up from repeated @ConvModules (with same/fixed feature map size)
|
|
|
|
Args:
|
|
n_convs ([type]): [Number of convolutions]
|
|
n_featmaps ([type]): [Feature map size of the conv]
|
|
conv_op ([torch.nn.Module], optional): [Convulioton operation -> see ConvModule ]. Defaults to nn.Conv2d.
|
|
conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None.
|
|
normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.
|
|
normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
|
|
activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.
|
|
activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
|
|
"""
|
|
|
|
super(ConvBlock, self).__init__()
|
|
|
|
self.n_featmaps = n_featmaps
|
|
self.n_convs = n_convs
|
|
self.conv_params = conv_params
|
|
if self.conv_params is None:
|
|
self.conv_params = {}
|
|
|
|
self.conv_list = nn.ModuleList()
|
|
|
|
for i in range(self.n_convs):
|
|
conv_layer = ConvModule(
|
|
n_featmaps,
|
|
n_featmaps,
|
|
conv_op=conv_op,
|
|
conv_params=conv_params,
|
|
normalization_op=normalization_op,
|
|
normalization_params=normalization_params,
|
|
activation_op=activation_op,
|
|
activation_params=activation_params,
|
|
)
|
|
self.conv_list.append(conv_layer)
|
|
|
|
def forward(self, input, **frwd_params):
|
|
x = input
|
|
for conv_layer in self.conv_list:
|
|
x = conv_layer(x)
|
|
|
|
return x
|
|
|
|
|
|
class ResBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
n_convs,
|
|
n_featmaps,
|
|
conv_op=nn.Conv2d,
|
|
conv_params=None,
|
|
normalization_op=nn.BatchNorm2d,
|
|
normalization_params=None,
|
|
activation_op=nn.LeakyReLU,
|
|
activation_params=None,
|
|
):
|
|
"""Basic Conv block with repeated conv, build up from repeated @ConvModules (with same/fixed feature map size) and a skip/ residual connection:
|
|
x = input
|
|
x = conv_block(x)
|
|
out = x + input
|
|
|
|
Args:
|
|
n_convs ([type]): [Number of convolutions in the conv block]
|
|
n_featmaps ([type]): [Feature map size of the conv block]
|
|
conv_op ([torch.nn.Module], optional): [Convulioton operation -> see ConvModule ]. Defaults to nn.Conv2d.
|
|
conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None.
|
|
normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.
|
|
normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
|
|
activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.
|
|
activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
|
|
"""
|
|
super(ResBlock, self).__init__()
|
|
|
|
self.n_featmaps = n_featmaps
|
|
self.n_convs = n_convs
|
|
self.conv_params = conv_params
|
|
if self.conv_params is None:
|
|
self.conv_params = {}
|
|
|
|
self.conv_block = ConvBlock(
|
|
n_featmaps,
|
|
n_convs,
|
|
conv_op=conv_op,
|
|
conv_params=conv_params,
|
|
normalization_op=normalization_op,
|
|
normalization_params=normalization_params,
|
|
activation_op=activation_op,
|
|
activation_params=activation_params,
|
|
)
|
|
|
|
def forward(self, input, **frwd_params):
|
|
x = input
|
|
x = self.conv_block(x)
|
|
|
|
out = x + input
|
|
|
|
return out
|
|
|
|
|
|
|
|
class BasicGenerator(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
z_dim=256,
|
|
fmap_sizes=(256, 128, 64),
|
|
upsample_op=nn.ConvTranspose2d,
|
|
conv_params=None,
|
|
normalization_op=NoOp,
|
|
normalization_params=None,
|
|
activation_op=nn.LeakyReLU,
|
|
activation_params=None,
|
|
block_op=NoOp,
|
|
block_params=None,
|
|
to_1x1=True,
|
|
):
|
|
"""Basic configureable Generator/ Decoder.
|
|
Allows for mutilple "feature-map" levels defined by the feature map size, where for each feature map size a conv operation + optional conv block is used.
|
|
|
|
Args:
|
|
input_size ((int, int, int): Size of the input in format CxHxW):
|
|
z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim).
|
|
fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each
|
|
int defines the number of feature maps in the layer]. Defaults to (256, 128, 64).
|
|
upsample_op ([torch.nn.Module], optional): [Upsampling operation used, to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d.
|
|
conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).
|
|
normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.
|
|
normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
|
|
activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.
|
|
activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
|
|
block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp.
|
|
block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None.
|
|
to_1x1 (bool, optional): [If Latent dimesion is a z_dim x 1 x 1 vector (True) or if allows spatial resolution not to be 1x1 (z_dim x H x W) (False) ]. Defaults to True.
|
|
"""
|
|
|
|
super(BasicGenerator, self).__init__()
|
|
|
|
if conv_params is None:
|
|
conv_params = dict(kernel_size=4, stride=2, padding=1, bias=False)
|
|
if block_op is None:
|
|
block_op = NoOp
|
|
if block_params is None:
|
|
block_params = {}
|
|
|
|
n_channels = input_size[0]
|
|
input_size_ = np.array(input_size[1:])
|
|
|
|
if not isinstance(fmap_sizes, list) and not isinstance(fmap_sizes, tuple):
|
|
raise AttributeError("fmap_sizes has to be either a list or tuple or an int")
|
|
elif len(fmap_sizes) < 2:
|
|
raise AttributeError("fmap_sizes has to contain at least three elements")
|
|
else:
|
|
h_size_bot = fmap_sizes[0]
|
|
|
|
|
|
input_size_new = input_size_ // (2 ** len(fmap_sizes))
|
|
if np.min(input_size_new) < 2 and z_dim is not None:
|
|
raise AttributeError("fmap_sizes to long, one image dimension has already perished")
|
|
|
|
|
|
start_block = []
|
|
|
|
if not to_1x1:
|
|
kernel_size_start = [min(conv_params["kernel_size"], i) for i in input_size_new]
|
|
else:
|
|
kernel_size_start = input_size_new.tolist()
|
|
|
|
if z_dim is not None:
|
|
self.start = ConvModule(
|
|
z_dim,
|
|
h_size_bot,
|
|
conv_op=upsample_op,
|
|
conv_params=dict(kernel_size=kernel_size_start, stride=1, padding=0, bias=False),
|
|
normalization_op=normalization_op,
|
|
normalization_params=normalization_params,
|
|
activation_op=activation_op,
|
|
activation_params=activation_params,
|
|
)
|
|
|
|
input_size_new = input_size_new * 2
|
|
else:
|
|
self.start = NoOp()
|
|
|
|
|
|
self.middle_blocks = nn.ModuleList()
|
|
|
|
for h_size_top in fmap_sizes[1:]:
|
|
|
|
self.middle_blocks.append(block_op(h_size_bot, **block_params))
|
|
|
|
self.middle_blocks.append(
|
|
ConvModule(
|
|
h_size_bot,
|
|
h_size_top,
|
|
conv_op=upsample_op,
|
|
conv_params=conv_params,
|
|
normalization_op=normalization_op,
|
|
normalization_params={},
|
|
activation_op=activation_op,
|
|
activation_params=activation_params,
|
|
)
|
|
)
|
|
|
|
h_size_bot = h_size_top
|
|
input_size_new = input_size_new * 2
|
|
|
|
|
|
self.end = ConvModule(
|
|
h_size_bot,
|
|
n_channels,
|
|
conv_op=upsample_op,
|
|
conv_params=conv_params,
|
|
normalization_op=None,
|
|
activation_op=None,
|
|
)
|
|
|
|
def forward(self, inpt, **kwargs):
|
|
output = self.start(inpt, **kwargs)
|
|
for middle in self.middle_blocks:
|
|
output = middle(output, **kwargs)
|
|
output = self.end(output, **kwargs)
|
|
return output
|
|
|
|
|
|
|
|
class BasicEncoder(nn.Module):
|
|
def __init__(
|
|
self,
|
|
input_size,
|
|
z_dim=256,
|
|
fmap_sizes=(64, 128, 256),
|
|
conv_op=nn.Conv2d,
|
|
conv_params=None,
|
|
normalization_op=NoOp,
|
|
normalization_params=None,
|
|
activation_op=nn.LeakyReLU,
|
|
activation_params=None,
|
|
block_op=NoOp,
|
|
block_params=None,
|
|
to_1x1=True,
|
|
):
|
|
"""Basic configureable Encoder.
|
|
Allows for mutilple "feature-map" levels defined by the feature map size, where for each feature map size a conv operation + optional conv block is used.
|
|
|
|
Args:
|
|
z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim).
|
|
fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each
|
|
int defines the number of feature maps in the layer]. Defaults to (64, 128, 256).
|
|
conv_op ([torch.nn.Module], optional): [Convolutioon operation used to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d.
|
|
conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).
|
|
normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.
|
|
normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.
|
|
activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.
|
|
activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.
|
|
block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp.
|
|
block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None.
|
|
to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ]. Defaults to True.
|
|
"""
|
|
super(BasicEncoder, self).__init__()
|
|
|
|
if conv_params is None:
|
|
conv_params = dict(kernel_size=3, stride=2, padding=1, bias=False)
|
|
if block_op is None:
|
|
block_op = NoOp
|
|
if block_params is None:
|
|
block_params = {}
|
|
|
|
n_channels = input_size[0]
|
|
input_size_new = np.array(input_size[1:])
|
|
|
|
if not isinstance(fmap_sizes, list) and not isinstance(fmap_sizes, tuple):
|
|
raise AttributeError("fmap_sizes has to be either a list or tuple or an int")
|
|
|
|
|
|
else:
|
|
h_size_bot = fmap_sizes[0]
|
|
|
|
|
|
self.start = ConvModule(
|
|
n_channels,
|
|
h_size_bot,
|
|
conv_op=conv_op,
|
|
conv_params=conv_params,
|
|
normalization_op=normalization_op,
|
|
normalization_params={},
|
|
activation_op=activation_op,
|
|
activation_params=activation_params,
|
|
)
|
|
input_size_new = input_size_new // 2
|
|
|
|
|
|
self.middle_blocks = nn.ModuleList()
|
|
|
|
for h_size_top in fmap_sizes[1:]:
|
|
|
|
self.middle_blocks.append(block_op(h_size_bot, **block_params))
|
|
|
|
self.middle_blocks.append(
|
|
ConvModule(
|
|
h_size_bot,
|
|
h_size_top,
|
|
conv_op=conv_op,
|
|
conv_params=conv_params,
|
|
normalization_op=normalization_op,
|
|
normalization_params={},
|
|
activation_op=activation_op,
|
|
activation_params=activation_params,
|
|
)
|
|
)
|
|
|
|
h_size_bot = h_size_top
|
|
input_size_new = input_size_new // 2
|
|
|
|
if np.min(input_size_new) < 2 and z_dim is not None:
|
|
raise ("fmap_sizes to long, one image dimension has already perished")
|
|
|
|
|
|
if not to_1x1:
|
|
kernel_size_end = [min(conv_params["kernel_size"], i) for i in input_size_new]
|
|
else:
|
|
kernel_size_end = input_size_new.tolist()
|
|
|
|
if z_dim is not None:
|
|
self.end = ConvModule(
|
|
h_size_bot,
|
|
z_dim,
|
|
conv_op=conv_op,
|
|
conv_params=dict(kernel_size=kernel_size_end, stride=1, padding=0, bias=False),
|
|
normalization_op=None,
|
|
activation_op=None,
|
|
)
|
|
|
|
if to_1x1:
|
|
self.output_size = (z_dim, 1, 1)
|
|
else:
|
|
self.output_size = (z_dim, *[i - (j - 1) for i, j in zip(input_size_new, kernel_size_end)])
|
|
else:
|
|
self.end = NoOp()
|
|
self.output_size = input_size_new
|
|
|
|
def forward(self, inpt, **kwargs):
|
|
output = self.start(inpt, **kwargs)
|
|
for middle in self.middle_blocks:
|
|
output = middle(output, **kwargs)
|
|
output = self.end(output, **kwargs)
|
|
return output
|
|
|