|
import gradio as gr |
|
from huggingface_hub import HfApi, get_collection, list_collections, list_models |
|
|
|
from utils import MolecularGenerationModel |
|
import pandas as pd |
|
import os |
|
import spaces |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = MolecularGenerationModel() |
|
|
|
@spaces.GPU(duration=120) |
|
def predict_single_label(logp, tpas, sas, qed, logp_choose, tpsa_choose, sas_choose, qed_choose): |
|
input_dict = dict() |
|
if logp_choose: |
|
input_dict['logP'] = logp |
|
if tpsa_choose: |
|
input_dict['TPSA'] = tpas |
|
if sas_choose: |
|
input_dict['SAS'] = sas |
|
if qed_choose: |
|
input_dict['QED'] = qed |
|
|
|
if len(input_dict) == 0: |
|
return "NA", "No input is selected" |
|
|
|
print(input_dict) |
|
|
|
try: |
|
|
|
running_status = None |
|
prediction = None |
|
|
|
prediction = model.predict_single_smiles(input_dict) |
|
|
|
|
|
|
|
if prediction is None: |
|
return "NA", "Invalid SMILES string" |
|
|
|
except Exception as e: |
|
|
|
print(e) |
|
return "NA", "Generation failed" |
|
|
|
|
|
return prediction, "Generation is done" |
|
|
|
""" |
|
def get_description(task_name): |
|
task = task_names_to_tasks[task_name] |
|
return task_descriptions[task] |
|
|
|
#@spaces.GPU(duration=10) |
|
""" |
|
|
|
""" |
|
@spaces.GPU(duration=30) |
|
def predict_file(file, property_name): |
|
property_id = dataset_property_names_to_dataset[property_name] |
|
try: |
|
adapter_id = candidate_models[property_id] |
|
info = model.swith_adapter(property_id, adapter_id) |
|
|
|
running_status = None |
|
if info == "keep": |
|
running_status = "Adapter is the same as the current one" |
|
#print("Adapter is the same as the current one") |
|
elif info == "switched": |
|
running_status = "Adapter is switched successfully" |
|
#print("Adapter is switched successfully") |
|
elif info == "error": |
|
running_status = "Adapter is not found" |
|
#print("Adapter is not found") |
|
return None, None, file, running_status |
|
else: |
|
running_status = "Unknown error" |
|
return None, None, file, running_status |
|
|
|
df = pd.read_csv(file) |
|
# we have already checked the file contains the "smiles" column |
|
df = model.predict_file(df, dataset_task_types[property_id]) |
|
# we should save this file to the disk to be downloaded |
|
# rename the file to have "_prediction" suffix |
|
prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv") |
|
print(file, prediction_file) |
|
# save the file to the disk |
|
df.to_csv(prediction_file, index=False) |
|
except Exception as e: |
|
# no matter what the error is, we should return |
|
print(e) |
|
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), file, "Prediction failed" |
|
|
|
return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done" |
|
|
|
def validate_file(file): |
|
try: |
|
if file.endswith(".csv"): |
|
df = pd.read_csv(file) |
|
if "smiles" not in df.columns: |
|
# we should clear the file input |
|
return "Invalid file content. The csv file must contain column named 'smiles'", \ |
|
None, gr.update(visible=False), gr.update(visible=False) |
|
|
|
# check the length of the smiles |
|
length = len(df["smiles"]) |
|
|
|
elif file.endswith(".smi"): |
|
return "Invalid file extension", \ |
|
None, gr.update(visible=False), gr.update(visible=False) |
|
|
|
else: |
|
return "Invalid file extension", \ |
|
None, gr.update(visible=False), gr.update(visible=False) |
|
except Exception as e: |
|
return "Invalid file content.", \ |
|
None, gr.update(visible=False), gr.update(visible=False) |
|
|
|
if length > 100: |
|
return "The space does not support the file containing more than 100 SMILES", \ |
|
None, gr.update(visible=False), gr.update(visible=False) |
|
|
|
return "Valid file", file, gr.update(visible=True), gr.update(visible=False) |
|
""" |
|
|
|
|
|
def raise_error(status): |
|
if status != "Valid file": |
|
raise gr.Error(status) |
|
return None |
|
|
|
|
|
""" |
|
def clear_file(download_button): |
|
# we might need to delete the prediction file and uploaded file |
|
prediction_path = download_button |
|
print(prediction_path) |
|
if prediction_path and os.path.exists(prediction_path): |
|
os.remove(prediction_path) |
|
original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv") |
|
original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi") |
|
if os.path.exists(original_data_file_0): |
|
os.remove(original_data_file_0) |
|
if os.path.exists(original_data_file_1): |
|
os.remove(original_data_file_1) |
|
#if os.path.exists(file): |
|
# os.remove(file) |
|
#prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv") |
|
#if os.path.exists(prediction_file): |
|
# os.remove(prediction_file) |
|
|
|
|
|
return gr.update(visible=False), gr.update(visible=False), None |
|
""" |
|
|
|
def toggle_slider(checked): |
|
return gr.update(interactive=checked) |
|
|
|
def toggle_sliders_based_on_checkboxes(checked_values): |
|
"""Enable or disable sliders based on the corresponding checkbox values.""" |
|
return [gr.update(interactive=checked_values[i]) for i in range(4)] |
|
|
|
def build_inference(): |
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
|
|
|
|
description = f"This space allows you to generate ten possible molecules based on given conditions. \n" \ |
|
f"1. You can enable or disable specific properties using checkboxes and adjust their values with sliders. \n" \ |
|
f"2. The generated SMILES strings and their corresponding predicted properties will be displayed in the generations section. \n" \ |
|
f"3. The properties include logP, TPSA, SAS, and QED. \n" \ |
|
f"4. Model trained on the GuacaMol dataset for molecular design. " |
|
|
|
description_box = gr.Textbox(label="Task description", lines=5, |
|
interactive=False, |
|
value= description) |
|
|
|
with gr.Row(equal_height=True): |
|
with gr.Column(): |
|
checkbox_1 = gr.Checkbox(label="logP", value=True) |
|
slider_1 = gr.Slider(1, 7, value=4, label="logP", info="Choose between 1 and 7") |
|
checkbox_1.change(toggle_slider, checkbox_1, slider_1) |
|
with gr.Column(): |
|
checkbox_2 = gr.Checkbox(label="TPSA", value=True) |
|
slider_2 = gr.Slider(20, 140, value=80, label="TPSA", info="Choose between 20 and 140") |
|
checkbox_2.change(toggle_slider, checkbox_2, slider_2) |
|
with gr.Column(): |
|
checkbox_3 = gr.Checkbox(label="SAS", value=True) |
|
slider_3 = gr.Slider(1, 5, value=3, label="SAS", info="Choose between 1 and 5") |
|
checkbox_3.change(toggle_slider, checkbox_3, slider_3) |
|
with gr.Column(): |
|
checkbox_4 = gr.Checkbox(label="QED", value=True) |
|
slider_4 = gr.Slider(0.1, 0.9, value=0.5, label="QED", info="Choose between 0.1 and 0.9") |
|
checkbox_4.change(toggle_slider, checkbox_4, slider_4) |
|
|
|
predict_single_smiles_button = gr.Button("Generate", size='sm') |
|
|
|
|
|
prediction = gr.Dataframe(label="Generations", type="pandas", interactive=False) |
|
|
|
running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False) |
|
|
|
|
|
|
|
|
|
predict_single_label.zerogpu=True |
|
predict_single_smiles_button.click(lambda:(gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
) , outputs=[slider_1, slider_2, slider_3, slider_4, |
|
checkbox_1, checkbox_2, checkbox_3, checkbox_4, |
|
predict_single_smiles_button, running_terminal_label])\ |
|
.then(predict_single_label, inputs=[slider_1, slider_2, slider_3, slider_4, |
|
checkbox_1, checkbox_2, checkbox_3, checkbox_4 |
|
], outputs=[prediction, running_terminal_label])\ |
|
.then(lambda a, b, c, d: toggle_sliders_based_on_checkboxes([a, b, c, d]) + |
|
[gr.update(interactive=True)] * 6, |
|
inputs=[checkbox_1, checkbox_2, checkbox_3, checkbox_4], |
|
outputs=[slider_1, slider_2, slider_3, slider_4, |
|
checkbox_1, checkbox_2, checkbox_3, checkbox_4, |
|
predict_single_smiles_button, running_terminal_label]) |
|
|
|
return demo |
|
|
|
|
|
demo = build_inference() |
|
|
|
if __name__ == '__main__': |
|
demo.launch() |