|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
from datasets import load_dataset |
|
from tqdm import tqdm |
|
from math import ceil |
|
from model import generate, flush |
|
import numpy as np |
|
import os |
|
import torch |
|
import string |
|
|
|
def process(text): |
|
|
|
|
|
text = text.lower() |
|
|
|
|
|
punctuation_to_remove = string.punctuation.replace("'", "") |
|
translation_table = str.maketrans('', '', punctuation_to_remove) |
|
text = text.translate(translation_table) |
|
|
|
|
|
while text[0] == ' ' or text[-1] == ' ': |
|
if text[0] == ' ': |
|
text = text[1:] |
|
if text[-1] == ' ': |
|
text = text[:-1] |
|
|
|
return text |
|
|
|
device = "cuda:0" |
|
dtype = torch.float16 |
|
cache_dir = "./../cache" |
|
model_id = "openai/whisper-small" |
|
batch_size = 250 |
|
out_dir = "./transcripts" |
|
|
|
dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True) |
|
|
|
processor = WhisperProcessor.from_pretrained(model_id, cache_dir=cache_dir) |
|
model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=cache_dir, attn_implementation="sdpa").to(device).to(dtype).eval() |
|
|
|
for split in dataset.keys(): |
|
|
|
data = dataset[split] |
|
|
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
for idx in tqdm(range(ceil(len(data)/batch_size))): |
|
|
|
audios = data[idx * batch_size: (idx + 1) * batch_size]["audio"] |
|
|
|
arrays = [a["array"] for a in audios] |
|
|
|
transcripts = generate(arrays, model, processor) |
|
|
|
with open(os.path.join(out_dir, f"{split}.txt"), "a") as disk: |
|
disk.writelines([process(text) + "\n" for text in transcripts]) |
|
disk.close() |
|
|
|
flush() |