import shlex import subprocess import spaces import torch # install packages for mamba def install_mamba(): #subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118")) #subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) subprocess.run(shlex.split("pip install numpy==1.26.4")) subprocess.run(shlex.split("ls")) install_mamba() import gradio as gr import torch import yaml import librosa import librosa.display import matplotlib from huggingface_hub import hf_hub_download from models.stfts import mag_phase_stft, mag_phase_istft from models.generator import SEMamba from models.pcs400 import cal_pcs # download model files from your HF repo #ckpt = hf_hub_download("rc19477/Speech_Enhancement_Mamba", # "ckpts/SEMamba_advanced.pth") #cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba", # "recipes/SEMamba_advanced.yaml") ckpt = "ckpts/SEMamba_advanced.pth" cfg_f = "recipes/SEMamba_advanced.yaml" # load config with open(cfg_f) as f: cfg = yaml.safe_load(f) stft_cfg = cfg["stft_cfg"] model_cfg = cfg["model_cfg"] sr = stft_cfg["sampling_rate"] n_fft = stft_cfg["n_fft"] hop_size = stft_cfg["hop_size"] win_size = stft_cfg["win_size"] compress_ff = model_cfg["compress_factor"] # init model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SEMamba(cfg).to(device) sdict = torch.load(ckpt, map_location=device) model.load_state_dict(sdict["generator"]) model.eval() @spaces.GPU def enhance(audio): if audio is None: return None, None orig_sr, wav_np = audio if orig_sr != sr: wav_np = librosa.resample(wav_np, orig_sr, sr) wav = torch.from_numpy(wav_np).float().to(device) norm = torch.sqrt(len(wav) / torch.sum(wav ** 2)) wav = (wav * norm).unsqueeze(0) amp, pha, _ = mag_phase_stft(wav, **stft_cfg, compress_factor=model_cfg["compress_factor"]) amp_g, pha_g = model(amp, pha) out = mag_phase_istft(amp_g, pha_g, **stft_cfg, compress_factor=model_cfg["compress_factor"]) out = (out / norm).squeeze().cpu().numpy() if orig_sr != sr: out = librosa.resample(out, sr, orig_sr) # spectrogram D = librosa.stft(out, n_fft=1024, hop_length=512) S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max) fig, ax = plt.subplots(figsize=(6, 3)) librosa.display.specshow(S_db, sr=orig_sr, hop_length=512, x_axis='time', y_axis='hz', ax=ax) ax.set_title("Enhanced Spectrogram") plt.colorbar(format="%+2.0f dB", ax=ax) return (orig_sr, out), fig # --- Layout with Blocks --- with gr.Blocks(css=".gr-box {border: none !important}") as demo: gr.Markdown("
📄 SEMamba: Mamba for Long-Context Speech Enhancement (SLT 2024)
") demo.launch()