HoneyTian commited on
Commit
637d40c
·
1 Parent(s): cdf219b
Files changed (1) hide show
  1. main.py +11 -10
main.py CHANGED
@@ -60,31 +60,30 @@ def shell(cmd: str):
60
 
61
 
62
  denoise_engines = {
63
- "mpnet-aishell-1-epoch": {
64
  "infer_cls": InferenceMPNet,
65
  "kwargs": {
66
- "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-1-epoch.zip").as_posix()
 
67
  }
68
  },
69
- "mpnet-aishell-11-epoch": {
70
  "infer_cls": InferenceMPNet,
71
  "kwargs": {
72
- "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix()
73
  }
74
  },
75
- "mpnet-nx-speech-1-epoch": {
76
  "infer_cls": InferenceMPNet,
77
  "kwargs": {
78
- "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-nx-speech-1-epoch.zip").as_posix()
79
  }
80
  },
81
  }
82
 
83
 
84
  @lru_cache(maxsize=3)
85
- def load_denoise_model(infer_engine_param: dict):
86
- infer_cls = infer_engine_param["infer_cls"]
87
- kwargs = infer_engine_param["kwargs"]
88
  infer_engine = infer_cls(**kwargs)
89
 
90
  return infer_engine
@@ -101,7 +100,9 @@ def when_click_denoise_button(noisy_audio_t, engine: str):
101
  raise gr.Error(f"invalid denoise engine: {engine}.")
102
 
103
  try:
104
- infer_engine = load_denoise_model(infer_engine_param)
 
 
105
 
106
  enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
107
  enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
 
60
 
61
 
62
  denoise_engines = {
63
+ "mpnet-nx-speech-1-epoch": {
64
  "infer_cls": InferenceMPNet,
65
  "kwargs": {
66
+ "pretrained_model_path_or_zip_file": (
67
+ project_path / "trained_models/mpnet-nx-speech-1-epoch.zip").as_posix()
68
  }
69
  },
70
+ "mpnet-aishell-1-epoch": {
71
  "infer_cls": InferenceMPNet,
72
  "kwargs": {
73
+ "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-1-epoch.zip").as_posix()
74
  }
75
  },
76
+ "mpnet-aishell-11-epoch": {
77
  "infer_cls": InferenceMPNet,
78
  "kwargs": {
79
+ "pretrained_model_path_or_zip_file": (project_path / "trained_models/mpnet-aishell-11-epoch.zip").as_posix()
80
  }
81
  },
82
  }
83
 
84
 
85
  @lru_cache(maxsize=3)
86
+ def load_denoise_model(infer_cls, **kwargs):
 
 
87
  infer_engine = infer_cls(**kwargs)
88
 
89
  return infer_engine
 
100
  raise gr.Error(f"invalid denoise engine: {engine}.")
101
 
102
  try:
103
+ infer_cls = infer_engine_param["infer_cls"]
104
+ kwargs = infer_engine_param["kwargs"]
105
+ infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs)
106
 
107
  enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
108
  enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)