File size: 11,329 Bytes
c3bcb92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import numpy as np
import torch
import torch.nn as nn
import torch.distributions as dist

from .ae_bases import BasicEncoder, BasicGenerator


class VAE(torch.nn.Module):
    def __init__(

        self,

        d,

        input_size,

        z_dim=256,

        fmap_sizes=(16, 64, 256, 1024),

        to_1x1=True,

        conv_params=None,

        tconv_params=None,

        normalization_op=None,

        normalization_params=None,

        activation_op="leakyrelu",

        activation_params=None,

        block_op=None,

        block_params=None,

        *args,

        **kwargs

    ):
        """Basic VAE build up of a symetric BasicEncoder (Encoder) and BasicGenerator (Decoder)



        Args:

            input_size ((int, int, int): Size of the input in format CxHxW): 

            z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). Defaults to 256

            fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each 

                                            int defines the number of feature maps in the layer]. Defaults to (16, 64, 256, 1024).

            to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) 

                                        or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ].

                                        Defaults to True.

            conv_op ([torch.nn.Module], optional): [Convolutioon operation used in the encoder to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d.

            conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).

            tconv_op ([torch.nn.Module], optional): [Upsampling/ Transposed Conv operation used in the decoder to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d.

            tconv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).

            normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.

            normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.

            activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.

            activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.

            block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp.

            block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None.

        """

        super(VAE, self).__init__()
        
        if d == 2:
            conv_op = nn.Conv2d
            tconv_op = nn.ConvTranspose2d
        else:
            conv_op = nn.Conv3d
            tconv_op = nn.ConvTranspose3d
        
        match (activation_op):
            case "relu":
                activation_op = nn.ReLU
            case "prelu":
                activation_op = nn.PReLU
            case "leakyrelu":
                activation_op = nn.LeakyReLU
            case _:
                raise ValueError(f"Activation function {activation_op} not supported")

        input_size_enc = list(input_size)
        input_size_dec = list(input_size)

        self.enc = BasicEncoder(
            input_size=input_size_enc,
            fmap_sizes=fmap_sizes,
            z_dim=z_dim * 2,
            conv_op=conv_op,
            conv_params=conv_params,
            normalization_op=normalization_op,
            normalization_params=normalization_params,
            activation_op=activation_op,
            activation_params=activation_params,
            block_op=block_op,
            block_params=block_params,
            to_1x1=to_1x1,
        )
        self.dec = BasicGenerator(
            input_size=input_size_dec,
            fmap_sizes=fmap_sizes[::-1],
            z_dim=z_dim,
            upsample_op=tconv_op,
            conv_params=tconv_params,
            normalization_op=normalization_op,
            normalization_params=normalization_params,
            activation_op=activation_op,
            activation_params=activation_params,
            block_op=block_op,
            block_params=block_params,
            to_1x1=to_1x1,
        )

        self.hidden_size = self.enc.output_size

    def forward(self, inpt, sample=True, no_dist=False, **kwargs):
        y1 = self.enc(inpt, **kwargs)

        mu, log_std = torch.chunk(y1, 2, dim=1)
        std = torch.exp(log_std)
        z_dist = dist.Normal(mu, std)
        if sample:
            z_sample = z_dist.rsample()
        else:
            z_sample = mu

        x_rec = self.dec(z_sample)

        if no_dist:
            return x_rec
        else:
            return x_rec, z_dist

    def encode(self, inpt, **kwargs):
        """Encodes a sample and returns the paramters for the approx inference dist. (Normal)



        Args:

            inpt ([tensor]): The input to encode



        Returns:

            mu : The mean used to parameterized a Normal distribution

            std: The standard deviation used to parameterized a Normal distribution

        """
        enc = self.enc(inpt, **kwargs)
        mu, log_std = torch.chunk(enc, 2, dim=1)
        std = torch.exp(log_std)
        return mu, std

    def decode(self, inpt, **kwargs):
        """Decodes a latent space sample, used the generative model (decode = mu_{gen}(z) as used in p(x|z) = N(x | mu_{gen}(z), 1) ).



        Args:

            inpt ([type]): A sample from the latent space to decode



        Returns:

            [type]: [description]

        """
        x_rec = self.dec(inpt, **kwargs)
        return x_rec


class AE(torch.nn.Module):
    def __init__(

        self,

        input_size,

        z_dim=1024,

        fmap_sizes=(16, 64, 256, 1024),

        to_1x1=True,

        conv_op=torch.nn.Conv2d,

        conv_params=None,

        tconv_op=torch.nn.ConvTranspose2d,

        tconv_params=None,

        normalization_op=None,

        normalization_params=None,

        activation_op=torch.nn.LeakyReLU,

        activation_params=None,

        block_op=None,

        block_params=None,

        *args,

        **kwargs

    ):
        """Basic AE build up of a symetric BasicEncoder (Encoder) and BasicGenerator (Decoder)



        Args:

            input_size ((int, int, int): Size of the input in format CxHxW): 

            z_dim (int, optional): [description]. Dimension of the latent / Input dimension (C channel-dim). Defaults to 256

            fmap_sizes (tuple, optional): [Defines the Upsampling-Levels of the generator, list/ tuple of ints, where each 

                                            int defines the number of feature maps in the layer]. Defaults to (16, 64, 256, 1024).

            to_1x1 (bool, optional): [If True, then the last conv layer goes to a latent dimesion is a z_dim x 1 x 1 vector (similar to fully connected) 

                                        or if False allows spatial resolution not to be 1x1 (z_dim x H x W, uses the in the conv_params given conv-kernel-size) ].

                                        Defaults to True.

            conv_op ([torch.nn.Module], optional): [Convolutioon operation used in the encoder to downsample to a new level/ featuremap size]. Defaults to nn.Conv2d.

            conv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).

            tconv_op ([torch.nn.Module], optional): [Upsampling/ Transposed Conv operation used in the decoder to upsample to a new level/ featuremap size]. Defaults to nn.ConvTranspose2d.

            tconv_params ([dict], optional): [Init parameters for the conv operation]. Defaults to dict(kernel_size=3, stride=2, padding=1, bias=False).

            normalization_op ([torch.nn.Module], optional): [Normalization Operation (e.g. BatchNorm, InstanceNorm,...) -> see ConvModule]. Defaults to nn.BatchNorm2d.

            normalization_params ([dict], optional): [Init parameters for the normalization operation]. Defaults to None.

            activation_op ([torch.nn.Module], optional): [Actiovation Operation/ Non-linearity (e.g. ReLU, Sigmoid,...) -> see ConvModule]. Defaults to nn.LeakyReLU.

            activation_params ([dict], optional): [Init parameters for the activation operation]. Defaults to None.

            block_op ([torch.nn.Module], optional): [Block operation used for each feature map size after each upsample op of e.g. ConvBlock/ ResidualBlock]. Defaults to NoOp.

            block_params ([dict], optional): [Init parameters for the block operation]. Defaults to None.

        """
        super(AE, self).__init__()

        input_size_enc = list(input_size)
        input_size_dec = list(input_size)

        self.enc = BasicEncoder(
            input_size=input_size_enc,
            fmap_sizes=fmap_sizes,
            z_dim=z_dim,
            conv_op=conv_op,
            conv_params=conv_params,
            normalization_op=normalization_op,
            normalization_params=normalization_params,
            activation_op=activation_op,
            activation_params=activation_params,
            block_op=block_op,
            block_params=block_params,
            to_1x1=to_1x1,
        )
        self.dec = BasicGenerator(
            input_size=input_size_dec,
            fmap_sizes=fmap_sizes[::-1],
            z_dim=z_dim,
            upsample_op=tconv_op,
            conv_params=tconv_params,
            normalization_op=normalization_op,
            normalization_params=normalization_params,
            activation_op=activation_op,
            activation_params=activation_params,
            block_op=block_op,
            block_params=block_params,
            to_1x1=to_1x1,
        )

        self.hidden_size = self.enc.output_size

    def forward(self, inpt, **kwargs):

        y1 = self.enc(inpt, **kwargs)

        x_rec = self.dec(y1)

        return x_rec

    def encode(self, inpt, **kwargs):
        """Encodes a input sample to a latent space sample



        Args:

            inpt ([tensor]): Input sample



        Returns:

            enc: Encoded input sample in the latent space

        """
        enc = self.enc(inpt, **kwargs)
        return enc

    def decode(self, inpt, **kwargs):
        """Decodes a latent space sample back to the input space



        Args:

            inpt ([tensor]): [Latent space sample]



        Returns:

            [rec]: [Encoded latent sample back in the input space]

        """
        rec = self.dec(inpt, **kwargs)
        return rec