Spaces:
Running
Running
update
Browse files
examples/mpnet_aishell/step_3_evaluation.py
CHANGED
@@ -85,14 +85,14 @@ def save_audios(noise_audio: torch.Tensor,
|
|
85 |
output_dir.mkdir(parents=True, exist_ok=True)
|
86 |
|
87 |
filename = output_dir / "noise_audio.wav"
|
88 |
-
torchaudio.save(filename, noise_audio.detach().cpu(), sample_rate)
|
89 |
filename = output_dir / "clean_audio.wav"
|
90 |
-
torchaudio.save(filename, clean_audio.detach().cpu(), sample_rate)
|
91 |
filename = output_dir / "noisy_audio.wav"
|
92 |
-
torchaudio.save(filename, noisy_audio.detach().cpu(), sample_rate)
|
93 |
|
94 |
filename = output_dir / "enhanced_audio.wav"
|
95 |
-
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
|
96 |
|
97 |
return output_dir.as_posix()
|
98 |
|
@@ -159,14 +159,15 @@ def main():
|
|
159 |
# inference
|
160 |
clean_audio = clean_audio.to(device)
|
161 |
noisy_audio = noisy_audio.to(device)
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
170 |
|
171 |
save_audios(
|
172 |
noise_audio, clean_audio, noisy_audio,
|
|
|
85 |
output_dir.mkdir(parents=True, exist_ok=True)
|
86 |
|
87 |
filename = output_dir / "noise_audio.wav"
|
88 |
+
torchaudio.save(filename, noise_audio.detach().cpu(), sample_rate, bits_per_sample=16)
|
89 |
filename = output_dir / "clean_audio.wav"
|
90 |
+
torchaudio.save(filename, clean_audio.detach().cpu(), sample_rate, bits_per_sample=16)
|
91 |
filename = output_dir / "noisy_audio.wav"
|
92 |
+
torchaudio.save(filename, noisy_audio.detach().cpu(), sample_rate, bits_per_sample=16)
|
93 |
|
94 |
filename = output_dir / "enhanced_audio.wav"
|
95 |
+
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate, bits_per_sample=16)
|
96 |
|
97 |
return output_dir.as_posix()
|
98 |
|
|
|
159 |
# inference
|
160 |
clean_audio = clean_audio.to(device)
|
161 |
noisy_audio = noisy_audio.to(device)
|
162 |
+
with torch.no_grad():
|
163 |
+
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(
|
164 |
+
noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor
|
165 |
+
)
|
166 |
+
mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
|
167 |
+
audio_g = mag_pha_istft(
|
168 |
+
mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor
|
169 |
+
)
|
170 |
+
enhanced_audio = audio_g.detach()
|
171 |
|
172 |
save_audios(
|
173 |
noise_audio, clean_audio, noisy_audio,
|
main.py
CHANGED
@@ -4,8 +4,11 @@ import argparse
|
|
4 |
import platform
|
5 |
|
6 |
import gradio as gr
|
|
|
|
|
7 |
|
8 |
from project_settings import environment, project_path
|
|
|
9 |
|
10 |
|
11 |
def get_args():
|
@@ -25,12 +28,58 @@ def get_args():
|
|
25 |
return args
|
26 |
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def main():
|
29 |
args = get_args()
|
30 |
|
|
|
|
|
|
|
31 |
# ui
|
32 |
with gr.Blocks() as blocks:
|
33 |
-
gr.Markdown(value="
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
# http://127.0.0.1:7864/
|
36 |
blocks.queue().launch(
|
|
|
4 |
import platform
|
5 |
|
6 |
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
|
10 |
from project_settings import environment, project_path
|
11 |
+
from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
|
12 |
|
13 |
|
14 |
def get_args():
|
|
|
28 |
return args
|
29 |
|
30 |
|
31 |
+
denoise_engines = {
|
32 |
+
"mpnet": InferenceMPNet(
|
33 |
+
pretrained_model_path_or_zip_file=(project_path / "trained_models/mpnet_aishell_20250221.zip").as_posix(),
|
34 |
+
),
|
35 |
+
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
def when_click_denoise_button(noisy_audio_t, engine: str):
|
40 |
+
sample_rate, signal = noisy_audio_t
|
41 |
+
|
42 |
+
noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
|
43 |
+
|
44 |
+
infer_engine = denoise_engines.get(engine)
|
45 |
+
if infer_engine is None:
|
46 |
+
raise gr.Error(f"invalid denoise engine: {engine}.")
|
47 |
+
|
48 |
+
try:
|
49 |
+
enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
|
50 |
+
enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
|
51 |
+
except Exception as e:
|
52 |
+
raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.")
|
53 |
+
|
54 |
+
enhanced_audio_t = (sample_rate, enhanced_audio)
|
55 |
+
return enhanced_audio_t, None
|
56 |
+
|
57 |
+
|
58 |
def main():
|
59 |
args = get_args()
|
60 |
|
61 |
+
# choices
|
62 |
+
denoise_engine_choices = list(denoise_engines.keys())
|
63 |
+
|
64 |
# ui
|
65 |
with gr.Blocks() as blocks:
|
66 |
+
gr.Markdown(value="nx denoise.")
|
67 |
+
with gr.Tabs():
|
68 |
+
with gr.TabItem("denoise"):
|
69 |
+
with gr.Row():
|
70 |
+
with gr.Column(variant="panel", scale=5):
|
71 |
+
dn_noisy_audio = gr.Audio(label="noisy_audio")
|
72 |
+
dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine")
|
73 |
+
dn_button = gr.Button(variant="primary")
|
74 |
+
with gr.Column(variant="panel", scale=5):
|
75 |
+
dn_enhanced_audio = gr.Audio(label="enhanced_audio")
|
76 |
+
dn_clean_audio = gr.Audio(label="clean_audio")
|
77 |
+
|
78 |
+
dn_button.click(
|
79 |
+
when_click_denoise_button,
|
80 |
+
inputs=[dn_noisy_audio, dn_engine],
|
81 |
+
outputs=[dn_enhanced_audio, dn_clean_audio]
|
82 |
+
)
|
83 |
|
84 |
# http://127.0.0.1:7864/
|
85 |
blocks.queue().launch(
|
toolbox/torchaudio/models/mpnet/inference_mpnet.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import logging
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
import tempfile
|
7 |
+
import zipfile
|
8 |
+
|
9 |
+
import librosa
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
|
14 |
+
from project_settings import project_path
|
15 |
+
from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
|
16 |
+
from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel, MODEL_FILE
|
17 |
+
from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
|
18 |
+
|
19 |
+
logger = logging.getLogger("toolbox")
|
20 |
+
|
21 |
+
|
22 |
+
class InferenceMPNet(object):
|
23 |
+
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
24 |
+
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
25 |
+
self.device = torch.device(device)
|
26 |
+
|
27 |
+
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
28 |
+
config, generator = self.load_models(self.pretrained_model_path_or_zip_file)
|
29 |
+
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
30 |
+
|
31 |
+
self.config = config
|
32 |
+
self.generator = generator
|
33 |
+
self.generator.to(device)
|
34 |
+
self.generator.eval()
|
35 |
+
|
36 |
+
def load_models(self, model_path: str):
|
37 |
+
model_path = Path(model_path)
|
38 |
+
if model_path.name.endswith(".zip"):
|
39 |
+
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
40 |
+
out_root = Path(tempfile.gettempdir()) / "nx_denoise"
|
41 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
42 |
+
f_zip.extractall(path=out_root)
|
43 |
+
model_path = out_root / model_path.stem
|
44 |
+
|
45 |
+
config = MPNetConfig.from_pretrained(
|
46 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
47 |
+
)
|
48 |
+
generator = MPNetPretrainedModel.from_pretrained(
|
49 |
+
pretrained_model_name_or_path=model_path.as_posix(),
|
50 |
+
)
|
51 |
+
generator.to(self.device)
|
52 |
+
generator.eval()
|
53 |
+
|
54 |
+
shutil.rmtree(model_path)
|
55 |
+
return config, generator
|
56 |
+
|
57 |
+
def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
|
58 |
+
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
59 |
+
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
60 |
+
|
61 |
+
# noisy_audio shape: [batch_size, n_samples]
|
62 |
+
noisy_audio = self.enhancement_by_tensor(noisy_audio)
|
63 |
+
noisy_audio = noisy_audio[0]
|
64 |
+
|
65 |
+
return noisy_audio.cpu().numpy()
|
66 |
+
|
67 |
+
def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
|
68 |
+
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
|
69 |
+
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
|
70 |
+
|
71 |
+
noisy_audio = noisy_audio.to(self.device)
|
72 |
+
|
73 |
+
with torch.no_grad():
|
74 |
+
noisy_mag, noisy_pha, noisy_com = mag_pha_stft(
|
75 |
+
noisy_audio,
|
76 |
+
self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor
|
77 |
+
)
|
78 |
+
mag_g, pha_g, com_g = self.generator.forward(noisy_mag, noisy_pha)
|
79 |
+
audio_g = mag_pha_istft(
|
80 |
+
mag_g, pha_g,
|
81 |
+
self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor
|
82 |
+
)
|
83 |
+
enhanced_audio = audio_g.detach()
|
84 |
+
|
85 |
+
return enhanced_audio
|
86 |
+
|
87 |
+
def main():
|
88 |
+
model_zip_file = project_path / "trained_models/mpnet_aishell_20250221.zip"
|
89 |
+
infer_mpnet = InferenceMPNet(model_zip_file)
|
90 |
+
|
91 |
+
sample_rate = 8000
|
92 |
+
noisy_audio_file = project_path / "data/examples/noisy_audio.wav"
|
93 |
+
noisy_audio, _ = librosa.load(
|
94 |
+
noisy_audio_file.as_posix(),
|
95 |
+
sr=sample_rate,
|
96 |
+
)
|
97 |
+
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
98 |
+
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
99 |
+
|
100 |
+
enhanced_audio = infer_mpnet.enhancement_by_tensor(noisy_audio)
|
101 |
+
|
102 |
+
filename = "enhanced_audio.wav"
|
103 |
+
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
|
104 |
+
|
105 |
+
return
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == '__main__':
|
109 |
+
main()
|
toolbox/torchaudio/models/mpnet/modeling_mpnet.py
CHANGED
@@ -8,6 +8,8 @@ https://huggingface.co/spaces/JacobLinCool/MP-SENet
|
|
8 |
https://arxiv.org/abs/2305.13686
|
9 |
https://github.com/yxlu-0102/MP-SENet
|
10 |
|
|
|
|
|
11 |
"""
|
12 |
import os
|
13 |
from typing import Optional, Union
|
|
|
8 |
https://arxiv.org/abs/2305.13686
|
9 |
https://github.com/yxlu-0102/MP-SENet
|
10 |
|
11 |
+
应该是不支持流式改造的。
|
12 |
+
|
13 |
"""
|
14 |
import os
|
15 |
from typing import Optional, Union
|