import streamlit as st
from trainer import Trainer

st.title("Hello World")

class DrugGENConfig:
    submodel='CrossLoss'
    act='relu'
    z_dim=16
    max_atom=45
    lambda_gp=1
    dim=128
    depth=1
    heads=8
    dec_depth=1
    dec_heads=8
    dec_dim=128
    mlp_ratio=3
    warm_up_steps=0
    dis_select='mlp'
    init_type='normal'
    batch_size=128
    epoch=50
    g_lr=0.00001
    d_lr=0.00001
    g2_lr=0.00001
    d2_lr0.00001
    dropout=0.
    dec_dropout=0.
    n_critic=1
    beta1=0.9
    beta2=0.999
    resume_iters=None
    clipping_value=2
    features=False
    test_iters=10_000
    num_test_epoch=30_000
    inference_sample_num=1000
    num_workers=1
    mode="inference"
    inference_iterations=100
    inf_batch_size=1
    protein_data_dir='DrugGEN/data/akt'
    drug_index='DrugGEN/data/drug_smiles.index'
    drug_data_dir='DrugGEN/data/akt'
    mol_data_dir='DrugGEN/data'
    log_dir='DrugGEN/experiments/logs'
    model_save_dir='DrugGEN/experiments/models'
    inference_model=""
    sample_dir='DrugGEN/experiments/samples'
    result_dir="DrugGEN/experiments/tboard_output"
    dataset_file="chembl45_train.pt"
    drug_dataset_file="akt_train.pt"
    raw_file='DrugGEN/data/chembl_train.smi'
    drug_raw_file="DrugGEN/data/akt_train.smi"
    inf_dataset_file="chembl45_test.pt"
    inf_drug_dataset_file='akt_test.pt'
    inf_raw_file='DrugGEN/data/chembl_test.smi'
    inf_drug_raw_file="DrugGEN/data/akt_test.smi"

    
with st.spinner('Setting up the trainer class...'):
    trainer = Trainer(DrugGENConfig())

with st.spinner('Generating Molecules...'):
    trainer.inference()