phylo-diffusion / ldm /models /LSFautoencoder.py
mridulk's picture
added models
d39ef0a
raw
history blame
5.84 kB
from scripts.constants import BASERECLOSS
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from ldm.models.M_ModelAE_Cnn import CnnVae as LSFDisentangler
from main import instantiate_from_config
# from ldm.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.models.autoencoder import VQModel
from torchinfo import summary
import collections
import torchvision.utils as vutils
LSFLOCONFIG_KEY = "LSF_params"
BASEMODEL_KEY = "basemodel"
VQGAN_MODEL_INPUT = 'image'
DISENTANGLER_DECODER_OUTPUT = 'output'
DISENTANGLER_ENCODER_INPUT = 'in'
DISENTANGLER_CLASS_OUTPUT = 'class'
DISENTANGLER_ATTRIBUTE_OUTPUT = 'attribute'
DISENTANGLER_EMBEDDING = 'embedding'
class LSFVQVAE(VQModel):
def __init__(self, **args):
print(args)
self.save_hyperparameters()
LSF_args = args[LSFLOCONFIG_KEY]
del args[LSFLOCONFIG_KEY]
super().__init__(**args)
self.freeze()
ckpt_path = LSF_args.get('ckpt_path', None)
if 'ckpt_path' in LSF_args:
del LSF_args['ckpt_path']
self.LSF_disentangler = LSFDisentangler(**LSF_args)
LSF_args['ckpt_path'] = ckpt_path
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=[])
print('Loaded trained model at', ckpt_path)
self.verbose = LSF_args.get('verbose', False)
# self.verbose = True
def encode(self, x):
encoder_out = self.encoder(x)
disentangler_outputs = self.LSF_disentangler(encoder_out)
disentangler_out = disentangler_outputs[DISENTANGLER_DECODER_OUTPUT]
h = self.quant_conv(disentangler_out)
quant, base_quantizer_loss, info = self.quantize(h)
base_loss_dic = {'quantizer_loss': base_quantizer_loss}
in_out_disentangler = {
DISENTANGLER_ENCODER_INPUT: encoder_out,
}
in_out_disentangler = {**in_out_disentangler, **disentangler_outputs}
return quant, base_loss_dic, in_out_disentangler, info
def forward(self, input):
quant, base_loss_dic, in_out_disentangler, _ = self.encode(input)
dec = self.decode(quant)
return dec, base_loss_dic, in_out_disentangler
def forward_hypothetical(self, input):
encoder_out = self.encoder(input)
h = self.quant_conv(encoder_out)
quant, base_hypothetical_quantizer_loss, info = self.quantize(h)
dec = self.decode(quant)
return dec, base_hypothetical_quantizer_loss
def step(self, batch, batch_idx, prefix):
x = self.get_input(batch, self.image_key)
xrec, base_loss_dic, in_out_disentangler = self(x)
if self.verbose:
xrec_hypthetical, base_hypothetical_quantizer_loss = self.forward_hypothetical(x)
hypothetical_rec_loss =torch.mean(torch.abs(x.contiguous() - xrec_hypthetical.contiguous()))
self.log(prefix+"/base_hypothetical_rec_loss", hypothetical_rec_loss, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log(prefix+"/base_hypothetical_quantizer_loss", base_hypothetical_quantizer_loss, prog_bar=False, logger=True, on_step=False, on_epoch=True)
# base losses
true_rec_loss = torch.mean(torch.abs(x.contiguous() - xrec.contiguous()))
self.log(prefix+ BASERECLOSS, true_rec_loss, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log(prefix+"/base_quantizer_loss", base_loss_dic['quantizer_loss'], prog_bar=False, logger=True, on_step=False, on_epoch=True)
total_loss, LSF_losses_dict = self.LSF_disentangler.loss(in_out_disentangler[DISENTANGLER_DECODER_OUTPUT], in_out_disentangler[DISENTANGLER_ENCODER_INPUT],
batch['class'], in_out_disentangler['embedding'], in_out_disentangler['vae_mu'], in_out_disentangler['vae_logvar'])
if self.verbose:
self.log(prefix+"/disentangler_total_loss", total_loss, prog_bar=False, logger=True, on_step=True, on_epoch=True)
for i in LSF_losses_dict:
if "_f1" in i:
self.log(prefix+"/disentangler_LSF_"+i, LSF_losses_dict[i], prog_bar=False, logger=True, on_step=True, on_epoch=True)
self.log(prefix+"/disentangler_total_loss", total_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
for i in LSF_losses_dict:
self.log(prefix+"/disentangler_LSF_"+i, LSF_losses_dict[i], prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log(prefix+"/disentangler_learning_rate", self.LSF_disentangler.learning_rate, prog_bar=False, logger=True, on_step=False, on_epoch=True)
# monitor for checkpoint saving is set on this
self.log(prefix+"/rec_loss", LSF_losses_dict['L_rec'], prog_bar=True, logger=True, on_step=True, on_epoch=True)
return total_loss
def training_step(self, batch, batch_idx):
return self.step(batch, batch_idx, 'train')
def validation_step(self, batch, batch_idx):
return self.step(batch, batch_idx, 'val')
def configure_optimizers(self):
lr = self.LSF_disentangler.learning_rate
opt_ae = torch.optim.Adam(self.LSF_disentangler.parameters(), lr=lr)
return [opt_ae], []
def image2encoding(self, x):
encoder_out = self.encoder(x)
mu, logvar = self.LSF_disentangler.encode(encoder_out)
z = self.LSF_disentangler.reparameterize(mu, logvar)
return z, mu, logvar
def encoding2image(self, z):
disentangler_out = self.LSF_disentangler.decoder(z)
h = self.quant_conv(disentangler_out)
quant, base_quantizer_loss, info = self.quantize(h)
rec = self.decode(quant)
return rec