import warnings import numpy as np import torch import torch.nn as nn class NoOp(nn.Module): def __init__(self, *args, **kwargs): """NoOp Pytorch Module. Forwards the given input as is. """ super(NoOp, self).__init__() def forward(self, x, *args, **kwargs): return x class ConvModule(nn.Module): def __init__( self, in_channels: int, out_channels: int, conv_op=nn.Conv2d, conv_params=None, normalization_op=None, normalization_params=None, activation_op=nn.LeakyReLU, activation_params=None, ): """Basic Conv Pytorch Conv Module Has can have a Conv Op, a Normlization Op and a Non Linearity: x = conv(x) x = some_norm(x) x = nonlin(x) Args: in_channels ([int]): [Number on input channels/ feature maps] out_channels ([int]): [Number of ouput channels/ feature maps] conv_op ([torch.nn.Module], optional): [Conv operation]. Defaults to nn.Conv2d. conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None. normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...)]. Defaults to None. normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...)]. Defaults to nn.LeakyReLU. activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. """ super(ConvModule, self).__init__() self.conv_params = conv_params if self.conv_params is None: self.conv_params = {} self.activation_params = activation_params if self.activation_params is None: self.activation_params = {} self.normalization_params = normalization_params if self.normalization_params is None: self.normalization_params = {} self.conv = None if conv_op is not None and not isinstance(conv_op, str): self.conv = conv_op(in_channels, out_channels, **self.conv_params) self.normalization = None if normalization_op is not None and not isinstance(normalization_op, str): self.normalization = normalization_op(out_channels, **self.normalization_params) self.activation = None if activation_op is not None and not isinstance(activation_op, str): self.activation = activation_op(**self.activation_params) def forward(self, input, conv_add_input=None, normalization_add_input=None, activation_add_input=None): x = input if self.conv is not None: if conv_add_input is None: x = self.conv(x) else: x = self.conv(x, **conv_add_input) if self.normalization is not None: if normalization_add_input is None: x = self.normalization(x) else: x = self.normalization(x, **normalization_add_input) if self.activation is not None: if activation_add_input is None: x = self.activation(x) else: x = self.activation(x, **activation_add_input) # nn.functional.dropout(x, p=0.95, training=True) return x class ConvBlock(nn.Module): def __init__( self, n_convs: int, n_featmaps: int, conv_op=nn.Conv2d, conv_params=None, normalization_op=nn.BatchNorm2d, normalization_params=None, activation_op=nn.LeakyReLU, activation_params=None, ): """Basic Conv block with repeated conv, build up from repeated @ConvModules (with same/fixed feature map size) Args: n_convs ([type]): [Number of convolutions] n_featmaps ([type]): [Feature map size of the conv] conv_op ([torch.nn.Module], optional): [Convulioton operation -> see ConvModule ]. Defaults to nn.Conv2d. conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None. normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. """ super(ConvBlock, self).__init__() self.n_featmaps = n_featmaps self.n_convs = n_convs self.conv_params = conv_params if self.conv_params is None: self.conv_params = {} self.conv_list = nn.ModuleList() for i in range(self.n_convs): conv_layer = ConvModule( n_featmaps, n_featmaps, conv_op=conv_op, conv_params=conv_params, normalization_op=normalization_op, normalization_params=normalization_params, activation_op=activation_op, activation_params=activation_params, ) self.conv_list.append(conv_layer) def forward(self, input, **frwd_params): x = input for conv_layer in self.conv_list: x = conv_layer(x) return x class ResBlock(nn.Module): def __init__( self, n_convs, n_featmaps, conv_op=nn.Conv2d, conv_params=None, normalization_op=nn.BatchNorm2d, normalization_params=None, activation_op=nn.LeakyReLU, activation_params=None, ): """Basic Conv block with repeated conv, build up from repeated @ConvModules (with same/fixed feature map size) and a skip/ residual connection: x = input x = conv_block(x) out = x + input Args: n_convs ([type]): [Number of convolutions in the conv block] n_featmaps ([type]): [Feature map size of the conv block] conv_op ([torch.nn.Module], optional): [Convulioton operation -> see ConvModule ]. Defaults to nn.Conv2d. conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to None. normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. """ super(ResBlock, self).__init__() self.n_featmaps = n_featmaps self.n_convs = n_convs self.conv_params = conv_params if self.conv_params is None: self.conv_params = {} self.conv_block = ConvBlock( n_featmaps, n_convs, conv_op=conv_op, conv_params=conv_params, normalization_op=normalization_op, normalization_params=normalization_params, activation_op=activation_op, activation_params=activation_params, ) def forward(self, input, **frwd_params): x = input x = self.conv_block(x) out = x + input return out # Basic Generator class BasicGenerator(nn.Module): def __init__( self, input_size, z_dim=256, fmap_sizes=(256, 128, 64), upsample_op=nn.ConvTranspose2d, conv_params=None, normalization_op=NoOp, normalization_params=None, activation_op=nn.LeakyReLU, activation_params=None, block_op=NoOp, block_params=None, to_1x1=True, ): """Basic configureable Generator/ Decoder. Allows for mutilple "feature-map" levels defined by the feature map size, where for each feature map size a conv operation + optional conv block is used. Args: input_size ((int, int, int): Size of the input in format CxHxW): z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each int defines the number of feature maps in the layer]. Defaults to (256, 128, 64). upsample_op ([torch.nn.Module], optional): [Upsampling operation used, to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d. conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. to_1x1 (bool, optional): [If Latent dimesion is a z_dim x 1 x 1 vector (True) or if allows spatial resolution not to be 1x1 (z_dim x H x W) (False) ]. Defaults to True. """ super(BasicGenerator, self).__init__() if conv_params is None: conv_params = dict(kernel_size=4, stride=2, padding=1, bias=False) if block_op is None: block_op = NoOp if block_params is None: block_params = {} n_channels = input_size[0] input_size_ = np.array(input_size[1:]) if not isinstance(fmap_sizes, list) and not isinstance(fmap_sizes, tuple): raise AttributeError("fmap_sizes has to be either a list or tuple or an int") elif len(fmap_sizes) < 2: raise AttributeError("fmap_sizes has to contain at least three elements") else: h_size_bot = fmap_sizes[0] # We need to know how many layers we will use at the beginning input_size_new = input_size_ // (2 ** len(fmap_sizes)) if np.min(input_size_new) < 2 and z_dim is not None: raise AttributeError("fmap_sizes to long, one image dimension has already perished") ### Start block start_block = [] if not to_1x1: kernel_size_start = [min(conv_params["kernel_size"], i) for i in input_size_new] else: kernel_size_start = input_size_new.tolist() if z_dim is not None: self.start = ConvModule( z_dim, h_size_bot, conv_op=upsample_op, conv_params=dict(kernel_size=kernel_size_start, stride=1, padding=0, bias=False), normalization_op=normalization_op, normalization_params=normalization_params, activation_op=activation_op, activation_params=activation_params, ) input_size_new = input_size_new * 2 else: self.start = NoOp() ### Middle block (Done until we reach ? x input_size/2 x input_size/2) self.middle_blocks = nn.ModuleList() for h_size_top in fmap_sizes[1:]: self.middle_blocks.append(block_op(h_size_bot, **block_params)) self.middle_blocks.append( ConvModule( h_size_bot, h_size_top, conv_op=upsample_op, conv_params=conv_params, normalization_op=normalization_op, normalization_params={}, activation_op=activation_op, activation_params=activation_params, ) ) h_size_bot = h_size_top input_size_new = input_size_new * 2 ### End block self.end = ConvModule( h_size_bot, n_channels, conv_op=upsample_op, conv_params=conv_params, normalization_op=None, activation_op=None, ) def forward(self, inpt, **kwargs): output = self.start(inpt, **kwargs) for middle in self.middle_blocks: output = middle(output, **kwargs) output = self.end(output, **kwargs) return output # Basic Encoder class BasicEncoder(nn.Module): def __init__( self, input_size, z_dim=256, fmap_sizes=(64, 128, 256), conv_op=nn.Conv2d, conv_params=None, normalization_op=NoOp, normalization_params=None, activation_op=nn.LeakyReLU, activation_params=None, block_op=NoOp, block_params=None, to_1x1=True, ): """Basic configureable Encoder. Allows for mutilple "feature-map" levels defined by the feature map size, where for each feature map size a conv operation + optional conv block is used. Args: z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each int defines the number of feature maps in the layer]. Defaults to (64, 128, 256). conv_op ([torch.nn.Module], optional): [Convolutioon operation used to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d. conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False). normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d. normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None. activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU. activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None. block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp. block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None. to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ]. Defaults to True. """ super(BasicEncoder, self).__init__() if conv_params is None: conv_params = dict(kernel_size=3, stride=2, padding=1, bias=False) if block_op is None: block_op = NoOp if block_params is None: block_params = {} n_channels = input_size[0] input_size_new = np.array(input_size[1:]) if not isinstance(fmap_sizes, list) and not isinstance(fmap_sizes, tuple): raise AttributeError("fmap_sizes has to be either a list or tuple or an int") # elif len(fmap_sizes) < 2: # raise AttributeError("fmap_sizes has to contain at least three elements") else: h_size_bot = fmap_sizes[0] ### Start block self.start = ConvModule( n_channels, h_size_bot, conv_op=conv_op, conv_params=conv_params, normalization_op=normalization_op, normalization_params={}, activation_op=activation_op, activation_params=activation_params, ) input_size_new = input_size_new // 2 ### Middle block (Done until we reach ? x 4 x 4) self.middle_blocks = nn.ModuleList() for h_size_top in fmap_sizes[1:]: self.middle_blocks.append(block_op(h_size_bot, **block_params)) self.middle_blocks.append( ConvModule( h_size_bot, h_size_top, conv_op=conv_op, conv_params=conv_params, normalization_op=normalization_op, normalization_params={}, activation_op=activation_op, activation_params=activation_params, ) ) h_size_bot = h_size_top input_size_new = input_size_new // 2 if np.min(input_size_new) < 2 and z_dim is not None: raise ("fmap_sizes to long, one image dimension has already perished") ### End block if not to_1x1: kernel_size_end = [min(conv_params["kernel_size"], i) for i in input_size_new] else: kernel_size_end = input_size_new.tolist() if z_dim is not None: self.end = ConvModule( h_size_bot, z_dim, conv_op=conv_op, conv_params=dict(kernel_size=kernel_size_end, stride=1, padding=0, bias=False), normalization_op=None, activation_op=None, ) if to_1x1: self.output_size = (z_dim, 1, 1) else: self.output_size = (z_dim, *[i - (j - 1) for i, j in zip(input_size_new, kernel_size_end)]) else: self.end = NoOp() self.output_size = input_size_new def forward(self, inpt, **kwargs): output = self.start(inpt, **kwargs) for middle in self.middle_blocks: output = middle(output, **kwargs) output = self.end(output, **kwargs) return output