Spaces:
Running
on
Zero
Running
on
Zero
import shlex | |
import subprocess | |
import spaces | |
import torch | |
# install packages for mamba | |
def install_mamba(): | |
#subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118")) | |
#subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) | |
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) | |
subprocess.run(shlex.split("pip install numpy==1.26.4")) | |
subprocess.run(shlex.split("ls")) | |
install_mamba() | |
import gradio as gr | |
import torch | |
import yaml | |
import librosa | |
import librosa.display | |
import matplotlib | |
from huggingface_hub import hf_hub_download | |
from models.stfts import mag_phase_stft, mag_phase_istft | |
from models.generator import SEMamba | |
from models.pcs400 import cal_pcs | |
# download model files from your HF repo | |
#ckpt = hf_hub_download("rc19477/Speech_Enhancement_Mamba", | |
# "ckpts/SEMamba_advanced.pth") | |
#cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba", | |
# "recipes/SEMamba_advanced.yaml") | |
ckpt = "ckpts/SEMamba_advanced.pth" | |
cfg_f = "recipes/SEMamba_advanced.yaml" | |
# load config | |
with open(cfg_f) as f: | |
cfg = yaml.safe_load(f) | |
stft_cfg = cfg["stft_cfg"] | |
model_cfg = cfg["model_cfg"] | |
sr = stft_cfg["sampling_rate"] | |
n_fft = stft_cfg["n_fft"] | |
hop_size = stft_cfg["hop_size"] | |
win_size = stft_cfg["win_size"] | |
compress_ff = model_cfg["compress_factor"] | |
# init model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = SEMamba(cfg).to(device) | |
sdict = torch.load(ckpt, map_location=device) | |
model.load_state_dict(sdict["generator"]) | |
model.eval() | |
def enhance(filepath): | |
wav_np, orig_sr = librosa.load(filepath, sr=None) | |
if orig_sr!=SR: wav_np=librosa.resample(wav_np,orig_sr,SR) | |
wav = torch.from_numpy(wav_np).float().to(device) | |
norm= torch.sqrt(len(wav)/torch.sum(wav**2)); wav=(wav*norm).unsqueeze(0) | |
amp,pha,_ = mag_phase_stft(wav,**stft_cfg,compress_factor=model_cfg["compress_factor"]) | |
amp_g,pha_g = model(amp,pha) | |
out = mag_phase_istft(amp_g,pha_g,**stft_cfg,compress_factor=model_cfg["compress_factor"]) | |
out = (out/norm).squeeze().cpu().numpy() | |
if orig_sr!=SR: out=librosa.resample(out,SR,orig_sr) | |
sf.write("enhanced.wav", out, orig_sr) | |
D = librosa.stft(out, n_fft=1024, hop_length=512) | |
S = librosa.amplitude_to_db(np.abs(D), ref=np.max) | |
fig,ax=plt.subplots(figsize=(6,3)) | |
librosa.display.specshow(S,sr=orig_sr,hop_length=512,x_axis="time",y_axis="hz",ax=ax) | |
ax.set_title("Enhanced Spectrogram"); plt.colorbar(format="%+2.0f dB",ax=ax) | |
return "enhanced.wav", fig | |
# β Custom CSS β | |
CSS = """ | |
#title {text-align:center; margin-bottom:0.2em;} | |
#subtitle {text-align:center; color:#555; margin-bottom:1.5em;} | |
.duplicate-button {display:block; margin:0 auto 1.5em;} | |
#audio-in {border:2px dashed #aaa; border-radius:8px; padding:1em;} | |
#run-btn {width:100%; margin-top:0.5em;} | |
#out-audio, #spec-plot {margin-top:1em;} | |
""" | |
# β Blocks layout β | |
with gr.Blocks(css=CSS, theme="soft") as demo: | |
gr.HTML("<h1 id='title'>π§ SEMamba: Speech Enhancement</h1>") | |
gr.HTML("<p id='subtitle'>Upload or record your noisy clip, then click Enhance to boost clarity and view its spectrogram.</p>") | |
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
audio_in = gr.Audio(sources=["upload","microphone"], type="filepath", | |
label="Your Noisy Audio", elem_id="audio-in") | |
run_btn = gr.Button("Enhance Now π", variant="primary", elem_id="run-btn") | |
with gr.Column(scale=1): | |
audio_out = gr.Audio(type="filepath", label="Enhanced Audio", elem_id="out-audio") | |
spec_plot = gr.Plot(label="Spectrogram", elem_id="spec-plot") | |
run_btn.click(enhance, inputs=audio_in, outputs=[audio_out, spec_plot]) | |
demo.launch() | |