import shlex import subprocess import spaces import torch # install packages for mamba def install_mamba(): #subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118")) #subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl")) subprocess.run(shlex.split("pip install numpy==1.26.4")) subprocess.run(shlex.split("ls")) install_mamba() import gradio as gr import torch import yaml import librosa import librosa.display import matplotlib from huggingface_hub import hf_hub_download from models.stfts import mag_phase_stft, mag_phase_istft from models.generator import SEMamba from models.pcs400 import cal_pcs # download model files from your HF repo #ckpt = hf_hub_download("rc19477/Speech_Enhancement_Mamba", # "ckpts/SEMamba_advanced.pth") #cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba", # "recipes/SEMamba_advanced.yaml") ckpt = "ckpts/SEMamba_advanced.pth" cfg_f = "recipes/SEMamba_advanced.yaml" # load config with open(cfg_f) as f: cfg = yaml.safe_load(f) stft_cfg = cfg["stft_cfg"] model_cfg = cfg["model_cfg"] sr = stft_cfg["sampling_rate"] n_fft = stft_cfg["n_fft"] hop_size = stft_cfg["hop_size"] win_size = stft_cfg["win_size"] compress_ff = model_cfg["compress_factor"] # init model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = SEMamba(cfg).to(device) sdict = torch.load(ckpt, map_location=device) model.load_state_dict(sdict["generator"]) model.eval() @spaces.GPU def enhance(filepath): wav_np, orig_sr = librosa.load(filepath, sr=None) if orig_sr!=SR: wav_np=librosa.resample(wav_np,orig_sr,SR) wav = torch.from_numpy(wav_np).float().to(device) norm= torch.sqrt(len(wav)/torch.sum(wav**2)); wav=(wav*norm).unsqueeze(0) amp,pha,_ = mag_phase_stft(wav,**stft_cfg,compress_factor=model_cfg["compress_factor"]) amp_g,pha_g = model(amp,pha) out = mag_phase_istft(amp_g,pha_g,**stft_cfg,compress_factor=model_cfg["compress_factor"]) out = (out/norm).squeeze().cpu().numpy() if orig_sr!=SR: out=librosa.resample(out,SR,orig_sr) sf.write("enhanced.wav", out, orig_sr) D = librosa.stft(out, n_fft=1024, hop_length=512) S = librosa.amplitude_to_db(np.abs(D), ref=np.max) fig,ax=plt.subplots(figsize=(6,3)) librosa.display.specshow(S,sr=orig_sr,hop_length=512,x_axis="time",y_axis="hz",ax=ax) ax.set_title("Enhanced Spectrogram"); plt.colorbar(format="%+2.0f dB",ax=ax) return "enhanced.wav", fig # — Custom CSS — CSS = """ #title {text-align:center; margin-bottom:0.2em;} #subtitle {text-align:center; color:#555; margin-bottom:1.5em;} .duplicate-button {display:block; margin:0 auto 1.5em;} #audio-in {border:2px dashed #aaa; border-radius:8px; padding:1em;} #run-btn {width:100%; margin-top:0.5em;} #out-audio, #spec-plot {margin-top:1em;} """ # — Blocks layout — with gr.Blocks(css=CSS, theme="soft") as demo: gr.HTML("
Upload or record your noisy clip, then click Enhance to boost clarity and view its spectrogram.
") gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") with gr.Row(): with gr.Column(scale=1): audio_in = gr.Audio(sources=["upload","microphone"], type="filepath", label="Your Noisy Audio", elem_id="audio-in") run_btn = gr.Button("Enhance Now 🚀", variant="primary", elem_id="run-btn") with gr.Column(scale=1): audio_out = gr.Audio(type="filepath", label="Enhanced Audio", elem_id="out-audio") spec_plot = gr.Plot(label="Spectrogram", elem_id="spec-plot") run_btn.click(enhance, inputs=audio_in, outputs=[audio_out, spec_plot]) demo.launch()