HoneyTian commited on
Commit
cdf219b
·
1 Parent(s): fa467b8
main.py CHANGED
@@ -1,6 +1,7 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
 
4
  import logging
5
  from pathlib import Path
6
  import platform
@@ -71,9 +72,24 @@ denoise_engines = {
71
  "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix()
72
  }
73
  },
 
 
 
 
 
 
74
  }
75
 
76
 
 
 
 
 
 
 
 
 
 
77
  def when_click_denoise_button(noisy_audio_t, engine: str):
78
  sample_rate, signal = noisy_audio_t
79
  logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}")
@@ -85,9 +101,7 @@ def when_click_denoise_button(noisy_audio_t, engine: str):
85
  raise gr.Error(f"invalid denoise engine: {engine}.")
86
 
87
  try:
88
- infer_cls = infer_engine_param["infer_cls"]
89
- kwargs = infer_engine_param["kwargs"]
90
- infer_engine = infer_cls(**kwargs)
91
 
92
  enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
93
  enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
@@ -117,13 +131,14 @@ def main():
117
  denoise_engine_choices = list(denoise_engines.keys())
118
 
119
  # examples
120
- example_zip_file = trained_model_dir / "examples.zip"
121
- with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
122
- out_root = examples_dir
123
- if out_root.exists():
124
- shutil.rmtree(out_root.as_posix())
125
- out_root.mkdir(parents=True, exist_ok=True)
126
- f_zip.extractall(path=out_root)
 
127
 
128
  # examples
129
  examples = list()
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
+ from functools import lru_cache
5
  import logging
6
  from pathlib import Path
7
  import platform
 
72
  "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix()
73
  }
74
  },
75
+ "mpnet-nx-speech-1-epoch": {
76
+ "infer_cls": InferenceMPNet,
77
+ "kwargs": {
78
+ "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-nx-speech-1-epoch.zip").as_posix()
79
+ }
80
+ },
81
  }
82
 
83
 
84
+ @lru_cache(maxsize=3)
85
+ def load_denoise_model(infer_engine_param: dict):
86
+ infer_cls = infer_engine_param["infer_cls"]
87
+ kwargs = infer_engine_param["kwargs"]
88
+ infer_engine = infer_cls(**kwargs)
89
+
90
+ return infer_engine
91
+
92
+
93
  def when_click_denoise_button(noisy_audio_t, engine: str):
94
  sample_rate, signal = noisy_audio_t
95
  logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}")
 
101
  raise gr.Error(f"invalid denoise engine: {engine}.")
102
 
103
  try:
104
+ infer_engine = load_denoise_model(infer_engine_param)
 
 
105
 
106
  enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
107
  enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
 
131
  denoise_engine_choices = list(denoise_engines.keys())
132
 
133
  # examples
134
+ if not examples_dir.exists():
135
+ example_zip_file = trained_model_dir / "examples.zip"
136
+ with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
137
+ out_root = examples_dir
138
+ if out_root.exists():
139
+ shutil.rmtree(out_root.as_posix())
140
+ out_root.mkdir(parents=True, exist_ok=True)
141
+ f_zip.extractall(path=out_root)
142
 
143
  # examples
144
  examples = list()
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py CHANGED
@@ -7,6 +7,10 @@ https://github.com/nvidia/cleanunet
7
 
8
  https://huggingface.co/spaces/fsoft-ai-center/Speech-Enhancement/blob/main/src/model.py
9
 
 
 
 
 
10
  """
11
  import os
12
  from typing import Optional, Union
 
7
 
8
  https://huggingface.co/spaces/fsoft-ai-center/Speech-Enhancement/blob/main/src/model.py
9
 
10
+ 支持流式改造。
11
+
12
+ https://github.com/francislr/clean-unet-inference
13
+
14
  """
15
  import os
16
  from typing import Optional, Union
toolbox/torchaudio/models/conv_tasnet/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == '__main__':
6
+ pass
toolbox/torchaudio/models/conv_tasnet/modeling_conv_tasnet.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/kaituoxu/Conv-TasNet/blob/master/src/conv_tasnet.py
5
+ """
6
+
7
+
8
+ if __name__ == '__main__':
9
+ pass