|
from functools import partial |
|
import json |
|
from pathlib import Path |
|
import gradio as gr |
|
|
|
from collections import defaultdict |
|
import fsspec.config |
|
import math |
|
from datatrove.io import DataFolder, get_datafolder |
|
from datatrove.utils.stats import MetricStatsDict |
|
|
|
BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/") |
|
LOG_SCALE_STATS = { |
|
"length", |
|
"n_lines", |
|
"n_docs", |
|
"avg_words_per_line", |
|
"pages_with_lorem_ipsum", |
|
} |
|
|
|
|
|
def find_folders(base_folder, path): |
|
return sorted( |
|
[ |
|
folder["name"] |
|
for folder in base_folder.ls(path, detail=True) |
|
if folder["type"] == "directory" and not folder["name"].rstrip("/") == path |
|
] |
|
) |
|
|
|
|
|
def find_stats_folders(base_folder: DataFolder): |
|
|
|
stats_merged = base_folder.glob("**/stats-merged.json") |
|
|
|
|
|
stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged] |
|
|
|
return list(set(stats_folders)) |
|
|
|
|
|
RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER)) |
|
|
|
|
|
def fetch_groups(runs, old_groups): |
|
GROUPS = [ |
|
[Path(x).name for x in find_folders(BASE_DATA_FOLDER, run)] for run in runs |
|
] |
|
|
|
if len(GROUPS) == 0: |
|
return gr.update(choices=[], value=None) |
|
|
|
new_choices = set.intersection(*(set(g) for g in GROUPS)) |
|
value = None |
|
if old_groups: |
|
value = list(set.intersection(new_choices, {old_groups})) |
|
value = value[0] if value else None |
|
|
|
|
|
return gr.update(choices=list(new_choices), value=value) |
|
|
|
|
|
def fetch_stats(runs, group, old_stats): |
|
STATS = [ |
|
[Path(x).name for x in find_folders(BASE_DATA_FOLDER, f"{run}/{group}")] |
|
for run in runs |
|
] |
|
if len(STATS) == 0: |
|
return gr.update(choices=[], value=None) |
|
|
|
new_possibles_choices = set.intersection(*(set(s) for s in STATS)) |
|
value = None |
|
if old_stats: |
|
value = list(set.intersection(new_possibles_choices, {old_stats})) |
|
value = value[0] if value else None |
|
|
|
return gr.update(choices=list(new_possibles_choices), value=value) |
|
|
|
|
|
def load_stats(path, stat_name, group_by): |
|
with BASE_DATA_FOLDER.open( |
|
f"{path}/{group_by}/{stat_name}/stats-merged.json", |
|
filecache={"cache_storage": "/tmp/files"}, |
|
) as f: |
|
json_stat = json.load(f) |
|
|
|
return MetricStatsDict() + MetricStatsDict(init=json_stat) |
|
|
|
|
|
def prepare_non_grouped_data(stats: MetricStatsDict, normalization): |
|
stats_rounded = defaultdict(lambda: 0) |
|
for key, value in stats.items(): |
|
stats_rounded[float(key)] += value.total |
|
if normalization: |
|
normalizer = sum(stats_rounded.values()) |
|
stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()} |
|
return stats_rounded |
|
|
|
|
|
def prepare_grouped_data(stats: MetricStatsDict, top_k, direction): |
|
import heapq |
|
|
|
means = {key: value.mean for key, value in stats.items()} |
|
|
|
|
|
if direction == "Top": |
|
keys = heapq.nlargest(top_k, means, key=means.get) |
|
else: |
|
keys = heapq.nsmallest(top_k, means, key=means.get) |
|
print(keys) |
|
|
|
return {key: means[key] for key in keys} |
|
|
|
|
|
import math |
|
import plotly.graph_objects as go |
|
from plotly.offline import plot |
|
|
|
|
|
def plot_scatter( |
|
histograms: dict[str, dict[float, float]], stat_name: str, normalization: bool |
|
): |
|
fig = go.Figure() |
|
|
|
colors = iter( |
|
[ |
|
"rgba(31, 119, 180, 0.5)", |
|
"rgba(255, 127, 14, 0.5)", |
|
"rgba(44, 160, 44, 0.5)", |
|
"rgba(214, 39, 40, 0.5)", |
|
"rgba(148, 103, 189, 0.5)", |
|
"rgba(227, 119, 194, 0.5)", |
|
"rgba(127, 127, 127, 0.5)", |
|
"rgba(188, 189, 34, 0.5)", |
|
"rgba(23, 190, 207, 0.5)", |
|
] |
|
) |
|
|
|
for name, histogram in histograms.items(): |
|
if all(isinstance(k, str) for k in histogram.keys()): |
|
x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])] |
|
else: |
|
x = sorted(histogram.keys()) |
|
|
|
y = [histogram[k] for k in x] |
|
|
|
fig.add_trace( |
|
go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=next(colors))) |
|
) |
|
|
|
xaxis_scale = "log" if stat_name in LOG_SCALE_STATS else "linear" |
|
yaxis_title = "Frequency" if normalization else "Total" |
|
|
|
fig.update_layout( |
|
title=f"Line Plots for {stat_name}", |
|
xaxis_title=stat_name, |
|
yaxis_title=yaxis_title, |
|
xaxis_type=xaxis_scale, |
|
width=1200, |
|
height=600, |
|
) |
|
|
|
return fig |
|
|
|
|
|
def plot_bars(histograms: dict[str, dict[float, float]], stat_name: str): |
|
fig = go.Figure() |
|
|
|
for name, histogram in histograms.items(): |
|
x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])] |
|
y = [histogram[k] for k in x] |
|
|
|
fig.add_trace(go.Bar(x=x, y=y, name=name)) |
|
|
|
fig.update_layout( |
|
title=f"Bar Plots for {stat_name}", |
|
xaxis_title=stat_name, |
|
yaxis_title="Mean value", |
|
autosize=True, |
|
width=1200, |
|
height=600, |
|
) |
|
|
|
return fig |
|
|
|
|
|
def update_graph( |
|
multiselect_crawls, stat_name, grouping, normalization, top_k, direction |
|
): |
|
if len(multiselect_crawls) <= 0 or not stat_name or not grouping: |
|
return None |
|
|
|
prepare_fc = ( |
|
partial(prepare_non_grouped_data, normalization=normalization) |
|
if grouping == "histogram" |
|
else partial(prepare_grouped_data, top_k=top_k, direction=direction) |
|
) |
|
graph_fc = ( |
|
partial(plot_scatter, normalization=normalization) |
|
if grouping == "histogram" |
|
else plot_bars |
|
) |
|
|
|
print("Loading stats") |
|
histograms = { |
|
path: prepare_fc(load_stats(path, stat_name, grouping)) |
|
for path in multiselect_crawls |
|
} |
|
|
|
print("Plotting") |
|
return graph_fc(histograms, stat_name) |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
|
|
multiselect_crawls = gr.Dropdown( |
|
choices=RUNS, |
|
label="Multiselect for crawls", |
|
multiselect=True, |
|
) |
|
|
|
readme_description = gr.Markdown( |
|
label="Readme", |
|
value=""" |
|
Explaination of the tool: |
|
|
|
Groupings: |
|
- histogram: creates a line plot of values with their occurences. If normalization is on, the values are frequencies summing to 1. |
|
- (fqdn/suffix): creates a bar plot of the mean values of the stats for full qualied domain name/suffix of domain |
|
* k: the number of groups to show |
|
* Top/Bottom: the top/bottom k groups are shown |
|
- summary: simply shows the average value of given stat for selected crawls |
|
|
|
|
|
|
|
|
|
""", |
|
) |
|
with gr.Column(scale=1): |
|
|
|
grouping_dropdown = gr.Dropdown( |
|
choices=[], |
|
label="Grouping", |
|
multiselect=False, |
|
) |
|
|
|
stat_name_dropdown = gr.Dropdown( |
|
choices=[], |
|
label="Stat name", |
|
multiselect=False, |
|
) |
|
with gr.Row(visible=False) as histogram_choices: |
|
normalization_checkbox = gr.Checkbox( |
|
label="Normalize", |
|
value=False, |
|
) |
|
|
|
with gr.Row(visible=False) as group_choices: |
|
top_select = gr.Number( |
|
label="K", |
|
value=100, |
|
interactive=True, |
|
) |
|
direction_checkbox = gr.Radio( |
|
label="Partition", |
|
choices=["Top", "Bottom"], |
|
) |
|
|
|
update_button = gr.Button("Update Graph", variant="primary") |
|
with gr.Row(): |
|
|
|
graph_output = gr.Plot(label="Graph") |
|
|
|
update_button.click( |
|
fn=update_graph, |
|
inputs=[ |
|
multiselect_crawls, |
|
stat_name_dropdown, |
|
grouping_dropdown, |
|
normalization_checkbox, |
|
top_select, |
|
direction_checkbox, |
|
], |
|
outputs=graph_output, |
|
) |
|
|
|
multiselect_crawls.select( |
|
fn=fetch_groups, |
|
inputs=[multiselect_crawls, grouping_dropdown], |
|
outputs=grouping_dropdown, |
|
) |
|
|
|
grouping_dropdown.select( |
|
fn=fetch_stats, |
|
inputs=[multiselect_crawls, grouping_dropdown, stat_name_dropdown], |
|
outputs=stat_name_dropdown, |
|
) |
|
|
|
def update_grouping_options(grouping): |
|
if grouping == "histogram": |
|
return { |
|
histogram_choices: gr.Column(visible=True), |
|
group_choices: gr.Column(visible=False), |
|
} |
|
else: |
|
return { |
|
histogram_choices: gr.Column(visible=False), |
|
group_choices: gr.Column(visible=True), |
|
} |
|
|
|
grouping_dropdown.select( |
|
fn=update_grouping_options, |
|
inputs=[grouping_dropdown], |
|
outputs=[histogram_choices, group_choices], |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|