File size: 4,524 Bytes
588b387
 
 
 
722d7e2
588b387
 
 
 
 
 
 
 
0454f45
588b387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0454f45
 
 
 
 
588b387
0454f45
588b387
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84a0fab
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from transformers import WhisperForConditionalGeneration, WhisperProcessor
import torchaudio
import torch
import librosa
import ffmpeg 

MODEL_NAME = "openai/whisper-large-v3"

device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("[ INFO ] Device: ", device)
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32


model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME, torch_dtype=torch_dtype).to(device)
processor = WhisperProcessor.from_pretrained(MODEL_NAME)


def convert_forced_to_tokens(forced_decoder_ids):
    forced_decoder_tokens = []
    for i, (idx, token) in enumerate(forced_decoder_ids):
        if token is not None:
            forced_decoder_tokens.append([idx, processor.tokenizer.decode(token)])
        else:
            forced_decoder_tokens.append([idx, token])
    return forced_decoder_tokens


def change_formate(input_file):
    ffmpeg.input(input_file).output("16_" + input_file, loglevel='quiet', **{'ar': '16000'}).run(overwrite_output=True)
    return "16_" + input_file


def generate(audio):
    audio = change_formate(audio)
    input_audio, sample_rate = torchaudio.load(audio)

    #metadata = torchaudio.info(audio)
    #length1 = math.ceil(metadata.num_frames / metadata.sample_rate)
    length = librosa.get_duration(path=audio)

    input_speech = input_audio[0]


    if length <= 30:
        input_features = processor(input_speech, 
                                    sampling_rate=16_000, 
                                    return_tensors="pt", torch_dtype=torch_dtype).input_features.to(device)

    else:
        input_features = processor(input_speech, 
                                return_tensors="pt", 
                                truncation=False, 
                                padding="longest", 
                                return_attention_mask=True,
                                sampling_rate=16_000).input_features.to(device)
    forced_decoder_ids = []
    forced_decoder_ids.append([1,50270]) #[1, '<|ca|>']
    forced_decoder_ids.append([2,50262]) #[2, '<|es|>']
    forced_decoder_ids.append([3,50360]) #[3, '<|transcribe|>']

    forced_decoder_ids_modified = forced_decoder_ids
    idx = processor.tokenizer.all_special_tokens.index("<|startofprev|>")
    forced_bos_token_id = processor.tokenizer.all_special_ids[idx]
    
    prompt = " transcribe an audio containing code-switching between es and ca"
    prompt_tokens = processor.tokenizer(prompt, add_special_tokens=False).input_ids

    # we need to force these tokens
    forced_decoder_ids = []
    for idx, token in enumerate(prompt_tokens):
        # indexing starts from 1 for forced tokens (token at position 0 is the SOS token)
        forced_decoder_ids.append([idx + 1, token])
        
    # now we add the SOS token at the end
    offset = len(forced_decoder_ids)
    forced_decoder_ids.append([offset + 1, model.generation_config.decoder_start_token_id])

    # now we need to append the rest of the prefix tokens (lang, task, timestamps)
    offset = len(forced_decoder_ids)
    for idx, token in forced_decoder_ids_modified:
        forced_decoder_ids.append([idx + offset , token])

    model.config.forced_decoder_ids = forced_decoder_ids
    model.generation_config.forced_decoder_ids = forced_decoder_ids


    if length <= 30:
        pred_ids = model.generate(input_features, 
                                    return_timestamps=True,
                                    decoder_start_token_id=forced_bos_token_id,  
                                    max_new_tokens=128)
        #exclude prompt from output
        forced_decoder_tokens = convert_forced_to_tokens(forced_decoder_ids)
        output = processor.decode(pred_ids[0][len(forced_decoder_tokens) + 1:], skip_special_tokens=True)

    else:
        pred_ids = model.generate(input_features,
                                    return_timestamps=True,
                                    decoder_start_token_id=forced_bos_token_id, 
                                    logprob_threshold=-1.0, 
                                    compression_ratio_threshold=1.35,
                                    temperature=(0.0, 0.2, 0.4),
                                    no_speech_threshold=0.1,
                                    )
        output = processor.batch_decode(pred_ids, skip_special_tokens=True)

    if length <= 30:
        return output[1:]
    else:
        return output[0]