Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from tempfile import NamedTemporaryFile | |
from PIL import Image | |
from rdkit import RDLogger | |
from sklearn.model_selection import train_test_split | |
from molecule_generation_helpers import * | |
from property_prediction_helpers import * | |
DEBUG_VISIBLE = False | |
RDLogger.logger().setLevel(RDLogger.ERROR) | |
# Predefined dataset paths (these should be adjusted to your file paths) | |
predefined_datasets = { | |
" ": " ", | |
"BACE": "./data/bace/train.csv, ./data/bace/test.csv, smiles, Class", | |
"ESOL": "./data/esol/train.csv, ./data/esol/test.csv, smiles, prop", | |
} | |
# Models | |
models_enabled = [ | |
"MorganFingerprint", | |
"SMI-TED", | |
"SELFIES-TED", | |
"MHG-GED", | |
"POS-EGNN", | |
] | |
blank_df = pd.DataFrame({"id": [], "Model": [], "Score": []}) | |
# Function to load a predefined dataset from the local path | |
def load_predefined_dataset(dataset_name): | |
val = predefined_datasets.get(dataset_name) | |
if val: | |
try: | |
df = pd.read_csv(val.split(",")[0]) | |
return ( | |
df.head(), | |
gr.update(choices=list(df.columns), value=None), | |
gr.update(choices=list(df.columns), value=None), | |
dataset_name.lower(), | |
) | |
except: | |
pass | |
else: | |
dataset_name = "Custom" | |
return ( | |
pd.DataFrame(), | |
gr.update(choices=[], value=None), | |
gr.update(choices=[], value=None), | |
dataset_name.lower(), | |
) | |
# Function to handle dataset selection (predefined or custom) | |
def handle_dataset_selection(selected_dataset, state): | |
state["dataset_name"] = ( | |
selected_dataset if selected_dataset in predefined_datasets else "CUSTOM" | |
) | |
# Show file upload fields for train and test datasets if "Custom Dataset" is selected | |
task_type = ( | |
"Classification" | |
if selected_dataset == "BACE" | |
else "Regression" if selected_dataset == "ESOL" else None | |
) | |
return ( | |
gr.update(visible=selected_dataset not in predefined_datasets or DEBUG_VISIBLE), | |
task_type, | |
) | |
# Function to select input and output columns and display a message | |
def select_columns(input_column, output_column, train_data, test_data, state): | |
if train_data and test_data and input_column and output_column: | |
return f"{train_data.name},{test_data.name},{input_column},{output_column},{state['dataset_name']}" | |
return gr.update() | |
# Function to display the head of the uploaded CSV file | |
def display_csv_head(file): | |
if file is not None: | |
# Load the CSV file into a DataFrame | |
df = pd.read_csv(file.name) | |
return ( | |
df.head(), | |
gr.update(choices=list(df.columns)), | |
gr.update(choices=list(df.columns)), | |
) | |
return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[]) | |
def process_custom_file(file, selected_dataset): | |
if file and os.path.getsize(file.name) < 50 * 1024: | |
df = pd.read_csv(file.name) | |
if "input" in df.columns and "output" in df.columns: | |
train, test = train_test_split(df, test_size=0.2) | |
with NamedTemporaryFile( | |
prefix="fm4m-train-", suffix=".csv", delete=False | |
) as train_file: | |
train.to_csv(train_file.name, index=False) | |
with NamedTemporaryFile( | |
prefix="fm4m-test-", suffix=".csv", delete=False | |
) as test_file: | |
test.to_csv(test_file.name, index=False) | |
task_type = ( | |
"Classification" if df["output"].dtype == np.int64 else "Regression" | |
) | |
return train_file.name, test_file.name, "input", "output", task_type | |
return ( | |
None, | |
None, | |
None, | |
None, | |
gr.update() if selected_dataset in predefined_datasets else None, | |
) | |
def update_plot_choices(current, state): | |
choices = [] | |
if state.get("roc_auc") is not None: | |
choices.append("ROC-AUC") | |
if state.get("RMSE") is not None: | |
choices.append("Parity Plot") | |
if state.get("x_batch") is not None: | |
choices.append("Latent Space") | |
if current in choices: | |
return gr.update(choices=choices) | |
return gr.update(choices=choices, value=None if len(choices) == 0 else choices[0]) | |
def log_selected(df: pd.DataFrame, evt: gr.SelectData, state): | |
state.update(state["results"].get(df.at[evt.index[0], 'id'], {})) | |
# Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths) | |
smiles_image_mapping = { | |
# Example SMILES for ethanol | |
"Mol 1": { | |
"smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1", | |
"image": "img/img1.png", | |
}, | |
# Example SMILES for butane | |
"Mol 2": { | |
"smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1", | |
"image": "img/img2.png", | |
}, | |
# Example SMILES for ethylamine | |
"Mol 3": { | |
"smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1", | |
"image": "img/img3.png", | |
}, | |
# Example SMILES for diethyl ether | |
"Mol 4": { | |
"smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1", | |
"image": "img/img4.png", | |
}, | |
# Example SMILES for chloroethane | |
"Mol 5": { | |
"smiles": "C=CCS[C@@H](C)CC(=O)OCC", | |
"image": "img/img5.png", | |
}, | |
} | |
# Load images for selection | |
def load_image(path): | |
try: | |
return Image.open(smiles_image_mapping[path]["image"]) | |
except: | |
pass | |
# Function to handle image selection | |
def handle_image_selection(image_key): | |
if not image_key: | |
return None, None | |
smiles = smiles_image_mapping[image_key]["smiles"] | |
mol_image = smiles_to_image(smiles) | |
return smiles, mol_image | |
# Introduction | |
with gr.Blocks() as introduction: | |
with open("INTRODUCTION.md") as f: | |
gr.Markdown(f.read(), sanitize_html=False) | |
# Property Prediction | |
with gr.Blocks() as property_prediction: | |
state = gr.State({"model_name": "Default - Auto", "results": {}}) | |
gr.HTML( | |
''' | |
<p style="text-align: center"> | |
Task : Property Prediction | |
<br> | |
Models are finetuned with different combination of modalities on the uploaded or selected built data set. | |
</p> | |
''' | |
) | |
with gr.Row(): | |
with gr.Column(): | |
# Dropdown menu for predefined datasets including "Custom Dataset" option | |
dataset_selector = gr.Dropdown( | |
label="Select Dataset", | |
choices=list(predefined_datasets.keys()) + ["Custom Dataset"], | |
) | |
# Display the message for selected columns | |
selected_columns_message = gr.Textbox( | |
label="Selected Columns Info", visible=DEBUG_VISIBLE | |
) | |
with gr.Accordion( | |
"Custom Dataset Settings", open=True, visible=DEBUG_VISIBLE | |
) as settings: | |
# File upload options for custom dataset (train and test) | |
custom_file = gr.File( | |
label="Upload Custom Dataset", | |
file_types=[".csv"], | |
) | |
train_file = gr.File( | |
label="Upload Custom Train Dataset", | |
file_types=[".csv"], | |
visible=False, | |
) | |
train_display = gr.Dataframe( | |
label="Train Dataset Preview (First 5 Rows)", | |
interactive=False, | |
visible=DEBUG_VISIBLE, | |
) | |
test_file = gr.File( | |
label="Upload Custom Test Dataset", | |
file_types=[".csv"], | |
visible=False, | |
) | |
test_display = gr.Dataframe( | |
label="Test Dataset Preview (First 5 Rows)", | |
interactive=False, | |
visible=DEBUG_VISIBLE, | |
) | |
# Predefined dataset displays | |
predefined_display = gr.Dataframe( | |
label="Predefined Dataset Preview (First 5 Rows)", | |
interactive=False, | |
visible=DEBUG_VISIBLE, | |
) | |
# Dropdowns for selecting input and output columns for the custom dataset | |
input_column_selector = gr.Dropdown( | |
label="Select Input Column", | |
choices=[], | |
allow_custom_value=True, | |
visible=DEBUG_VISIBLE, | |
) | |
output_column_selector = gr.Dropdown( | |
label="Select Output Column", | |
choices=[], | |
allow_custom_value=True, | |
visible=DEBUG_VISIBLE, | |
) | |
# When a custom train file is uploaded, display its head and update column selectors | |
train_file.change( | |
display_csv_head, | |
inputs=train_file, | |
outputs=[ | |
train_display, | |
input_column_selector, | |
output_column_selector, | |
], | |
) | |
# When a custom test file is uploaded, display its head | |
test_file.change( | |
display_csv_head, | |
inputs=test_file, | |
outputs=[ | |
test_display, | |
input_column_selector, | |
output_column_selector, | |
], | |
) | |
model_checkbox = gr.CheckboxGroup( | |
choices=models_enabled, label="Select Model", visible=DEBUG_VISIBLE | |
) | |
task_radiobutton = gr.Radio( | |
choices=["Classification", "Regression"], | |
label="Task Type", | |
visible=DEBUG_VISIBLE, | |
) | |
# When a dataset is selected, show either file upload fields (for custom) or load predefined datasets | |
# When a predefined dataset is selected, load its head and update column selectors | |
dataset_selector.change(lambda: None, outputs=custom_file).then( | |
handle_dataset_selection, | |
inputs=[dataset_selector, state], | |
outputs=[settings, task_radiobutton], | |
).then( | |
load_predefined_dataset, | |
inputs=dataset_selector, | |
outputs=[ | |
predefined_display, | |
input_column_selector, | |
output_column_selector, | |
selected_columns_message, | |
], | |
) | |
custom_file.change( | |
process_custom_file, | |
inputs=[custom_file, dataset_selector], | |
outputs=[ | |
train_file, | |
test_file, | |
input_column_selector, | |
output_column_selector, | |
task_radiobutton, | |
], | |
) | |
eval_clear_button = gr.Button("Clear") | |
eval_button = gr.Button("Submit", variant="primary") | |
step_slider = gr.Slider( | |
minimum=0, | |
maximum=31, | |
value=0, | |
label="Progress", | |
show_label=True, | |
interactive=False, | |
visible=False, | |
) | |
# Right Column | |
with gr.Column(): | |
log_table = gr.Dataframe(value=blank_df, interactive=False) | |
plot_radio = gr.Radio(choices=[], label="Select Plot Type") | |
plot_output = gr.Plot(label="Visualization") | |
log_table.select(log_selected, [log_table, state]).success( | |
update_plot_choices, inputs=[plot_radio, state], outputs=plot_radio | |
).then(display_plot, inputs=[plot_radio, state], outputs=plot_output) | |
def clear_eval(state): | |
state["results"] = {} | |
return None, gr.update(choices=[], value=None), blank_df | |
def eval_part(part, step, selector, show_progress=False): | |
return ( | |
part.then( | |
lambda: [models_enabled[x] for x in selector], | |
outputs=model_checkbox, | |
) | |
.then( | |
evaluate_and_log, | |
inputs=[ | |
model_checkbox, | |
selected_columns_message, | |
task_radiobutton, | |
log_table, | |
state, | |
], | |
outputs=log_table, | |
show_progress=show_progress, | |
) | |
.then(lambda: step, outputs=step_slider, show_progress=False) | |
) | |
part = ( | |
eval_button.click( | |
lambda: ( | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
), | |
outputs=[eval_clear_button, eval_button], | |
) | |
.then( | |
select_columns, | |
inputs=[ | |
input_column_selector, | |
output_column_selector, | |
train_file, | |
test_file, | |
state, | |
], | |
outputs=selected_columns_message, | |
) | |
.then( | |
clear_eval, | |
inputs=state, | |
outputs=[ | |
plot_output, | |
plot_radio, | |
log_table, | |
], | |
) | |
) | |
part = part.then( | |
lambda: gr.update(value=0, visible=True), | |
outputs=step_slider, | |
show_progress=False, | |
) | |
from itertools import combinations | |
part_index = 1 # start index | |
for r in range(1, 6): # for group sizes 2 to 5 | |
for combo in combinations(range(5), r): # 5 items: indices 0 to 4 | |
if list(combo) == [0]: | |
part = eval_part(part, part_index, list(combo), True) | |
else: part = eval_part(part, part_index, list(combo)) | |
part_index += 1 | |
"""part = eval_part(part, 1, [0], True) | |
part = eval_part(part, 2, [1]) | |
part = eval_part(part, 3, [2]) | |
part = eval_part(part, 4, [3]) | |
part = eval_part(part, 5, [4]) | |
part = eval_part(part, 6, [0, 1]) | |
part = eval_part(part, 7, [0, 2]) | |
part = eval_part(part, 8, [0, 3]) | |
part = eval_part(part, 9, [0, 4]) | |
part = eval_part(part, 10, [1, 2]) | |
part = eval_part(part, 11, [1, 3]) | |
part = eval_part(part, 12, [1, 4]) | |
part = eval_part(part, 13, [2, 3]) | |
part = eval_part(part, 14, [2, 4]) | |
part = eval_part(part, 15, [3, 4]) | |
part = eval_part(part, 16, [0, 1, 2]) | |
part = eval_part(part, 17, [0, 1, 3]) | |
part = eval_part(part, 18, [0, 1, 4]) | |
part = eval_part(part, 19, [0, 2, 3]) | |
part = eval_part(part, 20, [0, 2, 4]) | |
part = eval_part(part, 21, [0, 3, 4]) | |
part = eval_part(part, 22, [1, 2, 3]) | |
part = eval_part(part, 23, [1, 2, 4]) | |
part = eval_part(part, 24, [1, 3, 4]) | |
part = eval_part(part, 25, [2, 3, 4]) | |
part = eval_part(part, 26, [0, 1, 2, 3]) | |
part = eval_part(part, 27, [0, 1, 2, 4]) | |
part = eval_part(part, 28, [0, 1, 3, 4]) | |
part = eval_part(part, 29, [0, 2, 3, 4]) | |
part = eval_part(part, 30, [1, 2, 3, 4]) | |
part = eval_part(part, 31, [0,1, 2, 3, 4])""" | |
part = part.then( | |
lambda: gr.update(visible=False), | |
outputs=step_slider, | |
show_progress=False, | |
) | |
part.then( | |
lambda: ( | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
), | |
outputs=[eval_clear_button, eval_button], | |
) | |
plot_radio.change( | |
display_plot, inputs=[plot_radio, state], outputs=plot_output | |
) | |
eval_clear_button.click( | |
clear_eval, | |
inputs=state, | |
outputs=[ | |
plot_output, | |
plot_radio, | |
log_table, | |
], | |
).then(lambda: " ", outputs=dataset_selector) | |
# Molecule Generation | |
with gr.Blocks() as molecule_generation: | |
gr.HTML( | |
''' | |
<p style="text-align: center"> | |
Task : Molecule Generation | |
<br> | |
Generate a new molecule similar to the initial molecule with better drug-likeness and synthetic accessibility. | |
</p> | |
''' | |
) | |
with gr.Row(): | |
with gr.Column(): | |
smiles_input = gr.Textbox(label="Input SMILES String") | |
image_display = gr.Image(label="Molecule Image", height=250, width=250) | |
# Show images for selection | |
with gr.Accordion("Select from sample molecules", open=False): | |
image_selector = gr.Radio( | |
choices=list(smiles_image_mapping.keys()), | |
label="Select from sample molecules", | |
value=None, | |
) | |
image_selector.change(load_image, image_selector, image_display) | |
clear_button = gr.Button("Clear") | |
generate_button = gr.Button("Submit", variant="primary") | |
# Right Column | |
with gr.Column(): | |
gen_image_display = gr.Image( | |
label="Generated Molecule Image", height=250, width=250 | |
) | |
generated_output = gr.Textbox(label="Generated Output") | |
property_table = gr.Dataframe(label="Molecular Properties Comparison") | |
# Handle image selection | |
image_selector.change( | |
handle_image_selection, | |
inputs=image_selector, | |
outputs=[smiles_input, image_display], | |
) | |
smiles_input.change( | |
smiles_to_image, inputs=smiles_input, outputs=image_display | |
) | |
# Generate button to display canonical SMILES and molecule image | |
generate_button.click( | |
lambda: ( | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
), | |
outputs=[clear_button, generate_button], | |
).then( | |
generate_canonical, | |
inputs=smiles_input, | |
outputs=[property_table, generated_output, gen_image_display], | |
).then( | |
lambda: ( | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
), | |
outputs=[clear_button, generate_button], | |
) | |
clear_button.click( | |
lambda: (None, None, None, None, None, None), | |
outputs=[ | |
smiles_input, | |
image_display, | |
image_selector, | |
gen_image_display, | |
generated_output, | |
property_table, | |
], | |
) | |
# Render with tabs | |
gr.TabbedInterface( | |
[introduction, property_prediction, molecule_generation], | |
["Introduction", "Property Prediction", "Molecule Generation"], | |
).launch(server_name="0.0.0.0", allowed_paths=["./"]) | |