File size: 4,308 Bytes
746a4fa
 
 
 
 
 
d69789d
73431e1
7af2ba0
d69789d
48070ab
 
0efa60f
 
746a4fa
d69789d
746a4fa
c450417
 
 
 
 
1fdb610
0601709
c450417
 
 
 
 
7001051
0efa60f
 
 
 
 
 
 
 
7001051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18c8531
 
2bbe7e3
ea5c419
7001051
 
 
 
2bbe7e3
ea5c419
7001051
ea5c419
7001051
ea5c419
7001051
 
 
 
1fdb610
2bbe7e3
ea5c419
1fdb610
2bbe7e3
 
1fdb610
ea5c419
a1a60de
1fdb610
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import shlex
import subprocess
import spaces
import torch

# install packages for mamba
def install_mamba():
    #subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118"))
    #subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
    subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
    subprocess.run(shlex.split("pip install numpy==1.26.4"))

    subprocess.run(shlex.split("ls"))


install_mamba()


import gradio as gr
import torch
import yaml
import librosa
import librosa.display
import matplotlib
from huggingface_hub import hf_hub_download
from models.stfts    import mag_phase_stft, mag_phase_istft
from models.generator import SEMamba
from models.pcs400   import cal_pcs

# download model files from your HF repo
#ckpt  = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
#                        "ckpts/SEMamba_advanced.pth")
#cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
#                        "recipes/SEMamba_advanced.yaml")

ckpt = "ckpts/SEMamba_advanced.pth"
cfg_f = "recipes/SEMamba_advanced.yaml"


# load config
with open(cfg_f) as f:
    cfg = yaml.safe_load(f)

stft_cfg    = cfg["stft_cfg"]
model_cfg   = cfg["model_cfg"]
sr          = stft_cfg["sampling_rate"]
n_fft       = stft_cfg["n_fft"]
hop_size    = stft_cfg["hop_size"]
win_size    = stft_cfg["win_size"]
compress_ff = model_cfg["compress_factor"]

# init model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = SEMamba(cfg).to(device)
sdict  = torch.load(ckpt, map_location=device)
model.load_state_dict(sdict["generator"])
model.eval()

   
@spaces.GPU
def enhance(audio):
    if audio is None: return None, None
    orig_sr, wav_np = audio
    if orig_sr != sr:
        wav_np = librosa.resample(wav_np, orig_sr, sr)
    wav = torch.from_numpy(wav_np).float().to(device)
    norm = torch.sqrt(len(wav) / torch.sum(wav ** 2))
    wav = (wav * norm).unsqueeze(0)

    amp, pha, _ = mag_phase_stft(wav, **stft_cfg, compress_factor=model_cfg["compress_factor"])
    amp_g, pha_g = model(amp, pha)
    out = mag_phase_istft(amp_g, pha_g, **stft_cfg, compress_factor=model_cfg["compress_factor"])
    out = (out / norm).squeeze().cpu().numpy()
    if orig_sr != sr:
        out = librosa.resample(out, sr, orig_sr)

    # spectrogram
    D = librosa.stft(out, n_fft=1024, hop_length=512)
    S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)
    fig, ax = plt.subplots(figsize=(6, 3))
    librosa.display.specshow(S_db, sr=orig_sr, hop_length=512, x_axis='time', y_axis='hz', ax=ax)
    ax.set_title("Enhanced Spectrogram")
    plt.colorbar(format="%+2.0f dB", ax=ax)
    return (orig_sr, out), fig

# --- Layout with Blocks ---
with gr.Blocks(css=".gr-box {border: none !important}") as demo:
    gr.Markdown("<h1 style='text-align: center;'>🎧 <a href='https://github.com/RoyChao19477/SEMamba' target='_blank'>SEMamba</a>: Speech Enhancement</h1>")
    gr.Markdown("Enhance real-world noisy speech using Mamba. Upload or record an audio clip and view the spectrogram.")

    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(sources=["upload", "microphone"], type="numpy", label="Upload or Record", elem_id="input-audio")
            run_btn = gr.Button("Enhance Now πŸš€", variant="primary")

        with gr.Column():
            enhanced_audio = gr.Audio(label="Enhanced Output", type="numpy")
            spec_plot = gr.Plot(label="Spectrogram")

    run_btn.click(enhance, inputs=audio_input, outputs=[enhanced_audio, spec_plot])

    gr.Examples(
        examples=[
            ["examples/noisy_sample_16k.wav"],
        ],
        inputs=audio_input,
        outputs=[enhanced_audio, spec_plot],
        fn=enhance,
        cache_examples=True,
        label="πŸ“‚ Try These Examples"
    )

    gr.Markdown("<p style='text-align: center'><a href='https://arxiv.org/abs/2405.15144' target='_blank'>πŸ“„ SEMamba: Mamba for Long-Context Speech Enhancement (SLT 2024)</a></p>")

demo.launch()