roychao19477 commited on
Commit
8bb81da
Β·
1 Parent(s): dd22f75
Files changed (1) hide show
  1. app.py +15 -89
app.py CHANGED
@@ -1,94 +1,20 @@
1
- import shlex
2
- import subprocess
3
- import spaces
4
- import torch
5
  import gradio as gr
6
-
7
- # install packages for mamba
8
- def install_mamba():
9
- #subprocess.run(shlex.split("pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118"))
10
- #subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.4.0/causal_conv1d-1.4.0+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
11
- subprocess.run("pip install gradio --upgrade --force", shell=True)
12
- subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
13
- subprocess.run(shlex.split("pip install numpy==1.26.4"))
14
-
15
- install_mamba()
16
-
17
- ABOUT = """
18
- # SEMamba: Speech Enhancement
19
- A Mamba-based model that denoises real-world audio.
20
- Upload or record a noisy clip and click **Enhance** to hear + see its spectrogram.
21
- """
22
-
23
-
24
- import torch
25
- import yaml
26
- import librosa
27
- import librosa.display
28
- import matplotlib
29
- from models.stfts import mag_phase_stft, mag_phase_istft
30
- from models.generator import SEMamba
31
- from models.pcs400 import cal_pcs
32
-
33
- ckpt = "ckpts/SEMamba_advanced.pth"
34
- cfg_f = "recipes/SEMamba_advanced.yaml"
35
-
36
- # load config
37
- with open(cfg_f, 'r') as f:
38
- cfg = yaml.safe_load(f)
39
-
40
-
41
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- device = "cuda"
43
- model = SEMamba(cfg).to(device)
44
- sdict = torch.load(ckpt, map_location=device)
45
- model.load_state_dict(sdict["generator"])
46
- model.eval()
47
-
48
 
49
  @spaces.GPU
50
- def enhance(filepath):
51
- with torch.no_grad():
52
- # load & (if needed) resample to model SR
53
- wav, orig_sr = librosa.load(filepath, sr=None)
54
-
55
- if orig_sr != 16000:
56
- wav = librosa.resample(wav, orig_sr, 16000)
57
- # normalize β†’ tensor
58
- x = torch.from_numpy(wav).float().to(device)
59
- norm = torch.sqrt(len(x)/torch.sum(x**2))
60
- x = (x*norm).unsqueeze(0)
61
- # STFT β†’ model β†’ ISTFT
62
- amp ,pha , _ = mag_phase_stft(x, 400, 100, 400, 0.3)
63
- with torch.no_grad():
64
- amp2, pha2, comp = model(amp, pha)
65
- out = mag_phase_istft(amp2, pha2, 400, 100, 400, 0.3)
66
- out = (out/norm).squeeze().cpu().numpy()
67
- # back to original rate
68
- if orig_sr != 16000:
69
- out = librosa.resample(out, 16000, orig_sr, 'PCM_16')
70
- # write file
71
- sf.write("enhanced.wav", out, orig_sr)
72
- # build spectrogram
73
-
74
- D = librosa.stft(out, n_fft=1024, hop_length=512)
75
- S = librosa.amplitude_to_db(np.abs(D), ref=np.max)
76
- fig, ax = plt.subplots(figsize=(6,3))
77
- librosa.display.specshow(S, sr=orig_sr, hop_length=512, x_axis="time", y_axis="hz", ax=ax)
78
- ax.set_title("Enhanced Spectrogram")
79
- plt.colorbar(format="%+2.0f dB", ax=ax)
80
-
81
- return "enhanced.wav"#, fig
82
-
83
 
84
  with gr.Blocks() as demo:
85
- gr.Markdown(ABOUT)
86
- input_audio = gr.Audio(label="Input Audio", type="filepath")
87
- enhance_btn = gr.Button("Enhance")
88
- output_audio = gr.Audio(label="Enhanced Audio", type="filepath")
89
-
90
- #enhance_btn.click(fn=enhance, inputs=input_audio, outputs=output_audio)
91
-
92
- #demo.queue().launch()
93
- if __name__ == "__main__":
94
- demo.launch()
 
 
 
 
 
1
  import gradio as gr
2
+ import spaces
3
+ import numpy as np
4
+ import soundfile as sf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  @spaces.GPU
7
+ def dummy_enhance(audio_path):
8
+ print("Audio received:", audio_path)
9
+ # Return the same file as a dummy operation
10
+ return audio_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  with gr.Blocks() as demo:
13
+ gr.Markdown("# SEMamba: ZeroGPU Upload Test")
14
+ audio_input = gr.Audio(type="filepath", label="Upload Audio", interactive=True)
15
+ submit = gr.Button("Run Enhancement")
16
+ audio_output = gr.Audio(type="filepath", label="Enhanced Output")
17
+
18
+ submit.click(dummy_enhance, inputs=[audio_input], outputs=[audio_output])
19
+
20
+ demo.queue().launch() # No ssr_mode=False