Spaces:
Running
on
Zero
Running
on
Zero
import shlex | |
import subprocess | |
import spaces | |
import torch | |
# install packages for mamba | |
def install(): | |
print("Install personal packages", flush=True) | |
subprocess.run(shlex.split("pip install causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl")) | |
subprocess.run(shlex.split("pip install mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl")) | |
install() | |
import gradio as gr | |
import torch | |
import yaml | |
import librosa | |
from huggingface_hub import hf_hub_download | |
from models.stfts import mag_phase_stft, mag_phase_istft | |
from models.generator import SEMamba | |
from models.pcs400 import cal_pcs | |
# download model files from your HF repo | |
ckpt = hf_hub_download("rc19477/Speech_Enhancement_Mamba", | |
"ckpts/SEMamba_advanced.pth") | |
cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba", | |
"recipes/SEMamba_advanced.yaml") | |
# load config | |
with open(cfg_f) as f: | |
cfg = yaml.safe_load(f) | |
stft_cfg = cfg["stft_cfg"] | |
model_cfg = cfg["model_cfg"] | |
sr = stft_cfg["sampling_rate"] | |
n_fft = stft_cfg["n_fft"] | |
hop_size = stft_cfg["hop_size"] | |
win_size = stft_cfg["win_size"] | |
compress_ff = model_cfg["compress_factor"] | |
# init model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = SEMamba(cfg).to(device) | |
sdict = torch.load(ckpt, map_location=device) | |
model.load_state_dict(sdict["generator"]) | |
model.eval() | |
def enhance(audio, do_pcs): | |
orig_sr, wav_np = audio | |
# 1) resample to 16 kHz if needed | |
if orig_sr != sr: | |
wav_np = librosa.resample(wav_np, orig_sr, sr) | |
wav = torch.from_numpy(wav_np).float().to(device) | |
# normalize | |
norm = torch.sqrt(len(wav) / torch.sum(wav**2)) | |
wav = (wav * norm).unsqueeze(0) | |
# STFT β model β ISTFT | |
amp, pha, _ = mag_phase_stft(wav, n_fft, hop_size, win_size, compress_ff) | |
amp_g, pha_g = model(amp, pha) | |
out = mag_phase_istft(amp_g, pha_g, n_fft, hop_size, win_size, compress_ff) | |
out = (out / norm).squeeze().cpu().numpy() | |
# optional PCS filter | |
if do_pcs: | |
out = cal_pcs(out) | |
# 2) resample back to original rate | |
if orig_sr != sr: | |
out = librosa.resample(out, sr, orig_sr) | |
return orig_sr, out | |
demo = gr.Interface( | |
fn=enhance, | |
inputs=[ | |
gr.Audio(source="upload", type="numpy", label="Noisy wav"), | |
gr.Checkbox(label="Apply PCS post-processing", value=False), | |
], | |
outputs=gr.Audio(type="numpy", label="Enhanced wav"), | |
title="SEMamba Speech Enhancement", | |
description="Upload a noisy WAV; tick **Apply PCS** for the pcs400 filter.", | |
) | |
demo.launch() | |