Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import itertools | |
import numpy as np | |
import pandas as pd | |
import pytorch_lightning as pl | |
import torch.nn.functional as F | |
# from contextlib import contextmanager | |
# from ldm.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer | |
# from ldm.modules.diffusionmodules.model import Encoder, Decoder | |
# from ldm.modules.distributions.distributions import DiagonalGaussianDistribution | |
from ldm.models.autoencoder import VQModel, AutoencoderKL | |
from ldm.models.disentanglement.iterative_normalization import IterNormRotation as cw_layer | |
from ldm.analysis_utils import get_CosineDistance_matrix, aggregatefrom_specimen_to_species | |
from ldm.plotting_utils import plot_heatmap_at_path | |
from ldm.util import instantiate_from_config | |
CONCEPT_DATA_KEY = "concept_data" | |
class CWmodelVQGAN(VQModel): | |
def __init__(self, **args): | |
print(args) | |
self.save_hyperparameters() | |
concept_data_args = args[CONCEPT_DATA_KEY] | |
print("Concepts params : ", concept_data_args) | |
self.concepts = instantiate_from_config(concept_data_args) | |
self.concepts.prepare_data() | |
self.concepts.setup() | |
del args[CONCEPT_DATA_KEY] | |
super().__init__(**args) | |
if not self.cw_module_infer: | |
self.encoder.norm_out = cw_layer(self.encoder.block_in) | |
print("Changed to cw layer after loading base VQGAN") | |
def training_step(self, batch, batch_idx, optimizer_idx): | |
if (batch_idx+1)%30==0 and optimizer_idx==0: | |
print('cw module') | |
self.eval() | |
with torch.no_grad(): | |
for _, concept_batch in enumerate(self.concepts.train_dataloader()): | |
for idx, concept in enumerate(concept_batch['class'].unique()): | |
concept_index = concept.item() | |
self.encoder.norm_out.mode = concept_index | |
X_var = concept_batch['image'][concept_batch['class'] == concept] | |
X_var = X_var.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) | |
X_var = torch.autograd.Variable(X_var).cuda() | |
X_var = X_var.float() | |
self(X_var) | |
break | |
self.encoder.norm_out.update_rotation_matrix() | |
self.encoder.norm_out.mode = -1 | |
self.train() | |
# breakpoint() | |
x = self.get_input(batch, self.image_key) | |
xrec, qloss = self(x, return_pred_indices=False) | |
# if optimizer_idx == 0 or (not self.loss.has_discriminator): | |
if optimizer_idx == 0: | |
# autoencode | |
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, | |
last_layer=self.get_last_layer(), split="train") | |
self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) | |
return aeloss | |
# if optimizer_idx == 1 and self.loss.has_discriminator: | |
if optimizer_idx == 1: | |
# discriminator | |
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, | |
last_layer=self.get_last_layer(), split="train") | |
self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) | |
return discloss | |
def test_step(self, batch, batch_idx): | |
x = self.get_input(batch, self.image_key) | |
h = self.encoder(x) | |
h = self.quant_conv(h) | |
class_label = batch['class'] | |
return {'z_cw': h, | |
'label': class_label, | |
'class_name': batch['class_name']} | |
# NOTE: This is kinda hacky. But ok for now for test purposes. | |
def set_test_chkpt_path(self, chkpt_path): | |
self.test_chkpt_path = chkpt_path | |
def test_epoch_end(self, in_out): | |
postfix_name = 'inference_false' | |
z_cw =torch.cat([x['z_cw'] for x in in_out], 0) | |
labels =torch.cat([x['label'] for x in in_out], 0) | |
sorting_indices = np.argsort(labels.cpu()) | |
sorted_zq_cw = z_cw[sorting_indices, :] | |
classnames = list(itertools.chain.from_iterable([x['class_name'] for x in in_out])) | |
sorted_class_names_according_to_class_indx = [classnames[i] for i in sorting_indices] | |
z_size = sorted_zq_cw.shape[-1] | |
channels = sorted_zq_cw.shape[1] | |
# breakpoint() | |
figs_folder = os.path.join('/', *self.test_chkpt_path.split('/')[:-2], 'figs/testset_agg') | |
if not os.path.exists(figs_folder): | |
os.makedirs(figs_folder) | |
sorted_zq_cw_aggregated = aggregatefrom_specimen_to_species(sorted_class_names_according_to_class_indx, sorted_zq_cw, z_size, channels) | |
z_cosine_distances = get_CosineDistance_matrix(sorted_zq_cw_aggregated) | |
plot_heatmap_at_path(z_cosine_distances.cpu(), figs_folder, self.test_chkpt_path, title=f'Cosine_distances_{postfix_name}', postfix='testset_agg') | |
z_cosine_distancess_np = z_cosine_distances.cpu().numpy() | |
df = pd.DataFrame(z_cosine_distancess_np) | |
df = df.drop(columns=[5, 6]) | |
df = df.drop([5, 6]) | |
breakpoint() | |
path_to_save = os.path.join(figs_folder, f'CW_z_cosine_distances_{postfix_name}.csv') | |
print("saved to path : ", path_to_save) | |
df.to_csv(path_to_save) | |
return None | |
class CWmodelInterface(VQModel): | |
def __init__(self, **args): | |
print(args) | |
self.save_hyperparameters() | |
concept_data_args = args[CONCEPT_DATA_KEY] | |
print("Concepts params : ", concept_data_args) | |
self.concepts = instantiate_from_config(concept_data_args) | |
self.concepts.prepare_data() | |
self.concepts.setup() | |
del args[CONCEPT_DATA_KEY] | |
super().__init__(**args) | |
if not self.cw_module_infer: | |
self.encoder.norm_out = cw_layer(self.encoder.block_in) | |
print("Changed to cw layer after loading base VQGAN") | |
def encode(self, x): | |
h = self.encoder(x) | |
h = self.quant_conv(h) | |
return h | |
def decode(self, h, force_not_quantize=False): | |
# also go through quantization layer | |
if not force_not_quantize: | |
quant, emb_loss, info = self.quantize(h) | |
else: | |
quant = h | |
quant = self.post_quant_conv(quant) | |
dec = self.decoder(quant) | |
return dec | |
class CWmodelKL(AutoencoderKL): | |
def __init__(self, **args): | |
print(args) | |
self.save_hyperparameters() | |
concept_data_args = args[CONCEPT_DATA_KEY] | |
print("Concepts params : ", concept_data_args) | |
self.concepts = instantiate_from_config(concept_data_args) | |
self.concepts.prepare_data() | |
self.concepts.setup() | |
del args[CONCEPT_DATA_KEY] | |
super().__init__(**args) | |
if not self.cw_module_infer: | |
self.encoder.norm_out = cw_layer(self.encoder.block_in) | |
print("Changed to cw layer after loading base KL Autoecoder") | |
def training_step(self, batch, batch_idx, optimizer_idx): | |
if (batch_idx+1)%30==0 and optimizer_idx==0: | |
print('cw module') | |
self.eval() | |
with torch.no_grad(): | |
for _, concept_batch in enumerate(self.concepts.train_dataloader()): | |
for idx, concept in enumerate(concept_batch['class'].unique()): | |
concept_index = concept.item() | |
self.encoder.norm_out.mode = concept_index | |
X_var = concept_batch['image'][concept_batch['class'] == concept] | |
X_var = X_var.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) | |
X_var = torch.autograd.Variable(X_var).cuda() | |
X_var = X_var.float() | |
self(X_var) | |
break | |
self.encoder.norm_out.update_rotation_matrix() | |
self.encoder.norm_out.mode = -1 | |
self.train() | |
# breakpoint() | |
inputs = self.get_input(batch, self.image_key) | |
reconstructions, posterior = self(inputs) | |
if optimizer_idx == 0: | |
# autoencode | |
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, | |
last_layer=self.get_last_layer(), split="train") | |
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) | |
return aeloss | |
if optimizer_idx == 1: | |
# discriminator | |
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, | |
last_layer=self.get_last_layer(), split="train") | |
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) | |
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) | |
return discloss | |