Spaces:
Running
Running
import gradio as gr | |
from inference import Inference | |
import PIL | |
from PIL import Image | |
import pandas as pd | |
import random | |
from rdkit import Chem | |
from rdkit.Chem import Draw | |
from rdkit.Chem.Draw import IPythonConsole | |
import shutil | |
import os | |
import time | |
class DrugGENConfig: | |
# Inference configuration | |
submodel='DrugGEN' | |
inference_model="experiments/models/DrugGEN/" | |
sample_num=100 | |
disable_correction=False # corresponds to correct=True in old config | |
# Data configuration | |
inf_smiles='data/chembl_test.smi' # corresponds to inf_raw_file in old config | |
train_smiles='data/chembl_train.smi' | |
train_drug_smiles='data/akt1_train.smi' | |
inf_batch_size=1 | |
mol_data_dir='data' | |
features=False | |
# Model configuration | |
act='relu' | |
max_atom=45 | |
dim=128 | |
depth=1 | |
heads=8 | |
mlp_ratio=3 | |
dropout=0. | |
# Seed configuration | |
set_seed=True | |
seed=10 | |
class DrugGENAKT1Config(DrugGENConfig): | |
submodel='DrugGEN' | |
inference_model="experiments/models/DrugGEN-AKT1/" | |
train_drug_smiles='data/akt1_train.smi' | |
max_atom=45 | |
class DrugGENCDK2Config(DrugGENConfig): | |
submodel='DrugGEN' | |
inference_model="experiments/models/DrugGEN-CDK2/" | |
train_drug_smiles='data/cdk2_train.smi' | |
max_atom=38 | |
class NoTargetConfig(DrugGENConfig): | |
submodel="NoTarget" | |
inference_model="experiments/models/NoTarget/" | |
train_drug_smiles='data/chembl_train.smi' # No specific target, use general ChEMBL data | |
model_configs = { | |
"DrugGEN-AKT1": DrugGENAKT1Config(), | |
"DrugGEN-CDK2": DrugGENCDK2Config(), | |
"DrugGEN-NoTarget": NoTargetConfig(), | |
} | |
def function(model_name: str, num_molecules: int, seed_num: int): | |
''' | |
Returns: | |
image, score_df, file_path, and individual metrics | |
''' | |
if model_name == "DrugGEN-NoTarget": | |
model_name = "NoTarget" | |
config = model_configs[model_name] | |
config.sample_num = num_molecules | |
if config.sample_num > 250: | |
raise gr.Error("You have requested to generate more than the allowed limit of 250 molecules. Please reduce your request to 250 or fewer.") | |
if seed_num is None or seed_num.strip() == "": | |
config.seed = random.randint(0, 10000) | |
else: | |
try: | |
config.seed = int(seed_num) | |
except ValueError: | |
raise gr.Error("The seed must be an integer value!") | |
inferer = Inference(config) | |
start_time = time.time() | |
scores = inferer.inference() # create scores_df out of this | |
et = time.time() - start_time | |
score_df = pd.DataFrame({ | |
"Runtime (seconds)": [et], | |
"Validity": [scores["validity"].iloc[0]], | |
"Uniqueness": [scores["uniqueness"].iloc[0]], | |
"Novelty (Train)": [scores["novelty"].iloc[0]], | |
"Novelty (Test)": [scores["novelty_test"].iloc[0]], | |
"Drug Novelty": [scores["drug_novelty"].iloc[0]], | |
"Max Length": [scores["max_len"].iloc[0]], | |
"Mean Atom Type": [scores["mean_atom_type"].iloc[0]], | |
"SNN ChEMBL": [scores["snn_chembl"].iloc[0]], | |
"SNN Drug": [scores["snn_drug"].iloc[0]], | |
"Internal Diversity": [scores["IntDiv"].iloc[0]], | |
"QED": [scores["qed"].iloc[0]], | |
"SA Score": [scores["sa"].iloc[0]] | |
}) | |
# Extract individual metrics | |
validity = scores["validity"].iloc[0] | |
uniqueness = scores["uniqueness"].iloc[0] | |
novelty_train = scores["novelty"].iloc[0] | |
novelty_test = scores["novelty_test"].iloc[0] | |
drug_novelty = scores["drug_novelty"].iloc[0] | |
runtime = et | |
qed = scores["qed"].iloc[0] | |
sa = scores["sa"].iloc[0] | |
int_div = scores["IntDiv"].iloc[0] | |
snn_chembl = scores["snn_chembl"].iloc[0] | |
snn_drug = scores["snn_drug"].iloc[0] | |
max_len = scores["max_len"].iloc[0] | |
output_file_path = f'experiments/inference/{model_name}/inference_drugs.txt' | |
new_path = f'{model_name}_denovo_mols.smi' | |
os.rename(output_file_path, new_path) | |
with open(new_path) as f: | |
inference_drugs = f.read() | |
generated_molecule_list = inference_drugs.split("\n")[:-1] | |
rng = random.Random(config.seed) | |
if num_molecules > 12: | |
selected_molecules = rng.choices(generated_molecule_list, k=12) | |
else: | |
selected_molecules = generated_molecule_list | |
selected_molecules = [Chem.MolFromSmiles(mol) for mol in selected_molecules if Chem.MolFromSmiles(mol) is not None] | |
drawOptions = Draw.rdMolDraw2D.MolDrawOptions() | |
drawOptions.prepareMolsBeforeDrawing = False | |
drawOptions.bondLineWidth = 0.5 | |
molecule_image = Draw.MolsToGridImage( | |
selected_molecules, | |
molsPerRow=3, | |
subImgSize=(400, 400), | |
maxMols=len(selected_molecules), | |
# legends=None, | |
returnPNG=False, | |
drawOptions=drawOptions, | |
highlightAtomLists=None, | |
highlightBondLists=None, | |
) | |
return molecule_image, score_df, new_path, validity, uniqueness, novelty_train, novelty_test, drug_novelty, runtime, qed, sa, int_div, snn_chembl, snn_drug, max_len | |
with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
# Add custom CSS for styling | |
gr.HTML(""" | |
<style> | |
#metrics-container { | |
border: 1px solid rgba(128, 128, 128, 0.3); | |
border-radius: 8px; | |
padding: 15px; | |
margin-top: 15px; | |
margin-bottom: 15px; | |
background-color: rgba(255, 255, 255, 0.05); | |
} | |
</style> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("# DrugGEN: Target Centric De Novo Design of Drug Candidate Molecules with Graph Generative Deep Adversarial Networks") | |
gr.HTML(""" | |
<div style="display: flex; gap: 10px; margin-bottom: 15px;"> | |
<a href="https://arxiv.org/abs/2302.07868" target="_blank" style="text-decoration: none;"> | |
<div style="display: inline-block; background-color: #b31b1b; color: white; padding: 5px 10px; border-radius: 5px; font-size: 14px;"> | |
<span style="font-weight: bold;">arXiv</span> 2302.07868 | |
</div> | |
</a> | |
<a href="https://github.com/HUBioDataLab/DrugGEN" target="_blank" style="text-decoration: none;"> | |
<div style="display: inline-block; background-color: #24292e; color: white; padding: 5px 10px; border-radius: 5px; font-size: 14px;"> | |
<span style="font-weight: bold;">GitHub</span> Repository | |
</div> | |
</a> | |
</div> | |
""") | |
with gr.Accordion("About DrugGEN Models", open=False): | |
gr.Markdown(""" | |
### DrugGEN-AKT1 | |
This model is designed to generate molecules targeting the human AKT1 protein (UniProt ID: P31749). | |
### DrugGEN-CDK2 | |
This model is designed to generate molecules targeting the human CDK2 protein (UniProt ID: P24941). | |
### DrugGEN-NoTarget | |
This is a general-purpose model that generates diverse drug-like molecules without targeting a specific protein. It's useful for: | |
- Exploring chemical space | |
- Generating diverse scaffolds | |
- Creating molecules with drug-like properties | |
For more details, see our [paper on arXiv](https://arxiv.org/abs/2302.07868). | |
""") | |
with gr.Accordion("Understanding the Metrics", open=False): | |
gr.Markdown(""" | |
## Evaluation Metrics | |
### Basic Metrics | |
- **Validity**: Percentage of generated molecules that are chemically valid | |
- **Uniqueness**: Percentage of unique molecules among valid ones | |
- **Runtime**: Time taken to generate the requested molecules | |
### Novelty Metrics | |
- **Novelty (Train)**: Percentage of molecules not found in the training set | |
- **Novelty (Test)**: Percentage of molecules not found in the test set | |
- **Drug Novelty**: Percentage of molecules not found in known drugs | |
### Structural Metrics | |
- **Max Length**: Maximum component length in the generated molecules | |
- **Mean Atom Type**: Average distribution of atom types | |
- **Internal Diversity**: Diversity within the generated set (higher is more diverse) | |
### Drug-likeness Metrics | |
- **QED (Quantitative Estimate of Drug-likeness)**: Score from 0-1 measuring how drug-like a molecule is (higher is better) | |
- **SA Score (Synthetic Accessibility)**: Score from 1-10 indicating ease of synthesis (lower is easier) | |
### Similarity Metrics | |
- **SNN ChEMBL**: Similarity to ChEMBL molecules (higher means more similar to known drug-like compounds) | |
- **SNN Drug**: Similarity to known drugs (higher means more similar to approved drugs) | |
""") | |
model_name = gr.Radio( | |
choices=("DrugGEN-AKT1", "DrugGEN-CDK2", "DrugGEN-NoTarget"), | |
value="DrugGEN-AKT1", | |
label="Select Target Model", | |
info="Choose which protein target or general model to use for molecule generation" | |
) | |
num_molecules = gr.Slider( | |
minimum=10, | |
maximum=250, | |
value=100, | |
step=10, | |
label="Number of Molecules to Generate", | |
info="This space runs on a CPU, which may result in slower performance. Generating 200 molecules takes approximately 6 minutes. Therefore, We set a 250-molecule cap. On a GPU, the model can generate 10,000 molecules in the same amount of time. Please check our GitHub repo for running our models on GPU." | |
) | |
seed_num = gr.Textbox( | |
label="Random Seed (Optional)", | |
value="", | |
info="Set a specific seed for reproducible results, or leave empty for random generation" | |
) | |
submit_button = gr.Button( | |
value="Generate Molecules", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=2): | |
image_output = gr.Image( | |
label="Sample of Generated Molecules", | |
elem_id="molecule_display" | |
) | |
file_download = gr.File( | |
label="Download All Generated Molecules (SMILES format)", | |
) | |
with gr.Group(elem_id="metrics-container"): | |
gr.Markdown("### Performance Metrics") | |
with gr.Row(): | |
with gr.Column(): | |
validity = gr.Number(label="Validity", precision=3) | |
uniqueness = gr.Number(label="Uniqueness", precision=3) | |
novelty_train = gr.Number(label="Novelty (Train)", precision=3) | |
novelty_test = gr.Number(label="Novelty (Test)", precision=3) | |
drug_novelty = gr.Number(label="Drug Novelty", precision=3) | |
runtime = gr.Number(label="Runtime (seconds)", precision=2) | |
with gr.Column(): | |
qed = gr.Number(label="QED Score", precision=3, info="Higher is more drug-like (0-1)") | |
sa = gr.Number(label="SA Score", precision=3, info="Lower is easier to synthesize (1-10)") | |
int_div = gr.Number(label="Internal Diversity", precision=3) | |
snn_chembl = gr.Number(label="SNN ChEMBL", precision=3) | |
snn_drug = gr.Number(label="SNN Drug", precision=3) | |
max_len = gr.Number(label="Max Length", precision=3) | |
with gr.Accordion("All Metrics (Table View)", open=False): | |
scores_df = gr.Dataframe( | |
headers=["Runtime (seconds)", "Validity", "Uniqueness", "Novelty (Train)", "Novelty (Test)", | |
"Drug Novelty", "Max Length", "Mean Atom Type", "SNN ChEMBL", "SNN Drug", | |
"Internal Diversity", "QED", "SA Score"] | |
) | |
gr.Markdown("### Created by the HUBioDataLab | [GitHub](https://github.com/HUBioDataLab/DrugGEN) | [Paper](https://arxiv.org/abs/2302.07868)") | |
submit_button.click( | |
function, | |
inputs=[model_name, num_molecules, seed_num], | |
outputs=[ | |
image_output, | |
scores_df, | |
file_download, | |
validity, | |
uniqueness, | |
novelty_train, | |
novelty_test, | |
drug_novelty, | |
runtime, | |
qed, | |
sa, | |
int_div, | |
snn_chembl, | |
snn_drug, | |
max_len | |
], | |
api_name="inference" | |
) | |
demo.queue() | |
demo.launch() |