import json from pathlib import Path import gradio as gr from collections import defaultdict import fsspec.config import math from datatrove.io import DataFolder, get_datafolder from datatrove.utils.stats import MetricStatsDict BASE_DATA_FOLDER = get_datafolder("s3://fineweb-stats/summary/") def find_folders(base_folder, path): return sorted( [ folder["name"] for folder in base_folder.ls(path, detail=True) if folder["type"] == "directory" and not folder["name"].rstrip("/") == path ] ) def find_stats_folders(base_folder: DataFolder): # First find all stats-merged.json using globing for stats-merged.json stats_merged = base_folder.glob("**/stats-merged.json") # Then for each of stats.merged take the all but last two parts of the path (grouping/stat_name) stats_folders = [str(Path(x).parent.parent.parent) for x in stats_merged] # Finally get the unique paths return list(set(stats_folders)) RUNS = sorted(find_stats_folders(BASE_DATA_FOLDER)) GROUPS = [Path(x).name for x in find_folders(BASE_DATA_FOLDER, RUNS[0])] STATS = [ Path(x).name for x in find_folders(BASE_DATA_FOLDER, str(Path(RUNS[0], GROUPS[0]))) ] def load_stats(path, stat_name, group_by): with BASE_DATA_FOLDER.open( f"{path}/{group_by}/{stat_name}/stats-merged.json", filecache={"cache_storage": "/tmp/files"}, ) as f: json_stat = json.load(f) # No idea why this is necessary, but it is, otheriwse the Metric StatsDict is malforme return MetricStatsDict() + MetricStatsDict(init=json_stat) def prepare_non_grouped_data(stats: MetricStatsDict): stats_rounded = defaultdict(lambda: 0) for key, value in stats.items(): stats_rounded[float(key)] += value.total normalizer = sum(stats_rounded.values()) normalizer = 1 stats_rounded = {k: v / normalizer for k, v in stats_rounded.items()} return stats_rounded def prepare_grouped_data(stats: MetricStatsDict, top_k=100): means = {key: value.mean for key, value in stats.items()} # Take the top_k most frequent keys top_keys = sorted(means, key=lambda x: means[x], reverse=True)[:top_k] return {key: means[key] for key in top_keys} import math import plotly.graph_objects as go from plotly.offline import plot def plot_scatter(histograms: dict[str, dict[float, float]], stat_name: str): fig = go.Figure() colors = iter( [ "rgba(31, 119, 180, 0.5)", "rgba(255, 127, 14, 0.5)", "rgba(44, 160, 44, 0.5)", "rgba(214, 39, 40, 0.5)", "rgba(148, 103, 189, 0.5)", ] ) for name, histogram in histograms.items(): if all(isinstance(k, str) for k in histogram.keys()): x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])] else: x = sorted(histogram.keys()) y = [histogram[k] for k in x] fig.add_trace( go.Scatter(x=x, y=y, mode="lines", name=name, line=dict(color=next(colors))) ) fig.update_layout( title=f"Line Plots for {stat_name}", xaxis_title=stat_name, yaxis_title="Frequency", xaxis_type="log", width=1000, height=600, ) return fig def plot_bars(histograms: dict[str, dict[float, float]], stat_name: str): fig = go.Figure() for name, histogram in histograms.items(): x = [k for k, v in sorted(histogram.items(), key=lambda item: item[1])] y = [histogram[k] for k in x] fig.add_trace(go.Bar(x=x, y=y, name=name)) fig.update_layout( title=f"Bar Plots for {stat_name}", xaxis_title=stat_name, yaxis_title="Frequency", autosize=True, width=600, height=600, ) return fig def update_graph(multiselect_crawls, stat_name, grouping): if len(multiselect_crawls) <= 0 or not stat_name or not grouping: return None # Placeholder for logic to rerender the graph based on the inputs prepare_fc = ( prepare_non_grouped_data if grouping == "histogram" else prepare_grouped_data ) graph_fc = plot_scatter if grouping == "histogram" else plot_bars print("Loading stats") histograms = { path: prepare_fc(load_stats(path, stat_name, grouping)) for path in multiselect_crawls } print("Plotting") return graph_fc(histograms, stat_name) # Create the Gradio interface with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=2): # Define the multiselect for crawls multiselect_crawls = gr.Dropdown( choices=RUNS, label="Multiselect for crawls", multiselect=True, ) with gr.Column(scale=1): # Define the dropdown for stat_name stat_name_dropdown = gr.Dropdown( choices=STATS, label="Stat name", multiselect=False, ) # Define the dropdown for grouping grouping_dropdown = gr.Dropdown( choices=GROUPS, label="Grouping", multiselect=False, ) update_button = gr.Button("Update Graph", variant="primary") with gr.Row(): # Define the graph output graph_output = gr.Plot(label="Graph") update_button.click( fn=update_graph, inputs=[multiselect_crawls, stat_name_dropdown, grouping_dropdown], outputs=graph_output, ) # Launch the application if __name__ == "__main__": demo.launch()