File size: 16,511 Bytes
1cc2077
 
 
 
 
 
 
25f445b
b90013e
dc9c8a6
4826928
df66c51
b12fa6d
 
df66c51
 
1cc2077
 
3e6bf0e
 
a2e6203
1cc2077
3624a97
1cc2077
 
 
 
 
2781be6
d10decd
 
 
 
72f465f
1cc2077
b12fa6d
cdf41df
11d1b83
cdf41df
 
11d1b83
cdf41df
 
56d7438
b363799
1cc2077
b972165
7f7ea9c
b12fa6d
 
 
 
 
 
 
 
 
 
 
 
56d7438
7f7ea9c
b12fa6d
72f465f
 
e1bfbc1
b12fa6d
7f7ea9c
72f465f
1cc2077
b12fa6d
90fcb15
37c0c8d
 
 
 
 
 
8cc60a4
b20cd7e
b12fa6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56d7438
1cc2077
 
b90013e
b12fa6d
1cc2077
524ef7e
b12fa6d
 
508ed01
90fcb15
51bfc88
b12fa6d
c0e572f
 
 
 
 
 
 
b12fa6d
51bfc88
b12fa6d
 
 
 
51bfc88
c0e572f
b12fa6d
 
 
c0e572f
 
51bfc88
b12fa6d
 
 
 
51bfc88
 
 
b12fa6d
508ed01
51bfc88
 
6962b8e
761c866
 
 
 
 
 
 
 
6962b8e
51bfc88
b12fa6d
 
51bfc88
 
c0e572f
 
 
 
 
51bfc88
b12fa6d
 
51bfc88
 
 
6962b8e
 
 
b12fa6d
 
6962b8e
 
b12fa6d
 
 
 
 
 
101c8c7
 
 
 
 
 
b12fa6d
 
 
 
 
 
51bfc88
6962b8e
 
b12fa6d
51bfc88
b12fa6d
f24fa7c
51bfc88
 
 
b12fa6d
51bfc88
 
b12fa6d
1cc2077
 
 
fe897f2
 
b12fa6d
 
 
fe897f2
b12fa6d
1cc2077
 
 
 
 
 
 
b12fa6d
 
72f465f
1cc2077
72f465f
1cc2077
 
c806fef
 
b12fa6d
c806fef
 
 
 
b12fa6d
c806fef
 
 
 
b12fa6d
c806fef
 
6323d6b
72f465f
6323d6b
 
 
72f465f
 
 
 
1b91391
b12fa6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d280876
b12fa6d
 
d280876
b12fa6d
d280876
 
 
 
 
 
 
 
 
 
 
b12fa6d
d280876
b12fa6d
d280876
1cc2077
 
 
b90013e
1cc2077
 
 
 
 
 
 
 
 
b12fa6d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
__all__ = ['block', 'make_clickable_model', 'make_clickable_user', 'get_submissions']

import gradio as gr
import pandas as pd
import re
import os
import json
import yaml
import matplotlib.pyplot as plt
import seaborn as sns
import plotnine as p9
import sys
import zipfile
import tempfile
sys.path.append('./src')
sys.path.append('.')

from src.about import *
from src.saving_utils import *
from src.vis_utils import *
from src.bin.PROBE import run_probe


def add_new_eval(
    human_file,
    skempi_file,
    model_name_textbox: str,
    revision_name_textbox: str,
    benchmark_types,
    similarity_tasks,
    function_prediction_aspect,
    function_prediction_dataset,
    family_prediction_dataset,
    save,
):
    # Validate required files based on selected benchmarks
    if any(task in benchmark_types for task in ['similarity', 'family', 'function']) and human_file is None:
        gr.Warning("Human representations are required for similarity, family, or function benchmarks!")
        return -1
    if 'affinity' in benchmark_types and skempi_file is None:
        gr.Warning("SKEMPI representations are required for affinity benchmark!")
        return -1

    processing_info = gr.Info("Your submission is being processed...")

    representation_name = model_name_textbox if revision_name_textbox == '' else revision_name_textbox

    try:
        results = run_probe(
            benchmark_types,
            representation_name,
            human_file,
            skempi_file,
            similarity_tasks,
            function_prediction_aspect,
            function_prediction_dataset,
            family_prediction_dataset,
        )
    except Exception as e:
        gr.Warning("Your submission has not been processed. Please check your representation files!")
        return -1

    # Even if save is False, we store the submission (e.g., temporarily) so that the leaderboard includes it.
    if save:
        save_results(representation_name, benchmark_types, results)
    else:
        save_results(representation_name, benchmark_types, results, temporary=True)

    return 0


def refresh_data():
    benchmark_types = ["similarity", "function", "family", "affinity", "leaderboard"]
    for benchmark_type in benchmark_types:
        path = f"/tmp/{benchmark_type}_results.csv"
        if os.path.exists(path):
            os.remove(path)
    benchmark_types.remove("leaderboard")
    download_from_hub(benchmark_types)


def download_leaderboard_csv():
    """Generates a CSV file for the updated leaderboard."""
    df = get_baseline_df(None, None)
    tmp_csv = os.path.join(tempfile.gettempdir(), "leaderboard_download.csv")
    df.to_csv(tmp_csv, index=False)
    return tmp_csv


def generate_plots_based_on_submission(benchmark_types, similarity_tasks, function_prediction_aspect, function_prediction_dataset, family_prediction_dataset):
    """
    For each benchmark type selected during submission, generate a plot based on the corresponding extra parameters.
    """
    tmp_dir = tempfile.mkdtemp()
    plot_files = []
    # Get the current leaderboard to retrieve available method names.
    leaderboard = get_baseline_df(None, None)
    method_names = leaderboard['Method'].unique().tolist()

    for btype in benchmark_types:
        # For each benchmark type, choose plotting parameters based on additional selections.
        if btype == "similarity":
            # Use the user-selected similarity tasks (if provided) to determine the metrics.
            x_metric = similarity_tasks[0] if similarity_tasks and len(similarity_tasks) > 0 else None
            y_metric = similarity_tasks[1] if similarity_tasks and len(similarity_tasks) > 1 else None
        elif btype == "function":
            x_metric = function_prediction_aspect if function_prediction_aspect else None
            y_metric = function_prediction_dataset if function_prediction_dataset else None
        elif btype == "family":
            # For family, assume that family_prediction_dataset is a list of datasets.
            x_metric = family_prediction_dataset[0] if family_prediction_dataset and len(family_prediction_dataset) > 0 else None
            y_metric = family_prediction_dataset[1] if family_prediction_dataset and len(family_prediction_dataset) > 1 else None
        elif btype == "affinity":
            # For affinity, you may use default plotting parameters.
            x_metric, y_metric = None, None
        else:
            x_metric, y_metric = None, None

        # Generate the plot using your benchmark_plot function.
        # Here, aspect, dataset, and single_metric are passed as None, but you could extend this logic.
        plot_img = benchmark_plot(btype, method_names, x_metric, y_metric, None, None, None)
        plot_file = os.path.join(tmp_dir, f"{btype}.png")
        if isinstance(plot_img, plt.Figure):
            plot_img.savefig(plot_file)
            plt.close(plot_img)
        else:
            # If benchmark_plot already returns a file path, use it directly.
            plot_file = plot_img
        plot_files.append(plot_file)

    # Zip all plot images
    zip_path = os.path.join(tmp_dir, "submission_plots.zip")
    with zipfile.ZipFile(zip_path, "w") as zipf:
        for file in plot_files:
            zipf.write(file, arcname=os.path.basename(file))
    return zip_path


def submission_callback(
    human_file,
    skempi_file,
    model_name_textbox,
    revision_name_textbox,
    benchmark_types,
    similarity_tasks,
    function_prediction_aspect,
    function_prediction_dataset,
    family_prediction_dataset,
    save_checkbox,
    return_option,  # New radio selection: "Leaderboard CSV" or "Plot Results"
):
    """
    Runs the evaluation and then returns either a downloadable CSV of the leaderboard
    (which includes the new submission) or a ZIP file of plots generated based on the submission's selections.
    """
    eval_status = add_new_eval(
        human_file,
        skempi_file,
        model_name_textbox,
        revision_name_textbox,
        benchmark_types,
        similarity_tasks,
        function_prediction_aspect,
        function_prediction_dataset,
        family_prediction_dataset,
        save_checkbox,
    )

    if eval_status == -1:
        return "Submission failed. Please check your files and selections.", None

    if return_option == "Leaderboard CSV":
        csv_path = download_leaderboard_csv()
        return "Your leaderboard CSV (including your submission) is ready for download.", csv_path
    elif return_option == "Plot Results":
        zip_path = generate_plots_based_on_submission(
            benchmark_types,
            similarity_tasks,
            function_prediction_aspect,
            function_prediction_dataset,
            family_prediction_dataset,
        )
        return "Your plots are ready for download.", zip_path
    else:
        return "Submission processed, but no output option was selected.", None


# --------------------------
# Build the Gradio interface
# --------------------------
block = gr.Blocks()

with block:
    gr.Markdown(LEADERBOARD_INTRODUCTION)

    with gr.Tabs(elem_classes="tab-buttons") as tabs:
        with gr.TabItem("🏅 PROBE Leaderboard", elem_id="probe-benchmark-tab-table", id=1):
            # Leaderboard tab (unchanged from before)
            leaderboard = get_baseline_df(None, None)
            method_names = leaderboard['Method'].unique().tolist()
            metric_names = leaderboard.columns.tolist()
            metrics_with_method = metric_names.copy()
            metric_names.remove('Method')

            benchmark_metric_mapping = {
                "similarity": [metric for metric in metric_names if metric.startswith('sim_')],
                "function": [metric for metric in metric_names if metric.startswith('func')],
                "family": [metric for metric in metric_names if metric.startswith('fam_')],
                "affinity": [metric for metric in metric_names if metric.startswith('aff_')],
            }

            leaderboard_method_selector = gr.CheckboxGroup(
                choices=method_names,
                label="Select Methods for the Leaderboard",
                value=method_names,
                interactive=True
            )
            benchmark_type_selector = gr.CheckboxGroup(
                choices=list(benchmark_metric_mapping.keys()),
                label="Select Benchmark Types",
                value=None,
                interactive=True
            )
            leaderboard_metric_selector = gr.CheckboxGroup(
                choices=metric_names,
                label="Select Metrics for the Leaderboard",
                value=None,
                interactive=True
            )

            baseline_value = get_baseline_df(method_names, metric_names)
            baseline_value = baseline_value.applymap(lambda x: round(x, 4) if isinstance(x, (int, float)) else x)
            baseline_header = ["Method"] + metric_names
            baseline_datatype = ['markdown'] + ['number'] * len(metric_names)

            with gr.Row(show_progress=True, variant='panel'):
                data_component = gr.components.Dataframe(
                    value=baseline_value,
                    headers=baseline_header,
                    type="pandas",
                    datatype=baseline_datatype,
                    interactive=False,
                    visible=True,
                )

            leaderboard_method_selector.change(
                get_baseline_df,
                inputs=[leaderboard_method_selector, leaderboard_metric_selector],
                outputs=data_component
            )
            benchmark_type_selector.change(
                lambda selected_benchmarks: update_metrics(selected_benchmarks),
                inputs=[benchmark_type_selector],
                outputs=leaderboard_metric_selector
            )
            leaderboard_metric_selector.change(
                get_baseline_df,
                inputs=[leaderboard_method_selector, leaderboard_metric_selector],
                outputs=data_component
            )

            with gr.Row():
                gr.Markdown(
                    """
                    ## **Visualize the Leaderboard Results**
                    Select options to update the visualization.
                    """
                )
            # (Plotting section remains available as before; not the focus of the submission callback)
            benchmark_type_selector_plot = gr.Dropdown(
                choices=list(benchmark_specific_metrics.keys()),
                label="Select Benchmark Type for Plotting",
                value=None
            )
            with gr.Row():
                x_metric_selector = gr.Dropdown(choices=[], label="Select X-axis Metric", visible=False)
                y_metric_selector = gr.Dropdown(choices=[], label="Select Y-axis Metric", visible=False)
                aspect_type_selector = gr.Dropdown(choices=[], label="Select Aspect Type", visible=False)
                dataset_selector = gr.Dropdown(choices=[], label="Select Dataset", visible=False)
                single_metric_selector = gr.Dropdown(choices=[], label="Select Metric", visible=False)
            method_selector = gr.CheckboxGroup(
                choices=method_names,
                label="Select Methods to Visualize",
                interactive=True,
                value=method_names
            )
            plot_button = gr.Button("Plot")
            with gr.Row(show_progress=True, variant='panel'):
                plot_output = gr.Image(label="Plot")
            benchmark_type_selector_plot.change(
                update_metric_choices,
                inputs=[benchmark_type_selector_plot],
                outputs=[x_metric_selector, y_metric_selector, aspect_type_selector, dataset_selector, single_metric_selector]
            )
            plot_button.click(
                benchmark_plot,
                inputs=[benchmark_type_selector_plot, method_selector, x_metric_selector, y_metric_selector, aspect_type_selector, dataset_selector, single_metric_selector],
                outputs=plot_output
            )

        with gr.TabItem("📝 About", elem_id="probe-benchmark-tab-table", id=2):
            with gr.Row():
                gr.Markdown(LLM_BENCHMARKS_TEXT, elem_classes="markdown-text")
            with gr.Row():
                gr.Image(
                    value="./src/data/PROBE_workflow_figure.jpg",
                    label="PROBE Workflow Figure",
                    elem_classes="about-image",
                )

        with gr.TabItem("🚀 Submit here! ", elem_id="probe-benchmark-tab-table", id=3):
            with gr.Row():
                gr.Markdown(EVALUATION_QUEUE_TEXT, elem_classes="markdown-text")
            with gr.Row():
                gr.Markdown("# ✉️✨ Submit your model's representation files here!", elem_classes="markdown-text")
            with gr.Row():
                with gr.Column():
                    model_name_textbox = gr.Textbox(label="Method name")
                    revision_name_textbox = gr.Textbox(label="Revision Method Name")
                    benchmark_types = gr.CheckboxGroup(
                        choices=TASK_INFO,
                        label="Benchmark Types",
                        interactive=True,
                    )
                    similarity_tasks = gr.CheckboxGroup(
                        choices=similarity_tasks_options,
                        label="Similarity Tasks (if selected)",
                        interactive=True,
                    )
                    function_prediction_aspect = gr.Radio(
                        choices=function_prediction_aspect_options,
                        label="Function Prediction Aspects (if selected)",
                        interactive=True,
                    )
                    family_prediction_dataset = gr.CheckboxGroup(
                        choices=family_prediction_dataset_options,
                        label="Family Prediction Datasets (if selected)",
                        interactive=True,
                    )
                    function_dataset = gr.Textbox(
                        label="Function Prediction Datasets",
                        visible=False,
                        value="All_Data_Sets"
                    )
                    save_checkbox = gr.Checkbox(
                        label="Save results for leaderboard and visualization",
                        value=True
                    )
            with gr.Row():
                human_file = gr.components.File(
                    label="The representation file (csv) for Human dataset",
                    file_count="single",
                    type='filepath'
                )
                skempi_file = gr.components.File(
                    label="The representation file (csv) for SKEMPI dataset",
                    file_count="single",
                    type='filepath'
                )
            # New radio button for output selection.
            return_option = gr.Radio(
                choices=["Leaderboard CSV", "Plot Results"],
                label="Return Output",
                value="Leaderboard CSV",
                interactive=True,
            )
            submit_button = gr.Button("Submit Eval")
            submission_result_msg = gr.Markdown()
            submission_result_file = gr.File()
            submit_button.click(
                submission_callback,
                inputs=[
                    human_file,
                    skempi_file,
                    model_name_textbox,
                    revision_name_textbox,
                    benchmark_types,
                    similarity_tasks,
                    function_prediction_aspect,
                    function_dataset,
                    family_prediction_dataset,
                    save_checkbox,
                    return_option,
                ],
                outputs=[submission_result_msg, submission_result_file]
            )

    with gr.Row():
        data_run = gr.Button("Refresh")
        data_run.click(refresh_data, outputs=[data_component])

    with gr.Accordion("Citation", open=False):
        citation_button = gr.Textbox(
            value=CITATION_BUTTON_TEXT,
            label=CITATION_BUTTON_LABEL,
            elem_id="citation-button",
            show_copy_button=True,
        )

block.launch()