Datasets-Metrics-Viewer / src /logic /data_fetching.py
hynky's picture
hynky HF Staff
Refactor the code
75448af
raw
history blame
5.27 kB
from functools import partial
import os
import json
import re
import tempfile
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict
from datatrove.io import get_datafolder
from datatrove.utils.stats import MetricStatsDict
import gradio as gr
import tenacity
from src.logic.graph_settings import Grouping
def find_folders(base_folder: str, path: str) -> List[str]:
base_folder_df = get_datafolder(base_folder)
if not base_folder_df.exists(path):
return []
return sorted(
[
folder
for folder,info in base_folder_df.find(path, maxdepth=1, withdirs=True, detail=True).items()
if info["type"] == "directory" and not (folder.rstrip("/") == path.rstrip("/"))
]
)
def fetch_datasets(base_folder: str):
datasets = sorted(find_folders(base_folder, ""))
if len(datasets) == 0:
raise ValueError("No datasets found")
return datasets
def fetch_groups(base_folder: str, datasets: List[str], old_groups: str, type: str = "intersection"):
if not datasets:
return gr.update(choices=[], value=None)
with ThreadPoolExecutor() as executor:
GROUPS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, run)], datasets))
if len(GROUPS) == 0:
return gr.update(choices=[], value=None)
if type == "intersection":
new_choices = set.intersection(*(set(g) for g in GROUPS))
else:
new_choices = set.union(*(set(g) for g in GROUPS))
value = None
if old_groups:
value = list(set.intersection(new_choices, {old_groups}))
value = value[0] if value else None
if not value and len(new_choices) == 1:
value = list(new_choices)[0]
return gr.Dropdown(choices=sorted(list(new_choices)), value=value)
def fetch_metrics(base_folder: str, datasets: List[str], group: str, old_metrics: str, type: str = "intersection"):
if not group:
return gr.update(choices=[], value=None)
with ThreadPoolExecutor() as executor:
metrics = list(
executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets))
if len(metrics) == 0:
return gr.update(choices=[], value=None)
if type == "intersection":
new_possibles_choices = set.intersection(*(set(s) for s in metrics))
else:
new_possibles_choices = set.union(*(set(s) for s in metrics))
value = None
if old_metrics:
value = list(set.intersection(new_possibles_choices, {old_metrics}))
value = value[0] if value else None
if not value and len(new_possibles_choices) == 1:
value = list(new_possibles_choices)[0]
return gr.Dropdown(choices=sorted(list(new_possibles_choices)), value=value)
def reverse_search(base_folder: str, possible_datasets: List[str], grouping: str, metric_name: str) -> str:
with ThreadPoolExecutor() as executor:
found_datasets = list(executor.map(
lambda dataset: dataset if metric_exists(base_folder, dataset, metric_name, grouping) else None,
possible_datasets))
found_datasets = [dataset for dataset in found_datasets if dataset is not None]
return "\n".join(found_datasets)
def reverse_search_add(datasets: List[str], reverse_search_results: str) -> List[str]:
datasets = datasets or []
return list(set(datasets + reverse_search_results.strip().split("\n")))
def metric_exists(base_folder: str, path: str, metric_name: str, group_by: str) -> bool:
base_folder = get_datafolder(base_folder)
return base_folder.exists(f"{path}/{group_by}/{metric_name}/metric.json")
@tenacity.retry(stop=tenacity.stop_after_attempt(5))
def load_metrics(base_folder: str, path: str, metric_name: str, group_by: str) -> MetricStatsDict:
base_folder = get_datafolder(base_folder)
with base_folder.open(f"{path}/{group_by}/{metric_name}/metric.json") as f:
json_metric = json.load(f)
return MetricStatsDict.from_dict(json_metric)
def load_data(dataset_path: str, base_folder: str, grouping: str, metric_name: str) -> MetricStatsDict:
return load_metrics(base_folder, dataset_path, metric_name, grouping)
def fetch_graph_data(
base_folder: str,
datasets: List[str],
metric_name: str,
grouping: Grouping,
progress=gr.Progress(),
):
if len(datasets) <= 0 or not metric_name or not grouping:
return None
with ThreadPoolExecutor() as pool:
data = list(
progress.tqdm(
pool.map(
partial(load_data, base_folder=base_folder, metric_name=metric_name, grouping=grouping),
datasets,
),
total=len(datasets),
desc="Loading data...",
)
)
data = {path: result for path, result in zip(datasets, data)}
return data, None
def update_datasets_with_regex(regex: str, selected_runs: List[str], all_runs: List[str]):
if not regex:
return []
new_dsts = {run for run in all_runs if re.search(regex, run)}
if not new_dsts:
return selected_runs
dst_union = new_dsts.union(selected_runs or [])
return sorted(list(dst_union))