import gradio as gr
import numpy as np
import pandas as pd
from tempfile import NamedTemporaryFile
from PIL import Image
from rdkit import RDLogger
from sklearn.model_selection import train_test_split
from molecule_generation_helpers import *
from property_prediction_helpers import *
DEBUG_VISIBLE = False
RDLogger.logger().setLevel(RDLogger.ERROR)
# Predefined dataset paths (these should be adjusted to your file paths)
predefined_datasets = {
" ": " ",
"BACE": "./data/bace/train.csv, ./data/bace/test.csv, smiles, Class",
"ESOL": "./data/esol/train.csv, ./data/esol/test.csv, smiles, prop",
}
# Models
models_enabled = [
"MorganFingerprint",
"SMI-TED",
"SELFIES-TED",
"MHG-GED",
"POS-EGNN",
]
blank_df = pd.DataFrame({"id": [], "Model": [], "Score": []})
# Function to load a predefined dataset from the local path
def load_predefined_dataset(dataset_name):
val = predefined_datasets.get(dataset_name)
if val:
try:
df = pd.read_csv(val.split(",")[0])
return (
df.head(),
gr.update(choices=list(df.columns), value=None),
gr.update(choices=list(df.columns), value=None),
dataset_name.lower(),
)
except:
pass
else:
dataset_name = "Custom"
return (
pd.DataFrame(),
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
dataset_name.lower(),
)
# Function to handle dataset selection (predefined or custom)
def handle_dataset_selection(selected_dataset, state):
state["dataset_name"] = (
selected_dataset if selected_dataset in predefined_datasets else "CUSTOM"
)
# Show file upload fields for train and test datasets if "Custom Dataset" is selected
task_type = (
"Classification"
if selected_dataset == "BACE"
else "Regression" if selected_dataset == "ESOL" else None
)
return (
gr.update(visible=selected_dataset not in predefined_datasets or DEBUG_VISIBLE),
task_type,
)
# Function to select input and output columns and display a message
def select_columns(input_column, output_column, train_data, test_data, state):
if train_data and test_data and input_column and output_column:
return f"{train_data.name},{test_data.name},{input_column},{output_column},{state['dataset_name']}"
return gr.update()
# Function to display the head of the uploaded CSV file
def display_csv_head(file):
if file is not None:
# Load the CSV file into a DataFrame
df = pd.read_csv(file.name)
return (
df.head(),
gr.update(choices=list(df.columns)),
gr.update(choices=list(df.columns)),
)
return pd.DataFrame(), gr.update(choices=[]), gr.update(choices=[])
def process_custom_file(file, selected_dataset):
if file and os.path.getsize(file.name) < 50 * 1024:
df = pd.read_csv(file.name)
if "input" in df.columns and "output" in df.columns:
train, test = train_test_split(df, test_size=0.2)
with NamedTemporaryFile(
prefix="fm4m-train-", suffix=".csv", delete=False
) as train_file:
train.to_csv(train_file.name, index=False)
with NamedTemporaryFile(
prefix="fm4m-test-", suffix=".csv", delete=False
) as test_file:
test.to_csv(test_file.name, index=False)
task_type = (
"Classification" if df["output"].dtype == np.int64 else "Regression"
)
return train_file.name, test_file.name, "input", "output", task_type
return (
None,
None,
None,
None,
gr.update() if selected_dataset in predefined_datasets else None,
)
def update_plot_choices(current, state):
choices = []
if state.get("roc_auc") is not None:
choices.append("ROC-AUC")
if state.get("RMSE") is not None:
choices.append("Parity Plot")
if state.get("x_batch") is not None:
choices.append("Latent Space")
if current in choices:
return gr.update(choices=choices)
return gr.update(choices=choices, value=None if len(choices) == 0 else choices[0])
def log_selected(df: pd.DataFrame, evt: gr.SelectData, state):
state.update(state["results"].get(df.at[evt.index[0], 'id'], {}))
# Dictionary for SMILES strings and corresponding images (you can replace with your actual image paths)
smiles_image_mapping = {
# Example SMILES for ethanol
"Mol 1": {
"smiles": "C=C(C)CC(=O)NC[C@H](CO)NC(=O)C=Cc1ccc(C)c(Cl)c1",
"image": "img/img1.png",
},
# Example SMILES for butane
"Mol 2": {
"smiles": "C=CC1(CC(=O)NC[C@@H](CCCC)NC(=O)c2cc(Cl)cc(Br)c2)CC1",
"image": "img/img2.png",
},
# Example SMILES for ethylamine
"Mol 3": {
"smiles": "C=C(C)C[C@H](NC(C)=O)C(=O)N1CC[C@H](NC(=O)[C@H]2C[C@@]2(C)Br)C(C)(C)C1",
"image": "img/img3.png",
},
# Example SMILES for diethyl ether
"Mol 4": {
"smiles": "C=C1CC(CC(=O)N[C@H]2CCN(C(=O)c3ncccc3SC)C23CC3)C1",
"image": "img/img4.png",
},
# Example SMILES for chloroethane
"Mol 5": {
"smiles": "C=CCS[C@@H](C)CC(=O)OCC",
"image": "img/img5.png",
},
}
# Load images for selection
def load_image(path):
try:
return Image.open(smiles_image_mapping[path]["image"])
except:
pass
# Function to handle image selection
def handle_image_selection(image_key):
if not image_key:
return None, None
smiles = smiles_image_mapping[image_key]["smiles"]
mol_image = smiles_to_image(smiles)
return smiles, mol_image
# Introduction
with gr.Blocks() as introduction:
with open("INTRODUCTION.md") as f:
gr.Markdown(f.read(), sanitize_html=False)
# Property Prediction
with gr.Blocks() as property_prediction:
state = gr.State({"model_name": "Default - Auto", "results": {}})
gr.HTML(
'''
Task : Property Prediction
Models are finetuned with different combination of modalities on the uploaded or selected built data set.
'''
)
with gr.Row():
with gr.Column():
# Dropdown menu for predefined datasets including "Custom Dataset" option
dataset_selector = gr.Dropdown(
label="Select Dataset",
choices=list(predefined_datasets.keys()) + ["Custom Dataset"],
)
# Display the message for selected columns
selected_columns_message = gr.Textbox(
label="Selected Columns Info", visible=DEBUG_VISIBLE
)
with gr.Accordion(
"Custom Dataset Settings", open=True, visible=DEBUG_VISIBLE
) as settings:
# File upload options for custom dataset (train and test)
custom_file = gr.File(
label="Upload Custom Dataset",
file_types=[".csv"],
)
train_file = gr.File(
label="Upload Custom Train Dataset",
file_types=[".csv"],
visible=False,
)
train_display = gr.Dataframe(
label="Train Dataset Preview (First 5 Rows)",
interactive=False,
visible=DEBUG_VISIBLE,
)
test_file = gr.File(
label="Upload Custom Test Dataset",
file_types=[".csv"],
visible=False,
)
test_display = gr.Dataframe(
label="Test Dataset Preview (First 5 Rows)",
interactive=False,
visible=DEBUG_VISIBLE,
)
# Predefined dataset displays
predefined_display = gr.Dataframe(
label="Predefined Dataset Preview (First 5 Rows)",
interactive=False,
visible=DEBUG_VISIBLE,
)
# Dropdowns for selecting input and output columns for the custom dataset
input_column_selector = gr.Dropdown(
label="Select Input Column",
choices=[],
allow_custom_value=True,
visible=DEBUG_VISIBLE,
)
output_column_selector = gr.Dropdown(
label="Select Output Column",
choices=[],
allow_custom_value=True,
visible=DEBUG_VISIBLE,
)
# When a custom train file is uploaded, display its head and update column selectors
train_file.change(
display_csv_head,
inputs=train_file,
outputs=[
train_display,
input_column_selector,
output_column_selector,
],
)
# When a custom test file is uploaded, display its head
test_file.change(
display_csv_head,
inputs=test_file,
outputs=[
test_display,
input_column_selector,
output_column_selector,
],
)
model_checkbox = gr.CheckboxGroup(
choices=models_enabled, label="Select Model", visible=DEBUG_VISIBLE
)
task_radiobutton = gr.Radio(
choices=["Classification", "Regression"],
label="Task Type",
visible=DEBUG_VISIBLE,
)
# When a dataset is selected, show either file upload fields (for custom) or load predefined datasets
# When a predefined dataset is selected, load its head and update column selectors
dataset_selector.change(lambda: None, outputs=custom_file).then(
handle_dataset_selection,
inputs=[dataset_selector, state],
outputs=[settings, task_radiobutton],
).then(
load_predefined_dataset,
inputs=dataset_selector,
outputs=[
predefined_display,
input_column_selector,
output_column_selector,
selected_columns_message,
],
)
custom_file.change(
process_custom_file,
inputs=[custom_file, dataset_selector],
outputs=[
train_file,
test_file,
input_column_selector,
output_column_selector,
task_radiobutton,
],
)
eval_clear_button = gr.Button("Clear")
eval_button = gr.Button("Submit", variant="primary")
step_slider = gr.Slider(
minimum=0,
maximum=31,
value=0,
label="Progress",
show_label=True,
interactive=False,
visible=False,
)
# Right Column
with gr.Column():
log_table = gr.Dataframe(value=blank_df, interactive=False)
plot_radio = gr.Radio(choices=[], label="Select Plot Type")
plot_output = gr.Plot(label="Visualization")
log_table.select(log_selected, [log_table, state]).success(
update_plot_choices, inputs=[plot_radio, state], outputs=plot_radio
).then(display_plot, inputs=[plot_radio, state], outputs=plot_output)
def clear_eval(state):
state["results"] = {}
return None, gr.update(choices=[], value=None), blank_df
def eval_part(part, step, selector, show_progress=False):
return (
part.then(
lambda: [models_enabled[x] for x in selector],
outputs=model_checkbox,
)
.then(
evaluate_and_log,
inputs=[
model_checkbox,
selected_columns_message,
task_radiobutton,
log_table,
state,
],
outputs=log_table,
show_progress=show_progress,
)
.then(lambda: step, outputs=step_slider, show_progress=False)
)
part = (
eval_button.click(
lambda: (
gr.update(interactive=False),
gr.update(interactive=False),
),
outputs=[eval_clear_button, eval_button],
)
.then(
select_columns,
inputs=[
input_column_selector,
output_column_selector,
train_file,
test_file,
state,
],
outputs=selected_columns_message,
)
.then(
clear_eval,
inputs=state,
outputs=[
plot_output,
plot_radio,
log_table,
],
)
)
part = part.then(
lambda: gr.update(value=0, visible=True),
outputs=step_slider,
show_progress=False,
)
from itertools import combinations
part_index = 1 # start index
for r in range(1, 6): # for group sizes 2 to 5
for combo in combinations(range(5), r): # 5 items: indices 0 to 4
if list(combo) == [0]:
part = eval_part(part, part_index, list(combo), True)
else: part = eval_part(part, part_index, list(combo))
part_index += 1
"""part = eval_part(part, 1, [0], True)
part = eval_part(part, 2, [1])
part = eval_part(part, 3, [2])
part = eval_part(part, 4, [3])
part = eval_part(part, 5, [4])
part = eval_part(part, 6, [0, 1])
part = eval_part(part, 7, [0, 2])
part = eval_part(part, 8, [0, 3])
part = eval_part(part, 9, [0, 4])
part = eval_part(part, 10, [1, 2])
part = eval_part(part, 11, [1, 3])
part = eval_part(part, 12, [1, 4])
part = eval_part(part, 13, [2, 3])
part = eval_part(part, 14, [2, 4])
part = eval_part(part, 15, [3, 4])
part = eval_part(part, 16, [0, 1, 2])
part = eval_part(part, 17, [0, 1, 3])
part = eval_part(part, 18, [0, 1, 4])
part = eval_part(part, 19, [0, 2, 3])
part = eval_part(part, 20, [0, 2, 4])
part = eval_part(part, 21, [0, 3, 4])
part = eval_part(part, 22, [1, 2, 3])
part = eval_part(part, 23, [1, 2, 4])
part = eval_part(part, 24, [1, 3, 4])
part = eval_part(part, 25, [2, 3, 4])
part = eval_part(part, 26, [0, 1, 2, 3])
part = eval_part(part, 27, [0, 1, 2, 4])
part = eval_part(part, 28, [0, 1, 3, 4])
part = eval_part(part, 29, [0, 2, 3, 4])
part = eval_part(part, 30, [1, 2, 3, 4])
part = eval_part(part, 31, [0,1, 2, 3, 4])"""
part = part.then(
lambda: gr.update(visible=False),
outputs=step_slider,
show_progress=False,
)
part.then(
lambda: (
gr.update(interactive=True),
gr.update(interactive=True),
),
outputs=[eval_clear_button, eval_button],
)
plot_radio.change(
display_plot, inputs=[plot_radio, state], outputs=plot_output
)
eval_clear_button.click(
clear_eval,
inputs=state,
outputs=[
plot_output,
plot_radio,
log_table,
],
).then(lambda: " ", outputs=dataset_selector)
# Molecule Generation
with gr.Blocks() as molecule_generation:
gr.HTML(
'''
Task : Molecule Generation
Generate a new molecule similar to the initial molecule with better drug-likeness and synthetic accessibility.
'''
)
with gr.Row():
with gr.Column():
smiles_input = gr.Textbox(label="Input SMILES String")
image_display = gr.Image(label="Molecule Image", height=250, width=250)
# Show images for selection
with gr.Accordion("Select from sample molecules", open=False):
image_selector = gr.Radio(
choices=list(smiles_image_mapping.keys()),
label="Select from sample molecules",
value=None,
)
image_selector.change(load_image, image_selector, image_display)
clear_button = gr.Button("Clear")
generate_button = gr.Button("Submit", variant="primary")
# Right Column
with gr.Column():
gen_image_display = gr.Image(
label="Generated Molecule Image", height=250, width=250
)
generated_output = gr.Textbox(label="Generated Output")
property_table = gr.Dataframe(label="Molecular Properties Comparison")
# Handle image selection
image_selector.change(
handle_image_selection,
inputs=image_selector,
outputs=[smiles_input, image_display],
)
smiles_input.change(
smiles_to_image, inputs=smiles_input, outputs=image_display
)
# Generate button to display canonical SMILES and molecule image
generate_button.click(
lambda: (
gr.update(interactive=False),
gr.update(interactive=False),
),
outputs=[clear_button, generate_button],
).then(
generate_canonical,
inputs=smiles_input,
outputs=[property_table, generated_output, gen_image_display],
).then(
lambda: (
gr.update(interactive=True),
gr.update(interactive=True),
),
outputs=[clear_button, generate_button],
)
clear_button.click(
lambda: (None, None, None, None, None, None),
outputs=[
smiles_input,
image_display,
image_selector,
gen_image_display,
generated_output,
property_table,
],
)
# Render with tabs
gr.TabbedInterface(
[introduction, property_prediction, molecule_generation],
["Introduction", "Property Prediction", "Molecule Generation"],
).launch(server_name="0.0.0.0", allowed_paths=["./"])