update
Browse files
examples/fsmn_vad_by_webrtcvad/step_4_train_model.py
CHANGED
@@ -270,7 +270,7 @@ def main():
|
|
270 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
271 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
272 |
|
273 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.
|
274 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
275 |
logger.info(f"find nan or inf in loss. continue.")
|
276 |
continue
|
@@ -351,7 +351,7 @@ def main():
|
|
351 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
352 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
353 |
|
354 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.
|
355 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
356 |
logger.info(f"find nan or inf in loss. continue.")
|
357 |
continue
|
|
|
270 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
271 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
272 |
|
273 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
|
274 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
275 |
logger.info(f"find nan or inf in loss. continue.")
|
276 |
continue
|
|
|
351 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
352 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
353 |
|
354 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
|
355 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
356 |
logger.info(f"find nan or inf in loss. continue.")
|
357 |
continue
|
examples/silero_vad_by_webrtcvad/step_4_train_model.py
CHANGED
@@ -270,7 +270,7 @@ def main():
|
|
270 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
271 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
272 |
|
273 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.
|
274 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
275 |
logger.info(f"find nan or inf in loss. continue.")
|
276 |
continue
|
@@ -351,7 +351,7 @@ def main():
|
|
351 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
352 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
353 |
|
354 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.
|
355 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
356 |
logger.info(f"find nan or inf in loss. continue.")
|
357 |
continue
|
|
|
270 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
271 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
272 |
|
273 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
|
274 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
275 |
logger.info(f"find nan or inf in loss. continue.")
|
276 |
continue
|
|
|
351 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
352 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
353 |
|
354 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
|
355 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
356 |
logger.info(f"find nan or inf in loss. continue.")
|
357 |
continue
|
main.py
CHANGED
@@ -1,14 +1,25 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
import argparse
|
|
|
|
|
4 |
import logging
|
5 |
import platform
|
|
|
|
|
|
|
6 |
|
7 |
import gradio as gr
|
|
|
|
|
|
|
|
|
8 |
|
9 |
import log
|
10 |
-
from project_settings import environment, log_directory, time_zone_info
|
11 |
from toolbox.os.command import Command
|
|
|
|
|
12 |
|
13 |
log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
|
14 |
|
@@ -36,13 +47,136 @@ def shell(cmd: str):
|
|
36 |
return Command.popen(cmd)
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def main():
|
40 |
args = get_args()
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
# ui
|
43 |
with gr.Blocks() as blocks:
|
44 |
gr.Markdown(value="vad.")
|
45 |
with gr.Tabs():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
with gr.TabItem("shell"):
|
47 |
shell_text = gr.Textbox(label="cmd")
|
48 |
shell_button = gr.Button("run")
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
import argparse
|
4 |
+
from functools import lru_cache
|
5 |
+
import json
|
6 |
import logging
|
7 |
import platform
|
8 |
+
import tempfile
|
9 |
+
import time
|
10 |
+
from typing import Dict, Tuple
|
11 |
|
12 |
import gradio as gr
|
13 |
+
import librosa
|
14 |
+
import librosa.display
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
import numpy as np
|
17 |
|
18 |
import log
|
19 |
+
from project_settings import environment, project_path, log_directory, time_zone_info
|
20 |
from toolbox.os.command import Command
|
21 |
+
from toolbox.torchaudio.models.vad.fsmn_vad.inference_fsmn_vad import InferenceFSMNVad
|
22 |
+
from toolbox.torchaudio.utils.visualization import process_speech_probs
|
23 |
|
24 |
log.setup_size_rotating(log_directory=log_directory, tz_info=time_zone_info)
|
25 |
|
|
|
47 |
return Command.popen(cmd)
|
48 |
|
49 |
|
50 |
+
def get_infer_cls_by_model_name(model_name: str):
|
51 |
+
if model_name.__contains__("fsmn"):
|
52 |
+
infer_cls = InferenceFSMNVad
|
53 |
+
else:
|
54 |
+
raise AssertionError
|
55 |
+
return infer_cls
|
56 |
+
|
57 |
+
|
58 |
+
vad_engines: Dict[str, dict] = None
|
59 |
+
|
60 |
+
|
61 |
+
@lru_cache(maxsize=1)
|
62 |
+
def load_vad_model(infer_cls, **kwargs):
|
63 |
+
infer_engine = infer_cls(**kwargs)
|
64 |
+
return infer_engine
|
65 |
+
|
66 |
+
|
67 |
+
def generate_image(signal: np.ndarray, speech_probs: np.ndarray, sample_rate: int = 8000, title: str = ""):
|
68 |
+
duration = np.arange(0, len(signal)) / sample_rate
|
69 |
+
plt.figure(figsize=(12, 5))
|
70 |
+
plt.plot(duration, signal, color='b')
|
71 |
+
plt.plot(duration, speech_probs, color='gray')
|
72 |
+
plt.title(title)
|
73 |
+
|
74 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
75 |
+
plt.savefig(temp_file.name, bbox_inches="tight")
|
76 |
+
plt.close()
|
77 |
+
return temp_file.name
|
78 |
+
|
79 |
+
|
80 |
+
def when_click_vad_button(audio_file_t = None, audio_microphone_t = None, engine: str = None):
|
81 |
+
if audio_file_t is None and audio_microphone_t is None:
|
82 |
+
raise gr.Error(f"audio file and microphone is null.")
|
83 |
+
if audio_file_t is not None and audio_microphone_t is not None:
|
84 |
+
gr.Warning(f"both audio file and microphone file is provided, audio file taking priority.")
|
85 |
+
audio_t: Tuple = audio_file_t or audio_microphone_t
|
86 |
+
|
87 |
+
sample_rate, signal = audio_t
|
88 |
+
audio_duration = signal.shape[-1] // 8000
|
89 |
+
audio = np.array(signal / (1 << 15), dtype=np.float32)
|
90 |
+
|
91 |
+
infer_engine_param = vad_engines.get(engine)
|
92 |
+
if infer_engine_param is None:
|
93 |
+
raise gr.Error(f"invalid denoise engine: {engine}.")
|
94 |
+
|
95 |
+
try:
|
96 |
+
infer_cls = infer_engine_param["infer_cls"]
|
97 |
+
kwargs = infer_engine_param["kwargs"]
|
98 |
+
infer_engine = load_vad_model(infer_cls=infer_cls, **kwargs)
|
99 |
+
|
100 |
+
begin = time.time()
|
101 |
+
vad_info = infer_engine.infer(audio)
|
102 |
+
time_cost = time.time() - begin
|
103 |
+
|
104 |
+
fpr = time_cost / audio_duration
|
105 |
+
info = {
|
106 |
+
"time_cost": round(time_cost, 4),
|
107 |
+
"audio_duration": round(audio_duration, 4),
|
108 |
+
"fpr": round(fpr, 4)
|
109 |
+
}
|
110 |
+
message = json.dumps(info, ensure_ascii=False, indent=4)
|
111 |
+
|
112 |
+
probs = vad_info["probs"]
|
113 |
+
lsnr = vad_info["lsnr"]
|
114 |
+
lsnr = lsnr / np.max(np.abs(lsnr))
|
115 |
+
|
116 |
+
frame_step = infer_engine.config.hop_size
|
117 |
+
probs = process_speech_probs(audio, probs, frame_step)
|
118 |
+
lsnr = process_speech_probs(audio, lsnr, frame_step)
|
119 |
+
probs_image = generate_image(audio, probs)
|
120 |
+
lsnr_image = generate_image(audio, lsnr)
|
121 |
+
|
122 |
+
except Exception as e:
|
123 |
+
raise gr.Error(f"vad failed, error type: {type(e)}, error text: {str(e)}.")
|
124 |
+
|
125 |
+
return probs_image, lsnr_image, message
|
126 |
+
|
127 |
+
|
128 |
def main():
|
129 |
args = get_args()
|
130 |
|
131 |
+
# engines
|
132 |
+
global vad_engines
|
133 |
+
vad_engines = {
|
134 |
+
filename.stem: {
|
135 |
+
"infer_cls": get_infer_cls_by_model_name(filename.stem),
|
136 |
+
"kwargs": {
|
137 |
+
"pretrained_model_path_or_zip_file": filename.as_posix()
|
138 |
+
}
|
139 |
+
}
|
140 |
+
for filename in (project_path / "trained_models").glob("*.zip")
|
141 |
+
if filename.name not in (
|
142 |
+
"cnn-vad-by-webrtcvad-nx-dns3.zip",
|
143 |
+
"fsmn-vad-by-webrtcvad-nx-dns3.zip",
|
144 |
+
"examples.zip",
|
145 |
+
"sound-2-ch32.zip",
|
146 |
+
"sound-3-ch32.zip",
|
147 |
+
"sound-4-ch32.zip",
|
148 |
+
"sound-8-ch32.zip",
|
149 |
+
)
|
150 |
+
}
|
151 |
+
|
152 |
+
# choices
|
153 |
+
vad_engine_choices = list(vad_engines.keys())
|
154 |
+
|
155 |
# ui
|
156 |
with gr.Blocks() as blocks:
|
157 |
gr.Markdown(value="vad.")
|
158 |
with gr.Tabs():
|
159 |
+
with gr.TabItem("vad"):
|
160 |
+
with gr.Row():
|
161 |
+
with gr.Column(variant="panel", scale=5):
|
162 |
+
with gr.Tabs():
|
163 |
+
with gr.TabItem("file"):
|
164 |
+
vad_audio_file = gr.Audio(label="audio")
|
165 |
+
with gr.TabItem("microphone"):
|
166 |
+
vad_audio_microphone = gr.Audio(sources="microphone", label="audio")
|
167 |
+
|
168 |
+
vad_engine = gr.Dropdown(choices=vad_engine_choices, value=vad_engine_choices[0], label="engine")
|
169 |
+
vad_button = gr.Button(variant="primary")
|
170 |
+
with gr.Column(variant="panel", scale=5):
|
171 |
+
vad_vad_image = gr.Image(label="vad")
|
172 |
+
vad_lsnr_image = gr.Image(label="lsnr")
|
173 |
+
vad_message = gr.Textbox(lines=1, max_lines=20, label="message")
|
174 |
+
|
175 |
+
vad_button.click(
|
176 |
+
when_click_vad_button,
|
177 |
+
inputs=[vad_audio_file, vad_audio_microphone, vad_engine],
|
178 |
+
outputs=[vad_vad_image, vad_lsnr_image, vad_message]
|
179 |
+
)
|
180 |
with gr.TabItem("shell"):
|
181 |
shell_text = gr.Textbox(label="cmd")
|
182 |
shell_button = gr.Button("run")
|
toolbox/torchaudio/models/vad/fsmn_vad/inference_fsmn_vad.py
CHANGED
@@ -60,7 +60,7 @@ class InferenceFSMNVad(object):
|
|
60 |
shutil.rmtree(model_path)
|
61 |
return config, model
|
62 |
|
63 |
-
def infer(self, signal: torch.Tensor) ->
|
64 |
# signal shape: [num_samples,], value between -1 and 1.
|
65 |
|
66 |
inputs = torch.tensor(signal, dtype=torch.float32)
|
@@ -73,11 +73,20 @@ class InferenceFSMNVad(object):
|
|
73 |
# probs shape: [b, t, 1]
|
74 |
probs = torch.squeeze(probs, dim=-1)
|
75 |
# probs shape: [b, t]
|
76 |
-
|
77 |
probs = probs.numpy()
|
78 |
probs = probs[0]
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
def post_process(self, probs: List[float]):
|
83 |
return
|
@@ -88,11 +97,11 @@ def get_args():
|
|
88 |
parser.add_argument(
|
89 |
"--wav_file",
|
90 |
# default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
|
91 |
-
|
92 |
# default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
|
93 |
# default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
|
94 |
# default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
|
95 |
-
default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
|
96 |
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_w_8b6e28e2-a238-4c8c-b2e3-426b1fca149b_6.wav",
|
97 |
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0a56f035-40f6-4530-b852-613f057d718d_6.wav",
|
98 |
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ae70b76-3651-4a71-bc0c-9e1429e4c854_5.wav",
|
@@ -133,18 +142,29 @@ def main():
|
|
133 |
)
|
134 |
frame_step = infer.config.hop_size
|
135 |
|
136 |
-
|
|
|
|
|
|
|
|
|
137 |
|
138 |
-
|
|
|
139 |
|
140 |
speech_probs = process_speech_probs(
|
141 |
signal=signal,
|
142 |
speech_probs=speech_probs,
|
143 |
frame_step=frame_step,
|
144 |
)
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
# plot
|
147 |
make_visualization(signal, speech_probs, SAMPLE_RATE)
|
|
|
148 |
return
|
149 |
|
150 |
|
|
|
60 |
shutil.rmtree(model_path)
|
61 |
return config, model
|
62 |
|
63 |
+
def infer(self, signal: torch.Tensor) -> dict:
|
64 |
# signal shape: [num_samples,], value between -1 and 1.
|
65 |
|
66 |
inputs = torch.tensor(signal, dtype=torch.float32)
|
|
|
73 |
# probs shape: [b, t, 1]
|
74 |
probs = torch.squeeze(probs, dim=-1)
|
75 |
# probs shape: [b, t]
|
|
|
76 |
probs = probs.numpy()
|
77 |
probs = probs[0]
|
78 |
+
|
79 |
+
# lsnr shape: [b, t, 1]
|
80 |
+
lsnr = torch.squeeze(lsnr, dim=-1)
|
81 |
+
# lsnr shape: [b, t]
|
82 |
+
lsnr = lsnr.numpy()
|
83 |
+
lsnr = lsnr[0]
|
84 |
+
|
85 |
+
result = {
|
86 |
+
"probs": probs,
|
87 |
+
"lsnr": lsnr,
|
88 |
+
}
|
89 |
+
return result
|
90 |
|
91 |
def post_process(self, probs: List[float]):
|
92 |
return
|
|
|
97 |
parser.add_argument(
|
98 |
"--wav_file",
|
99 |
# default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
|
100 |
+
default=(project_path / "data/examples/ai_agent/chinese-5.wav").as_posix(),
|
101 |
# default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
|
102 |
# default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
|
103 |
# default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
|
104 |
+
# default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
|
105 |
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_w_8b6e28e2-a238-4c8c-b2e3-426b1fca149b_6.wav",
|
106 |
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0a56f035-40f6-4530-b852-613f057d718d_6.wav",
|
107 |
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ae70b76-3651-4a71-bc0c-9e1429e4c854_5.wav",
|
|
|
142 |
)
|
143 |
frame_step = infer.config.hop_size
|
144 |
|
145 |
+
vad_info = infer.infer(signal)
|
146 |
+
speech_probs = vad_info["probs"]
|
147 |
+
lsnr = vad_info["lsnr"]
|
148 |
+
|
149 |
+
lsnr = lsnr / np.max(np.abs(lsnr))
|
150 |
|
151 |
+
speech_probs = speech_probs.tolist()
|
152 |
+
lsnr = lsnr.tolist()
|
153 |
|
154 |
speech_probs = process_speech_probs(
|
155 |
signal=signal,
|
156 |
speech_probs=speech_probs,
|
157 |
frame_step=frame_step,
|
158 |
)
|
159 |
+
lsnr = process_speech_probs(
|
160 |
+
signal=signal,
|
161 |
+
speech_probs=lsnr,
|
162 |
+
frame_step=frame_step,
|
163 |
+
)
|
164 |
|
165 |
# plot
|
166 |
make_visualization(signal, speech_probs, SAMPLE_RATE)
|
167 |
+
make_visualization(signal, lsnr, SAMPLE_RATE)
|
168 |
return
|
169 |
|
170 |
|
toolbox/torchaudio/models/vad/silero_vad/inference_silero_vad.py
CHANGED
@@ -60,7 +60,7 @@ class InferenceSileroVad(object):
|
|
60 |
shutil.rmtree(model_path)
|
61 |
return config, model
|
62 |
|
63 |
-
def infer(self, signal: torch.Tensor) ->
|
64 |
# signal shape: [num_samples,], value between -1 and 1.
|
65 |
|
66 |
inputs = torch.tensor(signal, dtype=torch.float32)
|
@@ -68,16 +68,25 @@ class InferenceSileroVad(object):
|
|
68 |
# inputs shape: [1, num_samples,]
|
69 |
|
70 |
with torch.no_grad():
|
71 |
-
logits, probs = self.model.forward(inputs)
|
72 |
|
73 |
# probs shape: [b, t, 1]
|
74 |
probs = torch.squeeze(probs, dim=-1)
|
75 |
# probs shape: [b, t]
|
76 |
-
|
77 |
probs = probs.numpy()
|
78 |
probs = probs[0]
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
def post_process(self, probs: List[float]):
|
83 |
return
|
@@ -87,11 +96,29 @@ def get_args():
|
|
87 |
parser = argparse.ArgumentParser()
|
88 |
parser.add_argument(
|
89 |
"--wav_file",
|
90 |
-
default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
|
91 |
# default=(project_path / "data/examples/ai_agent/chinese-5.wav").as_posix(),
|
92 |
# default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
|
93 |
# default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
|
94 |
-
# default=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
type=str,
|
96 |
)
|
97 |
args = parser.parse_args()
|
@@ -115,18 +142,29 @@ def main():
|
|
115 |
)
|
116 |
frame_step = infer.model.hop_size
|
117 |
|
118 |
-
|
|
|
|
|
|
|
|
|
119 |
|
120 |
-
|
|
|
121 |
|
122 |
speech_probs = process_speech_probs(
|
123 |
signal=signal,
|
124 |
speech_probs=speech_probs,
|
125 |
frame_step=frame_step,
|
126 |
)
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
# plot
|
129 |
make_visualization(signal, speech_probs, SAMPLE_RATE)
|
|
|
130 |
return
|
131 |
|
132 |
|
|
|
60 |
shutil.rmtree(model_path)
|
61 |
return config, model
|
62 |
|
63 |
+
def infer(self, signal: torch.Tensor) -> dict:
|
64 |
# signal shape: [num_samples,], value between -1 and 1.
|
65 |
|
66 |
inputs = torch.tensor(signal, dtype=torch.float32)
|
|
|
68 |
# inputs shape: [1, num_samples,]
|
69 |
|
70 |
with torch.no_grad():
|
71 |
+
logits, probs, lsnr = self.model.forward(inputs)
|
72 |
|
73 |
# probs shape: [b, t, 1]
|
74 |
probs = torch.squeeze(probs, dim=-1)
|
75 |
# probs shape: [b, t]
|
|
|
76 |
probs = probs.numpy()
|
77 |
probs = probs[0]
|
78 |
+
|
79 |
+
# lsnr shape: [b, t, 1]
|
80 |
+
lsnr = torch.squeeze(lsnr, dim=-1)
|
81 |
+
# lsnr shape: [b, t]
|
82 |
+
lsnr = lsnr.numpy()
|
83 |
+
lsnr = lsnr[0]
|
84 |
+
|
85 |
+
result = {
|
86 |
+
"probs": probs,
|
87 |
+
"lsnr": lsnr,
|
88 |
+
}
|
89 |
+
return result
|
90 |
|
91 |
def post_process(self, probs: List[float]):
|
92 |
return
|
|
|
96 |
parser = argparse.ArgumentParser()
|
97 |
parser.add_argument(
|
98 |
"--wav_file",
|
99 |
+
# default=(project_path / "data/examples/ai_agent/chinese-4.wav").as_posix(),
|
100 |
# default=(project_path / "data/examples/ai_agent/chinese-5.wav").as_posix(),
|
101 |
# default=(project_path / "data/examples/hado/b556437e-c68b-4f6d-9eed-2977c29db887.wav").as_posix(),
|
102 |
# default=(project_path / "data/examples/hado/eae93a33-8ee0-4d86-8f85-cac5116ae6ef.wav").as_posix(),
|
103 |
+
# default=(project_path / "data/examples/speech/active_media_r_0ba69730-66a4-4ecd-8929-ef58f18f4612_2.wav").as_posix(),
|
104 |
+
# default=(project_path / "data/examples/speech/active_media_r_2a2f472b-a0b8-4fd5-b1c4-1aedc5d2ce57_0.wav").as_posix(),
|
105 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_w_8b6e28e2-a238-4c8c-b2e3-426b1fca149b_6.wav",
|
106 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0a56f035-40f6-4530-b852-613f057d718d_6.wav",
|
107 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ae70b76-3651-4a71-bc0c-9e1429e4c854_5.wav",
|
108 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d483249-57f8-4d45-b4c6-bda82d6816ae_2.wav",
|
109 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0d952885-5bc2-4633-81b6-e0e809e113f1_2.wav",
|
110 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-SG\2025-05-19\active_media_r_0ddac777-d986-4a5c-9c7c-ff64be0a463d_11.wav",
|
111 |
+
|
112 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0b8a8e80-52af-423b-8877-03a78b1e6e43_0.wav",
|
113 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0ebffb68-6490-4a8b-8eb6-eb82443d7d75_0.wav",
|
114 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_0f6ec933-90df-447b-aca4-6ddc149452ab_0.wav",
|
115 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_0.wav",
|
116 |
+
default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aac396f-1661-4f26-ab49-1a4879684567_1.wav",
|
117 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1aff518b-4749-42fc-adfe-64046f9baeb6_0.wav",
|
118 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_0.wav",
|
119 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1b16f2a3-a8c9-4739-9a76-59faf1c64d79_1.wav",
|
120 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1bb1f22e-9c3a-4aea-b53f-71cc6547a6ee_0.wav",
|
121 |
+
# default=r"D:\Users\tianx\HuggingDatasets\nx_noise\data\noise\en-SG\2025-05-19\active_media_r_1dab161b-2a76-4491-abd1-60dba6172f8d_2.wav",
|
122 |
type=str,
|
123 |
)
|
124 |
args = parser.parse_args()
|
|
|
142 |
)
|
143 |
frame_step = infer.model.hop_size
|
144 |
|
145 |
+
vad_info = infer.infer(signal)
|
146 |
+
speech_probs = vad_info["probs"]
|
147 |
+
lsnr = vad_info["lsnr"]
|
148 |
+
|
149 |
+
lsnr = lsnr / np.max(np.abs(lsnr))
|
150 |
|
151 |
+
speech_probs = speech_probs.tolist()
|
152 |
+
lsnr = lsnr.tolist()
|
153 |
|
154 |
speech_probs = process_speech_probs(
|
155 |
signal=signal,
|
156 |
speech_probs=speech_probs,
|
157 |
frame_step=frame_step,
|
158 |
)
|
159 |
+
lsnr = process_speech_probs(
|
160 |
+
signal=signal,
|
161 |
+
speech_probs=lsnr,
|
162 |
+
frame_step=frame_step,
|
163 |
+
)
|
164 |
|
165 |
# plot
|
166 |
make_visualization(signal, speech_probs, SAMPLE_RATE)
|
167 |
+
make_visualization(signal, lsnr, SAMPLE_RATE)
|
168 |
return
|
169 |
|
170 |
|