import torch import numpy as np import gc from typing import List def flush(): gc.collect() torch.cuda.empty_cache() @torch.no_grad() def generate(arrays, model, processor, max_new_tokens = 444) -> List[str]: """ arrays: a list of audio arrays model: the whisper model to use processor: the wisper processor to use """ inputs = processor(arrays, sampling_rate=16000, return_tensors="pt").input_features # Cache the encoder hidden states encoder_hidden_states = model.model.encoder(inputs.to(model.device).to(model.dtype)).last_hidden_state decoder_ids = torch.tensor([[50258, 50259, 50359, 50363] for _ in range(inputs.shape[0])]).to(model.device) # Tensor to keep track of which samples have reached the end of text token inference_continues = torch.ones(inputs.shape[0], dtype=torch.bool).to(model.device) while inference_continues.any() and max_new_tokens > 0: last_hidden_state = model.model.decoder(input_ids = decoder_ids, encoder_hidden_states = encoder_hidden_states).last_hidden_state # A small optimization to only project the hidden states of the last token last_token_hidden_state = last_hidden_state[:, -1, :] logits = model.proj_out(last_token_hidden_state) # Greedy Sampling probas = torch.softmax(logits, dim=-1) pred_idx = torch.argmax(probas, dim=-1, keepdim=True) # Fill the samples where inference has stopped with <|end of text|> token pred_idx[~inference_continues, :] = 50257 decoder_ids = torch.cat((decoder_ids, pred_idx), dim=-1) # Check if any sample has reached the end of text token reached_end_of_text = (pred_idx.squeeze(-1) == 50257) inference_continues &= ~reached_end_of_text max_new_tokens -= 1 transcripts = processor.batch_decode(decoder_ids, skip_special_tokens=True) return transcripts