import os import torch import itertools import numpy as np import pandas as pd import pytorch_lightning as pl import torch.nn.functional as F # from contextlib import contextmanager # from ldm.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer # from ldm.modules.diffusionmodules.model import Encoder, Decoder # from ldm.modules.distributions.distributions import DiagonalGaussianDistribution from ldm.models.autoencoder import VQModel, AutoencoderKL from ldm.models.disentanglement.iterative_normalization import IterNormRotation as cw_layer from ldm.analysis_utils import get_CosineDistance_matrix, aggregatefrom_specimen_to_species from ldm.plotting_utils import plot_heatmap_at_path from ldm.util import instantiate_from_config CONCEPT_DATA_KEY = "concept_data" class CWmodelVQGAN(VQModel): def __init__(self, **args): print(args) self.save_hyperparameters() concept_data_args = args[CONCEPT_DATA_KEY] print("Concepts params : ", concept_data_args) self.concepts = instantiate_from_config(concept_data_args) self.concepts.prepare_data() self.concepts.setup() del args[CONCEPT_DATA_KEY] super().__init__(**args) if not self.cw_module_infer: self.encoder.norm_out = cw_layer(self.encoder.block_in) print("Changed to cw layer after loading base VQGAN") def training_step(self, batch, batch_idx, optimizer_idx): if (batch_idx+1)%30==0 and optimizer_idx==0: print('cw module') self.eval() with torch.no_grad(): for _, concept_batch in enumerate(self.concepts.train_dataloader()): for idx, concept in enumerate(concept_batch['class'].unique()): concept_index = concept.item() self.encoder.norm_out.mode = concept_index X_var = concept_batch['image'][concept_batch['class'] == concept] X_var = X_var.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) X_var = torch.autograd.Variable(X_var).cuda() X_var = X_var.float() self(X_var) break self.encoder.norm_out.update_rotation_matrix() self.encoder.norm_out.mode = -1 self.train() # breakpoint() x = self.get_input(batch, self.image_key) xrec, qloss = self(x, return_pred_indices=False) # if optimizer_idx == 0 or (not self.loss.has_discriminator): if optimizer_idx == 0: # autoencode aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train") self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) return aeloss # if optimizer_idx == 1 and self.loss.has_discriminator: if optimizer_idx == 1: # discriminator discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train") self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) return discloss @torch.no_grad() def test_step(self, batch, batch_idx): x = self.get_input(batch, self.image_key) h = self.encoder(x) h = self.quant_conv(h) class_label = batch['class'] return {'z_cw': h, 'label': class_label, 'class_name': batch['class_name']} # NOTE: This is kinda hacky. But ok for now for test purposes. def set_test_chkpt_path(self, chkpt_path): self.test_chkpt_path = chkpt_path @torch.no_grad() def test_epoch_end(self, in_out): postfix_name = 'inference_false' z_cw =torch.cat([x['z_cw'] for x in in_out], 0) labels =torch.cat([x['label'] for x in in_out], 0) sorting_indices = np.argsort(labels.cpu()) sorted_zq_cw = z_cw[sorting_indices, :] classnames = list(itertools.chain.from_iterable([x['class_name'] for x in in_out])) sorted_class_names_according_to_class_indx = [classnames[i] for i in sorting_indices] z_size = sorted_zq_cw.shape[-1] channels = sorted_zq_cw.shape[1] # breakpoint() figs_folder = os.path.join('/', *self.test_chkpt_path.split('/')[:-2], 'figs/testset_agg') if not os.path.exists(figs_folder): os.makedirs(figs_folder) sorted_zq_cw_aggregated = aggregatefrom_specimen_to_species(sorted_class_names_according_to_class_indx, sorted_zq_cw, z_size, channels) z_cosine_distances = get_CosineDistance_matrix(sorted_zq_cw_aggregated) plot_heatmap_at_path(z_cosine_distances.cpu(), figs_folder, self.test_chkpt_path, title=f'Cosine_distances_{postfix_name}', postfix='testset_agg') z_cosine_distancess_np = z_cosine_distances.cpu().numpy() df = pd.DataFrame(z_cosine_distancess_np) df = df.drop(columns=[5, 6]) df = df.drop([5, 6]) breakpoint() path_to_save = os.path.join(figs_folder, f'CW_z_cosine_distances_{postfix_name}.csv') print("saved to path : ", path_to_save) df.to_csv(path_to_save) return None class CWmodelInterface(VQModel): def __init__(self, **args): print(args) self.save_hyperparameters() concept_data_args = args[CONCEPT_DATA_KEY] print("Concepts params : ", concept_data_args) self.concepts = instantiate_from_config(concept_data_args) self.concepts.prepare_data() self.concepts.setup() del args[CONCEPT_DATA_KEY] super().__init__(**args) if not self.cw_module_infer: self.encoder.norm_out = cw_layer(self.encoder.block_in) print("Changed to cw layer after loading base VQGAN") def encode(self, x): h = self.encoder(x) h = self.quant_conv(h) return h def decode(self, h, force_not_quantize=False): # also go through quantization layer if not force_not_quantize: quant, emb_loss, info = self.quantize(h) else: quant = h quant = self.post_quant_conv(quant) dec = self.decoder(quant) return dec class CWmodelKL(AutoencoderKL): def __init__(self, **args): print(args) self.save_hyperparameters() concept_data_args = args[CONCEPT_DATA_KEY] print("Concepts params : ", concept_data_args) self.concepts = instantiate_from_config(concept_data_args) self.concepts.prepare_data() self.concepts.setup() del args[CONCEPT_DATA_KEY] super().__init__(**args) if not self.cw_module_infer: self.encoder.norm_out = cw_layer(self.encoder.block_in) print("Changed to cw layer after loading base KL Autoecoder") def training_step(self, batch, batch_idx, optimizer_idx): if (batch_idx+1)%30==0 and optimizer_idx==0: print('cw module') self.eval() with torch.no_grad(): for _, concept_batch in enumerate(self.concepts.train_dataloader()): for idx, concept in enumerate(concept_batch['class'].unique()): concept_index = concept.item() self.encoder.norm_out.mode = concept_index X_var = concept_batch['image'][concept_batch['class'] == concept] X_var = X_var.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) X_var = torch.autograd.Variable(X_var).cuda() X_var = X_var.float() self(X_var) break self.encoder.norm_out.update_rotation_matrix() self.encoder.norm_out.mode = -1 self.train() # breakpoint() inputs = self.get_input(batch, self.image_key) reconstructions, posterior = self(inputs) if optimizer_idx == 0: # autoencode aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train") self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) return aeloss if optimizer_idx == 1: # discriminator discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train") self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) return discloss