Spaces:
Running
Running
File size: 6,715 Bytes
7ab1cfa 7334270 019978b f7c5d29 7de09a8 5d80aad 67206bd 7ab1cfa 019978b 7f96593 ad99e6b eb27d90 ad99e6b 0bb4b87 201412f 0bb4b87 ad99e6b 0bb4b87 ad99e6b 0bb4b87 1e8ab2e f7c5d29 24a14ee 201412f 6e28a65 d66f370 eb3660f e418463 76a1678 99c16c0 76a1678 e418463 76a1678 99c16c0 0e9e498 99c16c0 847f6de f7c5d29 e01fd9a 5b1e10a f7c5d29 6e28a65 019978b e01fd9a 7334270 e01fd9a f7c5d29 e01fd9a eb3660f e01fd9a eb3660f a303dc5 7c08841 eb3660f 7de09a8 7f7361b eb3660f 7334270 eb3660f 7de09a8 f92e7a9 ff2edc5 0db0a4a ff2edc5 7de09a8 5bf5a53 7e143cc c59033b f92e7a9 7de09a8 5d415a5 091242e 07f378f 808014b 67206bd 808014b 7de09a8 26b55ee 019978b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import streamlit as st
import streamlit_ext as ste
from trainer import Trainer
import random
from rdkit.Chem import Draw
from rdkit import Chem
from rdkit.Chem.Draw import IPythonConsole
import io
from PIL import Image
class DrugGENConfig:
submodel='CrossLoss'
act='relu'
z_dim=16
max_atom=45
lambda_gp=1
dim=128
depth=1
heads=8
dec_depth=1
dec_heads=8
dec_dim=128
mlp_ratio=3
warm_up_steps=0
dis_select='mlp'
init_type='normal'
batch_size=128
epoch=50
g_lr=0.00001
d_lr=0.00001
g2_lr=0.00001
d2_lr=0.00001
dropout=0.
dec_dropout=0.
n_critic=1
beta1=0.9
beta2=0.999
resume_iters=None
clipping_value=2
features=False
test_iters=10_000
num_test_epoch=30_000
inference_sample_num=1000
num_workers=1
mode="inference"
inference_iterations=100
inf_batch_size=1
protein_data_dir='data/akt'
drug_index='data/drug_smiles.index'
drug_data_dir='data/akt'
mol_data_dir='data'
log_dir='experiments/logs'
model_save_dir='experiments/models'
# inference_model=""
sample_dir='experiments/samples'
result_dir="experiments/tboard_output"
dataset_file="chembl45_train.pt"
drug_dataset_file="akt_train.pt"
raw_file='data/chembl_train.smi'
drug_raw_file="data/akt_train.smi"
inf_dataset_file="chembl45_test.pt"
inf_drug_dataset_file='akt_test.pt'
inf_raw_file='data/chembl_test.smi'
inf_drug_raw_file="data/akt_test.smi"
log_sample_step=1000
set_seed=True
seed=1
resume=False
resume_epoch=None
resume_iter=None
resume_directory=None
class ProtConfig(DrugGENConfig):
submodel="Prot"
inference_model="experiments/models/Prot"
class CrossLossConfig(DrugGENConfig):
submodel="CrossLoss"
inference_model="experiments/models/CrossLoss"
class NoTargetConfig(DrugGENConfig):
submodel="NoTarget"
inference_model="experiments/models/NoTarget"
model_configs = {
"Prot": ProtConfig(),
"CrossLoss": CrossLossConfig(),
"NoTarget": NoTargetConfig(),
}
with st.sidebar:
st.title("DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
st.write("[](https://arxiv.org/abs/2302.07868) [](https://github.com/HUBioDataLab/DrugGEN)")
with st.expander("Expand to display information about models"):
st.write("""
### Model Variations
- **DrugGEN-Prot**: composed of two GANs, incorporates protein features to the transformer decoder module of GAN2 (together with the de novo molecules generated by GAN1) to direct the target centric molecule design.
- **DrugGEN-CrossLoss**: composed of one GAN, the input of the GAN1 generator is the real molecules dataset and the GAN1 discriminator compares the generated molecules with the real inhibitors of the given target.
- **DrugGEN-NoTarget**: composed of one GAN, focuses on learning the chemical properties from the ChEMBL training dataset, no target-specific generation.
""")
with st.form("model_selection_from"):
model_name = st.radio(
'Select a model to make inference (DrugGEN-Prot and DrugGEN-CrossLoss models design molecules to target the AKT1 protein)',
('DrugGEN-Prot', 'DrugGEN-CrossLoss', 'DrugGEN-NoTarget')
)
model_name = model_name.replace("DrugGEN-", "")
molecule_num_input = st.number_input('Number of molecules to generate', min_value=1, max_value=100_000, value=1000, step=1)
seed_input = st.number_input("RNG seed value (can be used for reproducibility):", min_value=0, value=42, step=1)
submitted = st.form_submit_button("Start Computing")
if submitted:
# if submitted or ("submitted" in st.session_state):
# st.session_state["submitted"] = True
config = model_configs[model_name]
config.inference_sample_num = molecule_num_input
config.seed = seed_input
with st.spinner(f'Creating the trainer class instance for {model_name}...'):
trainer = Trainer(config)
with st.spinner(f'Running inference function of {model_name} (this may take a while) ...'):
results = trainer.inference()
st.success(f"Inference of {model_name} took {results['runtime']:.2f} seconds.")
with st.expander("Expand to see the generation performance scores"):
st.write("### Generation performance scores (novelty is calculated in comparison to the training dataset)")
st.success(f"Validity: {results['fraction_valid']}")
st.success(f"Uniqueness: {results['uniqueness']}")
st.success(f"Novelty: {results['novelty']}")
with open(f'experiments/inference/{model_name}/inference_drugs.txt') as f:
inference_drugs = f.read()
# st.download_button(label="Click to download generated molecules", data=inference_drugs, file_name=f'DrugGEN-{model_name}_denovo_mols.smi', mime="text/plain")
ste.download_button(label="Click to download generated molecules", data=inference_drugs, file_name=f'DrugGEN-{model_name}_denovo_mols.smi', mime="text/plain")
st.write("Structures of randomly selected 12 de novo molecules from the inference set:")
# from rdkit.Chem import Draw
# img = Draw.MolsToGridImage(mol_list, molsPerRow=5, subImgSize=(250, 250), maxMols=num_mols,
# legends=None, useSVG=True)
generated_molecule_list = inference_drugs.split("\n")
selected_molecules = random.choices(generated_molecule_list,k=12)
selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules]
# IPythonConsole.UninstallIPythonRenderer()
drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
drawOptions.prepareMolsBeforeDrawing = False
drawOptions.bondLineWidth = 1.
molecule_image = Draw.MolsToGridImage(
selected_molecules,
molsPerRow=3,
subImgSize=(250, 250),
maxMols=len(selected_molecules),
# legends=None,
returnPNG=False,
# drawOptions=drawOptions,
highlightAtomLists=None,
highlightBondLists=None,
)
print(type(molecule_image))
# print(type(molecule_image._data_and_metadata()))
molecule_image.save("result_grid.png")
# png_data = io.BytesIO()
# molecule_image.save(png_data, format='PNG')
# png_data.seek(0)
# Step 2: Read the PNG image data as a PIL image
# pil_image = Image.open(png_data)
# st.image(pil_image)
st.image(molecule_image)
else:
st.warning("Please select a model to make inference")
|