Spaces:
Running
on
Zero
Running
on
Zero
roychao19477
commited on
Commit
·
23813e6
1
Parent(s):
b753bba
Upload
Browse files
app.py
CHANGED
@@ -10,9 +10,6 @@ def install_mamba():
|
|
10 |
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
|
11 |
subprocess.run(shlex.split("pip install numpy==1.26.4"))
|
12 |
|
13 |
-
subprocess.run(shlex.split("ls"))
|
14 |
-
|
15 |
-
|
16 |
install_mamba()
|
17 |
|
18 |
|
@@ -36,7 +33,6 @@ from models.pcs400 import cal_pcs
|
|
36 |
ckpt = "ckpts/SEMamba_advanced.pth"
|
37 |
cfg_f = "recipes/SEMamba_advanced.yaml"
|
38 |
|
39 |
-
|
40 |
# load config
|
41 |
with open(cfg_f) as f:
|
42 |
cfg = yaml.safe_load(f)
|
@@ -49,12 +45,17 @@ hop_size = stft_cfg["hop_size"]
|
|
49 |
win_size = stft_cfg["win_size"]
|
50 |
compress_ff = model_cfg["compress_factor"]
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
model.
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
|
60 |
@spaces.GPU
|
|
|
10 |
subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
|
11 |
subprocess.run(shlex.split("pip install numpy==1.26.4"))
|
12 |
|
|
|
|
|
|
|
13 |
install_mamba()
|
14 |
|
15 |
|
|
|
33 |
ckpt = "ckpts/SEMamba_advanced.pth"
|
34 |
cfg_f = "recipes/SEMamba_advanced.yaml"
|
35 |
|
|
|
36 |
# load config
|
37 |
with open(cfg_f) as f:
|
38 |
cfg = yaml.safe_load(f)
|
|
|
45 |
win_size = stft_cfg["win_size"]
|
46 |
compress_ff = model_cfg["compress_factor"]
|
47 |
|
48 |
+
|
49 |
+
@spaces.GPU
|
50 |
+
def load_model():
|
51 |
+
# init model
|
52 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
53 |
+
model = SEMamba(cfg).to(device)
|
54 |
+
sdict = torch.load(ckpt, map_location=device)
|
55 |
+
model.load_state_dict(sdict["generator"])
|
56 |
+
model.eval()
|
57 |
+
|
58 |
+
load_model()
|
59 |
|
60 |
|
61 |
@spaces.GPU
|