|
import torch |
|
import numpy as np |
|
import gc |
|
from typing import List |
|
|
|
def flush(): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
@torch.no_grad() |
|
def generate(arrays, model, processor, max_new_tokens = 444) -> List[str]: |
|
""" |
|
arrays: a list of audio arrays |
|
model: the whisper model to use |
|
processor: the wisper processor to use |
|
""" |
|
|
|
inputs = processor(arrays, sampling_rate=16000, return_tensors="pt").input_features |
|
|
|
|
|
encoder_hidden_states = model.model.encoder(inputs.to(model.device).to(model.dtype)).last_hidden_state |
|
|
|
decoder_ids = torch.tensor([[50258, 50259, 50359, 50363] for _ in range(inputs.shape[0])]).to(model.device) |
|
|
|
|
|
inference_continues = torch.ones(inputs.shape[0], dtype=torch.bool).to(model.device) |
|
|
|
while inference_continues.any() and max_new_tokens > 0: |
|
|
|
last_hidden_state = model.model.decoder(input_ids = decoder_ids, encoder_hidden_states = encoder_hidden_states).last_hidden_state |
|
|
|
|
|
last_token_hidden_state = last_hidden_state[:, -1, :] |
|
logits = model.proj_out(last_token_hidden_state) |
|
|
|
|
|
probas = torch.softmax(logits, dim=-1) |
|
pred_idx = torch.argmax(probas, dim=-1, keepdim=True) |
|
|
|
|
|
pred_idx[~inference_continues, :] = 50257 |
|
|
|
decoder_ids = torch.cat((decoder_ids, pred_idx), dim=-1) |
|
|
|
|
|
reached_end_of_text = (pred_idx.squeeze(-1) == 50257) |
|
inference_continues &= ~reached_end_of_text |
|
|
|
max_new_tokens -= 1 |
|
|
|
transcripts = processor.batch_decode(decoder_ids, skip_special_tokens=True) |
|
|
|
return transcripts |