from functools import partial import json from pathlib import Path import gradio as gr from collections import defaultdict import fsspec.config import math from datatrove.io import DataFolder, get_datafolder from datatrove.utils.stats import MetricStatsDict BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/") LOG_SCALE_STATS = { "length", "n_lines", "n_docs", "n_words", "avg_words_per_line", "pages_with_lorem_ipsum", } colors = list( [ "rgba(31, 119, 180, 0.5)", "rgba(255, 127, 14, 0.5)", "rgba(44, 160, 44, 0.5)", "rgba(214, 39, 40, 0.5)", "rgba(148, 103, 189, 0.5)", "rgba(227, 119, 194, 0.5)", "rgba(127, 127, 127, 0.5)", "rgba(188, 189, 34, 0.5)", "rgba(23, 190, 207, 0.5)", "rgba(255, 193, 7, 0.5)", "rgba(40, 167, 69, 0.5)", "rgba(23, 162, 184, 0.5)", "rgba(108, 117, 125, 0.5)", "rgba(0, 123, 255, 0.5)", "rgba(220, 53, 69, 0.5)", "rgba(255, 159, 67, 0.5)", "rgba(255, 87, 34, 0.5)", "rgba(41, 182, 246, 0.5)", "rgba(142, 36, 170, 0.5)", "rgba(0, 188, 212, 0.5)", "rgba(255, 235, 59, 0.5)", "rgba(156, 39, 176, 0.5)", ] ) def find_folders(base_folder, path): return sorted( [ folder["name"] for folder in base_folder.ls(path, detail=True) if folder["type"] == "directory" and not folder["name"].rstrip("/") == path ] ) def find_stats_folders(base_folder: DataFolder): # First find all stats-merged.json using globing for stats-merged.json stats_merged = base_folder.glob("**/stats-merged.json") # Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name) stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged] # Finally get the unique paths return sorted(list(set(stats_folders))) RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER)) def fetch_groups(runs, old_groups): GROUPS = [ [Path(x).name for x in find_folders(BASE_DATA_FOLDER, run)] for run in runs ] # DO the intersection if len(GROUPS) == 0: return gr.update(choices=[], value=None) new_choices = set.intersection(*(set(g) for g in GROUPS)) value = None if old_groups: value = list(set.intersection(new_choices, {old_groups})) value = value[0] if value else None # now take the intersection of all grups return gr.update(choices=list(new_choices), value=value) def fetch_stats(runs, group, old_stats): STATS = [ [Path(x).name for x in find_folders(BASE_DATA_FOLDER, f"{run}/{group}")] for run in runs ] if len(STATS) == 0: return gr.update(choices=[], value=None) new_possibles_choices = set.intersection(*(set(s) for s in STATS)) value = None if old_stats: value = list(set.intersection(new_possibles_choices, {old_stats})) value = value[0] if value else None return gr.update(choices=list(new_possibles_choices), value=value) def load_stats(path, stat_name, group_by): with BASE_DATA_FOLDER.open( f"{path}/{group_by}/{stat_name}/stats-merged.json", filecache={"cache_storage": "/tmp/files"}, ) as f: json_stat = json.load(f) # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme return MetricStatsDict() + MetricStatsDict(init=json_stat) def prepare_non_grouped_data(path, stat_name, grouping, normalization): stats = load_stats(path, stat_name, grouping) stats_rounded = defaultdict(lambda: 0) for key, value in stats.items(): stats_rounded[float(key)] += value.total if normalization: normalizer = sum(stats_rounded.values()) stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()} return stats_rounded def prepare_grouped_data(path, stat_name, grouping, top_k, direction): import heapq stats = load_stats(path, stat_name, grouping) means = {key: value.mean for key, value in stats.items()} # Use heap to get top_k keys if direction == "Top": keys = heapq.nlargest(top_k, means, key=means.get) elif direction == "Most frequent (n_docs)": n_docs = load_stats(path, "n_docs", grouping) totals = {key: value.total for key, value in n_docs.items()} keys = heapq.nlargest(top_k, totals, key=totals.get) elif direction == "Most frequent (length)": n_docs = load_stats(path, "length", grouping) totals = {key: value.total * value.mean for key, value in n_docs.items()} keys = heapq.nlargest(top_k, totals, key=totals.get) else: keys = heapq.nsmallest(top_k, means, key=means.get) return [(key, means[key]) for key in keys] import math import plotly.graph_objects as go from plotly.offline import plot def plot_scatter( histograms: dict[str, dict[float, float]], stat_name: str, normalization: bool ): fig = go.Figure() for i, (name, histogram) in enumerate(histograms.items()): if all(isinstance(k, str) for k in histogram.keys()): x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])] else: x = sorted(histogram.keys()) y = [histogram[k] for k in x] fig.add_trace( go.Scatter( x=x, y=y, mode="lines", name=name, line=dict(color=colors[i % len(colors)]), ) ) xaxis_scale = "log" if stat_name in LOG_SCALE_STATS else "linear" yaxis_title = "Frequency" if normalization else "Total" fig.update_layout( title=f"Line Plots for {stat_name}", xaxis_title=stat_name, yaxis_title=yaxis_title, xaxis_type=xaxis_scale, width=1200, height=600, showlegend=True, ) return fig def plot_bars(histograms: dict[str, list[tuple[str, float]]], stat_name: str): fig = go.Figure() for i, (name, histogram) in enumerate(histograms.items()): x = [k for k, v in histogram] y = [v for k, v in histogram] fig.add_trace(go.Bar(x=x, y=y, name=name, marker_color=colors[i % len(colors)])) fig.update_layout( title=f"Bar Plots for {stat_name}", xaxis_title=stat_name, yaxis_title="Mean value", autosize=True, width=1200, height=600, showlegend=True, ) return fig def update_graph( multiselect_crawls, stat_name, grouping, normalization, top_k, direction ): if len(multiselect_crawls) <= 0 or not stat_name or not grouping: return None # Placeholder for logic to rerender the graph based on the inputs prepare_fc = ( partial(prepare_non_grouped_data, normalization=normalization) if grouping == "histogram" else partial(prepare_grouped_data, top_k=top_k, direction=direction) ) graph_fc = ( partial(plot_scatter, normalization=normalization) if grouping == "histogram" else plot_bars ) print("Loading stats") histograms = { path: prepare_fc(path, stat_name, grouping) for path in multiselect_crawls } print("Plotting") return graph_fc(histograms, stat_name) # Create the Gradio interface with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=2): # Define the multiselect for crawls multiselect_crawls = gr.Dropdown( choices=RUNS, label="Multiselect for crawls", multiselect=True, ) # add a readme description readme_description = gr.Markdown( label="Readme", value=""" Explaination of the tool: Groupings: - histogram: creates a line plot of values with their occurences. If normalization is on, the values are frequencies summing to 1. - (fqdn/suffix): creates a bar plot of the mean values of the stats for full qualied domain name/suffix of domain * k: the number of groups to show * Top/Bottom: the top/bottom k groups are shown - summary: simply shows the average value of given stat for selected crawls """, ) with gr.Column(scale=1): # Define the dropdown for grouping grouping_dropdown = gr.Dropdown( choices=[], label="Grouping", multiselect=False, ) # Define the dropdown for stat_name stat_name_dropdown = gr.Dropdown( choices=[], label="Stat name", multiselect=False, ) with gr.Row(visible=False) as histogram_choices: normalization_checkbox = gr.Checkbox( label="Normalize", value=False, # Default value ) with gr.Row(visible=False) as group_choices: top_select = gr.Number( label="K", value=100, interactive=True, ) direction_checkbox = gr.Radio( label="Partition", choices=[ "Top", "Bottom", "Most frequent (n_docs)", "Most frequent (length)", ], ) update_button = gr.Button("Update Graph", variant="primary") with gr.Row(): # Define the graph output graph_output = gr.Plot(label="Graph") update_button.click( fn=update_graph, inputs=[ multiselect_crawls, stat_name_dropdown, grouping_dropdown, normalization_checkbox, top_select, direction_checkbox, ], outputs=graph_output, ) multiselect_crawls.select( fn=fetch_groups, inputs=[multiselect_crawls, grouping_dropdown], outputs=grouping_dropdown, ) grouping_dropdown.select( fn=fetch_stats, inputs=[multiselect_crawls, grouping_dropdown, stat_name_dropdown], outputs=stat_name_dropdown, ) def update_grouping_options(grouping): if grouping == "histogram": return { histogram_choices: gr.Column(visible=True), group_choices: gr.Column(visible=False), } else: return { histogram_choices: gr.Column(visible=False), group_choices: gr.Column(visible=True), } grouping_dropdown.select( fn=update_grouping_options, inputs=[grouping_dropdown], outputs=[histogram_choices, group_choices], ) # Launch the application if __name__ == "__main__": demo.launch()