|
from datetime import datetime |
|
import json |
|
import re |
|
import heapq |
|
from collections import defaultdict |
|
import tempfile |
|
from typing import Dict, Tuple, List, Literal |
|
import gradio as gr |
|
from datatrove.utils.stats import MetricStatsDict |
|
|
|
from src.logic.graph_settings import Grouping |
|
|
|
PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"] |
|
|
|
def prepare_for_non_grouped_plotting(metric: Dict[str, MetricStatsDict], normalization: bool, rounding: int) -> Dict[float, float]: |
|
metrics_rounded = defaultdict(lambda: 0) |
|
for key, value in metric.items(): |
|
metrics_rounded[round(float(key), rounding)] += value.total |
|
if normalization: |
|
normalizer = sum(metrics_rounded.values()) |
|
metrics_rounded = {k: v / normalizer for k, v in metrics_rounded.items()} |
|
assert abs(sum(metrics_rounded.values()) - 1) < 0.01 |
|
return metrics_rounded |
|
|
|
def prepare_for_group_plotting(metric: Dict[str, MetricStatsDict], top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int) -> Tuple[List[str], List[float], List[float]]: |
|
regex_compiled = re.compile(regex) if regex else None |
|
metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)} |
|
means = {key: round(float(value.mean), rounding) for key, value in metric.items()} |
|
if direction == "Top": |
|
keys = heapq.nlargest(top_k, means, key=means.get) |
|
elif direction == "Most frequent (n_docs)": |
|
totals = {key: int(value.n) for key, value in metric.items()} |
|
keys = heapq.nlargest(top_k, totals, key=totals.get) |
|
else: |
|
keys = heapq.nsmallest(top_k, means, key=means.get) |
|
|
|
means = [means[key] for key in keys] |
|
stds = [metric[key].standard_deviation for key in keys] |
|
return keys, means, stds |
|
|
|
def export_data(exported_data: Dict[str, MetricStatsDict], metric_name: str, grouping: Grouping): |
|
if not exported_data: |
|
return None |
|
|
|
file_name = f"{metric_name}_{grouping}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json" |
|
with open(file_name, 'w') as f: |
|
json.dump({ |
|
name: sorted([{"value": key, **value} for key, value in dt.to_dict().items()], key=lambda x: x["value"]) |
|
for name, dt in exported_data.items() |
|
}, f, indent=2) |
|
return gr.File(value=file_name, visible=True) |