File size: 3,669 Bytes
7872d8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3617711
7872d8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3617711
7872d8f
 
 
 
 
 
 
 
 
3617711
 
7872d8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3617711
7872d8f
 
 
 
3617711
7872d8f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python
import os
import torch
import numpy as np
import soundfile as sf
import gradio as gr
from model import UFormer, UFormerConfig

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 1) Setup & model loading from local checkpoints
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "checkpoints"
config      = UFormerConfig()
_model_cache = {}

VALID_CKPTS = [
    "acoustic_guitar","bass","electric_guitar","guitars","keyboards",
    "orchestra","rhythm_section","synth","vocals"
]

def _get_model(ckpt_name: str):
    if ckpt_name not in VALID_CKPTS:
        raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}")
    if ckpt_name in _model_cache:
        return _model_cache[ckpt_name]
    ckpt_path = os.path.join(CHECKPOINT_DIR, f"{ckpt_name}.pth")
    model = UFormer(config).to(DEVICE).eval()
    state = torch.load(ckpt_path, map_location=DEVICE)
    model.load_state_dict(state)
    _model_cache[ckpt_name] = model
    return model

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 2) Overlap-add for long audio
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=2.5):
    C, T = x.shape
    chunk, hop = int(sr*chunk_s), int(sr*hop_s)
    pad = (-(T - chunk) % hop) if T > chunk else 0
    x_pad = np.pad(x, ((0,0),(0,pad)), mode="reflect")
    win   = np.hanning(chunk)[None, :]
    out   = np.zeros_like(x_pad)
    norm  = np.zeros((1, x_pad.shape[1]))
    n_chunks = 1 + (x_pad.shape[1] - chunk) // hop
    print(f"Processing {n_chunks} chunks of size {chunk} with hop {hop}...")

    for i in range(n_chunks):
        s = i * hop
        seg = x_pad[:, s:s+chunk].astype(np.float32)
        with torch.no_grad():
            y = model(torch.from_numpy(seg[None]).to(DEVICE)).squeeze(0).cpu().numpy()
        out[:, s:s+chunk]  += y * win
        norm[:, s:s+chunk] += win

    eps = 1e-8
    return (out / (norm + eps))[:, :T]

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 3) Restore function for Gradio
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def restore_fn(audio_path, checkpoint):
    audio, sr = sf.read(audio_path)
    if audio.ndim == 1:
        audio = np.stack([audio, audio], axis=1)
    x = audio.T  # (C, T)

    model = _get_model(checkpoint)
    if x.shape[1] <= sr * 5:
        seg = x.astype(np.float32)[None]
        with torch.no_grad():
            y = model(torch.from_numpy(seg).to(DEVICE)).squeeze(0).cpu().numpy()
    else:
        y = _overlap_add(model, x, sr)

    tmp = "restored.wav"
    sf.write(tmp, y.T, sr, format="WAV")
    return tmp

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 4) Gradio App
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
demo = gr.Interface(
    fn=restore_fn,
    inputs=[
        gr.Audio(sources="upload", type="filepath", label="Your Input"),
        gr.Dropdown(VALID_CKPTS, label="Checkpoint")
    ],
    outputs=gr.Audio(type="filepath", label="Restored Output"),
    title="🎡 Music Source Restoration",
    description="Upload an (stereo) audio file and choose an instrument/group checkpoint to restore. Please note that these are baseline models for demonstration purposes only, and most of them don't perform really well...",
    allow_flagging="never"
)

if __name__ == "__main__":
    demo.launch()
else:
    demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))