Spaces:
Sleeping
Sleeping
import ast | |
from collections import defaultdict | |
from functools import partial | |
import itertools | |
import os | |
import re | |
from concurrent.futures import ThreadPoolExecutor | |
import numpy as np | |
from datetime import datetime | |
import gradio as gr | |
import huggingface_hub | |
import pandas as pd | |
import plotly.graph_objects as go | |
from huggingface_hub.file_download import repo_folder_name | |
from huggingface_hub.hf_api import RepoFile | |
from huggingface_hub.utils import EntryNotFoundError | |
FALLBACK_TOKEN_NAME = "HF_TOKEN" | |
def is_arary_like(x): | |
return isinstance(x, list) or isinstance(x, tuple) or isinstance(x, np.ndarray) | |
def get_task_type(df): | |
if all(isinstance(pred, str) for pred in df['predictions'].iloc[0]): | |
return "generative" | |
if all(is_arary_like(pred) and all(isinstance(item, float) for item in pred) for pred in df['predictions'].iloc[0]): | |
return "multiple_choice" | |
return "mixed" | |
def fix_df(df): | |
# For some reason some metrics and predictions are stored as strings | |
for col in ["predictions", "metrics", "choices", "gold", "gold_index"]: | |
df[col] = [ast.literal_eval(x) if isinstance(x, str) else x for x in df[col].values] | |
return df | |
def get_run_name_seed(run_name): | |
if "-seed-" not in run_name: | |
return run_name, 5 | |
run_name, seed = run_name.split("-seed-") | |
return run_name, int(seed) | |
def fetch_repo_structure(repo_name, oauth_token: gr.OAuthToken | None = None): | |
token = os.environ.get(FALLBACK_TOKEN_NAME) | |
if oauth_token: | |
token = oauth_token.token | |
files = list(huggingface_hub.list_repo_tree(repo_name, "details", recursive=False, token=token)) | |
runs = {file.path.split('/')[-1] for file in files if isinstance(file, huggingface_hub.hf_api.RepoFolder)} | |
if not runs: | |
return {}, gr.update(choices=[], value=None) | |
def process_run(run): | |
run_files = list(huggingface_hub.list_repo_tree(repo_name, f"details/{run}", recursive=False, token=token)) | |
return run, [file.path.split('/')[-1] for file in run_files if isinstance(file, huggingface_hub.hf_api.RepoFolder)] | |
with ThreadPoolExecutor() as executor: | |
results = list(executor.map(process_run, runs)) | |
checkpoints_dict = dict(results) | |
return checkpoints_dict, gr.update(choices=list(checkpoints_dict), value=None) | |
def update_checkpoints(selected_runs, checkpoints): | |
if not selected_runs: | |
return gr.update(choices=[], value=None) | |
common_checkpoints = set(checkpoints[selected_runs[0]]) | |
for run in selected_runs[1:]: | |
common_checkpoints.intersection_update(set(checkpoints[run])) | |
common_checkpoints = sorted(list(common_checkpoints)) | |
return gr.update(choices=common_checkpoints, value=common_checkpoints[0] if common_checkpoints else None) | |
def select_runs_by_regex(runs, current_selected, regex_to_select): | |
comp_re = re.compile(regex_to_select) | |
return list(sorted(set((current_selected if current_selected else []) + | |
[run for run in runs if comp_re.fullmatch(run)]))) | |
def select_runs_by_language(runs, current_selected, language): | |
if language: | |
return select_runs_by_regex(runs, current_selected, f".*-{language}-.*") | |
return current_selected | |
def fetch_available_tasks(repo_name, runs_to_fetch, checkpoint) -> dict[str, dict[str, str]]: | |
token = os.environ.get(FALLBACK_TOKEN_NAME) | |
all_tasks = defaultdict(lambda: defaultdict(dict)) | |
for run in runs_to_fetch: | |
try: | |
files = huggingface_hub.list_repo_tree(repo_name, f"details/{run}/{checkpoint}", token=token) | |
parquet_files = [f.path.split('/')[-1] for f in files if f.path.endswith('.parquet')] | |
for full_filename in parquet_files: | |
task_name, date_str = full_filename.replace('.parquet', '').rsplit('_', 1) | |
date = datetime.strptime(date_str, '%Y-%m-%dT%H-%M-%S.%f') | |
if run not in all_tasks[task_name] or date > all_tasks[task_name][run]['date']: | |
all_tasks[task_name][run] = {'filename': full_filename, 'date': date} | |
except EntryNotFoundError: | |
print(f"Checkpoint not found for run: {run}") | |
available_tasks = { | |
task: {run: info['filename'] for run, info in runs.items()} | |
for task, runs in all_tasks.items() | |
if set(runs.keys()) == set(runs_to_fetch) | |
} | |
return available_tasks | |
def fetch_run_results(repo_name, runs_to_fetch, checkpoint, | |
oauth_token: gr.OAuthToken | None = None, progress=gr.Progress()): | |
task_runs_dict = fetch_available_tasks(repo_name, runs_to_fetch, checkpoint) | |
task_names = list(task_runs_dict.keys()) | |
return gr.update(choices=task_names, value=task_names[0] if task_names else None), task_runs_dict | |
def filter_with_metric(df, selected_runs, metric_name): | |
if df is None or not selected_runs or not metric_name: | |
return None | |
kept_metrics = [f"metric_{metric_name}_{run_name}" for run_name in selected_runs] | |
other_metrics = [col for col in df.columns if col.startswith(f"metric_") and col not in kept_metrics] | |
df = df.drop(columns=other_metrics) | |
widths = get_column_widths(df) | |
df = consize_runname_metric(df, selected_runs, metric_name) | |
return gr.update(value=df, column_widths=widths) | |
def get_column_widths(df): | |
column_widths = [] | |
for col in df.columns: | |
if col == "full_prompt": | |
column_widths.append("300px") | |
elif col in ["choices", "gold"]: | |
column_widths.append("250px") | |
elif col.startswith("metric_"): | |
column_widths.append("100px") | |
else: | |
column_widths.append("200px") # Default width for other columns | |
return column_widths | |
def consize_runname_metric(df, run_names, metric_name): | |
""" | |
Turns metric columns (metric_{metric}_{run_name}) into {metric}_i | |
""" | |
# Initialize the new column with empty strings | |
for idx, run_name in enumerate(run_names): | |
original_column = f"metric_{metric_name}_{run_name}" | |
if original_column in df.columns: | |
# Append the run name and metric value to the concise column | |
df[f"{metric_name}_{idx}"] = df[original_column] | |
df = df.drop(columns=[original_column]) | |
return df | |
def load_task_data(repo_name, runs_to_fetch, checkpoint, task_name, tasks_files, progress=gr.Progress()): | |
token = os.environ.get(FALLBACK_TOKEN_NAME) | |
if not runs_to_fetch or not task_name: | |
return None, None, None | |
def fetch_run_file(run_to_fetch): | |
file_path = f"details/{run_to_fetch}/{checkpoint}/{tasks_files[task_name][run_to_fetch]}" | |
try: | |
cached_path = huggingface_hub.hf_hub_download(repo_name, file_path, token=token) | |
df = pd.read_parquet(cached_path) | |
return df, run_to_fetch | |
except EntryNotFoundError: | |
print(f"File not found: {file_path}") | |
return None, run_to_fetch | |
with ThreadPoolExecutor() as pool: | |
results = list(progress.tqdm(pool.map(fetch_run_file, runs_to_fetch), total=len(runs_to_fetch), | |
desc="Fetching run data...")) | |
dfs = [fix_df(df) for df, _ in results if df is not None] | |
run_names = [run for _, run in results if run is not None] | |
if not dfs: | |
return None, None, gr.update(choices=[], value=None) | |
task_type = get_task_type(dfs[0]) | |
def prepare_df(df, run_name, task_type): | |
def get_choice_predictions(df, task_type): | |
# For some evals it's string for other it's list | |
predictions = df['predictions'] | |
if task_type == "generative": | |
return predictions | |
if task_type == "multiple_choice": | |
n_choices = len(df['choices']) | |
return df['choices'][np.argmax([pred[0] for pred in predictions[:n_choices]])] | |
if task_type == "mixed": | |
return predictions[0] | |
return predictions | |
prepared_df = pd.DataFrame({ | |
'full_prompt': df['full_prompt'], | |
f'{run_name}': df.apply(partial(get_choice_predictions, task_type=task_type), axis=1) | |
}) | |
# For some reason some metrics are stored as strings | |
metrics = df['metrics'] | |
# Assume all metrics are the same | |
for metric_key in metrics[0].keys(): | |
prepared_df[f'metric_{metric_key}_{run_name}'] = [metric[metric_key] for metric in metrics] | |
return prepared_df.set_index('full_prompt') | |
def get_gold_label(df, task_type): | |
if task_type == "generative": | |
return df['gold'] | |
return [df['choices'][idx] for idx in df['gold_index']] | |
# Prepare the first DataFrame with choices and gold | |
combined_df = dfs[0][['full_prompt', 'choices']].set_index('full_prompt') | |
combined_df['gold'] = dfs[0].apply(lambda row: get_gold_label(row, task_type), axis=1).values | |
# Join all prepared DataFrames | |
for df, run_name in zip(dfs, run_names): | |
prepared_df = prepare_df(df, run_name, task_type) | |
combined_df = combined_df.join(prepared_df, how='outer', ) | |
available_metrics = list(set("_".join(col.split('_')[1:-1]) for col in combined_df.columns if col.startswith("metric_"))) | |
combined_df = combined_df.reset_index() | |
return combined_df, filter_with_metric(combined_df, runs_to_fetch, available_metrics[0]), gr.update(choices=available_metrics, value=available_metrics[0]) | |
def render_results_table(df: pd.DataFrame): | |
if df is None or df.empty: | |
return None | |
# Select a subset of 100 examples | |
df_subset = df.sample(n=min(100, len(df)), random_state=42) | |
# Prepare the data for display | |
display_data = [] | |
for _, row in df_subset.iterrows(): | |
example_data = { | |
'text': row['example'], | |
'choices': row['choices'], | |
'gold_index': row['gold_index'], | |
} | |
for run in df['run'].unique(): | |
run_data = df[(df['run'] == run) & (df['example'] == row['example'])] | |
if not run_data.empty: | |
example_data[f'{run}_prediction'] = run_data['predictions'].values[0] | |
example_data[f'{run}_score'] = run_data['metrics'].values[0] | |
display_data.append(example_data) | |
return pd.DataFrame(display_data) | |
with gr.Blocks() as demo: | |
runs_checkpoints = gr.State({}) | |
results_df_full = gr.State(None) | |
tasks_files = gr.State({}) | |
login_button = gr.LoginButton(visible=False) | |
repo = gr.Textbox(label="HF Repo", value="HuggingFaceFW-Dev/multiligual-ablation-logs-dev", visible=True) | |
with gr.Column(): | |
gr.Markdown("# FineWeb experiments results explorer") | |
with gr.Row(): | |
with gr.Column(): | |
select_by_regex_text = gr.Textbox(label="Regex to select runs", | |
value="ind_minhash(-CC-MAIN-|_)\\d{4}-\\d{2}-seed.*") | |
select_by_regex_button = gr.Button("Select matching runs") | |
with gr.Column(): | |
select_by_language = gr.Dropdown(choices=["ar", "fr", "ru", "hi", "th", "tr", "zh", "sw", "te"], | |
interactive=True, label="Select by language", | |
info="Choose a language to prefill the regex") | |
selected_runs = gr.Dropdown(choices=[], interactive=True, multiselect=True, label="Selected runs") | |
checkpoint = gr.Dropdown(choices=[], interactive=True, label="Checkpoint") | |
fetch_res = gr.Button("Fetch results") | |
task_name = gr.Dropdown(choices=[], interactive=True, label="Task name") | |
metric_name = gr.Dropdown(choices=[], interactive=True, label="Metric") | |
results_df = gr.Dataframe(interactive=False, wrap=True) | |
# Run selection | |
gr.on( | |
triggers=[repo.change], | |
fn=fetch_repo_structure, inputs=[repo], outputs=[runs_checkpoints, selected_runs], | |
) | |
gr.on( | |
triggers=[select_by_regex_button.click], | |
fn=select_runs_by_regex, | |
inputs=[runs_checkpoints, selected_runs, select_by_regex_text], outputs=[selected_runs] | |
) | |
gr.on( | |
triggers=[select_by_language.change], | |
fn=select_runs_by_language, | |
inputs=[runs_checkpoints, selected_runs, select_by_language], outputs=[selected_runs] | |
) | |
# Update checkpoints based on selected runs | |
gr.on( | |
triggers=[selected_runs.change], | |
fn=update_checkpoints, | |
inputs=[selected_runs, runs_checkpoints], | |
outputs=[checkpoint] | |
) | |
# Fetch available tasks | |
gr.on( | |
triggers=[fetch_res.click], | |
fn=fetch_run_results, | |
inputs=[repo, selected_runs, checkpoint], | |
outputs=[task_name, tasks_files] | |
) | |
# Update results when task name or metric changes | |
gr.on( | |
triggers=[task_name.change], | |
fn=load_task_data, | |
inputs=[repo, selected_runs, checkpoint, task_name, tasks_files], | |
outputs=[results_df_full, results_df, metric_name] | |
) | |
gr.on( | |
triggers=[metric_name.change], | |
fn=filter_with_metric, | |
inputs=[results_df_full, selected_runs, metric_name], | |
outputs=[results_df] | |
) | |
demo.load(fn=fetch_repo_structure, inputs=[repo], outputs=[runs_checkpoints, selected_runs]) | |
demo.launch() |