PROBE / app.py
mgyigit's picture
Update app.py
7d7e727 verified
raw
history blame
18.5 kB
import gradio as gr
import pandas as pd
import re
import os
import json
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import plotnine as p9
import sys
sys.path.append('./src')
sys.path.append('.')
from huggingface_hub import HfApi
repo_id = "HUBioDataLab/PROBE"
api = HfApi()
from src.about import *
from src.saving_utils import *
from src.vis_utils import *
from src.bin.PROBE import run_probe
# ------------------------------------------------------------------
# Helper functions --------------------------------------------------
# ------------------------------------------------------------------
def add_new_eval(
human_file,
skempi_file,
model_name_textbox: str,
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
save,
):
"""Validate inputs, run evaluation and (optionally) save results."""
# map the user‐facing labels back to the original codes
try:
benchmark_types_mapped = [benchmark_type_map[b] for b in benchmark_types]
similarity_tasks_mapped = [similarity_tasks_map[s] for s in similarity_tasks]
function_prediction_aspect_mapped = function_prediction_aspect_map[function_prediction_aspect]
family_prediction_dataset_mapped = [family_prediction_dataset_map[f] for f in family_prediction_dataset]
except KeyError as e:
gr.Warning(f"Unrecognized option: {e.args[0]}")
return -1
# validate inputs
if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None:
gr.Warning("Human representations are required for similarity, family, or function benchmarks!")
return -1
if 'affinity' in benchmark_types and skempi_file is None:
gr.Warning("SKEMPI representations are required for affinity benchmark!")
return -1
gr.Info("Your submission is being processed…")
representation_name = model_name_textbox
try:
results = run_probe(
benchmark_types,
representation_name,
human_file,
skempi_file,
similarity_tasks,
function_prediction_aspect,
function_prediction_dataset,
family_prediction_dataset,
)
except Exception:
gr.Warning("Your submission has not been processed. Please check your representation files!")
return -1
if save:
save_results(representation_name, benchmark_types, results)
gr.Info("Your submission has been processed and results are saved!")
else:
gr.Info("Your submission has been processed!")
return 0
def refresh_data():
"""Re‑start the space and pull fresh leaderboard CSVs from the HF Hub."""
api.restart_space(repo_id=repo_id)
benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"]
for benchmark_type in benchmark_types:
path = f"/tmp/{benchmark_type}_results.csv"
if os.path.exists(path):
os.remove(path)
benchmark_types.remove("leaderboard")
download_from_hub(benchmark_types)
# ------- Leaderboard helpers -----------------------------------------------
def update_metrics(selected_benchmarks):
updated_metrics = set()
for benchmark in selected_benchmarks:
updated_metrics.update(benchmark_metric_mapping.get(benchmark, []))
return list(updated_metrics)
def update_leaderboard(selected_methods, selected_metrics):
return build_leaderboard_styler(selected_methods, selected_metrics)
def colour_method_html(name: str) -> str:
"""Return the method string wrapped in a coloured <span>. Handles raw names
or markdown links like '[T5](https://…)' transparently."""
colour = color_dict.get(re.sub(r"\[|\]|\(.*?\)", "", name), "black") # strip md link
return f"<span style='color:{colour}; font-weight:bold;'>{name}</span>"
# darkest → lightest green
TOP5_GREENS = ["#006400", "#228B22", "#32CD32", "#7CFC00", "#ADFF2F"]
def shade_top5(col: pd.Series) -> list[str]:
"""Return a CSS list for one column: background for ranks 1-5, blank else."""
if not pd.api.types.is_numeric_dtype(col):
return [""] * len(col)
ranks = col.rank(ascending=False, method="first")
return [
f"background-color:{TOP5_GREENS[int(r)-1]};" if r <= 5 else ""
for r in ranks
]
def build_leaderboard_styler(selected_methods=None, selected_metrics=None):
df = get_baseline_df(selected_methods, selected_metrics).round(4)
# 1️⃣ colour method names via inline-HTML
df["Method"] = df["Method"].apply(colour_method_html)
numeric_cols = [c for c in df.columns if c != "Method"]
# 2️⃣ add the green gradient only to numeric cols
styler = (
df.style
.apply(shade_top5, axis=0, subset=numeric_cols)
.format(precision=4) # keep numbers tidy
)
return styler
# ------- Visualisation helpers ---------------------------------------------
def generate_plot(benchmark_type, methods_selected, x_metric, y_metric, aspect, dataset, single_metric):
plot_path = benchmark_plot(
benchmark_type,
methods_selected,
x_metric,
y_metric,
aspect,
dataset,
single_metric,
)
return plot_path
# ---------------------------------------------------------------------------
# Custom CSS for frozen first column and clearer table styles
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
/* freeze first column */
#leaderboard-table table tr th:first-child,
#leaderboard-table table tr td:first-child {
position: sticky;
left: 0;
background: white;
z-index: 2;
}
/* striped rows for readability */
#leaderboard-table table tr:nth-child(odd) {
background: #fafafa;
}
/* centre numeric cells */
#leaderboard-table td:not(:first-child) {
text-align: center;
}
/* scrollable and taller table */
#leaderboard-table .dataframe-wrap {
max-height: 1200px;
overflow-y: auto;
overflow-x: auto;
}
"""
# ---------------------------------------------------------------------------
# UI definition
# ---------------------------------------------------------------------------
block = gr.Blocks(css=CUSTOM_CSS)
with block:
gr.Markdown(LEADERBOARD_INTRODUCTION)
with gr.Tabs(elem_classes="tab-buttons") as tabs:
# ------------------------------------------------------------------
# 1️⃣ Leaderboard tab
# ------------------------------------------------------------------
with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1):
# small workflow figure at top
gr.Image(
value="./src/data/PROBE_workflow_figure.jpg",
show_label=False,
height=1000,
container=False,
)
gr.Markdown(
"## For detailed explanations of the metrics and benchmarks, please refer to the 📝 About tab.",
elem_classes="leaderboard-note",
)
leaderboard = get_baseline_df(None, None)
method_names = leaderboard['Method'].unique().tolist()
metric_names = leaderboard.columns.tolist(); metric_names.remove('Method')
benchmark_metric_mapping = {
"Semantic Similarity Inference": [m for m in metric_names if m.startswith('sim_')],
"Ontology-based Protein Function Prediction": [m for m in metric_names if m.startswith('func')],
"Drug Target Protein Family Classification": [m for m in metric_names if m.startswith('fam_')],
"Protein-Protein Binding Affinity Estimation": [m for m in metric_names if m.startswith('aff_')],
}
leaderboard_method_selector = gr.CheckboxGroup(
choices=method_names,
label="Select Methods",
value=method_names,
interactive=True,
)
benchmark_type_selector_lb = gr.CheckboxGroup(
choices=list(benchmark_metric_mapping.keys()),
label="Select Benchmark Types",
value=None,
interactive=True,
)
leaderboard_metric_selector = gr.CheckboxGroup(
choices=metric_names,
label="Select Metrics",
value=None,
interactive=True,
)
baseline_value = get_baseline_df(method_names, metric_names)
baseline_value = baseline_value.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x)
baseline_header = ["Method"] + metric_names
baseline_datatype = ['markdown'] + ['number'] * len(metric_names)
styler = build_leaderboard_styler()
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(
"""
## Method-name colours
<span style='color:green; font-weight:bold; font-size:1.1rem;'>🟢  Classical representations
<span style='color:blue; font-weight:bold; font-size:1.1rem;'>🔵  Small-scale learned representations
<span style='color:red; font-weight:bold; font-size:1.1rem;'>🔴  Large-scale learned representations
<span style='color:orange;font-weight:bold; font-size:1.1rem;'>🟠  Multimodal learned representations
""",
elem_classes="leaderboard-note",
)
with gr.Column(scale=1):
gr.Markdown(
"""
## Metric-cell shading
<span style='background-color:#006400; color:white; padding:0.4rem 0.8rem; border-radius:0.4rem;
font-size:1.1rem; display:inline-block; text-align:center; margin-right:0.2rem;'>1</span>
<span style='background-color:#228B22; color:white; padding:0.4rem 0.8rem; border-radius:0.4rem;
font-size:1.1rem; display:inline-block; text-align:center; margin-right:0.2rem;'>2</span>
<span style='background-color:#32CD32; color:black; padding:0.4rem 0.8rem; border-radius:0.4rem;
font-size:1.1rem; display:inline-block; text-align:center; margin-right:0.2rem;'>3</span>
<span style='background-color:#7CFC00; color:black; padding:0.4rem 0.8rem; border-radius:0.4rem;
font-size:1.1rem; display:inline-block; text-align:center; margin-right:0.2rem;'>4</span>
<span style='background-color:#ADFF2F; color:black; padding:0.4rem 0.8rem; border-radius:0.4rem;
font-size:1.1rem; display:inline-block; text-align:center;'>5</span>
<br>
<span style='font-size:1.1rem;'> top-five scores (darker → better)</span>
""",
elem_classes="leaderboard-note",
)
with gr.Row(show_progress=True, variant='panel'):
data_component = gr.Dataframe(
value=styler,
headers=baseline_header,
type="pandas",
datatype=baseline_datatype,
interactive=False,
elem_id="leaderboard-table",
pinned_columns=1,
max_height=1000,
show_fullscreen_button=True
)
# callbacks
leaderboard_method_selector.change(
update_leaderboard,
inputs=[leaderboard_method_selector, leaderboard_metric_selector],
outputs=data_component,
)
benchmark_type_selector_lb.change(
lambda selected: update_metrics(selected),
inputs=[benchmark_type_selector_lb],
outputs=leaderboard_metric_selector,
)
leaderboard_metric_selector.change(
update_leaderboard,
inputs=[leaderboard_method_selector, leaderboard_metric_selector],
outputs=data_component,
)
# ------------------------------------------------------------------
# 2️⃣ Visualisation tab
# ------------------------------------------------------------------
with gr.TabItem("📊 Visualization", elem_id="probe-benchmark-tab-visualization", id=2):
gr.Markdown(
"""## **Interactive Visualizations**
Choose a benchmark type; context‑specific options will appear.""",
elem_classes="markdown-text",
)
vis_benchmark_type_selector = gr.Dropdown(
choices=list(benchmark_specific_metrics.keys()),
label="Benchmark Type",
value=None,
)
with gr.Row():
vis_x_metric_selector = gr.Dropdown(choices=[], label="X‑axis Metric", visible=False)
vis_y_metric_selector = gr.Dropdown(choices=[], label="Y‑axis Metric", visible=False)
vis_aspect_type_selector = gr.Dropdown(choices=[], label="Aspect", visible=False)
vis_dataset_selector = gr.Dropdown(choices=[], label="Dataset", visible=False)
vis_single_metric_selector = gr.Dropdown(choices=[], label="Metric", visible=False)
vis_method_selector = gr.CheckboxGroup(
choices=method_names,
label="Methods",
value=method_names,
interactive=True,
)
plot_button = gr.Button("Plot")
with gr.Row(show_progress=True, variant='panel'):
plot_output = gr.Image(label="Plot")
plot_explanation = gr.Markdown(visible=False)
# callbacks
vis_benchmark_type_selector.change(
update_metric_choices,
inputs=[vis_benchmark_type_selector],
outputs=[
vis_x_metric_selector,
vis_y_metric_selector,
vis_aspect_type_selector,
vis_dataset_selector,
vis_single_metric_selector,
],
)
plot_button.click(
generate_plot,
inputs=[
vis_benchmark_type_selector,
vis_method_selector,
vis_x_metric_selector,
vis_y_metric_selector,
vis_aspect_type_selector,
vis_dataset_selector,
vis_single_metric_selector,
],
outputs=[plot_output],
)
# ------------------------------------------------------------------
# 3️⃣ About tab
# ------------------------------------------------------------------
with gr.TabItem("📝 About", elem_id="probe-benchmark-tab-table", id=3):
with gr.Row():
gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
with gr.Row():
gr.Image(
value="./src/data/PROBE_workflow_figure.jpg",
label="PROBE Workflow Figure",
elem_classes="about-image",
)
# ------------------------------------------------------------------
# 4️⃣ Submit tab
# ------------------------------------------------------------------
with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=4):
with gr.Row():
gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text")
with gr.Row():
gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text")
with gr.Row():
with gr.Column():
model_name_textbox = gr.Textbox(label="Method name")
benchmark_types = gr.CheckboxGroup(choices=TASK_INFO, label="Benchmark Types", interactive=True)
similarity_tasks = gr.CheckboxGroup(choices=similarity_tasks_options, label="Similarity Tasks", interactive=True)
function_prediction_aspect = gr.Radio(choices=function_prediction_aspect_options, label="Function Prediction Aspects", interactive=True)
family_prediction_dataset = gr.CheckboxGroup(choices=family_prediction_dataset_options, label="Family Prediction Datasets", interactive=True)
function_dataset = gr.Textbox(label="Function Prediction Datasets", visible=False, value="All_Data_Sets")
save_checkbox = gr.Checkbox(label="Save results for leaderboard and visualization", value=True)
with gr.Row():
human_file = gr.File(label="Representation file (CSV) for Human dataset", file_count="single", type='filepath')
skempi_file = gr.File(label="Representation file (CSV) for SKEMPI dataset", file_count="single", type='filepath')
submit_button = gr.Button("Submit Eval")
submission_result = gr.Markdown()
submit_button.click(
add_new_eval,
inputs=[
human_file,
skempi_file,
model_name_textbox,
benchmark_types,
similarity_tasks,
function_prediction_aspect,
function_dataset,
family_prediction_dataset,
save_checkbox,
],
)
# global refresh + citation ---------------------------------------------
with gr.Row():
data_run = gr.Button("Refresh")
data_run.click(refresh_data, outputs=[data_component])
with gr.Accordion("Citation", open=False):
citation_button = gr.Textbox(
value=CITATION_BUTTON_TEXT,
label=CITATION_BUTTON_LABEL,
elem_id="citation-button",
show_copy_button=True,
)
# ---------------------------------------------------------------------------
block.launch()