from datetime import datetime import json import re import heapq from collections import defaultdict import tempfile from typing import Dict, Tuple, List, Literal import gradio as gr from datatrove.utils.stats import MetricStatsDict from src.logic.graph_settings import Grouping PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"] def prepare_for_non_grouped_plotting(metric: Dict[str, MetricStatsDict], normalization: bool, rounding: int) -> Dict[float, float]: metrics_rounded = defaultdict(lambda: 0) for key, value in metric.items(): metrics_rounded[round(float(key), rounding)] += value.total if normalization: normalizer = sum(metrics_rounded.values()) metrics_rounded = {k: v / normalizer for k, v in metrics_rounded.items()} assert abs(sum(metrics_rounded.values()) - 1) < 0.01 return metrics_rounded def prepare_for_group_plotting(metric: Dict[str, MetricStatsDict], top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int) -> Tuple[List[str], List[float], List[float]]: regex_compiled = re.compile(regex) if regex else None metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)} means = {key: round(float(value.mean), rounding) for key, value in metric.items()} if direction == "Top": keys = heapq.nlargest(top_k, means, key=means.get) elif direction == "Most frequent (n_docs)": totals = {key: int(value.n) for key, value in metric.items()} keys = heapq.nlargest(top_k, totals, key=totals.get) else: keys = heapq.nsmallest(top_k, means, key=means.get) means = [means[key] for key in keys] stds = [metric[key].standard_deviation for key in keys] return keys, means, stds def export_data(exported_data: Dict[str, MetricStatsDict], metric_name: str, grouping: Grouping): if not exported_data: return None file_name = f"{metric_name}_{grouping}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json" with open(file_name, 'w') as f: json.dump({ name: sorted([{"value": key, **value} for key, value in dt.to_dict().items()], key=lambda x: x["value"]) for name, dt in exported_data.items() }, f, indent=2) return gr.File(value=file_name, visible=True)