Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,437 Bytes
746a4fa f299c53 746a4fa d69789d 73431e1 7af2ba0 d69789d 48070ab d69789d 746a4fa f299c53 c450417 1fdb610 0601709 c450417 0efa60f 7001051 6da993c 18c8531 9657939 81e7d3e 9657939 81e7d3e 6da993c 81e7d3e 9657939 81e7d3e ff0b6ec 81e7d3e 2bbe7e3 ff0b6ec 81e7d3e 179e88e 9657939 f299c53 b9afdac 9657939 ecbb90e c90ec4a ab7d954 feda893 ecbb90e bb56b46 feda893 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import shlex
import subprocess
import spaces
import torch
import gradio as gr
# install packages for mamba
def install_mamba():
#subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118"))
#subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install numpy==1.26.4"))
install_mamba()
ABOUT = """
# SEMamba: Speech Enhancement
A Mamba-based model that denoises real-world audio.
Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram.
"""
import torch
import yaml
import librosa
import librosa.display
import matplotlib
from huggingface_hub import hf_hub_download
from models.stfts import mag_phase_stft, mag_phase_istft
from models.generator import SEMamba
from models.pcs400 import cal_pcs
ckpt = "ckpts/SEMamba_advanced.pth"
cfg_f = "recipes/SEMamba_advanced.yaml"
# load config
with open(cfg_f) as f:
cfg = yaml.safe_load(f)
stft_cfg = cfg["stft_cfg"]
model_cfg = cfg["model_cfg"]
sr = stft_cfg["sampling_rate"]
n_fft = stft_cfg["n_fft"]
hop_size = stft_cfg["hop_size"]
win_size = stft_cfg["win_size"]
compress_ff = model_cfg["compress_factor"]
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"
model = SEMamba(cfg).to(device)
sdict = torch.load(ckpt, map_location=device)
model.load_state_dict(sdict["generator"])
model.eval()
@spaces.GPU
def enhance(filepath):
# load & (if needed) resample to model SR
wav, orig_sr = librosa.load(filepath, sr=None)
if orig_sr != SR:
wav = librosa.resample(wav, orig_sr, SR)
# normalize β tensor
x = torch.from_numpy(wav).float().to(device)
norm = torch.sqrt(len(x)/torch.sum(x**2))
x = (x*norm).unsqueeze(0)
# STFT β model β ISTFT
amp,pha,_ = mag_phase_stft(x, **stft_cfg, compress_factor=model_cfg["compress_factor"])
with torch.no_grad():
amp2, pha2, comp = model(amp, pha)
out = mag_phase_istft(amp2, pha2, **stft_cfg, compress_factor=model_cfg["compress_factor"])
out = (out/norm).squeeze().cpu().numpy()
# back to original rate
if orig_sr != SR:
out = librosa.resample(out, SR, orig_sr)
# write file
sf.write("enhanced.wav", out, orig_sr)
# build spectrogram
D = librosa.stft(out, n_fft=1024, hop_length=512)
S = librosa.amplitude_to_db(np.abs(D), ref=np.max)
fig, ax = plt.subplots(figsize=(6,3))
librosa.display.specshow(S, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=ax)
ax.set_title("Enhanced Spectrogram")
plt.colorbar(format="%+2.0f dB", ax=ax)
return "enhanced.wav"#, fig
se_demo = gr.Interface(
fn = enhance,
inputs = [
gr.Audio(label="Input Audio", type="filepath"),
gr.Checkbox(label="Apply Speech Enhancement", value=True),
],
outputs = [
gr.Audio(label="Output Audio", type="filepath"),
#gr.Plot(label="Spectrogram")
],
title = "SEMamba",
cache_examples=False,
)
if __name__ == "__main__":
se_demo.launch()
|