import torch; torch.manual_seed(0)
import torch.utils
from torch.utils.data import DataLoader
import torch.distributions
import torch.nn as nn
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
from src.cocktails.representation_learning.dataset import MyDataset, get_representation_from_ingredient, get_max_n_ingredients
import json
import pandas as pd
import numpy as np
import os
from src.cocktails.representation_learning.vae_model import get_vae_model
from src.cocktails.config import COCKTAILS_CSV_DATA, FULL_COCKTAIL_REP_PATH, EXPERIMENT_PATH
from src.cocktails.utilities.cocktail_utilities import get_bunch_of_rep_keys
from src.cocktails.utilities.ingredients_utilities import ingredient_profiles
from resource import getrusage
from resource import RUSAGE_SELF
import gc
gc.collect(2)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_params():
    data = pd.read_csv(COCKTAILS_CSV_DATA)
    max_ingredients, ingredient_set, liquor_set, liqueur_set = get_max_n_ingredients(data)
    num_ingredients = len(ingredient_set)
    rep_keys = get_bunch_of_rep_keys()['custom']
    ing_keys = [k.split(' ')[1] for k in rep_keys]
    ing_keys.remove('volume')
    nb_ing_categories = len(set(ingredient_profiles['type']))
    category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))

    params = dict(trial_id='test',
                  save_path=EXPERIMENT_PATH + "/deepset_vae/",
                  nb_epochs=2000,
                  print_every=50,
                  plot_every=100,
                  batch_size=64,
                  lr=0.001,
                  dropout=0.,
                  nb_epoch_switch_beta=600,
                  latent_dim=10,
                  beta_vae=0.2,
                  ing_keys=ing_keys,
                  nb_ingredients=len(ingredient_set),
                  hidden_dims_ingredients=[128],
                  hidden_dims_cocktail=[32],
                  hidden_dims_decoder=[32],
                  agg='mean',
                  activation='relu',
                  auxiliaries_dict=dict(categories=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
                                        glasses=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
                                        prep_type=dict(weight=0, type='classif', final_activ=None, dim_output=len(set(data['category']))),
                                        cocktail_reps=dict(weight=0, type='regression', final_activ=None, dim_output=13),
                                        volume=dict(weight=0, type='regression', final_activ='relu',  dim_output=1),
                                        taste_reps=dict(weight=0, type='regression', final_activ='relu', dim_output=2),
                                        ingredients_presence=dict(weight=0, type='multiclassif', final_activ=None, dim_output=num_ingredients)),
                  category_encodings=category_encodings
                  )
    # params = dict(trial_id='test',
    #               save_path=EXPERIMENT_PATH + "/deepset_vae/",
    #               nb_epochs=1000,
    #               print_every=50,
    #               plot_every=100,
    #               batch_size=64,
    #               lr=0.001,
    #               dropout=0.,
    #               nb_epoch_switch_beta=500,
    #               latent_dim=64,
    #               beta_vae=0.3,
    #               ing_keys=ing_keys,
    #               nb_ingredients=len(ingredient_set),
    #               hidden_dims_ingredients=[128],
    #               hidden_dims_cocktail=[128, 128],
    #               hidden_dims_decoder=[128, 128],
    #               agg='mean',
    #               activation='mish',
    #               auxiliaries_dict=dict(categories=dict(weight=0.5, type='classif', final_activ=None, dim_output=len(set(data['subcategory']))),
    #                                     glasses=dict(weight=0.03, type='classif', final_activ=None, dim_output=len(set(data['glass']))),
    #                                     prep_type=dict(weight=0.02, type='classif', final_activ=None, dim_output=len(set(data['category']))),
    #                                     cocktail_reps=dict(weight=1, type='regression', final_activ=None, dim_output=13),
    #                                     volume=dict(weight=1, type='regression', final_activ='relu',  dim_output=1),
    #                                     taste_reps=dict(weight=1, type='regression', final_activ='relu', dim_output=2),
    #                                     ingredients_presence=dict(weight=1.5, type='multiclassif', final_activ=None, dim_output=num_ingredients)),
    #               category_encodings=category_encodings
    #               )
    water_rep, indexes_to_normalize = get_representation_from_ingredient(ingredients=['water'], quantities=[1],
                                                                         max_q_per_ing=dict(zip(ingredient_set, [1] * num_ingredients)), index=0,
                                                                         params=params)
    dim_rep_ingredient = water_rep.size
    params['indexes_ing_to_normalize'] = indexes_to_normalize
    params['deepset_latent_dim'] = dim_rep_ingredient * max_ingredients
    params['input_dim'] = dim_rep_ingredient
    params['dim_rep_ingredient'] = dim_rep_ingredient
    params = compute_expe_name_and_save_path(params)
    del params['category_encodings']  # to dump
    with open(params['save_path'] + 'params.json', 'w') as f:
        json.dump(params, f)

    params = complete_params(params)
    return params

def complete_params(params):
    data = pd.read_csv(COCKTAILS_CSV_DATA)
    cocktail_reps = np.loadtxt(FULL_COCKTAIL_REP_PATH)
    nb_ing_categories = len(set(ingredient_profiles['type']))
    category_encodings = dict(zip(sorted(set(ingredient_profiles['type'])), np.eye(nb_ing_categories)))
    params['cocktail_reps'] = cocktail_reps
    params['raw_data'] = data
    params['category_encodings'] = category_encodings
    return params

def compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data):
    losses = dict()
    accuracies = dict()
    other_metrics = dict()
    for i_k, k in enumerate(auxiliaries_str):
        # get ground truth
        # compute loss
        if k == 'volume':
            outputs[i_k] = outputs[i_k].flatten()
        ground_truth = auxiliaries[k]
        if ground_truth.dtype == torch.float64:
            losses[k] = loss_functions[k](outputs[i_k], ground_truth.float()).float()
        elif ground_truth.dtype == torch.int64:
            if str(loss_functions[k]) != "BCEWithLogitsLoss()":
                losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.long()).float()
            else:
                losses[k] = loss_functions[k](outputs[i_k].float(), ground_truth.float()).float()
        else:
            losses[k] = loss_functions[k](outputs[i_k], ground_truth).float()
        # compute accuracies
        if str(loss_functions[k]) == 'CrossEntropyLoss()':
            bs, n_options = outputs[i_k].shape
            predicted = outputs[i_k].argmax(dim=1).detach().numpy()
            true = ground_truth.int().detach().numpy()
            confusion_matrix = np.zeros([n_options, n_options])
            for i in range(bs):
                confusion_matrix[true[i], predicted[i]] += 1
            acc = confusion_matrix.diagonal().sum() / bs
            for i in range(n_options):
                if confusion_matrix[i].sum() != 0:
                    confusion_matrix[i] /= confusion_matrix[i].sum()
            other_metrics[k + '_confusion'] = confusion_matrix
            accuracies[k] = np.mean(outputs[i_k].argmax(dim=1).detach().numpy() == ground_truth.int().detach().numpy())
            assert (acc - accuracies[k]) < 1e-5

        elif str(loss_functions[k]) == 'BCEWithLogitsLoss()':
            assert k == 'ingredients_presence'
            outputs_rescaled = outputs[i_k].detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
            predicted_presence = (outputs_rescaled > 0).astype(bool)
            presence = ground_truth.detach().numpy().astype(bool)
            other_metrics[k + '_false_positive'] = np.mean(np.logical_and(predicted_presence.astype(bool), ~presence.astype(bool)))
            other_metrics[k + '_false_negative'] = np.mean(np.logical_and(~predicted_presence.astype(bool), presence.astype(bool)))
            accuracies[k] = np.mean(predicted_presence == presence)  # accuracy for multi class labeling
        elif str(loss_functions[k]) == 'MSELoss()':
            accuracies[k] = np.nan
        else:
            raise ValueError
    return losses, accuracies, other_metrics

def compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat):
    ing_q = ingredient_quantities.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
    ing_presence = (ing_q > 0)
    x_hat = x_hat.detach().numpy() * data.dataset.std_ing_quantities + data.dataset.mean_ing_quantities
    # abs_diff = np.abs(ing_q - x_hat) * data.dataset.max_ing_quantities
    abs_diff = np.abs(ing_q - x_hat)
    ing_q_abs_loss_when_present, ing_q_abs_loss_when_absent = [], []
    for i in range(ingredient_quantities.shape[0]):
        ing_q_abs_loss_when_present.append(np.mean(abs_diff[i, np.where(ing_presence[i])]))
        ing_q_abs_loss_when_absent.append(np.mean(abs_diff[i, np.where(~ing_presence[i])]))
    aux_other_metrics['ing_q_abs_loss_when_present'] = np.mean(ing_q_abs_loss_when_present)
    aux_other_metrics['ing_q_abs_loss_when_absent'] = np.mean(ing_q_abs_loss_when_absent)
    return aux_other_metrics

def run_epoch(opt, train, model, data, loss_functions, weights, params):
    if train:
        model.train()
    else:
        model.eval()

    # prepare logging of losses
    losses = dict(kld_loss=[],
                  mse_loss=[],
                  vae_loss=[],
                  volume_loss=[],
                  global_loss=[])
    accuracies = dict()
    other_metrics = dict()
    for aux in params['auxiliaries_dict'].keys():
        losses[aux] = []
        accuracies[aux] = []
    if train: opt.zero_grad()

    for d in data:
        nb_ingredients = d[0]
        batch_size = nb_ingredients.shape[0]
        x_ingredients = d[1].float()
        ingredient_quantities = d[2]
        cocktail_reps = d[3]
        auxiliaries = d[4]
        for k in auxiliaries.keys():
            if auxiliaries[k].dtype == torch.float64: auxiliaries[k] = auxiliaries[k].float()
        taste_valid = d[-1]
        x = x_ingredients.to(device)
        x_hat, z, mean, log_var, outputs, auxiliaries_str = model.forward_direct(ingredient_quantities.float())
        # get auxiliary losses and accuracies
        aux_losses, aux_accuracies, aux_other_metrics = compute_losses_and_accuracies(loss_functions, auxiliaries, auxiliaries_str, outputs, data)

        # compute vae loss
        mse_loss = ((ingredient_quantities - x_hat) ** 2).mean().float()
        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mean ** 2 - log_var.exp(), dim=1)).float()
        vae_loss = mse_loss + params['beta_vae'] * (params['latent_dim'] / params['nb_ingredients']) * kld_loss
        # compute total volume loss to train decoder
        # volume_loss = ((ingredient_quantities.sum(dim=1) - x_hat.sum(dim=1)) ** 2).mean().float()
        volume_loss = torch.FloatTensor([0])

        aux_other_metrics = compute_metric_output(aux_other_metrics, data, ingredient_quantities, x_hat)

        indexes_taste_valid = np.argwhere(taste_valid.detach().numpy()).flatten()
        if indexes_taste_valid.size > 0:
            outputs_taste = model.get_auxiliary(z[indexes_taste_valid], aux_str='taste_reps')
            gt = auxiliaries['taste_reps'][indexes_taste_valid]
            factor_loss = indexes_taste_valid.size / (0.3 * batch_size)# factor on the loss: if same ratio as actual dataset factor = 1 if there is less data, then the factor decreases, more data, it increases
            aux_losses['taste_reps'] = (loss_functions['taste_reps'](outputs_taste, gt) * factor_loss).float()
        else:
            aux_losses['taste_reps'] = torch.FloatTensor([0]).reshape([])
        aux_accuracies['taste_reps'] = 0

        # aggregate losses
        global_loss = torch.sum(torch.cat([torch.atleast_1d(vae_loss), torch.atleast_1d(volume_loss)] + [torch.atleast_1d(aux_losses[k] * weights[k]) for k in params['auxiliaries_dict'].keys()]))
        # for k in params['auxiliaries_dict'].keys():
        #     global_loss += aux_losses[k] * weights[k]

        if train:
            global_loss.backward()
            opt.step()
            opt.zero_grad()

        # logging
        losses['global_loss'].append(float(global_loss))
        losses['mse_loss'].append(float(mse_loss))
        losses['vae_loss'].append(float(vae_loss))
        losses['volume_loss'].append(float(volume_loss))
        losses['kld_loss'].append(float(kld_loss))
        for k in params['auxiliaries_dict'].keys():
            losses[k].append(float(aux_losses[k]))
            accuracies[k].append(float(aux_accuracies[k]))
        for k in aux_other_metrics.keys():
            if k not in other_metrics.keys():
                other_metrics[k] = [aux_other_metrics[k]]
            else:
                other_metrics[k].append(aux_other_metrics[k])

    for k in losses.keys():
        losses[k] = np.mean(losses[k])
    for k in accuracies.keys():
        accuracies[k] = np.mean(accuracies[k])
    for k in other_metrics.keys():
        other_metrics[k] = np.mean(other_metrics[k], axis=0)
    return model, losses, accuracies, other_metrics

def prepare_data_and_loss(params):
    train_data = MyDataset(split='train', params=params)
    test_data = MyDataset(split='test', params=params)

    train_data_loader = DataLoader(train_data, batch_size=params['batch_size'], shuffle=True)
    test_data_loader = DataLoader(test_data, batch_size=params['batch_size'], shuffle=True)

    loss_functions = dict()
    weights = dict()
    for k in sorted(params['auxiliaries_dict'].keys()):
        if params['auxiliaries_dict'][k]['type'] == 'classif':
            if k == 'glasses':
                classif_weights = train_data.glasses_weights
            elif k == 'prep_type':
                classif_weights = train_data.prep_types_weights
            elif k == 'categories':
                classif_weights = train_data.categories_weights
            else:
                raise ValueError
            loss_functions[k] = nn.CrossEntropyLoss(torch.FloatTensor(classif_weights))
        elif params['auxiliaries_dict'][k]['type'] == 'multiclassif':
            loss_functions[k] = nn.BCEWithLogitsLoss()
        elif params['auxiliaries_dict'][k]['type'] == 'regression':
            loss_functions[k] = nn.MSELoss()
        else:
            raise ValueError
        weights[k] = params['auxiliaries_dict'][k]['weight']


    return loss_functions, train_data_loader, test_data_loader, weights

def print_losses(train, losses, accuracies, other_metrics):
    keyword = 'Train' if train else 'Eval'
    print(f'\t{keyword} logs:')
    keys = ['global_loss', 'vae_loss', 'mse_loss', 'kld_loss', 'volume_loss']
    for k in keys:
        print(f'\t\t{k} - Loss: {losses[k]:.2f}')
    for k in sorted(accuracies.keys()):
        print(f'\t\t{k} (aux) - Loss: {losses[k]:.2f}, Acc: {accuracies[k]:.2f}')
    for k in sorted(other_metrics.keys()):
        if 'confusion' not in k:
            print(f'\t\t{k} - {other_metrics[k]:.2f}')


def run_experiment(params, verbose=True):
    loss_functions, train_data_loader, test_data_loader, weights = prepare_data_and_loss(params)
    params['filter_decoder_output'] = train_data_loader.dataset.filter_decoder_output

    model_params = [params[k] for k in ["input_dim", "deepset_latent_dim", "hidden_dims_ingredients", "activation",
                               "hidden_dims_cocktail", "hidden_dims_decoder", "nb_ingredients", "latent_dim", "agg", "dropout", "auxiliaries_dict",
                                        "filter_decoder_output"]]
    model = get_vae_model(*model_params)
    opt = torch.optim.AdamW(model.parameters(), lr=params['lr'])


    all_train_losses = []
    all_eval_losses = []
    all_train_accuracies = []
    all_eval_accuracies = []
    all_eval_other_metrics = []
    all_train_other_metrics = []
    best_loss = np.inf
    model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
                                                                        weights=weights, params=params)
    all_eval_losses.append(eval_losses)
    all_eval_accuracies.append(eval_accuracies)
    all_eval_other_metrics.append(eval_other_metrics)
    if verbose: print(f'\n--------\nEpoch #0')
    if verbose: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
    for epoch in range(params['nb_epochs']):
        if verbose and (epoch + 1) % params['print_every'] == 0: print(f'\n--------\nEpoch #{epoch+1}')
        model, train_losses, train_accuracies, train_other_metrics = run_epoch(opt=opt, train=True, model=model, data=train_data_loader, loss_functions=loss_functions,
                                                                            weights=weights, params=params)
        if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=True, accuracies=train_accuracies, losses=train_losses, other_metrics=train_other_metrics)
        model, eval_losses, eval_accuracies, eval_other_metrics = run_epoch(opt=opt, train=False, model=model, data=test_data_loader, loss_functions=loss_functions,
                                                                            weights=weights, params=params)
        if verbose and (epoch + 1) % params['print_every'] == 0: print_losses(train=False, accuracies=eval_accuracies, losses=eval_losses, other_metrics=eval_other_metrics)
        if eval_losses['global_loss'] < best_loss:
            best_loss = eval_losses['global_loss']
            if verbose: print(f'Saving new best model with loss {best_loss:.2f}')
            torch.save(model.state_dict(), params['save_path'] + f'checkpoint_best.save')

        # log
        all_train_losses.append(train_losses)
        all_train_accuracies.append(train_accuracies)
        all_eval_losses.append(eval_losses)
        all_eval_accuracies.append(eval_accuracies)
        all_eval_other_metrics.append(eval_other_metrics)
        all_train_other_metrics.append(train_other_metrics)

        # if epoch == params['nb_epoch_switch_beta']:
        #     params['beta_vae'] = 2.5
            # params['auxiliaries_dict']['prep_type']['weight'] /= 10
            # params['auxiliaries_dict']['glasses']['weight'] /= 10

        if (epoch + 1) % params['plot_every'] == 0:

            plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
                         all_eval_losses, all_eval_accuracies, all_eval_other_metrics, params['plot_path'], weights)

    return model

def plot_results(all_train_losses, all_train_accuracies, all_train_other_metrics,
                 all_eval_losses, all_eval_accuracies, all_eval_other_metrics, plot_path, weights):

    steps = np.arange(len(all_eval_accuracies))

    loss_keys = sorted(all_train_losses[0].keys())
    acc_keys = sorted(all_train_accuracies[0].keys())
    metrics_keys = sorted(all_train_other_metrics[0].keys())

    plt.figure()
    plt.title('Train losses')
    for k in loss_keys:
        factor = 1 if k == 'mse_loss' else 1
        if k not in weights.keys():
            plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)
        else:
            if weights[k] != 0:
                plt.plot(steps[1:], [train_loss[k] * factor for train_loss in all_train_losses], label=k)

    plt.legend()
    plt.ylim([0, 4])
    plt.savefig(plot_path + 'train_losses.png', dpi=200)
    fig = plt.gcf()
    plt.close(fig)

    plt.figure()
    plt.title('Train accuracies')
    for k in acc_keys:
        if weights[k] != 0:
            plt.plot(steps[1:], [train_acc[k] for train_acc in all_train_accuracies], label=k)
    plt.legend()
    plt.ylim([0, 1])
    plt.savefig(plot_path + 'train_acc.png', dpi=200)
    fig = plt.gcf()
    plt.close(fig)

    plt.figure()
    plt.title('Train other metrics')
    for k in metrics_keys:
        if 'confusion' not in k and 'presence' in k:
            plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
    plt.legend()
    plt.ylim([0, 1])
    plt.savefig(plot_path + 'train_ing_presence_errors.png', dpi=200)
    fig = plt.gcf()
    plt.close(fig)

    plt.figure()
    plt.title('Train other metrics')
    for k in metrics_keys:
        if 'confusion' not in k and 'presence' not in k:
            plt.plot(steps[1:], [train_metric[k] for train_metric in all_train_other_metrics], label=k)
    plt.legend()
    plt.savefig(plot_path + 'train_ing_q_error.png', dpi=200)
    fig = plt.gcf()
    plt.close(fig)

    plt.figure()
    plt.title('Eval losses')
    for k in loss_keys:
        factor = 1 if k == 'mse_loss' else 1
        if k not in weights.keys():
            plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
        else:
            if weights[k] != 0:
                plt.plot(steps, [eval_loss[k] * factor for eval_loss in all_eval_losses], label=k)
    plt.legend()
    plt.ylim([0, 4])
    plt.savefig(plot_path + 'eval_losses.png', dpi=200)
    fig = plt.gcf()
    plt.close(fig)

    plt.figure()
    plt.title('Eval accuracies')
    for k in acc_keys:
        if weights[k] != 0:
            plt.plot(steps, [eval_acc[k] for eval_acc in all_eval_accuracies], label=k)
    plt.legend()
    plt.ylim([0, 1])
    plt.savefig(plot_path + 'eval_acc.png', dpi=200)
    fig = plt.gcf()
    plt.close(fig)

    plt.figure()
    plt.title('Eval other metrics')
    for k in metrics_keys:
        if 'confusion' not in k and 'presence' in k:
            plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
    plt.legend()
    plt.ylim([0, 1])
    plt.savefig(plot_path + 'eval_ing_presence_errors.png', dpi=200)
    fig = plt.gcf()
    plt.close(fig)

    plt.figure()
    plt.title('Eval other metrics')
    for k in metrics_keys:
        if 'confusion' not in k and 'presence' not in k:
            plt.plot(steps, [eval_metric[k] for eval_metric in all_eval_other_metrics], label=k)
    plt.legend()
    plt.savefig(plot_path + 'eval_ing_q_error.png', dpi=200)
    fig = plt.gcf()
    plt.close(fig)


    for k in metrics_keys:
        if 'confusion' in k:
            plt.figure()
            plt.title(k)
            plt.ylabel('True')
            plt.xlabel('Predicted')
            plt.imshow(all_eval_other_metrics[-1][k], vmin=0, vmax=1)
            plt.colorbar()
            plt.savefig(plot_path + f'eval_{k}.png', dpi=200)
            fig = plt.gcf()
            plt.close(fig)

    for k in metrics_keys:
        if 'confusion' in k:
            plt.figure()
            plt.title(k)
            plt.ylabel('True')
            plt.xlabel('Predicted')
            plt.imshow(all_train_other_metrics[-1][k], vmin=0, vmax=1)
            plt.colorbar()
            plt.savefig(plot_path + f'train_{k}.png', dpi=200)
            fig = plt.gcf()
            plt.close(fig)

    plt.close('all')


def get_model(model_path):

    with open(model_path + 'params.json', 'r') as f:
        params = json.load(f)
    params['save_path'] = model_path
    max_ing_quantities = np.loadtxt(params['save_path'] + 'max_ing_quantities.txt')
    mean_ing_quantities = np.loadtxt(params['save_path'] + 'mean_ing_quantities.txt')
    std_ing_quantities = np.loadtxt(params['save_path'] + 'std_ing_quantities.txt')
    min_when_present_ing_quantities = np.loadtxt(params['save_path'] + 'min_when_present_ing_quantities.txt')
    def filter_decoder_output(output):
        output = output.detach().numpy()
        output_unnormalized = output * std_ing_quantities + mean_ing_quantities
        if output.ndim == 1:
            output_unnormalized[np.where(output_unnormalized < min_when_present_ing_quantities)] = 0
        else:
            for i in range(output.shape[0]):
                output_unnormalized[i, np.where(output_unnormalized[i] < min_when_present_ing_quantities)] = 0
        return output_unnormalized.copy()
    params['filter_decoder_output'] = filter_decoder_output
    model_chkpt = model_path + "checkpoint_best.save"
    model_params = [params[k] for k in ["input_dim", "deepset_latent_dim", "hidden_dims_ingredients", "activation",
                                        "hidden_dims_cocktail", "hidden_dims_decoder", "nb_ingredients", "latent_dim", "agg", "dropout", "auxiliaries_dict",
                                        "filter_decoder_output"]]
    model = get_vae_model(*model_params)
    model.load_state_dict(torch.load(model_chkpt))
    model.eval()
    return model, filter_decoder_output, params


def compute_expe_name_and_save_path(params):
    weights_str = '['
    for aux in params['auxiliaries_dict'].keys():
        weights_str += f'{params["auxiliaries_dict"][aux]["weight"]}, '
    weights_str = weights_str[:-2] + ']'
    save_path = params['save_path'] + params["trial_id"]
    save_path += f'_lr{params["lr"]}'
    save_path += f'_betavae{params["beta_vae"]}'
    save_path += f'_bs{params["batch_size"]}'
    save_path += f'_latentdim{params["latent_dim"]}'
    save_path += f'_hding{params["hidden_dims_ingredients"]}'
    save_path += f'_hdcocktail{params["hidden_dims_cocktail"]}'
    save_path += f'_hddecoder{params["hidden_dims_decoder"]}'
    save_path += f'_agg{params["agg"]}'
    save_path += f'_activ{params["activation"]}'
    save_path += f'_w{weights_str}'
    counter = 0
    while os.path.exists(save_path + f"_{counter}"):
        counter += 1
    save_path = save_path + f"_{counter}" + '/'
    params["save_path"] = save_path
    os.makedirs(save_path)
    os.makedirs(save_path + 'plots/')
    params['plot_path'] = save_path + 'plots/'
    print(f'logging to {save_path}')
    return params



if __name__ == '__main__':
    params = get_params()
    run_experiment(params)