from scripts.import_utils import instantiate_from_config from scripts.modules.losses.phyloloss import get_loss_name from scripts.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from scripts.models.vqgan import VQModel from scripts.models.phyloautoencoder import PhyloVQVAE from scripts.analysis_utils import Embedding_Code_converter import scripts.constants as CONSTANTS import torch from torch import nn import numpy from torchinfo import summary import itertools import math class PhyloLDM(PhyloVQVAE): def __init__(self, **args): print(args) # For wandb self.save_hyperparameters() self.freeze() # # self.phylo_disentangler = PhyloDisentangler(**phylo_args) # self.phylo_disentangler = PhyloDisentanglerConv(**phylo_args) # self.verbose = phylo_args.get('verbose', False)