roychao19477 commited on
Commit
6da993c
Β·
1 Parent(s): 306a1c8
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -39,19 +39,18 @@ hop_size = stft_cfg["hop_size"]
39
  win_size = stft_cfg["win_size"]
40
  compress_ff = model_cfg["compress_factor"]
41
 
 
 
 
 
 
 
 
 
 
42
 
43
  @spaces.GPU
44
  def enhance(filepath):
45
- global model
46
- if model is None:
47
- print("loading model")
48
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
- model = SEMamba(cfg).to(device)
50
- sdict = torch.load(ckpt, map_location=device)
51
- model.load_state_dict(sdict["generator"])
52
- model.eval()
53
- print("Finished.")
54
-
55
  # load & (if needed) resample to model SR
56
  wav, orig_sr = librosa.load(filepath, sr=None)
57
  if orig_sr != SR:
@@ -62,7 +61,8 @@ def enhance(filepath):
62
  x = (x*norm).unsqueeze(0)
63
  # STFT β†’ model β†’ ISTFT
64
  amp,pha,_ = mag_phase_stft(x, **stft_cfg, compress_factor=model_cfg["compress_factor"])
65
- amp2,pha2 = model(amp, pha)
 
66
  out = mag_phase_istft(amp2, pha2, **stft_cfg, compress_factor=model_cfg["compress_factor"])
67
  out = (out/norm).squeeze().cpu().numpy()
68
  # back to original rate
 
39
  win_size = stft_cfg["win_size"]
40
  compress_ff = model_cfg["compress_factor"]
41
 
42
+
43
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
+ device = "cuda"
45
+ model = SEMamba(cfg).to(device)
46
+ sdict = torch.load(ckpt, map_location=device)
47
+ model.load_state_dict(sdict["generator"])
48
+ model.eval()
49
+
50
+
51
 
52
  @spaces.GPU
53
  def enhance(filepath):
 
 
 
 
 
 
 
 
 
 
54
  # load & (if needed) resample to model SR
55
  wav, orig_sr = librosa.load(filepath, sr=None)
56
  if orig_sr != SR:
 
61
  x = (x*norm).unsqueeze(0)
62
  # STFT β†’ model β†’ ISTFT
63
  amp,pha,_ = mag_phase_stft(x, **stft_cfg, compress_factor=model_cfg["compress_factor"])
64
+ with torch.no_grad():
65
+ amp2, pha2, comp = model(amp, pha)
66
  out = mag_phase_istft(amp2, pha2, **stft_cfg, compress_factor=model_cfg["compress_factor"])
67
  out = (out/norm).squeeze().cpu().numpy()
68
  # back to original rate