File size: 1,964 Bytes
2cddd11 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch
import numpy as np
import gc
from typing import List
def flush():
gc.collect()
torch.cuda.empty_cache()
@torch.no_grad()
def generate(arrays, model, processor, max_new_tokens = 444) -> List[str]:
"""
arrays: a list of audio arrays
model: the whisper model to use
processor: the wisper processor to use
"""
inputs = processor(arrays, sampling_rate=16000, return_tensors="pt").input_features
# Cache the encoder hidden states
encoder_hidden_states = model.model.encoder(inputs.to(model.device).to(model.dtype)).last_hidden_state
decoder_ids = torch.tensor([[50258, 50259, 50359, 50363] for _ in range(inputs.shape[0])]).to(model.device)
# Tensor to keep track of which samples have reached the end of text token
inference_continues = torch.ones(inputs.shape[0], dtype=torch.bool).to(model.device)
while inference_continues.any() and max_new_tokens > 0:
last_hidden_state = model.model.decoder(input_ids = decoder_ids, encoder_hidden_states = encoder_hidden_states).last_hidden_state
# A small optimization to only project the hidden states of the last token
last_token_hidden_state = last_hidden_state[:, -1, :]
logits = model.proj_out(last_token_hidden_state)
# Greedy Sampling
probas = torch.softmax(logits, dim=-1)
pred_idx = torch.argmax(probas, dim=-1, keepdim=True)
# Fill the samples where inference has stopped with <|end of text|> token
pred_idx[~inference_continues, :] = 50257
decoder_ids = torch.cat((decoder_ids, pred_idx), dim=-1)
# Check if any sample has reached the end of text token
reached_end_of_text = (pred_idx.squeeze(-1) == 50257)
inference_continues &= ~reached_end_of_text
max_new_tokens -= 1
transcripts = processor.batch_decode(decoder_ids, skip_special_tokens=True)
return transcripts |