roychao19477 commited on
Commit
ea5c419
Β·
1 Parent(s): 18c8531
Files changed (1) hide show
  1. app.py +25 -27
app.py CHANGED
@@ -56,47 +56,45 @@ model.eval()
56
 
57
 
58
  @spaces.GPU
59
- def enhance(audio, do_pcs):
 
60
  orig_sr, wav_np = audio
61
- # 1) resample to 16 kHz if needed
62
  if orig_sr != sr:
63
  wav_np = librosa.resample(wav_np, orig_sr, sr)
64
  wav = torch.from_numpy(wav_np).float().to(device)
 
 
65
 
66
- # normalize
67
- norm = torch.sqrt(len(wav) / torch.sum(wav**2))
68
- wav = (wav * norm).unsqueeze(0)
69
-
70
- # STFT β†’ model β†’ ISTFT
71
- amp, pha, _ = mag_phase_stft(wav, n_fft, hop_size, win_size, compress_ff)
72
  amp_g, pha_g = model(amp, pha)
73
- out = mag_phase_istft(amp_g, pha_g, n_fft, hop_size, win_size, compress_ff)
74
  out = (out / norm).squeeze().cpu().numpy()
75
 
76
- # optional PCS filter
77
- if do_pcs:
78
- out = cal_pcs(out)
79
-
80
- # 2) resample back to original rate
81
  if orig_sr != sr:
82
  out = librosa.resample(out, sr, orig_sr)
83
 
84
- return orig_sr, out
 
 
 
 
 
85
 
86
- with gr.Blocks() as demo:
87
- gr.Markdown("## SEMamba Speech Enhancement demo")
88
- with gr.Row():
89
- upload = gr.Audio(label="Upload WAV", type="numpy")
90
- record = gr.Audio(label="Record via mic", type="numpy")
91
- pcs = gr.Checkbox(label="Apply PCS post-processing", value=False)
92
- btn = gr.Button("Enhance")
93
- out = gr.Audio(label="Enhanced WAV", type="numpy")
94
 
95
- @spaces.GPU
96
- def runner(up, rec, do_pcs):
97
- return enhance(up if up else rec, do_pcs)
 
 
 
 
 
98
 
99
- btn.click(runner, [upload, record, pcs], out)
 
 
100
 
 
101
 
102
  demo.launch()
 
56
 
57
 
58
  @spaces.GPU
59
+ def enhance_and_plot(audio):
60
+ if audio is None: return None, None
61
  orig_sr, wav_np = audio
 
62
  if orig_sr != sr:
63
  wav_np = librosa.resample(wav_np, orig_sr, sr)
64
  wav = torch.from_numpy(wav_np).float().to(device)
65
+ norm = torch.sqrt(len(wav)/torch.sum(wav**2))
66
+ wav = (wav * norm).unsqueeze(0)
67
 
68
+ amp, pha, _ = mag_phase_stft(wav, **stft_cfg, compress_factor=model_cfg["compress_factor"])
 
 
 
 
 
69
  amp_g, pha_g = model(amp, pha)
70
+ out = mag_phase_istft(amp_g, pha_g, **stft_cfg, compress_factor=model_cfg["compress_factor"])
71
  out = (out / norm).squeeze().cpu().numpy()
72
 
 
 
 
 
 
73
  if orig_sr != sr:
74
  out = librosa.resample(out, sr, orig_sr)
75
 
76
+ D = librosa.stft(out, n_fft=1024, hop_length=512)
77
+ S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
78
+ fig, ax = plt.subplots()
79
+ img = librosa.display.specshow(S_db, sr=orig_sr, hop_length=512, x_axis='time', y_axis='hz', ax=ax)
80
+ plt.colorbar(img, ax=ax, format="%+2.0f dB")
81
+ ax.set_title("Enhanced Output Spectrum")
82
 
83
+ return (orig_sr, out), fig
 
 
 
 
 
 
 
84
 
85
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
86
+ gr.Markdown("# 🎧 SEMamba Speech Enhancement")
87
+ gr.Markdown("Upload or record a noisy audio sample to enhance it and view the spectrogram.")
88
+
89
+ with gr.Row():
90
+ with gr.Column():
91
+ input_audio = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Noisy Input")
92
+ btn = gr.Button("Enhance")
93
 
94
+ with gr.Column():
95
+ output_audio = gr.Audio(label="Enhanced Output", type="numpy")
96
+ spectrum = gr.Plot(label="Spectrogram")
97
 
98
+ btn.click(enhance_and_plot, inputs=input_audio, outputs=[output_audio, spectrum])
99
 
100
  demo.launch()