|
from torch import nn
|
|
from torch.autograd import Variable
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
import torchvision.models as models
|
|
|
|
def add_gaussian_noise(ins, mean=0, stddev=0.1):
|
|
noise = ins.data.new(ins.size()).normal_(mean, stddev)
|
|
return ins + noise
|
|
|
|
class FlattenLayer(nn.Module):
|
|
def __init__(self):
|
|
super(FlattenLayer, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x.view(x.size(0), -1)
|
|
|
|
|
|
class UnflattenLayer(nn.Module):
|
|
def __init__(self, width):
|
|
super(UnflattenLayer, self).__init__()
|
|
self.width = width
|
|
|
|
def forward(self, x):
|
|
return x.view(x.size(0), -1, self.width, self.width)
|
|
|
|
class VAE_Encoder(nn.Module):
|
|
'''
|
|
VAE_Encoder: Encode image into std and logvar
|
|
'''
|
|
|
|
def __init__(self, latent_dim=256):
|
|
super(VAE_Encoder, self).__init__()
|
|
self.resnet = models.resnet18(pretrained=True)
|
|
self.resnet.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.resnet = nn.Sequential(
|
|
*list(self.resnet.children())[:-1],
|
|
FlattenLayer()
|
|
)
|
|
|
|
self.l_mu = nn.Linear(512, latent_dim)
|
|
self.l_var = nn.Linear(512, latent_dim)
|
|
|
|
def encode(self, x):
|
|
hidden = self.resnet(x)
|
|
mu = self.l_mu(hidden)
|
|
logvar = self.l_var(hidden)
|
|
return mu, logvar
|
|
|
|
def reparameterize(self, mu, logvar):
|
|
if self.training:
|
|
std = torch.exp(0.5*logvar)
|
|
eps = torch.randn_like(std)
|
|
return mu + eps*std
|
|
|
|
else:
|
|
return mu
|
|
|
|
def forward(self, x):
|
|
mu, logvar = self.encode(x)
|
|
z = self.reparameterize(mu, logvar)
|
|
return z, mu, logvar
|
|
|
|
|
|
class VAE_Decoder(nn.Module):
|
|
'''
|
|
VAE_Decoder: Decode noise to image
|
|
'''
|
|
|
|
def __init__(self, latent_dim, output_dim=3):
|
|
super(VAE_Decoder, self).__init__()
|
|
self.convs = nn.Sequential(
|
|
UnflattenLayer(width=1),
|
|
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
|
|
nn.ReLU(inplace=True),
|
|
nn.ConvTranspose2d(512, 384, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(384),
|
|
nn.ReLU(inplace=True),
|
|
nn.ConvTranspose2d(384, 192, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(192),
|
|
nn.ReLU(inplace=True),
|
|
nn.ConvTranspose2d(192, 96, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(96),
|
|
nn.ReLU(inplace=True),
|
|
nn.ConvTranspose2d(96, 64, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(64),
|
|
nn.ReLU(inplace=True),
|
|
nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(32),
|
|
nn.LeakyReLU(inplace=True),
|
|
nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
|
|
nn.Tanh()
|
|
)
|
|
|
|
def forward(self, z):
|
|
return self.convs(z)
|
|
|
|
class ImageAE(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(ImageAE, self).__init__()
|
|
latent_dim = 512
|
|
self.enc = VAE_Encoder(latent_dim)
|
|
self.dec = VAE_Decoder(latent_dim)
|
|
|
|
def forward(self, x):
|
|
z, *_ = self.enc(x)
|
|
out = self.dec(z)
|
|
|
|
return out
|
|
|
|
def load_ckpt(self, enc_path, dec_path):
|
|
self.enc.load_state_dict(torch.load(enc_path, map_location='cpu'))
|
|
self.dec.load_state_dict(torch.load(dec_path, map_location='cpu'))
|
|
|
|
|
|
def get_pretraiend_ae(enc_path='pretrained/ae/vae/enc.pth', dec_path='pretrained/ae/vae/dec1.pth'):
|
|
ae = ImageAE()
|
|
ae.load_ckpt(enc_path, dec_path)
|
|
print('load image auto-encoder')
|
|
ae.eval()
|
|
return ae
|
|
|
|
|
|
def get_pretraiend_unet(path='pretrained/ae/unet/ckpt_srm.pth'):
|
|
unet = UnetGenerator(3, 3, 8)
|
|
unet.load_state_dict(torch.load(path, map_location='cpu'))
|
|
print('load Unet')
|
|
unet.eval()
|
|
return unet
|
|
|
|
if __name__ == "__main__":
|
|
ae = get_pretraiend_ae()
|
|
print(ae)
|
|
|