|
import sys, os |
|
import json |
|
import string |
|
from tqdm import tqdm |
|
|
|
def process(text): |
|
|
|
|
|
text = text.lower() |
|
|
|
|
|
punctuation_to_remove = string.punctuation.replace("'", "") |
|
translation_table = str.maketrans('', '', punctuation_to_remove) |
|
text = text.translate(translation_table) |
|
|
|
|
|
while text[0] == ' ' or text[-1] == ' ': |
|
if text[0] == ' ': |
|
text = text[1:] |
|
if text[-1] == ' ': |
|
text = text[:-1] |
|
|
|
return text |
|
|
|
split_name = "train.other.500" |
|
|
|
with open("./blist/all_rare_words.txt") as fin: |
|
rarewords = [process(word.strip()) for word in fin] |
|
|
|
with open(f"./transcripts/{split_name}.txt") as fin: |
|
transcripts = [line.strip() for line in fin] |
|
|
|
from datasets import load_dataset |
|
|
|
cache_dir = "./../cache" |
|
dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True) |
|
|
|
train_data = [] |
|
|
|
pbar = tqdm(dataset[split_name]) |
|
for idx, sample in enumerate(pbar): |
|
|
|
text = process(sample["text"]) |
|
transcript = transcripts[idx] |
|
|
|
bwords = [] |
|
for word in text.split(): |
|
if word in rarewords and word not in transcript: |
|
bwords.append(word) |
|
|
|
if len(bwords) > 0: |
|
train_data.append({ |
|
"split": split_name, |
|
"idx": idx, |
|
"text": text, |
|
"transcript": transcript, |
|
"b_words": bwords, |
|
}) |
|
pbar.set_description(f"Len of train data: {len(train_data)}") |
|
|
|
with open(f"./train_data/{split_name}.json", "w") as fout: |
|
json.dump(train_data, fout, indent=4) |