# fine-tune wav2vec BERT v2

[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phineas-pta/fine-tune-whisper-vi/blob/main/train/w2v-bert-v2.ipynb)

on colab: mount gdrive using GUI before training

on kaggle: select kaggle free T4×2 for auto double batch size

In [None]:
from huggingface_hub import notebook_login
notebook_login()
# !huggingface-cli login --token=███

In [None]:
# workaround for a bug in `datasets` package
%pip uninstall -y cudf dask-cuda dask-cudf
%pip install -q cudf-cu12 --extra-index-url=https://pypi.nvidia.com
%pip install -qU 'datasets[audio]' accelerate transformers jiwer bitsandbytes
# install then `import evaluate` throw error on kaggle

In [None]:
import torch
from dataclasses import dataclass
import datasets as hugDS
from transformers import Wav2Vec2BertForCTC, SeamlessM4TFeatureExtractor, Wav2Vec2CTCTokenizer, TrainingArguments, Trainer
import jiwer

In [None]:
SAMPLING_RATE = 16_000
def load_my_data(mode, **kwargs):
	tmp = hugDS.load_dataset(**kwargs, trust_remote_code=True, streaming=True).cast_column("audio", hugDS.Audio(sampling_rate=SAMPLING_RATE))
	match mode:
		case 0:
			return tmp
		case 1:
			return tmp.select_columns(["audio", "transcription"])
		case 2:
			return tmp.select_columns(["audio", "sentence"]).rename_column("sentence", "transcription")
		case _:
			raise ValueError("oh no!")

MY_DATA = hugDS.IterableDatasetDict()

MY_DATA["train"] = hugDS.concatenate_datasets([ # total: 1.5M samples
	load_my_data(path="google/fleurs", name="vi_vn", split="train", mode=1), # 3k
	load_my_data(path="vivos", split="train", mode=2), # 11.7k
	load_my_data(path="doof-ferb/fpt_fosd", split="train", mode=0), # 25.9k
	load_my_data(path="doof-ferb/infore1_25hours", split="train", mode=0), # 14.9k
	load_my_data(path="doof-ferb/vlsp2020_vinai_100h", split="train", mode=0), # 56.4k
	load_my_data(path="doof-ferb/LSVSC", split="train", mode=1), # 45k
	load_my_data(path="quocanh34/viet_vlsp", split="train", mode=0), # 171k
	load_my_data(path="linhtran92/viet_youtube_asr_corpus_v2", split="train", mode=1), # 195k
	load_my_data(path="doof-ferb/infore2_audiobooks", split="train", mode=0), # 315k
	load_my_data(path="linhtran92/viet_bud500", split="train", mode=0), # 634k
])

MY_DATA["test"] = hugDS.concatenate_datasets([ # total: 15k samples
	load_my_data(path="mozilla-foundation/common_voice_16_1", name="vi", split="test", mode=2), # 1.3k
	# remove FLEURS because error when running in batch
	load_my_data(path="vivos", split="test", mode=2), # .7k
])

In [None]:
modelID = "trick4kid/w2v-bert-2.0-vietnamese-CV16.0"
FEATURE_EXTRACTOR = SeamlessM4TFeatureExtractor.from_pretrained(modelID)
TOKENIZER = Wav2Vec2CTCTokenizer.from_pretrained(modelID)
MODEL = Wav2Vec2BertForCTC.from_pretrained(
	modelID, ctc_loss_reduction="mean", add_adapter=True, mask_time_prob=0.,
	layerdrop=0., pad_token_id=TOKENIZER.pad_token_id, vocab_size=len(TOKENIZER)
)

DUMMY_TOKEN = -100

In [None]:
def prepare_dataset(batch):
	audio = batch["audio"]
	batch["input_features"] = FEATURE_EXTRACTOR(audio["array"], sampling_rate=SAMPLING_RATE).input_features[0] # compute log-Mel input features
	batch["labels"] = TOKENIZER(batch["transcription"]).input_ids # encode target text to label ids
	return batch
MY_DATA = MY_DATA.map(prepare_dataset) # no `num_proc` coz streaming

In [None]:
@dataclass
class DataCollatorCTCWithPadding:
	def __call__(self, features):
		# split inputs and labels since they have to be of different lengths and need different padding methods
		input_features = [{"input_features": feature["input_features"]} for feature in features]
		label_features = [{"input_ids" : feature["labels"] } for feature in features]

		batch = FEATURE_EXTRACTOR.pad(input_features, padding=True, return_tensors="pt")
		labels_batch = TOKENIZER.pad(label_features, padding=True, return_tensors="pt")
		labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), DUMMY_TOKEN) # replace padding with -100 to ignore loss correctly

		batch["labels"] = labels
		return batch

DATA_COLLATOR = DataCollatorCTCWithPadding()

In [None]:
JIWER_TRANS = jiwer.Compose([ # DO NOT use `jiwer.RemoveEmptyStrings` it can cause rows count mismatch
	jiwer.ToLowerCase(),
	jiwer.RemoveKaldiNonWords(),
	jiwer.RemoveMultipleSpaces(),
	jiwer.Strip(),
	jiwer.RemovePunctuation(),
	jiwer.ReduceToListOfListOfWords(),
])

def compute_metrics(pred):
	pred_logits, label_ids = pred.predictions, pred.label_ids
	pred_ids = torch.argmax(pred_logits, axis=-1)
	label_ids[label_ids == DUMMY_TOKEN] = TOKENIZER.pad_token_id # replace -100 with the pad_token_id

	wer = jiwer.wer( # we do not want to group tokens when computing the metrics
		reference=TOKENIZER.batch_decode(label_ids, group_tokens=False)[0],
		hypothesis=TOKENIZER.batch_decode(pred_ids)[0],
		reference_transform=JIWER_TRANS, hypothesis_transform=JIWER_TRANS
	)
	return {"wer": wer}

In [None]:
# mount gdrive using GUI before training
%cd '/content/drive/My Drive/coder'

# %cd /kaggle/working
# !rm -rf ./my-w2v-bert

In [None]:
SAVE_PATH = "./my-w2v-bert"
BATCH_SIZE = 4 # should be a power of 2
# kaggle free P100 train faster than colab free T4
# kaggle free T4×2: no speed up but auto double batch size

# colab free tier can only run for 8-12h max daily
# kaggle free tier can only run for 30h max weekly but max 12h per session

has_bf16 = torch.cuda.is_bf16_supported() # GPU Ampere or later

TRAINING_ARGS = TrainingArguments(
	output_dir=SAVE_PATH,
	per_device_train_batch_size=BATCH_SIZE,
	per_device_eval_batch_size=BATCH_SIZE,
	fp16=not has_bf16,
	bf16=has_bf16, tf32=has_bf16,
	# torch_compile=True, # SDPA not support wav2vec yet
	report_to=["tensorboard"],

	max_steps=1200, # no `num_train_epochs` coz streaming
	logging_steps=25,
	save_steps=50,
	eval_steps=50,
	evaluation_strategy="steps",
	save_total_limit=2,

	optim="adamw_bnb_8bit", # 8-bit AdamW optimizer: lower vram usage than default AdamW
	learning_rate=5e-5,
	warmup_ratio=.05, # keep between 5-15%
	gradient_accumulation_steps=1 if BATCH_SIZE >= 8 else 8 // BATCH_SIZE, # keep effective batch size as min 8 per device
	gradient_checkpointing=True,
	gradient_checkpointing_kwargs={"use_reentrant": False},
	load_best_model_at_end=True,
	metric_for_best_model="wer",
	greater_is_better=False, # WER is better when lower
)

TRAINER = Trainer(
	args=TRAINING_ARGS,
	model=MODEL,
	train_dataset=MY_DATA["train"],
	eval_dataset=MY_DATA["test"],
	data_collator=DATA_COLLATOR,
	compute_metrics=compute_metrics,
	tokenizer=FEATURE_EXTRACTOR, # not TOKENIZER
)

In [None]:
TRAINER.train() # resume_from_checkpoint=True # only if resume

In [None]:
TRAINER.save_model()
!zip -FSr res.zip ./my-w2v-bert