Spaces:
Running
Running
import gradio as gr | |
from inference import Inference | |
import PIL | |
from PIL import Image | |
import pandas as pd | |
import random | |
from rdkit import Chem | |
from rdkit.Chem import Draw | |
from rdkit.Chem.Draw import IPythonConsole | |
import shutil | |
import os | |
import time | |
class DrugGENConfig: | |
# Inference configuration | |
submodel = 'DrugGEN' | |
inference_model = "/home/user/app/experiments/models/DrugGEN/" | |
sample_num = 100 | |
# Data configuration | |
inf_smiles = '/home/user/app/data/chembl_test.smi' | |
train_smiles = '/home/user/app/data/chembl_train.smi' | |
inf_batch_size = 1 | |
mol_data_dir = '/home/user/app/data' | |
features = False | |
# Model configuration | |
act = 'relu' | |
max_atom = 45 | |
dim = 128 | |
depth = 1 | |
heads = 8 | |
mlp_ratio = 3 | |
dropout = 0. | |
# Seed configuration | |
set_seed = True | |
seed = 10 | |
disable_correction = False | |
class DrugGENAKT1Config(DrugGENConfig): | |
submodel = 'DrugGEN' | |
inference_model = "/home/user/app/experiments/models/DrugGEN-akt1/" | |
train_drug_smiles = '/home/user/app/data/akt_train.smi' | |
max_atom = 45 | |
class DrugGENCDK2Config(DrugGENConfig): | |
submodel = 'DrugGEN' | |
inference_model = "/home/user/app/experiments/models/DrugGEN-cdk2/" | |
train_drug_smiles = '/home/user/app/data/cdk2_train.smi' | |
max_atom = 38 | |
class NoTargetConfig(DrugGENConfig): | |
submodel = "NoTarget" | |
inference_model = "/home/user/app/experiments/models/NoTarget/" | |
model_configs = { | |
"DrugGEN-AKT1": DrugGENAKT1Config(), | |
"DrugGEN-CDK2": DrugGENCDK2Config(), | |
"DrugGEN-NoTarget": NoTargetConfig(), | |
} | |
def run_inference(mode: str, model_name: str, num_molecules: int, seed_num: str, custom_smiles: str): | |
""" | |
Depending on the selected mode, either generate new molecules or evaluate provided SMILES. | |
Returns: | |
image, file_path, basic_metrics, advanced_metrics | |
""" | |
config = model_configs[model_name] | |
if mode == "Custom Input SMILES": | |
# Process the custom input SMILES | |
smiles_list = [s.strip() for s in custom_smiles.strip().splitlines() if s.strip() != ""] | |
if len(smiles_list) > 100: | |
raise gr.Error("You have provided more than the allowed limit of 100 molecules. Please provide 100 or fewer.") | |
# Write the custom SMILES to a temporary file and update config | |
config.seed = random.randint(0, 10000) | |
temp_input_file = f"custom_input{config.seed}.smi" | |
with open(temp_input_file, "w") as f: | |
for s in smiles_list: | |
f.write(s + "\n") | |
config.inf_smiles = temp_input_file | |
config.sample_num = len(smiles_list) | |
# Always use a random seed for custom mode | |
else: | |
# Classical Generation mode | |
config.sample_num = num_molecules | |
if config.sample_num > 200: | |
raise gr.Error("You have requested to generate more than the allowed limit of 200 molecules. Please reduce your request to 200 or fewer.") | |
if seed_num is None or seed_num.strip() == "": | |
config.seed = random.randint(0, 10000) | |
else: | |
try: | |
config.seed = int(seed_num) | |
except ValueError: | |
raise gr.Error("The seed must be an integer value!") | |
# Adjust model name for the inference if not using NoTarget | |
if model_name != "DrugGEN-NoTarget": | |
target_model_name = "DrugGEN" | |
else: | |
target_model_name = "NoTarget" | |
inferer = Inference(config) | |
start_time = time.time() | |
scores = inferer.inference() # This returns a DataFrame with specific columns | |
et = time.time() - start_time | |
# Create basic metrics dataframe | |
basic_metrics = pd.DataFrame({ | |
"Validity": [scores["validity"].iloc[0]], | |
"Uniqueness": [scores["uniqueness"].iloc[0]], | |
"Novelty (Train)": [scores["novelty"].iloc[0]], | |
"Novelty (Inference)": [scores["novelty_test"].iloc[0]], | |
"Novelty (Real Inhibitors)": [scores["drug_novelty"].iloc[0]], | |
"Runtime (s)": [round(et, 2)] | |
}) | |
# Create advanced metrics dataframe | |
advanced_metrics = pd.DataFrame({ | |
"QED": [scores["qed"].iloc[0]], | |
"SA Score": [scores["sa"].iloc[0]], | |
"Internal Diversity": [scores["IntDiv"].iloc[0]], | |
"SNN ChEMBL": [scores["snn_chembl"].iloc[0]], | |
"SNN Real Inhibitors": [scores["snn_drug"].iloc[0]], | |
"Average Length": [scores["max_len"].iloc[0]] | |
}) | |
# Process the output file from inference | |
output_file_path = f'/home/user/app/experiments/inference/{target_model_name}/inference_drugs.txt' | |
new_path = f'{target_model_name}_denovo_mols.smi' | |
os.rename(output_file_path, new_path) | |
with open(new_path) as f: | |
inference_drugs = f.read() | |
generated_molecule_list = inference_drugs.split("\n")[:-1] | |
# Randomly select up to 9 molecules for display | |
rng = random.Random(config.seed) | |
if len(generated_molecule_list) > 9: | |
selected_smiles = rng.choices(generated_molecule_list, k=9) | |
else: | |
selected_smiles = generated_molecule_list | |
selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_smiles if Chem.MolFromSmiles(mol) is not None] | |
drawOptions = Draw.rdMolDraw2D.MolDrawOptions() | |
drawOptions.prepareMolsBeforeDrawing = False | |
drawOptions.bondLineWidth = 0.5 | |
molecule_image = Draw.MolsToGridImage( | |
selected_molecules, | |
molsPerRow=3, | |
subImgSize=(400, 400), | |
maxMols=len(selected_molecules), | |
returnPNG=False, | |
drawOptions=drawOptions, | |
highlightAtomLists=None, | |
highlightBondLists=None, | |
) | |
return molecule_image, new_path, basic_metrics, advanced_metrics | |
with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
# Add custom CSS for styling | |
gr.HTML(""" | |
<style> | |
#metrics-container { | |
border: 1px solid rgba(128, 128, 128, 0.3); | |
border-radius: 8px; | |
padding: 15px; | |
margin-top: 15px; | |
margin-bottom: 15px; | |
background-color: rgba(255, 255, 255, 0.05); | |
} | |
</style> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks") | |
gr.HTML(""" | |
<div style="display: flex; gap: 10px; margin-bottom: 15px;"> | |
<!-- arXiv badge --> | |
<a href="https://arxiv.org/abs/2302.07868" target="_blank" style="text-decoration: none;"> | |
<div style=" | |
display: inline-block; | |
background-color: #b31b1b; | |
color: #ffffff !important; | |
padding: 5px 10px; | |
border-radius: 5px; | |
font-size: 14px;"> | |
<span style="font-weight: bold;">arXiv</span> 2302.07868 | |
</div> | |
</a> | |
<!-- GitHub badge --> | |
<a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;"> | |
<div style=" | |
display: inline-block; | |
background-color: #24292e; | |
color: #ffffff !important; | |
padding: 5px 10px; | |
border-radius: 5px; | |
font-size: 14px;"> | |
<span style="font-weight: bold;">GitHub</span> Repository | |
</div> | |
</a> | |
</div> | |
""") | |
with gr.Accordion("About DrugGEN Models", open=False): | |
gr.Markdown(""" | |
### DrugGEN-AKT1 | |
This model is designed to generate molecules targeting the human AKT1 protein (UniProt ID: P31749). Trained with [2,607 bioactive compounds](https://drive.google.com/file/d/1B2OOim5wrUJalixeBTDKXLHY8BAIvNh-/view?usp=drive_link). | |
Molecules larger than 45 heavy atoms were excluded. | |
### DrugGEN-CDK2 | |
This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941). Trained with [1,817 bioactive compounds](https://drive.google.com/file/d/1C0CGFKx0I2gdSfbIEgUO7q3K2S1P9ksT/view?usp=drive_link). | |
Molecules larger than 38 heavy atoms were excluded. | |
### DrugGEN-NoTarget | |
This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. Trained with a general [ChEMBL dataset]((https://drive.google.com/file/d/1oyybQ4oXpzrme_n0kbwc0-CFjvTFSlBG/view?usp=drive_link) | |
Molecules larger than 45 heavy atoms were excluded. | |
- Useful for exploring chemical space, generating diverse scaffolds, and creating molecules with drug-like properties. | |
For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868). | |
""") | |
with gr.Accordion("Understanding the Metrics", open=False): | |
gr.Markdown(""" | |
### Basic Metrics | |
- **Validity**: Percentage of generated molecules that are chemically valid | |
- **Uniqueness**: Percentage of unique molecules among valid ones | |
- **Runtime**: Time taken to generate or evaluate the molecules | |
### Novelty Metrics | |
- **Novelty (Train)**: Percentage of molecules not found in the [training set](https://drive.google.com/file/d/1oyybQ4oXpzrme_n0kbwc0-CFjvTFSlBG/view?usp=drive_link). These molecules are used as inputs to | |
the generator during training. | |
- **Novelty (Inference)**: Percentage of molecules not found in the [inference set](https://drive.google.com/file/d/1vMGXqK1SQXB3Od3l80gMWvTEOjJ5MFXP/view?usp=share_link). These molecules are used as inputs | |
to the generator during inference. | |
- **Novelty (Real Inhibitors)**: Percentage of molecules not found in known inhibitors of the target protein (look at About DrugGEN Models for details). These molecules are used as inputs to the | |
discriminator during training. | |
### Structural Metrics | |
- **Average Length**: Normalized average number of atoms in the generated molecules, normalized by the maximum number of atoms (e.g., 45 for AKT1/NoTarget, 38 for CDK2) | |
- **Mean Atom Type**: Average number of distinct atom types in the generated molecules | |
- **Internal Diversity**: Diversity within the generated set (higher is more diverse) | |
### Drug-likeness Metrics | |
- **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better) | |
- **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is better) | |
### Similarity Metrics | |
- **SNN ChEMBL**: Similarity to [ChEMBL molecules](https://drive.google.com/file/d/1oyybQ4oXpzrme_n0kbwc0-CFjvTFSlBG/view?usp=drive_link) (higher means more similar to known drug-like compounds) | |
- **SNN Real Inhibitors**: Similarity to the real inhibitors of the selected target (higher means more similar to the real inhibitors) | |
""") | |
model_name = gr.Radio( | |
choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"), | |
value="DrugGEN-AKT1", | |
label="Select Target Model", | |
info="Choose which protein target or general model to use for molecule generation" | |
) | |
with gr.Tabs(): | |
with gr.TabItem("Classical Generation"): | |
num_molecules = gr.Slider( | |
minimum=10, | |
maximum=200, | |
value=100, | |
step=10, | |
label="Number of Molecules to Generate", | |
info="This space runs on a CPU, which may result in slower performance. Generating 100 molecules takes approximately 6 minutes. Therefore, we set a 200-molecule cap." | |
) | |
seed_num = gr.Textbox( | |
label="Random Seed (Optional)", | |
value="", | |
info="Set a specific seed for reproducible results, or leave empty for random generation" | |
) | |
classical_submit = gr.Button( | |
value="Generate Molecules", | |
variant="primary", | |
size="lg" | |
) | |
with gr.TabItem("Custom Input SMILES"): | |
custom_smiles = gr.Textbox( | |
label="Input SMILES (one per line, maximum 100 molecules)", | |
info="This space runs on a CPU, which may result in slower performance. Generating 100 molecules takes approximately 6 minutes. Therefore, we set a 100-molecule cap.\n\n Molecules larger than allowed maximum length (45 for AKT1/NoTarget and 38 for CDK2) and allowed atom types are going to be filtered.\n\n Novelty (Inference) metric is going to be calculated using these input smiles.", | |
placeholder="Nc1ccccc1-c1nc(N)c2ccccc2n1\nO=C(O)c1ccccc1C(=O)c1cccc(Cl)c1\n...", | |
lines=10 | |
) | |
custom_submit = gr.Button( | |
value="Generate Molecules using Custom SMILES", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=2): | |
basic_metrics_df = gr.Dataframe( | |
headers=["Validity", "Uniqueness", "Novelty (Train)", "Novelty (Inference)", "Novelty (Real Inhibitors)", "Runtime (s)"], | |
elem_id="basic-metrics" | |
) | |
advanced_metrics_df = gr.Dataframe( | |
headers=["QED", "SA Score", "Internal Diversity", "SNN (ChEMBL)", "SNN (Real Inhibitors)", "Average Length"], | |
elem_id="advanced-metrics" | |
) | |
file_download = gr.File(label="Download All Generated Molecules (SMILES format)") | |
image_output = gr.Image(label="Structures of Randomly Selected Generated Molecules", | |
elem_id="molecule_display") | |
gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)") | |
# Set up the click actions for each tab. | |
classical_submit.click( | |
run_inference, | |
inputs=[gr.State("Generate Molecules"), model_name, num_molecules, seed_num, gr.State("")], | |
outputs=[ | |
image_output, | |
file_download, | |
basic_metrics_df, | |
advanced_metrics_df | |
], | |
api_name="inference_classical" | |
) | |
custom_submit.click( | |
run_inference, | |
inputs=[gr.State("Custom Input SMILES"), model_name, gr.State(0), gr.State(""), custom_smiles], | |
outputs=[ | |
image_output, | |
file_download, | |
basic_metrics_df, | |
advanced_metrics_df | |
], | |
api_name="inference_custom" | |
) | |
demo.queue() | |
demo.launch() | |