HoneyTian commited on
Commit
f1a5461
·
1 Parent(s): 9192cea
Files changed (1) hide show
  1. main.py +22 -4
main.py CHANGED
@@ -1,6 +1,7 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
 
4
  from functools import lru_cache
5
  import logging
6
  from pathlib import Path
@@ -8,6 +9,7 @@ import platform
8
  import shutil
9
  from typing import Tuple
10
  import zipfile
 
11
 
12
  import gradio as gr
13
  from huggingface_hub import snapshot_download
@@ -18,6 +20,8 @@ from project_settings import environment, project_path, log_directory
18
  from toolbox.os.command import Command
19
  from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
20
  from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN
 
 
21
 
22
  log.setup_size_rotating(log_directory=log_directory)
23
 
@@ -63,7 +67,7 @@ def shell(cmd: str):
63
 
64
  denoise_engines = {
65
  "dfnet-nx-dns3": {
66
- "infer_cls": InferenceFRCRN,
67
  "kwargs": {
68
  "pretrained_model_path_or_zip_file": (project_path / "trained_models/dfnet-nx-dns3.zip").as_posix()
69
  }
@@ -99,6 +103,7 @@ def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_
99
  noisy_audio_t: Tuple = noisy_audio_file_t or noisy_audio_microphone_t
100
 
101
  sample_rate, signal = noisy_audio_t
 
102
 
103
  # Test: 使用 microphone 时,显示采样率是 44100,但 signal 实际是按 8000 的采样率的。
104
  logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}")
@@ -114,13 +119,25 @@ def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_
114
  kwargs = infer_engine_param["kwargs"]
115
  infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs)
116
 
 
117
  enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
 
 
 
 
 
 
 
 
 
 
 
118
  enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
119
  except Exception as e:
120
  raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.")
121
 
122
  enhanced_audio_t = (sample_rate, enhanced_audio)
123
- return enhanced_audio_t
124
 
125
 
126
  def main():
@@ -177,16 +194,17 @@ def main():
177
  dn_button = gr.Button(variant="primary")
178
  with gr.Column(variant="panel", scale=5):
179
  dn_enhanced_audio = gr.Audio(label="enhanced_audio")
 
180
 
181
  dn_button.click(
182
  when_click_denoise_button,
183
  inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
184
- outputs=[dn_enhanced_audio]
185
  )
186
  gr.Examples(
187
  examples=examples,
188
  inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
189
- outputs=[dn_enhanced_audio],
190
  fn=when_click_denoise_button,
191
  # cache_examples=True,
192
  # cache_mode="lazy",
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import argparse
4
+ import json
5
  from functools import lru_cache
6
  import logging
7
  from pathlib import Path
 
9
  import shutil
10
  from typing import Tuple
11
  import zipfile
12
+ import time
13
 
14
  import gradio as gr
15
  from huggingface_hub import snapshot_download
 
20
  from toolbox.os.command import Command
21
  from toolbox.torchaudio.models.mpnet.inference_mpnet import InferenceMPNet
22
  from toolbox.torchaudio.models.frcrn.inference_frcrn import InferenceFRCRN
23
+ from toolbox.torchaudio.models.dfnet.inference_dfnet import InferenceDfNet
24
+
25
 
26
  log.setup_size_rotating(log_directory=log_directory)
27
 
 
67
 
68
  denoise_engines = {
69
  "dfnet-nx-dns3": {
70
+ "infer_cls": InferenceDfNet,
71
  "kwargs": {
72
  "pretrained_model_path_or_zip_file": (project_path / "trained_models/dfnet-nx-dns3.zip").as_posix()
73
  }
 
103
  noisy_audio_t: Tuple = noisy_audio_file_t or noisy_audio_microphone_t
104
 
105
  sample_rate, signal = noisy_audio_t
106
+ audio_duration = signal.shape[-1] // 8000
107
 
108
  # Test: 使用 microphone 时,显示采样率是 44100,但 signal 实际是按 8000 的采样率的。
109
  logger.info(f"run denoise; engine: {engine}, sample_rate: {sample_rate}, signal dtype: {signal.dtype}, signal shape: {signal.shape}")
 
119
  kwargs = infer_engine_param["kwargs"]
120
  infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs)
121
 
122
+ begin = time.time()
123
  enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
124
+ time_cost = time.time() - begin
125
+
126
+ fpr = time_cost / audio_duration
127
+
128
+ info = {
129
+ "time_cost": round(time_cost, 4),
130
+ "audio_duration": round(audio_duration, 4),
131
+ "fpr": round(fpr, 4)
132
+ }
133
+ message = json.dumps(info, ensure_ascii=False, indent=4)
134
+
135
  enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
136
  except Exception as e:
137
  raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.")
138
 
139
  enhanced_audio_t = (sample_rate, enhanced_audio)
140
+ return enhanced_audio_t, message
141
 
142
 
143
  def main():
 
194
  dn_button = gr.Button(variant="primary")
195
  with gr.Column(variant="panel", scale=5):
196
  dn_enhanced_audio = gr.Audio(label="enhanced_audio")
197
+ dn_message = gr.Textbox(lines=1, max_lines=20, label="message")
198
 
199
  dn_button.click(
200
  when_click_denoise_button,
201
  inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
202
+ outputs=[dn_enhanced_audio, dn_message]
203
  )
204
  gr.Examples(
205
  examples=examples,
206
  inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
207
+ outputs=[dn_enhanced_audio, dn_message],
208
  fn=when_click_denoise_button,
209
  # cache_examples=True,
210
  # cache_mode="lazy",