File size: 4,101 Bytes
caa56d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from torch import nn
from torch.autograd import Variable
import torch
import torch.nn.functional as F
import torchvision.models as models
def add_gaussian_noise(ins, mean=0, stddev=0.1):
noise = ins.data.new(ins.size()).normal_(mean, stddev)
return ins + noise
class FlattenLayer(nn.Module):
def __init__(self):
super(FlattenLayer, self).__init__()
def forward(self, x):
return x.view(x.size(0), -1)
class UnflattenLayer(nn.Module):
def __init__(self, width):
super(UnflattenLayer, self).__init__()
self.width = width
def forward(self, x):
return x.view(x.size(0), -1, self.width, self.width)
class VAE_Encoder(nn.Module):
'''
VAE_Encoder: Encode image into std and logvar
'''
def __init__(self, latent_dim=256):
super(VAE_Encoder, self).__init__()
self.resnet = models.resnet18(pretrained=True)
self.resnet.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.resnet = nn.Sequential(
*list(self.resnet.children())[:-1],
FlattenLayer()
)
self.l_mu = nn.Linear(512, latent_dim)
self.l_var = nn.Linear(512, latent_dim)
def encode(self, x):
hidden = self.resnet(x)
mu = self.l_mu(hidden)
logvar = self.l_var(hidden)
return mu, logvar
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
else:
return mu
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return z, mu, logvar
class VAE_Decoder(nn.Module):
'''
VAE_Decoder: Decode noise to image
'''
def __init__(self, latent_dim, output_dim=3):
super(VAE_Decoder, self).__init__()
self.convs = nn.Sequential(
UnflattenLayer(width=1),
nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(512, 384, 4, 2, 1, bias=False),
nn.BatchNorm2d(384),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(384, 192, 4, 2, 1, bias=False),
nn.BatchNorm2d(192),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(192, 96, 4, 2, 1, bias=False),
nn.BatchNorm2d(96),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(96, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, 4, 2, 1, bias=False),
nn.BatchNorm2d(32),
nn.LeakyReLU(inplace=True),
nn.ConvTranspose2d(32, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, z):
return self.convs(z)
class ImageAE(nn.Module):
# VAE architecture
def __init__(self):
super(ImageAE, self).__init__()
latent_dim = 512
self.enc = VAE_Encoder(latent_dim)
self.dec = VAE_Decoder(latent_dim)
def forward(self, x):
z, *_ = self.enc(x)
out = self.dec(z)
return out
def load_ckpt(self, enc_path, dec_path):
self.enc.load_state_dict(torch.load(enc_path, map_location='cpu'))
self.dec.load_state_dict(torch.load(dec_path, map_location='cpu'))
def get_pretraiend_ae(enc_path='pretrained/ae/vae/enc.pth', dec_path='pretrained/ae/vae/dec1.pth'):
ae = ImageAE()
ae.load_ckpt(enc_path, dec_path)
print('load image auto-encoder')
ae.eval()
return ae
# from networks.pix2pix_network import UnetGenerator
def get_pretraiend_unet(path='pretrained/ae/unet/ckpt_srm.pth'):
unet = UnetGenerator(3, 3, 8)
unet.load_state_dict(torch.load(path, map_location='cpu'))
print('load Unet')
unet.eval()
return unet
if __name__ == "__main__":
ae = get_pretraiend_ae()
print(ae)
|