roychao19477 commited on
Commit
2bbe7e3
·
1 Parent(s): 306d2bd
Files changed (1) hide show
  1. app.py +27 -24
app.py CHANGED
@@ -56,45 +56,48 @@ model.eval()
56
 
57
 
58
  @spaces.GPU
59
- def enhance_and_plot(audio):
 
60
  if audio is None: return None, None
61
  orig_sr, wav_np = audio
62
  if orig_sr != sr:
63
  wav_np = librosa.resample(wav_np, orig_sr, sr)
64
  wav = torch.from_numpy(wav_np).float().to(device)
65
- norm = torch.sqrt(len(wav)/torch.sum(wav**2))
66
  wav = (wav * norm).unsqueeze(0)
67
 
68
  amp, pha, _ = mag_phase_stft(wav, **stft_cfg, compress_factor=model_cfg["compress_factor"])
69
  amp_g, pha_g = model(amp, pha)
70
  out = mag_phase_istft(amp_g, pha_g, **stft_cfg, compress_factor=model_cfg["compress_factor"])
71
  out = (out / norm).squeeze().cpu().numpy()
72
-
73
  if orig_sr != sr:
74
  out = librosa.resample(out, sr, orig_sr)
75
 
76
- D = librosa.stft(out, n_fft=512, hop_length=256)
 
77
  S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
78
  fig, ax = plt.subplots()
79
- img = librosa.display.specshow(S_db, sr=orig_sr, hop_length=256, x_axis='time', y_axis='hz', ax=ax)
80
- plt.colorbar(img, ax=ax, format="%+2.0f dB")
81
- ax.set_title("Enhanced Output Spectrum")
82
-
83
  return (orig_sr, out), fig
84
 
85
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
86
- gr.Markdown("# 🎧 SEMamba Speech Enhancement")
87
- gr.Markdown("Upload or record a noisy audio sample to enhance it and view the spectrogram.")
88
-
89
- with gr.Row():
90
- with gr.Column():
91
- input_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Noisy Input")
92
- btn = gr.Button("Enhance")
93
-
94
- with gr.Column():
95
- output_audio = gr.Audio(label="Enhanced Output", type="numpy")
96
- spectrum = gr.Plot(label="Spectrogram")
97
-
98
- btn.click(enhance_and_plot, inputs=input_audio, outputs=[output_audio, spectrum])
99
-
100
- demo.launch()
 
 
 
 
56
 
57
 
58
  @spaces.GPU
59
+ # --- Inference ---
60
+ def enhance(audio):
61
  if audio is None: return None, None
62
  orig_sr, wav_np = audio
63
  if orig_sr != sr:
64
  wav_np = librosa.resample(wav_np, orig_sr, sr)
65
  wav = torch.from_numpy(wav_np).float().to(device)
66
+ norm = torch.sqrt(len(wav) / torch.sum(wav ** 2))
67
  wav = (wav * norm).unsqueeze(0)
68
 
69
  amp, pha, _ = mag_phase_stft(wav, **stft_cfg, compress_factor=model_cfg["compress_factor"])
70
  amp_g, pha_g = model(amp, pha)
71
  out = mag_phase_istft(amp_g, pha_g, **stft_cfg, compress_factor=model_cfg["compress_factor"])
72
  out = (out / norm).squeeze().cpu().numpy()
 
73
  if orig_sr != sr:
74
  out = librosa.resample(out, sr, orig_sr)
75
 
76
+ # draw spectrum
77
+ D = librosa.stft(out, n_fft=1024, hop_length=512)
78
  S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
79
  fig, ax = plt.subplots()
80
+ librosa.display.specshow(S_db, sr=orig_sr, hop_length=512, x_axis='time', y_axis='hz', ax=ax)
81
+ ax.set_title("Enhanced Spectrogram")
82
+ plt.colorbar(format="%+2.0f dB")
 
83
  return (orig_sr, out), fig
84
 
85
+ # --- Interface ---
86
+ se_demo = gr.Interface(
87
+ fn=enhance,
88
+ inputs=gr.Audio(sources=["upload", "microphone"], type="numpy", label="Input Audio"),
89
+ outputs=[
90
+ gr.Audio(label="Enhanced Audio", type="numpy"),
91
+ gr.Plot(label="Spectrogram")
92
+ ],
93
+ title="<a href='https://github.com/RoyChao19477/SEMamba' target='_blank'>SEMamba</a>: Speech Enhancement",
94
+ description="SEMamba is a state-space model for real-world noisy speech enhancement. Upload or record a noisy sample to hear the result and view the spectrogram.",
95
+ article="<p style='text-align: center'><a href='https://arxiv.org/abs/2405.15144' target='_blank'>SEMamba: Mamba for Long-Context Speech Enhancement (SLT 2024)</a></p>",
96
+ examples=[
97
+ ["examples/noisy_sample_16k.wav"]
98
+ ],
99
+ cache_examples=True
100
+ )
101
+
102
+ # --- Launch ---
103
+ se_demo.launch()