{ "cells": [ { "cell_type": "markdown", "id": "59efc3d7-a57f-43cc-8aa3-34bb57de0251", "metadata": {}, "source": [ "## Librispeech" ] }, { "cell_type": "code", "execution_count": null, "id": "327243de-fd0f-449d-998a-63282a1c67a2", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "cache_dir = \"./../cache\"\n", "dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir)" ] }, { "cell_type": "code", "execution_count": null, "id": "456889e1-f8cc-440b-bf6b-f6fbfafc367d", "metadata": {}, "outputs": [], "source": [ "from torchmetrics import WordErrorRate, CharErrorRate\n", "from edit_distance import SequenceMatcher\n", "from tqdm import tqdm\n", "import jiwer\n", "\n", "def correct_text(text):\n", " transforms = jiwer.Compose(\n", " [\n", " jiwer.ExpandCommonEnglishContractions(),\n", " jiwer.ToLowerCase(),\n", " jiwer.RemoveMultipleSpaces(),\n", " jiwer.Strip(),\n", " jiwer.RemovePunctuation(),\n", " jiwer.ReduceToListOfListOfWords(),\n", " ]\n", " )\n", " return transforms(text)\n", "\n", "def align_gt_asr(gt, asr):\n", "\n", " sm = SequenceMatcher(a=gt, b=asr)\n", " best_path = []\n", " opcodes = sm.get_opcodes()\n", "\n", " for tag, i1, i2, j1, j2 in opcodes:\n", "\n", " if tag == \"delete\":\n", " for i in range(i1, i2):\n", " best_path.append([gt[i], \"\"])\n", "\n", " if tag == \"replace\" or tag == \"equal\":\n", " for i, j in zip(range(i1, i2), range(j1, j2)):\n", " best_path.append([gt[i], asr[j]])\n", "\n", " if tag == \"insert\":\n", " for j in range(j1, j2):\n", " best_path.append([\"\", asr[j]])\n", "\n", " return best_path\n", "\n", "import string\n", "def process(text):\n", "\n", " # Lower case every letter\n", " text = text.lower()\n", "\n", " # Remove punctuation\n", " punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n", " translation_table = str.maketrans('', '', punctuation_to_remove)\n", " text = text.translate(translation_table)\n", "\n", " # Remove whitespaces from front and behind\n", " while text[0] == ' ' or text[-1] == ' ':\n", " if text[0] == ' ':\n", " text = text[1:]\n", " if text[-1] == ' ':\n", " text = text[:-1]\n", " \n", " return text" ] }, { "cell_type": "code", "execution_count": null, "id": "3bc907b0-2ebe-46ac-b6a1-02919e69af88", "metadata": {}, "outputs": [], "source": [ "from tqdm import tqdm\n", "\n", "gens = []\n", "texts = []\n", "\n", "unmatches = []\n", "\n", "for split in [\"validation.clean\"]:\n", " data = dataset[split]\n", " with open(f\"./transcripts/{split}.txt\", \"r\") as f:\n", " for idx, line in enumerate(tqdm(f)):\n", " preds = process(line.rstrip())\n", " text = data[idx][\"text\"]\n", "\n", " path = align_gt_asr(correct_text(text)[0], correct_text(preds)[0])\n", " un = 0\n", " for a, b in path:\n", " if a!=b:\n", " un+=1\n", " \n", " unmatches.append(un)\n", "\n", " # texts.append(process(text))\n", " # gens.append(preds)" ] }, { "cell_type": "code", "execution_count": null, "id": "cac10009-1b47-4e2f-a232-f71b23ee983e", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "np.count_nonzero(unmatches)" ] }, { "cell_type": "code", "execution_count": null, "id": "afdc9f74-c2cf-4d52-8563-1bd827f6d900", "metadata": {}, "outputs": [], "source": [ "def align_gt_asr(gt, asr):\n", "\n", " sm = SequenceMatcher(a=gt, b=asr)\n", " best_path = []\n", " opcodes = sm.get_opcodes()\n", " \n", " for tag, i1, i2, j1, j2 in opcodes:\n", " \n", " if tag == \"delete\":\n", " for i in range(i1, i2):\n", " best_path.append([gt[i], \"\"])\n", " \n", " if tag == \"replace\" or tag == \"equal\":\n", " for i, j in zip(range(i1, i2), range(j1, j2)):\n", " best_path.append([gt[i], asr[j]])\n", " \n", " if tag == \"insert\":\n", " for j in range(j1, j2):\n", " best_path.append([\"\", asr[j]])\n", " \n", " return best_path\n", "\n", "# align_gt_asr(correct_text(text), correct_text(preds))" ] }, { "cell_type": "code", "execution_count": null, "id": "3cdfd3d9-6c22-4ccd-a22b-df8e79fc20b0", "metadata": {}, "outputs": [], "source": [ "correct_text(text)" ] }, { "cell_type": "code", "execution_count": null, "id": "2c33f46a-f3dd-435f-81e3-e7b10ae03470", "metadata": {}, "outputs": [], "source": [ "correct_text([\"hello\", \"hey\"])" ] }, { "cell_type": "code", "execution_count": null, "id": "2cfab12a-2b2c-4c00-bd80-ab571c012f29", "metadata": {}, "outputs": [], "source": [ "## Transcript of whisper small WER\n", "## validation.clean 4.62\n", "## validation.other 8.11\n", "## test.clean 4.22\n", "## test.other 8.56\n" ] }, { "cell_type": "code", "execution_count": null, "id": "24cb2d8f-9ce2-42f2-bbf0-522106078aac", "metadata": {}, "outputs": [], "source": [ "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n", "from datasets import load_dataset\n", "import numpy as np\n", "import torch\n", "\n", "device = \"cuda:0\"\n", "dtype = torch.float16\n", "cache_dir = \"./../cache\"\n", "model_id = \"openai/whisper-small\"\n", "\n", "processor = WhisperProcessor.from_pretrained(\"openai/whisper-small\", cache_dir=cache_dir)\n", "model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=cache_dir, attn_implementation=\"sdpa\").to(device).to(dtype).eval()" ] }, { "cell_type": "markdown", "id": "d5fa6f8e-43f2-44ce-b719-2d8fde4067ce", "metadata": {}, "source": [ "## Biasing List" ] }, { "cell_type": "code", "execution_count": null, "id": "3cc0f934-d208-445e-aecd-31df73be6986", "metadata": {}, "outputs": [], "source": [ "import sys, os\n", "import json\n", "import string\n", "from tqdm import tqdm\n", "def process(text):\n", "\n", " # Lower case every letter\n", " text = text.lower()\n", "\n", " # Remove punctuation\n", " punctuation_to_remove = string.punctuation.replace(\"'\", \"\")\n", " translation_table = str.maketrans('', '', punctuation_to_remove)\n", " text = text.translate(translation_table)\n", "\n", " # Remove whitespaces from front and behind\n", " while text[0] == ' ' or text[-1] == ' ':\n", " if text[0] == ' ':\n", " text = text[1:]\n", " if text[-1] == ' ':\n", " text = text[:-1]\n", " \n", " return text\n", "\n", "split_name = \"train.clean.100\"\n", "\n", "with open(\"./blist/all_rare_words.txt\") as fin:\n", " rarewords = [process(word.strip()) for word in fin]\n", "\n", "with open(f\"./transcripts/{split_name}.txt\") as fin:\n", " transcripts = [line.strip() for line in fin]\n", "\n", "from datasets import load_dataset\n", "\n", "cache_dir = \"./../cache\"\n", "dataset = load_dataset(\"openslr/librispeech_asr\", cache_dir=cache_dir, trust_remote_code=True)\n", "\n", "train_data = []\n", "\n", "pbar = tqdm(dataset[split_name])\n", "for idx, sample in enumerate(pbar):\n", " \n", " text = process(sample[\"text\"])\n", " transcript = transcripts[idx]\n", " \n", " bwords = []\n", " for word in text.split():\n", " if word in rarewords and word not in transcript:\n", " bwords.append(word)\n", " \n", " if len(bwords) > 0:\n", " train_data.append({\n", " \"split\": split_name,\n", " \"idx\": idx,\n", " \"text\": text,\n", " \"transcript\": transcript,\n", " \"b_words\": bwords,\n", " })\n", " pbar.set_description(f\"Len of train data: {len(train_data)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "cac9a909-e1ce-426a-bda3-b65ba3985d06", "metadata": {}, "outputs": [], "source": [ "with open(f\"./train_data/{split_name}.json\", \"w\") as fout:\n", " json.dump(train_data, fout, indent=4)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }