HoneyTian commited on
Commit
fa467b8
·
1 Parent(s): 19b9289
Files changed (1) hide show
  1. main.py +22 -16
main.py CHANGED
@@ -58,7 +58,20 @@ def shell(cmd: str):
58
  return Command.popen(cmd)
59
 
60
 
61
- denoise_engines = dict()
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  def when_click_denoise_button(noisy_audio_t, engine: str):
@@ -67,11 +80,15 @@ def when_click_denoise_button(noisy_audio_t, engine: str):
67
 
68
  noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
69
 
70
- infer_engine = denoise_engines.get(engine)
71
- if infer_engine is None:
72
  raise gr.Error(f"invalid denoise engine: {engine}.")
73
 
74
  try:
 
 
 
 
75
  enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
76
  enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
77
  except Exception as e:
@@ -96,17 +113,6 @@ def main():
96
  token=args.hf_token,
97
  )
98
 
99
- # engines
100
- global denoise_engines
101
- denoise_engines = {
102
- "mpnet-aishell-1-epoch": InferenceMPNet(
103
- pretrained_model_path_or_zip_file=(project_path / "trained_models/mpnet-aishell-1-epoch.zip").as_posix(),
104
- ),
105
- "mpnet-aishell-11-epoch": InferenceMPNet(
106
- pretrained_model_path_or_zip_file=(project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix(),
107
- ),
108
- }
109
-
110
  # choices
111
  denoise_engine_choices = list(denoise_engines.keys())
112
 
@@ -150,8 +156,8 @@ def main():
150
  inputs=[dn_noisy_audio, dn_engine],
151
  outputs=[dn_enhanced_audio],
152
  fn=when_click_denoise_button,
153
- cache_examples=True,
154
- cache_mode="lazy",
155
  )
156
 
157
  with gr.TabItem("shell"):
 
58
  return Command.popen(cmd)
59
 
60
 
61
+ denoise_engines = {
62
+ "mpnet-aishell-1-epoch": {
63
+ "infer_cls": InferenceMPNet,
64
+ "kwargs": {
65
+ "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-1-epoch.zip").as_posix()
66
+ }
67
+ },
68
+ "mpnet-aishell-11-epoch": {
69
+ "infer_cls": InferenceMPNet,
70
+ "kwargs": {
71
+ "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix()
72
+ }
73
+ },
74
+ }
75
 
76
 
77
  def when_click_denoise_button(noisy_audio_t, engine: str):
 
80
 
81
  noisy_audio = np.array(signal / (1 << 15), dtype=np.float32)
82
 
83
+ infer_engine_param = denoise_engines.get(engine)
84
+ if infer_engine_param is None:
85
  raise gr.Error(f"invalid denoise engine: {engine}.")
86
 
87
  try:
88
+ infer_cls = infer_engine_param["infer_cls"]
89
+ kwargs = infer_engine_param["kwargs"]
90
+ infer_engine = infer_cls(**kwargs)
91
+
92
  enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
93
  enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
94
  except Exception as e:
 
113
  token=args.hf_token,
114
  )
115
 
 
 
 
 
 
 
 
 
 
 
 
116
  # choices
117
  denoise_engine_choices = list(denoise_engines.keys())
118
 
 
156
  inputs=[dn_noisy_audio, dn_engine],
157
  outputs=[dn_enhanced_audio],
158
  fn=when_click_denoise_button,
159
+ # cache_examples=True,
160
+ # cache_mode="lazy",
161
  )
162
 
163
  with gr.TabItem("shell"):