hynky's picture
hynky HF Staff
length fix grouping
f43ecb1
raw
history blame
11 kB
from functools import partial
import json
from pathlib import Path
import gradio as gr
from collections import defaultdict
import fsspec.config
import math
from datatrove.io import DataFolder, get_datafolder
from datatrove.utils.stats import MetricStatsDict
BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/")
LOG_SCALE_STATS = {
"length",
"n_lines",
"n_docs",
"n_words",
"avg_words_per_line",
"pages_with_lorem_ipsum",
}
colors = list(
[
"rgba(31, 119, 180, 0.5)",
"rgba(255, 127, 14, 0.5)",
"rgba(44, 160, 44, 0.5)",
"rgba(214, 39, 40, 0.5)",
"rgba(148, 103, 189, 0.5)",
"rgba(227, 119, 194, 0.5)",
"rgba(127, 127, 127, 0.5)",
"rgba(188, 189, 34, 0.5)",
"rgba(23, 190, 207, 0.5)",
"rgba(255, 193, 7, 0.5)",
"rgba(40, 167, 69, 0.5)",
"rgba(23, 162, 184, 0.5)",
"rgba(108, 117, 125, 0.5)",
"rgba(0, 123, 255, 0.5)",
"rgba(220, 53, 69, 0.5)",
"rgba(255, 159, 67, 0.5)",
"rgba(255, 87, 34, 0.5)",
"rgba(41, 182, 246, 0.5)",
"rgba(142, 36, 170, 0.5)",
"rgba(0, 188, 212, 0.5)",
"rgba(255, 235, 59, 0.5)",
"rgba(156, 39, 176, 0.5)",
]
)
def find_folders(base_folder, path):
return sorted(
[
folder["name"]
for folder in base_folder.ls(path, detail=True)
if folder["type"] == "directory" and not folder["name"].rstrip("/") == path
]
)
def find_stats_folders(base_folder: DataFolder):
# First find all stats-merged.json using globing for stats-merged.json
stats_merged = base_folder.glob("**/stats-merged.json")
# Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name)
stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged]
# Finally get the unique paths
return sorted(list(set(stats_folders)))
RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER))
def fetch_groups(runs, old_groups):
GROUPS = [
[Path(x).name for x in find_folders(BASE_DATA_FOLDER, run)] for run in runs
]
# DO the intersection
if len(GROUPS) == 0:
return gr.update(choices=[], value=None)
new_choices = set.intersection(*(set(g) for g in GROUPS))
value = None
if old_groups:
value = list(set.intersection(new_choices, {old_groups}))
value = value[0] if value else None
# now take the intersection of all grups
return gr.update(choices=list(new_choices), value=value)
def fetch_stats(runs, group, old_stats):
STATS = [
[Path(x).name for x in find_folders(BASE_DATA_FOLDER, f"{run}/{group}")]
for run in runs
]
if len(STATS) == 0:
return gr.update(choices=[], value=None)
new_possibles_choices = set.intersection(*(set(s) for s in STATS))
value = None
if old_stats:
value = list(set.intersection(new_possibles_choices, {old_stats}))
value = value[0] if value else None
return gr.update(choices=list(new_possibles_choices), value=value)
def load_stats(path, stat_name, group_by):
with BASE_DATA_FOLDER.open(
f"{path}/{group_by}/{stat_name}/stats-merged.json",
filecache={"cache_storage": "/tmp/files"},
) as f:
json_stat = json.load(f)
# No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme
return MetricStatsDict() + MetricStatsDict(init=json_stat)
def prepare_non_grouped_data(path, stat_name, grouping, normalization):
stats = load_stats(path, stat_name, grouping)
stats_rounded = defaultdict(lambda: 0)
for key, value in stats.items():
stats_rounded[float(key)] += value.total
if normalization:
normalizer = sum(stats_rounded.values())
stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()}
return stats_rounded
def prepare_grouped_data(path, stat_name, grouping, top_k, direction):
import heapq
stats = load_stats(path, stat_name, grouping)
means = {key: value.mean for key, value in stats.items()}
# Use heap to get top_k keys
if direction == "Top":
keys = heapq.nlargest(top_k, means, key=means.get)
elif direction == "Most frequent (n_docs)":
n_docs = load_stats(path, "n_docs", grouping)
totals = {key: value.total for key, value in n_docs.items()}
keys = heapq.nlargest(top_k, totals, key=totals.get)
elif direction == "Most frequent (length)":
n_docs = load_stats(path, "length", grouping)
totals = {key: value.total * value.mean for key, value in n_docs.items()}
keys = heapq.nlargest(top_k, totals, key=totals.get)
else:
keys = heapq.nsmallest(top_k, means, key=means.get)
return [(key, means[key]) for key in keys]
import math
import plotly.graph_objects as go
from plotly.offline import plot
def plot_scatter(
histograms: dict[str, dict[float, float]], stat_name: str, normalization: bool
):
fig = go.Figure()
for i, (name, histogram) in enumerate(histograms.items()):
if all(isinstance(k, str) for k in histogram.keys()):
x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])]
else:
x = sorted(histogram.keys())
y = [histogram[k] for k in x]
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="lines",
name=name,
line=dict(color=colors[i % len(colors)]),
)
)
xaxis_scale = "log" if stat_name in LOG_SCALE_STATS else "linear"
yaxis_title = "Frequency" if normalization else "Total"
fig.update_layout(
title=f"Line Plots for {stat_name}",
xaxis_title=stat_name,
yaxis_title=yaxis_title,
xaxis_type=xaxis_scale,
width=1200,
height=600,
showlegend=True,
)
return fig
def plot_bars(histograms: dict[str, list[tuple[str, float]]], stat_name: str):
fig = go.Figure()
for i, (name, histogram) in enumerate(histograms.items()):
x = [k for k, v in histogram]
y = [v for k, v in histogram]
fig.add_trace(go.Bar(x=x, y=y, name=name, marker_color=colors[i % len(colors)]))
fig.update_layout(
title=f"Bar Plots for {stat_name}",
xaxis_title=stat_name,
yaxis_title="Mean value",
autosize=True,
width=1200,
height=600,
showlegend=True,
)
return fig
def update_graph(
multiselect_crawls, stat_name, grouping, normalization, top_k, direction
):
if len(multiselect_crawls) <= 0 or not stat_name or not grouping:
return None
# Placeholder for logic to rerender the graph based on the inputs
prepare_fc = (
partial(prepare_non_grouped_data, normalization=normalization)
if grouping == "histogram"
else partial(prepare_grouped_data, top_k=top_k, direction=direction)
)
graph_fc = (
partial(plot_scatter, normalization=normalization)
if grouping == "histogram"
else plot_bars
)
print("Loading stats")
histograms = {
path: prepare_fc(path, stat_name, grouping) for path in multiselect_crawls
}
print("Plotting")
return graph_fc(histograms, stat_name)
# Create the Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2):
# Define the multiselect for crawls
multiselect_crawls = gr.Dropdown(
choices=RUNS,
label="Multiselect for crawls",
multiselect=True,
)
# add a readme description
readme_description = gr.Markdown(
label="Readme",
value="""
Explaination of the tool:
Groupings:
- histogram: creates a line plot of values with their occurences. If normalization is on, the values are frequencies summing to 1.
- (fqdn/suffix): creates a bar plot of the mean values of the stats for full qualied domain name/suffix of domain
* k: the number of groups to show
* Top/Bottom: the top/bottom k groups are shown
- summary: simply shows the average value of given stat for selected crawls
""",
)
with gr.Column(scale=1):
# Define the dropdown for grouping
grouping_dropdown = gr.Dropdown(
choices=[],
label="Grouping",
multiselect=False,
)
# Define the dropdown for stat_name
stat_name_dropdown = gr.Dropdown(
choices=[],
label="Stat name",
multiselect=False,
)
with gr.Row(visible=False) as histogram_choices:
normalization_checkbox = gr.Checkbox(
label="Normalize",
value=False, # Default value
)
with gr.Row(visible=False) as group_choices:
top_select = gr.Number(
label="K",
value=100,
interactive=True,
)
direction_checkbox = gr.Radio(
label="Partition",
choices=[
"Top",
"Bottom",
"Most frequent (n_docs)",
"Most frequent (length)",
],
)
update_button = gr.Button("Update Graph", variant="primary")
with gr.Row():
# Define the graph output
graph_output = gr.Plot(label="Graph")
update_button.click(
fn=update_graph,
inputs=[
multiselect_crawls,
stat_name_dropdown,
grouping_dropdown,
normalization_checkbox,
top_select,
direction_checkbox,
],
outputs=graph_output,
)
multiselect_crawls.select(
fn=fetch_groups,
inputs=[multiselect_crawls, grouping_dropdown],
outputs=grouping_dropdown,
)
grouping_dropdown.select(
fn=fetch_stats,
inputs=[multiselect_crawls, grouping_dropdown, stat_name_dropdown],
outputs=stat_name_dropdown,
)
def update_grouping_options(grouping):
if grouping == "histogram":
return {
histogram_choices: gr.Column(visible=True),
group_choices: gr.Column(visible=False),
}
else:
return {
histogram_choices: gr.Column(visible=False),
group_choices: gr.Column(visible=True),
}
grouping_dropdown.select(
fn=update_grouping_options,
inputs=[grouping_dropdown],
outputs=[histogram_choices, group_choices],
)
# Launch the application
if __name__ == "__main__":
demo.launch()