storage / prompting /generate_transcripts.py
darshanmakwana's picture
Upload folder using huggingface_hub
2cddd11 verified
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
from tqdm import tqdm
from math import ceil
from model import generate, flush
import numpy as np
import os
import torch
import string
def process(text):
# Lower case every letter
text = text.lower()
# Remove punctuation
punctuation_to_remove = string.punctuation.replace("'", "")
translation_table = str.maketrans('', '', punctuation_to_remove)
text = text.translate(translation_table)
# Remove whitespaces from front and behind
while text[0] == ' ' or text[-1] == ' ':
if text[0] == ' ':
text = text[1:]
if text[-1] == ' ':
text = text[:-1]
return text
device = "cuda:0"
dtype = torch.float16
cache_dir = "./../cache"
model_id = "openai/whisper-small"
batch_size = 250
out_dir = "./transcripts"
dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True)
processor = WhisperProcessor.from_pretrained(model_id, cache_dir=cache_dir)
model = WhisperForConditionalGeneration.from_pretrained(model_id, cache_dir=cache_dir, attn_implementation="sdpa").to(device).to(dtype).eval()
for split in dataset.keys():
data = dataset[split]
os.makedirs(out_dir, exist_ok=True)
for idx in tqdm(range(ceil(len(data)/batch_size))):
audios = data[idx * batch_size: (idx + 1) * batch_size]["audio"]
arrays = [a["array"] for a in audios]
transcripts = generate(arrays, model, processor)
with open(os.path.join(out_dir, f"{split}.txt"), "a") as disk:
disk.writelines([process(text) + "\n" for text in transcripts])
disk.close()
flush()