File size: 2,320 Bytes
75448af
40e38d3
 
 
 
 
 
 
 
 
75448af
 
40e38d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75448af
40e38d3
 
75448af
 
 
40e38d3
 
 
75448af
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from datetime import datetime
import json
import re
import heapq
from collections import defaultdict
import tempfile
from typing import Dict, Tuple, List, Literal
import gradio as gr
from datatrove.utils.stats import MetricStatsDict

from src.logic.graph_settings import Grouping

PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"]

def prepare_for_non_grouped_plotting(metric: Dict[str, MetricStatsDict], normalization: bool, rounding: int) -> Dict[float, float]:
    metrics_rounded = defaultdict(lambda: 0)
    for key, value in metric.items():
        metrics_rounded[round(float(key), rounding)] += value.total
    if normalization:
        normalizer = sum(metrics_rounded.values())
        metrics_rounded = {k: v / normalizer for k, v in metrics_rounded.items()}
        assert abs(sum(metrics_rounded.values()) - 1) < 0.01
    return metrics_rounded

def prepare_for_group_plotting(metric: Dict[str, MetricStatsDict], top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int) -> Tuple[List[str], List[float], List[float]]:
    regex_compiled = re.compile(regex) if regex else None
    metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)}
    means = {key: round(float(value.mean), rounding) for key, value in metric.items()}
    if direction == "Top":
        keys = heapq.nlargest(top_k, means, key=means.get)
    elif direction == "Most frequent (n_docs)":
        totals = {key: int(value.n) for key, value in metric.items()}
        keys = heapq.nlargest(top_k, totals, key=totals.get)
    else:
        keys = heapq.nsmallest(top_k, means, key=means.get)

    means = [means[key] for key in keys]
    stds = [metric[key].standard_deviation for key in keys]
    return keys, means, stds

def export_data(exported_data: Dict[str, MetricStatsDict], metric_name: str, grouping: Grouping):
    if not exported_data:
        return None

    file_name = f"{metric_name}_{grouping}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
    with open(file_name, 'w') as f:
        json.dump({
            name: sorted([{"value": key, **value} for key, value in dt.to_dict().items()], key=lambda x: x["value"])
            for name, dt in exported_data.items()
        }, f, indent=2)
    return gr.File(value=file_name, visible=True)