Yongyi Zang commited on
Commit
3617711
Β·
1 Parent(s): 7acb2e5

Change Files

Browse files
Files changed (2) hide show
  1. __pycache__/model.cpython-313.pyc +0 -0
  2. app.py +6 -4
__pycache__/model.cpython-313.pyc ADDED
Binary file (22 kB). View file
 
app.py CHANGED
@@ -24,7 +24,7 @@ def _get_model(ckpt_name: str):
24
  raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}")
25
  if ckpt_name in _model_cache:
26
  return _model_cache[ckpt_name]
27
- ckpt_path = os.path.join(CHECKPOINT_DIR, f"{ckpt_name}.pt")
28
  model = UFormer(config).to(DEVICE).eval()
29
  state = torch.load(ckpt_path, map_location=DEVICE)
30
  model.load_state_dict(state)
@@ -43,6 +43,7 @@ def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=
43
  out = np.zeros_like(x_pad)
44
  norm = np.zeros((1, x_pad.shape[1]))
45
  n_chunks = 1 + (x_pad.shape[1] - chunk) // hop
 
46
 
47
  for i in range(n_chunks):
48
  s = i * hop
@@ -52,7 +53,8 @@ def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=
52
  out[:, s:s+chunk] += y * win
53
  norm[:, s:s+chunk] += win
54
 
55
- return (out / norm)[:, :T]
 
56
 
57
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
58
  # 3) Restore function for Gradio
@@ -81,12 +83,12 @@ def restore_fn(audio_path, checkpoint):
81
  demo = gr.Interface(
82
  fn=restore_fn,
83
  inputs=[
84
- gr.Audio(source="upload", type="filepath", label="Your Input"),
85
  gr.Dropdown(VALID_CKPTS, label="Checkpoint")
86
  ],
87
  outputs=gr.Audio(type="filepath", label="Restored Output"),
88
  title="🎡 Music Source Restoration",
89
- description="Upload a WAV file and choose an instrument/group checkpoint to restore.",
90
  allow_flagging="never"
91
  )
92
 
 
24
  raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}")
25
  if ckpt_name in _model_cache:
26
  return _model_cache[ckpt_name]
27
+ ckpt_path = os.path.join(CHECKPOINT_DIR, f"{ckpt_name}.pth")
28
  model = UFormer(config).to(DEVICE).eval()
29
  state = torch.load(ckpt_path, map_location=DEVICE)
30
  model.load_state_dict(state)
 
43
  out = np.zeros_like(x_pad)
44
  norm = np.zeros((1, x_pad.shape[1]))
45
  n_chunks = 1 + (x_pad.shape[1] - chunk) // hop
46
+ print(f"Processing {n_chunks} chunks of size {chunk} with hop {hop}...")
47
 
48
  for i in range(n_chunks):
49
  s = i * hop
 
53
  out[:, s:s+chunk] += y * win
54
  norm[:, s:s+chunk] += win
55
 
56
+ eps = 1e-8
57
+ return (out / (norm + eps))[:, :T]
58
 
59
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
60
  # 3) Restore function for Gradio
 
83
  demo = gr.Interface(
84
  fn=restore_fn,
85
  inputs=[
86
+ gr.Audio(sources="upload", type="filepath", label="Your Input"),
87
  gr.Dropdown(VALID_CKPTS, label="Checkpoint")
88
  ],
89
  outputs=gr.Audio(type="filepath", label="Restored Output"),
90
  title="🎡 Music Source Restoration",
91
+ description="Upload an (stereo) audio file and choose an instrument/group checkpoint to restore. Please note that these are baseline models for demonstration purposes only, and most of them don't perform really well...",
92
  allow_flagging="never"
93
  )
94