hynky's picture
hynky HF Staff
add readme
6c72e3f
raw
history blame
9.66 kB
from functools import partial
import json
from pathlib import Path
import gradio as gr
from collections import defaultdict
import fsspec.config
import math
from datatrove.io import DataFolder, get_datafolder
from datatrove.utils.stats import MetricStatsDict
BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/")
LOG_SCALE_STATS = {
"length",
"n_lines",
"n_docs",
"avg_words_per_line",
"pages_with_lorem_ipsum",
}
def find_folders(base_folder, path):
return sorted(
[
folder["name"]
for folder in base_folder.ls(path, detail=True)
if folder["type"] == "directory" and not folder["name"].rstrip("/") == path
]
)
def find_stats_folders(base_folder: DataFolder):
# First find all stats-merged.json using globing for stats-merged.json
stats_merged = base_folder.glob("**/stats-merged.json")
# Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name)
stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged]
# Finally get the unique paths
return list(set(stats_folders))
RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER))
def fetch_groups(runs, old_groups):
GROUPS = [
[Path(x).name for x in find_folders(BASE_DATA_FOLDER, run)] for run in runs
]
# DO the intersection
if len(GROUPS) == 0:
return gr.update(choices=[], value=None)
new_choices = set.intersection(*(set(g) for g in GROUPS))
value = None
if old_groups:
value = list(set.intersection(new_choices, {old_groups}))
value = value[0] if value else None
# now take the intersection of all grups
return gr.update(choices=list(new_choices), value=value)
def fetch_stats(runs, group, old_stats):
STATS = [
[Path(x).name for x in find_folders(BASE_DATA_FOLDER, f"{run}/{group}")]
for run in runs
]
if len(STATS) == 0:
return gr.update(choices=[], value=None)
new_possibles_choices = set.intersection(*(set(s) for s in STATS))
value = None
if old_stats:
value = list(set.intersection(new_possibles_choices, {old_stats}))
value = value[0] if value else None
return gr.update(choices=list(new_possibles_choices), value=value)
def load_stats(path, stat_name, group_by):
with BASE_DATA_FOLDER.open(
f"{path}/{group_by}/{stat_name}/stats-merged.json",
filecache={"cache_storage": "/tmp/files"},
) as f:
json_stat = json.load(f)
# No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme
return MetricStatsDict() + MetricStatsDict(init=json_stat)
def prepare_non_grouped_data(stats: MetricStatsDict, normalization):
stats_rounded = defaultdict(lambda: 0)
for key, value in stats.items():
stats_rounded[float(key)] += value.total
if normalization:
normalizer = sum(stats_rounded.values())
stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()}
return stats_rounded
def prepare_grouped_data(stats: MetricStatsDict, top_k, direction):
import heapq
means = {key: value.mean for key, value in stats.items()}
# Use heap to get top_k keys
if direction == "Top":
keys = heapq.nlargest(top_k, means, key=means.get)
else:
keys = heapq.nsmallest(top_k, means, key=means.get)
print(keys)
return {key: means[key] for key in keys}
import math
import plotly.graph_objects as go
from plotly.offline import plot
def plot_scatter(
histograms: dict[str, dict[float, float]], stat_name: str, normalization: bool
):
fig = go.Figure()
colors = iter(
[
"rgba(31, 119, 180, 0.5)",
"rgba(255, 127, 14, 0.5)",
"rgba(44, 160, 44, 0.5)",
"rgba(214, 39, 40, 0.5)",
"rgba(148, 103, 189, 0.5)",
"rgba(227, 119, 194, 0.5)",
"rgba(127, 127, 127, 0.5)",
"rgba(188, 189, 34, 0.5)",
"rgba(23, 190, 207, 0.5)",
]
)
for name, histogram in histograms.items():
if all(isinstance(k, str) for k in histogram.keys()):
x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])]
else:
x = sorted(histogram.keys())
y = [histogram[k] for k in x]
fig.add_trace(
go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=next(colors)))
)
xaxis_scale = "log" if stat_name in LOG_SCALE_STATS else "linear"
yaxis_title = "Frequency" if normalization else "Total"
fig.update_layout(
title=f"Line Plots for {stat_name}",
xaxis_title=stat_name,
yaxis_title=yaxis_title,
xaxis_type=xaxis_scale,
width=1200,
height=600,
)
return fig
def plot_bars(histograms: dict[str, dict[float, float]], stat_name: str):
fig = go.Figure()
for name, histogram in histograms.items():
x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])]
y = [histogram[k] for k in x]
fig.add_trace(go.Bar(x=x, y=y, name=name))
fig.update_layout(
title=f"Bar Plots for {stat_name}",
xaxis_title=stat_name,
yaxis_title="Mean value",
autosize=True,
width=1200,
height=600,
)
return fig
def update_graph(
multiselect_crawls, stat_name, grouping, normalization, top_k, direction
):
if len(multiselect_crawls) <= 0 or not stat_name or not grouping:
return None
# Placeholder for logic to rerender the graph based on the inputs
prepare_fc = (
partial(prepare_non_grouped_data, normalization=normalization)
if grouping == "histogram"
else partial(prepare_grouped_data, top_k=top_k, direction=direction)
)
graph_fc = (
partial(plot_scatter, normalization=normalization)
if grouping == "histogram"
else plot_bars
)
print("Loading stats")
histograms = {
path: prepare_fc(load_stats(path, stat_name, grouping))
for path in multiselect_crawls
}
print("Plotting")
return graph_fc(histograms, stat_name)
# Create the Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2):
# Define the multiselect for crawls
multiselect_crawls = gr.Dropdown(
choices=RUNS,
label="Multiselect for crawls",
multiselect=True,
)
# add a readme description
readme_description = gr.Markdown(
label="Readme",
value="""
Explaination of the tool:
Groupings:
- histogram: creates a line plot of values with their occurences. If normalization is on, the values are frequencies summing to 1.
- (fqdn/suffix): creates a bar plot of the mean values of the stats for full qualied domain name/suffix of domain
* k: the number of groups to show
* Top/Bottom: the top/bottom k groups are shown
- summary: simply shows the average value of given stat for selected crawls
""",
)
with gr.Column(scale=1):
# Define the dropdown for grouping
grouping_dropdown = gr.Dropdown(
choices=[],
label="Grouping",
multiselect=False,
)
# Define the dropdown for stat_name
stat_name_dropdown = gr.Dropdown(
choices=[],
label="Stat name",
multiselect=False,
)
with gr.Row(visible=False) as histogram_choices:
normalization_checkbox = gr.Checkbox(
label="Normalize",
value=False, # Default value
)
with gr.Row(visible=False) as group_choices:
top_select = gr.Number(
label="K",
value=100,
interactive=True,
)
direction_checkbox = gr.Radio(
label="Partition",
choices=["Top", "Bottom"],
)
update_button = gr.Button("Update Graph", variant="primary")
with gr.Row():
# Define the graph output
graph_output = gr.Plot(label="Graph")
update_button.click(
fn=update_graph,
inputs=[
multiselect_crawls,
stat_name_dropdown,
grouping_dropdown,
normalization_checkbox,
top_select,
direction_checkbox,
],
outputs=graph_output,
)
multiselect_crawls.select(
fn=fetch_groups,
inputs=[multiselect_crawls, grouping_dropdown],
outputs=grouping_dropdown,
)
grouping_dropdown.select(
fn=fetch_stats,
inputs=[multiselect_crawls, grouping_dropdown, stat_name_dropdown],
outputs=stat_name_dropdown,
)
def update_grouping_options(grouping):
if grouping == "histogram":
return {
histogram_choices: gr.Column(visible=True),
group_choices: gr.Column(visible=False),
}
else:
return {
histogram_choices: gr.Column(visible=False),
group_choices: gr.Column(visible=True),
}
grouping_dropdown.select(
fn=update_grouping_options,
inputs=[grouping_dropdown],
outputs=[histogram_choices, group_choices],
)
# Launch the application
if __name__ == "__main__":
demo.launch()