|
import json |
|
from pathlib import Path |
|
import gradio as gr |
|
|
|
from collections import defaultdict |
|
import fsspec.config |
|
import math |
|
from datatrove.io import DataFolder, get_datafolder |
|
from datatrove.utils.stats import MetricStatsDict |
|
|
|
BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/") |
|
|
|
|
|
def find_folders(base_folder, path): |
|
return sorted( |
|
[ |
|
folder["name"] |
|
for folder in base_folder.ls(path, detail=True) |
|
if folder["type"] == "directory" and not folder["name"].rstrip("/") == path |
|
] |
|
) |
|
|
|
|
|
def find_stats_folders(base_folder: DataFolder): |
|
|
|
stats_merged = base_folder.glob("**/stats-merged.json") |
|
|
|
|
|
stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged] |
|
|
|
return list(set(stats_folders)) |
|
|
|
|
|
RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER)) |
|
print(RUNS) |
|
GROUPS = [Path(x).name for x in find_folders(BASE_DATA_FOLDER, RUNS[0])] |
|
print(GROUPS) |
|
STATS = [ |
|
Path(x).name for x in find_folders(BASE_DATA_FOLDER, str(Path(RUNS[0], GROUPS[0]))) |
|
] |
|
|
|
|
|
def load_stats(path, stat_name, group_by): |
|
with BASE_DATA_FOLDER.open( |
|
f"{path}/{group_by}/{stat_name}/stats-merged.json", |
|
filecache={"cache_storage": "/tmp/files"}, |
|
) as f: |
|
json_stat = json.load(f) |
|
|
|
return MetricStatsDict() + MetricStatsDict(init=json_stat) |
|
|
|
|
|
def prepare_non_grouped_data(stats: MetricStatsDict): |
|
|
|
stats_rounded = defaultdict(lambda: 0) |
|
for key, value in stats.items(): |
|
stats_rounded[float(key)] += value.total |
|
normalizer = sum(stats_rounded.values()) |
|
normalizer = 1 |
|
stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()} |
|
return stats_rounded |
|
|
|
|
|
def prepare_grouped_data(stats: MetricStatsDict, top_k=100): |
|
means = {key: value.mean for key, value in stats.items()} |
|
|
|
|
|
top_keys = sorted(means, key=lambda x: means[x], reverse=True)[:top_k] |
|
return {key: means[key] for key in top_keys} |
|
|
|
|
|
import math |
|
import plotly.graph_objects as go |
|
from plotly.offline import plot |
|
|
|
|
|
def plot_scatter(histograms: dict[str, dict[float, float]], stat_name: str): |
|
fig = go.Figure() |
|
|
|
colors = iter( |
|
[ |
|
"rgba(31, 119, 180, 0.5)", |
|
"rgba(255, 127, 14, 0.5)", |
|
"rgba(44, 160, 44, 0.5)", |
|
"rgba(214, 39, 40, 0.5)", |
|
"rgba(148, 103, 189, 0.5)", |
|
] |
|
) |
|
|
|
for name, histogram in histograms.items(): |
|
if all(isinstance(k, str) for k in histogram.keys()): |
|
x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])] |
|
else: |
|
x = sorted(histogram.keys()) |
|
|
|
y = [histogram[k] for k in x] |
|
|
|
fig.add_trace( |
|
go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=next(colors))) |
|
) |
|
|
|
fig.update_layout( |
|
title=f"Line Plots for {stat_name}", |
|
xaxis_title=stat_name, |
|
yaxis_title="Frequency", |
|
xaxis_type="log", |
|
width=1000, |
|
height=600, |
|
) |
|
|
|
return fig |
|
|
|
|
|
def plot_bars(histograms: dict[str, dict[float, float]], stat_name: str): |
|
fig = go.Figure() |
|
|
|
for name, histogram in histograms.items(): |
|
x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])] |
|
y = [histogram[k] for k in x] |
|
|
|
fig.add_trace(go.Bar(x=x, y=y, name=name)) |
|
|
|
fig.update_layout( |
|
title=f"Bar Plots for {stat_name}", |
|
xaxis_title=stat_name, |
|
yaxis_title="Frequency", |
|
autosize=True, |
|
width=600, |
|
height=600, |
|
) |
|
|
|
return fig |
|
|
|
|
|
def update_graph(multiselect_crawls, stat_name, grouping): |
|
if len(multiselect_crawls) <= 0 or not stat_name or not grouping: |
|
return None |
|
|
|
prepare_fc = ( |
|
prepare_non_grouped_data if grouping == "histogram" else prepare_grouped_data |
|
) |
|
graph_fc = plot_scatter if grouping == "histogram" else plot_bars |
|
|
|
print("Loading stats") |
|
histograms = { |
|
path: prepare_fc(load_stats(path, stat_name, grouping)) |
|
for path in multiselect_crawls |
|
} |
|
|
|
print("Plotting") |
|
return graph_fc(histograms, stat_name) |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
|
|
multiselect_crawls = gr.Dropdown( |
|
choices=RUNS, |
|
label="Multiselect for crawls", |
|
multiselect=True, |
|
) |
|
with gr.Column(scale=1): |
|
|
|
stat_name_dropdown = gr.Dropdown( |
|
choices=STATS, |
|
label="Stat name", |
|
multiselect=False, |
|
) |
|
|
|
grouping_dropdown = gr.Dropdown( |
|
choices=GROUPS, |
|
label="Grouping", |
|
multiselect=False, |
|
) |
|
update_button = gr.Button("Update Graph", variant="primary") |
|
with gr.Row(): |
|
|
|
graph_output = gr.Plot(label="Graph") |
|
|
|
update_button.click( |
|
fn=update_graph, |
|
inputs=[multiselect_crawls, stat_name_dropdown, grouping_dropdown], |
|
outputs=graph_output, |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|