Update inference_app.py
Browse files- inference_app.py +69 -3
inference_app.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
from pathlib import Path
|
| 3 |
import time
|
| 4 |
-
from biotite.application.autodock import VinaApp
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
|
|
@@ -24,6 +23,31 @@ EVAL_METRICS = ["system", "LDDT-PLI", "LDDT-LP", "BISY-RMSD"]
|
|
| 24 |
EVAL_METRICS_PINDER = ["system","L_rms", "I_rms", "F_nat", "DOCKQ", "CAPRI_class"]
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
|
|
@@ -85,13 +109,55 @@ def get_metrics(
|
|
| 85 |
return gr.DataFrame(metrics, visible=True), run_time
|
| 86 |
|
| 87 |
|
|
|
|
| 88 |
def get_metrics_pinder(
|
| 89 |
system_id: str,
|
| 90 |
complex_file: Path,
|
| 91 |
methodname: str = "",
|
| 92 |
store:bool =True
|
| 93 |
) -> tuple[pd.DataFrame, float]:
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
with gr.Blocks() as app:
|
| 97 |
with gr.Tab("🧬 PINDER evaluation template"):
|
|
@@ -127,7 +193,7 @@ with gr.Blocks() as app:
|
|
| 127 |
posebusters = gr.Checkbox(label="PoseBusters", value=True)
|
| 128 |
methodname = gr.Textbox(label="Name of your method in the format mlsb/spacename")
|
| 129 |
store = gr.Checkbox(label="Store on huggingface for leaderboard", value=False)
|
| 130 |
-
|
| 131 |
[
|
| 132 |
[
|
| 133 |
"4neh__1__1.B__1.H",
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
from pathlib import Path
|
| 3 |
import time
|
|
|
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
|
|
|
|
| 23 |
EVAL_METRICS_PINDER = ["system","L_rms", "I_rms", "F_nat", "DOCKQ", "CAPRI_class"]
|
| 24 |
|
| 25 |
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
from huggingface_hub import HfApi
|
| 29 |
+
|
| 30 |
+
# Info to change for your repository
|
| 31 |
+
# ----------------------------------
|
| 32 |
+
TOKEN = os.environ.get("HF_TOKEN") # A read/write token for your org
|
| 33 |
+
|
| 34 |
+
OWNER = "MLSB" # Change to your org - don't forget to create a results and request dataset, with the correct format!
|
| 35 |
+
# ----------------------------------
|
| 36 |
+
|
| 37 |
+
REPO_ID = f"{OWNER}/leaderboard2024"
|
| 38 |
+
QUEUE_REPO = f"{OWNER}/requests"
|
| 39 |
+
RESULTS_REPO = f"{OWNER}/results"
|
| 40 |
+
|
| 41 |
+
# If you setup a cache later, just change HF_HOME
|
| 42 |
+
CACHE_PATH=os.getenv("HF_HOME", ".")
|
| 43 |
+
|
| 44 |
+
# Local caches
|
| 45 |
+
EVAL_REQUESTS_PATH = os.path.join(CACHE_PATH, "eval-queue")
|
| 46 |
+
EVAL_RESULTS_PATH = os.path.join(CACHE_PATH, "eval-results")
|
| 47 |
+
EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
|
| 48 |
+
EVAL_RESULTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-results-bk")
|
| 49 |
+
|
| 50 |
+
API = HfApi(token=TOKEN)
|
| 51 |
|
| 52 |
|
| 53 |
|
|
|
|
| 109 |
return gr.DataFrame(metrics, visible=True), run_time
|
| 110 |
|
| 111 |
|
| 112 |
+
|
| 113 |
def get_metrics_pinder(
|
| 114 |
system_id: str,
|
| 115 |
complex_file: Path,
|
| 116 |
methodname: str = "",
|
| 117 |
store:bool =True
|
| 118 |
) -> tuple[pd.DataFrame, float]:
|
| 119 |
+
start_time = time.time()
|
| 120 |
+
|
| 121 |
+
if not isinstance(prediction, Path):
|
| 122 |
+
prediction = Path(prediction)
|
| 123 |
+
# Infer the ground-truth name from prediction filename or directory where its stored
|
| 124 |
+
# We need to figure out how we plan to consistently map predictions to systems so that eval metrics can be calculated
|
| 125 |
+
# I assume we won't distribute the ground-truth structures (though they are already accessible if we don't blind system IDs)
|
| 126 |
+
native = Path(f"./ground_truth/{system_id}.pdb")
|
| 127 |
+
# alternatively
|
| 128 |
+
# native = Path(f"./ground_truth/{prediction.parent.parent.stem}.pdb")
|
| 129 |
+
# OR we need the user to provide prediction + system name
|
| 130 |
+
try:
|
| 131 |
+
# Get eval metrics for the prediction
|
| 132 |
+
bdq = BiotiteDockQ(native, complex_file.name, parallel_io=False)
|
| 133 |
+
metrics = bdq.calculate()
|
| 134 |
+
metrics = metrics[["system", "LRMS", "iRMS", "Fnat", "DockQ", "CAPRI"]].copy()
|
| 135 |
+
metrics.rename(columns={"LRMS": "L_rms", "iRMS": "I_rms", "Fnat": "F_nat", "DockQ": "DOCKQ", "CAPRI": "CAPRI_class"}, inplace=True)
|
| 136 |
+
except Exception as e:
|
| 137 |
+
failed_metrics = {"L_rms": 100.0, "I_rms": 100.0, "F_nat": 0.0, "DOCKQ": 0.0, "CAPRI_class": "Incorrect"}
|
| 138 |
+
metrics = pd.DataFrame([failed_metrics])
|
| 139 |
+
metrics["system"] = native.stem
|
| 140 |
+
gr.Error(f"Failed to evaluate prediction [{prediction}]:\n{e}")
|
| 141 |
+
# Upload to hub
|
| 142 |
+
with tempfile.NamedTemporaryFile as temp:
|
| 143 |
+
metrics.to_csv(temp.name)
|
| 144 |
+
API.upload_file(
|
| 145 |
+
path_or_fileobj=temp.name,
|
| 146 |
+
path_in_repo=f"{dataset}/{methodname}/{system_id}/",
|
| 147 |
+
repo_id=QUEUE_REPO,
|
| 148 |
+
repo_type="dataset",
|
| 149 |
+
commit_message=f"Add {model_name} to eval queue",
|
| 150 |
+
)
|
| 151 |
+
API.upload_file(
|
| 152 |
+
path_or_fileobj=complex_file.name,
|
| 153 |
+
path_in_repo=f"{dataset}/{methodname}/{system_id}/",
|
| 154 |
+
repo_id=QUEUE_REPO,
|
| 155 |
+
repo_type="dataset",
|
| 156 |
+
commit_message=f"Add {model_name} to eval queue",
|
| 157 |
+
)
|
| 158 |
+
end_time = time.time()
|
| 159 |
+
run_time = end_time - start_time
|
| 160 |
+
return gr.DataFrame(metrics, visible=True), run_time
|
| 161 |
|
| 162 |
with gr.Blocks() as app:
|
| 163 |
with gr.Tab("🧬 PINDER evaluation template"):
|
|
|
|
| 193 |
posebusters = gr.Checkbox(label="PoseBusters", value=True)
|
| 194 |
methodname = gr.Textbox(label="Name of your method in the format mlsb/spacename")
|
| 195 |
store = gr.Checkbox(label="Store on huggingface for leaderboard", value=False)
|
| 196 |
+
gr.Examples(
|
| 197 |
[
|
| 198 |
[
|
| 199 |
"4neh__1__1.B__1.H",
|