Mark Duppenthaler
Combined leaderboard, simplified filters
b087e88
import pandas as pd
from pathlib import Path
def plot_data(metric, selected_attack, all_attacks_df):
attack_df = all_attacks_df[all_attacks_df.attack == selected_attack]
# if metric == "None":
# return gr.LinePlot(x_bin=None)
# return gr.LinePlot(
# attack_df,
# x="strength",
# y=metric,
# color="model",
# )
def mk_variations(
all_attacks_df,
attacks_with_variations: list[str],
):
# all_attacks_df = pd.read_csv(csv_file)
# print(all_attacks_df)
# print(csv_file)
# with gr.Row():
# group_by = gr.Radio(metrics, value=metrics[0], label="Choose metric")
# attacks_dropdown = gr.Dropdown(
# attacks_with_variations,
# label=attacks_with_variations[0],
# info="Select attack",
# )
# attacks_by_strength = plot_data(
# group_by.value, attacks_dropdown.value, all_attacks_df
# )
# all_graphs = [
# attacks_by_strength,
# ]
# group_by.change(
# lambda group: plot_data(group, attacks_dropdown.value, all_attacks_df),
# group_by,
# all_graphs,
# )
# attacks_dropdown.change(
# lambda attack: plot_data(group_by.value, attack, all_attacks_df),
# attacks_dropdown,
# all_graphs,
# )
# Replace NaN values with None for JSON serialization
all_attacks_df = all_attacks_df.fillna(value="NaN")
attacks_plot_metrics = [
"bit_acc",
"log10_p_value",
"TPR",
"FPR",
"watermark_det_score",
]
return {
"metrics": attacks_plot_metrics,
"attacks_with_variations": attacks_with_variations,
"all_attacks_df": all_attacks_df.to_dict(orient="records"),
}