File size: 4,249 Bytes
746a4fa
 
 
 
 
 
d69789d
73431e1
7af2ba0
d69789d
48070ab
 
0efa60f
 
746a4fa
d69789d
746a4fa
c450417
 
 
 
 
1fdb610
0601709
c450417
 
 
 
 
7001051
0efa60f
 
 
 
 
 
 
 
7001051
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18c8531
 
c58812e
 
ff0b6ec
7001051
ff0b6ec
 
 
 
 
 
 
 
2bbe7e3
ff0b6ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1fdb610
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import shlex
import subprocess
import spaces
import torch

# install packages for mamba
def install_mamba():
    #subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118"))
    #subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
    subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
    subprocess.run(shlex.split("pip install numpy==1.26.4"))

    subprocess.run(shlex.split("ls"))


install_mamba()


import gradio as gr
import torch
import yaml
import librosa
import librosa.display
import matplotlib
from huggingface_hub import hf_hub_download
from models.stfts    import mag_phase_stft, mag_phase_istft
from models.generator import SEMamba
from models.pcs400   import cal_pcs

# download model files from your HF repo
#ckpt  = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
#                        "ckpts/SEMamba_advanced.pth")
#cfg_f = hf_hub_download("rc19477/Speech_Enhancement_Mamba",
#                        "recipes/SEMamba_advanced.yaml")

ckpt = "ckpts/SEMamba_advanced.pth"
cfg_f = "recipes/SEMamba_advanced.yaml"


# load config
with open(cfg_f) as f:
    cfg = yaml.safe_load(f)

stft_cfg    = cfg["stft_cfg"]
model_cfg   = cfg["model_cfg"]
sr          = stft_cfg["sampling_rate"]
n_fft       = stft_cfg["n_fft"]
hop_size    = stft_cfg["hop_size"]
win_size    = stft_cfg["win_size"]
compress_ff = model_cfg["compress_factor"]

# init model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = SEMamba(cfg).to(device)
sdict  = torch.load(ckpt, map_location=device)
model.load_state_dict(sdict["generator"])
model.eval()

   
@spaces.GPU
def enhance(filepath):
    wav_np, orig_sr = librosa.load(filepath, sr=None)
    if orig_sr!=SR: wav_np=librosa.resample(wav_np,orig_sr,SR)
    wav = torch.from_numpy(wav_np).float().to(device)
    norm= torch.sqrt(len(wav)/torch.sum(wav**2)); wav=(wav*norm).unsqueeze(0)
    amp,pha,_ = mag_phase_stft(wav,**stft_cfg,compress_factor=model_cfg["compress_factor"])
    amp_g,pha_g = model(amp,pha)
    out = mag_phase_istft(amp_g,pha_g,**stft_cfg,compress_factor=model_cfg["compress_factor"])
    out = (out/norm).squeeze().cpu().numpy()
    if orig_sr!=SR: out=librosa.resample(out,SR,orig_sr)
    sf.write("enhanced.wav", out, orig_sr)

    D = librosa.stft(out, n_fft=1024, hop_length=512)
    S = librosa.amplitude_to_db(np.abs(D), ref=np.max)
    fig,ax=plt.subplots(figsize=(6,3))
    librosa.display.specshow(S,sr=orig_sr,hop_length=512,x_axis="time",y_axis="hz",ax=ax)
    ax.set_title("Enhanced Spectrogram"); plt.colorbar(format="%+2.0f dB",ax=ax)
    return "enhanced.wav", fig

# β€” Custom CSS β€”
CSS = """
#title {text-align:center; margin-bottom:0.2em;}
#subtitle {text-align:center; color:#555; margin-bottom:1.5em;}
.duplicate-button {display:block; margin:0 auto 1.5em;}
#audio-in {border:2px dashed #aaa; border-radius:8px; padding:1em;}
#run-btn {width:100%; margin-top:0.5em;}
#out-audio, #spec-plot {margin-top:1em;}
"""

# β€” Blocks layout β€”
with gr.Blocks(css=CSS, theme="soft") as demo:
    gr.HTML("<h1 id='title'>🎧 SEMamba: Speech Enhancement</h1>")
    gr.HTML("<p id='subtitle'>Upload or record your noisy clip, then click Enhance to boost clarity and view its spectrogram.</p>")
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")

    with gr.Row():
        with gr.Column(scale=1):
            audio_in  = gr.Audio(sources=["upload","microphone"], type="filepath",
                                 label="Your Noisy Audio", elem_id="audio-in")
            run_btn   = gr.Button("Enhance Now πŸš€", variant="primary", elem_id="run-btn")

        with gr.Column(scale=1):
            audio_out = gr.Audio(type="filepath", label="Enhanced Audio", elem_id="out-audio")
            spec_plot = gr.Plot(label="Spectrogram", elem_id="spec-plot")

    run_btn.click(enhance, inputs=audio_in, outputs=[audio_out, spec_plot])

demo.launch()