Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,249 Bytes
746a4fa d69789d 73431e1 7af2ba0 d69789d 48070ab 0efa60f 746a4fa d69789d 746a4fa c450417 1fdb610 0601709 c450417 7001051 0efa60f 7001051 18c8531 c58812e ff0b6ec 7001051 ff0b6ec 2bbe7e3 ff0b6ec 1fdb610 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import shlex
import subprocess
import spaces
import torch
# install packages for mamba
def install_mamba():
#subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118"))
#subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
subprocess.run(shlex.split("pip install numpy==1.26.4"))
subprocess.run(shlex.split("ls"))
install_mamba()
import gradio as gr
import torch
import yaml
import librosa
import librosa.display
import matplotlib
from huggingface_hub import hf_hub_download
from models.stfts import mag_phase_stft, mag_phase_istft
from models.generator import SEMamba
from models.pcs400 import cal_pcs
# download model files from your HF repo
#ckpt = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
# "ckpts/SEMamba_advanced.pth")
#cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
# "recipes/SEMamba_advanced.yaml")
ckpt = "ckpts/SEMamba_advanced.pth"
cfg_f = "recipes/SEMamba_advanced.yaml"
# load config
with open(cfg_f) as f:
cfg = yaml.safe_load(f)
stft_cfg = cfg["stft_cfg"]
model_cfg = cfg["model_cfg"]
sr = stft_cfg["sampling_rate"]
n_fft = stft_cfg["n_fft"]
hop_size = stft_cfg["hop_size"]
win_size = stft_cfg["win_size"]
compress_ff = model_cfg["compress_factor"]
# init model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SEMamba(cfg).to(device)
sdict = torch.load(ckpt, map_location=device)
model.load_state_dict(sdict["generator"])
model.eval()
@spaces.GPU
def enhance(filepath):
wav_np, orig_sr = librosa.load(filepath, sr=None)
if orig_sr!=SR: wav_np=librosa.resample(wav_np,orig_sr,SR)
wav = torch.from_numpy(wav_np).float().to(device)
norm= torch.sqrt(len(wav)/torch.sum(wav**2)); wav=(wav*norm).unsqueeze(0)
amp,pha,_ = mag_phase_stft(wav,**stft_cfg,compress_factor=model_cfg["compress_factor"])
amp_g,pha_g = model(amp,pha)
out = mag_phase_istft(amp_g,pha_g,**stft_cfg,compress_factor=model_cfg["compress_factor"])
out = (out/norm).squeeze().cpu().numpy()
if orig_sr!=SR: out=librosa.resample(out,SR,orig_sr)
sf.write("enhanced.wav", out, orig_sr)
D = librosa.stft(out, n_fft=1024, hop_length=512)
S = librosa.amplitude_to_db(np.abs(D), ref=np.max)
fig,ax=plt.subplots(figsize=(6,3))
librosa.display.specshow(S,sr=orig_sr,hop_length=512,x_axis="time",y_axis="hz",ax=ax)
ax.set_title("Enhanced Spectrogram"); plt.colorbar(format="%+2.0f dB",ax=ax)
return "enhanced.wav", fig
# β Custom CSS β
CSS = """
#title {text-align:center; margin-bottom:0.2em;}
#subtitle {text-align:center; color:#555; margin-bottom:1.5em;}
.duplicate-button {display:block; margin:0 auto 1.5em;}
#audio-in {border:2px dashed #aaa; border-radius:8px; padding:1em;}
#run-btn {width:100%; margin-top:0.5em;}
#out-audio, #spec-plot {margin-top:1em;}
"""
# β Blocks layout β
with gr.Blocks(css=CSS, theme="soft") as demo:
gr.HTML("<h1 id='title'>π§ SEMamba: Speech Enhancement</h1>")
gr.HTML("<p id='subtitle'>Upload or record your noisy clip, then click Enhance to boost clarity and view its spectrogram.</p>")
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
with gr.Row():
with gr.Column(scale=1):
audio_in = gr.Audio(sources=["upload","microphone"], type="filepath",
label="Your Noisy Audio", elem_id="audio-in")
run_btn = gr.Button("Enhance Now π", variant="primary", elem_id="run-btn")
with gr.Column(scale=1):
audio_out = gr.Audio(type="filepath", label="Enhanced Audio", elem_id="out-audio")
spec_plot = gr.Plot(label="Spectrogram", elem_id="spec-plot")
run_btn.click(enhance, inputs=audio_in, outputs=[audio_out, spec_plot])
demo.launch()
|