import pandas as pd from pathlib import Path def plot_data(metric, selected_attack, all_attacks_df): attack_df = all_attacks_df[all_attacks_df.attack == selected_attack] # if metric == "None": # return gr.LinePlot(x_bin=None) # return gr.LinePlot( # attack_df, # x="strength", # y=metric, # color="model", # ) def mk_variations( all_attacks_df, attacks_with_variations: list[str], ): # all_attacks_df = pd.read_csv(csv_file) # print(all_attacks_df) # print(csv_file) # with gr.Row(): # group_by = gr.Radio(metrics, value=metrics[0], label="Choose metric") # attacks_dropdown = gr.Dropdown( # attacks_with_variations, # label=attacks_with_variations[0], # info="Select attack", # ) # attacks_by_strength = plot_data( # group_by.value, attacks_dropdown.value, all_attacks_df # ) # all_graphs = [ # attacks_by_strength, # ] # group_by.change( # lambda group: plot_data(group, attacks_dropdown.value, all_attacks_df), # group_by, # all_graphs, # ) # attacks_dropdown.change( # lambda attack: plot_data(group_by.value, attack, all_attacks_df), # attacks_dropdown, # all_graphs, # ) # Replace NaN values with None for JSON serialization all_attacks_df = all_attacks_df.fillna(value="NaN") attacks_plot_metrics = [ "bit_acc", "log10_p_value", "TPR", "FPR", "watermark_det_score", ] return { "metrics": attacks_plot_metrics, "attacks_with_variations": attacks_with_variations, "all_attacks_df": all_attacks_df.to_dict(orient="records"), }