roychao19477 commited on
Commit
e231b3a
Β·
1 Parent(s): c58812e
Files changed (1) hide show
  1. app.py +17 -22
app.py CHANGED
@@ -59,50 +59,45 @@ model.eval()
59
 
60
  @spaces.GPU
61
  def enhance(filepath):
62
- if filepath is None: return None, None
63
  wav_np, orig_sr = librosa.load(filepath, sr=None)
64
- if orig_sr != sr:
65
- wav_np = librosa.resample(wav_np, orig_sr, sr)
 
66
  wav = torch.from_numpy(wav_np).float().to(device)
67
  norm = torch.sqrt(len(wav) / torch.sum(wav**2))
68
  wav = (wav * norm).unsqueeze(0)
69
-
70
  amp, pha, _ = mag_phase_stft(wav, **stft_cfg, compress_factor=model_cfg["compress_factor"])
71
  amp_g, pha_g = model(amp, pha)
72
  out = mag_phase_istft(amp_g, pha_g, **stft_cfg, compress_factor=model_cfg["compress_factor"])
73
  out = (out / norm).squeeze().cpu().numpy()
74
- if orig_sr != sr:
75
- out = librosa.resample(out, sr, orig_sr)
76
-
77
- # write output to temp file
78
- out_path = "enhanced_output.wav"
79
  sf.write(out_path, out, orig_sr)
80
-
81
- # plot spectrum
82
  D = librosa.stft(out, n_fft=1024, hop_length=512)
83
  S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
84
- fig, ax = plt.subplots(figsize=(6, 3))
85
- librosa.display.specshow(S_db, sr=orig_sr, hop_length=512, x_axis='time', y_axis='hz', ax=ax)
 
86
  ax.set_title("Enhanced Spectrogram")
87
  plt.colorbar(format="%+2.0f dB", ax=ax)
88
-
89
  return out_path, fig
90
 
91
- # --- Gradio Interface ---
92
  demo = gr.Interface(
93
  fn=enhance,
94
- inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Input Audio"),
95
  outputs=[
96
  gr.Audio(label="Enhanced Audio", type="filepath"),
97
  gr.Plot(label="Spectrogram")
98
  ],
99
  title="<a href='https://github.com/RoyChao19477/SEMamba' target='_blank'>SEMamba</a>: Speech Enhancement",
100
- description="Upload or record a noisy audio file and SEMamba will enhance the speech and show its spectrogram.",
101
- article="<p style='text-align: center'><a href='https://arxiv.org/abs/2405.15144' target='_blank'>πŸ“„ SEMamba: Mamba for Long-Context Speech Enhancement (SLT 2024)</a></p>",
102
- examples=[
103
- ["examples/noisy_sample_16k.wav"]
104
- ],
105
- cache_examples=True
106
  )
107
 
108
  demo.launch()
 
59
 
60
  @spaces.GPU
61
  def enhance(filepath):
62
+ # load & resample to model SR
63
  wav_np, orig_sr = librosa.load(filepath, sr=None)
64
+ if orig_sr != SR:
65
+ wav_np = librosa.resample(wav_np, orig_sr, SR)
66
+ # to tensor + normalize
67
  wav = torch.from_numpy(wav_np).float().to(device)
68
  norm = torch.sqrt(len(wav) / torch.sum(wav**2))
69
  wav = (wav * norm).unsqueeze(0)
70
+ # STFT β†’ model β†’ ISTFT
71
  amp, pha, _ = mag_phase_stft(wav, **stft_cfg, compress_factor=model_cfg["compress_factor"])
72
  amp_g, pha_g = model(amp, pha)
73
  out = mag_phase_istft(amp_g, pha_g, **stft_cfg, compress_factor=model_cfg["compress_factor"])
74
  out = (out / norm).squeeze().cpu().numpy()
75
+ # resample back
76
+ if orig_sr != SR:
77
+ out = librosa.resample(out, SR, orig_sr)
78
+ # write to temp file
79
+ out_path = "enhanced.wav"
80
  sf.write(out_path, out, orig_sr)
81
+ # compute spectrogram plot
 
82
  D = librosa.stft(out, n_fft=1024, hop_length=512)
83
  S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
84
+ fig, ax = plt.subplots(figsize=(6,3))
85
+ librosa.display.specshow(S_db, sr=orig_sr, hop_length=512,
86
+ x_axis="time", y_axis="hz", ax=ax)
87
  ax.set_title("Enhanced Spectrogram")
88
  plt.colorbar(format="%+2.0f dB", ax=ax)
 
89
  return out_path, fig
90
 
 
91
  demo = gr.Interface(
92
  fn=enhance,
93
+ inputs=gr.Audio(sources=["upload","microphone"], type="filepath", label="Input Audio"),
94
  outputs=[
95
  gr.Audio(label="Enhanced Audio", type="filepath"),
96
  gr.Plot(label="Spectrogram")
97
  ],
98
  title="<a href='https://github.com/RoyChao19477/SEMamba' target='_blank'>SEMamba</a>: Speech Enhancement",
99
+ description="Upload or record noisy speech; SEMamba enhances it and shows the spectrogram.",
100
+ article="<p style='text-align:center'><a href='https://arxiv.org/abs/2405.15144' target='_blank'>SEMamba: Mamba for Long-Context Speech Enhancement (SLT 2024)</a></p>"
 
 
 
 
101
  )
102
 
103
  demo.launch()