Spaces:
Running
Running
File size: 5,842 Bytes
d39ef0a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from scripts.constants import BASERECLOSS
import torch
from torch import nn
import torch.nn.functional as F
import pytorch_lightning as pl
from ldm.models.M_ModelAE_Cnn import CnnVae as LSFDisentangler
from main import instantiate_from_config
# from ldm.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ldm.models.autoencoder import VQModel
from torchinfo import summary
import collections
import torchvision.utils as vutils
LSFLOCONFIG_KEY = "LSF_params"
BASEMODEL_KEY = "basemodel"
VQGAN_MODEL_INPUT = 'image'
DISENTANGLER_DECODER_OUTPUT = 'output'
DISENTANGLER_ENCODER_INPUT = 'in'
DISENTANGLER_CLASS_OUTPUT = 'class'
DISENTANGLER_ATTRIBUTE_OUTPUT = 'attribute'
DISENTANGLER_EMBEDDING = 'embedding'
class LSFVQVAE(VQModel):
def __init__(self, **args):
print(args)
self.save_hyperparameters()
LSF_args = args[LSFLOCONFIG_KEY]
del args[LSFLOCONFIG_KEY]
super().__init__(**args)
self.freeze()
ckpt_path = LSF_args.get('ckpt_path', None)
if 'ckpt_path' in LSF_args:
del LSF_args['ckpt_path']
self.LSF_disentangler = LSFDisentangler(**LSF_args)
LSF_args['ckpt_path'] = ckpt_path
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=[])
print('Loaded trained model at', ckpt_path)
self.verbose = LSF_args.get('verbose', False)
# self.verbose = True
def encode(self, x):
encoder_out = self.encoder(x)
disentangler_outputs = self.LSF_disentangler(encoder_out)
disentangler_out = disentangler_outputs[DISENTANGLER_DECODER_OUTPUT]
h = self.quant_conv(disentangler_out)
quant, base_quantizer_loss, info = self.quantize(h)
base_loss_dic = {'quantizer_loss': base_quantizer_loss}
in_out_disentangler = {
DISENTANGLER_ENCODER_INPUT: encoder_out,
}
in_out_disentangler = {**in_out_disentangler, **disentangler_outputs}
return quant, base_loss_dic, in_out_disentangler, info
def forward(self, input):
quant, base_loss_dic, in_out_disentangler, _ = self.encode(input)
dec = self.decode(quant)
return dec, base_loss_dic, in_out_disentangler
def forward_hypothetical(self, input):
encoder_out = self.encoder(input)
h = self.quant_conv(encoder_out)
quant, base_hypothetical_quantizer_loss, info = self.quantize(h)
dec = self.decode(quant)
return dec, base_hypothetical_quantizer_loss
def step(self, batch, batch_idx, prefix):
x = self.get_input(batch, self.image_key)
xrec, base_loss_dic, in_out_disentangler = self(x)
if self.verbose:
xrec_hypthetical, base_hypothetical_quantizer_loss = self.forward_hypothetical(x)
hypothetical_rec_loss =torch.mean(torch.abs(x.contiguous() - xrec_hypthetical.contiguous()))
self.log(prefix+"/base_hypothetical_rec_loss", hypothetical_rec_loss, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log(prefix+"/base_hypothetical_quantizer_loss", base_hypothetical_quantizer_loss, prog_bar=False, logger=True, on_step=False, on_epoch=True)
# base losses
true_rec_loss = torch.mean(torch.abs(x.contiguous() - xrec.contiguous()))
self.log(prefix+ BASERECLOSS, true_rec_loss, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log(prefix+"/base_quantizer_loss", base_loss_dic['quantizer_loss'], prog_bar=False, logger=True, on_step=False, on_epoch=True)
total_loss, LSF_losses_dict = self.LSF_disentangler.loss(in_out_disentangler[DISENTANGLER_DECODER_OUTPUT], in_out_disentangler[DISENTANGLER_ENCODER_INPUT],
batch['class'], in_out_disentangler['embedding'], in_out_disentangler['vae_mu'], in_out_disentangler['vae_logvar'])
if self.verbose:
self.log(prefix+"/disentangler_total_loss", total_loss, prog_bar=False, logger=True, on_step=True, on_epoch=True)
for i in LSF_losses_dict:
if "_f1" in i:
self.log(prefix+"/disentangler_LSF_"+i, LSF_losses_dict[i], prog_bar=False, logger=True, on_step=True, on_epoch=True)
self.log(prefix+"/disentangler_total_loss", total_loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
for i in LSF_losses_dict:
self.log(prefix+"/disentangler_LSF_"+i, LSF_losses_dict[i], prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log(prefix+"/disentangler_learning_rate", self.LSF_disentangler.learning_rate, prog_bar=False, logger=True, on_step=False, on_epoch=True)
# monitor for checkpoint saving is set on this
self.log(prefix+"/rec_loss", LSF_losses_dict['L_rec'], prog_bar=True, logger=True, on_step=True, on_epoch=True)
return total_loss
def training_step(self, batch, batch_idx):
return self.step(batch, batch_idx, 'train')
def validation_step(self, batch, batch_idx):
return self.step(batch, batch_idx, 'val')
def configure_optimizers(self):
lr = self.LSF_disentangler.learning_rate
opt_ae = torch.optim.Adam(self.LSF_disentangler.parameters(), lr=lr)
return [opt_ae], []
def image2encoding(self, x):
encoder_out = self.encoder(x)
mu, logvar = self.LSF_disentangler.encode(encoder_out)
z = self.LSF_disentangler.reparameterize(mu, logvar)
return z, mu, logvar
def encoding2image(self, z):
disentangler_out = self.LSF_disentangler.decoder(z)
h = self.quant_conv(disentangler_out)
quant, base_quantizer_loss, info = self.quantize(h)
rec = self.decode(quant)
return rec
|