File size: 2,126 Bytes
aeda668
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
from pathlib import Path
from gector import GecBERTModel
from faster_whisper import WhisperModel, BatchedInferencePipeline
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from text_processing.inverse_normalize import InverseNormalizer
import time

import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True


inverse_normalizer = InverseNormalizer('vi')
current_dir = Path(__file__).parent.as_posix()
whisper_model = WhisperModel("pho_distill_q8", device="auto", compute_type="auto")
batched_model = BatchedInferencePipeline(model=whisper_model, use_vad_model=True, chunk_length=15)
gector_model = GecBERTModel(
    vocab_path=os.path.join(current_dir, "gector/vocabulary"),
    model_paths=[os.path.join(current_dir, "gector/Model_GECTOR")],
    split_chunk=True
)
normalizer = BasicTextNormalizer()

####start transcriptions#####
print("Distill model")
start = time.time()
segments, info = batched_model.transcribe("HA1.wav", language="vi", batch_size=32)
transcriptions = [segment.text for segment in segments]
normalized_transcriptions = [inverse_normalizer.inverse_normalize(normalizer(text)) for text in transcriptions]
corrected_texts = gector_model(normalized_transcriptions)
print(''.join(text for text in corrected_texts))
print(time.time() - start)

del whisper_model , batched_model
#######################################################################################

print("Student model")

whisper_model = WhisperModel("pho_distill_fp16", device="auto", compute_type="auto")
batched_model = BatchedInferencePipeline(model=whisper_model, use_vad_model=True, chunk_length=15)

start = time.time()
segments, info = batched_model.transcribe("HA1.wav", language="vi", batch_size=32)
transcriptions = [segment.text for segment in segments]
normalized_transcriptions = [inverse_normalizer.inverse_normalize(normalizer(text)) for text in transcriptions]
corrected_texts = gector_model(normalized_transcriptions)
print(''.join(text for text in corrected_texts))
print(time.time() - start)