jhj0517 commited on
Commit
fe5e707
·
1 Parent(s): aa74840

modularize vad

Browse files
modules/vad/__init__.py ADDED
File without changes
modules/vad/silero_vad.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from faster_whisper.vad import VadOptions
2
+ import numpy as np
3
+ from typing import BinaryIO, Union, List, Optional
4
+ import warnings
5
+ import faster_whisper
6
+ import gradio as gr
7
+
8
+
9
+ class SileroVAD:
10
+ def __init__(self):
11
+ self.sampling_rate = 16000
12
+
13
+ def run(self,
14
+ audio: Union[str, BinaryIO, np.ndarray],
15
+ vad_parameters: VadOptions,
16
+ progress: gr.Progress = gr.Progress()):
17
+ """
18
+ Run VAD
19
+
20
+ Parameters
21
+ ----------
22
+ audio: Union[str, BinaryIO, np.ndarray]
23
+ Audio path or file binary or Audio numpy array
24
+ vad_parameters:
25
+ Options for VAD processing.
26
+ progress: gr.Progress
27
+ Indicator to show progress directly in gradio.
28
+
29
+ Returns
30
+ ----------
31
+ audio: np.ndarray
32
+ Pre-processed audio with VAD
33
+ """
34
+
35
+ sampling_rate = self.sampling_rate
36
+
37
+ if not isinstance(audio, np.ndarray):
38
+ audio = faster_whisper.decode_audio(audio, sampling_rate=sampling_rate)
39
+
40
+ duration = audio.shape[0] / sampling_rate
41
+ duration_after_vad = duration
42
+
43
+ if vad_parameters is None:
44
+ vad_parameters = VadOptions()
45
+ elif isinstance(vad_parameters, dict):
46
+ vad_parameters = VadOptions(**vad_parameters)
47
+ speech_chunks = self.get_speech_timestamps(
48
+ audio=audio,
49
+ vad_options=vad_parameters,
50
+ progress=progress
51
+ )
52
+ audio = self.collect_chunks(audio, speech_chunks)
53
+ duration_after_vad = audio.shape[0] / sampling_rate
54
+
55
+ return audio
56
+
57
+ @staticmethod
58
+ def get_speech_timestamps(
59
+ audio: np.ndarray,
60
+ vad_options: Optional[VadOptions] = None,
61
+ progress: gr.Progress = gr.Progress(),
62
+ **kwargs,
63
+ ) -> List[dict]:
64
+ """This method is used for splitting long audios into speech chunks using silero VAD.
65
+
66
+ Args:
67
+ audio: One dimensional float array.
68
+ vad_options: Options for VAD processing.
69
+ kwargs: VAD options passed as keyword arguments for backward compatibility.
70
+ progress: Gradio progress to indicate progress.
71
+
72
+ Returns:
73
+ List of dicts containing begin and end samples of each speech chunk.
74
+ """
75
+ if vad_options is None:
76
+ vad_options = VadOptions(**kwargs)
77
+
78
+ threshold = vad_options.threshold
79
+ min_speech_duration_ms = vad_options.min_speech_duration_ms
80
+ max_speech_duration_s = vad_options.max_speech_duration_s
81
+ min_silence_duration_ms = vad_options.min_silence_duration_ms
82
+ window_size_samples = vad_options.window_size_samples
83
+ speech_pad_ms = vad_options.speech_pad_ms
84
+
85
+ if window_size_samples not in [512, 1024, 1536]:
86
+ warnings.warn(
87
+ "Unusual window_size_samples! Supported window_size_samples:\n"
88
+ " - [512, 1024, 1536] for 16000 sampling_rate"
89
+ )
90
+
91
+ sampling_rate = 16000
92
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
93
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
94
+ max_speech_samples = (
95
+ sampling_rate * max_speech_duration_s
96
+ - window_size_samples
97
+ - 2 * speech_pad_samples
98
+ )
99
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
100
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
101
+
102
+ audio_length_samples = len(audio)
103
+
104
+ model = faster_whisper.vad.get_vad_model()
105
+ state = model.get_initial_state(batch_size=1)
106
+
107
+ speech_probs = []
108
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
109
+ progress(current_start_sample/audio_length_samples, desc="Preprocessing using VAD..")
110
+
111
+ chunk = audio[current_start_sample: current_start_sample + window_size_samples]
112
+ if len(chunk) < window_size_samples:
113
+ chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
114
+ speech_prob, state = model(chunk, state, sampling_rate)
115
+ speech_probs.append(speech_prob)
116
+
117
+ triggered = False
118
+ speeches = []
119
+ current_speech = {}
120
+ neg_threshold = threshold - 0.15
121
+
122
+ # to save potential segment end (and tolerate some silence)
123
+ temp_end = 0
124
+ # to save potential segment limits in case of maximum segment size reached
125
+ prev_end = next_start = 0
126
+
127
+ for i, speech_prob in enumerate(speech_probs):
128
+ if (speech_prob >= threshold) and temp_end:
129
+ temp_end = 0
130
+ if next_start < prev_end:
131
+ next_start = window_size_samples * i
132
+
133
+ if (speech_prob >= threshold) and not triggered:
134
+ triggered = True
135
+ current_speech["start"] = window_size_samples * i
136
+ continue
137
+
138
+ if (
139
+ triggered
140
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
141
+ ):
142
+ if prev_end:
143
+ current_speech["end"] = prev_end
144
+ speeches.append(current_speech)
145
+ current_speech = {}
146
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
147
+ if next_start < prev_end:
148
+ triggered = False
149
+ else:
150
+ current_speech["start"] = next_start
151
+ prev_end = next_start = temp_end = 0
152
+ else:
153
+ current_speech["end"] = window_size_samples * i
154
+ speeches.append(current_speech)
155
+ current_speech = {}
156
+ prev_end = next_start = temp_end = 0
157
+ triggered = False
158
+ continue
159
+
160
+ if (speech_prob < neg_threshold) and triggered:
161
+ if not temp_end:
162
+ temp_end = window_size_samples * i
163
+ # condition to avoid cutting in very short silence
164
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
165
+ prev_end = temp_end
166
+ if (window_size_samples * i) - temp_end < min_silence_samples:
167
+ continue
168
+ else:
169
+ current_speech["end"] = temp_end
170
+ if (
171
+ current_speech["end"] - current_speech["start"]
172
+ ) > min_speech_samples:
173
+ speeches.append(current_speech)
174
+ current_speech = {}
175
+ prev_end = next_start = temp_end = 0
176
+ triggered = False
177
+ continue
178
+
179
+ if (
180
+ current_speech
181
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
182
+ ):
183
+ current_speech["end"] = audio_length_samples
184
+ speeches.append(current_speech)
185
+
186
+ for i, speech in enumerate(speeches):
187
+ if i == 0:
188
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
189
+ if i != len(speeches) - 1:
190
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
191
+ if silence_duration < 2 * speech_pad_samples:
192
+ speech["end"] += int(silence_duration // 2)
193
+ speeches[i + 1]["start"] = int(
194
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
195
+ )
196
+ else:
197
+ speech["end"] = int(
198
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
199
+ )
200
+ speeches[i + 1]["start"] = int(
201
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
202
+ )
203
+ else:
204
+ speech["end"] = int(
205
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
206
+ )
207
+
208
+ return speeches
209
+
210
+ @staticmethod
211
+ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
212
+ """Collects and concatenates audio chunks."""
213
+ if not chunks:
214
+ return np.array([], dtype=np.float32)
215
+
216
+ return np.concatenate([audio[chunk["start"]: chunk["end"]] for chunk in chunks])
217
+
218
+ @staticmethod
219
+ def format_timestamp(
220
+ seconds: float,
221
+ always_include_hours: bool = False,
222
+ decimal_marker: str = ".",
223
+ ) -> str:
224
+ assert seconds >= 0, "non-negative timestamp expected"
225
+ milliseconds = round(seconds * 1000.0)
226
+
227
+ hours = milliseconds // 3_600_000
228
+ milliseconds -= hours * 3_600_000
229
+
230
+ minutes = milliseconds // 60_000
231
+ milliseconds -= minutes * 60_000
232
+
233
+ seconds = milliseconds // 1_000
234
+ milliseconds -= seconds * 1_000
235
+
236
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
237
+ return (
238
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
239
+ )
240
+
modules/whisper/faster_whisper_inference.py CHANGED
@@ -62,21 +62,6 @@ class FasterWhisperInference(WhisperBase):
62
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
63
  self.update_model(params.model_size, params.compute_type, progress)
64
 
65
- if params.lang == "Automatic Detection":
66
- params.lang = None
67
- else:
68
- language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
69
- params.lang = language_code_dict[params.lang]
70
-
71
- vad_options = VadOptions(
72
- threshold=params.threshold,
73
- min_speech_duration_ms=params.min_speech_duration_ms,
74
- max_speech_duration_s=params.max_speech_duration_s,
75
- min_silence_duration_ms=params.min_silence_duration_ms,
76
- window_size_samples=params.window_size_samples,
77
- speech_pad_ms=params.speech_pad_ms
78
- )
79
-
80
  segments, info = self.model.transcribe(
81
  audio=audio,
82
  language=params.lang,
@@ -88,8 +73,6 @@ class FasterWhisperInference(WhisperBase):
88
  patience=params.patience,
89
  temperature=params.temperature,
90
  compression_ratio_threshold=params.compression_ratio_threshold,
91
- vad_filter=params.vad_filter,
92
- vad_parameters=vad_options
93
  )
94
  progress(0, desc="Loading audio..")
95
 
 
62
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
63
  self.update_model(params.model_size, params.compute_type, progress)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  segments, info = self.model.transcribe(
66
  audio=audio,
67
  language=params.lang,
 
73
  patience=params.patience,
74
  temperature=params.temperature,
75
  compression_ratio_threshold=params.compression_ratio_threshold,
 
 
76
  )
77
  progress(0, desc="Loading audio..")
78
 
modules/whisper/whisper_base.py CHANGED
@@ -7,11 +7,14 @@ from typing import BinaryIO, Union, Tuple, List
7
  import numpy as np
8
  from datetime import datetime
9
  from argparse import Namespace
 
 
10
 
11
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
12
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
13
  from modules.whisper.whisper_parameter import *
14
  from modules.diarize.diarizer import Diarizer
 
15
 
16
 
17
  class WhisperBase(ABC):
@@ -35,6 +38,7 @@ class WhisperBase(ABC):
35
  self.diarizer = Diarizer(
36
  model_dir=args.diarization_model_dir
37
  )
 
38
 
39
  @abstractmethod
40
  def transcribe(self,
@@ -79,6 +83,21 @@ class WhisperBase(ABC):
79
  """
80
  params = WhisperParameters.as_value(*whisper_params)
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  if params.lang == "Automatic Detection":
83
  params.lang = None
84
  else:
@@ -88,7 +107,7 @@ class WhisperBase(ABC):
88
  result, elapsed_time = self.transcribe(
89
  audio,
90
  progress,
91
- *whisper_params
92
  )
93
 
94
  if params.is_diarize:
 
7
  import numpy as np
8
  from datetime import datetime
9
  from argparse import Namespace
10
+ from faster_whisper.vad import VadOptions
11
+ from dataclasses import astuple
12
 
13
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
14
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
15
  from modules.whisper.whisper_parameter import *
16
  from modules.diarize.diarizer import Diarizer
17
+ from modules.vad.silero_vad import SileroVAD
18
 
19
 
20
  class WhisperBase(ABC):
 
38
  self.diarizer = Diarizer(
39
  model_dir=args.diarization_model_dir
40
  )
41
+ self.vad = SileroVAD()
42
 
43
  @abstractmethod
44
  def transcribe(self,
 
83
  """
84
  params = WhisperParameters.as_value(*whisper_params)
85
 
86
+ if params.vad_filter:
87
+ vad_options = VadOptions(
88
+ threshold=params.threshold,
89
+ min_speech_duration_ms=params.min_speech_duration_ms,
90
+ max_speech_duration_s=params.max_speech_duration_s,
91
+ min_silence_duration_ms=params.min_silence_duration_ms,
92
+ window_size_samples=params.window_size_samples,
93
+ speech_pad_ms=params.speech_pad_ms
94
+ )
95
+ self.vad.run(
96
+ audio=audio,
97
+ vad_parameters=vad_options,
98
+ progress=progress
99
+ )
100
+
101
  if params.lang == "Automatic Detection":
102
  params.lang = None
103
  else:
 
107
  result, elapsed_time = self.transcribe(
108
  audio,
109
  progress,
110
+ *astuple(params)
111
  )
112
 
113
  if params.is_diarize: