## Librispeech

In [None]:
from datasets import load_dataset

cache_dir = "./../cache"
dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir)

In [None]:
from torchmetrics import WordErrorRate, CharErrorRate
from edit_distance import SequenceMatcher
from tqdm import tqdm
import jiwer

def correct_text(text):
 transforms = jiwer.Compose(
 [
 jiwer.ExpandCommonEnglishContractions(),
 jiwer.ToLowerCase(),
 jiwer.RemoveMultipleSpaces(),
 jiwer.Strip(),
 jiwer.RemovePunctuation(),
 jiwer.ReduceToListOfListOfWords(),
 ]
 )
 return transforms(text)

def align_gt_asr(gt, asr):

 sm = SequenceMatcher(a=gt, b=asr)
 best_path = []
 opcodes = sm.get_opcodes()

 for tag, i1, i2, j1, j2 in opcodes:

 if tag == "delete":
 for i in range(i1, i2):
 best_path.append([gt[i], ""])

 if tag == "replace" or tag == "equal":
 for i, j in zip(range(i1, i2), range(j1, j2)):
 best_path.append([gt[i], asr[j]])

 if tag == "insert":
 for j in range(j1, j2):
 best_path.append(["", asr[j]])

 return best_path

import string
def process(text):

 # Lower case every letter
 text = text.lower()

 # Remove punctuation
 punctuation_to_remove = string.punctuation.replace("'", "")
 translation_table = str.maketrans('', '', punctuation_to_remove)
 text = text.translate(translation_table)

 # Remove whitespaces from front and behind
 while text[0] == ' ' or text[-1] == ' ':
 if text[0] == ' ':
 text = text[1:]
 if text[-1] == ' ':
 text = text[:-1]
 
 return text

In [None]:
from tqdm import tqdm

gens = []
texts = []

unmatches = []

for split in ["validation.clean"]:
 data = dataset[split]
 with open(f"./transcripts/{split}.txt", "r") as f:
 for idx, line in enumerate(tqdm(f)):
 preds = process(line.rstrip())
 text = data[idx]["text"]

 path = align_gt_asr(correct_text(text)[0], correct_text(preds)[0])
 un = 0
 for a, b in path:
 if a!=b:
 un+=1
 
 unmatches.append(un)

 # texts.append(process(text))
 # gens.append(preds)

In [None]:
import numpy as np

np.count_nonzero(unmatches)

In [None]:
def align_gt_asr(gt, asr):

 sm = SequenceMatcher(a=gt, b=asr)
 best_path = []
 opcodes = sm.get_opcodes()
 
 for tag, i1, i2, j1, j2 in opcodes:
 
 if tag == "delete":
 for i in range(i1, i2):
 best_path.append([gt[i], ""])
 
 if tag == "replace" or tag == "equal":
 for i, j in zip(range(i1, i2), range(j1, j2)):
 best_path.append([gt[i], asr[j]])
 
 if tag == "insert":
 for j in range(j1, j2):
 best_path.append(["", asr[j]])
 
 return best_path

# align_gt_asr(correct_text(text), correct_text(preds))

In [None]:
correct_text(text)

In [None]:
correct_text(["hello", "hey"])

In [None]:
## Transcript of whisper small WER
## validation.clean 4.62
## validation.other 8.11
## test.clean 4.22
## test.other 8.56


In [None]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
import numpy as np
import torch

device = "cuda:0"
dtype = torch.float16
cache_dir = "./../cache"
model_id = "openai/whisper-small"

processor = WhisperProcessor.from_pretrained("openai/whisper-small", cache_dir=cache_dir)
model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=cache_dir, attn_implementation="sdpa").to(device).to(dtype).eval()

## Biasing List

In [None]:
import sys, os
import json
import string
from tqdm import tqdm
def process(text):

 # Lower case every letter
 text = text.lower()

 # Remove punctuation
 punctuation_to_remove = string.punctuation.replace("'", "")
 translation_table = str.maketrans('', '', punctuation_to_remove)
 text = text.translate(translation_table)

 # Remove whitespaces from front and behind
 while text[0] == ' ' or text[-1] == ' ':
 if text[0] == ' ':
 text = text[1:]
 if text[-1] == ' ':
 text = text[:-1]
 
 return text

split_name = "train.clean.100"

with open("./blist/all_rare_words.txt") as fin:
 rarewords = [process(word.strip()) for word in fin]

with open(f"./transcripts/{split_name}.txt") as fin:
 transcripts = [line.strip() for line in fin]

from datasets import load_dataset

cache_dir = "./../cache"
dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True)

train_data = []

pbar = tqdm(dataset[split_name])
for idx, sample in enumerate(pbar):
 
 text = process(sample["text"])
 transcript = transcripts[idx]
 
 bwords = []
 for word in text.split():
 if word in rarewords and word not in transcript:
 bwords.append(word)
 
 if len(bwords) > 0:
 train_data.append({
 "split": split_name,
 "idx": idx,
 "text": text,
 "transcript": transcript,
 "b_words": bwords,
 })
 pbar.set_description(f"Len of train data: {len(train_data)}")

In [None]:
with open(f"./train_data/{split_name}.json", "w") as fout:
 json.dump(train_data, fout, indent=4)