LAP-DEV commited on
Commit
cc7de77
·
verified ·
1 Parent(s): 59cd625

Upload faster_whisper_inference.py

Browse files
modules/whisper/faster_whisper_inference.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import huggingface_hub
4
+ import numpy as np
5
+ import torch
6
+ from typing import BinaryIO, Union, Tuple, List
7
+ import faster_whisper
8
+ from faster_whisper.vad import VadOptions
9
+ import ast
10
+ import ctranslate2
11
+ import whisper
12
+ import gradio as gr
13
+ from argparse import Namespace
14
+
15
+ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
16
+ from modules.whisper.whisper_parameter import *
17
+ from modules.whisper.whisper_base import WhisperBase
18
+
19
+ class FasterWhisperInference(WhisperBase):
20
+ def __init__(self,
21
+ model_dir: str = FASTER_WHISPER_MODELS_DIR,
22
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
23
+ uvr_model_dir: str = UVR_MODELS_DIR,
24
+ output_dir: str = OUTPUT_DIR,
25
+ ):
26
+ super().__init__(
27
+ model_dir=model_dir,
28
+ diarization_model_dir=diarization_model_dir,
29
+ uvr_model_dir=uvr_model_dir,
30
+ output_dir=output_dir
31
+ )
32
+ self.model_dir = model_dir
33
+ os.makedirs(self.model_dir, exist_ok=True)
34
+
35
+ self.model_paths = self.get_model_paths()
36
+ self.device = self.get_device()
37
+ self.available_models = self.model_paths.keys()
38
+
39
+ def transcribe(self,
40
+ audio: Union[str, BinaryIO, np.ndarray],
41
+ progress: gr.Progress = gr.Progress(),
42
+ *whisper_params,
43
+ ) -> Tuple[List[dict], float]:
44
+ """
45
+ transcribe method for faster-whisper.
46
+
47
+ Parameters
48
+ ----------
49
+ audio: Union[str, BinaryIO, np.ndarray]
50
+ Audio path or file binary or Audio numpy array
51
+ progress: gr.Progress
52
+ Indicator to show progress directly in gradio.
53
+ *whisper_params: tuple
54
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
55
+
56
+ Returns
57
+ ----------
58
+ segments_result: List[dict]
59
+ list of Segment that includes start, end timestamps and transcribed text
60
+ elapsed_time: float
61
+ elapsed time for transcription
62
+ """
63
+ start_time = time.time()
64
+
65
+ params = WhisperParameters.as_value(*whisper_params)
66
+ params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
67
+
68
+ if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
69
+ self.update_model(params.model_size, params.compute_type, progress)
70
+
71
+ segments, info = self.model.transcribe(
72
+ audio=audio,
73
+ language=params.lang,
74
+ task="translate" if params.is_translate else "transcribe",
75
+ beam_size=params.beam_size,
76
+ log_prob_threshold=params.log_prob_threshold,
77
+ no_speech_threshold=params.no_speech_threshold,
78
+ best_of=params.best_of,
79
+ patience=params.patience,
80
+ temperature=params.temperature,
81
+ initial_prompt=params.initial_prompt,
82
+ compression_ratio_threshold=params.compression_ratio_threshold,
83
+ length_penalty=params.length_penalty,
84
+ repetition_penalty=params.repetition_penalty,
85
+ no_repeat_ngram_size=params.no_repeat_ngram_size,
86
+ prefix=params.prefix,
87
+ suppress_blank=params.suppress_blank,
88
+ suppress_tokens=params.suppress_tokens,
89
+ max_initial_timestamp=params.max_initial_timestamp,
90
+ word_timestamps=params.word_timestamps,
91
+ prepend_punctuations=params.prepend_punctuations,
92
+ append_punctuations=params.append_punctuations,
93
+ max_new_tokens=params.max_new_tokens,
94
+ chunk_length=params.chunk_length,
95
+ hallucination_silence_threshold=params.hallucination_silence_threshold,
96
+ hotwords=params.hotwords,
97
+ language_detection_threshold=params.language_detection_threshold,
98
+ language_detection_segments=params.language_detection_segments,
99
+ prompt_reset_on_temperature=params.prompt_reset_on_temperature,
100
+ )
101
+ progress(0, desc="Loading audio...")
102
+
103
+ segments_result = []
104
+ for segment in segments:
105
+ progress(segment.start / info.duration, desc="Transcribing...")
106
+ segments_result.append({
107
+ "start": segment.start,
108
+ "end": segment.end,
109
+ "text": segment.text
110
+ })
111
+
112
+ elapsed_time = time.time() - start_time
113
+ return segments_result, elapsed_time
114
+
115
+ def update_model(self,
116
+ model_size: str,
117
+ compute_type: str,
118
+ progress: gr.Progress = gr.Progress()
119
+ ):
120
+ """
121
+ Update current model setting
122
+
123
+ Parameters
124
+ ----------
125
+ model_size: str
126
+ Size of whisper model. If you enter the huggingface repo id, it will try to download the model
127
+ automatically from huggingface.
128
+ compute_type: str
129
+ Compute type for transcription.
130
+ see more info : https://opennmt.net/CTranslate2/quantization.html
131
+ progress: gr.Progress
132
+ Indicator to show progress directly in gradio.
133
+ """
134
+ progress(0, desc="Initializing Model...")
135
+
136
+ model_size_dirname = model_size.replace("/", "--") if "/" in model_size else model_size
137
+ if model_size not in self.model_paths and model_size_dirname not in self.model_paths:
138
+ print(f"Model is not detected. Trying to download \"{model_size}\" from huggingface to "
139
+ f"\"{os.path.join(self.model_dir, model_size_dirname)} ...")
140
+ huggingface_hub.snapshot_download(
141
+ model_size,
142
+ local_dir=os.path.join(self.model_dir, model_size_dirname),
143
+ )
144
+ self.model_paths = self.get_model_paths()
145
+ gr.Info(f"Model is downloaded with the name \"{model_size_dirname}\"")
146
+
147
+ self.current_model_size = self.model_paths[model_size_dirname]
148
+
149
+ local_files_only = False
150
+ hf_prefix = "models--Systran--faster-whisper-"
151
+ official_model_path = os.path.join(self.model_dir, hf_prefix+model_size)
152
+ if ((os.path.isdir(self.current_model_size) and os.path.exists(self.current_model_size)) or
153
+ (model_size in faster_whisper.available_models() and os.path.exists(official_model_path))):
154
+ local_files_only = True
155
+
156
+ self.current_compute_type = compute_type
157
+ self.model = faster_whisper.WhisperModel(
158
+ device=self.device,
159
+ model_size_or_path=self.current_model_size,
160
+ download_root=self.model_dir,
161
+ compute_type=self.current_compute_type,
162
+ local_files_only=local_files_only
163
+ )
164
+
165
+ def get_model_paths(self):
166
+ """
167
+ Get available models from models path including fine-tuned model.
168
+
169
+ Returns
170
+ ----------
171
+ Name list of models
172
+ """
173
+ model_paths = {model:model for model in faster_whisper.available_models()}
174
+ faster_whisper_prefix = "models--Systran--faster-whisper-"
175
+
176
+ existing_models = os.listdir(self.model_dir)
177
+ wrong_dirs = [".locks", "faster_whisper_models_will_be_saved_here"]
178
+ existing_models = list(set(existing_models) - set(wrong_dirs))
179
+
180
+ for model_name in existing_models:
181
+ if faster_whisper_prefix in model_name:
182
+ model_name = model_name[len(faster_whisper_prefix):]
183
+
184
+ if model_name not in whisper.available_models():
185
+ model_paths[model_name] = os.path.join(self.model_dir, model_name)
186
+ return model_paths
187
+
188
+ @staticmethod
189
+ def get_device():
190
+ if torch.cuda.is_available():
191
+ return "cuda"
192
+ else:
193
+ return "auto"
194
+
195
+ @staticmethod
196
+ def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]:
197
+ try:
198
+ suppress_tokens = ast.literal_eval(suppress_tokens_str)
199
+ if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens):
200
+ raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
201
+ return suppress_tokens
202
+ except Exception as e:
203
+ raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")