anyantudre's picture
moved from training repo to inference
caa56d6
raw
history blame
4.1 kB
from torch import nn
from torch.autograd import Variable
import torch
import torch.nn.functional as F
import torchvision.models as models
def add_gaussian_noise(ins, mean=0, stddev=0.1):
noise = ins.data.new(ins.size()).normal_(mean, stddev)
return ins + noise
class FlattenLayer(nn.Module):
def __init__(self):
super(FlattenLayer, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
class UnflattenLayer(nn.Module):
def __init__(self, width):
super(UnflattenLayer, self).__init__()
self.width = width
def forward(self, x):
return x.view(x.size(0), -1, self.width, self.width)
class VAE_Encoder(nn.Module):
'''
VAE_Encoder: Encode image into std and logvar
'''
def __init__(self, latent_dim=256):
super(VAE_Encoder, self).__init__()
self.resnet = models.resnet18(pretrained=True)
self.resnet.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.resnet = nn.Sequential(
*list(self.resnet.children())[:-1],
FlattenLayer()
)
self.l_mu = nn.Linear(512, latent_dim)
self.l_var = nn.Linear(512, latent_dim)
def encode(self, x):
hidden = self.resnet(x)
mu = self.l_mu(hidden)
logvar = self.l_var(hidden)
return mu, logvar
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
else:
return mu
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
class VAE_Decoder(nn.Module):
'''
VAE_Decoder: Decode noise to image
'''
def __init__(self, latent_dim, output_dim=3):
super(VAE_Decoder, self).__init__()
self.convs = nn.Sequential(
UnflattenLayer(width=1),
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(512, 384, 4, 2, 1, bias=False),
nn.BatchNorm2d(384),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(384, 192, 4, 2, 1, bias=False),
nn.BatchNorm2d(192),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(192, 96, 4, 2, 1, bias=False),
nn.BatchNorm2d(96),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(96, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(inplace=True),
nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z):
return self.convs(z)
class ImageAE(nn.Module):
# VAE architecture
def __init__(self):
super(ImageAE, self).__init__()
latent_dim = 512
self.enc = VAE_Encoder(latent_dim)
self.dec = VAE_Decoder(latent_dim)
def forward(self, x):
z, *_ = self.enc(x)
out = self.dec(z)
return out
def load_ckpt(self, enc_path, dec_path):
self.enc.load_state_dict(torch.load(enc_path, map_location='cpu'))
self.dec.load_state_dict(torch.load(dec_path, map_location='cpu'))
def get_pretraiend_ae(enc_path='pretrained/ae/vae/enc.pth', dec_path='pretrained/ae/vae/dec1.pth'):
ae = ImageAE()
ae.load_ckpt(enc_path, dec_path)
print('load image auto-encoder')
ae.eval()
return ae
# from networks.pix2pix_network import UnetGenerator
def get_pretraiend_unet(path='pretrained/ae/unet/ckpt_srm.pth'):
unet = UnetGenerator(3, 3, 8)
unet.load_state_dict(torch.load(path, map_location='cpu'))
print('load Unet')
unet.eval()
return unet
if __name__ == "__main__":
ae = get_pretraiend_ae()
print(ae)