eval / inference_app.py
simonduerr's picture
Update inference_app.py
876eb15 verified
raw
history blame
5.42 kB
from __future__ import annotations
from pathlib import Path
import time
from biotite.application.autodock import VinaApp
import gradio as gr
from gradio_molecule3d import Molecule3D
from gradio_molecule2d import molecule2d
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
import pandas as pd
from biotite.structure import centroid, from_template
from biotite.structure.io import load_structure
from biotite.structure.io.mol import MOLFile, SDFile
from biotite.structure.io.pdb import PDBFile
from plinder.eval.docking.write_scores import evaluate
EVAL_METRICS = ["system", "LDDT-PLI", "LDDT-LP", "BISY-RMSD"]
EVAL_METRICS_PINDER = ["system","L_rms", "I_rms", "F_nat", "DOCKQ", "CAPRI_class"]
def get_metrics(
system_id: str,
receptor_file: Path,
ligand_file: Path,
flexible: bool = True,
posebusters: bool = True,
methodname: str = "",
store:bool =True
) -> tuple[pd.DataFrame, float]:
start_time = time.time()
metrics = pd.DataFrame(
[
evaluate(
model_system_id=system_id,
reference_system_id=system_id,
receptor_file=receptor_file,
ligand_file_list=[Path(ligand_file)],
flexible=flexible,
posebusters=posebusters,
posebusters_full=False,
).get("LIG_0", {})
]
)
if posebusters:
metrics["posebusters"] = metrics[
[col for col in metrics.columns if col.startswith("posebusters_")]
].sum(axis=1)
metrics["posebusters_valid"] = metrics[
[col for col in metrics.columns if col.startswith("posebusters_")]
].sum(axis=1) == 20
columns = ["reference", "lddt_pli_ave", "lddt_lp_ave", "bisy_rmsd_ave"]
if flexible:
columns.extend(["lddt", "bb_lddt"])
if posebusters:
columns.extend([col for col in metrics.columns if col.startswith("posebusters")])
metrics = metrics[columns].copy()
mapping = {
"lddt_pli_ave": "LDDT-PLI",
"lddt_lp_ave": "LDDT-LP",
"bisy_rmsd_ave": "BISY-RMSD",
"reference": "system",
}
if flexible:
mapping["lddt"] = "LDDT"
mapping["bb_lddt"] = "Backbone LDDT"
if posebusters:
mapping["posebusters"] = "PoseBusters #checks"
mapping["posebusters_valid"] = "PoseBusters valid"
metrics.rename(
columns=mapping,
inplace=True,
)
end_time = time.time()
run_time = end_time - start_time
return metrics, run_time
def get_metrics_pinder(
system_id: str,
receptor_file: Path,
ligand_file: Path,
flexible: bool = True,
posebusters: bool = True,
methodname: str = "",
store:bool =True
) -> tuple[pd.DataFrame, float]:
return pd.DataFrame(), 0
with gr.Blocks() as app:
with gr.Tab("🧬 PINDER evaluation template"):
with gr.Row():
with gr.Column():
input_system_id_pinder = gr.Textbox(label="PINDER system ID")
input_receptor_file_pinder = gr.File(label="Receptor file")
input_ligand_file_pinder = gr.File(label="Ligand file")
methodname_pinder = gr.Textbox(label="Name of your method in the format mlsb/spacename")
store_pinder = gr.Checkbox(label="Store on huggingface for leaderboard", value=False)
eval_btn_pinder = gr.Button("Run Evaluation")
with gr.Tab("⚖️ PLINDER evaluation template"):
with gr.Row():
with gr.Column():
input_system_id = gr.Textbox(label="PLINDER system ID")
input_receptor_file = gr.File(label="Receptor file (CIF)")
input_ligand_file = gr.File(label="Ligand file (SDF)")
flexible = gr.Checkbox(label="Flexible docking", value=True)
posebusters = gr.Checkbox(label="PoseBusters", value=True)
methodname = gr.Textbox(label="Name of your method in the format mlsb/spacename")
store = gr.Checkbox(label="Store on huggingface for leaderboard", value=False)
eval_btn = gr.Button("Run Evaluation")
gr.Examples(
[
[
"4neh__1__1.B__1.H",
"input_protein_test.cif",
"input_ligand_test.sdf",
True,
True,
],
],
[input_system_id, input_receptor_file, input_ligand_file, flexible, posebusters, methodname, store],
)
eval_run_time = gr.Textbox(label="Evaluation runtime")
metric_table = gr.DataFrame(
pd.DataFrame([], columns=EVAL_METRICS), label="Evaluation metrics"
)
metric_table_pinder = gr.DataFrame(
pd.DataFrame([], columns=EVAL_METRICS_PINDER), label="Evaluation metrics"
)
eval_btn.click(
get_metrics,
inputs=[input_system_id, input_receptor_file, input_ligand_file, flexible, posebusters],
outputs=[metric_table, eval_run_time],
)
eval_btn_pinder.click(
get_metrics_pinder,
inputs=[input_system_id_pinder, input_receptor_file_pinder, input_ligand_file_pinder, methodname_pinder, store_pinder],
outputs=[metric_table_pinder, eval_run_time],
)
app.launch()