Spaces:
Running
Running
File size: 3,864 Bytes
7ab1cfa 019978b f7c5d29 7ab1cfa 019978b 7f96593 ad99e6b eb27d90 ad99e6b 0bb4b87 201412f 0bb4b87 ad99e6b 0bb4b87 ad99e6b 0bb4b87 1e8ab2e f7c5d29 24a14ee 201412f 6e28a65 d66f370 eb3660f 6e28a65 847f6de 6e28a65 847f6de f7c5d29 e01fd9a f7c5d29 6e28a65 019978b 6e28a65 e01fd9a f7c5d29 e01fd9a eb3660f e01fd9a eb3660f a303dc5 7c08841 eb3660f a303dc5 7c08841 eb3660f 26b55ee 019978b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import streamlit as st
from trainer import Trainer
import random
class DrugGENConfig:
submodel='CrossLoss'
act='relu'
z_dim=16
max_atom=45
lambda_gp=1
dim=128
depth=1
heads=8
dec_depth=1
dec_heads=8
dec_dim=128
mlp_ratio=3
warm_up_steps=0
dis_select='mlp'
init_type='normal'
batch_size=128
epoch=50
g_lr=0.00001
d_lr=0.00001
g2_lr=0.00001
d2_lr=0.00001
dropout=0.
dec_dropout=0.
n_critic=1
beta1=0.9
beta2=0.999
resume_iters=None
clipping_value=2
features=False
test_iters=10_000
num_test_epoch=30_000
inference_sample_num=1000
num_workers=1
mode="inference"
inference_iterations=100
inf_batch_size=1
protein_data_dir='data/akt'
drug_index='data/drug_smiles.index'
drug_data_dir='data/akt'
mol_data_dir='data'
log_dir='experiments/logs'
model_save_dir='experiments/models'
# inference_model=""
sample_dir='experiments/samples'
result_dir="experiments/tboard_output"
dataset_file="chembl45_train.pt"
drug_dataset_file="akt_train.pt"
raw_file='data/chembl_train.smi'
drug_raw_file="data/akt_train.smi"
inf_dataset_file="chembl45_test.pt"
inf_drug_dataset_file='akt_test.pt'
inf_raw_file='data/chembl_test.smi'
inf_drug_raw_file="data/akt_test.smi"
log_sample_step=1000
set_seed=True
seed=1
resume=False
resume_epoch=None
resume_iter=None
resume_directory=None
class ProtConfig(DrugGENConfig):
submodel="Prot"
inference_model="experiments/models/Prot"
class CrossLossConfig(DrugGENConfig):
submodel="CrossLoss"
inference_model="experiments/models/CrossLoss"
class NoTargetConfig(DrugGENConfig):
submodel="NoTarget"
inference_model="experiments/models/NoTarget"
model_configs = {
"Prot": ProtConfig(),
"CrossLoss": CrossLossConfig(),
"NoTarget": NoTargetConfig(),
}
with st.sidebar:
st.title("DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
st.write("[](https://arxiv.org/abs/2302.07868) [](https://github.com/HUBioDataLab/DrugGEN)")
with st.form("model_selection_from"):
model_name = st.radio(
"Select a model to make inference",
('Prot', 'CrossLoss', 'NoTarget'))
molecule_num_input = st.number_input('Number of molecules to generate', min_value=1, max_value=100_000, value=1000, step=1)
seed_input = st.number_input("Input a seed for reproducibiliy", min_value=0, value=random.randint(1, 1000), step=1)
submitted = st.form_submit_button("Start Computing")
if submitted:
config = model_configs[model_name]
config.inference_sample_num = molecule_num_input
config.seed = seed_input
with st.spinner(f'Creating the trainer class instance for {model_name}...'):
trainer = Trainer(config)
with st.spinner(f'Running inference function of {model_name} (this may take a while) ...'):
results = trainer.inference()
st.success(f"Inference of {model_name} took {results['runtime']:.2f} seconds.")
with st.expander("Expand to see scores"):
st.write(f"Fraction valid: {results['fraction_valid']}")
st.write(f"Uniqueness: {results['uniqueness']}")
st.write(f"Novelty score: {results['novelty']}")
with open(f'experiments/inference/{model_name}/inference_drugs.txt') as f:
inference_drugs = f.read()
st.download_button(label="Click to download generated molecules", data=inference_drugs, file_name=f'{model_name}_inference.smi', mime="text/plain")
else:
st.warning("Please select a model to make inference")
|