phylo-diffusion / ldm /models /M_ModelAE_Cnn.py
mridulk's picture
added models
d39ef0a
raw
history blame
6.25 kB
# Source: https://github.com/lissomx/MSP/blob/master/M_ModelAE_Cnn.py
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
class Encoder(nn.Module):
# only for square pics with width or height is n^(2x)
def __init__(self, image_size, nf, hidden_size=None, nc=3):
super(Encoder, self).__init__()
self.image_size = image_size
self.hidden_size = hidden_size
sequens = [
nn.Conv2d(nc, nf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
]
while(True):
image_size = image_size/2
if image_size > 4:
sequens.append(nn.Conv2d(nf, nf * 2, 4, 2, 1, bias=False))
sequens.append(nn.BatchNorm2d(nf * 2))
sequens.append(nn.LeakyReLU(0.2, inplace=True))
nf = nf * 2
else:
if hidden_size is None:
self.hidden_size = int(nf)
sequens.append(nn.Conv2d(nf, self.hidden_size, int(image_size), 1, 0, bias=False))
break
self.main = nn.Sequential(*sequens)
def forward(self, input):
return self.main(input).squeeze(3).squeeze(2)
class Decoder(nn.Module):
# only for square pics with width or height is n^(2x)
def __init__(self, image_size, nf, hidden_size=None, nc=3):
super(Decoder, self).__init__()
self.image_size = image_size
self.hidden_size = hidden_size
sequens = [
nn.Tanh(),
nn.ConvTranspose2d(nf, nc, 4, 2, 1, bias=False),
]
while(True):
image_size = image_size/2
sequens.append(nn.ReLU(True))
sequens.append(nn.BatchNorm2d(nf))
if image_size > 4:
sequens.append(nn.ConvTranspose2d(nf * 2, nf, 4, 2, 1, bias=False))
else:
if hidden_size is None:
self.hidden_size = int(nf)
sequens.append(nn.ConvTranspose2d(self.hidden_size, nf, int(image_size), 1, 0, bias=False))
break
nf = nf*2
sequens.reverse()
self.main = nn.Sequential(*sequens)
def forward(self, z):
z = z.unsqueeze(2).unsqueeze(2)
output = self.main(z)
return output
def loss(self, predict, orig):
batch_size = predict.shape[0]
a = predict.view(batch_size, -1)
b = orig.view(batch_size, -1)
L = F.mse_loss(a, b, reduction='sum')
return L
class CnnVae(nn.Module):
def __init__(self, learning_rate, image_size, label_size, nf, hidden_size=None, nc=3):
super(CnnVae, self).__init__()
self.encoder = Encoder(image_size, nf, hidden_size, nc)
self.decoder = Decoder(image_size, nf, hidden_size, nc)
self.image_size = image_size
self.nc = nc
self.label_size = label_size
self.hidden_size = self.encoder.hidden_size
self.learning_rate = learning_rate
self.fc1 = nn.Linear(self.hidden_size, self.hidden_size)
self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)
self.M = nn.Parameter(torch.empty(label_size, self.hidden_size))
nn.init.xavier_normal_(self.M)
def encode(self, x):
h = self.encoder(x)
mu = self.fc1(h)
logvar = self.fc2(h)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def forward(self, x):
# breakpoint()
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
prod = self.decoder(z)
outputs = {'output': prod} # DISENTANGLER_DECODER_OUTPUT
# outputs[DISENTANGLER_ATTRIBUTE_OUTPUT] = attr
outputs['embedding'] = z
outputs['vae_mu'] = mu
outputs['vae_logvar'] = logvar
# return prod, z, mu, logvar
return outputs
def _loss_vae(self, mu, logvar):
# https://arxiv.org/abs/1312.6114
# KLD = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return KLD
def _loss_msp(self, label, z):
labels_one_hot = F.one_hot(label, num_classes=self.label_size)
labels_one_hot = labels_one_hot.to(dtype=torch.float32)
labels_one_hot[labels_one_hot == 0.0] = -1
L1 = F.mse_loss((z @ self.M.t()).view(-1), labels_one_hot.view(-1), reduction="none").sum()
L2 = F.mse_loss((labels_one_hot @ self.M).view(-1), z.view(-1), reduction="none").sum()
return L1 + L2, L1, L2
def loss(self, prod, orgi, label, z, mu, logvar):
L_rec = self.decoder.loss(prod, orgi)
L_vae = self._loss_vae(mu, logvar)
L_msp, L1_msp, L2_msp = self._loss_msp(label, z)
_msp_weight = orgi.numel()/(label.numel()+z.numel())
Loss = L_rec + L_vae + L_msp * _msp_weight
loss_dict = {'L1': L1_msp, 'L2': L2_msp, 'L_msp': L_msp,
'L_rec': L_rec, 'L_vae': L_vae}
return Loss, loss_dict #L_rec.item(), L_vae.item(), L_msp.item()
def acc(self, z, l):
zl = z @ self.M.t()
a = zl.clamp(-1, 1)*l*0.5+0.5
return a.round().mean().item()
def predict(self, x, new_ls=None, weight=1.0):
z, _ = self.encode(x)
if new_ls is not None:
zl = z @ self.M.t()
d = torch.zeros_like(zl)
for i, v in new_ls:
d[:,i] = v*weight - zl[:,i]
z += d @ self.M
prod = self.decoder(z)
return prod
def predict_ex(self, x, label, new_ls=None, weight=1.0):
return self.predict(x,new_ls,weight)
def get_U(self, eps=1e-5):
from scipy import linalg, compress
# get the null matrix N of M
# such that U=[M;N] is orthogonal
M = self.M.detach().cpu()
A = torch.zeros(M.shape[1]-M.shape[0], M.shape[1])
A = torch.cat([M, A])
u, s, vh = linalg.svd(A.numpy())
null_mask = (s <= eps)
null_space = compress(null_mask, vh, axis=0)
N = torch.tensor(null_space)
return torch.cat([self.M, N.to(self.M.device)])