roychao19477
Upload
0601709
raw
history blame
4.31 kB
import shlex
import subprocess
import spaces
import torch
# install packages for mamba
def install_mamba():
#subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118"))
#subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install numpy==1.26.4"))
subprocess.run(shlex.split("ls"))
install_mamba()
import gradio as gr
import torch
import yaml
import librosa
import librosa.display
import matplotlib
from huggingface_hub import hf_hub_download
from models.stfts import mag_phase_stft, mag_phase_istft
from models.generator import SEMamba
from models.pcs400 import cal_pcs
# download model files from your HF repo
#ckpt = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
# "ckpts/SEMamba_advanced.pth")
#cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
# "recipes/SEMamba_advanced.yaml")
ckpt = "ckpts/SEMamba_advanced.pth"
cfg_f = "recipes/SEMamba_advanced.yaml"
# load config
with open(cfg_f) as f:
cfg = yaml.safe_load(f)
stft_cfg = cfg["stft_cfg"]
model_cfg = cfg["model_cfg"]
sr = stft_cfg["sampling_rate"]
n_fft = stft_cfg["n_fft"]
hop_size = stft_cfg["hop_size"]
win_size = stft_cfg["win_size"]
compress_ff = model_cfg["compress_factor"]
# init model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SEMamba(cfg).to(device)
sdict = torch.load(ckpt, map_location=device)
model.load_state_dict(sdict["generator"])
model.eval()
@spaces.GPU
def enhance(audio):
if audio is None: return None, None
orig_sr, wav_np = audio
if orig_sr != sr:
wav_np = librosa.resample(wav_np, orig_sr, sr)
wav = torch.from_numpy(wav_np).float().to(device)
norm = torch.sqrt(len(wav) / torch.sum(wav ** 2))
wav = (wav * norm).unsqueeze(0)
amp, pha, _ = mag_phase_stft(wav, **stft_cfg, compress_factor=model_cfg["compress_factor"])
amp_g, pha_g = model(amp, pha)
out = mag_phase_istft(amp_g, pha_g, **stft_cfg, compress_factor=model_cfg["compress_factor"])
out = (out / norm).squeeze().cpu().numpy()
if orig_sr != sr:
out = librosa.resample(out, sr, orig_sr)
# spectrogram
D = librosa.stft(out, n_fft=1024, hop_length=512)
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
fig, ax = plt.subplots(figsize=(6, 3))
librosa.display.specshow(S_db, sr=orig_sr, hop_length=512, x_axis='time', y_axis='hz', ax=ax)
ax.set_title("Enhanced Spectrogram")
plt.colorbar(format="%+2.0f dB", ax=ax)
return (orig_sr, out), fig
# --- Layout with Blocks ---
with gr.Blocks(css=".gr-box {border: none !important}") as demo:
gr.Markdown("<h1 style='text-align: center;'>🎧 <a href='https://github.com/RoyChao19477/SEMamba' target='_blank'>SEMamba</a>: Speech Enhancement</h1>")
gr.Markdown("Enhance real-world noisy speech using Mamba. Upload or record an audio clip and view the spectrogram.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Upload or Record", elem_id="input-audio")
run_btn = gr.Button("Enhance Now πŸš€", variant="primary")
with gr.Column():
enhanced_audio = gr.Audio(label="Enhanced Output", type="numpy")
spec_plot = gr.Plot(label="Spectrogram")
run_btn.click(enhance, inputs=audio_input, outputs=[enhanced_audio, spec_plot])
gr.Examples(
examples=[
["examples/noisy_sample_16k.wav"],
],
inputs=audio_input,
outputs=[enhanced_audio, spec_plot],
fn=enhance,
cache_examples=True,
label="πŸ“‚ Try These Examples"
)
gr.Markdown("<p style='text-align: center'><a href='https://arxiv.org/abs/2405.15144' target='_blank'>πŸ“„ SEMamba: Mamba for Long-Context Speech Enhancement (SLT 2024)</a></p>")
demo.launch()