hynky HF Staff commited on
Commit
3cb4732
·
1 Parent(s): f5e1a8f
Files changed (2) hide show
  1. app.py +191 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ import gradio as gr
4
+
5
+ from collections import defaultdict
6
+ import fsspec.config
7
+ import math
8
+ from datatrove.io import DataFolder, get_datafolder
9
+ from datatrove.utils.stats import MetricStatsDict
10
+
11
+ BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/")
12
+
13
+
14
+ def find_folders(base_folder, path):
15
+ return sorted(
16
+ [
17
+ folder["name"]
18
+ for folder in base_folder.ls(path, detail=True)
19
+ if folder["type"] == "directory" and not folder["name"].rstrip("/") == path
20
+ ]
21
+ )
22
+
23
+
24
+ def find_stats_folders(base_folder: DataFolder):
25
+ # First find all stats-merged.json using globing for stats-merged.json
26
+ stats_merged = base_folder.glob("**/stats-merged.json")
27
+
28
+ # Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name)
29
+ stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged]
30
+ # Finally get the unique paths
31
+ return list(set(stats_folders))
32
+
33
+
34
+ RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER))
35
+ print(RUNS)
36
+ GROUPS = [Path(x).name for x in find_folders(BASE_DATA_FOLDER, RUNS[0])]
37
+ print(GROUPS)
38
+ STATS = [
39
+ Path(x).name for x in find_folders(BASE_DATA_FOLDER, str(Path(RUNS[0], GROUPS[0])))
40
+ ]
41
+
42
+
43
+ def load_stats(path, stat_name, group_by):
44
+ with BASE_DATA_FOLDER.open(
45
+ f"{path}/{group_by}/{stat_name}/stats-merged.json",
46
+ filecache={"cache_storage": "/tmp/files"},
47
+ ) as f:
48
+ json_stat = json.load(f)
49
+ # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme
50
+ return MetricStatsDict() + MetricStatsDict(init=json_stat)
51
+
52
+
53
+ def prepare_non_grouped_data(stats: MetricStatsDict):
54
+
55
+ stats_rounded = defaultdict(lambda: 0)
56
+ for key, value in stats.items():
57
+ stats_rounded[float(key)] += value.total
58
+ normalizer = sum(stats_rounded.values())
59
+ normalizer = 1
60
+ stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()}
61
+ return stats_rounded
62
+
63
+
64
+ def prepare_grouped_data(stats: MetricStatsDict, top_k=100):
65
+ means = {key: value.mean for key, value in stats.items()}
66
+
67
+ # Take the top_k most frequent keys
68
+ top_keys = sorted(means, key=lambda x: means[x], reverse=True)[:top_k]
69
+ return {key: means[key] for key in top_keys}
70
+
71
+
72
+ import math
73
+ import plotly.graph_objects as go
74
+ from plotly.offline import plot
75
+
76
+
77
+ def plot_scatter(histograms: dict[str, dict[float, float]], stat_name: str):
78
+ fig = go.Figure()
79
+
80
+ colors = iter(
81
+ [
82
+ "rgba(31, 119, 180, 0.5)",
83
+ "rgba(255, 127, 14, 0.5)",
84
+ "rgba(44, 160, 44, 0.5)",
85
+ "rgba(214, 39, 40, 0.5)",
86
+ "rgba(148, 103, 189, 0.5)",
87
+ ]
88
+ )
89
+
90
+ for name, histogram in histograms.items():
91
+ if all(isinstance(k, str) for k in histogram.keys()):
92
+ x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])]
93
+ else:
94
+ x = sorted(histogram.keys())
95
+
96
+ y = [histogram[k] for k in x]
97
+
98
+ fig.add_trace(
99
+ go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=next(colors)))
100
+ )
101
+
102
+ fig.update_layout(
103
+ title=f"Line Plots for {stat_name}",
104
+ xaxis_title=stat_name,
105
+ yaxis_title="Frequency",
106
+ xaxis_type="log",
107
+ width=1000,
108
+ height=600,
109
+ )
110
+
111
+ return fig
112
+
113
+
114
+ def plot_bars(histograms: dict[str, dict[float, float]], stat_name: str):
115
+ fig = go.Figure()
116
+
117
+ for name, histogram in histograms.items():
118
+ x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])]
119
+ y = [histogram[k] for k in x]
120
+
121
+ fig.add_trace(go.Bar(x=x, y=y, name=name))
122
+
123
+ fig.update_layout(
124
+ title=f"Bar Plots for {stat_name}",
125
+ xaxis_title=stat_name,
126
+ yaxis_title="Frequency",
127
+ autosize=True,
128
+ width=600,
129
+ height=600,
130
+ )
131
+
132
+ return fig
133
+
134
+
135
+ def update_graph(multiselect_crawls, stat_name, grouping):
136
+ if len(multiselect_crawls) <= 0 or not stat_name or not grouping:
137
+ return None
138
+ # Placeholder for logic to rerender the graph based on the inputs
139
+ prepare_fc = (
140
+ prepare_non_grouped_data if grouping == "histogram" else prepare_grouped_data
141
+ )
142
+ graph_fc = plot_scatter if grouping == "histogram" else plot_bars
143
+
144
+ print("Loading stats")
145
+ histograms = {
146
+ path: prepare_fc(load_stats(path, stat_name, grouping))
147
+ for path in multiselect_crawls
148
+ }
149
+
150
+ print("Plotting")
151
+ return graph_fc(histograms, stat_name)
152
+
153
+
154
+ # Create the Gradio interface
155
+ with gr.Blocks() as demo:
156
+ with gr.Row():
157
+ with gr.Column(scale=2):
158
+ # Define the multiselect for crawls
159
+ multiselect_crawls = gr.Dropdown(
160
+ choices=RUNS,
161
+ label="Multiselect for crawls",
162
+ multiselect=True,
163
+ )
164
+ with gr.Column(scale=1):
165
+ # Define the dropdown for stat_name
166
+ stat_name_dropdown = gr.Dropdown(
167
+ choices=STATS,
168
+ label="Stat name",
169
+ multiselect=False,
170
+ )
171
+ # Define the dropdown for grouping
172
+ grouping_dropdown = gr.Dropdown(
173
+ choices=GROUPS,
174
+ label="Grouping",
175
+ multiselect=False,
176
+ )
177
+ update_button = gr.Button("Update Graph", variant="primary")
178
+ with gr.Row():
179
+ # Define the graph output
180
+ graph_output = gr.Plot(label="Graph")
181
+
182
+ update_button.click(
183
+ fn=update_graph,
184
+ inputs=[multiselect_crawls, stat_name_dropdown, grouping_dropdown],
185
+ outputs=graph_output,
186
+ )
187
+
188
+
189
+ # Launch the application
190
+ if __name__ == "__main__":
191
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ datatrove
3
+ plotly