|
import gradio as gr |
|
from huggingface_hub import HfApi, get_collection, list_collections |
|
from utils import MolecularPropertyPredictionModel, dataset_task_types, dataset_descriptions, dataset_property_names, dataset_property_names_to_dataset |
|
import pandas as pd |
|
import os |
|
import spaces |
|
|
|
def get_models(): |
|
|
|
collection = get_collection("ChemFM/molecular-property-prediction-6710141ffc31f31a47d6fc0c", token = os.environ.get("TOKEN")) |
|
models = dict() |
|
for item in collection.items: |
|
if item.item_type == "model": |
|
item_name = item.item_id.split("/")[-1] |
|
models[item_name] = item.item_id |
|
assert item_name in dataset_task_types, f"{item_name} is not in the task_types" |
|
assert item_name in dataset_descriptions, f"{item_name} is not in the dataset_descriptions" |
|
|
|
return models |
|
|
|
candidate_models = get_models() |
|
properties = [dataset_property_names[item] for item in candidate_models.keys()] |
|
property_names = list(candidate_models.keys()) |
|
model = MolecularPropertyPredictionModel(candidate_models) |
|
|
|
def get_description(property_name): |
|
property_id = dataset_property_names_to_dataset[property_name] |
|
return dataset_descriptions[property_id] |
|
|
|
@spaces.GPU(duration=10) |
|
def predict_single_label(smiles, property_name): |
|
property_id = dataset_property_names_to_dataset[property_name] |
|
|
|
try: |
|
adapter_id = candidate_models[property_id] |
|
info = model.swith_adapter(property_id, adapter_id) |
|
|
|
running_status = None |
|
if info == "keep": |
|
running_status = "Adapter is the same as the current one" |
|
|
|
elif info == "switched": |
|
running_status = "Adapter is switched successfully" |
|
|
|
elif info == "error": |
|
running_status = "Adapter is not found" |
|
|
|
return "NA", running_status |
|
else: |
|
running_status = "Unknown error" |
|
return "NA", running_status |
|
|
|
|
|
prediction = model.predict_single_smiles(smiles, dataset_task_types[property_id]) |
|
if prediction is None: |
|
return "NA", "Invalid SMILES string" |
|
|
|
|
|
if isinstance(prediction, float): |
|
prediction = round(prediction, 3) |
|
except Exception as e: |
|
|
|
print(e) |
|
return "NA", "Prediction failed" |
|
|
|
return prediction, "Prediction is done" |
|
|
|
@spaces.GPU(duration=30) |
|
def predict_file(file, property_name): |
|
property_id = dataset_property_names_to_dataset[property_name] |
|
try: |
|
adapter_id = candidate_models[property_id] |
|
info = model.swith_adapter(property_id, adapter_id) |
|
|
|
running_status = None |
|
if info == "keep": |
|
running_status = "Adapter is the same as the current one" |
|
|
|
elif info == "switched": |
|
running_status = "Adapter is switched successfully" |
|
|
|
elif info == "error": |
|
running_status = "Adapter is not found" |
|
|
|
return None, None, file, running_status |
|
else: |
|
running_status = "Unknown error" |
|
return None, None, file, running_status |
|
|
|
df = pd.read_csv(file) |
|
|
|
df = model.predict_file(df, dataset_task_types[property_id]) |
|
|
|
|
|
prediction_file = file.replace(".csv", "_prediction.csv") if file.endswith(".csv") else file.replace(".smi", "_prediction.csv") |
|
print(file, prediction_file) |
|
|
|
df.to_csv(prediction_file, index=False) |
|
except Exception as e: |
|
|
|
print(e) |
|
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), file, "Prediction failed" |
|
|
|
return gr.update(visible=False), gr.DownloadButton(label="Download", value=prediction_file, visible=True), gr.update(visible=False), prediction_file, "Prediction is done" |
|
|
|
def validate_file(file): |
|
try: |
|
if file.endswith(".csv"): |
|
df = pd.read_csv(file) |
|
if "smiles" not in df.columns: |
|
|
|
return "Invalid file content. The csv file must contain column named 'smiles'", \ |
|
None, gr.update(visible=False), gr.update(visible=False) |
|
|
|
|
|
length = len(df["smiles"]) |
|
|
|
elif file.endswith(".smi"): |
|
return "Invalid file extension", \ |
|
None, gr.update(visible=False), gr.update(visible=False) |
|
|
|
else: |
|
return "Invalid file extension", \ |
|
None, gr.update(visible=False), gr.update(visible=False) |
|
except Exception as e: |
|
return "Invalid file content.", \ |
|
None, gr.update(visible=False), gr.update(visible=False) |
|
|
|
if length > 100: |
|
return "The space does not support the file containing more than 100 SMILES", \ |
|
None, gr.update(visible=False), gr.update(visible=False) |
|
|
|
return "Valid file", file, gr.update(visible=True), gr.update(visible=False) |
|
|
|
|
|
def raise_error(status): |
|
if status != "Valid file": |
|
raise gr.Error(status) |
|
return None |
|
|
|
|
|
def clear_file(download_button): |
|
|
|
prediction_path = download_button |
|
print(prediction_path) |
|
if prediction_path and os.path.exists(prediction_path): |
|
os.remove(prediction_path) |
|
original_data_file_0 = prediction_path.replace("_prediction.csv", ".csv") |
|
original_data_file_1 = prediction_path.replace("_prediction.csv", ".smi") |
|
if os.path.exists(original_data_file_0): |
|
os.remove(original_data_file_0) |
|
if os.path.exists(original_data_file_1): |
|
os.remove(original_data_file_1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return gr.update(visible=False), gr.update(visible=False), None |
|
|
|
def build_inference(): |
|
|
|
with gr.Blocks() as demo: |
|
|
|
|
|
print(property_names[0].lower()) |
|
print(properties) |
|
gr.Markdown(f"<span style='color: red;'>If you run out of your GPU quota, you can use the </span> <a href='https://huggingface.co/spaces/ChemFM/molecular_property_prediction'>CPU-powered space</a> but with much lower performance.") |
|
dropdown = gr.Dropdown(properties, label="Property", value=dataset_property_names[property_names[0].lower()]) |
|
description_box = gr.Textbox(label="Property description", lines=5, |
|
interactive=False, |
|
value=dataset_descriptions[property_names[0].lower()]) |
|
|
|
with gr.Row(equal_height=True): |
|
with gr.Column(): |
|
textbox = gr.Textbox(label="Molecule SMILES", type="text", placeholder="Provide a SMILES string here", |
|
lines=1) |
|
predict_single_smiles_button = gr.Button("Predict", size='sm') |
|
prediction = gr.Label("Prediction will appear here") |
|
|
|
running_terminal_label = gr.Textbox(label="Running status", type="text", placeholder=None, lines=10, interactive=False) |
|
|
|
input_file = gr.File(label="Molecule file", |
|
file_count='single', |
|
file_types=[".smi", ".csv"], height=300) |
|
predict_file_button = gr.Button("Predict", size='sm', visible=False) |
|
download_button = gr.DownloadButton("Download", size='sm', visible=False) |
|
stop_button = gr.Button("Stop", size='sm', visible=False) |
|
|
|
|
|
dropdown.change(get_description, inputs=dropdown, outputs=description_box) |
|
|
|
predict_single_smiles_button.click(lambda:(gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\ |
|
.then(predict_single_label, inputs=[textbox, dropdown], outputs=[prediction, running_terminal_label])\ |
|
.then(lambda:(gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label]) |
|
|
|
file_status = gr.State() |
|
input_file.upload(fn=validate_file, inputs=input_file, outputs=[file_status, input_file, predict_file_button, download_button]).success(raise_error, inputs=file_status, outputs=file_status) |
|
|
|
input_file.clear(fn=clear_file, inputs=[download_button], outputs=[predict_file_button, download_button, input_file]) |
|
|
|
predict_file_event = predict_file_button.click(lambda:(gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False, visible=True), |
|
gr.update(interactive=False), |
|
gr.update(interactive=True, visible=False), |
|
gr.update(interactive=False), |
|
gr.update(interactive=False), |
|
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label])\ |
|
.then(predict_file, inputs=[input_file, dropdown], outputs=[predict_file_button, download_button, stop_button, input_file, running_terminal_label])\ |
|
.then(lambda:(gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
gr.update(interactive=True), |
|
) , outputs=[dropdown, textbox, predict_single_smiles_button, predict_file_button, download_button, stop_button, input_file, running_terminal_label]) |
|
|
|
|
|
|
|
return demo |
|
|
|
|
|
demo = build_inference() |
|
|
|
if __name__ == '__main__': |
|
demo.launch() |