from transformers import WhisperProcessor, WhisperForConditionalGeneration from datasets import load_dataset from tqdm import tqdm from math import ceil from model import generate, flush import numpy as np import os import torch import string def process(text): # Lower case every letter text = text.lower() # Remove punctuation punctuation_to_remove = string.punctuation.replace("'", "") translation_table = str.maketrans('', '', punctuation_to_remove) text = text.translate(translation_table) # Remove whitespaces from front and behind while text[0] == ' ' or text[-1] == ' ': if text[0] == ' ': text = text[1:] if text[-1] == ' ': text = text[:-1] return text device = "cuda:0" dtype = torch.float16 cache_dir = "./../cache" model_id = "openai/whisper-small" batch_size = 250 out_dir = "./transcripts" dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True) processor = WhisperProcessor.from_pretrained(model_id, cache_dir=cache_dir) model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=cache_dir, attn_implementation="sdpa").to(device).to(dtype).eval() for split in dataset.keys(): data = dataset[split] os.makedirs(out_dir, exist_ok=True) for idx in tqdm(range(ceil(len(data)/batch_size))): audios = data[idx * batch_size: (idx + 1) * batch_size]["audio"] arrays = [a["array"] for a in audios] transcripts = generate(arrays, model, processor) with open(os.path.join(out_dir, f"{split}.txt"), "a") as disk: disk.writelines([process(text) + "\n" for text in transcripts]) disk.close() flush()