Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,033 Bytes
56efbc8 8bb81da 56efbc8 c7fba2d 56efbc8 18c8531 56efbc8 3af0ebe 56efbc8 9d66cc0 56efbc8 3c23ad1 3af0ebe 3c23ad1 3af0ebe 56efbc8 9d66cc0 3af0ebe 56efbc8 c7fba2d 56efbc8 3af0ebe 56efbc8 3af0ebe 56efbc8 3af0ebe ecbb90e 17efe4f 56efbc8 3af0ebe 56efbc8 3af0ebe 56efbc8 3af0ebe 56efbc8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import shlex
import subprocess
import spaces
import torch
import gradio as gr
# install packages for mamba
def install_mamba():
#subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118"))
#subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
#subprocess.run(shlex.split("pip install numpy==1.26.4"))
install_mamba()
ABOUT = """
# SEMamba: Speech Enhancement
A Mamba-based model that denoises real-world audio.
Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram.
"""
import torch
import yaml
import librosa
import librosa.display
import matplotlib
import soundfile as sf
from models.stfts import mag_phase_stft, mag_phase_istft
from models.generator import SEMamba
from models.pcs400 import cal_pcs
ckpt = "ckpts/SEMamba_advanced.pth"
cfg_f = "recipes/SEMamba_advanced.yaml"
# load config
with open(cfg_f, 'r') as f:
cfg = yaml.safe_load(f)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"
model = SEMamba(cfg).to(device)
sdict = torch.load(ckpt, map_location=device)
model.load_state_dict(sdict["generator"])
model.eval()
@spaces.GPU
def enhance(filepath):
with torch.no_grad():
# load & resample
wav, orig_sr = librosa.load(filepath, sr=None)
if orig_sr != 16000:
wav = librosa.resample(wav, orig_sr=orig_sr, target_sr=16000)
x = torch.from_numpy(wav).float().to(device)
norm = torch.sqrt(len(x)/torch.sum(x**2))
#x = (x * norm).unsqueeze(0)
x = (x * norm)
# split into 4s segments (64000 samples)
segment_len = 4 * 16000
chunks = x.split(segment_len)
enhanced_chunks = []
for chunk in chunks:
if len(chunk) < segment_len:
pad = torch.zeros(segment_len - len(chunk), device=chunk.device)
chunk = torch.cat([chunk, pad])
chunk = chunk.unsqueeze(0)
amp, pha, _ = mag_phase_stft(chunk, 400, 100, 400, 0.3)
amp2, pha2, _ = model(amp, pha)
out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
out = (out / norm).squeeze(0)
enhanced_chunks.append(out)
out = torch.cat(enhanced_chunks)[:len(x)].cpu().numpy() # trim padding
# back to original rate
if orig_sr != 16000:
out = librosa.resample(out, orig_sr=16000, target_sr=orig_sr)
# write file
sf.write("enhanced.wav", out, orig_sr)
# spectrograms
fig, axs = plt.subplots(1, 2, figsize=(10, 4))
# noisy
D_noisy = librosa.stft(wav, n_fft=1024, hop_length=512)
S_noisy = librosa.amplitude_to_db(np.abs(D_noisy), ref=np.max)
librosa.display.specshow(S_noisy, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=axs[0])
axs[0].set_title("Noisy Spectrogram")
# enhanced
D_clean = librosa.stft(out, n_fft=1024, hop_length=512)
S_clean = librosa.amplitude_to_db(np.abs(D_clean), ref=np.max)
librosa.display.specshow(S_clean, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=axs[1])
axs[1].set_title("Enhanced Spectrogram")
plt.tight_layout()
return "enhanced.wav", fig
with gr.Blocks() as demo:
gr.Markdown(ABOUT)
input_audio = gr.Audio(label="Input Audio", type="filepath", interactive=True)
enhance_btn = gr.Button("Enhance")
output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
plot_output = gr.Plot(label="Spectrograms")
enhance_btn.click(fn=enhance, inputs=input_audio, outputs=[output_audio, plot_output])
demo.queue().launch()
|