hynky's picture
hynky HF Staff
demo
3cb4732
raw
history blame
5.67 kB
import json
from pathlib import Path
import gradio as gr
from collections import defaultdict
import fsspec.config
import math
from datatrove.io import DataFolder, get_datafolder
from datatrove.utils.stats import MetricStatsDict
BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/")
def find_folders(base_folder, path):
return sorted(
[
folder["name"]
for folder in base_folder.ls(path, detail=True)
if folder["type"] == "directory" and not folder["name"].rstrip("/") == path
]
)
def find_stats_folders(base_folder: DataFolder):
# First find all stats-merged.json using globing for stats-merged.json
stats_merged = base_folder.glob("**/stats-merged.json")
# Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name)
stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged]
# Finally get the unique paths
return list(set(stats_folders))
RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER))
print(RUNS)
GROUPS = [Path(x).name for x in find_folders(BASE_DATA_FOLDER, RUNS[0])]
print(GROUPS)
STATS = [
Path(x).name for x in find_folders(BASE_DATA_FOLDER, str(Path(RUNS[0], GROUPS[0])))
]
def load_stats(path, stat_name, group_by):
with BASE_DATA_FOLDER.open(
f"{path}/{group_by}/{stat_name}/stats-merged.json",
filecache={"cache_storage": "/tmp/files"},
) as f:
json_stat = json.load(f)
# No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme
return MetricStatsDict() + MetricStatsDict(init=json_stat)
def prepare_non_grouped_data(stats: MetricStatsDict):
stats_rounded = defaultdict(lambda: 0)
for key, value in stats.items():
stats_rounded[float(key)] += value.total
normalizer = sum(stats_rounded.values())
normalizer = 1
stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()}
return stats_rounded
def prepare_grouped_data(stats: MetricStatsDict, top_k=100):
means = {key: value.mean for key, value in stats.items()}
# Take the top_k most frequent keys
top_keys = sorted(means, key=lambda x: means[x], reverse=True)[:top_k]
return {key: means[key] for key in top_keys}
import math
import plotly.graph_objects as go
from plotly.offline import plot
def plot_scatter(histograms: dict[str, dict[float, float]], stat_name: str):
fig = go.Figure()
colors = iter(
[
"rgba(31, 119, 180, 0.5)",
"rgba(255, 127, 14, 0.5)",
"rgba(44, 160, 44, 0.5)",
"rgba(214, 39, 40, 0.5)",
"rgba(148, 103, 189, 0.5)",
]
)
for name, histogram in histograms.items():
if all(isinstance(k, str) for k in histogram.keys()):
x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])]
else:
x = sorted(histogram.keys())
y = [histogram[k] for k in x]
fig.add_trace(
go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=next(colors)))
)
fig.update_layout(
title=f"Line Plots for {stat_name}",
xaxis_title=stat_name,
yaxis_title="Frequency",
xaxis_type="log",
width=1000,
height=600,
)
return fig
def plot_bars(histograms: dict[str, dict[float, float]], stat_name: str):
fig = go.Figure()
for name, histogram in histograms.items():
x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])]
y = [histogram[k] for k in x]
fig.add_trace(go.Bar(x=x, y=y, name=name))
fig.update_layout(
title=f"Bar Plots for {stat_name}",
xaxis_title=stat_name,
yaxis_title="Frequency",
autosize=True,
width=600,
height=600,
)
return fig
def update_graph(multiselect_crawls, stat_name, grouping):
if len(multiselect_crawls) <= 0 or not stat_name or not grouping:
return None
# Placeholder for logic to rerender the graph based on the inputs
prepare_fc = (
prepare_non_grouped_data if grouping == "histogram" else prepare_grouped_data
)
graph_fc = plot_scatter if grouping == "histogram" else plot_bars
print("Loading stats")
histograms = {
path: prepare_fc(load_stats(path, stat_name, grouping))
for path in multiselect_crawls
}
print("Plotting")
return graph_fc(histograms, stat_name)
# Create the Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2):
# Define the multiselect for crawls
multiselect_crawls = gr.Dropdown(
choices=RUNS,
label="Multiselect for crawls",
multiselect=True,
)
with gr.Column(scale=1):
# Define the dropdown for stat_name
stat_name_dropdown = gr.Dropdown(
choices=STATS,
label="Stat name",
multiselect=False,
)
# Define the dropdown for grouping
grouping_dropdown = gr.Dropdown(
choices=GROUPS,
label="Grouping",
multiselect=False,
)
update_button = gr.Button("Update Graph", variant="primary")
with gr.Row():
# Define the graph output
graph_output = gr.Plot(label="Graph")
update_button.click(
fn=update_graph,
inputs=[multiselect_crawls, stat_name_dropdown, grouping_dropdown],
outputs=graph_output,
)
# Launch the application
if __name__ == "__main__":
demo.launch()