HoneyTian commited on
Commit
1e78a70
·
1 Parent(s): 04e3488
examples/mpnet_aishell/step_3_evaluation.py CHANGED
@@ -85,14 +85,14 @@ def save_audios(noise_audio: torch.Tensor,
85
  output_dir.mkdir(parents=True, exist_ok=True)
86
 
87
  filename = output_dir / "noise_audio.wav"
88
- torchaudio.save(filename, noise_audio.detach().cpu(), sample_rate)
89
  filename = output_dir / "clean_audio.wav"
90
- torchaudio.save(filename, clean_audio.detach().cpu(), sample_rate)
91
  filename = output_dir / "noisy_audio.wav"
92
- torchaudio.save(filename, noisy_audio.detach().cpu(), sample_rate)
93
 
94
  filename = output_dir / "enhanced_audio.wav"
95
- torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
96
 
97
  return output_dir.as_posix()
98
 
@@ -159,14 +159,15 @@ def main():
159
  # inference
160
  clean_audio = clean_audio.to(device)
161
  noisy_audio = noisy_audio.to(device)
162
- noisy_mag, noisy_pha, noisy_com = mag_pha_stft(
163
- noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor
164
- )
165
- mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
166
- audio_g = mag_pha_istft(
167
- mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor
168
- )
169
- enhanced_audio = audio_g.detach()
 
170
 
171
  save_audios(
172
  noise_audio, clean_audio, noisy_audio,
 
85
  output_dir.mkdir(parents=True, exist_ok=True)
86
 
87
  filename = output_dir / "noise_audio.wav"
88
+ torchaudio.save(filename, noise_audio.detach().cpu(), sample_rate, bits_per_sample=16)
89
  filename = output_dir / "clean_audio.wav"
90
+ torchaudio.save(filename, clean_audio.detach().cpu(), sample_rate, bits_per_sample=16)
91
  filename = output_dir / "noisy_audio.wav"
92
+ torchaudio.save(filename, noisy_audio.detach().cpu(), sample_rate, bits_per_sample=16)
93
 
94
  filename = output_dir / "enhanced_audio.wav"
95
+ torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate, bits_per_sample=16)
96
 
97
  return output_dir.as_posix()
98
 
 
159
  # inference
160
  clean_audio = clean_audio.to(device)
161
  noisy_audio = noisy_audio.to(device)
162
+ with torch.no_grad():
163
+ noisy_mag, noisy_pha, noisy_com = mag_pha_stft(
164
+ noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor
165
+ )
166
+ mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
167
+ audio_g = mag_pha_istft(
168
+ mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor
169
+ )
170
+ enhanced_audio = audio_g.detach()
171
 
172
  save_audios(
173
  noise_audio, clean_audio, noisy_audio,
main.py CHANGED
@@ -4,8 +4,11 @@ import argparse
4
  import platform
5
 
6
  import gradio as gr
 
 
7
 
8
  from project_settings import environment, project_path
 
9
 
10
 
11
  def get_args():
@@ -25,12 +28,58 @@ def get_args():
25
  return args
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def main():
29
  args = get_args()
30
 
 
 
 
31
  # ui
32
  with gr.Blocks() as blocks:
33
- gr.Markdown(value="in progress.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # http://127.0.0.1:7864/
36
  blocks.queue().launch(
 
4
  import platform
5
 
6
  import gradio as gr
7
+ import numpy as np
8
+ import torch
9
 
10
  from project_settings import environment, project_path
11
+ from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
12
 
13
 
14
  def get_args():
 
28
  return args
29
 
30
 
31
+ denoise_engines = {
32
+ "mpnet": InferenceMPNet(
33
+ pretrained_model_path_or_zip_file=(project_path / "trained_models/mpnet_aishell_20250221.zip").as_posix(),
34
+ ),
35
+
36
+ }
37
+
38
+
39
+ def when_click_denoise_button(noisy_audio_t, engine: str):
40
+ sample_rate, signal = noisy_audio_t
41
+
42
+ noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
43
+
44
+ infer_engine = denoise_engines.get(engine)
45
+ if infer_engine is None:
46
+ raise gr.Error(f"invalid denoise engine: {engine}.")
47
+
48
+ try:
49
+ enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
50
+ enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
51
+ except Exception as e:
52
+ raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.")
53
+
54
+ enhanced_audio_t = (sample_rate, enhanced_audio)
55
+ return enhanced_audio_t, None
56
+
57
+
58
  def main():
59
  args = get_args()
60
 
61
+ # choices
62
+ denoise_engine_choices = list(denoise_engines.keys())
63
+
64
  # ui
65
  with gr.Blocks() as blocks:
66
+ gr.Markdown(value="nx denoise.")
67
+ with gr.Tabs():
68
+ with gr.TabItem("denoise"):
69
+ with gr.Row():
70
+ with gr.Column(variant="panel", scale=5):
71
+ dn_noisy_audio = gr.Audio(label="noisy_audio")
72
+ dn_engine = gr.Dropdown(choices=denoise_engine_choices, value=denoise_engine_choices[0], label="engine")
73
+ dn_button = gr.Button(variant="primary")
74
+ with gr.Column(variant="panel", scale=5):
75
+ dn_enhanced_audio = gr.Audio(label="enhanced_audio")
76
+ dn_clean_audio = gr.Audio(label="clean_audio")
77
+
78
+ dn_button.click(
79
+ when_click_denoise_button,
80
+ inputs=[dn_noisy_audio, dn_engine],
81
+ outputs=[dn_enhanced_audio, dn_clean_audio]
82
+ )
83
 
84
  # http://127.0.0.1:7864/
85
  blocks.queue().launch(
toolbox/torchaudio/models/mpnet/inference_mpnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ from pathlib import Path
5
+ import shutil
6
+ import tempfile
7
+ import zipfile
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+
14
+ from project_settings import project_path
15
+ from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
16
+ from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel, MODEL_FILE
17
+ from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
18
+
19
+ logger = logging.getLogger("toolbox")
20
+
21
+
22
+ class InferenceMPNet(object):
23
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
24
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
25
+ self.device = torch.device(device)
26
+
27
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
28
+ config, generator = self.load_models(self.pretrained_model_path_or_zip_file)
29
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
30
+
31
+ self.config = config
32
+ self.generator = generator
33
+ self.generator.to(device)
34
+ self.generator.eval()
35
+
36
+ def load_models(self, model_path: str):
37
+ model_path = Path(model_path)
38
+ if model_path.name.endswith(".zip"):
39
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
40
+ out_root = Path(tempfile.gettempdir()) / "nx_denoise"
41
+ out_root.mkdir(parents=True, exist_ok=True)
42
+ f_zip.extractall(path=out_root)
43
+ model_path = out_root / model_path.stem
44
+
45
+ config = MPNetConfig.from_pretrained(
46
+ pretrained_model_name_or_path=model_path.as_posix(),
47
+ )
48
+ generator = MPNetPretrainedModel.from_pretrained(
49
+ pretrained_model_name_or_path=model_path.as_posix(),
50
+ )
51
+ generator.to(self.device)
52
+ generator.eval()
53
+
54
+ shutil.rmtree(model_path)
55
+ return config, generator
56
+
57
+ def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
58
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
59
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
60
+
61
+ # noisy_audio shape: [batch_size, n_samples]
62
+ noisy_audio = self.enhancement_by_tensor(noisy_audio)
63
+ noisy_audio = noisy_audio[0]
64
+
65
+ return noisy_audio.cpu().numpy()
66
+
67
+ def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
68
+ if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
69
+ raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
70
+
71
+ noisy_audio = noisy_audio.to(self.device)
72
+
73
+ with torch.no_grad():
74
+ noisy_mag, noisy_pha, noisy_com = mag_pha_stft(
75
+ noisy_audio,
76
+ self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor
77
+ )
78
+ mag_g, pha_g, com_g = self.generator.forward(noisy_mag, noisy_pha)
79
+ audio_g = mag_pha_istft(
80
+ mag_g, pha_g,
81
+ self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor
82
+ )
83
+ enhanced_audio = audio_g.detach()
84
+
85
+ return enhanced_audio
86
+
87
+ def main():
88
+ model_zip_file = project_path / "trained_models/mpnet_aishell_20250221.zip"
89
+ infer_mpnet = InferenceMPNet(model_zip_file)
90
+
91
+ sample_rate = 8000
92
+ noisy_audio_file = project_path / "data/examples/noisy_audio.wav"
93
+ noisy_audio, _ = librosa.load(
94
+ noisy_audio_file.as_posix(),
95
+ sr=sample_rate,
96
+ )
97
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
98
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
99
+
100
+ enhanced_audio = infer_mpnet.enhancement_by_tensor(noisy_audio)
101
+
102
+ filename = "enhanced_audio.wav"
103
+ torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
104
+
105
+ return
106
+
107
+
108
+ if __name__ == '__main__':
109
+ main()
toolbox/torchaudio/models/mpnet/modeling_mpnet.py CHANGED
@@ -8,6 +8,8 @@ https://huggingface.co/spaces/JacobLinCool/MP-SENet
8
  https://arxiv.org/abs/2305.13686
9
  https://github.com/yxlu-0102/MP-SENet
10
 
 
 
11
  """
12
  import os
13
  from typing import Optional, Union
 
8
  https://arxiv.org/abs/2305.13686
9
  https://github.com/yxlu-0102/MP-SENet
10
 
11
+ 应该是不支持流式改造的。
12
+
13
  """
14
  import os
15
  from typing import Optional, Union