hynky HF Staff commited on
Commit
4668859
·
1 Parent(s): 745c60b

regex + folder definition + export

Browse files
Files changed (1) hide show
  1. app.py +147 -82
app.py CHANGED
@@ -1,15 +1,22 @@
 
 
1
  from functools import partial
2
  import json
3
  from pathlib import Path
 
 
 
4
  import gradio as gr
5
 
6
  from collections import defaultdict
7
- import fsspec.config
8
- import math
9
  from datatrove.io import DataFolder, get_datafolder
 
10
  from datatrove.utils.stats import MetricStatsDict
 
 
 
 
11
 
12
- BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/")
13
  LOG_SCALE_STATS = {
14
  "length",
15
  "n_lines",
@@ -18,35 +25,10 @@ LOG_SCALE_STATS = {
18
  "avg_words_per_line",
19
  "pages_with_lorem_ipsum",
20
  }
21
- colors = list(
22
- [
23
- "rgba(31, 119, 180, 0.5)",
24
- "rgba(255, 127, 14, 0.5)",
25
- "rgba(44, 160, 44, 0.5)",
26
- "rgba(214, 39, 40, 0.5)",
27
- "rgba(148, 103, 189, 0.5)",
28
- "rgba(227, 119, 194, 0.5)",
29
- "rgba(127, 127, 127, 0.5)",
30
- "rgba(188, 189, 34, 0.5)",
31
- "rgba(23, 190, 207, 0.5)",
32
- "rgba(255, 193, 7, 0.5)",
33
- "rgba(40, 167, 69, 0.5)",
34
- "rgba(23, 162, 184, 0.5)",
35
- "rgba(108, 117, 125, 0.5)",
36
- "rgba(0, 123, 255, 0.5)",
37
- "rgba(220, 53, 69, 0.5)",
38
- "rgba(255, 159, 67, 0.5)",
39
- "rgba(255, 87, 34, 0.5)",
40
- "rgba(41, 182, 246, 0.5)",
41
- "rgba(142, 36, 170, 0.5)",
42
- "rgba(0, 188, 212, 0.5)",
43
- "rgba(255, 235, 59, 0.5)",
44
- "rgba(156, 39, 176, 0.5)",
45
- ]
46
- )
47
 
48
 
49
  def find_folders(base_folder, path):
 
50
  return sorted(
51
  [
52
  folder["name"]
@@ -56,9 +38,10 @@ def find_folders(base_folder, path):
56
  )
57
 
58
 
59
- def find_stats_folders(base_folder: DataFolder):
 
60
  # First find all stats-merged.json using globing for stats-merged.json
61
- stats_merged = base_folder.glob("**/stats-merged.json")
62
 
63
  # Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name)
64
  stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged]
@@ -66,14 +49,25 @@ def find_stats_folders(base_folder: DataFolder):
66
  return sorted(list(set(stats_folders)))
67
 
68
 
69
- RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER))
 
 
70
 
71
 
72
- def fetch_groups(runs, old_groups):
 
 
 
 
 
 
 
 
 
 
73
  GROUPS = [
74
- [Path(x).name for x in find_folders(BASE_DATA_FOLDER, run)] for run in runs
75
  ]
76
- # DO the intersection
77
  if len(GROUPS) == 0:
78
  return gr.update(choices=[], value=None)
79
 
@@ -84,13 +78,13 @@ def fetch_groups(runs, old_groups):
84
  value = value[0] if value else None
85
 
86
  # now take the intersection of all grups
87
- return gr.update(choices=list(new_choices), value=value)
88
 
89
 
90
- def fetch_stats(runs, group, old_stats):
91
  STATS = [
92
- [Path(x).name for x in find_folders(BASE_DATA_FOLDER, f"{run}/{group}")]
93
- for run in runs
94
  ]
95
  if len(STATS) == 0:
96
  return gr.update(choices=[], value=None)
@@ -101,21 +95,21 @@ def fetch_stats(runs, group, old_stats):
101
  value = list(set.intersection(new_possibles_choices, {old_stats}))
102
  value = value[0] if value else None
103
 
104
- return gr.update(choices=list(new_possibles_choices), value=value)
105
 
106
 
107
- def load_stats(path, stat_name, group_by):
108
- with BASE_DATA_FOLDER.open(
 
109
  f"{path}/{group_by}/{stat_name}/stats-merged.json",
110
- filecache={"cache_storage": "/tmp/files"},
111
  ) as f:
112
  json_stat = json.load(f)
113
  # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme
114
  return MetricStatsDict() + MetricStatsDict(init=json_stat)
115
 
116
 
117
- def prepare_non_grouped_data(path, stat_name, grouping, normalization):
118
- stats = load_stats(path, stat_name, grouping)
119
  stats_rounded = defaultdict(lambda: 0)
120
  for key, value in stats.items():
121
  stats_rounded[float(key)] += value.total
@@ -125,10 +119,10 @@ def prepare_non_grouped_data(path, stat_name, grouping, normalization):
125
  return stats_rounded
126
 
127
 
128
- def prepare_grouped_data(path, stat_name, grouping, top_k, direction):
129
  import heapq
130
 
131
- stats = load_stats(path, stat_name, grouping)
132
 
133
  means = {key: value.mean for key, value in stats.items()}
134
 
@@ -136,13 +130,7 @@ def prepare_grouped_data(path, stat_name, grouping, top_k, direction):
136
  if direction == "Top":
137
  keys = heapq.nlargest(top_k, means, key=means.get)
138
  elif direction == "Most frequent (n_docs)":
139
- n_docs = load_stats(path, "n_docs", grouping)
140
- totals = {key: value.total for key, value in n_docs.items()}
141
- keys = heapq.nlargest(top_k, totals, key=totals.get)
142
-
143
- elif direction == "Most frequent (length)":
144
- n_docs = load_stats(path, "length", grouping)
145
- totals = {key: value.total for key, value in n_docs.items()}
146
  keys = heapq.nlargest(top_k, totals, key=totals.get)
147
  else:
148
  keys = heapq.nsmallest(top_k, means, key=means.get)
@@ -150,17 +138,29 @@ def prepare_grouped_data(path, stat_name, grouping, top_k, direction):
150
  return [(key, means[key]) for key in keys]
151
 
152
 
153
- import math
154
- import plotly.graph_objects as go
155
- from plotly.offline import plot
 
 
 
 
 
 
 
 
 
156
 
157
 
158
  def plot_scatter(
159
- histograms: dict[str, dict[float, float]], stat_name: str, normalization: bool
 
 
 
160
  ):
161
  fig = go.Figure()
162
 
163
- for i, (name, histogram) in enumerate(histograms.items()):
164
  if all(isinstance(k, str) for k in histogram.keys()):
165
  x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])]
166
  else:
@@ -174,7 +174,7 @@ def plot_scatter(
174
  y=y,
175
  mode="lines",
176
  name=name,
177
- line=dict(color=colors[i % len(colors)]),
178
  )
179
  )
180
 
@@ -194,14 +194,18 @@ def plot_scatter(
194
  return fig
195
 
196
 
197
- def plot_bars(histograms: dict[str, list[tuple[str, float]]], stat_name: str):
 
 
 
 
198
  fig = go.Figure()
199
 
200
- for i, (name, histogram) in enumerate(histograms.items()):
201
  x = [k for k, v in histogram]
202
  y = [v for k, v in histogram]
203
 
204
- fig.add_trace(go.Bar(x=x, y=y, name=name, marker_color=colors[i % len(colors)]))
205
 
206
  fig.update_layout(
207
  title=f"Bar Plots for {stat_name}",
@@ -217,9 +221,16 @@ def plot_bars(histograms: dict[str, list[tuple[str, float]]], stat_name: str):
217
 
218
 
219
  def update_graph(
220
- multiselect_crawls, stat_name, grouping, normalization, top_k, direction
 
 
 
 
 
 
 
221
  ):
222
- if len(multiselect_crawls) <= 0 or not stat_name or not grouping:
223
  return None
224
  # Placeholder for logic to rerender the graph based on the inputs
225
  prepare_fc = (
@@ -233,25 +244,48 @@ def update_graph(
233
  else plot_bars
234
  )
235
 
236
- print("Loading stats")
237
- histograms = {
238
- path: prepare_fc(path, stat_name, grouping) for path in multiselect_crawls
239
- }
 
 
 
 
 
 
 
 
 
240
 
241
- print("Plotting")
242
- return graph_fc(histograms, stat_name)
243
 
244
 
245
  # Create the Gradio interface
246
  with gr.Blocks() as demo:
 
 
247
  with gr.Row():
248
  with gr.Column(scale=2):
249
  # Define the multiselect for crawls
250
- multiselect_crawls = gr.Dropdown(
251
- choices=RUNS,
252
- label="Multiselect for crawls",
253
- multiselect=True,
254
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  # add a readme description
256
  readme_description = gr.Markdown(
257
  label="Readme",
@@ -283,6 +317,7 @@ Groupings:
283
  label="Stat name",
284
  multiselect=False,
285
  )
 
286
  with gr.Row(visible=False) as histogram_choices:
287
  normalization_checkbox = gr.Checkbox(
288
  label="Normalize",
@@ -301,11 +336,15 @@ Groupings:
301
  "Top",
302
  "Bottom",
303
  "Most frequent (n_docs)",
304
- "Most frequent (length)",
305
  ],
 
306
  )
307
 
308
  update_button = gr.Button("Update Graph", variant="primary")
 
 
 
 
309
  with gr.Row():
310
  # Define the graph output
311
  graph_output = gr.Plot(label="Graph")
@@ -313,28 +352,54 @@ Groupings:
313
  update_button.click(
314
  fn=update_graph,
315
  inputs=[
316
- multiselect_crawls,
 
317
  stat_name_dropdown,
318
  grouping_dropdown,
319
  normalization_checkbox,
320
  top_select,
321
  direction_checkbox,
322
  ],
323
- outputs=graph_output,
 
 
 
 
 
 
324
  )
325
 
326
- multiselect_crawls.select(
327
  fn=fetch_groups,
328
- inputs=[multiselect_crawls, grouping_dropdown],
329
  outputs=grouping_dropdown,
330
  )
331
 
332
  grouping_dropdown.select(
333
  fn=fetch_stats,
334
- inputs=[multiselect_crawls, grouping_dropdown, stat_name_dropdown],
335
  outputs=stat_name_dropdown,
336
  )
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  def update_grouping_options(grouping):
339
  if grouping == "histogram":
340
  return {
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import enum
3
  from functools import partial
4
  import json
5
  from pathlib import Path
6
+ import re
7
+ import tempfile
8
+ from typing import Literal
9
  import gradio as gr
10
 
11
  from collections import defaultdict
 
 
12
  from datatrove.io import DataFolder, get_datafolder
13
+ import plotly.graph_objects as go
14
  from datatrove.utils.stats import MetricStatsDict
15
+ import plotly.express as px
16
+
17
+ import gradio as gr
18
+ PARTITION_OPTIONS = Literal[ "Top", "Bottom", "Most frequent (n_docs)"]
19
 
 
20
  LOG_SCALE_STATS = {
21
  "length",
22
  "n_lines",
 
25
  "avg_words_per_line",
26
  "pages_with_lorem_ipsum",
27
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
 
30
  def find_folders(base_folder, path):
31
+ base_folder = get_datafolder(base_folder)
32
  return sorted(
33
  [
34
  folder["name"]
 
38
  )
39
 
40
 
41
+ def find_stats_folders(base_folder: str):
42
+ base_data_folder = get_datafolder(base_folder)
43
  # First find all stats-merged.json using globing for stats-merged.json
44
+ stats_merged = base_data_folder.glob("**/stats-merged.json")
45
 
46
  # Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name)
47
  stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged]
 
49
  return sorted(list(set(stats_folders)))
50
 
51
 
52
+ def fetch_runs(base_folder: str):
53
+ runs = sorted(find_stats_folders(base_folder))
54
+ return runs, gr.update(choices=runs, value=None)
55
 
56
 
57
+ def export_data(exported_data):
58
+ if not exported_data:
59
+ return None
60
+ # Assuming exported_data is a dictionary where the key is the dataset name and the value is the data to be exported
61
+ with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as temp:
62
+ json.dump(exported_data, temp)
63
+ temp_path = temp.name
64
+ return gr.update(visible=True, value=temp_path)
65
+
66
+
67
+ def fetch_groups(base_folder, datasets, old_groups):
68
  GROUPS = [
69
+ [Path(x).name for x in find_folders(base_folder, run)] for run in datasets
70
  ]
 
71
  if len(GROUPS) == 0:
72
  return gr.update(choices=[], value=None)
73
 
 
78
  value = value[0] if value else None
79
 
80
  # now take the intersection of all grups
81
+ return gr.update(choices=sorted(list(new_choices)), value=value)
82
 
83
 
84
+ def fetch_stats(base_folder, datasets, group, old_stats):
85
  STATS = [
86
+ [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")]
87
+ for run in datasets
88
  ]
89
  if len(STATS) == 0:
90
  return gr.update(choices=[], value=None)
 
95
  value = list(set.intersection(new_possibles_choices, {old_stats}))
96
  value = value[0] if value else None
97
 
98
+ return gr.update(choices=sorted(list(new_possibles_choices)), value=value)
99
 
100
 
101
+ def load_stats(base_folder, path, stat_name, group_by):
102
+ base_folder = get_datafolder(base_folder)
103
+ with base_folder.open(
104
  f"{path}/{group_by}/{stat_name}/stats-merged.json",
 
105
  ) as f:
106
  json_stat = json.load(f)
107
  # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme
108
  return MetricStatsDict() + MetricStatsDict(init=json_stat)
109
 
110
 
111
+ def prepare_non_grouped_data(dataset_path, base_folder, grouping, stat_name, normalization):
112
+ stats = load_stats(base_folder, dataset_path, stat_name, grouping)
113
  stats_rounded = defaultdict(lambda: 0)
114
  for key, value in stats.items():
115
  stats_rounded[float(key)] += value.total
 
119
  return stats_rounded
120
 
121
 
122
+ def prepare_grouped_data(dataset_path, base_folder, grouping, stat_name, top_k, direction: PARTITION_OPTIONS):
123
  import heapq
124
 
125
+ stats = load_stats(base_folder, dataset_path, stat_name, grouping)
126
 
127
  means = {key: value.mean for key, value in stats.items()}
128
 
 
130
  if direction == "Top":
131
  keys = heapq.nlargest(top_k, means, key=means.get)
132
  elif direction == "Most frequent (n_docs)":
133
+ totals = {key: value.n for key, value in stats.items()}
 
 
 
 
 
 
134
  keys = heapq.nlargest(top_k, totals, key=totals.get)
135
  else:
136
  keys = heapq.nsmallest(top_k, means, key=means.get)
 
138
  return [(key, means[key]) for key in keys]
139
 
140
 
141
+ def set_alpha(color, alpha):
142
+ """
143
+ Takes a hex color and returns
144
+ rgba(r, g, b, a)
145
+ """
146
+ if color.startswith('#'):
147
+ r, g, b = int(color[1:3], 16), int(color[3:5], 16), int(color[5:7], 16)
148
+ else:
149
+ r, g, b = 0, 0, 0 # Fallback to black if the color format is not recognized
150
+ return f"rgba({r}, {g}, {b}, {alpha})"
151
+
152
+
153
 
154
 
155
  def plot_scatter(
156
+ histograms: dict[str, dict[float, float]],
157
+ stat_name: str,
158
+ normalization: bool,
159
+ progress: gr.Progress,
160
  ):
161
  fig = go.Figure()
162
 
163
+ for i, (name, histogram) in enumerate(progress.tqdm(histograms.items(), total=len(histograms), desc="Plotting...")):
164
  if all(isinstance(k, str) for k in histogram.keys()):
165
  x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])]
166
  else:
 
174
  y=y,
175
  mode="lines",
176
  name=name,
177
+ marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)),
178
  )
179
  )
180
 
 
194
  return fig
195
 
196
 
197
+ def plot_bars(
198
+ histograms: dict[str, list[tuple[str, float]]],
199
+ stat_name: str,
200
+ progress: gr.Progress,
201
+ ):
202
  fig = go.Figure()
203
 
204
+ for i, (name, histogram) in enumerate(progress.tqdm(histograms.items(), total=len(histograms), desc="Plotting...")):
205
  x = [k for k, v in histogram]
206
  y = [v for k, v in histogram]
207
 
208
+ fig.add_trace(go.Bar(x=x, y=y, name=name, marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5))))
209
 
210
  fig.update_layout(
211
  title=f"Bar Plots for {stat_name}",
 
221
 
222
 
223
  def update_graph(
224
+ base_folder,
225
+ datasets,
226
+ stat_name,
227
+ grouping,
228
+ normalization,
229
+ top_k,
230
+ direction,
231
+ progress=gr.Progress(),
232
  ):
233
+ if len(datasets) <= 0 or not stat_name or not grouping:
234
  return None
235
  # Placeholder for logic to rerender the graph based on the inputs
236
  prepare_fc = (
 
244
  else plot_bars
245
  )
246
 
247
+ with ThreadPoolExecutor() as pool:
248
+ data = list(
249
+ progress.tqdm(
250
+ pool.map(
251
+ partial(prepare_fc, base_folder=base_folder, stat_name=stat_name, grouping=grouping),
252
+ datasets,
253
+ ),
254
+ total=len(datasets),
255
+ desc="Loading data...",
256
+ )
257
+ )
258
+
259
+ histograms = {path: result for path, result in zip(datasets, data)}
260
 
261
+ return graph_fc(histograms=histograms, stat_name=stat_name, progress=progress), histograms, gr.update(visible=True)
 
262
 
263
 
264
  # Create the Gradio interface
265
  with gr.Blocks() as demo:
266
+ datasets = gr.State([])
267
+ exported_data = gr.State([])
268
  with gr.Row():
269
  with gr.Column(scale=2):
270
  # Define the multiselect for crawls
271
+ with gr.Row():
272
+ with gr.Column(scale=1):
273
+ stats_folder = gr.Textbox(
274
+ label="Stats Location",
275
+ value="s3://fineweb-stats/summary/",
276
+ )
277
+ datasets_refetch = gr.Button("Fetch Datasets")
278
+
279
+ with gr.Column(scale=1):
280
+ regex_select = gr.Text(label="Regex select datasets", value=".*")
281
+ regex_button = gr.Button("Filter")
282
+ with gr.Row():
283
+ datasets_selected = gr.Dropdown(
284
+ choices=[],
285
+ label="Datasets",
286
+ multiselect=True,
287
+ )
288
+
289
  # add a readme description
290
  readme_description = gr.Markdown(
291
  label="Readme",
 
317
  label="Stat name",
318
  multiselect=False,
319
  )
320
+
321
  with gr.Row(visible=False) as histogram_choices:
322
  normalization_checkbox = gr.Checkbox(
323
  label="Normalize",
 
336
  "Top",
337
  "Bottom",
338
  "Most frequent (n_docs)",
 
339
  ],
340
+ value="Top",
341
  )
342
 
343
  update_button = gr.Button("Update Graph", variant="primary")
344
+ with gr.Row():
345
+ export_data_button = gr.Button("Export data", visible=False)
346
+ export_data_json = gr.File(visible=False)
347
+
348
  with gr.Row():
349
  # Define the graph output
350
  graph_output = gr.Plot(label="Graph")
 
352
  update_button.click(
353
  fn=update_graph,
354
  inputs=[
355
+ stats_folder,
356
+ datasets_selected,
357
  stat_name_dropdown,
358
  grouping_dropdown,
359
  normalization_checkbox,
360
  top_select,
361
  direction_checkbox,
362
  ],
363
+ outputs=[graph_output, exported_data, export_data_button],
364
+ )
365
+
366
+ export_data_button.click(
367
+ fn=export_data,
368
+ inputs=[exported_data],
369
+ outputs=export_data_json,
370
  )
371
 
372
+ datasets_selected.select(
373
  fn=fetch_groups,
374
+ inputs=[stats_folder, datasets_selected, grouping_dropdown],
375
  outputs=grouping_dropdown,
376
  )
377
 
378
  grouping_dropdown.select(
379
  fn=fetch_stats,
380
+ inputs=[stats_folder, datasets_selected, grouping_dropdown, stat_name_dropdown],
381
  outputs=stat_name_dropdown,
382
  )
383
 
384
+ datasets_refetch.click(
385
+ fn=fetch_runs,
386
+ inputs=[stats_folder],
387
+ outputs=[datasets, datasets_selected],
388
+ )
389
+
390
+ def update_datasets_with_regex(regex, selected_runs, all_runs):
391
+ if not regex:
392
+ return
393
+ new_dsts = {run for run in all_runs if re.search(regex, run)}
394
+ dst_union = new_dsts.union(selected_runs)
395
+ return gr.update(value=list(dst_union))
396
+
397
+ regex_button.click(
398
+ fn=update_datasets_with_regex,
399
+ inputs=[regex_select, datasets_selected, datasets],
400
+ outputs=datasets_selected,
401
+ )
402
+
403
  def update_grouping_options(grouping):
404
  if grouping == "histogram":
405
  return {