{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Fetch the data from the hub" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "import os\n", "import itertools\n", "import pandas as pd\n", "from concurrent.futures import ThreadPoolExecutor\n", "from tqdm import tqdm\n", "import itertools\n", "import huggingface_hub\n", "from tensorboard.backend.event_processing.event_accumulator import EventAccumulator\n", "from huggingface_hub.utils import EntryNotFoundError" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "def step_element_match(step_to_check, step_element):\n", " step_element = step_element.strip().replace(\" \", \"\")\n", " if \"-\" in step_element:\n", " a, b = step_element.split(\"-\")\n", " c = None\n", " if \"%\" in b:\n", " b, c = b.split(\"%\")\n", " return (int(a) <= step_to_check <= int(b) and\n", " (c is None or (step_to_check - int(a)) % int(c) == 0))\n", " elif \"%\" in step_element:\n", " return step_to_check % int(step_element[1:]) == 0\n", " else:\n", " return step_to_check == int(step_element)\n", " \n", "def fetch_run_results_simple(repo_name, runs_to_fetch, steps_to_fetch, prefix, agg_score_columns, column_name,\n", " seed_merge_method, oauth_token=None, prefix_file=None):\n", " if not runs_to_fetch:\n", " return\n", "\n", " def fetch_run_files(run_to_fetch):\n", " def filename_to_steps_timestamp(fn):\n", " step, ts = fn.split(\"_events.out.tfevents.\")\n", " return int(step[-7:]), int(ts[:ts.index(\".\")])\n", "\n", " run_to_fetch += \"_e\"\n", " try:\n", " eval_repo_file_names = [f.path for f in\n", " huggingface_hub.list_repo_tree(repo_name, run_to_fetch, expand=False,\n", " token=oauth_token) if\n", " \"_events.out.tfevents\" in f.path]\n", " except EntryNotFoundError:\n", " return []\n", "\n", " eval_files = [os.path.relpath(f, run_to_fetch) for f in eval_repo_file_names]\n", " timestamps = {}\n", " for fn in eval_files:\n", " steps, ts = filename_to_steps_timestamp(fn)\n", " if steps not in timestamps or timestamps[steps][0] < ts:\n", " timestamps[steps] = ts, fn\n", "\n", " results = []\n", " for eval_file, repofile in zip(eval_files, eval_repo_file_names):\n", " steps, ts = filename_to_steps_timestamp(eval_file)\n", " if not any(step_element_match(steps, step_el) for step_el in steps_to_fetch.split(\",\")):\n", " continue\n", " if timestamps[steps][1] == eval_file:\n", " results.append((run_to_fetch, steps, repofile))\n", " return results\n", "\n", " def load_run_file(data):\n", " run_to_fetch, steps, repofile = data\n", " loader = EventAccumulator(huggingface_hub.hf_hub_download(repo_name, repofile, token=oauth_token))\n", " loader.Reload()\n", " runname = run_to_fetch.removeprefix(prefix).removesuffix(\"-_e\")\n", " column_names = [\"runname\", \"seed\", \"steps\", \"agg_score\"]\n", " column_values = [runname, 0, steps, 0.0]\n", "\n", " for tag in loader.Tags()['scalars']:\n", " if not \"stderr\" in tag and tag.split('/')[0] == 'e':\n", " event_list = loader.Scalars(tag)\n", " tag = tag.split('/')\n", " column_names.append(f\"{tag[1]}/{tag[2]}\")\n", " column_values.append(event_list[-1].value)\n", "\n", " return pd.DataFrame([column_values], columns=column_names)\n", "\n", " with ThreadPoolExecutor() as pool:\n", " run_files = list(itertools.chain.from_iterable(\n", " tqdm(pool.map(fetch_run_files, runs_to_fetch), total=len(runs_to_fetch), desc=\"Fetching datafiles...\")))\n", " df = pd.concat(tqdm(pool.map(load_run_file, run_files), total=len(run_files), desc=\"Loading evals data...\"))\n", "\n", " cols_to_avg = [col for col in agg_score_columns if col in df.columns]\n", " if cols_to_avg:\n", " df['agg_score'] = df[cols_to_avg].mean(axis=1)\n", "\n", " prefix_file = prefix_file + \"_\" if prefix_file else \"\"\n", " df.to_csv(f\"{prefix_file}{repo_name.split('/')[-1]}_metrics.csv\", index=False)\n", " print(f\"Metrics saved to {repo_name.split('/')[-1]}_metrics.csv\")\n", "\n", " return df" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Fetching datafiles...: 100%|██████████| 1/1 [00:02<00:00, 2.94s/it]\n", "Loading evals data...: 100%|██████████| 82/82 [00:15<00:00, 5.37it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Metrics saved to loubna-edu_fw_ablations_metrics.csv\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
runnameseedstepsagg_scorecommonsense_qa/acccommonsense_qa/acc_normhellaswag/acchellaswag/acc_normopenbookqa/accopenbookqa/acc_norm...siqa/accsiqa/acc_normwinogrande/accwinogrande/acc_normall/accall/acc_normarc/accarc/acc_normmmlu/accmmlu/acc_norm
0edu_fineweb_350b_tokens-seed-1020000.3903260.2840.2830.3140.3250.1640.296...0.3620.4060.5110.5110.2796740.2991620.37950.38500.2659970.284605
0edu_fineweb_350b_tokens-seed-1040000.4146800.3220.3070.3430.3950.1960.320...0.3710.3880.5180.4950.2906130.3125930.42150.42850.2744010.295939
0edu_fineweb_350b_tokens-seed-1060000.4283900.3190.3110.3720.4310.2020.352...0.3730.3920.5200.5190.3039800.3233230.43150.44600.2885910.306123
0edu_fineweb_350b_tokens-seed-1080000.4436150.3400.3110.3790.4630.2040.360...0.3840.4040.5170.5170.3151480.3332840.46300.47900.2991860.314921
0edu_fineweb_350b_tokens-seed-10100000.4414570.3460.3170.3900.4540.2220.364...0.3660.3950.5140.5060.3189350.3354190.48900.48200.3021890.317653
..................................................................
0edu_fineweb_350b_tokens-seed-101600000.5071290.4300.3590.4730.5930.2820.418...0.3920.4020.5760.5750.3691370.3938980.56700.57250.3502260.374533
0edu_fineweb_350b_tokens-seed-101620000.5091180.4160.3670.4740.5920.2880.408...0.3900.4090.5720.5770.3674200.3928610.57200.57800.3482680.372947
0edu_fineweb_350b_tokens-seed-101640000.5078430.4160.3650.4670.5910.2760.408...0.3950.4060.5760.5800.3683190.3920000.56350.57150.3499430.372246
0edu_fineweb_350b_tokens-seed-101660000.5083080.4150.3640.4720.5930.2820.414...0.4010.4080.5750.5700.3705930.3931760.56400.57600.3522030.373463
0edu_fineweb_350b_tokens-seed-101670000.5094940.4290.3620.4720.5970.2900.418...0.3950.4040.5820.5780.3696660.3941360.56700.57350.3506710.374453
\n", "

82 rows × 22 columns

\n", "
" ], "text/plain": [ " runname seed steps agg_score \\\n", "0 edu_fineweb_350b_tokens-seed-1 0 2000 0.390326 \n", "0 edu_fineweb_350b_tokens-seed-1 0 4000 0.414680 \n", "0 edu_fineweb_350b_tokens-seed-1 0 6000 0.428390 \n", "0 edu_fineweb_350b_tokens-seed-1 0 8000 0.443615 \n", "0 edu_fineweb_350b_tokens-seed-1 0 10000 0.441457 \n", ".. ... ... ... ... \n", "0 edu_fineweb_350b_tokens-seed-1 0 160000 0.507129 \n", "0 edu_fineweb_350b_tokens-seed-1 0 162000 0.509118 \n", "0 edu_fineweb_350b_tokens-seed-1 0 164000 0.507843 \n", "0 edu_fineweb_350b_tokens-seed-1 0 166000 0.508308 \n", "0 edu_fineweb_350b_tokens-seed-1 0 167000 0.509494 \n", "\n", " commonsense_qa/acc commonsense_qa/acc_norm hellaswag/acc \\\n", "0 0.284 0.283 0.314 \n", "0 0.322 0.307 0.343 \n", "0 0.319 0.311 0.372 \n", "0 0.340 0.311 0.379 \n", "0 0.346 0.317 0.390 \n", ".. ... ... ... \n", "0 0.430 0.359 0.473 \n", "0 0.416 0.367 0.474 \n", "0 0.416 0.365 0.467 \n", "0 0.415 0.364 0.472 \n", "0 0.429 0.362 0.472 \n", "\n", " hellaswag/acc_norm openbookqa/acc openbookqa/acc_norm ... siqa/acc \\\n", "0 0.325 0.164 0.296 ... 0.362 \n", "0 0.395 0.196 0.320 ... 0.371 \n", "0 0.431 0.202 0.352 ... 0.373 \n", "0 0.463 0.204 0.360 ... 0.384 \n", "0 0.454 0.222 0.364 ... 0.366 \n", ".. ... ... ... ... ... \n", "0 0.593 0.282 0.418 ... 0.392 \n", "0 0.592 0.288 0.408 ... 0.390 \n", "0 0.591 0.276 0.408 ... 0.395 \n", "0 0.593 0.282 0.414 ... 0.401 \n", "0 0.597 0.290 0.418 ... 0.395 \n", "\n", " siqa/acc_norm winogrande/acc winogrande/acc_norm all/acc \\\n", "0 0.406 0.511 0.511 0.279674 \n", "0 0.388 0.518 0.495 0.290613 \n", "0 0.392 0.520 0.519 0.303980 \n", "0 0.404 0.517 0.517 0.315148 \n", "0 0.395 0.514 0.506 0.318935 \n", ".. ... ... ... ... \n", "0 0.402 0.576 0.575 0.369137 \n", "0 0.409 0.572 0.577 0.367420 \n", "0 0.406 0.576 0.580 0.368319 \n", "0 0.408 0.575 0.570 0.370593 \n", "0 0.404 0.582 0.578 0.369666 \n", "\n", " all/acc_norm arc/acc arc/acc_norm mmlu/acc mmlu/acc_norm \n", "0 0.299162 0.3795 0.3850 0.265997 0.284605 \n", "0 0.312593 0.4215 0.4285 0.274401 0.295939 \n", "0 0.323323 0.4315 0.4460 0.288591 0.306123 \n", "0 0.333284 0.4630 0.4790 0.299186 0.314921 \n", "0 0.335419 0.4890 0.4820 0.302189 0.317653 \n", ".. ... ... ... ... ... \n", "0 0.393898 0.5670 0.5725 0.350226 0.374533 \n", "0 0.392861 0.5720 0.5780 0.348268 0.372947 \n", "0 0.392000 0.5635 0.5715 0.349943 0.372246 \n", "0 0.393176 0.5640 0.5760 0.352203 0.373463 \n", "0 0.394136 0.5670 0.5735 0.350671 0.374453 \n", "\n", "[82 rows x 22 columns]" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "token = os.getenv(\"HF_TOKEN\")\n", "repo_name = \"HuggingFaceTB/loubna-edu_fw_ablations\"\n", "runs_to_fetch = [\"tb/edu_fw_ablations-1p82G-edu_fineweb_350b_tokens-seed-1-\"]\n", "steps_to_fetch = \"%1000\"\n", "prefix = \"tb/edu_fw_ablations-1p82G-\"\n", "metrics = ['commonsense_qa/acc_norm', 'hellaswag/acc_norm', 'openbookqa/acc_norm', 'piqa/acc_norm',\n", " 'siqa/acc_norm', 'winogrande/acc_norm', 'arc/acc_norm', 'mmlu/acc_norm']\n", "agg_score_columns = metrics\n", "column_name = \"agg_score\"\n", "seed_merge_method = \"mean\"\n", "oauth_token = token\n", "\n", "# runs_to_fetch = [prefix + run for run in runs_to_fetch]\n", "fetch_run_results_simple(repo_name, runs_to_fetch, steps_to_fetch, prefix, agg_score_columns, column_name, seed_merge_method, oauth_token=token)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load csvs for FW and FW-Edu" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
runnamestepsagg_scorecommonsense_qa/acccommonsense_qa/acc_normhellaswag/acchellaswag/acc_normopenbookqa/accopenbookqa/acc_normpiqa/accpiqa/acc_normsiqa/accsiqa/acc_normwinogrande/accwinogrande/acc_normarc/accarc/acc_normmmlu/accmmlu/acc_norm
0FineWeb-Edu20000.3903260.2840.2830.3140.3250.1640.2960.6230.6320.3620.4060.5110.5110.37950.38500.2659970.284605
1FineWeb-Edu40000.4146800.3220.3070.3430.3950.1960.3200.6560.6880.3710.3880.5180.4950.42150.42850.2744010.295939
2FineWeb-Edu60000.4283900.3190.3110.3720.4310.2020.3520.6600.6700.3730.3920.5200.5190.43150.44600.2885910.306123
3FineWeb-Edu80000.4436150.3400.3110.3790.4630.2040.3600.6810.7000.3840.4040.5170.5170.46300.47900.2991860.314921
4FineWeb-Edu100000.4414570.3460.3170.3900.4540.2220.3640.6900.6960.3660.3950.5140.5060.48900.48200.3021890.317653
\n", "
" ], "text/plain": [ " runname steps agg_score commonsense_qa/acc commonsense_qa/acc_norm \\\n", "0 FineWeb-Edu 2000 0.390326 0.284 0.283 \n", "1 FineWeb-Edu 4000 0.414680 0.322 0.307 \n", "2 FineWeb-Edu 6000 0.428390 0.319 0.311 \n", "3 FineWeb-Edu 8000 0.443615 0.340 0.311 \n", "4 FineWeb-Edu 10000 0.441457 0.346 0.317 \n", "\n", " hellaswag/acc hellaswag/acc_norm openbookqa/acc openbookqa/acc_norm \\\n", "0 0.314 0.325 0.164 0.296 \n", "1 0.343 0.395 0.196 0.320 \n", "2 0.372 0.431 0.202 0.352 \n", "3 0.379 0.463 0.204 0.360 \n", "4 0.390 0.454 0.222 0.364 \n", "\n", " piqa/acc piqa/acc_norm siqa/acc siqa/acc_norm winogrande/acc \\\n", "0 0.623 0.632 0.362 0.406 0.511 \n", "1 0.656 0.688 0.371 0.388 0.518 \n", "2 0.660 0.670 0.373 0.392 0.520 \n", "3 0.681 0.700 0.384 0.404 0.517 \n", "4 0.690 0.696 0.366 0.395 0.514 \n", "\n", " winogrande/acc_norm arc/acc arc/acc_norm mmlu/acc mmlu/acc_norm \n", "0 0.511 0.3795 0.3850 0.265997 0.284605 \n", "1 0.495 0.4215 0.4285 0.274401 0.295939 \n", "2 0.519 0.4315 0.4460 0.288591 0.306123 \n", "3 0.517 0.4630 0.4790 0.299186 0.314921 \n", "4 0.506 0.4890 0.4820 0.302189 0.317653 " ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "# load guilherme csv with all the FW runs\n", "df = pd.read_csv(\"../src_data/eval_results.csv\")\n", "\n", "# load FineWeb-Edu csv\n", "df_2 = pd.read_csv(\"./loubna-edu_fw_ablations_metrics.csv\")\n", "df_2['runname'] = df_2['runname'].replace('edu_fineweb_350b_tokens-seed-1', 'FineWeb-Edu', regex=True)\n", "df_2.drop([\"seed\", \"all/acc\", \"all/acc_norm\"], axis=1, inplace=True)\n", "df_2.head()" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
runnamestepsagg_scorecommonsense_qa/acccommonsense_qa/acc_normhellaswag/acchellaswag/acc_normopenbookqa/accopenbookqa/acc_normpiqa/acc...siqa/accsiqa/acc_normwinogrande/accwinogrande/acc_normsciq/accsciq/acc_normarc/accarc/acc_normmmlu/accmmlu/acc_norm
1253FineWeb-Edu1600000.5071290.4300.3590.4730.5930.2820.4180.744...0.3920.4020.5760.575NaNNaN0.56700.57250.3502260.374533
1254FineWeb-Edu1620000.5091180.4160.3670.4740.5920.2880.4080.747...0.3900.4090.5720.577NaNNaN0.57200.57800.3482680.372947
1255FineWeb-Edu1640000.5078430.4160.3650.4670.5910.2760.4080.737...0.3950.4060.5760.580NaNNaN0.56350.57150.3499430.372246
1256FineWeb-Edu1660000.5083080.4150.3640.4720.5930.2820.4140.740...0.4010.4080.5750.570NaNNaN0.56400.57600.3522030.373463
1257FineWeb-Edu1670000.5094940.4290.3620.4720.5970.2900.4180.738...0.3950.4040.5820.578NaNNaN0.56700.57350.3506710.374453
\n", "

5 rows × 21 columns

\n", "
" ], "text/plain": [ " runname steps agg_score commonsense_qa/acc \\\n", "1253 FineWeb-Edu 160000 0.507129 0.430 \n", "1254 FineWeb-Edu 162000 0.509118 0.416 \n", "1255 FineWeb-Edu 164000 0.507843 0.416 \n", "1256 FineWeb-Edu 166000 0.508308 0.415 \n", "1257 FineWeb-Edu 167000 0.509494 0.429 \n", "\n", " commonsense_qa/acc_norm hellaswag/acc hellaswag/acc_norm \\\n", "1253 0.359 0.473 0.593 \n", "1254 0.367 0.474 0.592 \n", "1255 0.365 0.467 0.591 \n", "1256 0.364 0.472 0.593 \n", "1257 0.362 0.472 0.597 \n", "\n", " openbookqa/acc openbookqa/acc_norm piqa/acc ... siqa/acc \\\n", "1253 0.282 0.418 0.744 ... 0.392 \n", "1254 0.288 0.408 0.747 ... 0.390 \n", "1255 0.276 0.408 0.737 ... 0.395 \n", "1256 0.282 0.414 0.740 ... 0.401 \n", "1257 0.290 0.418 0.738 ... 0.395 \n", "\n", " siqa/acc_norm winogrande/acc winogrande/acc_norm sciq/acc \\\n", "1253 0.402 0.576 0.575 NaN \n", "1254 0.409 0.572 0.577 NaN \n", "1255 0.406 0.576 0.580 NaN \n", "1256 0.408 0.575 0.570 NaN \n", "1257 0.404 0.582 0.578 NaN \n", "\n", " sciq/acc_norm arc/acc arc/acc_norm mmlu/acc mmlu/acc_norm \n", "1253 NaN 0.5670 0.5725 0.350226 0.374533 \n", "1254 NaN 0.5720 0.5780 0.348268 0.372947 \n", "1255 NaN 0.5635 0.5715 0.349943 0.372246 \n", "1256 NaN 0.5640 0.5760 0.352203 0.373463 \n", "1257 NaN 0.5670 0.5735 0.350671 0.374453 \n", "\n", "[5 rows x 21 columns]" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_full = pd.concat([df, df_2], ignore_index=True)\n", "df_full.tail()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Guilherme-Board plot" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
steps
runname
C4168
Dolma168
FineWeb168
FineWeb-Edu82
RedPajama2168
RefinedWeb168
SlimPajama168
The Pile168
\n", "
" ], "text/plain": [ " steps\n", "runname \n", "C4 168\n", "Dolma 168\n", "FineWeb 168\n", "FineWeb-Edu 82\n", "RedPajama2 168\n", "RefinedWeb 168\n", "SlimPajama 168\n", "The Pile 168" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_full.groupby(\"runname\").agg({\"steps\": \"count\"})" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "fineweb_edu_steps = df_full[df_full[\"runname\"] == \"FineWeb-Edu\"][\"steps\"].unique()\n", "# Only selects steps that are in the fineweb_edu_steps \n", "df_full = df_full[df_full[\"steps\"].isin(fineweb_edu_steps)]" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "from matplotlib import pyplot as plt\n", "metrics = ['agg_score', 'commonsense_qa/acc_norm', 'hellaswag/acc_norm', 'openbookqa/acc_norm', 'piqa/acc_norm',\n", " 'siqa/acc_norm', 'winogrande/acc_norm', 'arc/acc_norm', 'mmlu/acc_norm']\n", "\n", "def normalize_runname(runname):\n", " return runname.replace(\"/\", \"_\")\n", "\n", "grouped = (\n", " df_full.groupby([\"runname\", \"steps\"])\n", " .agg(\n", " {\n", " key: \"mean\" for key in metrics\n", " }\n", " )\n", " .reset_index()\n", ")\n", "\n", "file_id=\"../assets/data/plots/edu_ablations\"\n", "files = {}\n", "for metric in metrics:\n", " datas = {}\n", " for name, group in grouped.groupby(\"runname\"):\n", " group = group[[\"steps\", metric]].sort_values(by=\"steps\")\n", " group = group.set_index(\"steps\")\n", " rolling_avg = group\n", " # rolling_avg = group.rolling(window=5).mean()\n", " datas[name] = {\n", " \"x\": (rolling_avg.index * 2048 * 1024 * 1e-9).tolist(),\n", " \"y\": rolling_avg[metric].tolist(),\n", " \"label\": name,\n", " }\n", " # Sort the datata based on the steps\n", " datas = {k: v for k, v in sorted(datas.items(), key=lambda x: -x[1][\"y\"][-1])}\n", " # Create a folder\n", " os.makedirs(f\"{file_id}\", exist_ok=True)\n", " with open(f\"{file_id}/{normalize_runname(metric)}.json\", \"w\") as f:\n", " json.dump({\n", " \"data\": datas,\n", " \"layout\": {\n", " \"title\": {\n", " \"text\": \"Dataset ablations\"\n", " },\n", " }\n", " }, f)\n", " files[metric] = {\"file\": f\"{normalize_runname(metric)}.json\"}\n", "# Create index\n", "with open(f\"{file_id}/index.json\", \"w\") as f:\n", " json.dump({\n", " \"files\": files,\n", " \"settings\": {\n", " \"defaultMetric\": \"agg_score\",\n", " \"slider\":{\"min\":0,\"max\":30,\"default\":5},\n", " \"caption\": \"📚 FineWeb-Edu outperforms 🍷 FineWeb and all other open web datasets on our group of evaluation tasks.\"\n", " }\n", " }, f)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Barplot" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: kaleido in /Users/hynky/.pyenv/versions/3.12.2/envs/datatrove/lib/python3.12/site-packages (0.2.1)\n" ] } ], "source": [ "!pip install -U kaleido" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Plot saved to plots/edu-100k.png\n" ] } ], "source": [ "import plotly.express as px\n", "from plotly.subplots import make_subplots\n", "import plotly.graph_objects as go\n", "\n", "import json\n", "\n", "BASELINES = {\n", " \"mmlu/acc_norm\": 0.25,\n", " \"arc/acc_norm\": 0.25,\n", " \"openbookqa/acc_norm\": 0.25,\n", " \"piqa/acc_norm\": 0.5,\n", " \"hellaswag/acc_norm\": 0.25,\n", " \"siqa/acc_norm\": 0.33,\n", " \"winogrande/acc_norm\": 0.5,\n", "}\n", "\n", "\n", "def normalize_run_name(run_name):\n", " return run_name.replace(\"/\", \"_\")\n", "\n", "\n", "def save_for_bar(dir_name, df, metrics, default_metric=\"mmlu/acc_norm\", xlabel=\"Dataset\", plot_name=\"plot name\", custom_layout={}, ranges={}):\n", " import os\n", " files = {}\n", " os.makedirs(f\"../assets/data/plots/{dir_name}\", exist_ok=True)\n", " for metric in metrics:\n", " data = {}\n", " for run_name in df[\"runname\"].unique():\n", " data[run_name] = {\n", " \"x\": [run_name],\n", " \"y\": df[df[\"runname\"] == run_name][metric].tolist(),\n", " \"label\": run_name,\n", " }\n", " file_name = f\"{normalize_run_name(metric)}.json\"\n", " files[metric] = {\"file\": f\"{file_name}\"}\n", " with open(f\"../assets/data/plots/{dir_name}/{file_name}\", \"w\") as f:\n", " json.dump({\n", " \"data\": data,\n", " \"layout\": {\n", " \"showlegend\": False,\n", " \"title\": {\n", " \"text\": plot_name,\n", " },\n", " \"xaxis\": {\n", " \"title\": {\n", " \"text\": xlabel,\n", " \"standoff\": 30\n", " },\n", " \"tickangle\": 30\n", " },\n", " \"yaxis\": {\n", " \"range\": ranges.get(metric, [0, 1])\n", " },\n", " \"margin\": {\n", " \"b\": 100\n", " },\n", " **custom_layout,\n", " }\n", " }, f)\n", " with open(f\"../assets/data/plots/{dir_name}/index.json\", \"w\") as f:\n", " json.dump({\n", " \"files\": files,\n", " \"settings\": {\n", " \"defaultMetric\": default_metric,\n", " \"slider\": None,\n", " \"autoSetXRange\": False,\n", " \"type\": \"bar\"\n", " }\n", " }, f)\n", " return files\n", "\n", "def plot_metric_comparison(df, step, metrics, plot_name, run_name_replacements=None, output_file='comparison_plot_percentages.png', default_metric=\"mmlu/acc_norm\", custom_layout={}):\n", " \"\"\"\n", " Plot a comparison of the given metrics across different runs at the specified step and save the plot.\n", " \"\"\"\n", " if run_name_replacements:\n", " df['runname'] = df['runname'].replace(run_name_replacements)\n", "\n", " df_filtered = df[df['steps'] == step]\n", "\n", " # Create subplots\n", "\n", "\n", " ranges = {}\n", " for i, metric in enumerate(metrics):\n", " yrange_start = BASELINES.get(metric, 0) * 0.9\n", " yrange_end = max(df_filtered[metric])\n", " # Adjust the end\n", " yrange_end = yrange_end + (yrange_end - yrange_start) * 0.2\n", " ranges[metric] = [yrange_start, yrange_end]\n", " \n", " file_name=f\"plots/{output_file}.png\"\n", " # fig.write_image(file_name)\n", " print(f\"Plot saved to {file_name}\")\n", "\n", " save_for_bar(output_file, df_filtered, metrics, default_metric, plot_name=plot_name, custom_layout=custom_layout, ranges=ranges)\n", "\n", "\n", "metrics = [\n", " \"mmlu/acc_norm\",\n", " \"arc/acc_norm\",\n", " \"openbookqa/acc_norm\",\n", " \"piqa/acc_norm\",\n", " \"hellaswag/acc_norm\",\n", " \"siqa/acc_norm\",\n", " \"winogrande/acc_norm\",\n", "]\n", "\n", "plot_metric_comparison(df_full, 167000, metrics, output_file=\"edu-100k\", plot_name=\"Evaluation results at 350B tokens\", run_name_replacements={\n", " \"FineWeb (ours)\": \"FineWeb\"\n", "})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Thresholds ablation" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
runnamestepsagg_scorecommonsense_qa/acccommonsense_qa/acc_normhellaswag/acchellaswag/acc_normopenbookqa/accopenbookqa/acc_normpiqa/acc...siqa/accsiqa/acc_normwinogrande/accwinogrande/acc_normsciq/accsciq/acc_normarc/accarc/acc_normmmlu/accmmlu/acc_norm
0C400.3308930.1860.2330.2720.2580.1660.2860.542...0.3670.3620.5160.4970.2080.2020.21950.25100.2302940.250147
1C410000.3551120.2290.2600.2860.2880.1280.2500.614...0.3510.4040.5190.4760.5650.5180.26800.29350.2389510.250399
2C420000.3784350.2680.2780.3120.3300.1220.2760.646...0.3750.4000.5090.5000.6760.5770.30650.32300.2472750.255482
3C430000.3877950.2800.2950.3310.3800.1520.2740.660...0.3760.3870.5120.4960.7250.6210.31750.33400.2545340.267363
4C440000.3993200.2960.2980.3510.4060.1680.2820.676...0.3820.4040.5220.5030.7230.6180.32550.34700.2547620.263563
..................................................................
1171The Pile1630000.4637890.3790.3490.4410.5550.2400.3660.701...0.4050.3880.5850.5600.8750.8200.44750.44500.2993780.326313
1172The Pile1640000.4627580.3690.3440.4380.5520.2480.3480.708...0.3950.4010.5770.5670.8740.8060.44650.43550.3020830.331563
1173The Pile1650000.4650260.3830.3500.4380.5530.2340.3520.707...0.4000.4010.5690.5560.8740.8110.44600.44550.3051930.331708
1174The Pile1660000.4623490.3770.3460.4400.5570.2280.3460.711...0.3980.3980.5720.5580.8770.8110.45250.43850.3019520.331295
1175The Pile1670000.4645390.3860.3540.4340.5570.2320.3560.706...0.4020.4020.5730.5590.8670.8020.44750.43750.3019340.330810
\n", "

1176 rows × 21 columns

\n", "
" ], "text/plain": [ " runname steps agg_score commonsense_qa/acc \\\n", "0 C4 0 0.330893 0.186 \n", "1 C4 1000 0.355112 0.229 \n", "2 C4 2000 0.378435 0.268 \n", "3 C4 3000 0.387795 0.280 \n", "4 C4 4000 0.399320 0.296 \n", "... ... ... ... ... \n", "1171 The Pile 163000 0.463789 0.379 \n", "1172 The Pile 164000 0.462758 0.369 \n", "1173 The Pile 165000 0.465026 0.383 \n", "1174 The Pile 166000 0.462349 0.377 \n", "1175 The Pile 167000 0.464539 0.386 \n", "\n", " commonsense_qa/acc_norm hellaswag/acc hellaswag/acc_norm \\\n", "0 0.233 0.272 0.258 \n", "1 0.260 0.286 0.288 \n", "2 0.278 0.312 0.330 \n", "3 0.295 0.331 0.380 \n", "4 0.298 0.351 0.406 \n", "... ... ... ... \n", "1171 0.349 0.441 0.555 \n", "1172 0.344 0.438 0.552 \n", "1173 0.350 0.438 0.553 \n", "1174 0.346 0.440 0.557 \n", "1175 0.354 0.434 0.557 \n", "\n", " openbookqa/acc openbookqa/acc_norm piqa/acc ... siqa/acc \\\n", "0 0.166 0.286 0.542 ... 0.367 \n", "1 0.128 0.250 0.614 ... 0.351 \n", "2 0.122 0.276 0.646 ... 0.375 \n", "3 0.152 0.274 0.660 ... 0.376 \n", "4 0.168 0.282 0.676 ... 0.382 \n", "... ... ... ... ... ... \n", "1171 0.240 0.366 0.701 ... 0.405 \n", "1172 0.248 0.348 0.708 ... 0.395 \n", "1173 0.234 0.352 0.707 ... 0.400 \n", "1174 0.228 0.346 0.711 ... 0.398 \n", "1175 0.232 0.356 0.706 ... 0.402 \n", "\n", " siqa/acc_norm winogrande/acc winogrande/acc_norm sciq/acc \\\n", "0 0.362 0.516 0.497 0.208 \n", "1 0.404 0.519 0.476 0.565 \n", "2 0.400 0.509 0.500 0.676 \n", "3 0.387 0.512 0.496 0.725 \n", "4 0.404 0.522 0.503 0.723 \n", "... ... ... ... ... \n", "1171 0.388 0.585 0.560 0.875 \n", "1172 0.401 0.577 0.567 0.874 \n", "1173 0.401 0.569 0.556 0.874 \n", "1174 0.398 0.572 0.558 0.877 \n", "1175 0.402 0.573 0.559 0.867 \n", "\n", " sciq/acc_norm arc/acc arc/acc_norm mmlu/acc mmlu/acc_norm \n", "0 0.202 0.2195 0.2510 0.230294 0.250147 \n", "1 0.518 0.2680 0.2935 0.238951 0.250399 \n", "2 0.577 0.3065 0.3230 0.247275 0.255482 \n", "3 0.621 0.3175 0.3340 0.254534 0.267363 \n", "4 0.618 0.3255 0.3470 0.254762 0.263563 \n", "... ... ... ... ... ... \n", "1171 0.820 0.4475 0.4450 0.299378 0.326313 \n", "1172 0.806 0.4465 0.4355 0.302083 0.331563 \n", "1173 0.811 0.4460 0.4455 0.305193 0.331708 \n", "1174 0.811 0.4525 0.4385 0.301952 0.331295 \n", "1175 0.802 0.4475 0.4375 0.301934 0.330810 \n", "\n", "[1176 rows x 21 columns]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Fetching datafiles...: 100%|██████████| 4/4 [00:00<00:00, 21.68it/s]\n", "Loading evals data...: 100%|██████████| 26/26 [00:04<00:00, 5.76it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Metrics saved to loubna-ablations_faq_metrics.csv\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "token = os.getenv(\"HF_TOKEN\")\n", "repo_name = \"HuggingFaceTB/loubna-ablations_faq\"\n", "runs_to_fetch = [\"filtered_web_min_score_4_fix-seed-1-\", \"fineweb_2B_educational_minimum_score_3-seed-0-\", \"fineweb_2B_educational_regression-seed-6-\", \"fineweb_2024_10_all_2B-seed-6-\"]\n", "steps_to_fetch = \"%1000\"\n", "prefix = \"tb/ablations_faq-1p81G-\"\n", "metrics = ['commonsense_qa/acc_norm', 'hellaswag/acc_norm', 'openbookqa/acc_norm', 'piqa/acc_norm',\n", " 'siqa/acc_norm', 'winogrande/acc_norm', 'arc/acc_norm', 'mmlu/acc_norm']\n", "agg_score_columns = metrics\n", "column_name = \"agg_score\"\n", "seed_merge_method = \"mean\"\n", "oauth_token = token\n", "\n", "runs_to_fetch = [prefix + run for run in runs_to_fetch]\n", "df = fetch_run_results_simple(repo_name, runs_to_fetch, steps_to_fetch, prefix, agg_score_columns, column_name, seed_merge_method, oauth_token=token)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
runnameseedstepsagg_scorecommonsense_qa/acccommonsense_qa/acc_normhellaswag/acchellaswag/acc_normopenbookqa/accopenbookqa/acc_norm...siqa/accsiqa/acc_normwinogrande/accwinogrande/acc_normall/accall/acc_normarc/accarc/acc_normmmlu/accmmlu/acc_norm
0FineWeb (FW)040000.3899830.2750.2810.3520.3830.1520.286...0.3650.3850.5050.4930.2650540.2810460.32650.34350.2505000.264368
0FineWeb (FW)050000.3979870.3030.2970.3490.3970.1540.290...0.3750.3830.5090.5020.2685480.2826780.33400.35600.2531340.264896
0FineWeb (FW)060000.4039540.3170.3190.3590.4160.1660.284...0.3790.4000.5160.4900.2681970.2866780.33300.35900.2521020.268633
0FineWeb (FW)070000.4048590.2980.3100.3670.4240.1760.290...0.3820.3960.5110.4940.2717010.2894590.32500.35100.2562030.271874
0FineWeb (FW)080000.4032830.3300.3190.3640.4120.1760.276...0.3830.4030.5100.4930.2675330.2870180.32950.35100.2510460.269266
\n", "

5 rows × 22 columns

\n", "
" ], "text/plain": [ " runname seed steps agg_score commonsense_qa/acc \\\n", "0 FineWeb (FW) 0 4000 0.389983 0.275 \n", "0 FineWeb (FW) 0 5000 0.397987 0.303 \n", "0 FineWeb (FW) 0 6000 0.403954 0.317 \n", "0 FineWeb (FW) 0 7000 0.404859 0.298 \n", "0 FineWeb (FW) 0 8000 0.403283 0.330 \n", "\n", " commonsense_qa/acc_norm hellaswag/acc hellaswag/acc_norm openbookqa/acc \\\n", "0 0.281 0.352 0.383 0.152 \n", "0 0.297 0.349 0.397 0.154 \n", "0 0.319 0.359 0.416 0.166 \n", "0 0.310 0.367 0.424 0.176 \n", "0 0.319 0.364 0.412 0.176 \n", "\n", " openbookqa/acc_norm ... siqa/acc siqa/acc_norm winogrande/acc \\\n", "0 0.286 ... 0.365 0.385 0.505 \n", "0 0.290 ... 0.375 0.383 0.509 \n", "0 0.284 ... 0.379 0.400 0.516 \n", "0 0.290 ... 0.382 0.396 0.511 \n", "0 0.276 ... 0.383 0.403 0.510 \n", "\n", " winogrande/acc_norm all/acc all/acc_norm arc/acc arc/acc_norm \\\n", "0 0.493 0.265054 0.281046 0.3265 0.3435 \n", "0 0.502 0.268548 0.282678 0.3340 0.3560 \n", "0 0.490 0.268197 0.286678 0.3330 0.3590 \n", "0 0.494 0.271701 0.289459 0.3250 0.3510 \n", "0 0.493 0.267533 0.287018 0.3295 0.3510 \n", "\n", " mmlu/acc mmlu/acc_norm \n", "0 0.250500 0.264368 \n", "0 0.253134 0.264896 \n", "0 0.252102 0.268633 \n", "0 0.256203 0.271874 \n", "0 0.251046 0.269266 \n", "\n", "[5 rows x 22 columns]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df['runname'] = df['runname'].replace({\"filtered_web_min_score_4_fix-seed-1\": \"FW-Edu-threshold=4\",\n", " \"fineweb_2B_educational_minimum_score_3-seed-0\": \"FW-Edu-threshold=3\",\n", " \"fineweb_2B_educational_regression-seed-6\": \"FW-Edu-threshold=2\",\n", " \"fineweb_2024_10_all_2B-seed-6\": \"FineWeb (FW)\"}, regex=True)\n", "df.tail()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 FW-Edu-threshold=4\n", "0 FW-Edu-threshold=4\n", "0 FW-Edu-threshold=4\n", "0 FW-Edu-threshold=4\n", "0 FW-Edu-threshold=4\n", "0 FW-Edu-threshold=4\n", "0 FW-Edu-threshold=4\n", "0 FW-Edu-threshold=4\n", "0 FW-Edu-threshold=3\n", "0 FW-Edu-threshold=3\n", "0 FW-Edu-threshold=3\n", "0 FW-Edu-threshold=3\n", "0 FW-Edu-threshold=3\n", "0 FW-Edu-threshold=2\n", "0 FW-Edu-threshold=2\n", "0 FW-Edu-threshold=2\n", "0 FW-Edu-threshold=2\n", "0 FW-Edu-threshold=2\n", "0 FineWeb (FW)\n", "0 FineWeb (FW)\n", "0 FineWeb (FW)\n", "0 FineWeb (FW)\n", "0 FineWeb (FW)\n", "0 FineWeb (FW)\n", "0 FineWeb (FW)\n", "0 FineWeb (FW)\n", "Name: runname, dtype: object" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df[\"runname\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Plot saved to plots/edu-8k.png\n" ] } ], "source": [ "\n", "metrics = [\n", " \"mmlu/acc_norm\",\n", " \"arc/acc_norm\",\n", " \"openbookqa/acc_norm\",\n", " \"piqa/acc_norm\",\n", " \"hellaswag/acc_norm\",\n", " \"siqa/acc_norm\",\n", " \"winogrande/acc_norm\",\n", "]\n", "plot_metric_comparison(df, 8000, metrics, output_file=\"edu-8k\", plot_name=\"FineWeb-Edu thresholding\", custom_layout={\n", " \"xaxis\": {\n", " \"title\": {\n", " \"standoff\": 60,\n", " \"text\": \"Dataset\"\n", " },\n", " \"tickangle\": 30\n", " },\n", " \"margin\": {\n", " \"b\": 120\n", " }\n", "})" ] } ], "metadata": { "kernelspec": { "display_name": "textbooks", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 2 }