roychao19477 commited on
Commit
81e7d3e
Β·
1 Parent(s): 23813e6
Files changed (1) hide show
  1. app.py +34 -29
app.py CHANGED
@@ -24,12 +24,6 @@ from models.stfts import mag_phase_stft, mag_phase_istft
24
  from models.generator import SEMamba
25
  from models.pcs400 import cal_pcs
26
 
27
- # download model files from your HF repo
28
- #ckpt = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
29
- # "ckpts/SEMamba_advanced.pth")
30
- #cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
31
- # "recipes/SEMamba_advanced.yaml")
32
-
33
  ckpt = "ckpts/SEMamba_advanced.pth"
34
  cfg_f = "recipes/SEMamba_advanced.yaml"
35
 
@@ -60,38 +54,49 @@ load_model()
60
 
61
  @spaces.GPU
62
  def enhance(filepath):
63
- wav_np, orig_sr = librosa.load(filepath, sr=None)
 
64
  if orig_sr != SR:
65
- wav_np = librosa.resample(wav_np, orig_sr, SR)
66
- wav = torch.from_numpy(wav_np).float().to(device)
67
- norm = torch.sqrt(len(wav) / torch.sum(wav**2))
68
- wav = (wav * norm).unsqueeze(0)
69
- amp, pha, _ = mag_phase_stft(wav, **stft_cfg, compress_factor=model_cfg["compress_factor"])
70
- amp_g, pha_g = model(amp, pha)
71
- out = mag_phase_istft(amp_g, pha_g, **stft_cfg, compress_factor=model_cfg["compress_factor"])
72
- out = (out / norm).squeeze().cpu().numpy()
 
 
 
73
  if orig_sr != SR:
74
  out = librosa.resample(out, SR, orig_sr)
 
75
  sf.write("enhanced.wav", out, orig_sr)
76
-
77
  D = librosa.stft(out, n_fft=1024, hop_length=512)
78
  S = librosa.amplitude_to_db(np.abs(D), ref=np.max)
79
- fig, ax = plt.subplots(figsize=(6, 3))
80
  librosa.display.specshow(S, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=ax)
81
  ax.set_title("Enhanced Spectrogram")
82
  plt.colorbar(format="%+2.0f dB", ax=ax)
 
83
 
84
- return out, fig
 
 
 
 
85
 
86
- # --- Build demo ---
87
  with gr.Blocks() as demo:
88
- audio_in = gr.Audio(label="Input Audio", type="filepath", interactive=True)
89
- run_btn = gr.Button("Enhance", variant="primary")
90
- enhanced_audio = gr.Audio(label="Enhanced Audio", type="filepath", interactive=False)
91
- spectrogram = gr.Plot(label="Spectrogram")
92
-
93
- run_btn.click(fn=enhance, inputs=audio_in, outputs=[enhanced_audio, spectrogram])
94
-
95
- gr.Markdown("Unofficial demo by [yourname](https://github.com/RoyChao19477)")
96
-
97
- demo.launch()
 
 
 
24
  from models.generator import SEMamba
25
  from models.pcs400 import cal_pcs
26
 
 
 
 
 
 
 
27
  ckpt = "ckpts/SEMamba_advanced.pth"
28
  cfg_f = "recipes/SEMamba_advanced.yaml"
29
 
 
54
 
55
  @spaces.GPU
56
  def enhance(filepath):
57
+ # load & (if needed) resample to model SR
58
+ wav, orig_sr = librosa.load(filepath, sr=None)
59
  if orig_sr != SR:
60
+ wav = librosa.resample(wav, orig_sr, SR)
61
+ # normalize β†’ tensor
62
+ x = torch.from_numpy(wav).float().to(device)
63
+ norm = torch.sqrt(len(x)/torch.sum(x**2))
64
+ x = (x*norm).unsqueeze(0)
65
+ # STFT β†’ model β†’ ISTFT
66
+ amp,pha,_ = mag_phase_stft(x, **stft_cfg, compress_factor=model_cfg["compress_factor"])
67
+ amp2,pha2 = model(amp, pha)
68
+ out = mag_phase_istft(amp2, pha2, **stft_cfg, compress_factor=model_cfg["compress_factor"])
69
+ out = (out/norm).squeeze().cpu().numpy()
70
+ # back to original rate
71
  if orig_sr != SR:
72
  out = librosa.resample(out, SR, orig_sr)
73
+ # write file
74
  sf.write("enhanced.wav", out, orig_sr)
75
+ # build spectrogram
76
  D = librosa.stft(out, n_fft=1024, hop_length=512)
77
  S = librosa.amplitude_to_db(np.abs(D), ref=np.max)
78
+ fig, ax = plt.subplots(figsize=(6,3))
79
  librosa.display.specshow(S, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=ax)
80
  ax.set_title("Enhanced Spectrogram")
81
  plt.colorbar(format="%+2.0f dB", ax=ax)
82
+ return "enhanced.wav", fig
83
 
84
+ ABOUT = """
85
+ # SEMamba: Speech Enhancement
86
+ A Mamba-based model that denoises real-world audio.
87
+ Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram.
88
+ """
89
 
 
90
  with gr.Blocks() as demo:
91
+ gr.Markdown(ABOUT)
92
+ audio_in = gr.Audio(sources=["upload","microphone"],
93
+ type="filepath",
94
+ label="Your Noisy Audio",
95
+ interactive=True)
96
+ run_button = gr.Button("Enhance", variant="primary")
97
+ enhanced_out = gr.Audio(label="Enhanced Audio", type="filepath", interactive=False)
98
+ spec_out = gr.Plot(label="Spectrogram")
99
+
100
+ run_button.click(enhance, inputs=audio_in, outputs=[enhanced_out, spec_out])
101
+
102
+ demo.queue().launch()