hynky HF Staff commited on
Commit
6c72e3f
·
1 Parent(s): 219feb6

add readme

Browse files
Files changed (1) hide show
  1. app.py +164 -30
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import json
2
  from pathlib import Path
3
  import gradio as gr
@@ -9,6 +10,13 @@ from datatrove.io import DataFolder, get_datafolder
9
  from datatrove.utils.stats import MetricStatsDict
10
 
11
  BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/")
 
 
 
 
 
 
 
12
 
13
 
14
  def find_folders(base_folder, path):
@@ -32,10 +40,41 @@ def find_stats_folders(base_folder: DataFolder):
32
 
33
 
34
  RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER))
35
- GROUPS = [Path(x).name for x in find_folders(BASE_DATA_FOLDER, RUNS[0])]
36
- STATS = [
37
- Path(x).name for x in find_folders(BASE_DATA_FOLDER, str(Path(RUNS[0], GROUPS[0])))
38
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  def load_stats(path, stat_name, group_by):
@@ -48,23 +87,29 @@ def load_stats(path, stat_name, group_by):
48
  return MetricStatsDict() + MetricStatsDict(init=json_stat)
49
 
50
 
51
- def prepare_non_grouped_data(stats: MetricStatsDict):
52
-
53
  stats_rounded = defaultdict(lambda: 0)
54
  for key, value in stats.items():
55
  stats_rounded[float(key)] += value.total
56
- normalizer = sum(stats_rounded.values())
57
- normalizer = 1
58
- stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()}
59
  return stats_rounded
60
 
61
 
62
- def prepare_grouped_data(stats: MetricStatsDict, top_k=100):
 
 
63
  means = {key: value.mean for key, value in stats.items()}
64
 
65
- # Take the top_k most frequent keys
66
- top_keys = sorted(means, key=lambda x: means[x], reverse=True)[:top_k]
67
- return {key: means[key] for key in top_keys}
 
 
 
 
 
68
 
69
 
70
  import math
@@ -72,7 +117,9 @@ import plotly.graph_objects as go
72
  from plotly.offline import plot
73
 
74
 
75
- def plot_scatter(histograms: dict[str, dict[float, float]], stat_name: str):
 
 
76
  fig = go.Figure()
77
 
78
  colors = iter(
@@ -82,6 +129,10 @@ def plot_scatter(histograms: dict[str, dict[float, float]], stat_name: str):
82
  "rgba(44, 160, 44, 0.5)",
83
  "rgba(214, 39, 40, 0.5)",
84
  "rgba(148, 103, 189, 0.5)",
 
 
 
 
85
  ]
86
  )
87
 
@@ -97,12 +148,15 @@ def plot_scatter(histograms: dict[str, dict[float, float]], stat_name: str):
97
  go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=next(colors)))
98
  )
99
 
 
 
 
100
  fig.update_layout(
101
  title=f"Line Plots for {stat_name}",
102
  xaxis_title=stat_name,
103
- yaxis_title="Frequency",
104
- xaxis_type="log",
105
- width=1000,
106
  height=600,
107
  )
108
 
@@ -121,23 +175,31 @@ def plot_bars(histograms: dict[str, dict[float, float]], stat_name: str):
121
  fig.update_layout(
122
  title=f"Bar Plots for {stat_name}",
123
  xaxis_title=stat_name,
124
- yaxis_title="Frequency",
125
  autosize=True,
126
- width=600,
127
  height=600,
128
  )
129
 
130
  return fig
131
 
132
 
133
- def update_graph(multiselect_crawls, stat_name, grouping):
 
 
134
  if len(multiselect_crawls) <= 0 or not stat_name or not grouping:
135
  return None
136
  # Placeholder for logic to rerender the graph based on the inputs
137
  prepare_fc = (
138
- prepare_non_grouped_data if grouping == "histogram" else prepare_grouped_data
 
 
 
 
 
 
 
139
  )
140
- graph_fc = plot_scatter if grouping == "histogram" else plot_bars
141
 
142
  print("Loading stats")
143
  histograms = {
@@ -159,19 +221,54 @@ with gr.Blocks() as demo:
159
  label="Multiselect for crawls",
160
  multiselect=True,
161
  )
162
- with gr.Column(scale=1):
163
- # Define the dropdown for stat_name
164
- stat_name_dropdown = gr.Dropdown(
165
- choices=STATS,
166
- label="Stat name",
167
- multiselect=False,
 
 
 
 
 
 
 
 
 
 
 
168
  )
 
169
  # Define the dropdown for grouping
170
  grouping_dropdown = gr.Dropdown(
171
- choices=GROUPS,
172
  label="Grouping",
173
  multiselect=False,
174
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  update_button = gr.Button("Update Graph", variant="primary")
176
  with gr.Row():
177
  # Define the graph output
@@ -179,10 +276,47 @@ with gr.Blocks() as demo:
179
 
180
  update_button.click(
181
  fn=update_graph,
182
- inputs=[multiselect_crawls, stat_name_dropdown, grouping_dropdown],
 
 
 
 
 
 
 
183
  outputs=graph_output,
184
  )
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  # Launch the application
188
  if __name__ == "__main__":
 
1
+ from functools import partial
2
  import json
3
  from pathlib import Path
4
  import gradio as gr
 
10
  from datatrove.utils.stats import MetricStatsDict
11
 
12
  BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/")
13
+ LOG_SCALE_STATS = {
14
+ "length",
15
+ "n_lines",
16
+ "n_docs",
17
+ "avg_words_per_line",
18
+ "pages_with_lorem_ipsum",
19
+ }
20
 
21
 
22
  def find_folders(base_folder, path):
 
40
 
41
 
42
  RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER))
43
+
44
+
45
+ def fetch_groups(runs, old_groups):
46
+ GROUPS = [
47
+ [Path(x).name for x in find_folders(BASE_DATA_FOLDER, run)] for run in runs
48
+ ]
49
+ # DO the intersection
50
+ if len(GROUPS) == 0:
51
+ return gr.update(choices=[], value=None)
52
+
53
+ new_choices = set.intersection(*(set(g) for g in GROUPS))
54
+ value = None
55
+ if old_groups:
56
+ value = list(set.intersection(new_choices, {old_groups}))
57
+ value = value[0] if value else None
58
+
59
+ # now take the intersection of all grups
60
+ return gr.update(choices=list(new_choices), value=value)
61
+
62
+
63
+ def fetch_stats(runs, group, old_stats):
64
+ STATS = [
65
+ [Path(x).name for x in find_folders(BASE_DATA_FOLDER, f"{run}/{group}")]
66
+ for run in runs
67
+ ]
68
+ if len(STATS) == 0:
69
+ return gr.update(choices=[], value=None)
70
+
71
+ new_possibles_choices = set.intersection(*(set(s) for s in STATS))
72
+ value = None
73
+ if old_stats:
74
+ value = list(set.intersection(new_possibles_choices, {old_stats}))
75
+ value = value[0] if value else None
76
+
77
+ return gr.update(choices=list(new_possibles_choices), value=value)
78
 
79
 
80
  def load_stats(path, stat_name, group_by):
 
87
  return MetricStatsDict() + MetricStatsDict(init=json_stat)
88
 
89
 
90
+ def prepare_non_grouped_data(stats: MetricStatsDict, normalization):
 
91
  stats_rounded = defaultdict(lambda: 0)
92
  for key, value in stats.items():
93
  stats_rounded[float(key)] += value.total
94
+ if normalization:
95
+ normalizer = sum(stats_rounded.values())
96
+ stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()}
97
  return stats_rounded
98
 
99
 
100
+ def prepare_grouped_data(stats: MetricStatsDict, top_k, direction):
101
+ import heapq
102
+
103
  means = {key: value.mean for key, value in stats.items()}
104
 
105
+ # Use heap to get top_k keys
106
+ if direction == "Top":
107
+ keys = heapq.nlargest(top_k, means, key=means.get)
108
+ else:
109
+ keys = heapq.nsmallest(top_k, means, key=means.get)
110
+ print(keys)
111
+
112
+ return {key: means[key] for key in keys}
113
 
114
 
115
  import math
 
117
  from plotly.offline import plot
118
 
119
 
120
+ def plot_scatter(
121
+ histograms: dict[str, dict[float, float]], stat_name: str, normalization: bool
122
+ ):
123
  fig = go.Figure()
124
 
125
  colors = iter(
 
129
  "rgba(44, 160, 44, 0.5)",
130
  "rgba(214, 39, 40, 0.5)",
131
  "rgba(148, 103, 189, 0.5)",
132
+ "rgba(227, 119, 194, 0.5)",
133
+ "rgba(127, 127, 127, 0.5)",
134
+ "rgba(188, 189, 34, 0.5)",
135
+ "rgba(23, 190, 207, 0.5)",
136
  ]
137
  )
138
 
 
148
  go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=next(colors)))
149
  )
150
 
151
+ xaxis_scale = "log" if stat_name in LOG_SCALE_STATS else "linear"
152
+ yaxis_title = "Frequency" if normalization else "Total"
153
+
154
  fig.update_layout(
155
  title=f"Line Plots for {stat_name}",
156
  xaxis_title=stat_name,
157
+ yaxis_title=yaxis_title,
158
+ xaxis_type=xaxis_scale,
159
+ width=1200,
160
  height=600,
161
  )
162
 
 
175
  fig.update_layout(
176
  title=f"Bar Plots for {stat_name}",
177
  xaxis_title=stat_name,
178
+ yaxis_title="Mean value",
179
  autosize=True,
180
+ width=1200,
181
  height=600,
182
  )
183
 
184
  return fig
185
 
186
 
187
+ def update_graph(
188
+ multiselect_crawls, stat_name, grouping, normalization, top_k, direction
189
+ ):
190
  if len(multiselect_crawls) <= 0 or not stat_name or not grouping:
191
  return None
192
  # Placeholder for logic to rerender the graph based on the inputs
193
  prepare_fc = (
194
+ partial(prepare_non_grouped_data, normalization=normalization)
195
+ if grouping == "histogram"
196
+ else partial(prepare_grouped_data, top_k=top_k, direction=direction)
197
+ )
198
+ graph_fc = (
199
+ partial(plot_scatter, normalization=normalization)
200
+ if grouping == "histogram"
201
+ else plot_bars
202
  )
 
203
 
204
  print("Loading stats")
205
  histograms = {
 
221
  label="Multiselect for crawls",
222
  multiselect=True,
223
  )
224
+ # add a readme description
225
+ readme_description = gr.Markdown(
226
+ label="Readme",
227
+ value="""
228
+ Explaination of the tool:
229
+
230
+ Groupings:
231
+ - histogram: creates a line plot of values with their occurences. If normalization is on, the values are frequencies summing to 1.
232
+ - (fqdn/suffix): creates a bar plot of the mean values of the stats for full qualied domain name/suffix of domain
233
+ * k: the number of groups to show
234
+ * Top/Bottom: the top/bottom k groups are shown
235
+ - summary: simply shows the average value of given stat for selected crawls
236
+
237
+
238
+
239
+
240
+ """,
241
  )
242
+ with gr.Column(scale=1):
243
  # Define the dropdown for grouping
244
  grouping_dropdown = gr.Dropdown(
245
+ choices=[],
246
  label="Grouping",
247
  multiselect=False,
248
  )
249
+ # Define the dropdown for stat_name
250
+ stat_name_dropdown = gr.Dropdown(
251
+ choices=[],
252
+ label="Stat name",
253
+ multiselect=False,
254
+ )
255
+ with gr.Row(visible=False) as histogram_choices:
256
+ normalization_checkbox = gr.Checkbox(
257
+ label="Normalize",
258
+ value=False, # Default value
259
+ )
260
+
261
+ with gr.Row(visible=False) as group_choices:
262
+ top_select = gr.Number(
263
+ label="K",
264
+ value=100,
265
+ interactive=True,
266
+ )
267
+ direction_checkbox = gr.Radio(
268
+ label="Partition",
269
+ choices=["Top", "Bottom"],
270
+ )
271
+
272
  update_button = gr.Button("Update Graph", variant="primary")
273
  with gr.Row():
274
  # Define the graph output
 
276
 
277
  update_button.click(
278
  fn=update_graph,
279
+ inputs=[
280
+ multiselect_crawls,
281
+ stat_name_dropdown,
282
+ grouping_dropdown,
283
+ normalization_checkbox,
284
+ top_select,
285
+ direction_checkbox,
286
+ ],
287
  outputs=graph_output,
288
  )
289
 
290
+ multiselect_crawls.select(
291
+ fn=fetch_groups,
292
+ inputs=[multiselect_crawls, grouping_dropdown],
293
+ outputs=grouping_dropdown,
294
+ )
295
+
296
+ grouping_dropdown.select(
297
+ fn=fetch_stats,
298
+ inputs=[multiselect_crawls, grouping_dropdown, stat_name_dropdown],
299
+ outputs=stat_name_dropdown,
300
+ )
301
+
302
+ def update_grouping_options(grouping):
303
+ if grouping == "histogram":
304
+ return {
305
+ histogram_choices: gr.Column(visible=True),
306
+ group_choices: gr.Column(visible=False),
307
+ }
308
+ else:
309
+ return {
310
+ histogram_choices: gr.Column(visible=False),
311
+ group_choices: gr.Column(visible=True),
312
+ }
313
+
314
+ grouping_dropdown.select(
315
+ fn=update_grouping_options,
316
+ inputs=[grouping_dropdown],
317
+ outputs=[histogram_choices, group_choices],
318
+ )
319
+
320
 
321
  # Launch the application
322
  if __name__ == "__main__":