File size: 3,437 Bytes
746a4fa
 
 
 
f299c53
746a4fa
 
d69789d
73431e1
7af2ba0
d69789d
48070ab
 
d69789d
746a4fa
f299c53
 
 
 
 
 
c450417
 
 
 
1fdb610
0601709
c450417
 
 
 
 
0efa60f
 
 
7001051
 
 
 
 
 
 
 
 
 
 
 
6da993c
 
 
 
 
 
 
 
18c8531
 
9657939
81e7d3e
 
9657939
81e7d3e
 
 
 
 
 
 
6da993c
 
81e7d3e
 
 
9657939
 
81e7d3e
ff0b6ec
81e7d3e
2bbe7e3
ff0b6ec
81e7d3e
179e88e
9657939
 
f299c53
b9afdac
9657939
ecbb90e
c90ec4a
ab7d954
feda893
 
 
 
 
 
 
 
 
 
 
 
 
ecbb90e
bb56b46
feda893
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import shlex
import subprocess
import spaces
import torch
import gradio as gr

# install packages for mamba
def install_mamba():
    #subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118"))
    #subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
    subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
    subprocess.run(shlex.split("pip install numpy==1.26.4"))

install_mamba()

ABOUT = """
# SEMamba: Speech Enhancement
A Mamba-based model that denoises real-world audio.
Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram.
"""


import torch
import yaml
import librosa
import librosa.display
import matplotlib
from huggingface_hub import hf_hub_download
from models.stfts    import mag_phase_stft, mag_phase_istft
from models.generator import SEMamba
from models.pcs400   import cal_pcs

ckpt = "ckpts/SEMamba_advanced.pth"
cfg_f = "recipes/SEMamba_advanced.yaml"

# load config
with open(cfg_f) as f:
    cfg = yaml.safe_load(f)

stft_cfg    = cfg["stft_cfg"]
model_cfg   = cfg["model_cfg"]
sr          = stft_cfg["sampling_rate"]
n_fft       = stft_cfg["n_fft"]
hop_size    = stft_cfg["hop_size"]
win_size    = stft_cfg["win_size"]
compress_ff = model_cfg["compress_factor"]


# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cuda"
model  = SEMamba(cfg).to(device)
sdict  = torch.load(ckpt, map_location=device)
model.load_state_dict(sdict["generator"])
model.eval()

   
@spaces.GPU
def enhance(filepath):
    # load & (if needed) resample to model SR
    wav, orig_sr = librosa.load(filepath, sr=None)
    if orig_sr != SR:
        wav = librosa.resample(wav, orig_sr, SR)
    # normalize β†’ tensor
    x = torch.from_numpy(wav).float().to(device)
    norm = torch.sqrt(len(x)/torch.sum(x**2))
    x = (x*norm).unsqueeze(0)
    # STFT β†’ model β†’ ISTFT
    amp,pha,_ = mag_phase_stft(x, **stft_cfg, compress_factor=model_cfg["compress_factor"])
    with torch.no_grad():
        amp2, pha2, comp = model(amp, pha)
    out = mag_phase_istft(amp2, pha2, **stft_cfg, compress_factor=model_cfg["compress_factor"])
    out = (out/norm).squeeze().cpu().numpy()
    # back to original rate
    if orig_sr != SR:
        out = librosa.resample(out, SR, orig_sr)
    # write file
    sf.write("enhanced.wav", out, orig_sr)
    # build spectrogram
    D = librosa.stft(out, n_fft=1024, hop_length=512)
    S = librosa.amplitude_to_db(np.abs(D), ref=np.max)
    fig, ax = plt.subplots(figsize=(6,3))
    librosa.display.specshow(S, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=ax)
    ax.set_title("Enhanced Spectrogram")
    plt.colorbar(format="%+2.0f dB", ax=ax)

    return "enhanced.wav"#, fig




se_demo = gr.Interface(
    fn = enhance,
    inputs = [
        gr.Audio(label="Input Audio", type="filepath"),
        gr.Checkbox(label="Apply Speech Enhancement", value=True),
    ],
    outputs = [
        gr.Audio(label="Output Audio", type="filepath"),
        #gr.Plot(label="Spectrogram")
    ],
    title = "SEMamba",
    cache_examples=False,
)

if __name__ == "__main__":
    se_demo.launch()