File size: 1,964 Bytes
2cddd11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import numpy as np
import gc
from typing import List

def flush():
    gc.collect()
    torch.cuda.empty_cache()

@torch.no_grad()
def generate(arrays, model, processor, max_new_tokens = 444) -> List[str]:
    """
    arrays: a list of audio arrays
    model: the whisper model to use
    processor: the wisper processor to use
    """

    inputs = processor(arrays, sampling_rate=16000, return_tensors="pt").input_features
    
    # Cache the encoder hidden states
    encoder_hidden_states = model.model.encoder(inputs.to(model.device).to(model.dtype)).last_hidden_state
    
    decoder_ids = torch.tensor([[50258, 50259, 50359, 50363] for _ in range(inputs.shape[0])]).to(model.device)
    
    # Tensor to keep track of which samples have reached the end of text token
    inference_continues = torch.ones(inputs.shape[0], dtype=torch.bool).to(model.device)
    
    while inference_continues.any() and max_new_tokens > 0:
    
        last_hidden_state = model.model.decoder(input_ids = decoder_ids, encoder_hidden_states = encoder_hidden_states).last_hidden_state
    
        # A small optimization to only project the hidden states of the last token
        last_token_hidden_state = last_hidden_state[:, -1, :]
        logits = model.proj_out(last_token_hidden_state)
    
        # Greedy Sampling
        probas = torch.softmax(logits, dim=-1)
        pred_idx = torch.argmax(probas, dim=-1, keepdim=True)
    
        # Fill the samples where inference has stopped with <|end of text|> token
        pred_idx[~inference_continues, :] = 50257
    
        decoder_ids = torch.cat((decoder_ids, pred_idx), dim=-1)
    
        # Check if any sample has reached the end of text token
        reached_end_of_text = (pred_idx.squeeze(-1) == 50257)
        inference_continues &= ~reached_end_of_text
    
        max_new_tokens -= 1
    
    transcripts = processor.batch_decode(decoder_ids, skip_special_tokens=True)

    return transcripts