Spaces:
Running
Running
from scripts.constants import BASERECLOSS | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import pytorch_lightning as pl | |
from ldm.models.M_ModelAE_Cnn import CnnVae as LSFDisentangler | |
from main import instantiate_from_config | |
# from ldm.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer | |
from ldm.models.autoencoder import VQModel | |
from torchinfo import summary | |
import collections | |
import torchvision.utils as vutils | |
LSFLOCONFIG_KEY = "LSF_params" | |
BASEMODEL_KEY = "basemodel" | |
VQGAN_MODEL_INPUT = 'image' | |
DISENTANGLER_DECODER_OUTPUT = 'output' | |
DISENTANGLER_ENCODER_INPUT = 'in' | |
DISENTANGLER_CLASS_OUTPUT = 'class' | |
DISENTANGLER_ATTRIBUTE_OUTPUT = 'attribute' | |
DISENTANGLER_EMBEDDING = 'embedding' | |
class LSFVQVAE(VQModel): | |
def __init__(self, **args): | |
print(args) | |
self.save_hyperparameters() | |
LSF_args = args[LSFLOCONFIG_KEY] | |
del args[LSFLOCONFIG_KEY] | |
super().__init__(**args) | |
self.freeze() | |
ckpt_path = LSF_args.get('ckpt_path', None) | |
if 'ckpt_path' in LSF_args: | |
del LSF_args['ckpt_path'] | |
self.LSF_disentangler = LSFDisentangler(**LSF_args) | |
LSF_args['ckpt_path'] = ckpt_path | |
if ckpt_path is not None: | |
self.init_from_ckpt(ckpt_path, ignore_keys=[]) | |
print('Loaded trained model at', ckpt_path) | |
self.verbose = LSF_args.get('verbose', False) | |
# self.verbose = True | |
def encode(self, x): | |
encoder_out = self.encoder(x) | |
disentangler_outputs = self.LSF_disentangler(encoder_out) | |
disentangler_out = disentangler_outputs[DISENTANGLER_DECODER_OUTPUT] | |
h = self.quant_conv(disentangler_out) | |
quant, base_quantizer_loss, info = self.quantize(h) | |
base_loss_dic = {'quantizer_loss': base_quantizer_loss} | |
in_out_disentangler = { | |
DISENTANGLER_ENCODER_INPUT: encoder_out, | |
} | |
in_out_disentangler = {**in_out_disentangler, **disentangler_outputs} | |
return quant, base_loss_dic, in_out_disentangler, info | |
def forward(self, input): | |
quant, base_loss_dic, in_out_disentangler, _ = self.encode(input) | |
dec = self.decode(quant) | |
return dec, base_loss_dic, in_out_disentangler | |
def forward_hypothetical(self, input): | |
encoder_out = self.encoder(input) | |
h = self.quant_conv(encoder_out) | |
quant, base_hypothetical_quantizer_loss, info = self.quantize(h) | |
dec = self.decode(quant) | |
return dec, base_hypothetical_quantizer_loss | |
def step(self, batch, batch_idx, prefix): | |
x = self.get_input(batch, self.image_key) | |
xrec, base_loss_dic, in_out_disentangler = self(x) | |
if self.verbose: | |
xrec_hypthetical, base_hypothetical_quantizer_loss = self.forward_hypothetical(x) | |
hypothetical_rec_loss =torch.mean(torch.abs(x.contiguous() - xrec_hypthetical.contiguous())) | |
self.log(prefix+"/base_hypothetical_rec_loss", hypothetical_rec_loss, prog_bar=False, logger=True, on_step=False, on_epoch=True) | |
self.log(prefix+"/base_hypothetical_quantizer_loss", base_hypothetical_quantizer_loss, prog_bar=False, logger=True, on_step=False, on_epoch=True) | |
# base losses | |
true_rec_loss = torch.mean(torch.abs(x.contiguous() - xrec.contiguous())) | |
self.log(prefix+ BASERECLOSS, true_rec_loss, prog_bar=False, logger=True, on_step=False, on_epoch=True) | |
self.log(prefix+"/base_quantizer_loss", base_loss_dic['quantizer_loss'], prog_bar=False, logger=True, on_step=False, on_epoch=True) | |
total_loss, LSF_losses_dict = self.LSF_disentangler.loss(in_out_disentangler[DISENTANGLER_DECODER_OUTPUT], in_out_disentangler[DISENTANGLER_ENCODER_INPUT], | |
batch['class'], in_out_disentangler['embedding'], in_out_disentangler['vae_mu'], in_out_disentangler['vae_logvar']) | |
if self.verbose: | |
self.log(prefix+"/disentangler_total_loss", total_loss, prog_bar=False, logger=True, on_step=True, on_epoch=True) | |
for i in LSF_losses_dict: | |
if "_f1" in i: | |
self.log(prefix+"/disentangler_LSF_"+i, LSF_losses_dict[i], prog_bar=False, logger=True, on_step=True, on_epoch=True) | |
self.log(prefix+"/disentangler_total_loss", total_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
for i in LSF_losses_dict: | |
self.log(prefix+"/disentangler_LSF_"+i, LSF_losses_dict[i], prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
self.log(prefix+"/disentangler_learning_rate", self.LSF_disentangler.learning_rate, prog_bar=False, logger=True, on_step=False, on_epoch=True) | |
# monitor for checkpoint saving is set on this | |
self.log(prefix+"/rec_loss", LSF_losses_dict['L_rec'], prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
return total_loss | |
def training_step(self, batch, batch_idx): | |
return self.step(batch, batch_idx, 'train') | |
def validation_step(self, batch, batch_idx): | |
return self.step(batch, batch_idx, 'val') | |
def configure_optimizers(self): | |
lr = self.LSF_disentangler.learning_rate | |
opt_ae = torch.optim.Adam(self.LSF_disentangler.parameters(), lr=lr) | |
return [opt_ae], [] | |
def image2encoding(self, x): | |
encoder_out = self.encoder(x) | |
mu, logvar = self.LSF_disentangler.encode(encoder_out) | |
z = self.LSF_disentangler.reparameterize(mu, logvar) | |
return z, mu, logvar | |
def encoding2image(self, z): | |
disentangler_out = self.LSF_disentangler.decoder(z) | |
h = self.quant_conv(disentangler_out) | |
quant, base_quantizer_loss, info = self.quantize(h) | |
rec = self.decode(quant) | |
return rec | |