roychao19477 commited on
Commit
23813e6
·
1 Parent(s): b753bba
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -10,9 +10,6 @@ def install_mamba():
10
  subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
11
  subprocess.run(shlex.split("pip install numpy==1.26.4"))
12
 
13
- subprocess.run(shlex.split("ls"))
14
-
15
-
16
  install_mamba()
17
 
18
 
@@ -36,7 +33,6 @@ from models.pcs400 import cal_pcs
36
  ckpt = "ckpts/SEMamba_advanced.pth"
37
  cfg_f = "recipes/SEMamba_advanced.yaml"
38
 
39
-
40
  # load config
41
  with open(cfg_f) as f:
42
  cfg = yaml.safe_load(f)
@@ -49,12 +45,17 @@ hop_size = stft_cfg["hop_size"]
49
  win_size = stft_cfg["win_size"]
50
  compress_ff = model_cfg["compress_factor"]
51
 
52
- # init model
53
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
- model = SEMamba(cfg).to(device)
55
- sdict = torch.load(ckpt, map_location=device)
56
- model.load_state_dict(sdict["generator"])
57
- model.eval()
 
 
 
 
 
58
 
59
 
60
  @spaces.GPU
 
10
  subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.2/mamba_ssm-2.2.2+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"))
11
  subprocess.run(shlex.split("pip install numpy==1.26.4"))
12
 
 
 
 
13
  install_mamba()
14
 
15
 
 
33
  ckpt = "ckpts/SEMamba_advanced.pth"
34
  cfg_f = "recipes/SEMamba_advanced.yaml"
35
 
 
36
  # load config
37
  with open(cfg_f) as f:
38
  cfg = yaml.safe_load(f)
 
45
  win_size = stft_cfg["win_size"]
46
  compress_ff = model_cfg["compress_factor"]
47
 
48
+
49
+ @spaces.GPU
50
+ def load_model():
51
+ # init model
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+ model = SEMamba(cfg).to(device)
54
+ sdict = torch.load(ckpt, map_location=device)
55
+ model.load_state_dict(sdict["generator"])
56
+ model.eval()
57
+
58
+ load_model()
59
 
60
 
61
  @spaces.GPU