Mark Duppenthaler commited on
Commit
98847a8
·
1 Parent(s): a1f1bf8
Dockerfile CHANGED
@@ -38,4 +38,4 @@ EXPOSE 7860
38
  WORKDIR /app
39
 
40
  # Command to run the application
41
- CMD ["/bin/bash", "-c", "conda run --no-capture-output -n omniseal-benchmark-backend gunicorn --chdir /app/backend -b 0.0.0.0:7870 app:app"]
 
38
  WORKDIR /app
39
 
40
  # Command to run the application
41
+ CMD ["/bin/bash", "-c", "conda run --no-capture-output -n omniseal-benchmark-backend gunicorn --chdir /app/backend -b 0.0.0.0:7860 app:app --reload"]
backend/app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from flask import Flask, Response, send_from_directory
2
  from flask_cors import CORS
3
  import os
@@ -5,9 +7,12 @@ import logging
5
  import pandas as pd
6
  import json
7
  from io import StringIO
8
- from tools import get_old_format_dataframe # Import your function
 
 
 
9
  import typing as tp
10
- import collections
11
 
12
  logger = logging.getLogger(__name__)
13
  if not logger.hasHandlers():
@@ -24,6 +29,7 @@ CORS(app)
24
 
25
  @app.route("/")
26
  def index():
 
27
  return send_from_directory(app.static_folder, "index.html")
28
 
29
 
@@ -35,76 +41,13 @@ def data_files(filename):
35
  data_dir = os.path.join(os.path.dirname(__file__), "data")
36
  file_path = os.path.join(data_dir, filename)
37
  if os.path.isfile(file_path):
38
- # Determine file type and handle accordingly
39
  df = pd.read_csv(file_path)
40
- # Modify the dataframe - you'll need to define first_cols and attack_scores
41
- first_cols = [
42
- "snr",
43
- "sisnr",
44
- "stoi",
45
- "pesq",
46
- ] # Define appropriate values based on your needs
47
- attack_scores = [
48
- "bit_acc",
49
- "log10_p_value",
50
- "TPR",
51
- "FPR",
52
- ] # Define appropriate values based on your needs
53
- categories = {
54
- "speed": "Time",
55
- "updownresample": "Time",
56
- "echo": "Time",
57
- "random_noise": "Amplitude",
58
- "lowpass_filter": "Amplitude",
59
- "highpass_filter": "Amplitude",
60
- "bandpass_filter": "Amplitude",
61
- "smooth": "Amplitude",
62
- "boost_audio": "Amplitude",
63
- "duck_audio": "Amplitude",
64
- "shush": "Amplitude",
65
- "pink_noise": "Amplitude",
66
- "aac_compression": "Compression",
67
- "mp3_compression": "Compression",
68
- }
69
-
70
- # This part adds on all the columns
71
- df = get_old_format_dataframe(df, first_cols, attack_scores)
72
-
73
- # Create groups based on categories
74
- groups = collections.OrderedDict({"Overall": set()})
75
- for k in categories.values():
76
- groups[k] = set()
77
-
78
- default_selection = set()
79
- for k, v in categories.items():
80
- if v not in default_selection:
81
- for k in list(df.columns):
82
- if k.startswith(v):
83
- groups["Overall"].add(k)
84
- default_selection.add(k)
85
-
86
- for col in list(df.columns):
87
- for k in categories.keys():
88
- if col.startswith(k):
89
- cat = categories[k]
90
- groups[cat].add(col)
91
- break
92
-
93
- # Replace NaN values with None for JSON serialization
94
- df = df.fillna(value="NaN")
95
-
96
- # Transpose the DataFrame so each column becomes a row and column is the model
97
- df = df.set_index("model").T.reset_index()
98
- df = df.rename(columns={"index": "metric"})
99
-
100
- # Convert DataFrame to JSON
101
- result = {
102
- "groups": {group: list(metrics) for group, metrics in groups.items()},
103
- "selected": list(default_selection),
104
- "rows": df.to_dict(orient="records"),
105
- }
106
-
107
- return Response(json.dumps(result), mimetype="application/json")
108
  # return Response(json.dumps(result), mimetype="application/json")
109
 
110
  # Unreachable code - this section will never execute
@@ -115,5 +58,90 @@ def data_files(filename):
115
  return "File not found", 404
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  if __name__ == "__main__":
119
- app.run(host="0.0.0.0", port=7870, debug=True, use_reloader=True)
 
1
+ from backend.chart import mk_variations
2
+ from backend.examples import image_examples_tab
3
  from flask import Flask, Response, send_from_directory
4
  from flask_cors import CORS
5
  import os
 
7
  import pandas as pd
8
  import json
9
  from io import StringIO
10
+ from tools import (
11
+ get_leaderboard_filters,
12
+ get_old_format_dataframe,
13
+ ) # Import your function
14
  import typing as tp
15
+
16
 
17
  logger = logging.getLogger(__name__)
18
  if not logger.hasHandlers():
 
29
 
30
  @app.route("/")
31
  def index():
32
+ logger.warning("Serving index.html")
33
  return send_from_directory(app.static_folder, "index.html")
34
 
35
 
 
41
  data_dir = os.path.join(os.path.dirname(__file__), "data")
42
  file_path = os.path.join(data_dir, filename)
43
  if os.path.isfile(file_path):
 
44
  df = pd.read_csv(file_path)
45
+ logger.info(f"Processing file: {filename}")
46
+ if filename.endswith("benchmark.csv"):
47
+ # If the file is a CSV, process it to get the leaderboard
48
+ return get_leaderboard(df)
49
+ elif filename.endswith("attacks_variations.csv"):
50
+ return get_chart(df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # return Response(json.dumps(result), mimetype="application/json")
52
 
53
  # Unreachable code - this section will never execute
 
58
  return "File not found", 404
59
 
60
 
61
+ @app.route("/examples/<path:filename>")
62
+ def example_files(filename):
63
+ """
64
+ Serve example files from the examples directory.
65
+ """
66
+ # examples_dir = os.path.join(os.path.dirname(__file__), "examples")
67
+ # file_path = os.path.join(examples_dir, filename)
68
+ # if os.path.isfile(file_path):
69
+ # return send_from_directory(examples_dir, filename)
70
+ abs_path = "https://dl.fbaipublicfiles.com/omnisealbench/"
71
+
72
+ result = image_examples_tab(abs_path)
73
+ return Response(json.dumps(result), mimetype="application/json")
74
+
75
+ return "File not found", 404
76
+
77
+
78
+ def get_leaderboard(df):
79
+ # Determine file type and handle accordingly
80
+
81
+ # Modify the dataframe - you'll need to define first_cols and attack_scores
82
+ first_cols = [
83
+ "snr",
84
+ "sisnr",
85
+ "stoi",
86
+ "pesq",
87
+ ] # Define appropriate values based on your needs
88
+ attack_scores = [
89
+ "bit_acc",
90
+ "log10_p_value",
91
+ "TPR",
92
+ "FPR",
93
+ ] # Define appropriate values based on your needs
94
+ categories = {
95
+ "speed": "Time",
96
+ "updownresample": "Time",
97
+ "echo": "Time",
98
+ "random_noise": "Amplitude",
99
+ "lowpass_filter": "Amplitude",
100
+ "highpass_filter": "Amplitude",
101
+ "bandpass_filter": "Amplitude",
102
+ "smooth": "Amplitude",
103
+ "boost_audio": "Amplitude",
104
+ "duck_audio": "Amplitude",
105
+ "shush": "Amplitude",
106
+ "pink_noise": "Amplitude",
107
+ "aac_compression": "Compression",
108
+ "mp3_compression": "Compression",
109
+ }
110
+
111
+ # This part adds on all the columns
112
+ df = get_old_format_dataframe(df, first_cols, attack_scores)
113
+
114
+ groups, default_selection = get_leaderboard_filters(df, categories)
115
+
116
+ # Replace NaN values with None for JSON serialization
117
+ df = df.fillna(value="NaN")
118
+
119
+ # Transpose the DataFrame so each column becomes a row and column is the model
120
+ df = df.set_index("model").T.reset_index()
121
+ df = df.rename(columns={"index": "metric"})
122
+
123
+ # Convert DataFrame to JSON
124
+ result = {
125
+ "groups": {group: list(metrics) for group, metrics in groups.items()},
126
+ "default_selected_metrics": list(default_selection),
127
+ "rows": df.to_dict(orient="records"),
128
+ }
129
+
130
+ return Response(json.dumps(result), mimetype="application/json")
131
+
132
+
133
+ def get_chart(df):
134
+ # This function should return the chart data based on the DataFrame
135
+ # For now, we will just return a placeholder response
136
+ chart_data = mk_variations(
137
+ df,
138
+ # attacks_plot_metrics,
139
+ # audio_attacks_with_variations,
140
+ )
141
+ print(chart_data)
142
+
143
+ return Response(json.dumps(chart_data), mimetype="application/json")
144
+
145
+
146
  if __name__ == "__main__":
147
+ app.run(host="0.0.0.0", port=7860, debug=True, use_reloader=True)
backend/chart.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ from pathlib import Path
4
+
5
+ audio_attacks_with_variations = [
6
+ "random_noise",
7
+ "lowpass_filter",
8
+ "highpass_filter",
9
+ "boost_audio",
10
+ "duck_audio",
11
+ "shush",
12
+ ]
13
+
14
+ attacks_plot_metrics = ["bit_acc", "log10_p_value", "TPR", "FPR", "watermark_det_score"]
15
+
16
+ image_attacks_with_variations = [
17
+ "center_crop",
18
+ "jpeg",
19
+ "brightness",
20
+ "contrast",
21
+ "saturation",
22
+ "sharpness",
23
+ "resize",
24
+ "perspective",
25
+ "median_filter",
26
+ "hue",
27
+ "gaussian_blur",
28
+ ]
29
+
30
+
31
+ video_attacks_with_variations = [
32
+ "Rotate",
33
+ "Resize",
34
+ "Crop",
35
+ "Brightness",
36
+ "Contrast",
37
+ "Saturation",
38
+ "H264",
39
+ "H264rgb",
40
+ "H265",
41
+ ]
42
+
43
+
44
+ def plot_data(metric, selected_attack, all_attacks_df):
45
+ attack_df = all_attacks_df[all_attacks_df.attack == selected_attack]
46
+
47
+ # if metric == "None":
48
+ # return gr.LinePlot(x_bin=None)
49
+
50
+ # return gr.LinePlot(
51
+ # attack_df,
52
+ # x="strength",
53
+ # y=metric,
54
+ # color="model",
55
+ # )
56
+
57
+
58
+ def mk_variations(
59
+ all_attacks_df,
60
+ metrics: list[str] = attacks_plot_metrics,
61
+ attacks_with_variations: list[str] = audio_attacks_with_variations,
62
+ ):
63
+ # all_attacks_df = pd.read_csv(csv_file)
64
+ # print(all_attacks_df)
65
+ # print(csv_file)
66
+
67
+ # with gr.Row():
68
+ # group_by = gr.Radio(metrics, value=metrics[0], label="Choose metric")
69
+ # attacks_dropdown = gr.Dropdown(
70
+ # attacks_with_variations,
71
+ # label=attacks_with_variations[0],
72
+ # info="Select attack",
73
+ # )
74
+
75
+ # attacks_by_strength = plot_data(
76
+ # group_by.value, attacks_dropdown.value, all_attacks_df
77
+ # )
78
+
79
+ # all_graphs = [
80
+ # attacks_by_strength,
81
+ # ]
82
+
83
+ # group_by.change(
84
+ # lambda group: plot_data(group, attacks_dropdown.value, all_attacks_df),
85
+ # group_by,
86
+ # all_graphs,
87
+ # )
88
+
89
+ # attacks_dropdown.change(
90
+ # lambda attack: plot_data(group_by.value, attack, all_attacks_df),
91
+ # attacks_dropdown,
92
+ # all_graphs,
93
+ # )
94
+
95
+ return {
96
+ "metrics": metrics,
97
+ "attacks_with_variations": attacks_with_variations,
98
+ "all_attacks_df": all_attacks_df.to_dict(orient="records"),
99
+ }
backend/environment.yml CHANGED
@@ -8,6 +8,7 @@ dependencies:
8
  - flask=3.0.3
9
  - werkzeug=3.0.3
10
  - pandas
 
11
  - pip
12
  - pip:
13
  - watchdog
 
8
  - flask=3.0.3
9
  - werkzeug=3.0.3
10
  - pandas
11
+ - requests
12
  - pip
13
  - pip:
14
  - watchdog
backend/examples.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import re
3
+ from pathlib import Path
4
+
5
+ import requests
6
+
7
+
8
+ def group_files_by_index(file_paths, data_type="audio"):
9
+ # Regular expression pattern to extract the key from each image path
10
+ if data_type == "audio":
11
+ pattern = r"audio_(\d+).(png|wav)"
12
+ elif data_type == "video":
13
+ pattern = r"video_(\d+).(png|mkv)"
14
+ else:
15
+ pattern = r"img_(\d+).png"
16
+ # Dictionary to store the grouped files
17
+ grouped_files = {}
18
+ # Iterate over each file path
19
+ for file_path in file_paths:
20
+ # Extract the key using the regular expression pattern
21
+ match = re.search(pattern, file_path)
22
+ if match:
23
+ key = int(match.group(1))
24
+
25
+ # Add the file path to the corresponding group in the dictionary
26
+ if key not in grouped_files:
27
+ grouped_files[key] = []
28
+ grouped_files[key].append(file_path)
29
+ # Sort the dictionary by keys
30
+ sorted_grouped_files = dict(sorted(grouped_files.items()))
31
+ return sorted_grouped_files
32
+
33
+
34
+ def build_description(
35
+ i, data_none, data_attack, quality_metrics=["psnr", "ssim", "lpips"]
36
+ ):
37
+ # TODO: handle this at data generation
38
+ if isinstance(data_none["fake_det"], str):
39
+ data_none["fake_det"] = ast.literal_eval(data_none["fake_det"])
40
+ if isinstance(data_none["watermark_det"], str):
41
+ data_none["watermark_det"] = ast.literal_eval(data_none["watermark_det"])
42
+
43
+ if isinstance(data_attack["fake_det"], str):
44
+ data_attack["fake_det"] = ast.literal_eval(data_attack["fake_det"])
45
+ if isinstance(data_attack["watermark_det"], str):
46
+ data_attack["watermark_det"] = ast.literal_eval(data_attack["watermark_det"])
47
+
48
+ if i == 0:
49
+ fake_det = data_none["fake_det"]
50
+
51
+ return f"detected: {fake_det}"
52
+ elif i == 1:
53
+ # Fixed metrics
54
+ det = data_none["watermark_det"]
55
+ p_value = float(data_none["p_value"])
56
+ bit_acc = data_none["bit_acc"]
57
+
58
+ # Dynamic metrics
59
+ metrics_output = []
60
+ for metric in quality_metrics:
61
+ value = float(data_none[metric])
62
+ metrics_output.append(f"{metric}: {value:.2f}")
63
+
64
+ # Fixed metrics output
65
+ fixed_metrics_output = (
66
+ f" detected: {det} p_value: {p_value:.2f} bit_acc: {bit_acc:.2f}"
67
+ )
68
+
69
+ # Combine all outputs
70
+ return " ".join(metrics_output) + f"{fixed_metrics_output}"
71
+ elif i == 2:
72
+ fake_det = data_attack["fake_det"]
73
+
74
+ return f"det: {fake_det}"
75
+ elif i == 3:
76
+ det = data_attack["watermark_det"]
77
+
78
+ p_value = float(data_attack["p_value"])
79
+ word_acc = data_attack["word_acc"]
80
+ bit_acc = data_attack["bit_acc"]
81
+
82
+ return f"word_acc: {word_acc:.2f} detected: {det} p_value: {p_value:.2f} bit_acc: {bit_acc:.2f}"
83
+
84
+
85
+ def build_infos(abs_path: Path, datatype: str, dataset_name: str, db_key: str):
86
+ def generate_file_patterns(prefixes, extensions):
87
+ indices = [0, 1, 3, 4, 5]
88
+ return [
89
+ f"{prefix}_{index:05d}.{ext}"
90
+ for prefix in prefixes
91
+ for index in indices
92
+ for ext in extensions
93
+ ]
94
+
95
+ if datatype == "audio":
96
+ quality_metrics = ["snr", "sisnr", "stoi", "pesq"]
97
+ extensions = ["png", "wav"]
98
+ datatype_abbr = "audio"
99
+ eval_results_path = abs_path + f"{dataset_name}_1k/examples_eval_results.json"
100
+ elif datatype == "image":
101
+ quality_metrics = ["psnr", "ssim", "lpips"]
102
+ extensions = ["png"]
103
+ datatype_abbr = "img"
104
+ eval_results_path = abs_path + f"{dataset_name}_1k/examples_eval_results.json"
105
+ elif datatype == "video":
106
+ quality_metrics = ["psnr", "ssim", "lpips", "msssim", "vmaf"]
107
+ extensions = ["mkv"]
108
+ datatype_abbr = "video"
109
+ eval_results_path = abs_path + f"{dataset_name}/examples_eval_results.json"
110
+
111
+ response = requests.get(eval_results_path)
112
+ if response.status_code == 200:
113
+ results_data = response.json()
114
+ else:
115
+ return {}
116
+
117
+ dataset = results_data["eval"][db_key]
118
+
119
+ prefixes = [
120
+ f"attacked_{datatype_abbr}",
121
+ f"attacked_wmd_{datatype_abbr}",
122
+ f"{datatype_abbr}",
123
+ f"wmd_{datatype_abbr}",
124
+ ]
125
+
126
+ file_patterns = generate_file_patterns(prefixes, extensions)
127
+
128
+ infos = {}
129
+ for model_name in dataset.keys():
130
+ model_infos = {}
131
+
132
+ default_attack_name = "none"
133
+ if datatype == "audio":
134
+ default_attack_name = "identity"
135
+ elif datatype == "video":
136
+ default_attack_name = "Identity"
137
+
138
+ identity_attack_rows = dataset[model_name][default_attack_name]["default"]
139
+
140
+ for attack_name, attack_variants_data in dataset[model_name].items():
141
+ for attack_variant, attack_rows in attack_variants_data.items():
142
+ if attack_variant == "default":
143
+ attack = attack_name
144
+ else:
145
+ attack = f"{attack_name}_{attack_variant}"
146
+
147
+ if len(attack_rows) == 0:
148
+ model_infos[attack] = []
149
+ continue
150
+
151
+ if datatype == "video":
152
+ file_paths = [
153
+ f"{abs_path}{dataset_name}/examples/{datatype}/{model_name}/{attack}/{pattern}"
154
+ for pattern in file_patterns
155
+ ]
156
+ else:
157
+ file_paths = [
158
+ f"{abs_path}{dataset_name}_1k/examples/{datatype}/{model_name}/{attack}/{pattern}"
159
+ for pattern in file_patterns
160
+ ]
161
+
162
+ all_files = []
163
+
164
+ for i, files in group_files_by_index(
165
+ file_paths,
166
+ data_type=datatype,
167
+ ).items():
168
+ data_none = [e for e in identity_attack_rows if e["idx"] == i][0]
169
+ data_attack = [e for e in attack_rows if e["idx"] == i][0]
170
+
171
+ files = sorted(
172
+ [(f, Path(f).stem) for f in files], key=lambda x: x[1]
173
+ )
174
+ files = files[2:] + files[:2]
175
+
176
+ files = [
177
+ {
178
+ "url": f,
179
+ "description": f"{n}\n{build_description(i, data_none, data_attack, quality_metrics)}",
180
+ }
181
+ for i, (f, n) in enumerate(files)
182
+ ]
183
+
184
+ all_files.extend(files)
185
+
186
+ model_infos[attack] = all_files
187
+
188
+ infos[model_name] = model_infos
189
+
190
+ return infos
191
+
192
+
193
+ def image_examples_tab(abs_path: Path):
194
+ dataset_name = "coco_val2014"
195
+ datatype = "image"
196
+ db_key = "coco_val2014"
197
+
198
+ image_infos = build_infos(
199
+ abs_path,
200
+ datatype=datatype,
201
+ dataset_name=dataset_name,
202
+ db_key=db_key,
203
+ )
204
+
205
+ print(image_infos)
206
+
207
+ # First combo box (category selection)
208
+ # model_choice = gr.Dropdown(
209
+ # choices=list(image_infos.keys()),
210
+ # label="Select a Model",
211
+ # value=None,
212
+ # )
213
+ # Second combo box (subcategory selection)
214
+ # Initialize with options from the first category by default
215
+ # attack_choice = gr.Dropdown(
216
+ # choices=list(image_infos["wam"].keys()),
217
+ # label="Select an Attack",
218
+ # value=None,
219
+ # )
220
+
221
+ # # Gallery component to display images
222
+ # gallery = gr.Gallery(
223
+ # label="Image Gallery",
224
+ # columns=4,
225
+ # rows=1,
226
+ # )
227
+
228
+ # Update options for the second combo box when the first one changes
229
+ # def update_subcategories(selected_category):
230
+ # values = list(image_infos[selected_category].keys())
231
+ # values = [(v, v) for v in values]
232
+ # attack_choice.choices = values
233
+ # # return gr.Dropdown.update(choices=list(image_infos[selected_category].keys()))
234
+
235
+ # # Function to load images based on selections from both combo boxes
236
+ # def load_images(category, subcategory):
237
+ # return image_infos.get(category, {}).get(subcategory, [])
238
+
239
+ # # Update gallery based on both combo box selections
240
+ # model_choice.change(
241
+ # fn=update_subcategories, inputs=model_choice, outputs=attack_choice
242
+ # )
243
+ # attack_choice.change(
244
+ # fn=load_images, inputs=[model_choice, attack_choice], outputs=gallery
245
+ # )
246
+ return image_infos
247
+
248
+
249
+ def video_examples_tab(abs_path: Path):
250
+ dataset_name = "sav_val_full"
251
+ datatype = "video"
252
+ db_key = "sa-v_sav_val_videos"
253
+
254
+ image_infos = build_infos(
255
+ abs_path,
256
+ datatype=datatype,
257
+ dataset_name=dataset_name,
258
+ db_key=db_key,
259
+ )
260
+
261
+ # First combo box (category selection)
262
+ model_choice = gr.Dropdown(
263
+ choices=list(image_infos.keys()),
264
+ label="Select a Model",
265
+ value=None,
266
+ )
267
+ # Second combo box (subcategory selection)
268
+ # Initialize with options from the first category by default
269
+ attack_choice = gr.Dropdown(
270
+ choices=list(image_infos["videoseal_0.0"].keys()),
271
+ label="Select an Attack",
272
+ value=None,
273
+ )
274
+
275
+ # Gallery component to display images
276
+ gallery = gr.Gallery(
277
+ label="Video Gallery",
278
+ columns=4,
279
+ rows=1,
280
+ )
281
+
282
+ # Update options for the second combo box when the first one changes
283
+ def update_subcategories(selected_category):
284
+ values = list(image_infos[selected_category].keys())
285
+ values = [(v, v) for v in values]
286
+ attack_choice.choices = values
287
+ # return gr.Dropdown.update(choices=list(image_infos[selected_category].keys()))
288
+
289
+ # Function to load images based on selections from both combo boxes
290
+ def load_images(category, subcategory):
291
+ return image_infos.get(category, {}).get(subcategory, [])
292
+
293
+ # Update gallery based on both combo box selections
294
+ model_choice.change(
295
+ fn=update_subcategories, inputs=model_choice, outputs=attack_choice
296
+ )
297
+ attack_choice.change(
298
+ fn=load_images, inputs=[model_choice, attack_choice], outputs=gallery
299
+ )
300
+
301
+
302
+ def audio_examples_tab(abs_path: Path):
303
+ dataset_name = "voxpopuli"
304
+ datatype = "audio"
305
+ db_key = "voxpopuli"
306
+
307
+ audio_infos = build_infos(
308
+ abs_path,
309
+ datatype=datatype,
310
+ dataset_name=dataset_name,
311
+ db_key=db_key,
312
+ )
313
+
314
+ # First combo box (category selection)
315
+ model_choice = gr.Dropdown(
316
+ choices=list(audio_infos.keys()),
317
+ label="Select a Model",
318
+ value=None,
319
+ )
320
+ # Second combo box (subcategory selection)
321
+ # Initialize with options from the first category by default
322
+ attack_choice = gr.Dropdown(
323
+ choices=list(audio_infos["audioseal"].keys()),
324
+ label="Select an Attack",
325
+ value=None,
326
+ )
327
+
328
+ # Gallery component to display images
329
+ gallery = gr.Gallery(
330
+ label="Image Gallery", columns=4, rows=1, object_fit="scale-down"
331
+ )
332
+
333
+ audio_player = gr.Audio(visible=False)
334
+ audio_map_state = gr.State({})
335
+
336
+ # Update options for the second combo box when the first one changes
337
+ def update_subcategories(selected_category):
338
+ values = list(audio_infos[selected_category].keys())
339
+ values = [(v, v) for v in values]
340
+ attack_choice.choices = values
341
+ # return gr.Dropdown.update(choices=list(image_infos[selected_category].keys()))
342
+
343
+ # Function to load images based on selections from both combo boxes
344
+ def load_audios(category, subcategory):
345
+ files = audio_infos.get(category, {}).get(subcategory, [])
346
+ images = [f for f in files if f[0].endswith(".png")]
347
+ audios = {f[0]: f[0].replace(".png", ".wav") for f in images}
348
+ return images, audios
349
+
350
+ def play_audio(selected_image, audios):
351
+ image_path = selected_image["image"]["path"]
352
+ audio_file = audios.get(image_path)
353
+ return gr.update(value=audio_file, visible=audio_file is not None)
354
+
355
+ def hide_audio_player():
356
+ # Hide the audio player when the preview is closed
357
+ return gr.update(visible=False)
358
+
359
+ def get_selected_image(select_data: gr.SelectData, audios):
360
+ if select_data is None:
361
+ return gr.update(visible=False)
362
+ selected_image = select_data.value
363
+ return play_audio(selected_image, audios)
364
+
365
+ # Update gallery based on both combo box selections
366
+ model_choice.change(
367
+ fn=update_subcategories, inputs=model_choice, outputs=attack_choice
368
+ )
369
+ attack_choice.change(
370
+ fn=load_audios,
371
+ inputs=[model_choice, attack_choice],
372
+ outputs=[gallery, audio_map_state],
373
+ )
374
+ gallery.select(
375
+ fn=get_selected_image,
376
+ inputs=[audio_map_state],
377
+ outputs=audio_player,
378
+ )
379
+ gallery.preview_close(
380
+ fn=hide_audio_player,
381
+ outputs=audio_player,
382
+ )
383
+ return gr.Column([model_choice, attack_choice, gallery, audio_player])
backend/mk_leaderboard.py ADDED
File without changes
backend/tools.py CHANGED
@@ -1,6 +1,31 @@
 
 
1
  import pandas as pd
2
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def add_avg_as_columns(
5
  benchmark_df: pd.DataFrame, attack_scores: list[str]
6
  ) -> pd.DataFrame:
 
1
+ import collections
2
+
3
  import pandas as pd
4
 
5
 
6
+ def get_leaderboard_filters(df, categories) -> dict[str, list[str]]:
7
+ # Create groups based on categories
8
+ groups = collections.OrderedDict({"Overall": set()})
9
+ for k in categories.values():
10
+ groups[k] = set()
11
+
12
+ default_selection = set()
13
+ for k, v in categories.items():
14
+ if v not in default_selection:
15
+ for k in list(df.columns):
16
+ if k.startswith(v):
17
+ groups["Overall"].add(k)
18
+ default_selection.add(k)
19
+
20
+ for col in list(df.columns):
21
+ for k in categories.keys():
22
+ if col.startswith(k):
23
+ cat = categories[k]
24
+ groups[cat].add(col)
25
+ break
26
+ return groups, default_selection
27
+
28
+
29
  def add_avg_as_columns(
30
  benchmark_df: pd.DataFrame, attack_scores: list[str]
31
  ) -> pd.DataFrame:
frontend/src/API.ts CHANGED
@@ -14,6 +14,16 @@ class API {
14
  if (!response.ok) throw new Error(`Failed to fetch ${path}`)
15
  return response.text()
16
  }
 
 
 
 
 
 
 
 
 
 
17
  }
18
 
19
  export default API
 
14
  if (!response.ok) throw new Error(`Failed to fetch ${path}`)
15
  return response.text()
16
  }
17
+
18
+ // Rename the method to fetchExamplesByType
19
+ static fetchExamplesByType(type: 'image' | 'audio' | 'video'): Promise<any> {
20
+ return fetch(`${VITE_API_SERVER_URL}/examples/${type}`).then((response) => {
21
+ if (!response.ok) {
22
+ throw new Error(`Failed to fetch examples of type ${type}`)
23
+ }
24
+ return response.json()
25
+ })
26
+ }
27
  }
28
 
29
  export default API
frontend/src/App.tsx CHANGED
@@ -1,151 +1,12 @@
1
- import { useState, useEffect } from 'react'
2
  import API from './API'
3
  import DataChart from './components/DataChart'
4
-
5
- // Define types for groups and metrics
6
- interface Groups {
7
- [group: string]: { [subgroup: string]: string[] }
8
- }
9
-
10
- interface Row {
11
- metric: string
12
- [key: string]: string | number
13
- }
14
-
15
- // New Filter Component
16
- function Filter({
17
- groups,
18
- selectedMetrics,
19
- setSelectedMetrics,
20
- }: {
21
- groups: Groups
22
- selectedMetrics: Set<string>
23
- setSelectedMetrics: (metrics: Set<string>) => void
24
- }) {
25
- const [openGroups, setOpenGroups] = useState<{ [key: string]: boolean }>({})
26
- const [openSubGroups, setOpenSubGroups] = useState<{ [key: string]: { [key: string]: boolean } }>(
27
- {}
28
- )
29
-
30
- const toggleGroup = (group: string) => {
31
- setOpenGroups((prev) => ({ ...prev, [group]: !prev[group] }))
32
- }
33
-
34
- const toggleSubGroup = (group: string, subGroup: string) => {
35
- setOpenSubGroups((prev) => ({
36
- ...prev,
37
- [group]: {
38
- ...prev[group],
39
- [subGroup]: !prev[group]?.[subGroup],
40
- },
41
- }))
42
- }
43
-
44
- return (
45
- <div className="w-11/12 flex flex-wrap gap-4 p-4 bg-gray-50 rounded shadow">
46
- {Object.entries(groups).map(([group, subGroups]) => (
47
- <div key={group} className="filter-group w-1/3 border p-2 rounded overflow-hidden">
48
- <h4
49
- onClick={() => toggleGroup(group)}
50
- className="cursor-pointer text-lg font-semibold text-blue-600 hover:underline truncate"
51
- title={group}
52
- >
53
- {group} {openGroups[group] ? '▼' : '▶'}
54
- </h4>
55
- {openGroups[group] && (
56
- <div className="filter-subgroups">
57
- {Object.entries(subGroups).map(([subGroup, metrics]) => (
58
- <div key={subGroup} className="filter-subgroup border-t pt-2 mt-2">
59
- <h5
60
- onClick={() => toggleSubGroup(group, subGroup)}
61
- className="cursor-pointer text-md font-medium text-gray-700 hover:underline truncate"
62
- title={subGroup}
63
- >
64
- {subGroup} {openSubGroups[group]?.[subGroup] ? '▼' : '▶'}
65
- </h5>
66
- {openSubGroups[group]?.[subGroup] && (
67
- <div className="filter-metrics grid grid-cols-2 gap-2 mt-2">
68
- {metrics.map((metric) => (
69
- <div key={metric} className="flex items-center space-x-2 truncate">
70
- <input
71
- type="checkbox"
72
- checked={selectedMetrics.has(metric)}
73
- onChange={(event) => {
74
- const newSet = new Set(selectedMetrics)
75
- if (event.target.checked) {
76
- newSet.add(metric)
77
- } else {
78
- newSet.delete(metric)
79
- }
80
- setSelectedMetrics(newSet)
81
- }}
82
- className="form-checkbox h-4 w-4 text-blue-600"
83
- />
84
- <label className="text-sm text-gray-600 truncate" title={metric}>
85
- {metric.includes('_') ? metric.split('_').slice(1).join('_') : metric}
86
- </label>
87
- </div>
88
- ))}
89
- </div>
90
- )}
91
- </div>
92
- ))}
93
- </div>
94
- )}
95
- </div>
96
- ))}
97
- </div>
98
- )
99
- }
100
 
101
  function App() {
102
- const [count, setCount] = useState(0)
103
- const [tableRows, setTableRows] = useState<Row[]>([])
104
- const [tableHeader, setTableHeader] = useState<string[]>([])
105
- const [chartData, setChartData] = useState<Row[]>([])
106
- const [loading, setLoading] = useState(true)
107
- const [error, setError] = useState<string | null>(null)
108
- const [groups, setGroups] = useState<Groups>({})
109
- const [selectedMetrics, setSelectedMetrics] = useState<Set<string>>(new Set())
110
-
111
- const file = 'voxpopuli_1k_audio_benchmark'
112
-
113
- useEffect(() => {
114
- API.fetchStaticFile(`data/${file}.csv`)
115
- .then((response) => {
116
- const data = JSON.parse(response)
117
- const rows: Row[] = data['rows']
118
- const groups = data['groups'] as { [key: string]: string[] }
119
-
120
- // Each value of groups is a list of metrics, group them by the first part of the metric before the first _
121
- const groupsData = Object.entries(groups).reduce(
122
- (acc, [group, metrics]) => {
123
- acc[group] = metrics.reduce<{ [key: string]: string[] }>((subAcc, metric) => {
124
- const [mainGroup, subGroup] = metric.split('_')
125
- if (!subAcc[mainGroup]) {
126
- subAcc[mainGroup] = []
127
- }
128
- subAcc[mainGroup].push(metric)
129
- return subAcc
130
- }, {})
131
- return acc
132
- },
133
- {} as { [key: string]: { [key: string]: string[] } }
134
- )
135
-
136
- const allKeys: string[] = Array.from(new Set(rows.flatMap((row) => Object.keys(row))))
137
- setSelectedMetrics(new Set(data['selected']))
138
- setTableHeader(allKeys)
139
- setTableRows(rows)
140
- setChartData(rows)
141
- setGroups(groupsData)
142
- setLoading(false)
143
- })
144
- .catch((err) => {
145
- setError('Failed to fetch JSON: ' + err.message)
146
- setLoading(false)
147
- })
148
- }, [])
149
 
150
  return (
151
  <div className="flex flex-col items-center justify-center min-h-screen bg-gray-100">
@@ -153,60 +14,33 @@ function App() {
153
  <div className="card-body">
154
  <h2 className="card-title">🥇 Omni Seal Bench Watermarking Leaderboard</h2>
155
  <p>Simple proof of concept with Flask backend serving a React frontend.</p>
156
- <div className="card-actions justify-center mt-4">
157
- <button className="btn btn-primary" onClick={() => setCount((count) => count + 1)}>
158
- Count is {count}
159
- </button>
160
- </div>
161
  </div>
162
  </div>
163
- <DataChart data={chartData} loading={loading} error={error} headers={tableHeader} />
164
-
165
- <Filter
166
- groups={groups}
167
- selectedMetrics={selectedMetrics}
168
- setSelectedMetrics={(metrics) => {
169
- setSelectedMetrics(metrics)
170
- }}
171
- />
172
 
173
- <div className="w-11/12 max-w-4xl bg-white rounded shadow p-4 overflow-auto">
174
- <h3 className="font-bold mb-2">{file}</h3>
175
- {loading && <div>Loading...</div>}
176
- {error && <div className="text-red-500">{error}</div>}
177
- {!loading && !error && (
178
- <div className="overflow-x-auto">
179
- <table>
180
- <thead>
181
- <tr>
182
- {tableHeader.map((col, idx) => (
183
- <th key={idx}>{col}</th>
184
- ))}
185
- </tr>
186
- </thead>
187
- <tbody>
188
- {tableRows
189
- .filter((row) => selectedMetrics.has(row['metric']))
190
- .map((row, i) => (
191
- <tr key={i}>
192
- {Object.keys(row).map((column, j) => {
193
- const cell = row[column]
194
-
195
- return (
196
- <td key={j}>
197
- <div className="p-4">
198
- {isNaN(Number(cell)) ? cell : Number(Number(cell).toFixed(3))}
199
- </div>
200
- </td>
201
- )
202
- })}
203
- </tr>
204
- ))}
205
- </tbody>
206
- </table>
207
- </div>
208
- )}
209
  </div>
 
 
 
 
210
  </div>
211
  )
212
  }
 
1
+ import { useState } from 'react'
2
  import API from './API'
3
  import DataChart from './components/DataChart'
4
+ import LeaderboardTable from './components/LeaderboardTable'
5
+ import Examples from './components/Examples'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  function App() {
8
+ const file = 'voxpopuli_1k_audio'
9
+ const [activeTab, setActiveTab] = useState<'dataChart' | 'leaderboard' | 'examples'>('dataChart')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  return (
12
  <div className="flex flex-col items-center justify-center min-h-screen bg-gray-100">
 
14
  <div className="card-body">
15
  <h2 className="card-title">🥇 Omni Seal Bench Watermarking Leaderboard</h2>
16
  <p>Simple proof of concept with Flask backend serving a React frontend.</p>
 
 
 
 
 
17
  </div>
18
  </div>
 
 
 
 
 
 
 
 
 
19
 
20
+ <div className="tabs">
21
+ <a
22
+ className={`tab tab-bordered ${activeTab === 'dataChart' ? 'tab-active' : ''}`}
23
+ onClick={() => setActiveTab('dataChart')}
24
+ >
25
+ Data Chart
26
+ </a>
27
+ <a
28
+ className={`tab tab-bordered ${activeTab === 'leaderboard' ? 'tab-active' : ''}`}
29
+ onClick={() => setActiveTab('leaderboard')}
30
+ >
31
+ Leaderboard Table
32
+ </a>
33
+ <a
34
+ className={`tab tab-bordered ${activeTab === 'examples' ? 'tab-active' : ''}`}
35
+ onClick={() => setActiveTab('examples')}
36
+ >
37
+ Examples
38
+ </a>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  </div>
40
+
41
+ {activeTab === 'dataChart' && <DataChart file={file} />}
42
+ {activeTab === 'leaderboard' && <LeaderboardTable file={file} />}
43
+ {activeTab === 'examples' && <Examples />}
44
  </div>
45
  )
46
  }
frontend/src/components/DataChart.tsx CHANGED
@@ -1,47 +1,231 @@
1
- import { LineChart, Line, XAxis, YAxis, CartesianGrid, Tooltip, Legend, ResponsiveContainer } from 'recharts';
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  interface DataChartProps {
4
- data: any[];
5
- loading: boolean;
6
- error: string | null;
7
- headers: string[];
8
  }
9
 
10
- const DataChart = ({ data, loading, error, headers }: DataChartProps) => {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  return (
12
  <div className="w-11/12 max-w-4xl bg-white rounded shadow p-4 overflow-auto mb-8">
13
  <h3 className="font-bold mb-2">Data Visualization</h3>
14
  {loading && <div>Loading...</div>}
15
  {error && <div className="text-red-500">{error}</div>}
16
- {!loading && !error && data.length > 0 && (
17
- <div className="h-64 mb-4">
18
- <ResponsiveContainer width="100%" height="100%">
19
- <LineChart
20
- data={data}
21
- margin={{
22
- top: 5,
23
- right: 30,
24
- left: 20,
25
- bottom: 5,
26
- }}
27
- >
28
- <CartesianGrid strokeDasharray="3 3" />
29
- <XAxis dataKey={headers[0]} />
30
- <YAxis />
31
- <Tooltip />
32
- <Legend />
33
- {headers[1] && (
34
- <Line type="monotone" dataKey={headers[1]} stroke="#8884d8" dot={false} />
35
- )}
36
- {headers[2] && (
37
- <Line type="monotone" dataKey={headers[2]} stroke="#82ca9d" dot={false} />
38
- )}
39
- </LineChart>
40
- </ResponsiveContainer>
41
- </div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )}
43
  </div>
44
- );
45
- };
46
 
47
- export default DataChart;
 
1
+ import { useEffect, useState } from 'react'
2
+ import {
3
+ LineChart,
4
+ Line,
5
+ XAxis,
6
+ YAxis,
7
+ CartesianGrid,
8
+ Tooltip,
9
+ Legend,
10
+ ResponsiveContainer,
11
+ } from 'recharts'
12
+ import API from '../API'
13
 
14
  interface DataChartProps {
15
+ file: string
 
 
 
16
  }
17
 
18
+ interface Row {
19
+ metric: string
20
+ [key: string]: string | number
21
+ }
22
+
23
+ // MetricSelector Component
24
+ const MetricSelector = ({
25
+ metrics,
26
+ selectedMetric,
27
+ onMetricChange,
28
+ }: {
29
+ metrics: Set<string>
30
+ selectedMetric: string | null
31
+ onMetricChange: (event: React.ChangeEvent<HTMLSelectElement>) => void
32
+ }) => {
33
+ return (
34
+ <div className="mb-4">
35
+ <label htmlFor="metric-selector" className="block text-sm font-medium text-gray-700">
36
+ Select Metric:
37
+ </label>
38
+ <select
39
+ id="metric-selector"
40
+ value={selectedMetric || ''}
41
+ onChange={onMetricChange}
42
+ className="mt-1 block w-full pl-3 pr-10 py-2 text-base border-gray-300 focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm rounded-md"
43
+ >
44
+ <option value="">-- Select a Metric --</option>
45
+ {[...metrics].map((metric) => (
46
+ <option key={metric} value={metric}>
47
+ {metric}
48
+ </option>
49
+ ))}
50
+ </select>
51
+ </div>
52
+ )
53
+ }
54
+
55
+ // AttackSelector Component
56
+ const AttackSelector = ({
57
+ attacks,
58
+ selectedAttack,
59
+ onAttackChange,
60
+ }: {
61
+ attacks: Set<string>
62
+ selectedAttack: string | null
63
+ onAttackChange: (event: React.ChangeEvent<HTMLSelectElement>) => void
64
+ }) => {
65
+ return (
66
+ <div className="mb-4">
67
+ <label htmlFor="attack-selector" className="block text-sm font-medium text-gray-700">
68
+ Select Attack:
69
+ </label>
70
+ <select
71
+ id="attack-selector"
72
+ value={selectedAttack || ''}
73
+ onChange={onAttackChange}
74
+ className="mt-1 block w-full pl-3 pr-10 py-2 text-base border-gray-300 focus:outline-none focus:ring-blue-500 focus:border-blue-500 sm:text-sm rounded-md"
75
+ >
76
+ <option value="">-- Select an Attack --</option>
77
+ {[...attacks].map((attack) => (
78
+ <option key={attack} value={attack}>
79
+ {attack}
80
+ </option>
81
+ ))}
82
+ </select>
83
+ </div>
84
+ )
85
+ }
86
+
87
+ const DataChart = ({ file }: DataChartProps) => {
88
+ const [chartData, setChartData] = useState<Row[]>([])
89
+ const [loading, setLoading] = useState(true)
90
+ const [error, setError] = useState<string | null>(null)
91
+ const [metrics, setMetrics] = useState<Set<string>>(new Set())
92
+ const [attacks, setAttacks] = useState<Set<string>>(new Set())
93
+ const [selectedMetric, setSelectedMetric] = useState<string | null>(null)
94
+ const [selectedAttack, setSelectedAttack] = useState<string | null>(null)
95
+
96
+ useEffect(() => {
97
+ API.fetchStaticFile(`data/${file}_attacks_variations.csv`)
98
+ .then((response) => {
99
+ const data = JSON.parse(response)
100
+ const rows: Row[] = data['all_attacks_df'].map((row: any) => {
101
+ const newRow: Row = { ...row }
102
+ // Convert strength value to number if it exists and is a string
103
+ if (typeof newRow.strength === 'string') {
104
+ newRow.strength = parseFloat(newRow.strength)
105
+ }
106
+ return newRow
107
+ })
108
+
109
+ setSelectedMetric(data['metrics'][0])
110
+ setMetrics(new Set(data['metrics']))
111
+ setSelectedAttack(data['attacks_with_variations'][0])
112
+ setAttacks(new Set(data['attacks_with_variations']))
113
+ setChartData(rows)
114
+ setLoading(false)
115
+ })
116
+ .catch((err) => {
117
+ setError('Failed to fetch JSON: ' + err.message)
118
+ setLoading(false)
119
+ })
120
+ }, [])
121
+
122
+ const handleMetricChange = (event: React.ChangeEvent<HTMLSelectElement>) => {
123
+ setSelectedMetric(event.target.value)
124
+ }
125
+
126
+ const handleAttackChange = (event: React.ChangeEvent<HTMLSelectElement>) => {
127
+ setSelectedAttack(event.target.value)
128
+ }
129
+
130
+ // Sort the chart data by the 'strength' field before rendering
131
+ const sortedChartData = chartData
132
+ .filter((row) => !selectedAttack || row.attack === selectedAttack)
133
+ .sort((a, b) => (a.strength as number) - (b.strength as number))
134
+
135
  return (
136
  <div className="w-11/12 max-w-4xl bg-white rounded shadow p-4 overflow-auto mb-8">
137
  <h3 className="font-bold mb-2">Data Visualization</h3>
138
  {loading && <div>Loading...</div>}
139
  {error && <div className="text-red-500">{error}</div>}
140
+ {!loading && !error && (
141
+ <>
142
+ <MetricSelector
143
+ metrics={metrics}
144
+ selectedMetric={selectedMetric}
145
+ onMetricChange={handleMetricChange}
146
+ />
147
+
148
+ <AttackSelector
149
+ attacks={attacks}
150
+ selectedAttack={selectedAttack}
151
+ onAttackChange={handleAttackChange}
152
+ />
153
+
154
+ {chartData.length > 0 && (
155
+ <div className="h-64 mb-4">
156
+ <ResponsiveContainer width="100%" height="100%">
157
+ <LineChart
158
+ data={sortedChartData}
159
+ margin={{
160
+ top: 5,
161
+ right: 30,
162
+ left: 20,
163
+ bottom: 5,
164
+ }}
165
+ >
166
+ <CartesianGrid strokeDasharray="3 3" />
167
+ <XAxis
168
+ dataKey="strength"
169
+ domain={[
170
+ Math.min(...sortedChartData.map((item) => Number(item.strength))),
171
+ Math.max(...sortedChartData.map((item) => Number(item.strength))),
172
+ ]}
173
+ type="number"
174
+ tickFormatter={(value) => value.toString()}
175
+ label={{ value: 'Strength', position: 'insideBottomRight', offset: -5 }}
176
+ />
177
+ <YAxis
178
+ label={{
179
+ value: selectedMetric || '',
180
+ angle: -90,
181
+ position: 'insideLeft',
182
+ style: { textAnchor: 'middle' },
183
+ }}
184
+ />
185
+ <Tooltip />
186
+ <Legend />
187
+
188
+ {(() => {
189
+ // Ensure selectedMetric is not null before rendering the Line components
190
+ if (!selectedMetric) return null // Do not render lines if no metric is selected
191
+
192
+ // Get unique models from the filtered and sorted data
193
+ const models = new Set(sortedChartData.map((row) => row.model))
194
+
195
+ // Generate different colors for each model
196
+ const colors = [
197
+ '#8884d8',
198
+ '#82ca9d',
199
+ '#ffc658',
200
+ '#ff8042',
201
+ '#0088fe',
202
+ '#00C49F',
203
+ ]
204
+
205
+ // Return a Line component for each model
206
+ return [...models].map((model, index) => {
207
+ console.log(sortedChartData.filter((row) => row.model === model))
208
+ return (
209
+ <Line
210
+ key={model as string}
211
+ type="monotone"
212
+ dataKey={selectedMetric as string} // Ensure selectedMetric is a string
213
+ data={sortedChartData.filter((row) => row.model === model)}
214
+ name={model as string}
215
+ stroke={colors[index % colors.length]}
216
+ dot={false}
217
+ />
218
+ )
219
+ })
220
+ })()}
221
+ </LineChart>
222
+ </ResponsiveContainer>
223
+ </div>
224
+ )}
225
+ </>
226
  )}
227
  </div>
228
+ )
229
+ }
230
 
231
+ export default DataChart
frontend/src/components/Examples.tsx ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useState, useEffect } from 'react'
2
+ import API from '../API'
3
+
4
+ const Examples = () => {
5
+ const [fileType, setFileType] = useState<'image' | 'audio' | 'video'>('image')
6
+ const [examples, setExamples] = useState<{
7
+ [model: string]: { [attack: string]: { url: string; description: string }[] }
8
+ }>({})
9
+ const [loading, setLoading] = useState(false)
10
+ const [error, setError] = useState<string | null>(null)
11
+ const [selectedModel, setSelectedModel] = useState<string | null>(null)
12
+ const [selectedAttack, setSelectedAttack] = useState<string | null>(null)
13
+
14
+ useEffect(() => {
15
+ setLoading(true)
16
+ setError(null)
17
+ API.fetchExamplesByType(fileType)
18
+ .then((data) => {
19
+ // data is a dictionary from {[model]: {[attack]: (url, description)}}
20
+ setExamples(data)
21
+ setLoading(false)
22
+ })
23
+ .catch((err) => {
24
+ setError(err.message)
25
+ setLoading(false)
26
+ })
27
+ }, [fileType])
28
+ if (selectedModel && selectedAttack) {
29
+ console.log(examples[selectedModel][selectedAttack])
30
+ }
31
+
32
+ // Define the Gallery component within this file
33
+ const Gallery = ({
34
+ selectedModel,
35
+ selectedAttack,
36
+ examples,
37
+ fileType,
38
+ }: {
39
+ selectedModel: string
40
+ selectedAttack: string
41
+ examples: {
42
+ [model: string]: { [attack: string]: { url: string; description: string }[] }
43
+ }
44
+ fileType: 'image' | 'audio' | 'video'
45
+ }) => {
46
+ const exampleItems = examples[selectedModel][selectedAttack]
47
+
48
+ return (
49
+ <div className="example-display">
50
+ <h4>{selectedModel}</h4>
51
+ <h5>{selectedAttack}</h5>
52
+ {exampleItems.map((item, index) => (
53
+ <div key={index} className="example-item">
54
+ <p>{item.description}</p>
55
+ {fileType === 'image' && (
56
+ <img src={item.url} alt={item.description} className="example-image" />
57
+ )}
58
+ {fileType === 'audio' && <audio controls src={item.url} className="example-audio" />}
59
+ {fileType === 'video' && <video controls src={item.url} className="example-video" />}
60
+ </div>
61
+ ))}
62
+ </div>
63
+ )
64
+ }
65
+
66
+ return (
67
+ <div className="examples-container">
68
+ <h3>Examples</h3>
69
+ <div className="file-type-selector">
70
+ <label>
71
+ Select File Type:
72
+ <select
73
+ value={fileType}
74
+ onChange={(e) => setFileType(e.target.value as 'image' | 'audio' | 'video')}
75
+ >
76
+ <option value="image">Image</option>
77
+ <option value="audio">Audio</option>
78
+ <option value="video">Video</option>
79
+ </select>
80
+ </label>
81
+ </div>
82
+
83
+ <div className="model-selector">
84
+ <label>
85
+ Select Model:
86
+ <select
87
+ value={selectedModel || ''}
88
+ onChange={(e) => setSelectedModel(e.target.value || null)}
89
+ >
90
+ <option value="">-- Select a Model --</option>
91
+ {Object.keys(examples).map((model) => (
92
+ <option key={model} value={model}>
93
+ {model}
94
+ </option>
95
+ ))}
96
+ </select>
97
+ </label>
98
+ </div>
99
+
100
+ {selectedModel && (
101
+ <div className="attack-selector">
102
+ <label>
103
+ Select Attack:
104
+ <select
105
+ value={selectedAttack || ''}
106
+ onChange={(e) => setSelectedAttack(e.target.value || null)}
107
+ >
108
+ <option value="">-- Select an Attack --</option>
109
+ {Object.keys(examples[selectedModel]).map((attack) => (
110
+ <option key={attack} value={attack}>
111
+ {attack}
112
+ </option>
113
+ ))}
114
+ </select>
115
+ </label>
116
+ </div>
117
+ )}
118
+
119
+ {loading && <p>Loading files...</p>}
120
+ {error && <p className="error">Error: {error}</p>}
121
+
122
+ {selectedModel && selectedAttack && (
123
+ <Gallery
124
+ selectedModel={selectedModel}
125
+ selectedAttack={selectedAttack}
126
+ examples={examples}
127
+ fileType={fileType}
128
+ />
129
+ )}
130
+ </div>
131
+ )
132
+ }
133
+
134
+ export default Examples
frontend/src/components/LeaderboardTable.tsx ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import React, { useEffect, useState } from 'react'
2
+ import API from '../API'
3
+
4
+ interface LeaderboardTableProps {
5
+ file: string
6
+ }
7
+
8
+ interface Row {
9
+ metric: string
10
+ [key: string]: string | number
11
+ }
12
+
13
+ interface Groups {
14
+ [group: string]: { [subgroup: string]: string[] }
15
+ }
16
+
17
+ // New Filter Component
18
+ function Filter({
19
+ groups,
20
+ selectedMetrics,
21
+ setSelectedMetrics,
22
+ }: {
23
+ groups: Groups
24
+ selectedMetrics: Set<string>
25
+ setSelectedMetrics: (metrics: Set<string>) => void
26
+ }) {
27
+ const [openGroups, setOpenGroups] = useState<{ [key: string]: boolean }>({})
28
+ const [openSubGroups, setOpenSubGroups] = useState<{ [key: string]: { [key: string]: boolean } }>(
29
+ {}
30
+ )
31
+
32
+ const toggleGroup = (group: string) => {
33
+ setOpenGroups((prev) => ({ ...prev, [group]: !prev[group] }))
34
+ }
35
+
36
+ const toggleSubGroup = (group: string, subGroup: string) => {
37
+ setOpenSubGroups((prev) => ({
38
+ ...prev,
39
+ [group]: {
40
+ ...prev[group],
41
+ [subGroup]: !prev[group]?.[subGroup],
42
+ },
43
+ }))
44
+ }
45
+ return (
46
+ <div className="w-11/12 flex flex-wrap gap-4 p-4 bg-gray-50 rounded shadow">
47
+ {Object.entries(groups).map(([group, subGroups]) => (
48
+ <div key={group} className="filter-group w-1/3 border p-2 rounded overflow-hidden">
49
+ <h4
50
+ onClick={() => toggleGroup(group)}
51
+ className="cursor-pointer text-lg font-semibold text-blue-600 hover:underline truncate"
52
+ title={group}
53
+ >
54
+ {group} {openGroups[group] ? '▼' : '▶'}
55
+ </h4>
56
+ {openGroups[group] && (
57
+ <div className="filter-subgroups">
58
+ {Object.entries(subGroups).map(([subGroup, metrics]) => (
59
+ <div key={subGroup} className="filter-subgroup border-t pt-2 mt-2">
60
+ <h5
61
+ onClick={() => toggleSubGroup(group, subGroup)}
62
+ className="cursor-pointer text-md font-medium text-gray-700 hover:underline truncate"
63
+ title={subGroup}
64
+ >
65
+ {subGroup} {openSubGroups[group]?.[subGroup] ? '▼' : '▶'}
66
+ </h5>
67
+ {openSubGroups[group]?.[subGroup] && (
68
+ <div className="filter-metrics grid grid-cols-2 gap-2 mt-2">
69
+ {metrics.map((metric) => (
70
+ <div key={metric} className="flex items-center space-x-2 truncate">
71
+ <input
72
+ type="checkbox"
73
+ checked={selectedMetrics.has(metric)}
74
+ onChange={(event) => {
75
+ const newSet = new Set(selectedMetrics)
76
+ if (event.target.checked) {
77
+ newSet.add(metric)
78
+ } else {
79
+ newSet.delete(metric)
80
+ }
81
+ setSelectedMetrics(newSet)
82
+ }}
83
+ className="form-checkbox h-4 w-4 text-blue-600"
84
+ />
85
+ <label className="text-sm text-gray-600 truncate" title={metric}>
86
+ {metric.includes('_') ? metric.split('_').slice(1).join('_') : metric}
87
+ </label>
88
+ </div>
89
+ ))}
90
+ </div>
91
+ )}
92
+ </div>
93
+ ))}
94
+ </div>
95
+ )}
96
+ </div>
97
+ ))}
98
+ </div>
99
+ )
100
+ }
101
+ const LeaderboardTable: React.FC<LeaderboardTableProps> = ({ file }) => {
102
+ const [tableRows, setTableRows] = useState<Row[]>([])
103
+ const [tableHeader, setTableHeader] = useState<string[]>([])
104
+ const [loading, setLoading] = useState(true)
105
+ const [error, setError] = useState<string | null>(null)
106
+ const [groups, setGroups] = useState<Groups>({})
107
+
108
+ const [selectedMetrics, setSelectedMetrics] = useState<Set<string>>(new Set())
109
+
110
+ useEffect(() => {
111
+ API.fetchStaticFile(`data/${file}_benchmark.csv`)
112
+ .then((response) => {
113
+ const data = JSON.parse(response)
114
+ const rows: Row[] = data['rows']
115
+ const groups = data['groups'] as { [key: string]: string[] }
116
+
117
+ // Each value of groups is a list of metrics, group them by the first part of the metric before the first _
118
+ const groupsData = Object.entries(groups).reduce(
119
+ (acc, [group, metrics]) => {
120
+ acc[group] = metrics.reduce<{ [key: string]: string[] }>((subAcc, metric) => {
121
+ const [mainGroup, subGroup] = metric.split('_')
122
+ if (!subAcc[mainGroup]) {
123
+ subAcc[mainGroup] = []
124
+ }
125
+ subAcc[mainGroup].push(metric)
126
+ return subAcc
127
+ }, {})
128
+ return acc
129
+ },
130
+ {} as { [key: string]: { [key: string]: string[] } }
131
+ )
132
+
133
+ const allKeys: string[] = Array.from(new Set(rows.flatMap((row) => Object.keys(row))))
134
+ setSelectedMetrics(new Set(data['default_selected_metrics']))
135
+ setTableHeader(allKeys)
136
+ setTableRows(rows)
137
+ setGroups(groupsData)
138
+ setLoading(false)
139
+ })
140
+ .catch((err) => {
141
+ setError('Failed to fetch JSON: ' + err.message)
142
+ setLoading(false)
143
+ })
144
+ }, [file])
145
+
146
+ return (
147
+ <div className="w-11/12 max-w-4xl bg-white rounded shadow p-4 overflow-auto">
148
+ <h3 className="font-bold mb-2">{file}</h3>
149
+ {loading && <div>Loading...</div>}
150
+ {error && <div className="text-red-500">{error}</div>}
151
+
152
+ {!loading && !error && (
153
+ <div className="overflow-x-auto">
154
+ <Filter
155
+ groups={groups}
156
+ selectedMetrics={selectedMetrics}
157
+ setSelectedMetrics={setSelectedMetrics}
158
+ />
159
+ <table>
160
+ <thead>
161
+ <tr>
162
+ {tableHeader.map((col, idx) => (
163
+ <th key={idx}>{col}</th>
164
+ ))}
165
+ </tr>
166
+ </thead>
167
+ <tbody>
168
+ {tableRows
169
+ .filter((row) => selectedMetrics.has(row['metric'] as string))
170
+ .map((row, i) => (
171
+ <tr key={i}>
172
+ {Object.keys(row).map((column, j) => {
173
+ const cell = row[column]
174
+
175
+ return (
176
+ <td key={j}>
177
+ <div className="p-4">
178
+ {isNaN(Number(cell)) ? cell : Number(Number(cell).toFixed(3))}
179
+ </div>
180
+ </td>
181
+ )
182
+ })}
183
+ </tr>
184
+ ))}
185
+ </tbody>
186
+ </table>
187
+ </div>
188
+ )}
189
+ </div>
190
+ )
191
+ }
192
+
193
+ export default LeaderboardTable