|
import gradio as gr |
|
import pandas as pd |
|
from PIL import Image |
|
from rdkit import RDLogger |
|
from molecule_generation_helpers import * |
|
from property_prediction_helpers import * |
|
|
|
RDLogger.logger().setLevel(RDLogger.ERROR) |
|
|
|
|
|
predefined_datasets = { |
|
" ": " ", |
|
"BACE": f"./data/bace/train.csv, ./data/bace/test.csv, smiles, Class", |
|
"ESOL": f"./data/esol/train.csv, ./data/esol/test.csv, smiles, prop", |
|
} |
|
|
|
|
|
models_enabled = ["SELFIES-TED", "MHG-GED", "MolFormer", "SMI-TED"] |
|
|
|
|
|
fusion_available = ["Concat"] |
|
|
|
|
|
|
|
def load_predefined_dataset(dataset_name): |
|
val = predefined_datasets.get(dataset_name) |
|
if val: |
|
df = pd.read_csv(val.split(",")[0]) |
|
return ( |
|
df.head(), |
|
gr.update(choices=list(df.columns)), |
|
gr.update(choices=list(df.columns)), |
|
dataset_name.lower(), |
|
) |
|
else: |
|
return ( |
|
pd.DataFrame(), |
|
gr.update(choices=[]), |
|
gr.update(choices=[]), |
|
f"Dataset not found", |
|
) |
|
|
|
|
|
|
|
def handle_dataset_selection(selected_dataset): |
|
if selected_dataset == "Custom Dataset": |
|
|
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
) |
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|
|
|
|
|
|
def update_hyperparameters(model_name): |
|
if model_name == "XGBClassifier": |
|
return ( |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
elif model_name == "SVR": |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
) |
|
elif model_name == "Kernel Ridge": |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
gr.update(visible=True), |
|
) |
|
elif model_name == "Linear Regression": |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
elif model_name == "Default - Auto": |
|
return ( |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
gr.update(visible=False), |
|
) |
|
|
|
|
|
|
|
def select_columns(input_column, output_column, train_data, test_data, dataset_name): |
|
if input_column and output_column: |
|
return f"{train_data.name},{test_data.name},{input_column},{output_column},{dataset_name}" |
|
return "Please select both input and output columns." |
|
|
|
|
|
|
|
def set_dataname(dataset_name, dataset_selector): |
|
return dataset_name if dataset_selector == "Custom Dataset" else dataset_selector |
|
|
|
|
|
|
|
def display_csv_head(file): |
|
if file is not None: |
|
|
|
df = pd.read_csv(file.name) |
|
return ( |
|
df.head(), |
|
gr.update(choices=list(df.columns)), |
|
gr.update(choices=list(df.columns)), |
|
) |
|
return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[]) |
|
|
|
|
|
|
|
smiles_image_mapping = { |
|
|
|
"Mol 1": { |
|
"smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1", |
|
"image": "img/img1.png", |
|
}, |
|
|
|
"Mol 2": { |
|
"smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1", |
|
"image": "img/img2.png", |
|
}, |
|
|
|
"Mol 3": { |
|
"smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1", |
|
"image": "img/img3.png", |
|
}, |
|
|
|
"Mol 4": { |
|
"smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1", |
|
"image": "img/img4.png", |
|
}, |
|
|
|
"Mol 5": { |
|
"smiles": "C=CCS[C@@H](C)CC(=O)OCC", |
|
"image": "img/img5.png", |
|
}, |
|
} |
|
|
|
|
|
|
|
def load_image(path): |
|
try: |
|
return Image.open(smiles_image_mapping[path]["image"]) |
|
except: |
|
pass |
|
|
|
|
|
|
|
def handle_image_selection(image_key): |
|
if not image_key: |
|
return None, None |
|
smiles = smiles_image_mapping[image_key]["smiles"] |
|
mol_image = smiles_to_image(smiles) |
|
return smiles, mol_image |
|
|
|
|
|
|
|
with open("INTRODUCTION.md") as f: |
|
|
|
with gr.Blocks() as introduction: |
|
gr.Markdown(f.read()) |
|
gr.Markdown("---\n# Debug") |
|
gr.HTML("HTML text: <img src='file/img/selfies-ted.png'>") |
|
gr.Markdown("Markdown text: ") |
|
gr.HTML("HTML text: <img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg'>") |
|
gr.Markdown("Markdown text: ") |
|
|
|
|
|
with gr.Blocks() as property_prediction: |
|
log_df = pd.DataFrame( |
|
{"": [], 'Selected Models': [], 'Dataset': [], 'Task': [], 'Result': []} |
|
) |
|
state = gr.State({"log_df": log_df}) |
|
gr.HTML( |
|
''' |
|
<p style="text-align: center"> |
|
Task : Property Prediction |
|
<br> |
|
Models are finetuned with different combination of modalities on the uploaded or selected built data set. |
|
</p> |
|
''' |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
dataset_selector = gr.Dropdown( |
|
label="Select Dataset", |
|
choices=list(predefined_datasets.keys()) + ["Custom Dataset"], |
|
) |
|
|
|
selected_columns_message = gr.Textbox( |
|
label="Selected Columns Info", visible=False |
|
) |
|
|
|
with gr.Accordion("Dataset Settings", open=True): |
|
|
|
dataset_name = gr.Textbox(label="Dataset Name", visible=False) |
|
train_file = gr.File( |
|
label="Upload Custom Train Dataset", |
|
file_types=[".csv"], |
|
visible=False, |
|
) |
|
train_display = gr.Dataframe( |
|
label="Train Dataset Preview (First 5 Rows)", |
|
visible=False, |
|
interactive=False, |
|
) |
|
|
|
test_file = gr.File( |
|
label="Upload Custom Test Dataset", |
|
file_types=[".csv"], |
|
visible=False, |
|
) |
|
test_display = gr.Dataframe( |
|
label="Test Dataset Preview (First 5 Rows)", |
|
visible=False, |
|
interactive=False, |
|
) |
|
|
|
|
|
predefined_display = gr.Dataframe( |
|
label="Predefined Dataset Preview (First 5 Rows)", |
|
visible=False, |
|
interactive=False, |
|
) |
|
|
|
|
|
input_column_selector = gr.Dropdown( |
|
label="Select Input Column", choices=[], visible=False |
|
) |
|
output_column_selector = gr.Dropdown( |
|
label="Select Output Column", choices=[], visible=False |
|
) |
|
|
|
|
|
dataset_selector.change( |
|
handle_dataset_selection, |
|
inputs=dataset_selector, |
|
outputs=[ |
|
dataset_name, |
|
train_file, |
|
train_display, |
|
test_file, |
|
test_display, |
|
predefined_display, |
|
input_column_selector, |
|
output_column_selector, |
|
], |
|
) |
|
|
|
|
|
dataset_selector.change( |
|
load_predefined_dataset, |
|
inputs=dataset_selector, |
|
outputs=[ |
|
predefined_display, |
|
input_column_selector, |
|
output_column_selector, |
|
selected_columns_message, |
|
], |
|
) |
|
|
|
|
|
train_file.change( |
|
display_csv_head, |
|
inputs=train_file, |
|
outputs=[ |
|
train_display, |
|
input_column_selector, |
|
output_column_selector, |
|
], |
|
) |
|
|
|
|
|
test_file.change( |
|
display_csv_head, |
|
inputs=test_file, |
|
outputs=[ |
|
test_display, |
|
input_column_selector, |
|
output_column_selector, |
|
], |
|
) |
|
|
|
dataset_selector.change( |
|
set_dataname, |
|
inputs=[dataset_name, dataset_selector], |
|
outputs=dataset_name, |
|
) |
|
|
|
|
|
input_column_selector.change( |
|
select_columns, |
|
inputs=[ |
|
input_column_selector, |
|
output_column_selector, |
|
train_file, |
|
test_file, |
|
dataset_name, |
|
], |
|
outputs=selected_columns_message, |
|
) |
|
|
|
output_column_selector.change( |
|
select_columns, |
|
inputs=[ |
|
input_column_selector, |
|
output_column_selector, |
|
train_file, |
|
test_file, |
|
dataset_name, |
|
], |
|
outputs=selected_columns_message, |
|
) |
|
|
|
model_checkbox = gr.CheckboxGroup( |
|
choices=models_enabled, label="Select Model" |
|
) |
|
|
|
task_radiobutton = gr.Radio( |
|
choices=["Classification", "Regression"], label="Task Type" |
|
) |
|
|
|
|
|
model_name = gr.Dropdown( |
|
[ |
|
"Default - Auto", |
|
"XGBClassifier", |
|
"SVR", |
|
"Kernel Ridge", |
|
"Linear Regression", |
|
], |
|
label="Select Downstream Model", |
|
) |
|
with gr.Accordion("Downstream Hyperparameter Settings", open=True): |
|
|
|
max_depth = gr.Slider(1, 20, step=1, visible=False, label="max_depth") |
|
n_estimators = gr.Slider( |
|
100, 5000, step=100, visible=False, label="n_estimators" |
|
) |
|
alpha = gr.Slider(0.1, 10.0, step=0.1, visible=False, label="alpha") |
|
degree = gr.Slider(1, 20, step=1, visible=False, label="degree") |
|
kernel = gr.Dropdown( |
|
choices=["rbf", "poly", "linear"], visible=False, label="kernel" |
|
) |
|
|
|
|
|
output = gr.Textbox(label="Loaded Parameters") |
|
|
|
|
|
model_name.change( |
|
update_hyperparameters, |
|
inputs=[model_name], |
|
outputs=[max_depth, n_estimators, alpha, degree, kernel], |
|
) |
|
|
|
|
|
submit_button = gr.Button("Create Downstream Model") |
|
|
|
|
|
submit_button.click( |
|
create_downstream_model, |
|
inputs=[model_name, max_depth, n_estimators, alpha, degree, kernel], |
|
outputs=output, |
|
) |
|
|
|
|
|
fusion_radiobutton = gr.Radio(choices=fusion_available, label="Fusion Type") |
|
|
|
eval_button = gr.Button("Train downstream model") |
|
|
|
|
|
with gr.Column(): |
|
eval_output = gr.Textbox(label="Train downstream model") |
|
|
|
plot_radio = gr.Radio( |
|
choices=["ROC-AUC", "Parity Plot", "Latent Space"], |
|
label="Select Plot Type", |
|
) |
|
plot_output = gr.Plot(label="Visualization") |
|
|
|
create_log = gr.Button("Store log") |
|
|
|
log_table = gr.Dataframe( |
|
value=log_df, label="Log of Selections and Results", interactive=False |
|
) |
|
|
|
eval_button.click( |
|
display_eval, |
|
inputs=[ |
|
model_checkbox, |
|
selected_columns_message, |
|
task_radiobutton, |
|
output, |
|
fusion_radiobutton, |
|
state, |
|
], |
|
outputs=eval_output, |
|
) |
|
|
|
plot_radio.change( |
|
display_plot, inputs=[plot_radio, state], outputs=plot_output |
|
) |
|
|
|
create_log.click( |
|
evaluate_and_log, |
|
inputs=[ |
|
model_checkbox, |
|
dataset_name, |
|
task_radiobutton, |
|
eval_output, |
|
state, |
|
], |
|
outputs=log_table, |
|
) |
|
|
|
|
|
|
|
with gr.Blocks() as molecule_generation: |
|
gr.HTML( |
|
''' |
|
<p style="text-align: center"> |
|
Task : Molecule Generation |
|
<br> |
|
Generate a new molecule similar to the initial molecule with better drug-likeness and synthetic accessibility. |
|
</p> |
|
''' |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
smiles_input = gr.Textbox(label="Input SMILES String") |
|
image_display = gr.Image(label="Molecule Image", height=250, width=250) |
|
|
|
with gr.Accordion("Select from sample molecules", open=False): |
|
image_selector = gr.Radio( |
|
choices=list(smiles_image_mapping.keys()), |
|
label="Select from sample molecules", |
|
value=None, |
|
) |
|
image_selector.change(load_image, image_selector, image_display) |
|
clear_button = gr.Button("Clear") |
|
generate_button = gr.Button("Submit", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
gen_image_display = gr.Image( |
|
label="Generated Molecule Image", height=250, width=250 |
|
) |
|
generated_output = gr.Textbox(label="Generated Output") |
|
property_table = gr.Dataframe(label="Molecular Properties Comparison") |
|
|
|
|
|
image_selector.change( |
|
handle_image_selection, |
|
inputs=image_selector, |
|
outputs=[smiles_input, image_display], |
|
) |
|
smiles_input.change( |
|
smiles_to_image, inputs=smiles_input, outputs=image_display |
|
) |
|
|
|
|
|
generate_button.click( |
|
generate_canonical, |
|
inputs=smiles_input, |
|
outputs=[property_table, generated_output, gen_image_display], |
|
) |
|
clear_button.click( |
|
lambda: (None, None, None, None, None, None), |
|
outputs=[ |
|
smiles_input, |
|
image_display, |
|
image_selector, |
|
gen_image_display, |
|
generated_output, |
|
property_table, |
|
], |
|
) |
|
|
|
|
|
|
|
gr.TabbedInterface( |
|
[introduction, property_prediction, molecule_generation], |
|
["Introduction", "Property Prediction", "Molecule Generation"], |
|
).launch(server_name="0.0.0.0", allowed_paths=["./"]) |
|
|