File size: 4,101 Bytes
caa56d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from torch import nn
from torch.autograd import Variable
import torch
import torch.nn.functional as F

import torchvision.models as models

def add_gaussian_noise(ins, mean=0, stddev=0.1):
    noise = ins.data.new(ins.size()).normal_(mean, stddev)
    return ins + noise

class FlattenLayer(nn.Module):
    def __init__(self):
        super(FlattenLayer, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)


class UnflattenLayer(nn.Module):
    def __init__(self, width):
        super(UnflattenLayer, self).__init__()
        self.width = width

    def forward(self, x):
        return x.view(x.size(0), -1, self.width, self.width)

class VAE_Encoder(nn.Module):
    ''' 

    VAE_Encoder: Encode image into std and logvar 

    '''

    def __init__(self, latent_dim=256):
        super(VAE_Encoder, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.resnet = nn.Sequential(
            *list(self.resnet.children())[:-1],
            FlattenLayer()
        )

        self.l_mu = nn.Linear(512, latent_dim)
        self.l_var = nn.Linear(512, latent_dim)

    def encode(self, x):
        hidden = self.resnet(x)
        mu = self.l_mu(hidden)
        logvar = self.l_var(hidden)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return mu + eps*std

        else:
            return mu

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar


class VAE_Decoder(nn.Module):
    ''' 

    VAE_Decoder: Decode noise to image

    '''

    def __init__(self, latent_dim, output_dim=3):
        super(VAE_Decoder, self).__init__()        
        self.convs = nn.Sequential(
            UnflattenLayer(width=1),
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 384, 4, 2, 1, bias=False),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(384, 192, 4, 2, 1, bias=False),
            nn.BatchNorm2d(192),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(192, 96, 4, 2, 1, bias=False),
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(96, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(inplace=True),
            nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        return self.convs(z)

class ImageAE(nn.Module):
    # VAE architecture
    def __init__(self):
        super(ImageAE, self).__init__()
        latent_dim = 512
        self.enc = VAE_Encoder(latent_dim)
        self.dec = VAE_Decoder(latent_dim)        

    def forward(self, x):
        z, *_ = self.enc(x)
        out = self.dec(z)
        
        return out
    
    def load_ckpt(self, enc_path,  dec_path):
        self.enc.load_state_dict(torch.load(enc_path, map_location='cpu'))
        self.dec.load_state_dict(torch.load(dec_path, map_location='cpu'))

    
def get_pretraiend_ae(enc_path='pretrained/ae/vae/enc.pth', dec_path='pretrained/ae/vae/dec1.pth'):
    ae = ImageAE()
    ae.load_ckpt(enc_path, dec_path)
    print('load image auto-encoder')
    ae.eval()
    return ae

# from networks.pix2pix_network import UnetGenerator
def get_pretraiend_unet(path='pretrained/ae/unet/ckpt_srm.pth'):
    unet = UnetGenerator(3, 3, 8)
    unet.load_state_dict(torch.load(path, map_location='cpu'))
    print('load Unet')
    unet.eval()
    return unet

if __name__ == "__main__":
    ae = get_pretraiend_ae()
    print(ae)