{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# fine-tune wav2vec BERT v2\n", "\n", "[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phineas-pta/fine-tune-whisper-vi/blob/main/train/w2v-bert-v2.ipynb)\n", "\n", "on colab: mount gdrive using GUI before training\n", "\n", "on kaggle: select kaggle free T4×2 for auto double batch size" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "trusted": true }, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "notebook_login()\n", "# !huggingface-cli login --token=███" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "scrolled": true, "trusted": true }, "outputs": [], "source": [ "# workaround for a bug in `datasets` package\n", "%pip uninstall -y cudf dask-cuda dask-cudf\n", "%pip install -q cudf-cu12 --extra-index-url=https://pypi.nvidia.com\n", "%pip install -qU 'datasets[audio]' accelerate transformers jiwer bitsandbytes\n", "# install then `import evaluate` throw error on kaggle" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "trusted": true }, "outputs": [], "source": [ "import torch\n", "from dataclasses import dataclass\n", "import datasets as hugDS\n", "from transformers import Wav2Vec2BertForCTC, SeamlessM4TFeatureExtractor, Wav2Vec2CTCTokenizer, TrainingArguments, Trainer\n", "import jiwer" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "trusted": true }, "outputs": [], "source": [ "SAMPLING_RATE = 16_000\n", "def load_my_data(mode, **kwargs):\n", "\ttmp = hugDS.load_dataset(**kwargs, trust_remote_code=True, streaming=True).cast_column(\"audio\", hugDS.Audio(sampling_rate=SAMPLING_RATE))\n", "\tmatch mode:\n", "\t\tcase 0:\n", "\t\t\treturn tmp\n", "\t\tcase 1:\n", "\t\t\treturn tmp.select_columns([\"audio\", \"transcription\"])\n", "\t\tcase 2:\n", "\t\t\treturn tmp.select_columns([\"audio\", \"sentence\"]).rename_column(\"sentence\", \"transcription\")\n", "\t\tcase _:\n", "\t\t\traise ValueError(\"oh no!\")\n", "\n", "MY_DATA = hugDS.IterableDatasetDict()\n", "\n", "MY_DATA[\"train\"] = hugDS.concatenate_datasets([ # total: 1.5M samples\n", "\tload_my_data(path=\"google/fleurs\", name=\"vi_vn\", split=\"train\", mode=1), # 3k\n", "\tload_my_data(path=\"vivos\", split=\"train\", mode=2), # 11.7k\n", "\tload_my_data(path=\"doof-ferb/fpt_fosd\", split=\"train\", mode=0), # 25.9k\n", "\tload_my_data(path=\"doof-ferb/infore1_25hours\", split=\"train\", mode=0), # 14.9k\n", "\tload_my_data(path=\"doof-ferb/vlsp2020_vinai_100h\", split=\"train\", mode=0), # 56.4k\n", "\tload_my_data(path=\"doof-ferb/LSVSC\", split=\"train\", mode=1), # 45k\n", "\tload_my_data(path=\"quocanh34/viet_vlsp\", split=\"train\", mode=0), # 171k\n", "\tload_my_data(path=\"linhtran92/viet_youtube_asr_corpus_v2\", split=\"train\", mode=1), # 195k\n", "\tload_my_data(path=\"doof-ferb/infore2_audiobooks\", split=\"train\", mode=0), # 315k\n", "\tload_my_data(path=\"linhtran92/viet_bud500\", split=\"train\", mode=0), # 634k\n", "])\n", "\n", "MY_DATA[\"test\"] = hugDS.concatenate_datasets([ # total: 15k samples\n", "\tload_my_data(path=\"mozilla-foundation/common_voice_16_1\", name=\"vi\", split=\"test\", mode=2), # 1.3k\n", "\t# remove FLEURS because error when running in batch\n", "\tload_my_data(path=\"vivos\", split=\"test\", mode=2), # .7k\n", "])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "trusted": true }, "outputs": [], "source": [ "modelID = \"trick4kid/w2v-bert-2.0-vietnamese-CV16.0\"\n", "FEATURE_EXTRACTOR = SeamlessM4TFeatureExtractor.from_pretrained(modelID)\n", "TOKENIZER = Wav2Vec2CTCTokenizer.from_pretrained(modelID)\n", "MODEL = Wav2Vec2BertForCTC.from_pretrained(\n", "\tmodelID, ctc_loss_reduction=\"mean\", add_adapter=True, mask_time_prob=0.,\n", "\tlayerdrop=0., pad_token_id=TOKENIZER.pad_token_id, vocab_size=len(TOKENIZER)\n", ")\n", "\n", "DUMMY_TOKEN = -100" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "trusted": true }, "outputs": [], "source": [ "def prepare_dataset(batch):\n", "\taudio = batch[\"audio\"]\n", "\tbatch[\"input_features\"] = FEATURE_EXTRACTOR(audio[\"array\"], sampling_rate=SAMPLING_RATE).input_features[0] # compute log-Mel input features\n", "\tbatch[\"labels\"] = TOKENIZER(batch[\"transcription\"]).input_ids # encode target text to label ids\n", "\treturn batch\n", "MY_DATA = MY_DATA.map(prepare_dataset) # no `num_proc` coz streaming" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "trusted": true }, "outputs": [], "source": [ "@dataclass\n", "class DataCollatorCTCWithPadding:\n", "\tdef __call__(self, features):\n", "\t\t# split inputs and labels since they have to be of different lengths and need different padding methods\n", "\t\tinput_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n", "\t\tlabel_features = [{\"input_ids\" : feature[\"labels\"] } for feature in features]\n", "\n", "\t\tbatch = FEATURE_EXTRACTOR.pad(input_features, padding=True, return_tensors=\"pt\")\n", "\t\tlabels_batch = TOKENIZER.pad(label_features, padding=True, return_tensors=\"pt\")\n", "\t\tlabels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), DUMMY_TOKEN) # replace padding with -100 to ignore loss correctly\n", "\n", "\t\tbatch[\"labels\"] = labels\n", "\t\treturn batch\n", "\n", "DATA_COLLATOR = DataCollatorCTCWithPadding()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "trusted": true }, "outputs": [], "source": [ "JIWER_TRANS = jiwer.Compose([ # DO NOT use `jiwer.RemoveEmptyStrings` it can cause rows count mismatch\n", "\tjiwer.ToLowerCase(),\n", "\tjiwer.RemoveKaldiNonWords(),\n", "\tjiwer.RemoveMultipleSpaces(),\n", "\tjiwer.Strip(),\n", "\tjiwer.RemovePunctuation(),\n", "\tjiwer.ReduceToListOfListOfWords(),\n", "])\n", "\n", "def compute_metrics(pred):\n", "\tpred_logits, label_ids = pred.predictions, pred.label_ids\n", "\tpred_ids = torch.argmax(pred_logits, axis=-1)\n", "\tlabel_ids[label_ids == DUMMY_TOKEN] = TOKENIZER.pad_token_id # replace -100 with the pad_token_id\n", "\n", "\twer = jiwer.wer( # we do not want to group tokens when computing the metrics\n", "\t\treference=TOKENIZER.batch_decode(label_ids, group_tokens=False)[0],\n", "\t\thypothesis=TOKENIZER.batch_decode(pred_ids)[0],\n", "\t\treference_transform=JIWER_TRANS, hypothesis_transform=JIWER_TRANS\n", "\t)\n", "\treturn {\"wer\": wer}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "trusted": true }, "outputs": [], "source": [ "# mount gdrive using GUI before training\n", "%cd '/content/drive/My Drive/coder'\n", "\n", "# %cd /kaggle/working\n", "# !rm -rf ./my-w2v-bert" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "trusted": true }, "outputs": [], "source": [ "SAVE_PATH = \"./my-w2v-bert\"\n", "BATCH_SIZE = 4 # should be a power of 2\n", "# kaggle free P100 train faster than colab free T4\n", "# kaggle free T4×2: no speed up but auto double batch size\n", "\n", "# colab free tier can only run for 8-12h max daily\n", "# kaggle free tier can only run for 30h max weekly but max 12h per session\n", "\n", "has_bf16 = torch.cuda.is_bf16_supported() # GPU Ampere or later\n", "\n", "TRAINING_ARGS = TrainingArguments(\n", "\toutput_dir=SAVE_PATH,\n", "\tper_device_train_batch_size=BATCH_SIZE,\n", "\tper_device_eval_batch_size=BATCH_SIZE,\n", "\tfp16=not has_bf16,\n", "\tbf16=has_bf16, tf32=has_bf16,\n", "\t# torch_compile=True, # SDPA not support wav2vec yet\n", "\treport_to=[\"tensorboard\"],\n", "\n", "\tmax_steps=1200, # no `num_train_epochs` coz streaming\n", "\tlogging_steps=25,\n", "\tsave_steps=50,\n", "\teval_steps=50,\n", "\tevaluation_strategy=\"steps\",\n", "\tsave_total_limit=2,\n", "\n", "\toptim=\"adamw_bnb_8bit\", # 8-bit AdamW optimizer: lower vram usage than default AdamW\n", "\tlearning_rate=5e-5,\n", "\twarmup_ratio=.05, # keep between 5-15%\n", "\tgradient_accumulation_steps=1 if BATCH_SIZE >= 8 else 8 // BATCH_SIZE, # keep effective batch size as min 8 per device\n", "\tgradient_checkpointing=True,\n", "\tgradient_checkpointing_kwargs={\"use_reentrant\": False},\n", "\tload_best_model_at_end=True,\n", "\tmetric_for_best_model=\"wer\",\n", "\tgreater_is_better=False, # WER is better when lower\n", ")\n", "\n", "TRAINER = Trainer(\n", "\targs=TRAINING_ARGS,\n", "\tmodel=MODEL,\n", "\ttrain_dataset=MY_DATA[\"train\"],\n", "\teval_dataset=MY_DATA[\"test\"],\n", "\tdata_collator=DATA_COLLATOR,\n", "\tcompute_metrics=compute_metrics,\n", "\ttokenizer=FEATURE_EXTRACTOR, # not TOKENIZER\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "scrolled": true, "trusted": true }, "outputs": [], "source": [ "TRAINER.train() # resume_from_checkpoint=True # only if resume" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false }, "trusted": true }, "outputs": [], "source": [ "TRAINER.save_model()\n", "!zip -FSr res.zip ./my-w2v-bert" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "private_outputs": true, "provenance": [] }, "kaggle": { "accelerator": "nvidiaTeslaT4", "dataSources": [], "isGpuEnabled": true, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 0 }