{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fetch the data from the hub"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import itertools\n",
"import pandas as pd\n",
"from concurrent.futures import ThreadPoolExecutor\n",
"from tqdm import tqdm\n",
"import itertools\n",
"import huggingface_hub\n",
"from tensorboard.backend.event_processing.event_accumulator import EventAccumulator\n",
"from huggingface_hub.utils import EntryNotFoundError"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"def step_element_match(step_to_check, step_element):\n",
" step_element = step_element.strip().replace(\" \", \"\")\n",
" if \"-\" in step_element:\n",
" a, b = step_element.split(\"-\")\n",
" c = None\n",
" if \"%\" in b:\n",
" b, c = b.split(\"%\")\n",
" return (int(a) <= step_to_check <= int(b) and\n",
" (c is None or (step_to_check - int(a)) % int(c) == 0))\n",
" elif \"%\" in step_element:\n",
" return step_to_check % int(step_element[1:]) == 0\n",
" else:\n",
" return step_to_check == int(step_element)\n",
" \n",
"def fetch_run_results_simple(repo_name, runs_to_fetch, steps_to_fetch, prefix, agg_score_columns, column_name,\n",
" seed_merge_method, oauth_token=None, prefix_file=None):\n",
" if not runs_to_fetch:\n",
" return\n",
"\n",
" def fetch_run_files(run_to_fetch):\n",
" def filename_to_steps_timestamp(fn):\n",
" step, ts = fn.split(\"_events.out.tfevents.\")\n",
" return int(step[-7:]), int(ts[:ts.index(\".\")])\n",
"\n",
" run_to_fetch += \"_e\"\n",
" try:\n",
" eval_repo_file_names = [f.path for f in\n",
" huggingface_hub.list_repo_tree(repo_name, run_to_fetch, expand=False,\n",
" token=oauth_token) if\n",
" \"_events.out.tfevents\" in f.path]\n",
" except EntryNotFoundError:\n",
" return []\n",
"\n",
" eval_files = [os.path.relpath(f, run_to_fetch) for f in eval_repo_file_names]\n",
" timestamps = {}\n",
" for fn in eval_files:\n",
" steps, ts = filename_to_steps_timestamp(fn)\n",
" if steps not in timestamps or timestamps[steps][0] < ts:\n",
" timestamps[steps] = ts, fn\n",
"\n",
" results = []\n",
" for eval_file, repofile in zip(eval_files, eval_repo_file_names):\n",
" steps, ts = filename_to_steps_timestamp(eval_file)\n",
" if not any(step_element_match(steps, step_el) for step_el in steps_to_fetch.split(\",\")):\n",
" continue\n",
" if timestamps[steps][1] == eval_file:\n",
" results.append((run_to_fetch, steps, repofile))\n",
" return results\n",
"\n",
" def load_run_file(data):\n",
" run_to_fetch, steps, repofile = data\n",
" loader = EventAccumulator(huggingface_hub.hf_hub_download(repo_name, repofile, token=oauth_token))\n",
" loader.Reload()\n",
" runname = run_to_fetch.removeprefix(prefix).removesuffix(\"-_e\")\n",
" column_names = [\"runname\", \"seed\", \"steps\", \"agg_score\"]\n",
" column_values = [runname, 0, steps, 0.0]\n",
"\n",
" for tag in loader.Tags()['scalars']:\n",
" if not \"stderr\" in tag and tag.split('/')[0] == 'e':\n",
" event_list = loader.Scalars(tag)\n",
" tag = tag.split('/')\n",
" column_names.append(f\"{tag[1]}/{tag[2]}\")\n",
" column_values.append(event_list[-1].value)\n",
"\n",
" return pd.DataFrame([column_values], columns=column_names)\n",
"\n",
" with ThreadPoolExecutor() as pool:\n",
" run_files = list(itertools.chain.from_iterable(\n",
" tqdm(pool.map(fetch_run_files, runs_to_fetch), total=len(runs_to_fetch), desc=\"Fetching datafiles...\")))\n",
" df = pd.concat(tqdm(pool.map(load_run_file, run_files), total=len(run_files), desc=\"Loading evals data...\"))\n",
"\n",
" cols_to_avg = [col for col in agg_score_columns if col in df.columns]\n",
" if cols_to_avg:\n",
" df['agg_score'] = df[cols_to_avg].mean(axis=1)\n",
"\n",
" prefix_file = prefix_file + \"_\" if prefix_file else \"\"\n",
" df.to_csv(f\"{prefix_file}{repo_name.split('/')[-1]}_metrics.csv\", index=False)\n",
" print(f\"Metrics saved to {repo_name.split('/')[-1]}_metrics.csv\")\n",
"\n",
" return df"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fetching datafiles...: 100%|██████████| 1/1 [00:02<00:00, 2.94s/it]\n",
"Loading evals data...: 100%|██████████| 82/82 [00:15<00:00, 5.37it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Metrics saved to loubna-edu_fw_ablations_metrics.csv\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" runname | \n",
" seed | \n",
" steps | \n",
" agg_score | \n",
" commonsense_qa/acc | \n",
" commonsense_qa/acc_norm | \n",
" hellaswag/acc | \n",
" hellaswag/acc_norm | \n",
" openbookqa/acc | \n",
" openbookqa/acc_norm | \n",
" ... | \n",
" siqa/acc | \n",
" siqa/acc_norm | \n",
" winogrande/acc | \n",
" winogrande/acc_norm | \n",
" all/acc | \n",
" all/acc_norm | \n",
" arc/acc | \n",
" arc/acc_norm | \n",
" mmlu/acc | \n",
" mmlu/acc_norm | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" edu_fineweb_350b_tokens-seed-1 | \n",
" 0 | \n",
" 2000 | \n",
" 0.390326 | \n",
" 0.284 | \n",
" 0.283 | \n",
" 0.314 | \n",
" 0.325 | \n",
" 0.164 | \n",
" 0.296 | \n",
" ... | \n",
" 0.362 | \n",
" 0.406 | \n",
" 0.511 | \n",
" 0.511 | \n",
" 0.279674 | \n",
" 0.299162 | \n",
" 0.3795 | \n",
" 0.3850 | \n",
" 0.265997 | \n",
" 0.284605 | \n",
"
\n",
" \n",
" 0 | \n",
" edu_fineweb_350b_tokens-seed-1 | \n",
" 0 | \n",
" 4000 | \n",
" 0.414680 | \n",
" 0.322 | \n",
" 0.307 | \n",
" 0.343 | \n",
" 0.395 | \n",
" 0.196 | \n",
" 0.320 | \n",
" ... | \n",
" 0.371 | \n",
" 0.388 | \n",
" 0.518 | \n",
" 0.495 | \n",
" 0.290613 | \n",
" 0.312593 | \n",
" 0.4215 | \n",
" 0.4285 | \n",
" 0.274401 | \n",
" 0.295939 | \n",
"
\n",
" \n",
" 0 | \n",
" edu_fineweb_350b_tokens-seed-1 | \n",
" 0 | \n",
" 6000 | \n",
" 0.428390 | \n",
" 0.319 | \n",
" 0.311 | \n",
" 0.372 | \n",
" 0.431 | \n",
" 0.202 | \n",
" 0.352 | \n",
" ... | \n",
" 0.373 | \n",
" 0.392 | \n",
" 0.520 | \n",
" 0.519 | \n",
" 0.303980 | \n",
" 0.323323 | \n",
" 0.4315 | \n",
" 0.4460 | \n",
" 0.288591 | \n",
" 0.306123 | \n",
"
\n",
" \n",
" 0 | \n",
" edu_fineweb_350b_tokens-seed-1 | \n",
" 0 | \n",
" 8000 | \n",
" 0.443615 | \n",
" 0.340 | \n",
" 0.311 | \n",
" 0.379 | \n",
" 0.463 | \n",
" 0.204 | \n",
" 0.360 | \n",
" ... | \n",
" 0.384 | \n",
" 0.404 | \n",
" 0.517 | \n",
" 0.517 | \n",
" 0.315148 | \n",
" 0.333284 | \n",
" 0.4630 | \n",
" 0.4790 | \n",
" 0.299186 | \n",
" 0.314921 | \n",
"
\n",
" \n",
" 0 | \n",
" edu_fineweb_350b_tokens-seed-1 | \n",
" 0 | \n",
" 10000 | \n",
" 0.441457 | \n",
" 0.346 | \n",
" 0.317 | \n",
" 0.390 | \n",
" 0.454 | \n",
" 0.222 | \n",
" 0.364 | \n",
" ... | \n",
" 0.366 | \n",
" 0.395 | \n",
" 0.514 | \n",
" 0.506 | \n",
" 0.318935 | \n",
" 0.335419 | \n",
" 0.4890 | \n",
" 0.4820 | \n",
" 0.302189 | \n",
" 0.317653 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 0 | \n",
" edu_fineweb_350b_tokens-seed-1 | \n",
" 0 | \n",
" 160000 | \n",
" 0.507129 | \n",
" 0.430 | \n",
" 0.359 | \n",
" 0.473 | \n",
" 0.593 | \n",
" 0.282 | \n",
" 0.418 | \n",
" ... | \n",
" 0.392 | \n",
" 0.402 | \n",
" 0.576 | \n",
" 0.575 | \n",
" 0.369137 | \n",
" 0.393898 | \n",
" 0.5670 | \n",
" 0.5725 | \n",
" 0.350226 | \n",
" 0.374533 | \n",
"
\n",
" \n",
" 0 | \n",
" edu_fineweb_350b_tokens-seed-1 | \n",
" 0 | \n",
" 162000 | \n",
" 0.509118 | \n",
" 0.416 | \n",
" 0.367 | \n",
" 0.474 | \n",
" 0.592 | \n",
" 0.288 | \n",
" 0.408 | \n",
" ... | \n",
" 0.390 | \n",
" 0.409 | \n",
" 0.572 | \n",
" 0.577 | \n",
" 0.367420 | \n",
" 0.392861 | \n",
" 0.5720 | \n",
" 0.5780 | \n",
" 0.348268 | \n",
" 0.372947 | \n",
"
\n",
" \n",
" 0 | \n",
" edu_fineweb_350b_tokens-seed-1 | \n",
" 0 | \n",
" 164000 | \n",
" 0.507843 | \n",
" 0.416 | \n",
" 0.365 | \n",
" 0.467 | \n",
" 0.591 | \n",
" 0.276 | \n",
" 0.408 | \n",
" ... | \n",
" 0.395 | \n",
" 0.406 | \n",
" 0.576 | \n",
" 0.580 | \n",
" 0.368319 | \n",
" 0.392000 | \n",
" 0.5635 | \n",
" 0.5715 | \n",
" 0.349943 | \n",
" 0.372246 | \n",
"
\n",
" \n",
" 0 | \n",
" edu_fineweb_350b_tokens-seed-1 | \n",
" 0 | \n",
" 166000 | \n",
" 0.508308 | \n",
" 0.415 | \n",
" 0.364 | \n",
" 0.472 | \n",
" 0.593 | \n",
" 0.282 | \n",
" 0.414 | \n",
" ... | \n",
" 0.401 | \n",
" 0.408 | \n",
" 0.575 | \n",
" 0.570 | \n",
" 0.370593 | \n",
" 0.393176 | \n",
" 0.5640 | \n",
" 0.5760 | \n",
" 0.352203 | \n",
" 0.373463 | \n",
"
\n",
" \n",
" 0 | \n",
" edu_fineweb_350b_tokens-seed-1 | \n",
" 0 | \n",
" 167000 | \n",
" 0.509494 | \n",
" 0.429 | \n",
" 0.362 | \n",
" 0.472 | \n",
" 0.597 | \n",
" 0.290 | \n",
" 0.418 | \n",
" ... | \n",
" 0.395 | \n",
" 0.404 | \n",
" 0.582 | \n",
" 0.578 | \n",
" 0.369666 | \n",
" 0.394136 | \n",
" 0.5670 | \n",
" 0.5735 | \n",
" 0.350671 | \n",
" 0.374453 | \n",
"
\n",
" \n",
"
\n",
"
82 rows × 22 columns
\n",
"
"
],
"text/plain": [
" runname seed steps agg_score \\\n",
"0 edu_fineweb_350b_tokens-seed-1 0 2000 0.390326 \n",
"0 edu_fineweb_350b_tokens-seed-1 0 4000 0.414680 \n",
"0 edu_fineweb_350b_tokens-seed-1 0 6000 0.428390 \n",
"0 edu_fineweb_350b_tokens-seed-1 0 8000 0.443615 \n",
"0 edu_fineweb_350b_tokens-seed-1 0 10000 0.441457 \n",
".. ... ... ... ... \n",
"0 edu_fineweb_350b_tokens-seed-1 0 160000 0.507129 \n",
"0 edu_fineweb_350b_tokens-seed-1 0 162000 0.509118 \n",
"0 edu_fineweb_350b_tokens-seed-1 0 164000 0.507843 \n",
"0 edu_fineweb_350b_tokens-seed-1 0 166000 0.508308 \n",
"0 edu_fineweb_350b_tokens-seed-1 0 167000 0.509494 \n",
"\n",
" commonsense_qa/acc commonsense_qa/acc_norm hellaswag/acc \\\n",
"0 0.284 0.283 0.314 \n",
"0 0.322 0.307 0.343 \n",
"0 0.319 0.311 0.372 \n",
"0 0.340 0.311 0.379 \n",
"0 0.346 0.317 0.390 \n",
".. ... ... ... \n",
"0 0.430 0.359 0.473 \n",
"0 0.416 0.367 0.474 \n",
"0 0.416 0.365 0.467 \n",
"0 0.415 0.364 0.472 \n",
"0 0.429 0.362 0.472 \n",
"\n",
" hellaswag/acc_norm openbookqa/acc openbookqa/acc_norm ... siqa/acc \\\n",
"0 0.325 0.164 0.296 ... 0.362 \n",
"0 0.395 0.196 0.320 ... 0.371 \n",
"0 0.431 0.202 0.352 ... 0.373 \n",
"0 0.463 0.204 0.360 ... 0.384 \n",
"0 0.454 0.222 0.364 ... 0.366 \n",
".. ... ... ... ... ... \n",
"0 0.593 0.282 0.418 ... 0.392 \n",
"0 0.592 0.288 0.408 ... 0.390 \n",
"0 0.591 0.276 0.408 ... 0.395 \n",
"0 0.593 0.282 0.414 ... 0.401 \n",
"0 0.597 0.290 0.418 ... 0.395 \n",
"\n",
" siqa/acc_norm winogrande/acc winogrande/acc_norm all/acc \\\n",
"0 0.406 0.511 0.511 0.279674 \n",
"0 0.388 0.518 0.495 0.290613 \n",
"0 0.392 0.520 0.519 0.303980 \n",
"0 0.404 0.517 0.517 0.315148 \n",
"0 0.395 0.514 0.506 0.318935 \n",
".. ... ... ... ... \n",
"0 0.402 0.576 0.575 0.369137 \n",
"0 0.409 0.572 0.577 0.367420 \n",
"0 0.406 0.576 0.580 0.368319 \n",
"0 0.408 0.575 0.570 0.370593 \n",
"0 0.404 0.582 0.578 0.369666 \n",
"\n",
" all/acc_norm arc/acc arc/acc_norm mmlu/acc mmlu/acc_norm \n",
"0 0.299162 0.3795 0.3850 0.265997 0.284605 \n",
"0 0.312593 0.4215 0.4285 0.274401 0.295939 \n",
"0 0.323323 0.4315 0.4460 0.288591 0.306123 \n",
"0 0.333284 0.4630 0.4790 0.299186 0.314921 \n",
"0 0.335419 0.4890 0.4820 0.302189 0.317653 \n",
".. ... ... ... ... ... \n",
"0 0.393898 0.5670 0.5725 0.350226 0.374533 \n",
"0 0.392861 0.5720 0.5780 0.348268 0.372947 \n",
"0 0.392000 0.5635 0.5715 0.349943 0.372246 \n",
"0 0.393176 0.5640 0.5760 0.352203 0.373463 \n",
"0 0.394136 0.5670 0.5735 0.350671 0.374453 \n",
"\n",
"[82 rows x 22 columns]"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token = os.getenv(\"HF_TOKEN\")\n",
"repo_name = \"HuggingFaceTB/loubna-edu_fw_ablations\"\n",
"runs_to_fetch = [\"tb/edu_fw_ablations-1p82G-edu_fineweb_350b_tokens-seed-1-\"]\n",
"steps_to_fetch = \"%1000\"\n",
"prefix = \"tb/edu_fw_ablations-1p82G-\"\n",
"metrics = ['commonsense_qa/acc_norm', 'hellaswag/acc_norm', 'openbookqa/acc_norm', 'piqa/acc_norm',\n",
" 'siqa/acc_norm', 'winogrande/acc_norm', 'arc/acc_norm', 'mmlu/acc_norm']\n",
"agg_score_columns = metrics\n",
"column_name = \"agg_score\"\n",
"seed_merge_method = \"mean\"\n",
"oauth_token = token\n",
"\n",
"# runs_to_fetch = [prefix + run for run in runs_to_fetch]\n",
"fetch_run_results_simple(repo_name, runs_to_fetch, steps_to_fetch, prefix, agg_score_columns, column_name, seed_merge_method, oauth_token=token)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot the data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load csvs for FW and FW-Edu"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" runname | \n",
" steps | \n",
" agg_score | \n",
" commonsense_qa/acc | \n",
" commonsense_qa/acc_norm | \n",
" hellaswag/acc | \n",
" hellaswag/acc_norm | \n",
" openbookqa/acc | \n",
" openbookqa/acc_norm | \n",
" piqa/acc | \n",
" piqa/acc_norm | \n",
" siqa/acc | \n",
" siqa/acc_norm | \n",
" winogrande/acc | \n",
" winogrande/acc_norm | \n",
" arc/acc | \n",
" arc/acc_norm | \n",
" mmlu/acc | \n",
" mmlu/acc_norm | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" FineWeb-Edu | \n",
" 2000 | \n",
" 0.390326 | \n",
" 0.284 | \n",
" 0.283 | \n",
" 0.314 | \n",
" 0.325 | \n",
" 0.164 | \n",
" 0.296 | \n",
" 0.623 | \n",
" 0.632 | \n",
" 0.362 | \n",
" 0.406 | \n",
" 0.511 | \n",
" 0.511 | \n",
" 0.3795 | \n",
" 0.3850 | \n",
" 0.265997 | \n",
" 0.284605 | \n",
"
\n",
" \n",
" 1 | \n",
" FineWeb-Edu | \n",
" 4000 | \n",
" 0.414680 | \n",
" 0.322 | \n",
" 0.307 | \n",
" 0.343 | \n",
" 0.395 | \n",
" 0.196 | \n",
" 0.320 | \n",
" 0.656 | \n",
" 0.688 | \n",
" 0.371 | \n",
" 0.388 | \n",
" 0.518 | \n",
" 0.495 | \n",
" 0.4215 | \n",
" 0.4285 | \n",
" 0.274401 | \n",
" 0.295939 | \n",
"
\n",
" \n",
" 2 | \n",
" FineWeb-Edu | \n",
" 6000 | \n",
" 0.428390 | \n",
" 0.319 | \n",
" 0.311 | \n",
" 0.372 | \n",
" 0.431 | \n",
" 0.202 | \n",
" 0.352 | \n",
" 0.660 | \n",
" 0.670 | \n",
" 0.373 | \n",
" 0.392 | \n",
" 0.520 | \n",
" 0.519 | \n",
" 0.4315 | \n",
" 0.4460 | \n",
" 0.288591 | \n",
" 0.306123 | \n",
"
\n",
" \n",
" 3 | \n",
" FineWeb-Edu | \n",
" 8000 | \n",
" 0.443615 | \n",
" 0.340 | \n",
" 0.311 | \n",
" 0.379 | \n",
" 0.463 | \n",
" 0.204 | \n",
" 0.360 | \n",
" 0.681 | \n",
" 0.700 | \n",
" 0.384 | \n",
" 0.404 | \n",
" 0.517 | \n",
" 0.517 | \n",
" 0.4630 | \n",
" 0.4790 | \n",
" 0.299186 | \n",
" 0.314921 | \n",
"
\n",
" \n",
" 4 | \n",
" FineWeb-Edu | \n",
" 10000 | \n",
" 0.441457 | \n",
" 0.346 | \n",
" 0.317 | \n",
" 0.390 | \n",
" 0.454 | \n",
" 0.222 | \n",
" 0.364 | \n",
" 0.690 | \n",
" 0.696 | \n",
" 0.366 | \n",
" 0.395 | \n",
" 0.514 | \n",
" 0.506 | \n",
" 0.4890 | \n",
" 0.4820 | \n",
" 0.302189 | \n",
" 0.317653 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" runname steps agg_score commonsense_qa/acc commonsense_qa/acc_norm \\\n",
"0 FineWeb-Edu 2000 0.390326 0.284 0.283 \n",
"1 FineWeb-Edu 4000 0.414680 0.322 0.307 \n",
"2 FineWeb-Edu 6000 0.428390 0.319 0.311 \n",
"3 FineWeb-Edu 8000 0.443615 0.340 0.311 \n",
"4 FineWeb-Edu 10000 0.441457 0.346 0.317 \n",
"\n",
" hellaswag/acc hellaswag/acc_norm openbookqa/acc openbookqa/acc_norm \\\n",
"0 0.314 0.325 0.164 0.296 \n",
"1 0.343 0.395 0.196 0.320 \n",
"2 0.372 0.431 0.202 0.352 \n",
"3 0.379 0.463 0.204 0.360 \n",
"4 0.390 0.454 0.222 0.364 \n",
"\n",
" piqa/acc piqa/acc_norm siqa/acc siqa/acc_norm winogrande/acc \\\n",
"0 0.623 0.632 0.362 0.406 0.511 \n",
"1 0.656 0.688 0.371 0.388 0.518 \n",
"2 0.660 0.670 0.373 0.392 0.520 \n",
"3 0.681 0.700 0.384 0.404 0.517 \n",
"4 0.690 0.696 0.366 0.395 0.514 \n",
"\n",
" winogrande/acc_norm arc/acc arc/acc_norm mmlu/acc mmlu/acc_norm \n",
"0 0.511 0.3795 0.3850 0.265997 0.284605 \n",
"1 0.495 0.4215 0.4285 0.274401 0.295939 \n",
"2 0.519 0.4315 0.4460 0.288591 0.306123 \n",
"3 0.517 0.4630 0.4790 0.299186 0.314921 \n",
"4 0.506 0.4890 0.4820 0.302189 0.317653 "
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"# load guilherme csv with all the FW runs\n",
"df = pd.read_csv(\"../src_data/eval_results.csv\")\n",
"\n",
"# load FineWeb-Edu csv\n",
"df_2 = pd.read_csv(\"./loubna-edu_fw_ablations_metrics.csv\")\n",
"df_2['runname'] = df_2['runname'].replace('edu_fineweb_350b_tokens-seed-1', 'FineWeb-Edu', regex=True)\n",
"df_2.drop([\"seed\", \"all/acc\", \"all/acc_norm\"], axis=1, inplace=True)\n",
"df_2.head()"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" runname | \n",
" steps | \n",
" agg_score | \n",
" commonsense_qa/acc | \n",
" commonsense_qa/acc_norm | \n",
" hellaswag/acc | \n",
" hellaswag/acc_norm | \n",
" openbookqa/acc | \n",
" openbookqa/acc_norm | \n",
" piqa/acc | \n",
" ... | \n",
" siqa/acc | \n",
" siqa/acc_norm | \n",
" winogrande/acc | \n",
" winogrande/acc_norm | \n",
" sciq/acc | \n",
" sciq/acc_norm | \n",
" arc/acc | \n",
" arc/acc_norm | \n",
" mmlu/acc | \n",
" mmlu/acc_norm | \n",
"
\n",
" \n",
" \n",
" \n",
" 1253 | \n",
" FineWeb-Edu | \n",
" 160000 | \n",
" 0.507129 | \n",
" 0.430 | \n",
" 0.359 | \n",
" 0.473 | \n",
" 0.593 | \n",
" 0.282 | \n",
" 0.418 | \n",
" 0.744 | \n",
" ... | \n",
" 0.392 | \n",
" 0.402 | \n",
" 0.576 | \n",
" 0.575 | \n",
" NaN | \n",
" NaN | \n",
" 0.5670 | \n",
" 0.5725 | \n",
" 0.350226 | \n",
" 0.374533 | \n",
"
\n",
" \n",
" 1254 | \n",
" FineWeb-Edu | \n",
" 162000 | \n",
" 0.509118 | \n",
" 0.416 | \n",
" 0.367 | \n",
" 0.474 | \n",
" 0.592 | \n",
" 0.288 | \n",
" 0.408 | \n",
" 0.747 | \n",
" ... | \n",
" 0.390 | \n",
" 0.409 | \n",
" 0.572 | \n",
" 0.577 | \n",
" NaN | \n",
" NaN | \n",
" 0.5720 | \n",
" 0.5780 | \n",
" 0.348268 | \n",
" 0.372947 | \n",
"
\n",
" \n",
" 1255 | \n",
" FineWeb-Edu | \n",
" 164000 | \n",
" 0.507843 | \n",
" 0.416 | \n",
" 0.365 | \n",
" 0.467 | \n",
" 0.591 | \n",
" 0.276 | \n",
" 0.408 | \n",
" 0.737 | \n",
" ... | \n",
" 0.395 | \n",
" 0.406 | \n",
" 0.576 | \n",
" 0.580 | \n",
" NaN | \n",
" NaN | \n",
" 0.5635 | \n",
" 0.5715 | \n",
" 0.349943 | \n",
" 0.372246 | \n",
"
\n",
" \n",
" 1256 | \n",
" FineWeb-Edu | \n",
" 166000 | \n",
" 0.508308 | \n",
" 0.415 | \n",
" 0.364 | \n",
" 0.472 | \n",
" 0.593 | \n",
" 0.282 | \n",
" 0.414 | \n",
" 0.740 | \n",
" ... | \n",
" 0.401 | \n",
" 0.408 | \n",
" 0.575 | \n",
" 0.570 | \n",
" NaN | \n",
" NaN | \n",
" 0.5640 | \n",
" 0.5760 | \n",
" 0.352203 | \n",
" 0.373463 | \n",
"
\n",
" \n",
" 1257 | \n",
" FineWeb-Edu | \n",
" 167000 | \n",
" 0.509494 | \n",
" 0.429 | \n",
" 0.362 | \n",
" 0.472 | \n",
" 0.597 | \n",
" 0.290 | \n",
" 0.418 | \n",
" 0.738 | \n",
" ... | \n",
" 0.395 | \n",
" 0.404 | \n",
" 0.582 | \n",
" 0.578 | \n",
" NaN | \n",
" NaN | \n",
" 0.5670 | \n",
" 0.5735 | \n",
" 0.350671 | \n",
" 0.374453 | \n",
"
\n",
" \n",
"
\n",
"
5 rows × 21 columns
\n",
"
"
],
"text/plain": [
" runname steps agg_score commonsense_qa/acc \\\n",
"1253 FineWeb-Edu 160000 0.507129 0.430 \n",
"1254 FineWeb-Edu 162000 0.509118 0.416 \n",
"1255 FineWeb-Edu 164000 0.507843 0.416 \n",
"1256 FineWeb-Edu 166000 0.508308 0.415 \n",
"1257 FineWeb-Edu 167000 0.509494 0.429 \n",
"\n",
" commonsense_qa/acc_norm hellaswag/acc hellaswag/acc_norm \\\n",
"1253 0.359 0.473 0.593 \n",
"1254 0.367 0.474 0.592 \n",
"1255 0.365 0.467 0.591 \n",
"1256 0.364 0.472 0.593 \n",
"1257 0.362 0.472 0.597 \n",
"\n",
" openbookqa/acc openbookqa/acc_norm piqa/acc ... siqa/acc \\\n",
"1253 0.282 0.418 0.744 ... 0.392 \n",
"1254 0.288 0.408 0.747 ... 0.390 \n",
"1255 0.276 0.408 0.737 ... 0.395 \n",
"1256 0.282 0.414 0.740 ... 0.401 \n",
"1257 0.290 0.418 0.738 ... 0.395 \n",
"\n",
" siqa/acc_norm winogrande/acc winogrande/acc_norm sciq/acc \\\n",
"1253 0.402 0.576 0.575 NaN \n",
"1254 0.409 0.572 0.577 NaN \n",
"1255 0.406 0.576 0.580 NaN \n",
"1256 0.408 0.575 0.570 NaN \n",
"1257 0.404 0.582 0.578 NaN \n",
"\n",
" sciq/acc_norm arc/acc arc/acc_norm mmlu/acc mmlu/acc_norm \n",
"1253 NaN 0.5670 0.5725 0.350226 0.374533 \n",
"1254 NaN 0.5720 0.5780 0.348268 0.372947 \n",
"1255 NaN 0.5635 0.5715 0.349943 0.372246 \n",
"1256 NaN 0.5640 0.5760 0.352203 0.373463 \n",
"1257 NaN 0.5670 0.5735 0.350671 0.374453 \n",
"\n",
"[5 rows x 21 columns]"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_full = pd.concat([df, df_2], ignore_index=True)\n",
"df_full.tail()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Guilherme-Board plot"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" steps | \n",
"
\n",
" \n",
" runname | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" C4 | \n",
" 168 | \n",
"
\n",
" \n",
" Dolma | \n",
" 168 | \n",
"
\n",
" \n",
" FineWeb | \n",
" 168 | \n",
"
\n",
" \n",
" FineWeb-Edu | \n",
" 82 | \n",
"
\n",
" \n",
" RedPajama2 | \n",
" 168 | \n",
"
\n",
" \n",
" RefinedWeb | \n",
" 168 | \n",
"
\n",
" \n",
" SlimPajama | \n",
" 168 | \n",
"
\n",
" \n",
" The Pile | \n",
" 168 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" steps\n",
"runname \n",
"C4 168\n",
"Dolma 168\n",
"FineWeb 168\n",
"FineWeb-Edu 82\n",
"RedPajama2 168\n",
"RefinedWeb 168\n",
"SlimPajama 168\n",
"The Pile 168"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_full.groupby(\"runname\").agg({\"steps\": \"count\"})"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
"fineweb_edu_steps = df_full[df_full[\"runname\"] == \"FineWeb-Edu\"][\"steps\"].unique()\n",
"# Only selects steps that are in the fineweb_edu_steps \n",
"df_full = df_full[df_full[\"steps\"].isin(fineweb_edu_steps)]"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"from matplotlib import pyplot as plt\n",
"metrics = ['agg_score', 'commonsense_qa/acc_norm', 'hellaswag/acc_norm', 'openbookqa/acc_norm', 'piqa/acc_norm',\n",
" 'siqa/acc_norm', 'winogrande/acc_norm', 'arc/acc_norm', 'mmlu/acc_norm']\n",
"\n",
"def normalize_runname(runname):\n",
" return runname.replace(\"/\", \"_\")\n",
"\n",
"grouped = (\n",
" df_full.groupby([\"runname\", \"steps\"])\n",
" .agg(\n",
" {\n",
" key: \"mean\" for key in metrics\n",
" }\n",
" )\n",
" .reset_index()\n",
")\n",
"\n",
"file_id=\"../assets/data/plots/edu_ablations\"\n",
"files = {}\n",
"for metric in metrics:\n",
" datas = {}\n",
" for name, group in grouped.groupby(\"runname\"):\n",
" group = group[[\"steps\", metric]].sort_values(by=\"steps\")\n",
" group = group.set_index(\"steps\")\n",
" rolling_avg = group\n",
" # rolling_avg = group.rolling(window=5).mean()\n",
" datas[name] = {\n",
" \"x\": (rolling_avg.index * 2048 * 1024 * 1e-9).tolist(),\n",
" \"y\": rolling_avg[metric].tolist(),\n",
" \"label\": name,\n",
" }\n",
" # Sort the datata based on the steps\n",
" datas = {k: v for k, v in sorted(datas.items(), key=lambda x: -x[1][\"y\"][-1])}\n",
" # Create a folder\n",
" os.makedirs(f\"{file_id}\", exist_ok=True)\n",
" with open(f\"{file_id}/{normalize_runname(metric)}.json\", \"w\") as f:\n",
" json.dump({\n",
" \"data\": datas,\n",
" \"layout\": {\n",
" \"title\": {\n",
" \"text\": \"Dataset ablations\"\n",
" },\n",
" }\n",
" }, f)\n",
" files[metric] = {\"file\": f\"{normalize_runname(metric)}.json\"}\n",
"# Create index\n",
"with open(f\"{file_id}/index.json\", \"w\") as f:\n",
" json.dump({\n",
" \"files\": files,\n",
" \"settings\": {\n",
" \"defaultMetric\": \"agg_score\",\n",
" \"slider\":{\"min\":0,\"max\":30,\"default\":5},\n",
" \"caption\": \"📚 FineWeb-Edu outperforms 🍷 FineWeb and all other open web datasets on our group of evaluation tasks.\"\n",
" }\n",
" }, f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Barplot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: kaleido in /Users/hynky/.pyenv/versions/3.12.2/envs/datatrove/lib/python3.12/site-packages (0.2.1)\n"
]
}
],
"source": [
"!pip install -U kaleido"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Plot saved to plots/edu-100k.png\n"
]
}
],
"source": [
"import plotly.express as px\n",
"from plotly.subplots import make_subplots\n",
"import plotly.graph_objects as go\n",
"\n",
"import json\n",
"\n",
"BASELINES = {\n",
" \"mmlu/acc_norm\": 0.25,\n",
" \"arc/acc_norm\": 0.25,\n",
" \"openbookqa/acc_norm\": 0.25,\n",
" \"piqa/acc_norm\": 0.5,\n",
" \"hellaswag/acc_norm\": 0.25,\n",
" \"siqa/acc_norm\": 0.33,\n",
" \"winogrande/acc_norm\": 0.5,\n",
"}\n",
"\n",
"\n",
"def normalize_run_name(run_name):\n",
" return run_name.replace(\"/\", \"_\")\n",
"\n",
"\n",
"def save_for_bar(dir_name, df, metrics, default_metric=\"mmlu/acc_norm\", xlabel=\"Dataset\", plot_name=\"plot name\", custom_layout={}, ranges={}):\n",
" import os\n",
" files = {}\n",
" os.makedirs(f\"../assets/data/plots/{dir_name}\", exist_ok=True)\n",
" for metric in metrics:\n",
" data = {}\n",
" for run_name in df[\"runname\"].unique():\n",
" data[run_name] = {\n",
" \"x\": [run_name],\n",
" \"y\": df[df[\"runname\"] == run_name][metric].tolist(),\n",
" \"label\": run_name,\n",
" }\n",
" file_name = f\"{normalize_run_name(metric)}.json\"\n",
" files[metric] = {\"file\": f\"{file_name}\"}\n",
" with open(f\"../assets/data/plots/{dir_name}/{file_name}\", \"w\") as f:\n",
" json.dump({\n",
" \"data\": data,\n",
" \"layout\": {\n",
" \"showlegend\": False,\n",
" \"title\": {\n",
" \"text\": plot_name,\n",
" },\n",
" \"xaxis\": {\n",
" \"title\": {\n",
" \"text\": xlabel,\n",
" \"standoff\": 30\n",
" },\n",
" \"tickangle\": 30\n",
" },\n",
" \"yaxis\": {\n",
" \"range\": ranges.get(metric, [0, 1])\n",
" },\n",
" \"margin\": {\n",
" \"b\": 100\n",
" },\n",
" **custom_layout,\n",
" }\n",
" }, f)\n",
" with open(f\"../assets/data/plots/{dir_name}/index.json\", \"w\") as f:\n",
" json.dump({\n",
" \"files\": files,\n",
" \"settings\": {\n",
" \"defaultMetric\": default_metric,\n",
" \"slider\": None,\n",
" \"autoSetXRange\": False,\n",
" \"type\": \"bar\"\n",
" }\n",
" }, f)\n",
" return files\n",
"\n",
"def plot_metric_comparison(df, step, metrics, plot_name, run_name_replacements=None, output_file='comparison_plot_percentages.png', default_metric=\"mmlu/acc_norm\", custom_layout={}):\n",
" \"\"\"\n",
" Plot a comparison of the given metrics across different runs at the specified step and save the plot.\n",
" \"\"\"\n",
" if run_name_replacements:\n",
" df['runname'] = df['runname'].replace(run_name_replacements)\n",
"\n",
" df_filtered = df[df['steps'] == step]\n",
"\n",
" # Create subplots\n",
"\n",
"\n",
" ranges = {}\n",
" for i, metric in enumerate(metrics):\n",
" yrange_start = BASELINES.get(metric, 0) * 0.9\n",
" yrange_end = max(df_filtered[metric])\n",
" # Adjust the end\n",
" yrange_end = yrange_end + (yrange_end - yrange_start) * 0.2\n",
" ranges[metric] = [yrange_start, yrange_end]\n",
" \n",
" file_name=f\"plots/{output_file}.png\"\n",
" # fig.write_image(file_name)\n",
" print(f\"Plot saved to {file_name}\")\n",
"\n",
" save_for_bar(output_file, df_filtered, metrics, default_metric, plot_name=plot_name, custom_layout=custom_layout, ranges=ranges)\n",
"\n",
"\n",
"metrics = [\n",
" \"mmlu/acc_norm\",\n",
" \"arc/acc_norm\",\n",
" \"openbookqa/acc_norm\",\n",
" \"piqa/acc_norm\",\n",
" \"hellaswag/acc_norm\",\n",
" \"siqa/acc_norm\",\n",
" \"winogrande/acc_norm\",\n",
"]\n",
"\n",
"plot_metric_comparison(df_full, 167000, metrics, output_file=\"edu-100k\", plot_name=\"Evaluation results at 350B tokens\", run_name_replacements={\n",
" \"FineWeb (ours)\": \"FineWeb\"\n",
"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Thresholds ablation"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" runname | \n",
" steps | \n",
" agg_score | \n",
" commonsense_qa/acc | \n",
" commonsense_qa/acc_norm | \n",
" hellaswag/acc | \n",
" hellaswag/acc_norm | \n",
" openbookqa/acc | \n",
" openbookqa/acc_norm | \n",
" piqa/acc | \n",
" ... | \n",
" siqa/acc | \n",
" siqa/acc_norm | \n",
" winogrande/acc | \n",
" winogrande/acc_norm | \n",
" sciq/acc | \n",
" sciq/acc_norm | \n",
" arc/acc | \n",
" arc/acc_norm | \n",
" mmlu/acc | \n",
" mmlu/acc_norm | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" C4 | \n",
" 0 | \n",
" 0.330893 | \n",
" 0.186 | \n",
" 0.233 | \n",
" 0.272 | \n",
" 0.258 | \n",
" 0.166 | \n",
" 0.286 | \n",
" 0.542 | \n",
" ... | \n",
" 0.367 | \n",
" 0.362 | \n",
" 0.516 | \n",
" 0.497 | \n",
" 0.208 | \n",
" 0.202 | \n",
" 0.2195 | \n",
" 0.2510 | \n",
" 0.230294 | \n",
" 0.250147 | \n",
"
\n",
" \n",
" 1 | \n",
" C4 | \n",
" 1000 | \n",
" 0.355112 | \n",
" 0.229 | \n",
" 0.260 | \n",
" 0.286 | \n",
" 0.288 | \n",
" 0.128 | \n",
" 0.250 | \n",
" 0.614 | \n",
" ... | \n",
" 0.351 | \n",
" 0.404 | \n",
" 0.519 | \n",
" 0.476 | \n",
" 0.565 | \n",
" 0.518 | \n",
" 0.2680 | \n",
" 0.2935 | \n",
" 0.238951 | \n",
" 0.250399 | \n",
"
\n",
" \n",
" 2 | \n",
" C4 | \n",
" 2000 | \n",
" 0.378435 | \n",
" 0.268 | \n",
" 0.278 | \n",
" 0.312 | \n",
" 0.330 | \n",
" 0.122 | \n",
" 0.276 | \n",
" 0.646 | \n",
" ... | \n",
" 0.375 | \n",
" 0.400 | \n",
" 0.509 | \n",
" 0.500 | \n",
" 0.676 | \n",
" 0.577 | \n",
" 0.3065 | \n",
" 0.3230 | \n",
" 0.247275 | \n",
" 0.255482 | \n",
"
\n",
" \n",
" 3 | \n",
" C4 | \n",
" 3000 | \n",
" 0.387795 | \n",
" 0.280 | \n",
" 0.295 | \n",
" 0.331 | \n",
" 0.380 | \n",
" 0.152 | \n",
" 0.274 | \n",
" 0.660 | \n",
" ... | \n",
" 0.376 | \n",
" 0.387 | \n",
" 0.512 | \n",
" 0.496 | \n",
" 0.725 | \n",
" 0.621 | \n",
" 0.3175 | \n",
" 0.3340 | \n",
" 0.254534 | \n",
" 0.267363 | \n",
"
\n",
" \n",
" 4 | \n",
" C4 | \n",
" 4000 | \n",
" 0.399320 | \n",
" 0.296 | \n",
" 0.298 | \n",
" 0.351 | \n",
" 0.406 | \n",
" 0.168 | \n",
" 0.282 | \n",
" 0.676 | \n",
" ... | \n",
" 0.382 | \n",
" 0.404 | \n",
" 0.522 | \n",
" 0.503 | \n",
" 0.723 | \n",
" 0.618 | \n",
" 0.3255 | \n",
" 0.3470 | \n",
" 0.254762 | \n",
" 0.263563 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 1171 | \n",
" The Pile | \n",
" 163000 | \n",
" 0.463789 | \n",
" 0.379 | \n",
" 0.349 | \n",
" 0.441 | \n",
" 0.555 | \n",
" 0.240 | \n",
" 0.366 | \n",
" 0.701 | \n",
" ... | \n",
" 0.405 | \n",
" 0.388 | \n",
" 0.585 | \n",
" 0.560 | \n",
" 0.875 | \n",
" 0.820 | \n",
" 0.4475 | \n",
" 0.4450 | \n",
" 0.299378 | \n",
" 0.326313 | \n",
"
\n",
" \n",
" 1172 | \n",
" The Pile | \n",
" 164000 | \n",
" 0.462758 | \n",
" 0.369 | \n",
" 0.344 | \n",
" 0.438 | \n",
" 0.552 | \n",
" 0.248 | \n",
" 0.348 | \n",
" 0.708 | \n",
" ... | \n",
" 0.395 | \n",
" 0.401 | \n",
" 0.577 | \n",
" 0.567 | \n",
" 0.874 | \n",
" 0.806 | \n",
" 0.4465 | \n",
" 0.4355 | \n",
" 0.302083 | \n",
" 0.331563 | \n",
"
\n",
" \n",
" 1173 | \n",
" The Pile | \n",
" 165000 | \n",
" 0.465026 | \n",
" 0.383 | \n",
" 0.350 | \n",
" 0.438 | \n",
" 0.553 | \n",
" 0.234 | \n",
" 0.352 | \n",
" 0.707 | \n",
" ... | \n",
" 0.400 | \n",
" 0.401 | \n",
" 0.569 | \n",
" 0.556 | \n",
" 0.874 | \n",
" 0.811 | \n",
" 0.4460 | \n",
" 0.4455 | \n",
" 0.305193 | \n",
" 0.331708 | \n",
"
\n",
" \n",
" 1174 | \n",
" The Pile | \n",
" 166000 | \n",
" 0.462349 | \n",
" 0.377 | \n",
" 0.346 | \n",
" 0.440 | \n",
" 0.557 | \n",
" 0.228 | \n",
" 0.346 | \n",
" 0.711 | \n",
" ... | \n",
" 0.398 | \n",
" 0.398 | \n",
" 0.572 | \n",
" 0.558 | \n",
" 0.877 | \n",
" 0.811 | \n",
" 0.4525 | \n",
" 0.4385 | \n",
" 0.301952 | \n",
" 0.331295 | \n",
"
\n",
" \n",
" 1175 | \n",
" The Pile | \n",
" 167000 | \n",
" 0.464539 | \n",
" 0.386 | \n",
" 0.354 | \n",
" 0.434 | \n",
" 0.557 | \n",
" 0.232 | \n",
" 0.356 | \n",
" 0.706 | \n",
" ... | \n",
" 0.402 | \n",
" 0.402 | \n",
" 0.573 | \n",
" 0.559 | \n",
" 0.867 | \n",
" 0.802 | \n",
" 0.4475 | \n",
" 0.4375 | \n",
" 0.301934 | \n",
" 0.330810 | \n",
"
\n",
" \n",
"
\n",
"
1176 rows × 21 columns
\n",
"
"
],
"text/plain": [
" runname steps agg_score commonsense_qa/acc \\\n",
"0 C4 0 0.330893 0.186 \n",
"1 C4 1000 0.355112 0.229 \n",
"2 C4 2000 0.378435 0.268 \n",
"3 C4 3000 0.387795 0.280 \n",
"4 C4 4000 0.399320 0.296 \n",
"... ... ... ... ... \n",
"1171 The Pile 163000 0.463789 0.379 \n",
"1172 The Pile 164000 0.462758 0.369 \n",
"1173 The Pile 165000 0.465026 0.383 \n",
"1174 The Pile 166000 0.462349 0.377 \n",
"1175 The Pile 167000 0.464539 0.386 \n",
"\n",
" commonsense_qa/acc_norm hellaswag/acc hellaswag/acc_norm \\\n",
"0 0.233 0.272 0.258 \n",
"1 0.260 0.286 0.288 \n",
"2 0.278 0.312 0.330 \n",
"3 0.295 0.331 0.380 \n",
"4 0.298 0.351 0.406 \n",
"... ... ... ... \n",
"1171 0.349 0.441 0.555 \n",
"1172 0.344 0.438 0.552 \n",
"1173 0.350 0.438 0.553 \n",
"1174 0.346 0.440 0.557 \n",
"1175 0.354 0.434 0.557 \n",
"\n",
" openbookqa/acc openbookqa/acc_norm piqa/acc ... siqa/acc \\\n",
"0 0.166 0.286 0.542 ... 0.367 \n",
"1 0.128 0.250 0.614 ... 0.351 \n",
"2 0.122 0.276 0.646 ... 0.375 \n",
"3 0.152 0.274 0.660 ... 0.376 \n",
"4 0.168 0.282 0.676 ... 0.382 \n",
"... ... ... ... ... ... \n",
"1171 0.240 0.366 0.701 ... 0.405 \n",
"1172 0.248 0.348 0.708 ... 0.395 \n",
"1173 0.234 0.352 0.707 ... 0.400 \n",
"1174 0.228 0.346 0.711 ... 0.398 \n",
"1175 0.232 0.356 0.706 ... 0.402 \n",
"\n",
" siqa/acc_norm winogrande/acc winogrande/acc_norm sciq/acc \\\n",
"0 0.362 0.516 0.497 0.208 \n",
"1 0.404 0.519 0.476 0.565 \n",
"2 0.400 0.509 0.500 0.676 \n",
"3 0.387 0.512 0.496 0.725 \n",
"4 0.404 0.522 0.503 0.723 \n",
"... ... ... ... ... \n",
"1171 0.388 0.585 0.560 0.875 \n",
"1172 0.401 0.577 0.567 0.874 \n",
"1173 0.401 0.569 0.556 0.874 \n",
"1174 0.398 0.572 0.558 0.877 \n",
"1175 0.402 0.573 0.559 0.867 \n",
"\n",
" sciq/acc_norm arc/acc arc/acc_norm mmlu/acc mmlu/acc_norm \n",
"0 0.202 0.2195 0.2510 0.230294 0.250147 \n",
"1 0.518 0.2680 0.2935 0.238951 0.250399 \n",
"2 0.577 0.3065 0.3230 0.247275 0.255482 \n",
"3 0.621 0.3175 0.3340 0.254534 0.267363 \n",
"4 0.618 0.3255 0.3470 0.254762 0.263563 \n",
"... ... ... ... ... ... \n",
"1171 0.820 0.4475 0.4450 0.299378 0.326313 \n",
"1172 0.806 0.4465 0.4355 0.302083 0.331563 \n",
"1173 0.811 0.4460 0.4455 0.305193 0.331708 \n",
"1174 0.811 0.4525 0.4385 0.301952 0.331295 \n",
"1175 0.802 0.4475 0.4375 0.301934 0.330810 \n",
"\n",
"[1176 rows x 21 columns]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fetching datafiles...: 100%|██████████| 4/4 [00:00<00:00, 21.68it/s]\n",
"Loading evals data...: 100%|██████████| 26/26 [00:04<00:00, 5.76it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Metrics saved to loubna-ablations_faq_metrics.csv\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"token = os.getenv(\"HF_TOKEN\")\n",
"repo_name = \"HuggingFaceTB/loubna-ablations_faq\"\n",
"runs_to_fetch = [\"filtered_web_min_score_4_fix-seed-1-\", \"fineweb_2B_educational_minimum_score_3-seed-0-\", \"fineweb_2B_educational_regression-seed-6-\", \"fineweb_2024_10_all_2B-seed-6-\"]\n",
"steps_to_fetch = \"%1000\"\n",
"prefix = \"tb/ablations_faq-1p81G-\"\n",
"metrics = ['commonsense_qa/acc_norm', 'hellaswag/acc_norm', 'openbookqa/acc_norm', 'piqa/acc_norm',\n",
" 'siqa/acc_norm', 'winogrande/acc_norm', 'arc/acc_norm', 'mmlu/acc_norm']\n",
"agg_score_columns = metrics\n",
"column_name = \"agg_score\"\n",
"seed_merge_method = \"mean\"\n",
"oauth_token = token\n",
"\n",
"runs_to_fetch = [prefix + run for run in runs_to_fetch]\n",
"df = fetch_run_results_simple(repo_name, runs_to_fetch, steps_to_fetch, prefix, agg_score_columns, column_name, seed_merge_method, oauth_token=token)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" runname | \n",
" seed | \n",
" steps | \n",
" agg_score | \n",
" commonsense_qa/acc | \n",
" commonsense_qa/acc_norm | \n",
" hellaswag/acc | \n",
" hellaswag/acc_norm | \n",
" openbookqa/acc | \n",
" openbookqa/acc_norm | \n",
" ... | \n",
" siqa/acc | \n",
" siqa/acc_norm | \n",
" winogrande/acc | \n",
" winogrande/acc_norm | \n",
" all/acc | \n",
" all/acc_norm | \n",
" arc/acc | \n",
" arc/acc_norm | \n",
" mmlu/acc | \n",
" mmlu/acc_norm | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" FineWeb (FW) | \n",
" 0 | \n",
" 4000 | \n",
" 0.389983 | \n",
" 0.275 | \n",
" 0.281 | \n",
" 0.352 | \n",
" 0.383 | \n",
" 0.152 | \n",
" 0.286 | \n",
" ... | \n",
" 0.365 | \n",
" 0.385 | \n",
" 0.505 | \n",
" 0.493 | \n",
" 0.265054 | \n",
" 0.281046 | \n",
" 0.3265 | \n",
" 0.3435 | \n",
" 0.250500 | \n",
" 0.264368 | \n",
"
\n",
" \n",
" 0 | \n",
" FineWeb (FW) | \n",
" 0 | \n",
" 5000 | \n",
" 0.397987 | \n",
" 0.303 | \n",
" 0.297 | \n",
" 0.349 | \n",
" 0.397 | \n",
" 0.154 | \n",
" 0.290 | \n",
" ... | \n",
" 0.375 | \n",
" 0.383 | \n",
" 0.509 | \n",
" 0.502 | \n",
" 0.268548 | \n",
" 0.282678 | \n",
" 0.3340 | \n",
" 0.3560 | \n",
" 0.253134 | \n",
" 0.264896 | \n",
"
\n",
" \n",
" 0 | \n",
" FineWeb (FW) | \n",
" 0 | \n",
" 6000 | \n",
" 0.403954 | \n",
" 0.317 | \n",
" 0.319 | \n",
" 0.359 | \n",
" 0.416 | \n",
" 0.166 | \n",
" 0.284 | \n",
" ... | \n",
" 0.379 | \n",
" 0.400 | \n",
" 0.516 | \n",
" 0.490 | \n",
" 0.268197 | \n",
" 0.286678 | \n",
" 0.3330 | \n",
" 0.3590 | \n",
" 0.252102 | \n",
" 0.268633 | \n",
"
\n",
" \n",
" 0 | \n",
" FineWeb (FW) | \n",
" 0 | \n",
" 7000 | \n",
" 0.404859 | \n",
" 0.298 | \n",
" 0.310 | \n",
" 0.367 | \n",
" 0.424 | \n",
" 0.176 | \n",
" 0.290 | \n",
" ... | \n",
" 0.382 | \n",
" 0.396 | \n",
" 0.511 | \n",
" 0.494 | \n",
" 0.271701 | \n",
" 0.289459 | \n",
" 0.3250 | \n",
" 0.3510 | \n",
" 0.256203 | \n",
" 0.271874 | \n",
"
\n",
" \n",
" 0 | \n",
" FineWeb (FW) | \n",
" 0 | \n",
" 8000 | \n",
" 0.403283 | \n",
" 0.330 | \n",
" 0.319 | \n",
" 0.364 | \n",
" 0.412 | \n",
" 0.176 | \n",
" 0.276 | \n",
" ... | \n",
" 0.383 | \n",
" 0.403 | \n",
" 0.510 | \n",
" 0.493 | \n",
" 0.267533 | \n",
" 0.287018 | \n",
" 0.3295 | \n",
" 0.3510 | \n",
" 0.251046 | \n",
" 0.269266 | \n",
"
\n",
" \n",
"
\n",
"
5 rows × 22 columns
\n",
"
"
],
"text/plain": [
" runname seed steps agg_score commonsense_qa/acc \\\n",
"0 FineWeb (FW) 0 4000 0.389983 0.275 \n",
"0 FineWeb (FW) 0 5000 0.397987 0.303 \n",
"0 FineWeb (FW) 0 6000 0.403954 0.317 \n",
"0 FineWeb (FW) 0 7000 0.404859 0.298 \n",
"0 FineWeb (FW) 0 8000 0.403283 0.330 \n",
"\n",
" commonsense_qa/acc_norm hellaswag/acc hellaswag/acc_norm openbookqa/acc \\\n",
"0 0.281 0.352 0.383 0.152 \n",
"0 0.297 0.349 0.397 0.154 \n",
"0 0.319 0.359 0.416 0.166 \n",
"0 0.310 0.367 0.424 0.176 \n",
"0 0.319 0.364 0.412 0.176 \n",
"\n",
" openbookqa/acc_norm ... siqa/acc siqa/acc_norm winogrande/acc \\\n",
"0 0.286 ... 0.365 0.385 0.505 \n",
"0 0.290 ... 0.375 0.383 0.509 \n",
"0 0.284 ... 0.379 0.400 0.516 \n",
"0 0.290 ... 0.382 0.396 0.511 \n",
"0 0.276 ... 0.383 0.403 0.510 \n",
"\n",
" winogrande/acc_norm all/acc all/acc_norm arc/acc arc/acc_norm \\\n",
"0 0.493 0.265054 0.281046 0.3265 0.3435 \n",
"0 0.502 0.268548 0.282678 0.3340 0.3560 \n",
"0 0.490 0.268197 0.286678 0.3330 0.3590 \n",
"0 0.494 0.271701 0.289459 0.3250 0.3510 \n",
"0 0.493 0.267533 0.287018 0.3295 0.3510 \n",
"\n",
" mmlu/acc mmlu/acc_norm \n",
"0 0.250500 0.264368 \n",
"0 0.253134 0.264896 \n",
"0 0.252102 0.268633 \n",
"0 0.256203 0.271874 \n",
"0 0.251046 0.269266 \n",
"\n",
"[5 rows x 22 columns]"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['runname'] = df['runname'].replace({\"filtered_web_min_score_4_fix-seed-1\": \"FW-Edu-threshold=4\",\n",
" \"fineweb_2B_educational_minimum_score_3-seed-0\": \"FW-Edu-threshold=3\",\n",
" \"fineweb_2B_educational_regression-seed-6\": \"FW-Edu-threshold=2\",\n",
" \"fineweb_2024_10_all_2B-seed-6\": \"FineWeb (FW)\"}, regex=True)\n",
"df.tail()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0 FW-Edu-threshold=4\n",
"0 FW-Edu-threshold=4\n",
"0 FW-Edu-threshold=4\n",
"0 FW-Edu-threshold=4\n",
"0 FW-Edu-threshold=4\n",
"0 FW-Edu-threshold=4\n",
"0 FW-Edu-threshold=4\n",
"0 FW-Edu-threshold=4\n",
"0 FW-Edu-threshold=3\n",
"0 FW-Edu-threshold=3\n",
"0 FW-Edu-threshold=3\n",
"0 FW-Edu-threshold=3\n",
"0 FW-Edu-threshold=3\n",
"0 FW-Edu-threshold=2\n",
"0 FW-Edu-threshold=2\n",
"0 FW-Edu-threshold=2\n",
"0 FW-Edu-threshold=2\n",
"0 FW-Edu-threshold=2\n",
"0 FineWeb (FW)\n",
"0 FineWeb (FW)\n",
"0 FineWeb (FW)\n",
"0 FineWeb (FW)\n",
"0 FineWeb (FW)\n",
"0 FineWeb (FW)\n",
"0 FineWeb (FW)\n",
"0 FineWeb (FW)\n",
"Name: runname, dtype: object"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df[\"runname\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Plot saved to plots/edu-8k.png\n"
]
}
],
"source": [
"\n",
"metrics = [\n",
" \"mmlu/acc_norm\",\n",
" \"arc/acc_norm\",\n",
" \"openbookqa/acc_norm\",\n",
" \"piqa/acc_norm\",\n",
" \"hellaswag/acc_norm\",\n",
" \"siqa/acc_norm\",\n",
" \"winogrande/acc_norm\",\n",
"]\n",
"plot_metric_comparison(df, 8000, metrics, output_file=\"edu-8k\", plot_name=\"FineWeb-Edu thresholding\", custom_layout={\n",
" \"xaxis\": {\n",
" \"title\": {\n",
" \"standoff\": 60,\n",
" \"text\": \"Dataset\"\n",
" },\n",
" \"tickangle\": 30\n",
" },\n",
" \"margin\": {\n",
" \"b\": 120\n",
" }\n",
"})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "textbooks",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}