PROBE / app.py
mgyigit's picture
Update app.py
8e77a26 verified
raw
history blame
23 kB
import gradio as gr
import pandas as pd
import re
import os
import json
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import plotnine as p9
import sys
sys.path.append('./src')
sys.path.append('.')
from huggingface_hub import HfApi
repo_id = "HUBioDataLab/PROBE"
api = HfApi()
from src.about import *
from src.saving_utils import *
from src.vis_utils import *
from src.bin.PROBE import run_probe
# ------------------------------------------------------------------
# Helper functions --------------------------------------------------
# ------------------------------------------------------------------
def add_new_eval(
human_file,
skempi_file,
model_name_textbox: str,
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
save,
):
"""Validate inputs, run evaluation and (optionally) save results."""
# map the user‐facing labels back to the original codes
try:
benchmark_types_mapped = [benchmark_type_map[b] for b in benchmark_types]
similarity_tasks_mapped = [similarity_tasks_map[s] for s in similarity_tasks]
function_prediction_aspect_mapped = function_prediction_aspect_map[function_prediction_aspect]
family_prediction_dataset_mapped = [family_prediction_dataset_map[f] for f in family_prediction_dataset]
except KeyError as e:
gr.Warning(f"Unrecognized option: {e.args[0]}")
return -1
# validate inputs
if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None:
gr.Warning("Human representations are required for similarity, family, or function benchmarks!")
return -1
if 'affinity' in benchmark_types and skempi_file is None:
gr.Warning("SKEMPI representations are required for affinity benchmark!")
return -1
gr.Info("Your submission is being processed…")
representation_name = model_name_textbox
try:
results = run_probe(
benchmark_types,
representation_name,
human_file,
skempi_file,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
)
except Exception:
gr.Warning("Your submission has not been processed. Please check your representation files!")
return -1
if save:
save_results(representation_name, benchmark_types, results)
gr.Info("Your submission has been processed and results are saved!")
else:
gr.Info("Your submission has been processed!")
return 0
def refresh_data():
"""Re‑start the space and pull fresh leaderboard CSVs from the HF Hub."""
api.restart_space(repo_id=repo_id)
benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"]
for benchmark_type in benchmark_types:
path = f"/tmp/{benchmark_type}_results.csv"
if os.path.exists(path):
os.remove(path)
benchmark_types.remove("leaderboard")
download_from_hub(benchmark_types)
# ------- Leaderboard helpers -----------------------------------------------
def update_metrics(selected_benchmarks):
updated_metrics = set()
for benchmark in selected_benchmarks:
updated_metrics.update(benchmark_metric_mapping.get(benchmark, []))
return list(updated_metrics)
def update_leaderboard(selected_methods, selected_metrics):
return build_leaderboard_styler(selected_methods, selected_metrics)
def colour_method_html(name: str) -> str:
"""Return the method string wrapped in a coloured <span>. Handles raw names
or markdown links like '[T5](https://…)' transparently."""
colour = color_dict.get(re.sub(r"\[|\]|\(.*?\)", "", name), "black") # strip md link
return f"<span style='color:{colour}; font-weight:bold;'>{name}</span>"
# darkest β†’ lightest green
TOP5_GREENS = ["#006400", "#228B22", "#32CD32", "#7CFC00", "#ADFF2F"]
def shade_top5(col: pd.Series) -> list[str]:
"""Return a CSS list for one column: background for ranks 1-5, blank else."""
if not pd.api.types.is_numeric_dtype(col):
return [""] * len(col)
ranks = col.rank(ascending=False, method="first")
return [
f"background-color:{TOP5_GREENS[int(r)-1]};" if r <= 5 else ""
for r in ranks
]
def build_leaderboard_styler(selected_methods=None, selected_metrics=None):
df = get_baseline_df(selected_methods, selected_metrics).round(4)
# 1️⃣ colour method names via inline-HTML
df["Method"] = df["Method"].apply(colour_method_html)
numeric_cols = [c for c in df.columns if c != "Method"]
# 2️⃣ add the green gradient only to numeric cols
styler = (
df.style
.apply(shade_top5, axis=0, subset=numeric_cols)
.format(precision=4) # keep numbers tidy
)
return styler
# ------- Visualisation helpers ---------------------------------------------
def generate_plot(benchmark_type, methods_selected, x_metric, y_metric, aspect, dataset, single_metric):
plot_path = benchmark_plot(
benchmark_type,
methods_selected,
x_metric,
y_metric,
aspect,
dataset,
single_metric,
)
return plot_path
# ---------------------------------------------------------------------------
# Custom CSS for frozen first column and clearer table styles
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
/* freeze first column */
#leaderboard-table table tr th:first-child,
#leaderboard-table table tr td:first-child {
position: sticky;
left: 0;
z-index: 2;
/* wider β€œMethod” column */
min-width: 190px;
width: 190px;
white-space: nowrap;
}
/* centre numeric cells */
#leaderboard-table td:not(:first-child) {
text-align: center;
}
/* scrollable and taller table */
#leaderboard-table .dataframe-wrap {
max-height: 1200px;
overflow-y: auto;
overflow-x: auto;
}
"""
# ---------------------------------------------------------------------------
# UI definition
# ---------------------------------------------------------------------------
block = gr.Blocks(css=CUSTOM_CSS)
with block:
gr.Markdown(LEADERBOARD_INTRODUCTION)
with gr.Tabs(elem_classes="tab-buttons") as tabs:
# ------------------------------------------------------------------
# 1️⃣ Leaderboard tab
# ------------------------------------------------------------------
with gr.TabItem("πŸ… PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1):
# ── header ────────────────────────────────────────────────────
gr.Image(
value="./src/data/PROBE_workflow_figure.jpg",
show_label=False,
height=1000,
container=False,
)
gr.Markdown(
"## For detailed explanations of the metrics and benchmarks, please refer to the πŸ“ About tab.",
elem_classes="leaderboard-note",
)
# ── data prep ────────────────────────────────────────────────
leaderboard = get_baseline_df(None, None)
method_names = leaderboard["Method"].unique().tolist()
metric_names = leaderboard.columns.tolist(); metric_names.remove("Method")
base_method_names = [m for m in method_names if m in base_methods]
user_method_names = [m for m in method_names if m not in base_methods]
benchmark_metric_mapping = {
"Semantic Similarity Inference": [m for m in metric_names if m.startswith("sim_")],
"Ontology-based Protein Function Prediction": [m for m in metric_names if m.startswith("func")],
"Drug Target Protein Family Classification": [m for m in metric_names if m.startswith("fam_")],
"Protein-Protein Binding Affinity Estimation": [m for m in metric_names if m.startswith("aff_")],
}
# ── callback helper ──────────────────────────────────────────
def update_leaderboard_combined(selected_base, selected_user, selected_metrics):
selected_methods = (selected_base or []) + (selected_user or [])
return build_leaderboard_styler(selected_methods, selected_metrics)
# ── collapsible selectors ────────────────────────────────────
with gr.Accordion("πŸ“¦ Base Methods", open=False):
leaderboard_method_selector_base = gr.CheckboxGroup(
choices=base_method_names,
label="Base Methods",
value=base_method_names, # ← all selected
interactive=True,
)
with gr.Accordion("πŸ› οΈ User-defined Methods", open=False):
leaderboard_method_selector_user = gr.CheckboxGroup(
choices=user_method_names,
label="User Methods",
value=[], # ← none selected
interactive=True,
)
with gr.Accordion("πŸ§ͺ Benchmark Types", open=False):
benchmark_type_selector_lb = gr.CheckboxGroup(
choices=list(benchmark_metric_mapping.keys()),
label="Benchmark Types",
value=list(benchmark_metric_mapping.keys()), # all selected
interactive=True,
)
with gr.Accordion("πŸ“ Metrics", open=False):
leaderboard_metric_selector = gr.CheckboxGroup(
choices=metric_names,
label="Select Metrics",
value=metric_names, # ← all selected
interactive=True,
)
# ── colour / shading legend (unchanged) ──────────────────────
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(
"""
## Method-name colours
<span style='color:green; font-weight:bold; font-size:1.1rem;'>πŸŸ’β€‚ Classical representations
<span style='color:blue; font-weight:bold; font-size:1.1rem;'>πŸ”΅β€‚ Small-scale Protein LMs
<span style='color:red; font-weight:bold; font-size:1.1rem;'>πŸ”΄β€‚ Large-scale Protein LMs
<span style='color:orange;font-weight:bold; font-size:1.1rem;'>πŸŸ β€‚ Multimodal Protein LMs
""",
elem_classes="leaderboard-note",
)
with gr.Column(scale=1):
gr.Markdown(
"""
## Metric-cell shading
<span style='background-color:#006400; color:white; padding:0.4rem 0.8rem; border-radius:0.4rem;
font-size:1.1rem; display:inline-block; text-align:center; margin-right:0.2rem;'>1</span>
<span style='background-color:#228B22; color:white; padding:0.4rem 0.8rem; border-radius:0.4rem;
font-size:1.1rem; display:inline-block; text-align:center; margin-right:0.2rem;'>2</span>
<span style='background-color:#32CD32; color:black; padding:0.4rem 0.8rem; border-radius:0.4rem;
font-size:1.1rem; display:inline-block; text-align:center; margin-right:0.2rem;'>3</span>
<span style='background-color:#7CFC00; color:black; padding:0.4rem 0.8rem; border-radius:0.4rem;
font-size:1.1rem; display:inline-block; text-align:center; margin-right:0.2rem;'>4</span>
<span style='background-color:#ADFF2F; color:black; padding:0.4rem 0.8rem; border-radius:0.4rem;
font-size:1.1rem; display:inline-block; text-align:center;'>5</span>
<br>
<span style='font-size:1.1rem;'> top-five scores (darker β†’ better)</span>
""",
elem_classes="leaderboard-note",
)
# ── dataframe ────────────────────────────────────────────────
styler = build_leaderboard_styler(base_method_names, metric_names)
data_component = gr.Dataframe(
value=styler,
headers=["Method"] + metric_names,
type="pandas",
datatype=["markdown"] + ["number"] * len(metric_names),
interactive=False,
elem_id="leaderboard-table",
pinned_columns=1,
max_height=1000,
show_fullscreen_button=True,
)
gr.Markdown("#### If a method name ends with **^**, it suggests potential suspicions of data leakage related to ***similarity***, ***function***, or ***family*** benchmarks.")
# ── callbacks ────────────────────────────────────────────────
leaderboard_method_selector_base.change(
update_leaderboard_combined,
inputs=[leaderboard_method_selector_base, leaderboard_method_selector_user, leaderboard_metric_selector],
outputs=data_component,
)
leaderboard_method_selector_user.change(
update_leaderboard_combined,
inputs=[leaderboard_method_selector_base, leaderboard_method_selector_user, leaderboard_metric_selector],
outputs=data_component,
)
leaderboard_metric_selector.change(
update_leaderboard_combined,
inputs=[leaderboard_method_selector_base, leaderboard_method_selector_user, leaderboard_metric_selector],
outputs=data_component,
)
benchmark_type_selector_lb.change(
lambda selected: update_metrics(selected),
inputs=[benchmark_type_selector_lb],
outputs=leaderboard_metric_selector,
)
# ------------------------------------------------------------------
# 2️⃣ Visualisation tab
# ------------------------------------------------------------------
with gr.TabItem("πŸ“Š Visualization", elem_id="probe-benchmark-tab-visualization", id=2):
gr.Markdown(
"""## **Interactive Visualizations**
Choose a benchmark type; context-specific options will appear.""",
elem_classes="markdown-text",
)
# ── benchmark-type selector ──────────────────────────────────
vis_benchmark_type_selector = gr.Dropdown(
choices=list(benchmark_specific_metrics.keys()),
label="πŸ§ͺ Benchmark Type",
value=None,
)
# ── metric / dataset selectors (appear contextually) ─────────
with gr.Row():
vis_x_metric_selector = gr.Dropdown(choices=[], label="X-axis Metric", visible=False)
vis_y_metric_selector = gr.Dropdown(choices=[], label="Y-axis Metric", visible=False)
vis_aspect_type_selector = gr.Dropdown(choices=[], label="Aspect", visible=False)
vis_dataset_selector = gr.Dropdown(choices=[], label="Dataset", visible=False)
vis_single_metric_selector = gr.Dropdown(choices=[], label="Metric", visible=False)
# ── method selectors (two accordions) ───────────────────────
base_method_names = [m for m in method_names if m in base_methods]
user_method_names = [m for m in method_names if m not in base_methods]
with gr.Accordion("πŸ“¦ Base methods", open=False):
vis_method_selector_base = gr.CheckboxGroup(
choices=base_method_names,
label="Base Methods",
value=base_method_names, # default: all selected
interactive=True,
)
with gr.Accordion("πŸ› οΈ User-defined methods", open=False):
vis_method_selector_user = gr.CheckboxGroup(
choices=user_method_names,
label="User Methods",
value=[], # default: none selected
interactive=True,
)
# ── plot button & output ────────────────────────────────────
plot_button = gr.Button("Plot")
with gr.Row(show_progress=True, variant='panel'):
plot_output = gr.Image(label="Plot")
gr.Markdown("#### If a method name ends with **^**, it suggests potential suspicions of data leakage related to ***similarity***, ***function***, or ***family*** benchmarks.")
# ── callbacks ───────────────────────────────────────────────
vis_benchmark_type_selector.change(
update_metric_choices,
inputs=[vis_benchmark_type_selector],
outputs=[
vis_x_metric_selector,
vis_y_metric_selector,
vis_aspect_type_selector,
vis_dataset_selector,
vis_single_metric_selector,
],
)
# combine the two method lists, then call the original helper
plot_button.click(
lambda bt, base_sel, user_sel, xm, ym, asp, ds, sm: generate_plot(
bt,
(base_sel or []) + (user_sel or []), # merged method list
xm, ym, asp, ds, sm,
),
inputs=[
vis_benchmark_type_selector,
vis_method_selector_base,
vis_method_selector_user,
vis_x_metric_selector,
vis_y_metric_selector,
vis_aspect_type_selector,
vis_dataset_selector,
vis_single_metric_selector,
],
outputs=[plot_output],
)
# ------------------------------------------------------------------
# 3️⃣ About tab
# ------------------------------------------------------------------
with gr.TabItem("πŸ“ About", elem_id="probe-benchmark-tab-table", id=3):
with gr.Row():
gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
with gr.Row():
gr.Image(
value="./src/data/PROBE_workflow_figure.jpg",
label="PROBE Workflow Figure",
elem_classes="about-image",
)
# ------------------------------------------------------------------
# 4️⃣ Submit tab
# ------------------------------------------------------------------
with gr.TabItem("πŸš€ Submit here! ", elem_id="probe-benchmark-tab-table", id=4):
with gr.Row():
gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text")
with gr.Row():
gr.Markdown("# βœ‰οΈβœ¨ Submit your model's representation files here!", elem_classes="markdown-text")
with gr.Row():
with gr.Column():
model_name_textbox = gr.Textbox(label="Method name")
benchmark_types = gr.CheckboxGroup(choices=TASK_INFO, label="Benchmark Types", interactive=True)
similarity_tasks = gr.CheckboxGroup(choices=similarity_tasks_options, label="Similarity Tasks", interactive=True)
function_prediction_aspect = gr.Radio(choices=function_prediction_aspect_options, label="Function Prediction Aspects", interactive=True)
family_prediction_dataset = gr.CheckboxGroup(choices=family_prediction_dataset_options, label="Family Prediction Datasets", interactive=True)
function_dataset = gr.Textbox(label="Function Prediction Datasets", visible=False, value="All_Data_Sets")
save_checkbox = gr.Checkbox(label="Save results for leaderboard and visualization", value=True)
with gr.Row():
human_file = gr.File(label="Representation file (CSV) for Human dataset", file_count="single", type='filepath')
skempi_file = gr.File(label="Representation file (CSV) for SKEMPI dataset", file_count="single", type='filepath')
submit_button = gr.Button("Submit Eval")
submission_result = gr.Markdown()
submit_button.click(
add_new_eval,
inputs=[
human_file,
skempi_file,
model_name_textbox,
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_dataset,
family_prediction_dataset,
save_checkbox,
],
)
# global refresh + citation ---------------------------------------------
with gr.Row():
data_run = gr.Button("Refresh")
data_run.click(refresh_data, outputs=[data_component])
with gr.Accordion("Citation", open=False):
citation_button = gr.Textbox(
value=CITATION_BUTTON_TEXT,
label=CITATION_BUTTON_LABEL,
elem_id="citation-button",
show_copy_button=True,
)
# ---------------------------------------------------------------------------
block.launch()