|
from concurrent.futures import ThreadPoolExecutor |
|
import enum |
|
from functools import partial |
|
import json |
|
import os |
|
from pathlib import Path |
|
import re |
|
import tempfile |
|
from typing import Literal |
|
import gradio as gr |
|
|
|
from collections import defaultdict |
|
from datatrove.io import DataFolder, get_datafolder |
|
import plotly.graph_objects as go |
|
from datatrove.utils.stats import MetricStatsDict |
|
import plotly.express as px |
|
|
|
import gradio as gr |
|
PARTITION_OPTIONS = Literal[ "Top", "Bottom", "Most frequent (n_docs)"] |
|
|
|
LOG_SCALE_STATS = { |
|
"length", |
|
"n_lines", |
|
"n_docs", |
|
"n_words", |
|
"avg_words_per_line", |
|
"pages_with_lorem_ipsum", |
|
} |
|
|
|
STATS_LOCATION_DEFAULT = os.getenv("STATS_LOCATION_DEFAULT", "s3://") |
|
|
|
|
|
def find_folders(base_folder, path): |
|
base_folder = get_datafolder(base_folder) |
|
if not base_folder.exists(path): |
|
return [] |
|
return sorted( |
|
[ |
|
folder["name"] |
|
for folder in base_folder.ls(path, detail=True) |
|
if folder["type"] == "directory" and not folder["name"].rstrip("/") == path |
|
] |
|
) |
|
|
|
|
|
def find_stats_folders(base_folder: str): |
|
base_data_folder = get_datafolder(base_folder) |
|
|
|
stats_merged = base_data_folder.glob("**/stats-merged.json") |
|
|
|
|
|
stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged] |
|
|
|
return sorted(list(set(stats_folders))) |
|
|
|
|
|
def fetch_datasets(base_folder: str): |
|
datasets = sorted(find_stats_folders(base_folder)) |
|
return datasets, gr.update(choices=datasets, value=None), fetch_groups(base_folder, datasets, None, "union") |
|
|
|
|
|
def export_data(exported_data): |
|
if not exported_data: |
|
return None |
|
|
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as temp: |
|
json.dump(exported_data, temp) |
|
temp_path = temp.name |
|
return gr.update(visible=True, value=temp_path) |
|
|
|
|
|
def fetch_groups(base_folder, datasets, old_groups, type="intersection"): |
|
if not datasets: |
|
return gr.update(choices=[], value=None) |
|
|
|
with ThreadPoolExecutor() as executor: |
|
GROUPS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, run)], datasets)) |
|
if len(GROUPS) == 0: |
|
return gr.update(choices=[], value=None) |
|
|
|
if type == "intersection": |
|
new_choices = set.intersection(*(set(g) for g in GROUPS)) |
|
elif type == "union": |
|
new_choices = set.union(*(set(g) for g in GROUPS)) |
|
value = None |
|
if old_groups: |
|
value = list(set.intersection(new_choices, {old_groups})) |
|
value = value[0] if value else None |
|
|
|
|
|
return gr.update(choices=sorted(list(new_choices)), value=value) |
|
|
|
|
|
def fetch_stats(base_folder, datasets, group, old_stats, type="intersection"): |
|
print("Fetching stats") |
|
with ThreadPoolExecutor() as executor: |
|
STATS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets)) |
|
if len(STATS) == 0: |
|
return gr.update(choices=[], value=None) |
|
|
|
if type == "intersection": |
|
new_possibles_choices = set.intersection(*(set(s) for s in STATS)) |
|
elif type == "union": |
|
new_possibles_choices = set.union(*(set(s) for s in STATS)) |
|
value = None |
|
if old_stats: |
|
value = list(set.intersection(new_possibles_choices, {old_stats})) |
|
value = value[0] if value else None |
|
|
|
return gr.update(choices=sorted(list(new_possibles_choices)), value=value) |
|
|
|
|
|
def reverse_search(base_folder, possible_datasets, grouping, stat_name): |
|
with ThreadPoolExecutor() as executor: |
|
found_datasets = list(executor.map(lambda dataset: dataset if stat_exists(base_folder, dataset, stat_name, grouping) else None, possible_datasets)) |
|
found_datasets = [dataset for dataset in found_datasets if dataset is not None] |
|
return "\n".join(found_datasets) |
|
|
|
|
|
def reverse_search_add(datasets, reverse_search_results): |
|
datasets = datasets or [] |
|
return sorted(list(set(datasets + reverse_search_results.strip().split("\n")))) |
|
|
|
|
|
|
|
def stat_exists(base_folder, path, stat_name, group_by): |
|
base_folder = get_datafolder(base_folder) |
|
return base_folder.exists(f"{path}/{group_by}/{stat_name}/stats-merged.json") |
|
|
|
def load_stats(base_folder, path, stat_name, group_by): |
|
base_folder = get_datafolder(base_folder) |
|
with base_folder.open( |
|
f"{path}/{group_by}/{stat_name}/stats-merged.json", |
|
) as f: |
|
json_stat = json.load(f) |
|
|
|
return MetricStatsDict() + MetricStatsDict(init=json_stat) |
|
|
|
|
|
def prepare_non_grouped_data(dataset_path, base_folder, grouping, stat_name, normalization): |
|
stats = load_stats(base_folder, dataset_path, stat_name, grouping) |
|
stats_rounded = defaultdict(lambda: 0) |
|
for key, value in stats.items(): |
|
stats_rounded[float(key)] += value.total |
|
if normalization: |
|
normalizer = sum(stats_rounded.values()) |
|
stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()} |
|
return stats_rounded |
|
|
|
|
|
def prepare_grouped_data(dataset_path, base_folder, grouping, stat_name, top_k, direction: PARTITION_OPTIONS, regex): |
|
import heapq |
|
regex_compiled = re.compile(regex) if regex else None |
|
|
|
stats = load_stats(base_folder, dataset_path, stat_name, grouping) |
|
stats = {key: value for key, value in stats.items() if not regex or regex_compiled.match(key)} |
|
means = {key: value.mean for key, value in stats.items()} |
|
|
|
|
|
if direction == "Top": |
|
keys = heapq.nlargest(top_k, means, key=means.get) |
|
elif direction == "Most frequent (n_docs)": |
|
totals = {key: value.n for key, value in stats.items()} |
|
keys = heapq.nlargest(top_k, totals, key=totals.get) |
|
else: |
|
keys = heapq.nsmallest(top_k, means, key=means.get) |
|
|
|
return [(key, means[key]) for key in keys] |
|
|
|
|
|
def set_alpha(color, alpha): |
|
""" |
|
Takes a hex color and returns |
|
rgba(r, g, b, a) |
|
""" |
|
if color.startswith('#'): |
|
r, g, b = int(color[1:3], 16), int(color[3:5], 16), int(color[5:7], 16) |
|
else: |
|
r, g, b = 0, 0, 0 |
|
return f"rgba({r}, {g}, {b}, {alpha})" |
|
|
|
|
|
|
|
|
|
def plot_scatter( |
|
histograms: dict[str, dict[float, float]], |
|
stat_name: str, |
|
normalization: bool, |
|
progress: gr.Progress, |
|
): |
|
fig = go.Figure() |
|
|
|
for i, (name, histogram) in enumerate(progress.tqdm(histograms.items(), total=len(histograms), desc="Plotting...")): |
|
if all(isinstance(k, str) for k in histogram.keys()): |
|
x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])] |
|
else: |
|
x = sorted(histogram.keys()) |
|
|
|
y = [histogram[k] for k in x] |
|
|
|
fig.add_trace( |
|
go.Scatter( |
|
x=x, |
|
y=y, |
|
mode="lines", |
|
name=name, |
|
marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)), |
|
) |
|
) |
|
|
|
xaxis_scale = "log" if stat_name in LOG_SCALE_STATS else "linear" |
|
yaxis_title = "Frequency" if normalization else "Total" |
|
|
|
fig.update_layout( |
|
title=f"Line Plots for {stat_name}", |
|
xaxis_title=stat_name, |
|
yaxis_title=yaxis_title, |
|
xaxis_type=xaxis_scale, |
|
width=1200, |
|
height=600, |
|
showlegend=True, |
|
) |
|
|
|
return fig |
|
|
|
|
|
def plot_bars( |
|
histograms: dict[str, list[tuple[str, float]]], |
|
stat_name: str, |
|
progress: gr.Progress, |
|
): |
|
fig = go.Figure() |
|
|
|
for i, (name, histogram) in enumerate(progress.tqdm(histograms.items(), total=len(histograms), desc="Plotting...")): |
|
x = [k for k, v in histogram] |
|
y = [v for k, v in histogram] |
|
|
|
fig.add_trace(go.Bar(x=x, y=y, name=name, marker=dict(color=set_alpha(px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)], 0.5)))) |
|
|
|
fig.update_layout( |
|
title=f"Bar Plots for {stat_name}", |
|
xaxis_title=stat_name, |
|
yaxis_title="Mean value", |
|
autosize=True, |
|
width=1200, |
|
height=600, |
|
showlegend=True, |
|
) |
|
|
|
return fig |
|
|
|
|
|
def update_graph( |
|
base_folder, |
|
datasets, |
|
stat_name, |
|
grouping, |
|
normalization, |
|
top_k, |
|
direction, |
|
regex, |
|
progress=gr.Progress(), |
|
): |
|
if len(datasets) <= 0 or not stat_name or not grouping: |
|
return None |
|
|
|
prepare_fc = ( |
|
partial(prepare_non_grouped_data, normalization=normalization) |
|
if grouping == "histogram" |
|
else partial(prepare_grouped_data, top_k=top_k, direction=direction, regex=regex) |
|
) |
|
graph_fc = ( |
|
partial(plot_scatter, normalization=normalization) |
|
if grouping == "histogram" |
|
else plot_bars |
|
) |
|
|
|
with ThreadPoolExecutor() as pool: |
|
data = list( |
|
progress.tqdm( |
|
pool.map( |
|
partial(prepare_fc, base_folder=base_folder, stat_name=stat_name, grouping=grouping), |
|
datasets, |
|
), |
|
total=len(datasets), |
|
desc="Loading data...", |
|
) |
|
) |
|
|
|
histograms = {path: result for path, result in zip(datasets, data)} |
|
|
|
return graph_fc(histograms=histograms, stat_name=stat_name, progress=progress), histograms, gr.update(visible=True) |
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
datasets = gr.State([]) |
|
exported_data = gr.State([]) |
|
stats_headline = gr.Markdown(value="# Stats Exploration") |
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
base_folder = gr.Textbox( |
|
label="Stats Location", |
|
value="s3://fineweb-stats/summary/", |
|
) |
|
datasets_refetch = gr.Button("Fetch Datasets") |
|
|
|
with gr.Column(scale=1): |
|
regex_select = gr.Text(label="Regex select datasets", value=".*") |
|
regex_button = gr.Button("Filter") |
|
with gr.Row(): |
|
datasets_selected = gr.Dropdown( |
|
choices=[], |
|
label="Datasets", |
|
multiselect=True, |
|
) |
|
|
|
|
|
readme_description = gr.Markdown( |
|
label="Readme", |
|
value=""" |
|
Explaination of the tool: |
|
|
|
Groupings: |
|
- histogram: creates a line plot of values with their occurences. If normalization is on, the values are frequencies summing to 1. |
|
- (fqdn/suffix): creates a bar plot of the mean values of the stats for full qualied domain name/suffix of domain |
|
* k: the number of groups to show |
|
* Top/Bottom: the top/bottom k groups are shown |
|
- summary: simply shows the average value of given stat for selected crawls |
|
""", |
|
) |
|
with gr.Column(scale=1): |
|
|
|
grouping_dropdown = gr.Dropdown( |
|
choices=[], |
|
label="Grouping", |
|
multiselect=False, |
|
) |
|
|
|
stat_name_dropdown = gr.Dropdown( |
|
choices=[], |
|
label="Stat name", |
|
multiselect=False, |
|
) |
|
|
|
with gr.Row(visible=False) as histogram_choices: |
|
normalization_checkbox = gr.Checkbox( |
|
label="Normalize", |
|
value=False, |
|
) |
|
|
|
with gr.Row(visible=False) as group_choices: |
|
with gr.Column(scale=2): |
|
group_regex = gr.Text( |
|
label="Group Regex", |
|
value=None, |
|
) |
|
with gr.Row(): |
|
top_select = gr.Number( |
|
label="N Groups", |
|
value=100, |
|
interactive=True, |
|
) |
|
|
|
direction_checkbox = gr.Radio( |
|
label="Partition", |
|
choices=[ |
|
"Top", |
|
"Bottom", |
|
"Most frequent (n_docs)", |
|
], |
|
value="Most frequent (n_docs)", |
|
) |
|
|
|
update_button = gr.Button("Update Graph", variant="primary") |
|
with gr.Row(): |
|
export_data_button = gr.Button("Export data", visible=False) |
|
export_data_json = gr.File(visible=False) |
|
|
|
with gr.Row(): |
|
|
|
graph_output = gr.Plot(label="Graph") |
|
|
|
with gr.Row(): |
|
reverse_search_headline = gr.Markdown(value="# Reverse stats search") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
reverse_grouping_dropdown = gr.Dropdown( |
|
choices=[], |
|
label="Grouping", |
|
multiselect=False, |
|
) |
|
|
|
reverse_stat_name_dropdown = gr.Dropdown( |
|
choices=[], |
|
label="Stat name", |
|
multiselect=False, |
|
) |
|
|
|
with gr.Column(scale=1): |
|
reverse_search_button = gr.Button("Search") |
|
reverse_search_add_button = gr.Button("Add to selection") |
|
|
|
with gr.Column(scale=2): |
|
reverse_search_results = gr.Textbox( |
|
label="Found datasets", |
|
lines=10, |
|
placeholder="Found datasets containing the group/stat name. You can modify the selection after search by removing unwanted lines and clicking Add to selection" |
|
) |
|
|
|
|
|
|
|
update_button.click( |
|
fn=update_graph, |
|
inputs=[ |
|
base_folder, |
|
datasets_selected, |
|
stat_name_dropdown, |
|
grouping_dropdown, |
|
normalization_checkbox, |
|
top_select, |
|
direction_checkbox, |
|
group_regex, |
|
], |
|
outputs=[graph_output, exported_data, export_data_button], |
|
) |
|
|
|
export_data_button.click( |
|
fn=export_data, |
|
inputs=[exported_data], |
|
outputs=export_data_json, |
|
) |
|
|
|
datasets_selected.change( |
|
fn=fetch_groups, |
|
inputs=[base_folder, datasets_selected, grouping_dropdown], |
|
outputs=grouping_dropdown, |
|
) |
|
|
|
grouping_dropdown.select( |
|
fn=fetch_stats, |
|
inputs=[base_folder, datasets_selected, grouping_dropdown, stat_name_dropdown], |
|
outputs=stat_name_dropdown, |
|
) |
|
|
|
reverse_grouping_dropdown.select( |
|
fn=partial(fetch_stats, type="union"), |
|
inputs=[base_folder, datasets, reverse_grouping_dropdown, reverse_stat_name_dropdown], |
|
outputs=reverse_stat_name_dropdown, |
|
) |
|
|
|
reverse_search_button.click( |
|
fn=reverse_search, |
|
inputs=[base_folder, datasets, reverse_grouping_dropdown, reverse_stat_name_dropdown], |
|
outputs=reverse_search_results, |
|
) |
|
|
|
reverse_search_add_button.click( |
|
fn=reverse_search_add, |
|
inputs=[datasets_selected, reverse_search_results], |
|
outputs=datasets_selected, |
|
) |
|
|
|
|
|
datasets_refetch.click( |
|
fn=fetch_datasets, |
|
inputs=[base_folder], |
|
outputs=[datasets, datasets_selected, reverse_grouping_dropdown], |
|
) |
|
|
|
def update_datasets_with_regex(regex, selected_runs, all_runs): |
|
if not regex: |
|
return |
|
new_dsts = {run for run in all_runs if re.search(regex, run)} |
|
dst_union = new_dsts.union(selected_runs) |
|
return gr.update(value=list(dst_union)) |
|
|
|
regex_button.click( |
|
fn=update_datasets_with_regex, |
|
inputs=[regex_select, datasets_selected, datasets], |
|
outputs=datasets_selected, |
|
) |
|
|
|
def update_grouping_options(grouping): |
|
if grouping == "histogram": |
|
return { |
|
histogram_choices: gr.Column(visible=True), |
|
group_choices: gr.Column(visible=False), |
|
} |
|
else: |
|
return { |
|
histogram_choices: gr.Column(visible=False), |
|
group_choices: gr.Column(visible=True), |
|
} |
|
|
|
grouping_dropdown.select( |
|
fn=update_grouping_options, |
|
inputs=[grouping_dropdown], |
|
outputs=[histogram_choices, group_choices], |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|