guipenedo HF Staff commited on
Commit
a5f2bd2
·
unverified ·
1 Parent(s): 13ccbad
Files changed (1) hide show
  1. app.py +82 -79
app.py CHANGED
@@ -1,25 +1,23 @@
1
- from concurrent.futures import ThreadPoolExecutor
2
- import enum
3
- from functools import partial
4
  import json
5
  import os
6
- from pathlib import Path
7
  import re
8
- import heapq
9
  import tempfile
 
 
 
 
10
  from typing import Literal
11
- import gradio as gr
12
 
13
- from collections import defaultdict
14
- from datatrove.io import get_datafolder
15
- import plotly.graph_objects as go
16
- from datatrove.utils.stats import MetricStats, MetricStatsDict
17
  import plotly.express as px
 
18
  import tenacity
 
 
19
 
20
- import gradio as gr
21
- PARTITION_OPTIONS = Literal[ "Top", "Bottom", "Most frequent (n_docs)"]
22
- METRICS_LOCATION_DEFAULT = os.getenv("METRICS_LOCATION_DEFAULT", "s3://fineweb-stats/summary/")
23
 
24
 
25
  def find_folders(base_folder, path):
@@ -74,7 +72,7 @@ def fetch_groups(base_folder, datasets, old_groups, type="intersection"):
74
  return gr.update(choices=[], value=None)
75
 
76
  if type == "intersection":
77
- new_choices = set.intersection(*(set(g) for g in GROUPS))
78
  else:
79
  new_choices = set.union(*(set(g) for g in GROUPS))
80
  value = None
@@ -88,7 +86,8 @@ def fetch_groups(base_folder, datasets, old_groups, type="intersection"):
88
 
89
  def fetch_metrics(base_folder, datasets, group, old_metrics, type="intersection"):
90
  with ThreadPoolExecutor() as executor:
91
- metrics = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets))
 
92
  if len(metrics) == 0:
93
  return gr.update(choices=[], value=None)
94
 
@@ -106,7 +105,9 @@ def fetch_metrics(base_folder, datasets, group, old_metrics, type="intersection"
106
 
107
  def reverse_search(base_folder, possible_datasets, grouping, metric_name):
108
  with ThreadPoolExecutor() as executor:
109
- found_datasets = list(executor.map(lambda dataset: dataset if metric_exists(base_folder, dataset, metric_name, grouping) else None, possible_datasets))
 
 
110
  found_datasets = [dataset for dataset in found_datasets if dataset is not None]
111
  return "\n".join(found_datasets)
112
 
@@ -116,16 +117,16 @@ def reverse_search_add(datasets, reverse_search_results):
116
  return sorted(list(set(datasets + reverse_search_results.strip().split("\n"))))
117
 
118
 
119
-
120
  def metric_exists(base_folder, path, metric_name, group_by):
121
  base_folder = get_datafolder(base_folder)
122
  return base_folder.exists(f"{path}/{group_by}/{metric_name}/metric.json")
123
 
 
124
  @tenacity.retry(stop=tenacity.stop_after_attempt(5))
125
  def load_metrics(base_folder, path, metric_name, group_by):
126
  base_folder = get_datafolder(base_folder)
127
  with base_folder.open(
128
- f"{path}/{group_by}/{metric_name}/metric.json",
129
  ) as f:
130
  json_metric = json.load(f)
131
  # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malformed
@@ -149,6 +150,7 @@ def load_data(dataset_path, base_folder, grouping, metric_name):
149
  metrics = load_metrics(base_folder, dataset_path, metric_name, grouping)
150
  return metrics
151
 
 
152
  def prepare_for_group_plotting(metric, top_k, direction: PARTITION_OPTIONS, regex: str | None, rounding: int):
153
  regex_compiled = re.compile(regex) if regex else None
154
  metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)}
@@ -162,7 +164,6 @@ def prepare_for_group_plotting(metric, top_k, direction: PARTITION_OPTIONS, rege
162
  else:
163
  keys = heapq.nsmallest(top_k, means, key=means.get)
164
 
165
-
166
  means = [means[key] for key in keys]
167
  stds = [metric[key].standard_deviation for key in keys]
168
  return keys, means, stds
@@ -181,13 +182,13 @@ def set_alpha(color, alpha):
181
 
182
 
183
  def plot_scatter(
184
- data: dict[str, dict[float, float]],
185
- metric_name: str,
186
- log_scale_x: bool,
187
- log_scale_y: bool,
188
- normalization: bool,
189
- rounding: int,
190
- progress: gr.Progress,
191
  ):
192
  fig = go.Figure()
193
 
@@ -225,15 +226,15 @@ def plot_scatter(
225
 
226
 
227
  def plot_bars(
228
- data: dict[str, list[dict[str, float]]],
229
- metric_name: str,
230
- top_k: int,
231
- direction: PARTITION_OPTIONS,
232
- regex: str | None,
233
- rounding: int,
234
- log_scale_x: bool,
235
- log_scale_y: bool,
236
- progress: gr.Progress,
237
  ):
238
  fig = go.Figure()
239
  x = []
@@ -243,9 +244,9 @@ def plot_bars(
243
  x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding)
244
 
245
  fig.add_trace(go.Bar(
246
- x=x,
247
- y=y,
248
- name=f"{name} Mean",
249
  marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
250
  error_y=dict(type='data', array=stds, visible=True)
251
  ))
@@ -266,18 +267,18 @@ def plot_bars(
266
 
267
 
268
  def update_graph(
269
- base_folder,
270
- datasets,
271
- metric_name,
272
- grouping,
273
- log_scale_x,
274
- log_scale_y,
275
- rounding,
276
- normalization,
277
- top_k,
278
- direction,
279
- regex,
280
- progress=gr.Progress(),
281
  ):
282
  if len(datasets) <= 0 or not metric_name or not grouping:
283
  return None
@@ -296,9 +297,12 @@ def update_graph(
296
  )
297
 
298
  data = {path: result for path, result in zip(datasets, data)}
299
- return plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y, progress), data, export_data(data, metric_name)
 
 
300
 
301
- def plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y, progress=gr.Progress()):
 
302
  if rounding is None or top_k is None:
303
  return None
304
  graph_fc = (
@@ -306,8 +310,8 @@ def plot_data(data, metric_name, normalization, rounding, grouping, top_k, direc
306
  if grouping == "histogram"
307
  else partial(plot_bars, top_k=top_k, direction=direction, regex=regex, rounding=rounding)
308
  )
309
- return graph_fc(data=data, metric_name=metric_name, progress=progress, log_scale_x=log_scale_x, log_scale_y=log_scale_y)
310
-
311
 
312
 
313
  # Create the Gradio interface
@@ -376,7 +380,6 @@ The data might not be 100% representative, due to the sampling and optimistic me
376
  multiselect=False,
377
  )
378
 
379
-
380
  update_button = gr.Button("Update Graph", variant="primary")
381
 
382
  with gr.Row():
@@ -414,7 +417,7 @@ The data might not be 100% representative, due to the sampling and optimistic me
414
  value=100,
415
  interactive=True,
416
  )
417
-
418
  direction_checkbox = gr.Radio(
419
  label="Partition",
420
  choices=[
@@ -423,14 +426,14 @@ The data might not be 100% representative, due to the sampling and optimistic me
423
  "Most frequent (n_docs)",
424
  ],
425
  value="Most frequent (n_docs)",
426
- )
427
  # Define the graph output
428
  with gr.Row():
429
  graph_output = gr.Plot(label="Graph")
430
-
431
  with gr.Row():
432
  reverse_search_headline = gr.Markdown(value="# Reverse metrics search")
433
-
434
  with gr.Row():
435
  with gr.Column(scale=1):
436
  # Define the dropdown for grouping
@@ -445,7 +448,7 @@ The data might not be 100% representative, due to the sampling and optimistic me
445
  label="Stat name",
446
  multiselect=False,
447
  )
448
-
449
  with gr.Column(scale=1):
450
  reverse_search_button = gr.Button("Search")
451
  reverse_search_add_button = gr.Button("Add to selection")
@@ -457,7 +460,6 @@ The data might not be 100% representative, due to the sampling and optimistic me
457
  placeholder="Found datasets containing the group/metric name. You can modify the selection after search by removing unwanted lines and clicking Add to selection"
458
  )
459
 
460
-
461
  update_button.click(
462
  fn=update_graph,
463
  inputs=[
@@ -476,25 +478,24 @@ The data might not be 100% representative, due to the sampling and optimistic me
476
  outputs=[graph_output, exported_data, export_data_json],
477
  )
478
 
479
- for inp in [normalization_checkbox, rounding, group_regex, direction_checkbox, top_select, log_scale_x_checkbox, log_scale_y_checkbox]:
 
480
  inp.change(
481
  fn=plot_data,
482
  inputs=[
483
- exported_data,
484
- metric_name_dropdown,
485
- normalization_checkbox,
486
- rounding,
487
- grouping_dropdown,
488
- top_select,
489
- direction_checkbox,
490
- group_regex,
491
- log_scale_x_checkbox,
492
- log_scale_y_checkbox,
493
- ],
494
- outputs=[graph_output],
495
- )
496
-
497
-
498
 
499
  datasets_selected.change(
500
  fn=fetch_groups,
@@ -526,13 +527,13 @@ The data might not be 100% representative, due to the sampling and optimistic me
526
  outputs=datasets_selected,
527
  )
528
 
529
-
530
  datasets_refetch.click(
531
  fn=fetch_datasets,
532
  inputs=[base_folder],
533
  outputs=[datasets, datasets_selected, reverse_grouping_dropdown],
534
  )
535
 
 
536
  def update_datasets_with_regex(regex, selected_runs, all_runs):
537
  if not regex:
538
  return
@@ -542,12 +543,14 @@ The data might not be 100% representative, due to the sampling and optimistic me
542
  dst_union = new_dsts.union(selected_runs or [])
543
  return gr.update(value=sorted(list(dst_union)))
544
 
 
545
  regex_button.click(
546
  fn=update_datasets_with_regex,
547
  inputs=[regex_select, datasets_selected, datasets],
548
  outputs=datasets_selected,
549
  )
550
 
 
551
  def update_grouping_options(grouping):
552
  if grouping == "histogram":
553
  return {
@@ -560,13 +563,13 @@ The data might not be 100% representative, due to the sampling and optimistic me
560
  group_choices: gr.Column(visible=True),
561
  }
562
 
 
563
  grouping_dropdown.select(
564
  fn=update_grouping_options,
565
  inputs=[grouping_dropdown],
566
  outputs=[normalization_checkbox, group_choices],
567
  )
568
 
569
-
570
  # Launch the application
571
  if __name__ == "__main__":
572
  demo.launch()
 
1
+ import heapq
 
 
2
  import json
3
  import os
 
4
  import re
 
5
  import tempfile
6
+ from collections import defaultdict
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from functools import partial
9
+ from pathlib import Path
10
  from typing import Literal
 
11
 
12
+ import gradio as gr
 
 
 
13
  import plotly.express as px
14
+ import plotly.graph_objects as go
15
  import tenacity
16
+ from datatrove.io import get_datafolder
17
+ from datatrove.utils.stats import MetricStatsDict
18
 
19
+ PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"]
20
+ METRICS_LOCATION_DEFAULT = os.getenv("METRICS_LOCATION_DEFAULT", "hf://datasets/HuggingFaceFW-Dev/summary-stats-files")
 
21
 
22
 
23
  def find_folders(base_folder, path):
 
72
  return gr.update(choices=[], value=None)
73
 
74
  if type == "intersection":
75
+ new_choices = set.intersection(*(set(g) for g in GROUPS))
76
  else:
77
  new_choices = set.union(*(set(g) for g in GROUPS))
78
  value = None
 
86
 
87
  def fetch_metrics(base_folder, datasets, group, old_metrics, type="intersection"):
88
  with ThreadPoolExecutor() as executor:
89
+ metrics = list(
90
+ executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets))
91
  if len(metrics) == 0:
92
  return gr.update(choices=[], value=None)
93
 
 
105
 
106
  def reverse_search(base_folder, possible_datasets, grouping, metric_name):
107
  with ThreadPoolExecutor() as executor:
108
+ found_datasets = list(executor.map(
109
+ lambda dataset: dataset if metric_exists(base_folder, dataset, metric_name, grouping) else None,
110
+ possible_datasets))
111
  found_datasets = [dataset for dataset in found_datasets if dataset is not None]
112
  return "\n".join(found_datasets)
113
 
 
117
  return sorted(list(set(datasets + reverse_search_results.strip().split("\n"))))
118
 
119
 
 
120
  def metric_exists(base_folder, path, metric_name, group_by):
121
  base_folder = get_datafolder(base_folder)
122
  return base_folder.exists(f"{path}/{group_by}/{metric_name}/metric.json")
123
 
124
+
125
  @tenacity.retry(stop=tenacity.stop_after_attempt(5))
126
  def load_metrics(base_folder, path, metric_name, group_by):
127
  base_folder = get_datafolder(base_folder)
128
  with base_folder.open(
129
+ f"{path}/{group_by}/{metric_name}/metric.json",
130
  ) as f:
131
  json_metric = json.load(f)
132
  # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malformed
 
150
  metrics = load_metrics(base_folder, dataset_path, metric_name, grouping)
151
  return metrics
152
 
153
+
154
  def prepare_for_group_plotting(metric, top_k, direction: PARTITION_OPTIONS, regex: str | None, rounding: int):
155
  regex_compiled = re.compile(regex) if regex else None
156
  metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)}
 
164
  else:
165
  keys = heapq.nsmallest(top_k, means, key=means.get)
166
 
 
167
  means = [means[key] for key in keys]
168
  stds = [metric[key].standard_deviation for key in keys]
169
  return keys, means, stds
 
182
 
183
 
184
  def plot_scatter(
185
+ data: dict[str, dict[float, float]],
186
+ metric_name: str,
187
+ log_scale_x: bool,
188
+ log_scale_y: bool,
189
+ normalization: bool,
190
+ rounding: int,
191
+ progress: gr.Progress,
192
  ):
193
  fig = go.Figure()
194
 
 
226
 
227
 
228
  def plot_bars(
229
+ data: dict[str, list[dict[str, float]]],
230
+ metric_name: str,
231
+ top_k: int,
232
+ direction: PARTITION_OPTIONS,
233
+ regex: str | None,
234
+ rounding: int,
235
+ log_scale_x: bool,
236
+ log_scale_y: bool,
237
+ progress: gr.Progress,
238
  ):
239
  fig = go.Figure()
240
  x = []
 
244
  x, y, stds = prepare_for_group_plotting(histogram, top_k, direction, regex, rounding)
245
 
246
  fig.add_trace(go.Bar(
247
+ x=x,
248
+ y=y,
249
+ name=f"{name} Mean",
250
  marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
251
  error_y=dict(type='data', array=stds, visible=True)
252
  ))
 
267
 
268
 
269
  def update_graph(
270
+ base_folder,
271
+ datasets,
272
+ metric_name,
273
+ grouping,
274
+ log_scale_x,
275
+ log_scale_y,
276
+ rounding,
277
+ normalization,
278
+ top_k,
279
+ direction,
280
+ regex,
281
+ progress=gr.Progress(),
282
  ):
283
  if len(datasets) <= 0 or not metric_name or not grouping:
284
  return None
 
297
  )
298
 
299
  data = {path: result for path, result in zip(datasets, data)}
300
+ return plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x,
301
+ log_scale_y, progress), data, export_data(data, metric_name)
302
+
303
 
304
+ def plot_data(data, metric_name, normalization, rounding, grouping, top_k, direction, regex, log_scale_x, log_scale_y,
305
+ progress=gr.Progress()):
306
  if rounding is None or top_k is None:
307
  return None
308
  graph_fc = (
 
310
  if grouping == "histogram"
311
  else partial(plot_bars, top_k=top_k, direction=direction, regex=regex, rounding=rounding)
312
  )
313
+ return graph_fc(data=data, metric_name=metric_name, progress=progress, log_scale_x=log_scale_x,
314
+ log_scale_y=log_scale_y)
315
 
316
 
317
  # Create the Gradio interface
 
380
  multiselect=False,
381
  )
382
 
 
383
  update_button = gr.Button("Update Graph", variant="primary")
384
 
385
  with gr.Row():
 
417
  value=100,
418
  interactive=True,
419
  )
420
+
421
  direction_checkbox = gr.Radio(
422
  label="Partition",
423
  choices=[
 
426
  "Most frequent (n_docs)",
427
  ],
428
  value="Most frequent (n_docs)",
429
+ )
430
  # Define the graph output
431
  with gr.Row():
432
  graph_output = gr.Plot(label="Graph")
433
+
434
  with gr.Row():
435
  reverse_search_headline = gr.Markdown(value="# Reverse metrics search")
436
+
437
  with gr.Row():
438
  with gr.Column(scale=1):
439
  # Define the dropdown for grouping
 
448
  label="Stat name",
449
  multiselect=False,
450
  )
451
+
452
  with gr.Column(scale=1):
453
  reverse_search_button = gr.Button("Search")
454
  reverse_search_add_button = gr.Button("Add to selection")
 
460
  placeholder="Found datasets containing the group/metric name. You can modify the selection after search by removing unwanted lines and clicking Add to selection"
461
  )
462
 
 
463
  update_button.click(
464
  fn=update_graph,
465
  inputs=[
 
478
  outputs=[graph_output, exported_data, export_data_json],
479
  )
480
 
481
+ for inp in [normalization_checkbox, rounding, group_regex, direction_checkbox, top_select, log_scale_x_checkbox,
482
+ log_scale_y_checkbox]:
483
  inp.change(
484
  fn=plot_data,
485
  inputs=[
486
+ exported_data,
487
+ metric_name_dropdown,
488
+ normalization_checkbox,
489
+ rounding,
490
+ grouping_dropdown,
491
+ top_select,
492
+ direction_checkbox,
493
+ group_regex,
494
+ log_scale_x_checkbox,
495
+ log_scale_y_checkbox,
496
+ ],
497
+ outputs=[graph_output],
498
+ )
 
 
499
 
500
  datasets_selected.change(
501
  fn=fetch_groups,
 
527
  outputs=datasets_selected,
528
  )
529
 
 
530
  datasets_refetch.click(
531
  fn=fetch_datasets,
532
  inputs=[base_folder],
533
  outputs=[datasets, datasets_selected, reverse_grouping_dropdown],
534
  )
535
 
536
+
537
  def update_datasets_with_regex(regex, selected_runs, all_runs):
538
  if not regex:
539
  return
 
543
  dst_union = new_dsts.union(selected_runs or [])
544
  return gr.update(value=sorted(list(dst_union)))
545
 
546
+
547
  regex_button.click(
548
  fn=update_datasets_with_regex,
549
  inputs=[regex_select, datasets_selected, datasets],
550
  outputs=datasets_selected,
551
  )
552
 
553
+
554
  def update_grouping_options(grouping):
555
  if grouping == "histogram":
556
  return {
 
563
  group_choices: gr.Column(visible=True),
564
  }
565
 
566
+
567
  grouping_dropdown.select(
568
  fn=update_grouping_options,
569
  inputs=[grouping_dropdown],
570
  outputs=[normalization_checkbox, group_choices],
571
  )
572
 
 
573
  # Launch the application
574
  if __name__ == "__main__":
575
  demo.launch()