Helw150 commited on
Commit
b164fe5
·
1 Parent(s): 94540c3

Move VAD to On-Device

Browse files
Files changed (4) hide show
  1. app.py +73 -95
  2. utils/assets/silero_vad.onnx +0 -3
  3. utils/snac_utils.py +0 -146
  4. utils/vad.py +0 -290
app.py CHANGED
@@ -12,10 +12,6 @@ import xxhash
12
  from datasets import Audio
13
  from transformers import AutoModel
14
  import io
15
- from pydub import AudioSegment
16
- import tempfile
17
-
18
- from utils.vad import VadOptions, collect_chunks, get_speech_timestamps
19
 
20
  if gr.NO_RELOAD:
21
  diva_model = AutoModel.from_pretrained(
@@ -25,7 +21,7 @@ if gr.NO_RELOAD:
25
  resampler = Audio(sampling_rate=16_000)
26
 
27
 
28
- @spaces.GPU(duration=20)
29
  @torch.no_grad
30
  def diva_audio(audio_input, do_sample=False, temperature=0.001, prev_outs=None):
31
  sr, y = audio_input
@@ -37,7 +33,11 @@ def diva_audio(audio_input, do_sample=False, temperature=0.001, prev_outs=None):
37
  )
38
  yield from diva_model.generate_stream(
39
  a["array"],
40
- None,
 
 
 
 
41
  do_sample=do_sample,
42
  max_new_tokens=256,
43
  init_outputs=prev_outs,
@@ -45,96 +45,24 @@ def diva_audio(audio_input, do_sample=False, temperature=0.001, prev_outs=None):
45
  )
46
 
47
 
48
- def run_vad(ori_audio, sr, duration):
49
- _st = time.time()
50
- try:
51
- audio = ori_audio
52
- if duration < 1:
53
- return -1, ori_audio, round(time.time() - _st, 4)
54
- audio = audio.astype(np.float32) / 32768.0
55
- sampling_rate = 16000
56
- if sr != sampling_rate:
57
- audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
58
-
59
- vad_parameters = {}
60
- vad_parameters = VadOptions(**vad_parameters)
61
- speech_chunks = get_speech_timestamps(audio, vad_parameters)
62
- audio = collect_chunks(audio, speech_chunks)
63
- duration_after_vad = audio.shape[0] / sampling_rate
64
-
65
- if sr != sampling_rate:
66
- # resample to original sampling rate
67
- vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
68
- else:
69
- vad_audio = audio
70
- vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
71
- vad_audio_bytes = vad_audio.tobytes()
72
-
73
- return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
74
- except Exception as e:
75
- msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
76
- print(msg)
77
- return -1, ori_audio, round(time.time() - _st, 4)
78
-
79
-
80
- def warm_up():
81
- frames = np.ones(2048) # 1024 frames of 2 bytes each
82
- dur, frames, tcost = run_vad(frames, 16000, 10)
83
- print(f"warm up done, time_cost: {tcost:.3f} s")
84
-
85
-
86
- warm_up()
87
-
88
-
89
  @dataclass
90
  class AppState:
91
  stream: np.ndarray | None = None
92
  sampling_rate: int = 0
93
- pause_detected: bool = False
94
- started_talking: bool = False
95
  stopped: bool = False
96
  conversation: list = field(default_factory=list)
97
  model_outs: any = None
98
 
99
 
100
- def determine_pause(audio: np.ndarray, sampling_rate: int, state: AppState) -> bool:
101
- """Take in the stream, determine if a pause happened"""
102
-
103
- temp_audio = audio[-2 * sampling_rate :]
104
-
105
- duration = len(audio) / sampling_rate
106
- dur_vad, _, time_vad = run_vad(temp_audio, sampling_rate, duration)
107
-
108
- if dur_vad > 0.25 and not state.started_talking:
109
- print("started talking")
110
- state.started_talking = True
111
- return False
112
-
113
- print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
114
-
115
- return dur_vad < 0.5
116
-
117
-
118
  def process_audio(audio: tuple, state: AppState):
119
- if state.stream is None:
120
- state.stream = audio[1]
121
- state.sampling_rate = audio[0]
122
- elif audio is not None and audio[1] is not None:
123
- state.stream = np.concatenate((state.stream, audio[1]))
124
- else:
125
- return None, state
126
 
127
- pause_detected = determine_pause(state.stream, state.sampling_rate, state)
128
- state.pause_detected = pause_detected
129
 
130
- if state.pause_detected and state.started_talking:
131
- return gr.Audio(recording=False), state
132
- return None, state
133
-
134
-
135
- def response(state: AppState):
136
- if not state.pause_detected and not state.started_talking:
137
  return AppState()
 
 
138
 
139
  file_name = f"/tmp/{xxhash.xxh32(bytes(state.stream)).hexdigest()}.wav"
140
 
@@ -159,8 +87,7 @@ def response(state: AppState):
159
 
160
 
161
  def start_recording_user(state: AppState):
162
- if not state.stopped:
163
- return gr.Audio(recording=True)
164
 
165
 
166
  theme = gr.themes.Soft(
@@ -181,29 +108,80 @@ theme = gr.themes.Soft(
181
  neutral_hue="stone",
182
  )
183
 
184
- with gr.Blocks(theme=theme) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  with gr.Row():
186
- input_audio = gr.Audio(label="Input Audio", sources="microphone", type="numpy")
 
 
 
 
 
187
  with gr.Row():
188
  chatbot = gr.Chatbot(label="Conversation", type="messages")
189
  state = gr.State(value=AppState())
190
-
191
- stream = input_audio.stream(
192
  process_audio,
193
  [input_audio, state],
194
  [input_audio, state],
195
- stream_every=0.25,
196
- time_limit=10,
197
  )
198
- respond = input_audio.stop_recording(response, [state], [state, chatbot])
199
- respond.then(start_recording_user, [state], [input_audio])
 
 
 
 
200
 
201
- cancel = gr.Button("Stop Conversation", variant="stop")
202
  cancel.click(
203
  lambda: (AppState(stopped=True), gr.Audio(recording=False)),
204
  None,
205
  [state, input_audio],
206
- cancels=[respond, stream],
207
  )
208
 
209
  if __name__ == "__main__":
 
12
  from datasets import Audio
13
  from transformers import AutoModel
14
  import io
 
 
 
 
15
 
16
  if gr.NO_RELOAD:
17
  diva_model = AutoModel.from_pretrained(
 
21
  resampler = Audio(sampling_rate=16_000)
22
 
23
 
24
+ @spaces.GPU(duration=20, progress=gr.Progress(track_tqdm=True))
25
  @torch.no_grad
26
  def diva_audio(audio_input, do_sample=False, temperature=0.001, prev_outs=None):
27
  sr, y = audio_input
 
33
  )
34
  yield from diva_model.generate_stream(
35
  a["array"],
36
+ (
37
+ "Your name is DiVA, which stands for Distilled Voice Assistant. You were trained with early-fusion training to merge OpenAI's Whisper and Meta AI's Llama 3 8B to provide end-to-end voice processing. You should give brief and helpful answers, in a conversational style. The user is talking to you with their voice and you are responding with text."
38
+ if prev_outs == None
39
+ else None
40
+ ),
41
  do_sample=do_sample,
42
  max_new_tokens=256,
43
  init_outputs=prev_outs,
 
45
  )
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  @dataclass
49
  class AppState:
50
  stream: np.ndarray | None = None
51
  sampling_rate: int = 0
 
 
52
  stopped: bool = False
53
  conversation: list = field(default_factory=list)
54
  model_outs: any = None
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def process_audio(audio: tuple, state: AppState):
58
+ return audio, state
 
 
 
 
 
 
59
 
 
 
60
 
61
+ def response(state: AppState, audio: tuple):
62
+ if not audio:
 
 
 
 
 
63
  return AppState()
64
+ state.stream = audio[1]
65
+ state.sampling_rate = audio[0]
66
 
67
  file_name = f"/tmp/{xxhash.xxh32(bytes(state.stream)).hexdigest()}.wav"
68
 
 
87
 
88
 
89
  def start_recording_user(state: AppState):
90
+ return None
 
91
 
92
 
93
  theme = gr.themes.Soft(
 
108
  neutral_hue="stone",
109
  )
110
 
111
+ js = """
112
+ async function main() {
113
+ const script1 = document.createElement("script");
114
+ script1.src = "https://cdn.jsdelivr.net/npm/[email protected]/dist/ort.js";
115
+ document.head.appendChild(script1)
116
+ const script2 = document.createElement("script");
117
+ script2.onload = async () => {
118
+ console.log("vad loaded") ;
119
+ var record = document.querySelector('.record-button');
120
+ record.textContent = "Just Start Talking!"
121
+ record.style = "width: 11vw"
122
+ const myvad = await vad.MicVAD.new({
123
+ onSpeechStart: () => {
124
+ var record = document.querySelector('.record-button');
125
+ if (record != null) {
126
+ console.log(record);
127
+ record.click();
128
+ }
129
+ },
130
+ onSpeechEnd: (audio) => {
131
+ var stop = document.querySelector('.stop-button');
132
+ if (stop != null) {
133
+ console.log(stop);
134
+ stop.click();
135
+ }
136
+ }
137
+ })
138
+ myvad.start()
139
+ }
140
+ script2.src = "https://cdn.jsdelivr.net/npm/@ricky0123/[email protected]/dist/bundle.min.js";
141
+ script1.onload = () => {
142
+ console.log("onnx loaded")
143
+ document.head.appendChild(script2)
144
+ };
145
+ }
146
+ """
147
+
148
+ js_reset = """
149
+ () => {
150
+ var record = document.querySelector('.record-button');
151
+ record.textContent = "Just Start Talking!"
152
+ record.style = "width: 11vw"
153
+ }
154
+ """
155
+
156
+ with gr.Blocks(theme=theme, js=js) as demo:
157
  with gr.Row():
158
+ input_audio = gr.Audio(
159
+ label="Input Audio",
160
+ sources=["microphone"],
161
+ type="numpy",
162
+ streaming=False,
163
+ )
164
  with gr.Row():
165
  chatbot = gr.Chatbot(label="Conversation", type="messages")
166
  state = gr.State(value=AppState())
167
+ stream = input_audio.start_recording(
 
168
  process_audio,
169
  [input_audio, state],
170
  [input_audio, state],
 
 
171
  )
172
+ respond = input_audio.stop_recording(
173
+ response, [state, input_audio], [state, chatbot]
174
+ )
175
+ restart = respond.then(start_recording_user, [state], [input_audio]).then(
176
+ lambda: None, None, None, js=js_reset
177
+ )
178
 
179
+ cancel = gr.Button("Restart Conversation", variant="stop")
180
  cancel.click(
181
  lambda: (AppState(stopped=True), gr.Audio(recording=False)),
182
  None,
183
  [state, input_audio],
184
+ cancels=[respond, restart],
185
  )
186
 
187
  if __name__ == "__main__":
utils/assets/silero_vad.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:591f853590d11ddde2f2a54f9e7ccecb2533a8af7716330e8adfa6f3849787a9
3
- size 1807524
 
 
 
 
utils/snac_utils.py DELETED
@@ -1,146 +0,0 @@
1
- import torch
2
- import time
3
- import numpy as np
4
-
5
-
6
- class SnacConfig:
7
- audio_vocab_size = 4096
8
- padded_vocab_size = 4160
9
- end_of_audio = 4097
10
-
11
-
12
- snac_config = SnacConfig()
13
-
14
-
15
- def get_time_str():
16
- time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
17
- return time_str
18
-
19
-
20
- def layershift(input_id, layer, stride=4160, shift=152000):
21
- return input_id + shift + layer * stride
22
-
23
-
24
- def generate_audio_data(snac_tokens, snacmodel, device=None):
25
- audio = reconstruct_tensors(snac_tokens, device)
26
- with torch.inference_mode():
27
- audio_hat = snacmodel.decode(audio)
28
- audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
29
- audio_data = audio_data.astype(np.int16)
30
- audio_data = audio_data.tobytes()
31
- return audio_data
32
-
33
-
34
- def get_snac(list_output, index, nums_generate):
35
-
36
- snac = []
37
- start = index
38
- for i in range(nums_generate):
39
- snac.append("#")
40
- for j in range(7):
41
- snac.append(list_output[j][start - nums_generate - 5 + j + i])
42
- return snac
43
-
44
-
45
- def reconscruct_snac(output_list):
46
- if len(output_list) == 8:
47
- output_list = output_list[:-1]
48
- output = []
49
- for i in range(7):
50
- output_list[i] = output_list[i][i + 1 :]
51
- for i in range(len(output_list[-1])):
52
- output.append("#")
53
- for j in range(7):
54
- output.append(output_list[j][i])
55
- return output
56
-
57
-
58
- def reconstruct_tensors(flattened_output, device=None):
59
- """Reconstructs the list of tensors from the flattened output."""
60
-
61
- if device is None:
62
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
-
64
- def count_elements_between_hashes(lst):
65
- try:
66
- # Find the index of the first '#'
67
- first_index = lst.index("#")
68
- # Find the index of the second '#' after the first
69
- second_index = lst.index("#", first_index + 1)
70
- # Count the elements between the two indices
71
- return second_index - first_index - 1
72
- except ValueError:
73
- # Handle the case where there aren't enough '#' symbols
74
- return "List does not contain two '#' symbols"
75
-
76
- def remove_elements_before_hash(flattened_list):
77
- try:
78
- # Find the index of the first '#'
79
- first_hash_index = flattened_list.index("#")
80
- # Return the list starting from the first '#'
81
- return flattened_list[first_hash_index:]
82
- except ValueError:
83
- # Handle the case where there is no '#'
84
- return "List does not contain the symbol '#'"
85
-
86
- def list_to_torch_tensor(tensor1):
87
- # Convert the list to a torch tensor
88
- tensor = torch.tensor(tensor1)
89
- # Reshape the tensor to have size (1, n)
90
- tensor = tensor.unsqueeze(0)
91
- return tensor
92
-
93
- flattened_output = remove_elements_before_hash(flattened_output)
94
- codes = []
95
- tensor1 = []
96
- tensor2 = []
97
- tensor3 = []
98
- tensor4 = []
99
-
100
- n_tensors = count_elements_between_hashes(flattened_output)
101
- if n_tensors == 7:
102
- for i in range(0, len(flattened_output), 8):
103
-
104
- tensor1.append(flattened_output[i + 1])
105
- tensor2.append(flattened_output[i + 2])
106
- tensor3.append(flattened_output[i + 3])
107
- tensor3.append(flattened_output[i + 4])
108
-
109
- tensor2.append(flattened_output[i + 5])
110
- tensor3.append(flattened_output[i + 6])
111
- tensor3.append(flattened_output[i + 7])
112
- codes = [
113
- list_to_torch_tensor(tensor1).to(device),
114
- list_to_torch_tensor(tensor2).to(device),
115
- list_to_torch_tensor(tensor3).to(device),
116
- ]
117
-
118
- if n_tensors == 15:
119
- for i in range(0, len(flattened_output), 16):
120
-
121
- tensor1.append(flattened_output[i + 1])
122
- tensor2.append(flattened_output[i + 2])
123
- tensor3.append(flattened_output[i + 3])
124
- tensor4.append(flattened_output[i + 4])
125
- tensor4.append(flattened_output[i + 5])
126
- tensor3.append(flattened_output[i + 6])
127
- tensor4.append(flattened_output[i + 7])
128
- tensor4.append(flattened_output[i + 8])
129
-
130
- tensor2.append(flattened_output[i + 9])
131
- tensor3.append(flattened_output[i + 10])
132
- tensor4.append(flattened_output[i + 11])
133
- tensor4.append(flattened_output[i + 12])
134
- tensor3.append(flattened_output[i + 13])
135
- tensor4.append(flattened_output[i + 14])
136
- tensor4.append(flattened_output[i + 15])
137
-
138
- codes = [
139
- list_to_torch_tensor(tensor1).to(device),
140
- list_to_torch_tensor(tensor2).to(device),
141
- list_to_torch_tensor(tensor3).to(device),
142
- list_to_torch_tensor(tensor4).to(device),
143
- ]
144
-
145
- return codes
146
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/vad.py DELETED
@@ -1,290 +0,0 @@
1
- import bisect
2
- import functools
3
- import os
4
- import warnings
5
-
6
- from typing import List, NamedTuple, Optional
7
-
8
- import numpy as np
9
-
10
-
11
- # The code below is adapted from https://github.com/snakers4/silero-vad.
12
- class VadOptions(NamedTuple):
13
- """VAD options.
14
-
15
- Attributes:
16
- threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
17
- probabilities ABOVE this value are considered as SPEECH. It is better to tune this
18
- parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
19
- min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
20
- max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
21
- than max_speech_duration_s will be split at the timestamp of the last silence that
22
- lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
23
- split aggressively just before max_speech_duration_s.
24
- min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
25
- before separating it
26
- window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
27
- WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
28
- Values other than these may affect model performance!!
29
- speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
30
- """
31
-
32
- threshold: float = 0.5
33
- min_speech_duration_ms: int = 250
34
- max_speech_duration_s: float = float("inf")
35
- min_silence_duration_ms: int = 2000
36
- window_size_samples: int = 1024
37
- speech_pad_ms: int = 400
38
-
39
-
40
- def get_speech_timestamps(
41
- audio: np.ndarray,
42
- vad_options: Optional[VadOptions] = None,
43
- **kwargs,
44
- ) -> List[dict]:
45
- """This method is used for splitting long audios into speech chunks using silero VAD.
46
-
47
- Args:
48
- audio: One dimensional float array.
49
- vad_options: Options for VAD processing.
50
- kwargs: VAD options passed as keyword arguments for backward compatibility.
51
-
52
- Returns:
53
- List of dicts containing begin and end samples of each speech chunk.
54
- """
55
- if vad_options is None:
56
- vad_options = VadOptions(**kwargs)
57
-
58
- threshold = vad_options.threshold
59
- min_speech_duration_ms = vad_options.min_speech_duration_ms
60
- max_speech_duration_s = vad_options.max_speech_duration_s
61
- min_silence_duration_ms = vad_options.min_silence_duration_ms
62
- window_size_samples = vad_options.window_size_samples
63
- speech_pad_ms = vad_options.speech_pad_ms
64
-
65
- if window_size_samples not in [512, 1024, 1536]:
66
- warnings.warn(
67
- "Unusual window_size_samples! Supported window_size_samples:\n"
68
- " - [512, 1024, 1536] for 16000 sampling_rate"
69
- )
70
-
71
- sampling_rate = 16000
72
- min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
73
- speech_pad_samples = sampling_rate * speech_pad_ms / 1000
74
- max_speech_samples = (
75
- sampling_rate * max_speech_duration_s
76
- - window_size_samples
77
- - 2 * speech_pad_samples
78
- )
79
- min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
80
- min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
81
-
82
- audio_length_samples = len(audio)
83
-
84
- model = get_vad_model()
85
- state = model.get_initial_state(batch_size=1)
86
-
87
- speech_probs = []
88
- for current_start_sample in range(0, audio_length_samples, window_size_samples):
89
- chunk = audio[current_start_sample : current_start_sample + window_size_samples]
90
- if len(chunk) < window_size_samples:
91
- chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
92
- speech_prob, state = model(chunk, state, sampling_rate)
93
- speech_probs.append(speech_prob)
94
-
95
- triggered = False
96
- speeches = []
97
- current_speech = {}
98
- neg_threshold = threshold - 0.15
99
-
100
- # to save potential segment end (and tolerate some silence)
101
- temp_end = 0
102
- # to save potential segment limits in case of maximum segment size reached
103
- prev_end = next_start = 0
104
-
105
- for i, speech_prob in enumerate(speech_probs):
106
- if (speech_prob >= threshold) and temp_end:
107
- temp_end = 0
108
- if next_start < prev_end:
109
- next_start = window_size_samples * i
110
-
111
- if (speech_prob >= threshold) and not triggered:
112
- triggered = True
113
- current_speech["start"] = window_size_samples * i
114
- continue
115
-
116
- if (
117
- triggered
118
- and (window_size_samples * i) - current_speech["start"] > max_speech_samples
119
- ):
120
- if prev_end:
121
- current_speech["end"] = prev_end
122
- speeches.append(current_speech)
123
- current_speech = {}
124
- # previously reached silence (< neg_thres) and is still not speech (< thres)
125
- if next_start < prev_end:
126
- triggered = False
127
- else:
128
- current_speech["start"] = next_start
129
- prev_end = next_start = temp_end = 0
130
- else:
131
- current_speech["end"] = window_size_samples * i
132
- speeches.append(current_speech)
133
- current_speech = {}
134
- prev_end = next_start = temp_end = 0
135
- triggered = False
136
- continue
137
-
138
- if (speech_prob < neg_threshold) and triggered:
139
- if not temp_end:
140
- temp_end = window_size_samples * i
141
- # condition to avoid cutting in very short silence
142
- if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
143
- prev_end = temp_end
144
- if (window_size_samples * i) - temp_end < min_silence_samples:
145
- continue
146
- else:
147
- current_speech["end"] = temp_end
148
- if (
149
- current_speech["end"] - current_speech["start"]
150
- ) > min_speech_samples:
151
- speeches.append(current_speech)
152
- current_speech = {}
153
- prev_end = next_start = temp_end = 0
154
- triggered = False
155
- continue
156
-
157
- if (
158
- current_speech
159
- and (audio_length_samples - current_speech["start"]) > min_speech_samples
160
- ):
161
- current_speech["end"] = audio_length_samples
162
- speeches.append(current_speech)
163
-
164
- for i, speech in enumerate(speeches):
165
- if i == 0:
166
- speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
167
- if i != len(speeches) - 1:
168
- silence_duration = speeches[i + 1]["start"] - speech["end"]
169
- if silence_duration < 2 * speech_pad_samples:
170
- speech["end"] += int(silence_duration // 2)
171
- speeches[i + 1]["start"] = int(
172
- max(0, speeches[i + 1]["start"] - silence_duration // 2)
173
- )
174
- else:
175
- speech["end"] = int(
176
- min(audio_length_samples, speech["end"] + speech_pad_samples)
177
- )
178
- speeches[i + 1]["start"] = int(
179
- max(0, speeches[i + 1]["start"] - speech_pad_samples)
180
- )
181
- else:
182
- speech["end"] = int(
183
- min(audio_length_samples, speech["end"] + speech_pad_samples)
184
- )
185
-
186
- return speeches
187
-
188
-
189
- def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
190
- """Collects and concatenates audio chunks."""
191
- if not chunks:
192
- return np.array([], dtype=np.float32)
193
-
194
- return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
195
-
196
-
197
- class SpeechTimestampsMap:
198
- """Helper class to restore original speech timestamps."""
199
-
200
- def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
201
- self.sampling_rate = sampling_rate
202
- self.time_precision = time_precision
203
- self.chunk_end_sample = []
204
- self.total_silence_before = []
205
-
206
- previous_end = 0
207
- silent_samples = 0
208
-
209
- for chunk in chunks:
210
- silent_samples += chunk["start"] - previous_end
211
- previous_end = chunk["end"]
212
-
213
- self.chunk_end_sample.append(chunk["end"] - silent_samples)
214
- self.total_silence_before.append(silent_samples / sampling_rate)
215
-
216
- def get_original_time(
217
- self,
218
- time: float,
219
- chunk_index: Optional[int] = None,
220
- ) -> float:
221
- if chunk_index is None:
222
- chunk_index = self.get_chunk_index(time)
223
-
224
- total_silence_before = self.total_silence_before[chunk_index]
225
- return round(total_silence_before + time, self.time_precision)
226
-
227
- def get_chunk_index(self, time: float) -> int:
228
- sample = int(time * self.sampling_rate)
229
- return min(
230
- bisect.bisect(self.chunk_end_sample, sample),
231
- len(self.chunk_end_sample) - 1,
232
- )
233
-
234
-
235
- @functools.lru_cache
236
- def get_vad_model():
237
- """Returns the VAD model instance."""
238
- asset_dir = os.path.join(os.path.dirname(__file__), "assets")
239
- path = os.path.join(asset_dir, "silero_vad.onnx")
240
- return SileroVADModel(path)
241
-
242
-
243
- class SileroVADModel:
244
- def __init__(self, path):
245
- try:
246
- import onnxruntime
247
- except ImportError as e:
248
- raise RuntimeError(
249
- "Applying the VAD filter requires the onnxruntime package"
250
- ) from e
251
-
252
- opts = onnxruntime.SessionOptions()
253
- opts.inter_op_num_threads = 1
254
- opts.intra_op_num_threads = 1
255
- opts.log_severity_level = 4
256
-
257
- self.session = onnxruntime.InferenceSession(
258
- path,
259
- providers=["CPUExecutionProvider"],
260
- sess_options=opts,
261
- )
262
-
263
- def get_initial_state(self, batch_size: int):
264
- h = np.zeros((2, batch_size, 64), dtype=np.float32)
265
- c = np.zeros((2, batch_size, 64), dtype=np.float32)
266
- return h, c
267
-
268
- def __call__(self, x, state, sr: int):
269
- if len(x.shape) == 1:
270
- x = np.expand_dims(x, 0)
271
- if len(x.shape) > 2:
272
- raise ValueError(
273
- f"Too many dimensions for input audio chunk {len(x.shape)}"
274
- )
275
- if sr / x.shape[1] > 31.25:
276
- raise ValueError("Input audio chunk is too short")
277
-
278
- h, c = state
279
-
280
- ort_inputs = {
281
- "input": x,
282
- "h": h,
283
- "c": c,
284
- "sr": np.array(sr, dtype="int64"),
285
- }
286
-
287
- out, h, c = self.session.run(None, ort_inputs)
288
- state = (h, c)
289
-
290
- return out, state