import sys, os import json import string from tqdm import tqdm def process(text): # Lower case every letter text = text.lower() # Remove punctuation punctuation_to_remove = string.punctuation.replace("'", "") translation_table = str.maketrans('', '', punctuation_to_remove) text = text.translate(translation_table) # Remove whitespaces from front and behind while text[0] == ' ' or text[-1] == ' ': if text[0] == ' ': text = text[1:] if text[-1] == ' ': text = text[:-1] return text split_name = "train.other.500" with open("./blist/all_rare_words.txt") as fin: rarewords = [process(word.strip()) for word in fin] with open(f"./transcripts/{split_name}.txt") as fin: transcripts = [line.strip() for line in fin] from datasets import load_dataset cache_dir = "./../cache" dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True) train_data = [] pbar = tqdm(dataset[split_name]) for idx, sample in enumerate(pbar): text = process(sample["text"]) transcript = transcripts[idx] bwords = [] for word in text.split(): if word in rarewords and word not in transcript: bwords.append(word) if len(bwords) > 0: train_data.append({ "split": split_name, "idx": idx, "text": text, "transcript": transcript, "b_words": bwords, }) pbar.set_description(f"Len of train data: {len(train_data)}") with open(f"./train_data/{split_name}.json", "w") as fout: json.dump(train_data, fout, indent=4)