Yongyi Zang
commited on
Commit
Β·
3617711
1
Parent(s):
7acb2e5
Change Files
Browse files- __pycache__/model.cpython-313.pyc +0 -0
- app.py +6 -4
__pycache__/model.cpython-313.pyc
ADDED
Binary file (22 kB). View file
|
|
app.py
CHANGED
@@ -24,7 +24,7 @@ def _get_model(ckpt_name: str):
|
|
24 |
raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}")
|
25 |
if ckpt_name in _model_cache:
|
26 |
return _model_cache[ckpt_name]
|
27 |
-
ckpt_path = os.path.join(CHECKPOINT_DIR, f"{ckpt_name}.
|
28 |
model = UFormer(config).to(DEVICE).eval()
|
29 |
state = torch.load(ckpt_path, map_location=DEVICE)
|
30 |
model.load_state_dict(state)
|
@@ -43,6 +43,7 @@ def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=
|
|
43 |
out = np.zeros_like(x_pad)
|
44 |
norm = np.zeros((1, x_pad.shape[1]))
|
45 |
n_chunks = 1 + (x_pad.shape[1] - chunk) // hop
|
|
|
46 |
|
47 |
for i in range(n_chunks):
|
48 |
s = i * hop
|
@@ -52,7 +53,8 @@ def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=
|
|
52 |
out[:, s:s+chunk] += y * win
|
53 |
norm[:, s:s+chunk] += win
|
54 |
|
55 |
-
|
|
|
56 |
|
57 |
# ββββββββββββββββββββββ
|
58 |
# 3) Restore function for Gradio
|
@@ -81,12 +83,12 @@ def restore_fn(audio_path, checkpoint):
|
|
81 |
demo = gr.Interface(
|
82 |
fn=restore_fn,
|
83 |
inputs=[
|
84 |
-
gr.Audio(
|
85 |
gr.Dropdown(VALID_CKPTS, label="Checkpoint")
|
86 |
],
|
87 |
outputs=gr.Audio(type="filepath", label="Restored Output"),
|
88 |
title="π΅ Music Source Restoration",
|
89 |
-
description="Upload
|
90 |
allow_flagging="never"
|
91 |
)
|
92 |
|
|
|
24 |
raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}")
|
25 |
if ckpt_name in _model_cache:
|
26 |
return _model_cache[ckpt_name]
|
27 |
+
ckpt_path = os.path.join(CHECKPOINT_DIR, f"{ckpt_name}.pth")
|
28 |
model = UFormer(config).to(DEVICE).eval()
|
29 |
state = torch.load(ckpt_path, map_location=DEVICE)
|
30 |
model.load_state_dict(state)
|
|
|
43 |
out = np.zeros_like(x_pad)
|
44 |
norm = np.zeros((1, x_pad.shape[1]))
|
45 |
n_chunks = 1 + (x_pad.shape[1] - chunk) // hop
|
46 |
+
print(f"Processing {n_chunks} chunks of size {chunk} with hop {hop}...")
|
47 |
|
48 |
for i in range(n_chunks):
|
49 |
s = i * hop
|
|
|
53 |
out[:, s:s+chunk] += y * win
|
54 |
norm[:, s:s+chunk] += win
|
55 |
|
56 |
+
eps = 1e-8
|
57 |
+
return (out / (norm + eps))[:, :T]
|
58 |
|
59 |
# ββββββββββββββββββββββ
|
60 |
# 3) Restore function for Gradio
|
|
|
83 |
demo = gr.Interface(
|
84 |
fn=restore_fn,
|
85 |
inputs=[
|
86 |
+
gr.Audio(sources="upload", type="filepath", label="Your Input"),
|
87 |
gr.Dropdown(VALID_CKPTS, label="Checkpoint")
|
88 |
],
|
89 |
outputs=gr.Audio(type="filepath", label="Restored Output"),
|
90 |
title="π΅ Music Source Restoration",
|
91 |
+
description="Upload an (stereo) audio file and choose an instrument/group checkpoint to restore. Please note that these are baseline models for demonstration purposes only, and most of them don't perform really well...",
|
92 |
allow_flagging="never"
|
93 |
)
|
94 |
|