hynky HF Staff commited on
Commit
2990133
·
1 Parent(s): b58aae6

metric naming

Browse files
Files changed (1) hide show
  1. app.py +201 -147
app.py CHANGED
@@ -5,29 +5,21 @@ import json
5
  import os
6
  from pathlib import Path
7
  import re
 
8
  import tempfile
9
  from typing import Literal
10
  import gradio as gr
11
 
12
  from collections import defaultdict
13
- from datatrove.io import DataFolder, get_datafolder
14
  import plotly.graph_objects as go
15
- from datatrove.utils.stats import MetricStatsDict
16
  import plotly.express as px
 
17
 
18
  import gradio as gr
19
  PARTITION_OPTIONS = Literal[ "Top", "Bottom", "Most frequent (n_docs)"]
20
-
21
- LOG_SCALE_STATS = {
22
- "length",
23
- "n_lines",
24
- "n_docs",
25
- "n_words",
26
- "avg_words_per_line",
27
- "pages_with_lorem_ipsum",
28
- }
29
-
30
- STATS_LOCATION_DEFAULT = os.getenv("STATS_LOCATION_DEFAULT", "s3://")
31
 
32
 
33
  def find_folders(base_folder, path):
@@ -43,28 +35,31 @@ def find_folders(base_folder, path):
43
  )
44
 
45
 
46
- def find_stats_folders(base_folder: str):
47
  base_data_folder = get_datafolder(base_folder)
48
- # First find all stats-merged.json using globing for stats-merged.json
49
- stats_merged = base_data_folder.glob("**/stats-merged.json")
50
 
51
- # Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name)
52
- stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged]
53
  # Finally get the unique paths
54
- return sorted(list(set(stats_folders)))
55
 
56
 
57
  def fetch_datasets(base_folder: str):
58
- datasets = sorted(find_stats_folders(base_folder))
59
  return datasets, gr.update(choices=datasets, value=None), fetch_groups(base_folder, datasets, None, "union")
60
 
61
 
62
- def export_data(exported_data, stat_name):
63
  if not exported_data:
64
  return None
65
  # Assuming exported_data is a dictionary where the key is the dataset name and the value is the data to be exported
66
- with tempfile.NamedTemporaryFile(mode="w", delete=False, prefix=stat_name, suffix=".json") as temp:
67
- json.dump(exported_data, temp)
 
 
 
68
  temp_path = temp.name
69
  return gr.update(visible=True, value=temp_path)
70
 
@@ -80,7 +75,7 @@ def fetch_groups(base_folder, datasets, old_groups, type="intersection"):
80
 
81
  if type == "intersection":
82
  new_choices = set.intersection(*(set(g) for g in GROUPS))
83
- elif type == "union":
84
  new_choices = set.union(*(set(g) for g in GROUPS))
85
  value = None
86
  if old_groups:
@@ -91,27 +86,27 @@ def fetch_groups(base_folder, datasets, old_groups, type="intersection"):
91
  return gr.update(choices=sorted(list(new_choices)), value=value)
92
 
93
 
94
- def fetch_stats(base_folder, datasets, group, old_stats, type="intersection"):
95
  with ThreadPoolExecutor() as executor:
96
- STATS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets))
97
- if len(STATS) == 0:
98
  return gr.update(choices=[], value=None)
99
 
100
  if type == "intersection":
101
- new_possibles_choices = set.intersection(*(set(s) for s in STATS))
102
- elif type == "union":
103
- new_possibles_choices = set.union(*(set(s) for s in STATS))
104
  value = None
105
- if old_stats:
106
- value = list(set.intersection(new_possibles_choices, {old_stats}))
107
  value = value[0] if value else None
108
 
109
  return gr.update(choices=sorted(list(new_possibles_choices)), value=value)
110
 
111
 
112
- def reverse_search(base_folder, possible_datasets, grouping, stat_name):
113
  with ThreadPoolExecutor() as executor:
114
- found_datasets = list(executor.map(lambda dataset: dataset if stat_exists(base_folder, dataset, stat_name, grouping) else None, possible_datasets))
115
  found_datasets = [dataset for dataset in found_datasets if dataset is not None]
116
  return "\n".join(found_datasets)
117
 
@@ -122,46 +117,47 @@ def reverse_search_add(datasets, reverse_search_results):
122
 
123
 
124
 
125
- def stat_exists(base_folder, path, stat_name, group_by):
126
  base_folder = get_datafolder(base_folder)
127
- return base_folder.exists(f"{path}/{group_by}/{stat_name}/stats-merged.json")
128
 
129
- def load_stats(base_folder, path, stat_name, group_by):
 
130
  base_folder = get_datafolder(base_folder)
131
  with base_folder.open(
132
- f"{path}/{group_by}/{stat_name}/stats-merged.json",
133
  ) as f:
134
- json_stat = json.load(f)
135
- # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme
136
- return MetricStatsDict() + MetricStatsDict(init=json_stat)
137
 
138
 
139
- def prepare_non_grouped_data(dataset_path, base_folder, grouping, stat_name, normalization):
140
- stats = load_stats(base_folder, dataset_path, stat_name, grouping)
141
- stats_rounded = defaultdict(lambda: 0)
142
- for key, value in stats.items():
143
- stats_rounded[round(float(key), 2)] += value.total
144
  if normalization:
145
- normalizer = sum(stats_rounded.values())
146
- stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()}
147
  # check that the sum of the values is 1
148
- summed = sum(stats_rounded.values())
149
- return stats_rounded
 
150
 
151
 
152
- def prepare_grouped_data(dataset_path, base_folder, grouping, stat_name, top_k, direction: PARTITION_OPTIONS, regex):
153
- import heapq
154
- regex_compiled = re.compile(regex) if regex else None
155
-
156
- stats = load_stats(base_folder, dataset_path, stat_name, grouping)
157
- stats = {key: value for key, value in stats.items() if not regex or regex_compiled.match(key)}
158
- means = {key: value.mean for key, value in stats.items()}
159
 
 
 
 
 
160
  # Use heap to get top_k keys
161
  if direction == "Top":
162
  keys = heapq.nlargest(top_k, means, key=means.get)
163
  elif direction == "Most frequent (n_docs)":
164
- totals = {key: value.n for key, value in stats.items()}
165
  keys = heapq.nlargest(top_k, totals, key=totals.get)
166
  else:
167
  keys = heapq.nsmallest(top_k, means, key=means.get)
@@ -181,23 +177,23 @@ def set_alpha(color, alpha):
181
  return f"rgba({r}, {g}, {b}, {alpha})"
182
 
183
 
184
-
185
-
186
  def plot_scatter(
187
- histograms: dict[str, dict[float, float]],
188
- stat_name: str,
 
 
189
  normalization: bool,
 
190
  progress: gr.Progress,
191
  ):
192
  fig = go.Figure()
193
 
194
- for i, (name, histogram) in enumerate(progress.tqdm(histograms.items(), total=len(histograms), desc="Plotting...")):
195
- if all(isinstance(k, str) for k in histogram.keys()):
196
- x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])]
197
- else:
198
- x = sorted(histogram.keys())
199
-
200
- y = [histogram[k] for k in x]
201
 
202
  fig.add_trace(
203
  go.Scatter(
@@ -209,14 +205,14 @@ def plot_scatter(
209
  )
210
  )
211
 
212
- xaxis_scale = "log" if stat_name in LOG_SCALE_STATS else "linear"
213
  yaxis_title = "Frequency" if normalization else "Total"
214
 
215
  fig.update_layout(
216
- title=f"Line Plots for {stat_name}",
217
- xaxis_title=stat_name,
218
  yaxis_title=yaxis_title,
219
- xaxis_type=xaxis_scale,
 
220
  width=1200,
221
  height=600,
222
  showlegend=True,
@@ -226,22 +222,33 @@ def plot_scatter(
226
 
227
 
228
  def plot_bars(
229
- histograms: dict[str, list[tuple[str, float]]],
230
- stat_name: str,
 
 
 
 
 
 
231
  progress: gr.Progress,
232
  ):
233
  fig = go.Figure()
 
 
234
 
235
- for i, (name, histogram) in enumerate(progress.tqdm(histograms.items(), total=len(histograms), desc="Plotting...")):
236
- x = [k for k, v in histogram]
237
- y = [v for k, v in histogram]
238
 
239
  fig.add_trace(go.Bar(x=x, y=y, name=name, marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5))))
 
240
 
241
  fig.update_layout(
242
- title=f"Bar Plots for {stat_name}",
243
- xaxis_title=stat_name,
244
- yaxis_title="Mean value",
 
 
245
  autosize=True,
246
  width=1200,
247
  height=600,
@@ -254,33 +261,26 @@ def plot_bars(
254
  def update_graph(
255
  base_folder,
256
  datasets,
257
- stat_name,
258
  grouping,
 
 
 
259
  normalization,
260
  top_k,
261
  direction,
262
  regex,
263
  progress=gr.Progress(),
264
  ):
265
- if len(datasets) <= 0 or not stat_name or not grouping:
266
  return None
267
  # Placeholder for logic to rerender the graph based on the inputs
268
- prepare_fc = (
269
- partial(prepare_non_grouped_data, normalization=normalization)
270
- if grouping == "histogram"
271
- else partial(prepare_grouped_data, top_k=top_k, direction=direction, regex=regex)
272
- )
273
- graph_fc = (
274
- partial(plot_scatter, normalization=normalization)
275
- if grouping == "histogram"
276
- else plot_bars
277
- )
278
 
279
  with ThreadPoolExecutor() as pool:
280
  data = list(
281
  progress.tqdm(
282
  pool.map(
283
- partial(prepare_fc, base_folder=base_folder, stat_name=stat_name, grouping=grouping),
284
  datasets,
285
  ),
286
  total=len(datasets),
@@ -288,30 +288,39 @@ def update_graph(
288
  )
289
  )
290
 
291
- histograms = {path: result for path, result in zip(datasets, data)}
 
 
 
 
 
 
 
 
 
 
 
292
 
293
- return graph_fc(histograms=histograms, stat_name=stat_name, progress=progress), histograms, gr.update(visible=True)
294
 
295
 
296
  # Create the Gradio interface
297
  with gr.Blocks() as demo:
298
  datasets = gr.State([])
299
  exported_data = gr.State([])
300
- stats_headline = gr.Markdown(value="# Stats Exploration")
301
  with gr.Row():
302
  with gr.Column(scale=2):
303
- # Define the multiselect for crawls
304
  with gr.Row():
305
  with gr.Column(scale=1):
306
  base_folder = gr.Textbox(
307
- label="Stats Location",
308
- value="s3://fineweb-stats/summary/",
309
  )
310
  datasets_refetch = gr.Button("Fetch Datasets")
311
 
312
  with gr.Column(scale=1):
313
- regex_select = gr.Text(label="Regex select datasets", value=".*")
314
- regex_button = gr.Button("Filter")
315
  with gr.Row():
316
  datasets_selected = gr.Dropdown(
317
  choices=[],
@@ -323,15 +332,28 @@ with gr.Blocks() as demo:
323
  readme_description = gr.Markdown(
324
  label="Readme",
325
  value="""
326
- Explaination of the tool:
 
 
 
 
 
327
 
328
- Groupings:
329
- - histogram: creates a line plot of values with their occurences. If normalization is on, the values are frequencies summing to 1.
330
- - (fqdn/suffix): creates a bar plot of the mean values of the stats for full qualied domain name/suffix of domain
 
331
  * k: the number of groups to show
332
- * Top/Bottom: the top/bottom k groups are shown
333
- - summary: simply shows the average value of given stat for selected crawls
334
- """,
 
 
 
 
 
 
 
335
  )
336
  with gr.Column(scale=1):
337
  # Define the dropdown for grouping
@@ -340,19 +362,39 @@ Groupings:
340
  label="Grouping",
341
  multiselect=False,
342
  )
343
- # Define the dropdown for stat_name
344
- stat_name_dropdown = gr.Dropdown(
345
  choices=[],
346
- label="Stat name",
347
  multiselect=False,
348
  )
349
 
350
- with gr.Row(visible=False) as histogram_choices:
351
- normalization_checkbox = gr.Checkbox(
352
- label="Normalize",
353
- value=False, # Default value
354
- )
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  with gr.Row(visible=False) as group_choices:
357
  with gr.Column(scale=2):
358
  group_regex = gr.Text(
@@ -374,19 +416,13 @@ Groupings:
374
  "Most frequent (n_docs)",
375
  ],
376
  value="Most frequent (n_docs)",
377
- )
378
-
379
- update_button = gr.Button("Update Graph", variant="primary")
380
- with gr.Row():
381
- export_data_button = gr.Button("Export data", visible=False)
382
- export_data_json = gr.File(visible=False)
383
-
384
- with gr.Row():
385
  # Define the graph output
 
386
  graph_output = gr.Plot(label="Graph")
387
 
388
  with gr.Row():
389
- reverse_search_headline = gr.Markdown(value="# Reverse stats search")
390
 
391
  with gr.Row():
392
  with gr.Column(scale=1):
@@ -396,8 +432,8 @@ Groupings:
396
  label="Grouping",
397
  multiselect=False,
398
  )
399
- # Define the dropdown for stat_name
400
- reverse_stat_name_dropdown = gr.Dropdown(
401
  choices=[],
402
  label="Stat name",
403
  multiselect=False,
@@ -411,9 +447,8 @@ Groupings:
411
  reverse_search_results = gr.Textbox(
412
  label="Found datasets",
413
  lines=10,
414
- placeholder="Found datasets containing the group/stat name. You can modify the selection after search by removing unwanted lines and clicking Add to selection"
415
  )
416
-
417
 
418
 
419
  update_button.click(
@@ -421,22 +456,39 @@ Groupings:
421
  inputs=[
422
  base_folder,
423
  datasets_selected,
424
- stat_name_dropdown,
425
  grouping_dropdown,
 
 
 
426
  normalization_checkbox,
427
  top_select,
428
  direction_checkbox,
429
  group_regex,
430
  ],
431
- outputs=[graph_output, exported_data, export_data_button],
432
  )
433
 
434
- export_data_button.click(
435
- fn=export_data,
436
- inputs=[exported_data, stat_name_dropdown],
437
- outputs=export_data_json,
 
 
 
 
 
 
 
 
 
 
 
 
438
  )
439
 
 
 
440
  datasets_selected.change(
441
  fn=fetch_groups,
442
  inputs=[base_folder, datasets_selected, grouping_dropdown],
@@ -444,20 +496,20 @@ Groupings:
444
  )
445
 
446
  grouping_dropdown.select(
447
- fn=fetch_stats,
448
- inputs=[base_folder, datasets_selected, grouping_dropdown, stat_name_dropdown],
449
- outputs=stat_name_dropdown,
450
  )
451
 
452
  reverse_grouping_dropdown.select(
453
- fn=partial(fetch_stats, type="union"),
454
- inputs=[base_folder, datasets, reverse_grouping_dropdown, reverse_stat_name_dropdown],
455
- outputs=reverse_stat_name_dropdown,
456
  )
457
 
458
  reverse_search_button.click(
459
  fn=reverse_search,
460
- inputs=[base_folder, datasets, reverse_grouping_dropdown, reverse_stat_name_dropdown],
461
  outputs=reverse_search_results,
462
  )
463
 
@@ -478,7 +530,9 @@ Groupings:
478
  if not regex:
479
  return
480
  new_dsts = {run for run in all_runs if re.search(regex, run)}
481
- dst_union = new_dsts.union(selected_runs)
 
 
482
  return gr.update(value=list(dst_union))
483
 
484
  regex_button.click(
@@ -490,19 +544,19 @@ Groupings:
490
  def update_grouping_options(grouping):
491
  if grouping == "histogram":
492
  return {
493
- histogram_choices: gr.Column(visible=True),
494
  group_choices: gr.Column(visible=False),
495
  }
496
  else:
497
  return {
498
- histogram_choices: gr.Column(visible=False),
499
  group_choices: gr.Column(visible=True),
500
  }
501
 
502
  grouping_dropdown.select(
503
  fn=update_grouping_options,
504
  inputs=[grouping_dropdown],
505
- outputs=[histogram_choices, group_choices],
506
  )
507
 
508
 
 
5
  import os
6
  from pathlib import Path
7
  import re
8
+ import heapq
9
  import tempfile
10
  from typing import Literal
11
  import gradio as gr
12
 
13
  from collections import defaultdict
14
+ from datatrove.io import get_datafolder
15
  import plotly.graph_objects as go
16
+ from datatrove.utils.stats import MetricStats, MetricStatsDict
17
  import plotly.express as px
18
+ import tenacity
19
 
20
  import gradio as gr
21
  PARTITION_OPTIONS = Literal[ "Top", "Bottom", "Most frequent (n_docs)"]
22
+ METRICS_LOCATION_DEFAULT = os.getenv("METRICS_LOCATION_DEFAULT", "s3://fineweb-stats/summary/")
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  def find_folders(base_folder, path):
 
35
  )
36
 
37
 
38
+ def find_metrics_folders(base_folder: str):
39
  base_data_folder = get_datafolder(base_folder)
40
+ # First find all metric.json using globing for metric.json
41
+ metrics_merged = base_data_folder.glob("**/metric.json")
42
 
43
+ # Then for each of metrics.merged take the all but last two parts of the path (grouping/metric_name)
44
+ metrics_folders = [str(Path(x).parent.parent.parent) for x in metrics_merged]
45
  # Finally get the unique paths
46
+ return sorted(list(set(metrics_folders)))
47
 
48
 
49
  def fetch_datasets(base_folder: str):
50
+ datasets = sorted(find_metrics_folders(base_folder))
51
  return datasets, gr.update(choices=datasets, value=None), fetch_groups(base_folder, datasets, None, "union")
52
 
53
 
54
+ def export_data(exported_data: MetricStatsDict, metric_name: str):
55
  if not exported_data:
56
  return None
57
  # Assuming exported_data is a dictionary where the key is the dataset name and the value is the data to be exported
58
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, prefix=metric_name, suffix=".json") as temp:
59
+ json.dump({
60
+ name: dt.to_dict()
61
+ for name, dt in exported_data.items()
62
+ }, temp)
63
  temp_path = temp.name
64
  return gr.update(visible=True, value=temp_path)
65
 
 
75
 
76
  if type == "intersection":
77
  new_choices = set.intersection(*(set(g) for g in GROUPS))
78
+ else:
79
  new_choices = set.union(*(set(g) for g in GROUPS))
80
  value = None
81
  if old_groups:
 
86
  return gr.update(choices=sorted(list(new_choices)), value=value)
87
 
88
 
89
+ def fetch_metrics(base_folder, datasets, group, old_metrics, type="intersection"):
90
  with ThreadPoolExecutor() as executor:
91
+ metrics = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets))
92
+ if len(metrics) == 0:
93
  return gr.update(choices=[], value=None)
94
 
95
  if type == "intersection":
96
+ new_possibles_choices = set.intersection(*(set(s) for s in metrics))
97
+ else:
98
+ new_possibles_choices = set.union(*(set(s) for s in metrics))
99
  value = None
100
+ if old_metrics:
101
+ value = list(set.intersection(new_possibles_choices, {old_metrics}))
102
  value = value[0] if value else None
103
 
104
  return gr.update(choices=sorted(list(new_possibles_choices)), value=value)
105
 
106
 
107
+ def reverse_search(base_folder, possible_datasets, grouping, metric_name):
108
  with ThreadPoolExecutor() as executor:
109
+ found_datasets = list(executor.map(lambda dataset: dataset if metric_exists(base_folder, dataset, metric_name, grouping) else None, possible_datasets))
110
  found_datasets = [dataset for dataset in found_datasets if dataset is not None]
111
  return "\n".join(found_datasets)
112
 
 
117
 
118
 
119
 
120
+ def metric_exists(base_folder, path, metric_name, group_by):
121
  base_folder = get_datafolder(base_folder)
122
+ return base_folder.exists(f"{path}/{group_by}/{metric_name}/metric.json")
123
 
124
+ @tenacity.retry(stop=tenacity.stop_after_attempt(5))
125
+ def load_metrics(base_folder, path, metric_name, group_by):
126
  base_folder = get_datafolder(base_folder)
127
  with base_folder.open(
128
+ f"{path}/{group_by}/{metric_name}/metric.json",
129
  ) as f:
130
+ json_metric = json.load(f)
131
+ # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malformed
132
+ return MetricStatsDict.from_dict(json_metric)
133
 
134
 
135
+ def prepare_for_non_grouped_plotting(metric, normalization, rounding):
136
+ metrics_rounded = defaultdict(lambda: 0)
137
+ for key, value in metric.items():
138
+ metrics_rounded[round(float(key), rounding)] += value.total
 
139
  if normalization:
140
+ normalizer = sum(metrics_rounded.values())
141
+ metrics_rounded = {k: v / normalizer for k, v in metrics_rounded.items()}
142
  # check that the sum of the values is 1
143
+ summed = sum(metrics_rounded.values())
144
+ assert abs(summed - 1) < 0.01, summed
145
+ return metrics_rounded
146
 
147
 
148
+ def load_data(dataset_path, base_folder, grouping, metric_name):
149
+ metrics = load_metrics(base_folder, dataset_path, metric_name, grouping)
150
+ return metrics
 
 
 
 
151
 
152
+ def prepare_for_group_plotting(metric, top_k, direction: PARTITION_OPTIONS, regex: str | None, rounding: int):
153
+ regex_compiled = re.compile(regex) if regex else None
154
+ metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)}
155
+ means = {key: round(float(value.mean), rounding) for key, value in metric.items()}
156
  # Use heap to get top_k keys
157
  if direction == "Top":
158
  keys = heapq.nlargest(top_k, means, key=means.get)
159
  elif direction == "Most frequent (n_docs)":
160
+ totals = {key: int(value.n) for key, value in metric.items()}
161
  keys = heapq.nlargest(top_k, totals, key=totals.get)
162
  else:
163
  keys = heapq.nsmallest(top_k, means, key=means.get)
 
177
  return f"rgba({r}, {g}, {b}, {alpha})"
178
 
179
 
 
 
180
  def plot_scatter(
181
+ data: dict[str, dict[float, float]],
182
+ metric_name: str,
183
+ log_scale_x: bool,
184
+ log_scale_y: bool,
185
  normalization: bool,
186
+ rounding: int,
187
  progress: gr.Progress,
188
  ):
189
  fig = go.Figure()
190
 
191
+ # First sort the histograms, by their name
192
+ data = {name: histogram for name, histogram in sorted(data.items())}
193
+ for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
194
+ histogram_prepared = prepare_for_non_grouped_plotting(histogram, normalization, rounding)
195
+ x = sorted(histogram_prepared.keys())
196
+ y = [histogram_prepared[k] for k in x]
 
197
 
198
  fig.add_trace(
199
  go.Scatter(
 
205
  )
206
  )
207
 
 
208
  yaxis_title = "Frequency" if normalization else "Total"
209
 
210
  fig.update_layout(
211
+ title=f"Line Plots for {metric_name}",
212
+ xaxis_title=metric_name,
213
  yaxis_title=yaxis_title,
214
+ xaxis_type="log" if log_scale_x and len(x) > 1 else None,
215
+ yaxis_type="log" if log_scale_y and len(y) > 1 else None,
216
  width=1200,
217
  height=600,
218
  showlegend=True,
 
222
 
223
 
224
  def plot_bars(
225
+ data: dict[str, list[dict[str, float]]],
226
+ metric_name: str,
227
+ top_k: int,
228
+ direction: PARTITION_OPTIONS,
229
+ regex: str | None,
230
+ rounding: int,
231
+ log_scale_x: bool,
232
+ log_scale_y: bool,
233
  progress: gr.Progress,
234
  ):
235
  fig = go.Figure()
236
+ x = []
237
+ y = []
238
 
239
+ for i, (name, histogram) in enumerate(progress.tqdm(data.items(), total=len(data), desc="Plotting...")):
240
+ histogram_prepared = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding)
241
+ x, y = zip(*histogram_prepared)
242
 
243
  fig.add_trace(go.Bar(x=x, y=y, name=name, marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5))))
244
+
245
 
246
  fig.update_layout(
247
+ title=f"Bar Plots for {metric_name}",
248
+ xaxis_title=metric_name,
249
+ yaxis_title="Avg. value",
250
+ xaxis_type="log" if log_scale_x and len(x) > 1 else None,
251
+ yaxis_type="log" if log_scale_y and len(y) > 1 else None,
252
  autosize=True,
253
  width=1200,
254
  height=600,
 
261
  def update_graph(
262
  base_folder,
263
  datasets,
264
+ metric_name,
265
  grouping,
266
+ log_scale_x,
267
+ log_scale_y,
268
+ rounding,
269
  normalization,
270
  top_k,
271
  direction,
272
  regex,
273
  progress=gr.Progress(),
274
  ):
275
+ if len(datasets) <= 0 or not metric_name or not grouping:
276
  return None
277
  # Placeholder for logic to rerender the graph based on the inputs
 
 
 
 
 
 
 
 
 
 
278
 
279
  with ThreadPoolExecutor() as pool:
280
  data = list(
281
  progress.tqdm(
282
  pool.map(
283
+ partial(load_data, base_folder=base_folder, metric_name=metric_name, grouping=grouping),
284
  datasets,
285
  ),
286
  total=len(datasets),
 
288
  )
289
  )
290
 
291
+ data = {path: result for path, result in zip(datasets, data)}
292
+ return plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y, progress), data, export_data(data, metric_name)
293
+
294
+ def plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y, progress=gr.Progress()):
295
+ if rounding is None or top_k is None:
296
+ return None
297
+ graph_fc = (
298
+ partial(plot_scatter, normalization=normalization, rounding=rounding)
299
+ if grouping == "histogram"
300
+ else partial(plot_bars, top_k=top_k, direction=direction, regex=regex, rounding=rounding)
301
+ )
302
+ return graph_fc(data=data, metric_name=metric_name, progress=progress, log_scale_x=log_scale_x, log_scale_y=log_scale_y)
303
 
 
304
 
305
 
306
  # Create the Gradio interface
307
  with gr.Blocks() as demo:
308
  datasets = gr.State([])
309
  exported_data = gr.State([])
310
+ metrics_headline = gr.Markdown(value="# Metrics Exploration")
311
  with gr.Row():
312
  with gr.Column(scale=2):
 
313
  with gr.Row():
314
  with gr.Column(scale=1):
315
  base_folder = gr.Textbox(
316
+ label="Metrics Location",
317
+ value=METRICS_LOCATION_DEFAULT,
318
  )
319
  datasets_refetch = gr.Button("Fetch Datasets")
320
 
321
  with gr.Column(scale=1):
322
+ regex_select = gr.Text(label="Regex filter", value=".*")
323
+ regex_button = gr.Button("Search")
324
  with gr.Row():
325
  datasets_selected = gr.Dropdown(
326
  choices=[],
 
332
  readme_description = gr.Markdown(
333
  label="Readme",
334
  value="""
335
+ ## How to use:
336
+ 1) Specify Metrics location (Stats block `output_folder` without the last path segment) and click "Fetch Datasets"
337
+ 2) Select datasets you are interested in using the dropdown or regex filter
338
+ 3) Specify Grouping (global average/value/fqdn/suffix) and Metric name
339
+ 4) Click "Update Graph"
340
+
341
 
342
+ ## Groupings:
343
+ - **histogram**: Creates a line plot of values with their frequencies. If normalization is on, the frequencies sum to 1.
344
+ * normalize:
345
+ - **(fqdn/suffix)**: Creates a bar plot of the avg. values of the metric for full qualifed domain name/suffix of domain.
346
  * k: the number of groups to show
347
+ * Top/Bottom/Most frequent (n_docs): Groups with the top/bottom k values/most prevalant docs are shown
348
+ - **none**: Shows the average value of given metric
349
+
350
+ ## Reverse search:
351
+ To search for datasets containing a grouping and certain metric, use the Reverse search section.
352
+ Specify the search parameters and click "Search". This will show you found datasets in the "Found datasets" textbox. You can modify the selection after search by removing unwanted lines and clicking "Add to selection".
353
+
354
+ ## Note:
355
+ The data might not be 100% representative, due to the sampling and optimistic merging of the metrics (fqdn/suffix).
356
+ """,
357
  )
358
  with gr.Column(scale=1):
359
  # Define the dropdown for grouping
 
362
  label="Grouping",
363
  multiselect=False,
364
  )
365
+ # Define the dropdown for metric_name
366
+ metric_name_dropdown = gr.Dropdown(
367
  choices=[],
368
+ label="Metric name",
369
  multiselect=False,
370
  )
371
 
 
 
 
 
 
372
 
373
+ update_button = gr.Button("Update Graph", variant="primary")
374
+
375
+ with gr.Row():
376
+ with gr.Column(scale=1):
377
+ log_scale_x_checkbox = gr.Checkbox(
378
+ label="Log scale x",
379
+ value=False,
380
+ )
381
+ log_scale_y_checkbox = gr.Checkbox(
382
+ label="Log scale y",
383
+ value=False,
384
+ )
385
+ rounding = gr.Number(
386
+ label="Rounding",
387
+ value=2,
388
+ )
389
+ normalization_checkbox = gr.Checkbox(
390
+ label="Normalize",
391
+ value=True, # Default value
392
+ visible=False
393
+ )
394
+ with gr.Row():
395
+ # export_data_button = gr.Button("Export data", visible=True, link=export_data_json)
396
+ export_data_json = gr.File(visible=False)
397
+ with gr.Column(scale=4):
398
  with gr.Row(visible=False) as group_choices:
399
  with gr.Column(scale=2):
400
  group_regex = gr.Text(
 
416
  "Most frequent (n_docs)",
417
  ],
418
  value="Most frequent (n_docs)",
419
+ )
 
 
 
 
 
 
 
420
  # Define the graph output
421
+ with gr.Row():
422
  graph_output = gr.Plot(label="Graph")
423
 
424
  with gr.Row():
425
+ reverse_search_headline = gr.Markdown(value="# Reverse metrics search")
426
 
427
  with gr.Row():
428
  with gr.Column(scale=1):
 
432
  label="Grouping",
433
  multiselect=False,
434
  )
435
+ # Define the dropdown for metric_name
436
+ reverse_metric_name_dropdown = gr.Dropdown(
437
  choices=[],
438
  label="Stat name",
439
  multiselect=False,
 
447
  reverse_search_results = gr.Textbox(
448
  label="Found datasets",
449
  lines=10,
450
+ placeholder="Found datasets containing the group/metric name. You can modify the selection after search by removing unwanted lines and clicking Add to selection"
451
  )
 
452
 
453
 
454
  update_button.click(
 
456
  inputs=[
457
  base_folder,
458
  datasets_selected,
459
+ metric_name_dropdown,
460
  grouping_dropdown,
461
+ log_scale_x_checkbox,
462
+ log_scale_y_checkbox,
463
+ rounding,
464
  normalization_checkbox,
465
  top_select,
466
  direction_checkbox,
467
  group_regex,
468
  ],
469
+ outputs=[graph_output, exported_data, export_data_json],
470
  )
471
 
472
+ for inp in [normalization_checkbox, rounding, group_regex, direction_checkbox, top_select, log_scale_x_checkbox, log_scale_y_checkbox]:
473
+ inp.change(
474
+ fn=plot_data,
475
+ inputs=[
476
+ exported_data,
477
+ metric_name_dropdown,
478
+ normalization_checkbox,
479
+ rounding,
480
+ grouping_dropdown,
481
+ top_select,
482
+ direction_checkbox,
483
+ group_regex,
484
+ log_scale_x_checkbox,
485
+ log_scale_y_checkbox,
486
+ ],
487
+ outputs=[graph_output],
488
  )
489
 
490
+
491
+
492
  datasets_selected.change(
493
  fn=fetch_groups,
494
  inputs=[base_folder, datasets_selected, grouping_dropdown],
 
496
  )
497
 
498
  grouping_dropdown.select(
499
+ fn=fetch_metrics,
500
+ inputs=[base_folder, datasets_selected, grouping_dropdown, metric_name_dropdown],
501
+ outputs=metric_name_dropdown,
502
  )
503
 
504
  reverse_grouping_dropdown.select(
505
+ fn=partial(fetch_metrics, type="union"),
506
+ inputs=[base_folder, datasets, reverse_grouping_dropdown, reverse_metric_name_dropdown],
507
+ outputs=reverse_metric_name_dropdown,
508
  )
509
 
510
  reverse_search_button.click(
511
  fn=reverse_search,
512
+ inputs=[base_folder, datasets, reverse_grouping_dropdown, reverse_metric_name_dropdown],
513
  outputs=reverse_search_results,
514
  )
515
 
 
530
  if not regex:
531
  return
532
  new_dsts = {run for run in all_runs if re.search(regex, run)}
533
+ if not new_dsts:
534
+ return gr.update(value=list(selected_runs))
535
+ dst_union = new_dsts.union(selected_runs or [])
536
  return gr.update(value=list(dst_union))
537
 
538
  regex_button.click(
 
544
  def update_grouping_options(grouping):
545
  if grouping == "histogram":
546
  return {
547
+ normalization_checkbox: gr.Column(visible=True),
548
  group_choices: gr.Column(visible=False),
549
  }
550
  else:
551
  return {
552
+ normalization_checkbox: gr.Column(visible=False),
553
  group_choices: gr.Column(visible=True),
554
  }
555
 
556
  grouping_dropdown.select(
557
  fn=update_grouping_options,
558
  inputs=[grouping_dropdown],
559
+ outputs=[normalization_checkbox, group_choices],
560
  )
561
 
562