File size: 5,647 Bytes
3cb4732 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import json
from pathlib import Path
import gradio as gr
from collections import defaultdict
import fsspec.config
import math
from datatrove.io import DataFolder, get_datafolder
from datatrove.utils.stats import MetricStatsDict
BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/")
def find_folders(base_folder, path):
return sorted(
[
folder["name"]
for folder in base_folder.ls(path, detail=True)
if folder["type"] == "directory" and not folder["name"].rstrip("/") == path
]
)
def find_stats_folders(base_folder: DataFolder):
# First find all stats-merged.json using globing for stats-merged.json
stats_merged = base_folder.glob("**/stats-merged.json")
# Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name)
stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged]
# Finally get the unique paths
return list(set(stats_folders))
RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER))
GROUPS = [Path(x).name for x in find_folders(BASE_DATA_FOLDER, RUNS[0])]
STATS = [
Path(x).name for x in find_folders(BASE_DATA_FOLDER, str(Path(RUNS[0], GROUPS[0])))
]
def load_stats(path, stat_name, group_by):
with BASE_DATA_FOLDER.open(
f"{path}/{group_by}/{stat_name}/stats-merged.json",
filecache={"cache_storage": "/tmp/files"},
) as f:
json_stat = json.load(f)
# No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme
return MetricStatsDict() + MetricStatsDict(init=json_stat)
def prepare_non_grouped_data(stats: MetricStatsDict):
stats_rounded = defaultdict(lambda: 0)
for key, value in stats.items():
stats_rounded[float(key)] += value.total
normalizer = sum(stats_rounded.values())
normalizer = 1
stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()}
return stats_rounded
def prepare_grouped_data(stats: MetricStatsDict, top_k=100):
means = {key: value.mean for key, value in stats.items()}
# Take the top_k most frequent keys
top_keys = sorted(means, key=lambda x: means[x], reverse=True)[:top_k]
return {key: means[key] for key in top_keys}
import math
import plotly.graph_objects as go
from plotly.offline import plot
def plot_scatter(histograms: dict[str, dict[float, float]], stat_name: str):
fig = go.Figure()
colors = iter(
[
"rgba(31, 119, 180, 0.5)",
"rgba(255, 127, 14, 0.5)",
"rgba(44, 160, 44, 0.5)",
"rgba(214, 39, 40, 0.5)",
"rgba(148, 103, 189, 0.5)",
]
)
for name, histogram in histograms.items():
if all(isinstance(k, str) for k in histogram.keys()):
x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])]
else:
x = sorted(histogram.keys())
y = [histogram[k] for k in x]
fig.add_trace(
go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=next(colors)))
)
fig.update_layout(
title=f"Line Plots for {stat_name}",
xaxis_title=stat_name,
yaxis_title="Frequency",
xaxis_type="log",
width=1000,
height=600,
)
return fig
def plot_bars(histograms: dict[str, dict[float, float]], stat_name: str):
fig = go.Figure()
for name, histogram in histograms.items():
x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])]
y = [histogram[k] for k in x]
fig.add_trace(go.Bar(x=x, y=y, name=name))
fig.update_layout(
title=f"Bar Plots for {stat_name}",
xaxis_title=stat_name,
yaxis_title="Frequency",
autosize=True,
width=600,
height=600,
)
return fig
def update_graph(multiselect_crawls, stat_name, grouping):
if len(multiselect_crawls) <= 0 or not stat_name or not grouping:
return None
# Placeholder for logic to rerender the graph based on the inputs
prepare_fc = (
prepare_non_grouped_data if grouping == "histogram" else prepare_grouped_data
)
graph_fc = plot_scatter if grouping == "histogram" else plot_bars
print("Loading stats")
histograms = {
path: prepare_fc(load_stats(path, stat_name, grouping))
for path in multiselect_crawls
}
print("Plotting")
return graph_fc(histograms, stat_name)
# Create the Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2):
# Define the multiselect for crawls
multiselect_crawls = gr.Dropdown(
choices=RUNS,
label="Multiselect for crawls",
multiselect=True,
)
with gr.Column(scale=1):
# Define the dropdown for stat_name
stat_name_dropdown = gr.Dropdown(
choices=STATS,
label="Stat name",
multiselect=False,
)
# Define the dropdown for grouping
grouping_dropdown = gr.Dropdown(
choices=GROUPS,
label="Grouping",
multiselect=False,
)
update_button = gr.Button("Update Graph", variant="primary")
with gr.Row():
# Define the graph output
graph_output = gr.Plot(label="Graph")
update_button.click(
fn=update_graph,
inputs=[multiselect_crawls, stat_name_dropdown, grouping_dropdown],
outputs=graph_output,
)
# Launch the application
if __name__ == "__main__":
demo.launch()
|