# Source: https://github.com/lissomx/MSP/blob/master/M_ModelAE_Cnn.py import torch import torch.nn as nn from torch.nn import functional as F import numpy as np class Encoder(nn.Module): # only for square pics with width or height is n^(2x) def __init__(self, image_size, nf, hidden_size=None, nc=3): super(Encoder, self).__init__() self.image_size = image_size self.hidden_size = hidden_size sequens = [ nn.Conv2d(nc, nf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), ] while(True): image_size = image_size/2 if image_size > 4: sequens.append(nn.Conv2d(nf, nf * 2, 4, 2, 1, bias=False)) sequens.append(nn.BatchNorm2d(nf * 2)) sequens.append(nn.LeakyReLU(0.2, inplace=True)) nf = nf * 2 else: if hidden_size is None: self.hidden_size = int(nf) sequens.append(nn.Conv2d(nf, self.hidden_size, int(image_size), 1, 0, bias=False)) break self.main = nn.Sequential(*sequens) def forward(self, input): return self.main(input).squeeze(3).squeeze(2) class Decoder(nn.Module): # only for square pics with width or height is n^(2x) def __init__(self, image_size, nf, hidden_size=None, nc=3): super(Decoder, self).__init__() self.image_size = image_size self.hidden_size = hidden_size sequens = [ nn.Tanh(), nn.ConvTranspose2d(nf, nc, 4, 2, 1, bias=False), ] while(True): image_size = image_size/2 sequens.append(nn.ReLU(True)) sequens.append(nn.BatchNorm2d(nf)) if image_size > 4: sequens.append(nn.ConvTranspose2d(nf * 2, nf, 4, 2, 1, bias=False)) else: if hidden_size is None: self.hidden_size = int(nf) sequens.append(nn.ConvTranspose2d(self.hidden_size, nf, int(image_size), 1, 0, bias=False)) break nf = nf*2 sequens.reverse() self.main = nn.Sequential(*sequens) def forward(self, z): z = z.unsqueeze(2).unsqueeze(2) output = self.main(z) return output def loss(self, predict, orig): batch_size = predict.shape[0] a = predict.view(batch_size, -1) b = orig.view(batch_size, -1) L = F.mse_loss(a, b, reduction='sum') return L class CnnVae(nn.Module): def __init__(self, learning_rate, image_size, label_size, nf, hidden_size=None, nc=3): super(CnnVae, self).__init__() self.encoder = Encoder(image_size, nf, hidden_size, nc) self.decoder = Decoder(image_size, nf, hidden_size, nc) self.image_size = image_size self.nc = nc self.label_size = label_size self.hidden_size = self.encoder.hidden_size self.learning_rate = learning_rate self.fc1 = nn.Linear(self.hidden_size, self.hidden_size) self.fc2 = nn.Linear(self.hidden_size, self.hidden_size) self.M = nn.Parameter(torch.empty(label_size, self.hidden_size)) nn.init.xavier_normal_(self.M) def encode(self, x): h = self.encoder(x) mu = self.fc1(h) logvar = self.fc2(h) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5*logvar) eps = torch.randn_like(std) return mu + eps*std def forward(self, x): # breakpoint() mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) prod = self.decoder(z) outputs = {'output': prod} # DISENTANGLER_DECODER_OUTPUT # outputs[DISENTANGLER_ATTRIBUTE_OUTPUT] = attr outputs['embedding'] = z outputs['vae_mu'] = mu outputs['vae_logvar'] = logvar # return prod, z, mu, logvar return outputs def _loss_vae(self, mu, logvar): # https://arxiv.org/abs/1312.6114 # KLD = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return KLD def _loss_msp(self, label, z): labels_one_hot = F.one_hot(label, num_classes=self.label_size) labels_one_hot = labels_one_hot.to(dtype=torch.float32) labels_one_hot[labels_one_hot == 0.0] = -1 L1 = F.mse_loss((z @ self.M.t()).view(-1), labels_one_hot.view(-1), reduction="none").sum() L2 = F.mse_loss((labels_one_hot @ self.M).view(-1), z.view(-1), reduction="none").sum() return L1 + L2, L1, L2 def loss(self, prod, orgi, label, z, mu, logvar): L_rec = self.decoder.loss(prod, orgi) L_vae = self._loss_vae(mu, logvar) L_msp, L1_msp, L2_msp = self._loss_msp(label, z) _msp_weight = orgi.numel()/(label.numel()+z.numel()) Loss = L_rec + L_vae + L_msp * _msp_weight loss_dict = {'L1': L1_msp, 'L2': L2_msp, 'L_msp': L_msp, 'L_rec': L_rec, 'L_vae': L_vae} return Loss, loss_dict #L_rec.item(), L_vae.item(), L_msp.item() def acc(self, z, l): zl = z @ self.M.t() a = zl.clamp(-1, 1)*l*0.5+0.5 return a.round().mean().item() def predict(self, x, new_ls=None, weight=1.0): z, _ = self.encode(x) if new_ls is not None: zl = z @ self.M.t() d = torch.zeros_like(zl) for i, v in new_ls: d[:,i] = v*weight - zl[:,i] z += d @ self.M prod = self.decoder(z) return prod def predict_ex(self, x, label, new_ls=None, weight=1.0): return self.predict(x,new_ls,weight) def get_U(self, eps=1e-5): from scipy import linalg, compress # get the null matrix N of M # such that U=[M;N] is orthogonal M = self.M.detach().cpu() A = torch.zeros(M.shape[1]-M.shape[0], M.shape[1]) A = torch.cat([M, A]) u, s, vh = linalg.svd(A.numpy()) null_mask = (s <= eps) null_space = compress(null_mask, vh, axis=0) N = torch.tensor(null_space) return torch.cat([self.M, N.to(self.M.device)])