File size: 2,600 Bytes
746a4fa
 
 
 
 
 
 
 
 
 
 
 
 
 
c450417
 
 
 
 
 
 
 
 
 
7001051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b3ab69
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import shlex
import subprocess
import spaces
import torch


# install packages for mamba
def install():
    print("Install personal packages", flush=True)
    subprocess.run(shlex.split("pip install causal_conv1d-1.0.0-cp310-cp310-linux_x86_64.whl"))
    subprocess.run(shlex.split("pip install mamba_ssm-1.0.1-cp310-cp310-linux_x86_64.whl"))

install()


import gradio as gr
import torch
import yaml
import librosa
from huggingface_hub import hf_hub_download
from models.stfts    import mag_phase_stft, mag_phase_istft
from models.generator import SEMamba
from models.pcs400   import cal_pcs

# download model files from your HF repo
ckpt  = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
                        "ckpts/SEMamba_advanced.pth")
cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
                        "recipes/SEMamba_advanced.yaml")

# load config
with open(cfg_f) as f:
    cfg = yaml.safe_load(f)

stft_cfg    = cfg["stft_cfg"]
model_cfg   = cfg["model_cfg"]
sr          = stft_cfg["sampling_rate"]
n_fft       = stft_cfg["n_fft"]
hop_size    = stft_cfg["hop_size"]
win_size    = stft_cfg["win_size"]
compress_ff = model_cfg["compress_factor"]

# init model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = SEMamba(cfg).to(device)
sdict  = torch.load(ckpt, map_location=device)
model.load_state_dict(sdict["generator"])
model.eval()

def enhance(audio, do_pcs):
    orig_sr, wav_np = audio
    # 1) resample to 16 kHz if needed
    if orig_sr != sr:
        wav_np = librosa.resample(wav_np, orig_sr, sr)
    wav = torch.from_numpy(wav_np).float().to(device)

    # normalize
    norm = torch.sqrt(len(wav) / torch.sum(wav**2))
    wav  = (wav * norm).unsqueeze(0)

    # STFT β†’ model β†’ ISTFT
    amp, pha, _ = mag_phase_stft(wav, n_fft, hop_size, win_size, compress_ff)
    amp_g, pha_g = model(amp, pha)
    out = mag_phase_istft(amp_g, pha_g, n_fft, hop_size, win_size, compress_ff)
    out = (out / norm).squeeze().cpu().numpy()

    # optional PCS filter
    if do_pcs:
        out = cal_pcs(out)

    # 2) resample back to original rate
    if orig_sr != sr:
        out = librosa.resample(out, sr, orig_sr)

    return orig_sr, out

demo = gr.Interface(
    fn=enhance,
    inputs=[
        gr.Audio(source="upload", type="numpy", label="Noisy wav"),
        gr.Checkbox(label="Apply PCS post-processing", value=False),
    ],
    outputs=gr.Audio(type="numpy", label="Enhanced wav"),
    title="SEMamba Speech Enhancement",
    description="Upload a noisy WAV; tick **Apply PCS** for the pcs400 filter.",
)


demo.launch()