Spaces:
Running
on
Zero
Running
on
Zero
import shlex | |
import subprocess | |
import spaces | |
import torch | |
import gradio as gr | |
# install packages for mamba | |
def install_mamba(): | |
#subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118")) | |
#subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) | |
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) | |
subprocess.run(shlex.split("pip install numpy==1.26.4")) | |
install_mamba() | |
ABOUT = """ | |
# SEMamba: Speech Enhancement | |
A Mamba-based model that denoises real-world audio. | |
Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram. | |
""" | |
import torch | |
import yaml | |
import librosa | |
import librosa.display | |
import matplotlib | |
from huggingface_hub import hf_hub_download | |
from models.stfts import mag_phase_stft, mag_phase_istft | |
from models.generator import SEMamba | |
from models.pcs400 import cal_pcs | |
ckpt = "ckpts/SEMamba_advanced.pth" | |
cfg_f = "recipes/SEMamba_advanced.yaml" | |
# load config | |
with open(cfg_f) as f: | |
cfg = yaml.safe_load(f) | |
stft_cfg = cfg["stft_cfg"] | |
model_cfg = cfg["model_cfg"] | |
sr = stft_cfg["sampling_rate"] | |
n_fft = stft_cfg["n_fft"] | |
hop_size = stft_cfg["hop_size"] | |
win_size = stft_cfg["win_size"] | |
compress_ff = model_cfg["compress_factor"] | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
device = "cuda" | |
model = SEMamba(cfg).to(device) | |
sdict = torch.load(ckpt, map_location=device) | |
model.load_state_dict(sdict["generator"]) | |
model.eval() | |
def enhance(filepath): | |
# load & (if needed) resample to model SR | |
wav, orig_sr = librosa.load(filepath, sr=None) | |
if orig_sr != SR: | |
wav = librosa.resample(wav, orig_sr, SR) | |
# normalize β tensor | |
x = torch.from_numpy(wav).float().to(device) | |
norm = torch.sqrt(len(x)/torch.sum(x**2)) | |
x = (x*norm).unsqueeze(0) | |
# STFT β model β ISTFT | |
amp,pha,_ = mag_phase_stft(x, **stft_cfg, compress_factor=model_cfg["compress_factor"]) | |
with torch.no_grad(): | |
amp2, pha2, comp = model(amp, pha) | |
out = mag_phase_istft(amp2, pha2, **stft_cfg, compress_factor=model_cfg["compress_factor"]) | |
out = (out/norm).squeeze().cpu().numpy() | |
# back to original rate | |
if orig_sr != SR: | |
out = librosa.resample(out, SR, orig_sr) | |
# write file | |
sf.write("enhanced.wav", out, orig_sr) | |
# build spectrogram | |
D = librosa.stft(out, n_fft=1024, hop_length=512) | |
S = librosa.amplitude_to_db(np.abs(D), ref=np.max) | |
fig, ax = plt.subplots(figsize=(6,3)) | |
librosa.display.specshow(S, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=ax) | |
ax.set_title("Enhanced Spectrogram") | |
plt.colorbar(format="%+2.0f dB", ax=ax) | |
return "enhanced.wav"#, fig | |
se_demo = gr.Interface( | |
fn = enhance, | |
inputs = [ | |
gr.Audio(label="Input Audio", type="filepath"), | |
gr.Checkbox(label="Apply Speech Enhancement", value=True), | |
], | |
outputs = [ | |
gr.Audio(label="Output Audio", type="filepath"), | |
#gr.Plot(label="Spectrogram") | |
], | |
title = "SEMamba", | |
cache_examples=False, | |
) | |
if __name__ == "__main__": | |
se_demo.launch() | |