DrugGEN / app.py
mgyigit's picture
Update app.py
b83af40 verified
raw
history blame
15.9 kB
import gradio as gr
from inference import Inference
import PIL
from PIL import Image
import pandas as pd
import random
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
import shutil
import os
import time
class DrugGENConfig:
# Inference configuration
submodel = 'DrugGEN'
inference_model = "/home/user/app/experiments/models/DrugGEN/"
sample_num = 100
# Data configuration
inf_smiles = '/home/user/app/data/chembl_test.smi'
train_smiles = '/home/user/app/data/chembl_train.smi'
inf_batch_size = 1
mol_data_dir = '/home/user/app/data'
features = False
# Model configuration
act = 'relu'
max_atom = 45
dim = 128
depth = 1
heads = 8
mlp_ratio = 3
dropout = 0.
# Seed configuration
set_seed = True
seed = 10
disable_correction = False
class DrugGENAKT1Config(DrugGENConfig):
submodel = 'DrugGEN'
inference_model = "/home/user/app/experiments/models/DrugGEN-akt1/"
train_drug_smiles = '/home/user/app/data/akt_train.smi'
max_atom = 45
class DrugGENCDK2Config(DrugGENConfig):
submodel = 'DrugGEN'
inference_model = "/home/user/app/experiments/models/DrugGEN-cdk2/"
train_drug_smiles = '/home/user/app/data/cdk2_train.smi'
max_atom = 38
class NoTargetConfig(DrugGENConfig):
submodel = "NoTarget"
inference_model = "/home/user/app/experiments/models/NoTarget/"
model_configs = {
"DrugGEN-AKT1": DrugGENAKT1Config(),
"DrugGEN-CDK2": DrugGENCDK2Config(),
"DrugGEN-NoTarget": NoTargetConfig(),
}
def run_inference(mode: str, model_name: str, num_molecules: int, seed_num: str, custom_smiles: str):
"""
Depending on the selected mode, either generate new molecules or evaluate provided SMILES.
Returns:
image, file_path, basic_metrics, advanced_metrics
"""
config = model_configs[model_name]
if mode == "Custom Input SMILES":
# Process the custom input SMILES
smiles_list = [s.strip() for s in custom_smiles.strip().splitlines() if s.strip() != ""]
if len(smiles_list) > 100:
raise gr.Error("You have provided more than the allowed limit of 100 molecules. Please provide 100 or fewer.")
# Write the custom SMILES to a temporary file and update config
config.seed = random.randint(0, 10000)
temp_input_file = f"custom_input{config.seed}.smi"
with open(temp_input_file, "w") as f:
for s in smiles_list:
f.write(s + "\n")
config.inf_smiles = temp_input_file
config.sample_num = len(smiles_list)
# Always use a random seed for custom mode
else:
# Classical Generation mode
config.sample_num = num_molecules
if config.sample_num > 200:
raise gr.Error("You have requested to generate more than the allowed limit of 200 molecules. Please reduce your request to 200 or fewer.")
if seed_num is None or seed_num.strip() == "":
config.seed = random.randint(0, 10000)
else:
try:
config.seed = int(seed_num)
except ValueError:
raise gr.Error("The seed must be an integer value!")
# Adjust model name for the inference if not using NoTarget
if model_name != "DrugGEN-NoTarget":
target_model_name = "DrugGEN"
else:
target_model_name = "NoTarget"
inferer = Inference(config)
start_time = time.time()
scores = inferer.inference() # This returns a DataFrame with specific columns
et = time.time() - start_time
# Create basic metrics dataframe
basic_metrics = pd.DataFrame({
"Validity": [scores["validity"].iloc[0]],
"Uniqueness": [scores["uniqueness"].iloc[0]],
"Novelty (Train)": [scores["novelty"].iloc[0]],
"Novelty (Inference)": [scores["novelty_test"].iloc[0]],
"Novelty (Real Inhibitors)": [scores["drug_novelty"].iloc[0]],
"Runtime (s)": [round(et, 2)]
})
# Create advanced metrics dataframe
advanced_metrics = pd.DataFrame({
"QED": [scores["qed"].iloc[0]],
"SA Score": [scores["sa"].iloc[0]],
"Internal Diversity": [scores["IntDiv"].iloc[0]],
"SNN ChEMBL": [scores["snn_chembl"].iloc[0]],
"SNN Real Inhibitors": [scores["snn_drug"].iloc[0]],
"Average Length": [scores["max_len"].iloc[0]]
})
# Process the output file from inference
output_file_path = f'/home/user/app/experiments/inference/{target_model_name}/inference_drugs.txt'
new_path = f'{target_model_name}_denovo_mols.smi'
os.rename(output_file_path, new_path)
with open(new_path) as f:
inference_drugs = f.read()
generated_molecule_list = inference_drugs.split("\n")[:-1]
# Randomly select up to 9 molecules for display
rng = random.Random(config.seed)
if len(generated_molecule_list) > 9:
selected_smiles = rng.choices(generated_molecule_list, k=9)
else:
selected_smiles = generated_molecule_list
selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_smiles if Chem.MolFromSmiles(mol) is not None]
drawOptions = Draw.rdMolDraw2D.MolDrawOptions()
drawOptions.prepareMolsBeforeDrawing = False
drawOptions.bondLineWidth = 0.5
molecule_image = Draw.MolsToGridImage(
selected_molecules,
molsPerRow=3,
subImgSize=(400, 400),
maxMols=len(selected_molecules),
returnPNG=False,
drawOptions=drawOptions,
highlightAtomLists=None,
highlightBondLists=None,
)
return molecule_image, new_path, basic_metrics, advanced_metrics
with gr.Blocks(theme=gr.themes.Ocean()) as demo:
# Add custom CSS for styling
gr.HTML("""
<style>
#metrics-container {
border: 1px solid rgba(128, 128, 128, 0.3);
border-radius: 8px;
padding: 15px;
margin-top: 15px;
margin-bottom: 15px;
background-color: rgba(255, 255, 255, 0.05);
}
</style>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks")
gr.HTML("""
<div style="display: flex; gap: 10px; margin-bottom: 15px;">
<!-- arXiv badge -->
<a href="https://arxiv.org/abs/2302.07868" target="_blank" style="text-decoration: none;">
<div style="
display: inline-block;
background-color: #b31b1b;
color: #ffffff !important;
padding: 5px 10px;
border-radius: 5px;
font-size: 14px;">
<span style="font-weight: bold;">arXiv</span> 2302.07868
</div>
</a>
<!-- GitHub badge -->
<a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;">
<div style="
display: inline-block;
background-color: #24292e;
color: #ffffff !important;
padding: 5px 10px;
border-radius: 5px;
font-size: 14px;">
<span style="font-weight: bold;">GitHub</span> Repository
</div>
</a>
</div>
""")
with gr.Accordion("About DrugGEN Models", open=False):
gr.Markdown("""
### DrugGEN-AKT1
This model is designed to generate molecules targeting the human AKT1 protein (UniProt ID: P31749). Trained with [2,607 bioactive compounds](https://drive.google.com/file/d/1B2OOim5wrUJalixeBTDKXLHY8BAIvNh-/view?usp=drive_link).
Molecules larger than 45 heavy atoms were excluded.
### DrugGEN-CDK2
This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941). Trained with [1,817 bioactive compounds](https://drive.google.com/file/d/1C0CGFKx0I2gdSfbIEgUO7q3K2S1P9ksT/view?usp=drive_link).
Molecules larger than 38 heavy atoms were excluded.
### DrugGEN-NoTarget
This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. Trained with a general [ChEMBL dataset]((https://drive.google.com/file/d/1oyybQ4oXpzrme_n0kbwc0-CFjvTFSlBG/view?usp=drive_link)
Molecules larger than 45 heavy atoms were excluded.
- Useful for exploring chemical space, generating diverse scaffolds, and creating molecules with drug-like properties.
For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868).
""")
with gr.Accordion("Understanding the Metrics", open=False):
gr.Markdown("""
### Basic Metrics
- **Validity**: Percentage of generated molecules that are chemically valid
- **Uniqueness**: Percentage of unique molecules among valid ones
- **Runtime**: Time taken to generate or evaluate the molecules
### Novelty Metrics
- **Novelty (Train)**: Percentage of molecules not found in the [training set](https://drive.google.com/file/d/1oyybQ4oXpzrme_n0kbwc0-CFjvTFSlBG/view?usp=drive_link). These molecules are used as inputs to
the generator during training.
- **Novelty (Inference)**: Percentage of molecules not found in the [inference set](https://drive.google.com/file/d/1vMGXqK1SQXB3Od3l80gMWvTEOjJ5MFXP/view?usp=share_link). These molecules are used as inputs
to the generator during inference.
- **Novelty (Real Inhibitors)**: Percentage of molecules not found in known inhibitors of the target protein (look at About DrugGEN Models for details). These molecules are used as inputs to the
discriminator during training.
### Structural Metrics
- **Average Length**: Normalized average number of atoms in the generated molecules, normalized by the maximum number of atoms (e.g., 45 for AKT1/NoTarget, 38 for CDK2)
- **Mean Atom Type**: Average number of distinct atom types in the generated molecules
- **Internal Diversity**: Diversity within the generated set (higher is more diverse)
### Drug-likeness Metrics
- **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better)
- **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is better)
### Similarity Metrics
- **SNN ChEMBL**: Similarity to [ChEMBL molecules](https://drive.google.com/file/d/1oyybQ4oXpzrme_n0kbwc0-CFjvTFSlBG/view?usp=drive_link) (higher means more similar to known drug-like compounds)
- **SNN Real Inhibitors**: Similarity to the real inhibitors of the selected target (higher means more similar to the real inhibitors)
""")
model_name = gr.Radio(
choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"),
value="DrugGEN-AKT1",
label="Select Target Model",
info="Choose which protein target or general model to use for molecule generation"
)
with gr.Tabs():
with gr.TabItem("Classical Generation"):
num_molecules = gr.Slider(
minimum=10,
maximum=200,
value=100,
step=10,
label="Number of Molecules to Generate",
info="This space runs on a CPU, which may result in slower performance. Generating 100 molecules takes approximately 6 minutes. Therefore, we set a 200-molecule cap."
)
seed_num = gr.Textbox(
label="Random Seed (Optional)",
value="",
info="Set a specific seed for reproducible results, or leave empty for random generation"
)
classical_submit = gr.Button(
value="Generate Molecules",
variant="primary",
size="lg"
)
with gr.TabItem("Custom Input SMILES"):
custom_smiles = gr.Textbox(
label="Input SMILES (one per line, maximum 100 molecules)",
info="This space runs on a CPU, which may result in slower performance. Generating 100 molecules takes approximately 6 minutes. Therefore, we set a 100-molecule cap.\n\n Molecules larger than allowed maximum length (45 for AKT1/NoTarget and 38 for CDK2) and allowed atom types are going to be filtered.\n\n Novelty (Inference) metric is going to be calculated using these input smiles.",
placeholder="Nc1ccccc1-c1nc(N)c2ccccc2n1\nO=C(O)c1ccccc1C(=O)c1cccc(Cl)c1\n...",
lines=10
)
custom_submit = gr.Button(
value="Generate Molecules using Custom SMILES",
variant="primary",
size="lg"
)
with gr.Column(scale=2):
basic_metrics_df = gr.Dataframe(
headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Inference)", "Novelty (Real Inhibitors)", "Runtime (s)"],
elem_id="basic-metrics"
)
advanced_metrics_df = gr.Dataframe(
headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Real Inhibitors)", "Average Length"],
elem_id="advanced-metrics"
)
file_download = gr.File(label="Download All Generated Molecules (SMILES format)")
image_output = gr.Image(label="Structures of Randomly Selected Generated Molecules",
elem_id="molecule_display")
gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)")
# Set up the click actions for each tab.
classical_submit.click(
run_inference,
inputs=[gr.State("Generate Molecules"), model_name, num_molecules, seed_num, gr.State("")],
outputs=[
image_output,
file_download,
basic_metrics_df,
advanced_metrics_df
],
api_name="inference_classical"
)
custom_submit.click(
run_inference,
inputs=[gr.State("Custom Input SMILES"), model_name, gr.State(0), gr.State(""), custom_smiles],
outputs=[
image_output,
file_download,
basic_metrics_df,
advanced_metrics_df
],
api_name="inference_custom"
)
demo.queue()
demo.launch()