Spaces:
Running
Running
update
Browse files- examples/conv_tasnet/step_2_train_model.py +36 -18
- examples/conv_tasnet/yaml/config.yaml +8 -0
- toolbox/torchaudio/models/clean_unet/inference_clean_unet.py +2 -1
- toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py +12 -0
- toolbox/torchaudio/models/conv_tasnet/inference_conv_tasnet.py +112 -0
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
-
https://github.com/
|
5 |
"""
|
6 |
import argparse
|
7 |
import json
|
@@ -42,14 +42,11 @@ def get_args():
|
|
42 |
parser.add_argument("--max_epochs", default=200, type=int)
|
43 |
|
44 |
parser.add_argument("--batch_size", default=8, type=int)
|
45 |
-
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
46 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
47 |
parser.add_argument("--patience", default=5, type=int)
|
48 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
49 |
parser.add_argument("--seed", default=1234, type=int)
|
50 |
|
51 |
-
parser.add_argument("--eval_steps", default=25000, type=int)
|
52 |
-
|
53 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
54 |
|
55 |
args = parser.parse_args()
|
@@ -171,7 +168,7 @@ def main():
|
|
171 |
|
172 |
# optimizer
|
173 |
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
174 |
-
optimizer = torch.optim.AdamW(model.parameters(),
|
175 |
|
176 |
# resume training
|
177 |
last_epoch = -1
|
@@ -197,10 +194,21 @@ def main():
|
|
197 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
198 |
optimizer.load_state_dict(state_dict)
|
199 |
|
200 |
-
lr_scheduler
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
206 |
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
@@ -209,6 +217,8 @@ def main():
|
|
209 |
fft_size_list=[256, 512, 1024],
|
210 |
win_size_list=[120, 240, 480],
|
211 |
hop_size_list=[25, 50, 100],
|
|
|
|
|
212 |
reduction="mean"
|
213 |
).to(device)
|
214 |
|
@@ -222,7 +232,7 @@ def main():
|
|
222 |
average_neg_stoi_loss = 1000000000
|
223 |
|
224 |
model_list = list()
|
225 |
-
|
226 |
best_metric = None
|
227 |
patience_count = 0
|
228 |
|
@@ -260,7 +270,10 @@ def main():
|
|
260 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
261 |
|
262 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
263 |
-
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
|
|
|
|
|
|
264 |
|
265 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
266 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
@@ -288,6 +301,7 @@ def main():
|
|
288 |
|
289 |
progress_bar_train.update(1)
|
290 |
progress_bar_train.set_postfix({
|
|
|
291 |
"pesq_score": average_pesq_score,
|
292 |
"loss": average_loss,
|
293 |
"ae_loss": average_ae_loss,
|
@@ -298,7 +312,7 @@ def main():
|
|
298 |
|
299 |
# evaluation
|
300 |
total_steps += 1
|
301 |
-
if total_steps %
|
302 |
with torch.no_grad():
|
303 |
torch.cuda.empty_cache()
|
304 |
|
@@ -311,7 +325,7 @@ def main():
|
|
311 |
|
312 |
progress_bar_train.close()
|
313 |
progress_bar_eval = tqdm(
|
314 |
-
desc="Evaluation; step-{}".format(total_steps),
|
315 |
)
|
316 |
for eval_batch in valid_data_loader:
|
317 |
clean_audios, noisy_audios = eval_batch
|
@@ -327,7 +341,10 @@ def main():
|
|
327 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
328 |
|
329 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
330 |
-
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
|
|
|
|
|
|
331 |
|
332 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
333 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
@@ -350,6 +367,7 @@ def main():
|
|
350 |
|
351 |
progress_bar_eval.update(1)
|
352 |
progress_bar_eval.set_postfix({
|
|
|
353 |
"pesq_score": average_pesq_score,
|
354 |
"loss": average_loss,
|
355 |
"ae_loss": average_ae_loss,
|
@@ -373,7 +391,7 @@ def main():
|
|
373 |
)
|
374 |
|
375 |
# save path
|
376 |
-
save_dir = serialization_dir / "steps-{}".format(total_steps)
|
377 |
save_dir.mkdir(parents=True, exist_ok=False)
|
378 |
|
379 |
# save models
|
@@ -389,18 +407,18 @@ def main():
|
|
389 |
|
390 |
# save metric
|
391 |
if best_metric is None:
|
392 |
-
|
393 |
best_metric = average_pesq_score
|
394 |
elif average_pesq_score > best_metric:
|
395 |
# great is better.
|
396 |
-
|
397 |
best_metric = average_pesq_score
|
398 |
else:
|
399 |
pass
|
400 |
|
401 |
metrics = {
|
402 |
"idx_epoch": idx_epoch,
|
403 |
-
"
|
404 |
"pesq_score": average_pesq_score,
|
405 |
"loss": average_loss,
|
406 |
"ae_loss": average_ae_loss,
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
+
https://github.com/kaituoxu/Conv-TasNet/tree/master/src
|
5 |
"""
|
6 |
import argparse
|
7 |
import json
|
|
|
42 |
parser.add_argument("--max_epochs", default=200, type=int)
|
43 |
|
44 |
parser.add_argument("--batch_size", default=8, type=int)
|
|
|
45 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
46 |
parser.add_argument("--patience", default=5, type=int)
|
47 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
48 |
parser.add_argument("--seed", default=1234, type=int)
|
49 |
|
|
|
|
|
50 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
51 |
|
52 |
args = parser.parse_args()
|
|
|
168 |
|
169 |
# optimizer
|
170 |
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
171 |
+
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
172 |
|
173 |
# resume training
|
174 |
last_epoch = -1
|
|
|
194 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
195 |
optimizer.load_state_dict(state_dict)
|
196 |
|
197 |
+
if config.lr_scheduler == "CosineAnnealingLR":
|
198 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
199 |
+
optimizer,
|
200 |
+
last_epoch=last_epoch,
|
201 |
+
# T_max=10 * config.eval_steps,
|
202 |
+
# eta_min=0.01 * config.lr,
|
203 |
+
**config.lr_scheduler_kwargs,
|
204 |
+
)
|
205 |
+
elif config.lr_scheduler == "MultiStepLR":
|
206 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
207 |
+
optimizer,
|
208 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
209 |
+
)
|
210 |
+
else:
|
211 |
+
raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
|
212 |
|
213 |
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
214 |
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
|
|
217 |
fft_size_list=[256, 512, 1024],
|
218 |
win_size_list=[120, 240, 480],
|
219 |
hop_size_list=[25, 50, 100],
|
220 |
+
factor_sc=1.5,
|
221 |
+
factor_mag=1.0,
|
222 |
reduction="mean"
|
223 |
).to(device)
|
224 |
|
|
|
232 |
average_neg_stoi_loss = 1000000000
|
233 |
|
234 |
model_list = list()
|
235 |
+
best_steps = None
|
236 |
best_metric = None
|
237 |
patience_count = 0
|
238 |
|
|
|
270 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
271 |
|
272 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
273 |
+
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
274 |
+
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
|
275 |
+
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
276 |
+
loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
277 |
|
278 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
279 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
|
|
301 |
|
302 |
progress_bar_train.update(1)
|
303 |
progress_bar_train.set_postfix({
|
304 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
305 |
"pesq_score": average_pesq_score,
|
306 |
"loss": average_loss,
|
307 |
"ae_loss": average_ae_loss,
|
|
|
312 |
|
313 |
# evaluation
|
314 |
total_steps += 1
|
315 |
+
if total_steps % config.eval_steps == 0:
|
316 |
with torch.no_grad():
|
317 |
torch.cuda.empty_cache()
|
318 |
|
|
|
325 |
|
326 |
progress_bar_train.close()
|
327 |
progress_bar_eval = tqdm(
|
328 |
+
desc="Evaluation; step-{}k".format(int(total_steps/1000)),
|
329 |
)
|
330 |
for eval_batch in valid_data_loader:
|
331 |
clean_audios, noisy_audios = eval_batch
|
|
|
341 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
342 |
|
343 |
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
344 |
+
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
345 |
+
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.5 * mr_stft_loss + 0.3 * neg_stoi_loss
|
346 |
+
# loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss
|
347 |
+
loss = 2.0 * mr_stft_loss + 0.8 * ae_loss + 0.7 * neg_si_snr_loss + 0.5 * neg_stoi_loss
|
348 |
|
349 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
350 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
|
|
367 |
|
368 |
progress_bar_eval.update(1)
|
369 |
progress_bar_eval.set_postfix({
|
370 |
+
"lr": lr_scheduler.get_last_lr()[0],
|
371 |
"pesq_score": average_pesq_score,
|
372 |
"loss": average_loss,
|
373 |
"ae_loss": average_ae_loss,
|
|
|
391 |
)
|
392 |
|
393 |
# save path
|
394 |
+
save_dir = serialization_dir / "steps-{}k".format(int(total_steps/1000))
|
395 |
save_dir.mkdir(parents=True, exist_ok=False)
|
396 |
|
397 |
# save models
|
|
|
407 |
|
408 |
# save metric
|
409 |
if best_metric is None:
|
410 |
+
best_steps = total_steps
|
411 |
best_metric = average_pesq_score
|
412 |
elif average_pesq_score > best_metric:
|
413 |
# great is better.
|
414 |
+
best_steps = total_steps
|
415 |
best_metric = average_pesq_score
|
416 |
else:
|
417 |
pass
|
418 |
|
419 |
metrics = {
|
420 |
"idx_epoch": idx_epoch,
|
421 |
+
"best_steps": best_steps,
|
422 |
"pesq_score": average_pesq_score,
|
423 |
"loss": average_loss,
|
424 |
"ae_loss": average_ae_loss,
|
examples/conv_tasnet/yaml/config.yaml
CHANGED
@@ -15,3 +15,11 @@ sub_blocks_kernel_size: 3
|
|
15 |
norm_type: "gLN"
|
16 |
causal: false
|
17 |
mask_nonlinear: "relu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
norm_type: "gLN"
|
16 |
causal: false
|
17 |
mask_nonlinear: "relu"
|
18 |
+
|
19 |
+
lr: 0.001
|
20 |
+
lr_scheduler: "CosineAnnealingLR"
|
21 |
+
lr_scheduler_kwargs:
|
22 |
+
T_max: 250000
|
23 |
+
eta_min: 0.00001
|
24 |
+
|
25 |
+
eval_steps: 25000
|
toolbox/torchaudio/models/clean_unet/inference_clean_unet.py
CHANGED
@@ -79,6 +79,7 @@ class InferenceCleanUNet(object):
|
|
79 |
# enhanced_audio shape: [channels, num_samples]
|
80 |
return enhanced_audio
|
81 |
|
|
|
82 |
def main():
|
83 |
model_zip_file = project_path / "trained_models/clean-unet-aishell-18-epoch.zip"
|
84 |
infer_mpnet = InferenceCleanUNet(model_zip_file)
|
@@ -100,5 +101,5 @@ def main():
|
|
100 |
return
|
101 |
|
102 |
|
103 |
-
if __name__ ==
|
104 |
main()
|
|
|
79 |
# enhanced_audio shape: [channels, num_samples]
|
80 |
return enhanced_audio
|
81 |
|
82 |
+
|
83 |
def main():
|
84 |
model_zip_file = project_path / "trained_models/clean-unet-aishell-18-epoch.zip"
|
85 |
infer_mpnet = InferenceCleanUNet(model_zip_file)
|
|
|
101 |
return
|
102 |
|
103 |
|
104 |
+
if __name__ == "__main__":
|
105 |
main()
|
toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py
CHANGED
@@ -27,6 +27,12 @@ class ConvTasNetConfig(PretrainedConfig):
|
|
27 |
causal: bool = False,
|
28 |
mask_nonlinear: str = "relu",
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
**kwargs
|
31 |
):
|
32 |
super(ConvTasNetConfig, self).__init__(**kwargs)
|
@@ -47,6 +53,12 @@ class ConvTasNetConfig(PretrainedConfig):
|
|
47 |
self.causal = causal
|
48 |
self.mask_nonlinear = mask_nonlinear
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
if __name__ == "__main__":
|
52 |
pass
|
|
|
27 |
causal: bool = False,
|
28 |
mask_nonlinear: str = "relu",
|
29 |
|
30 |
+
lr: float = 1e-3,
|
31 |
+
eval_steps: int = 25000,
|
32 |
+
|
33 |
+
lr_scheduler: str = "CosineAnnealingLR",
|
34 |
+
lr_scheduler_kwargs: dict = None,
|
35 |
+
|
36 |
**kwargs
|
37 |
):
|
38 |
super(ConvTasNetConfig, self).__init__(**kwargs)
|
|
|
53 |
self.causal = causal
|
54 |
self.mask_nonlinear = mask_nonlinear
|
55 |
|
56 |
+
self.lr = lr
|
57 |
+
self.eval_steps = eval_steps
|
58 |
+
|
59 |
+
self.lr_scheduler = lr_scheduler
|
60 |
+
self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
|
61 |
+
|
62 |
|
63 |
if __name__ == "__main__":
|
64 |
pass
|
toolbox/torchaudio/models/conv_tasnet/inference_conv_tasnet.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import logging
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
import tempfile, time
|
7 |
+
import zipfile
|
8 |
+
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
torch.set_num_threads(1)
|
15 |
+
|
16 |
+
from project_settings import project_path
|
17 |
+
from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
|
18 |
+
from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNetPretrainedModel, MODEL_FILE
|
19 |
+
|
20 |
+
logger = logging.getLogger("toolbox")
|
21 |
+
|
22 |
+
|
23 |
+
class InferenceConvTasNet(object):
|
24 |
+
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
25 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
26 |
+
self.device = torch.device(device)
|
27 |
+
|
28 |
+
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
29 |
+
config, model = self.load_models(self.pretrained_model_path_or_zip_file)
|
30 |
+
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
31 |
+
|
32 |
+
self.config = config
|
33 |
+
self.model = model
|
34 |
+
self.model.to(device)
|
35 |
+
self.model.eval()
|
36 |
+
|
37 |
+
def load_models(self, model_path: str):
|
38 |
+
model_path = Path(model_path)
|
39 |
+
if model_path.name.endswith(".zip"):
|
40 |
+
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
41 |
+
out_root = Path(tempfile.gettempdir()) / "nx_denoise"
|
42 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
43 |
+
f_zip.extractall(path=out_root)
|
44 |
+
model_path = out_root / model_path.stem
|
45 |
+
|
46 |
+
config = ConvTasNetConfig.from_pretrained(
|
47 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
48 |
+
)
|
49 |
+
model = ConvTasNetPretrainedModel.from_pretrained(
|
50 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
51 |
+
)
|
52 |
+
model.to(self.device)
|
53 |
+
model.eval()
|
54 |
+
|
55 |
+
shutil.rmtree(model_path)
|
56 |
+
return config, model
|
57 |
+
|
58 |
+
def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
|
59 |
+
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
60 |
+
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
61 |
+
|
62 |
+
# noisy_audio shape: [batch_size, n_samples]
|
63 |
+
enhanced_audio = self.enhancement_by_tensor(noisy_audio)
|
64 |
+
# noisy_audio shape: [n_samples,]
|
65 |
+
return enhanced_audio.cpu().numpy()
|
66 |
+
|
67 |
+
def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
|
68 |
+
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
|
69 |
+
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
|
70 |
+
|
71 |
+
# noisy_audio shape: [batch_size, num_samples]
|
72 |
+
noisy_audios = noisy_audio.to(self.device)
|
73 |
+
|
74 |
+
with torch.no_grad():
|
75 |
+
enhanced_audios = self.model.forward(noisy_audios)
|
76 |
+
# enhanced_audio shape: [batch_size, channels, num_samples]
|
77 |
+
# enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
78 |
+
|
79 |
+
enhanced_audio = enhanced_audios[0]
|
80 |
+
|
81 |
+
# enhanced_audio shape: [channels, num_samples]
|
82 |
+
return enhanced_audio
|
83 |
+
|
84 |
+
|
85 |
+
def main():
|
86 |
+
model_zip_file = project_path / "trained_models/conv-tasnet-dns3-575k-steps.zip"
|
87 |
+
infer_conv_tasnet = InferenceConvTasNet(model_zip_file)
|
88 |
+
|
89 |
+
sample_rate = 8000
|
90 |
+
noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
|
91 |
+
noisy_audio, sample_rate = librosa.load(
|
92 |
+
noisy_audio_file.as_posix(),
|
93 |
+
sr=sample_rate,
|
94 |
+
)
|
95 |
+
duration = librosa.get_duration(y=noisy_audio, sr=sample_rate)
|
96 |
+
# noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
|
97 |
+
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
98 |
+
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
99 |
+
|
100 |
+
begin = time.time()
|
101 |
+
enhanced_audio = infer_conv_tasnet.enhancement_by_tensor(noisy_audio)
|
102 |
+
time_cost = time.time() - begin
|
103 |
+
print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
|
104 |
+
|
105 |
+
filename = "enhanced_audio.wav"
|
106 |
+
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
|
107 |
+
|
108 |
+
return
|
109 |
+
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
main()
|