Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,308 Bytes
746a4fa d69789d 73431e1 7af2ba0 d69789d 48070ab 0efa60f 746a4fa d69789d 746a4fa c450417 1fdb610 0601709 c450417 7001051 0efa60f 7001051 18c8531 2bbe7e3 ea5c419 7001051 2bbe7e3 ea5c419 7001051 ea5c419 7001051 ea5c419 7001051 1fdb610 2bbe7e3 ea5c419 1fdb610 2bbe7e3 1fdb610 ea5c419 a1a60de 1fdb610 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import shlex
import subprocess
import spaces
import torch
# install packages for mamba
def install_mamba():
#subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118"))
#subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install numpy==1.26.4"))
subprocess.run(shlex.split("ls"))
install_mamba()
import gradio as gr
import torch
import yaml
import librosa
import librosa.display
import matplotlib
from huggingface_hub import hf_hub_download
from models.stfts import mag_phase_stft, mag_phase_istft
from models.generator import SEMamba
from models.pcs400 import cal_pcs
# download model files from your HF repo
#ckpt = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
# "ckpts/SEMamba_advanced.pth")
#cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
# "recipes/SEMamba_advanced.yaml")
ckpt = "ckpts/SEMamba_advanced.pth"
cfg_f = "recipes/SEMamba_advanced.yaml"
# load config
with open(cfg_f) as f:
cfg = yaml.safe_load(f)
stft_cfg = cfg["stft_cfg"]
model_cfg = cfg["model_cfg"]
sr = stft_cfg["sampling_rate"]
n_fft = stft_cfg["n_fft"]
hop_size = stft_cfg["hop_size"]
win_size = stft_cfg["win_size"]
compress_ff = model_cfg["compress_factor"]
# init model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SEMamba(cfg).to(device)
sdict = torch.load(ckpt, map_location=device)
model.load_state_dict(sdict["generator"])
model.eval()
@spaces.GPU
def enhance(audio):
if audio is None: return None, None
orig_sr, wav_np = audio
if orig_sr != sr:
wav_np = librosa.resample(wav_np, orig_sr, sr)
wav = torch.from_numpy(wav_np).float().to(device)
norm = torch.sqrt(len(wav) / torch.sum(wav ** 2))
wav = (wav * norm).unsqueeze(0)
amp, pha, _ = mag_phase_stft(wav, **stft_cfg, compress_factor=model_cfg["compress_factor"])
amp_g, pha_g = model(amp, pha)
out = mag_phase_istft(amp_g, pha_g, **stft_cfg, compress_factor=model_cfg["compress_factor"])
out = (out / norm).squeeze().cpu().numpy()
if orig_sr != sr:
out = librosa.resample(out, sr, orig_sr)
# spectrogram
D = librosa.stft(out, n_fft=1024, hop_length=512)
S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
fig, ax = plt.subplots(figsize=(6, 3))
librosa.display.specshow(S_db, sr=orig_sr, hop_length=512, x_axis='time', y_axis='hz', ax=ax)
ax.set_title("Enhanced Spectrogram")
plt.colorbar(format="%+2.0f dB", ax=ax)
return (orig_sr, out), fig
# --- Layout with Blocks ---
with gr.Blocks(css=".gr-box {border: none !important}") as demo:
gr.Markdown("<h1 style='text-align: center;'>π§ <a href='https://github.com/RoyChao19477/SEMamba' target='_blank'>SEMamba</a>: Speech Enhancement</h1>")
gr.Markdown("Enhance real-world noisy speech using Mamba. Upload or record an audio clip and view the spectrogram.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Upload or Record", elem_id="input-audio")
run_btn = gr.Button("Enhance Now π", variant="primary")
with gr.Column():
enhanced_audio = gr.Audio(label="Enhanced Output", type="numpy")
spec_plot = gr.Plot(label="Spectrogram")
run_btn.click(enhance, inputs=audio_input, outputs=[enhanced_audio, spec_plot])
gr.Examples(
examples=[
["examples/noisy_sample_16k.wav"],
],
inputs=audio_input,
outputs=[enhanced_audio, spec_plot],
fn=enhance,
cache_examples=True,
label="π Try These Examples"
)
gr.Markdown("<p style='text-align: center'><a href='https://arxiv.org/abs/2405.15144' target='_blank'>π SEMamba: Mamba for Long-Context Speech Enhancement (SLT 2024)</a></p>")
demo.launch()
|