File size: 3,669 Bytes
7872d8f 3617711 7872d8f 3617711 7872d8f 3617711 7872d8f 3617711 7872d8f 3617711 7872d8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
#!/usr/bin/env python
import os
import torch
import numpy as np
import soundfile as sf
import gradio as gr
from model import UFormer, UFormerConfig
# ββββββββββββββββββββββ
# 1) Setup & model loading from local checkpoints
# ββββββββββββββββββββββ
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CHECKPOINT_DIR = "checkpoints"
config = UFormerConfig()
_model_cache = {}
VALID_CKPTS = [
"acoustic_guitar","bass","electric_guitar","guitars","keyboards",
"orchestra","rhythm_section","synth","vocals"
]
def _get_model(ckpt_name: str):
if ckpt_name not in VALID_CKPTS:
raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}")
if ckpt_name in _model_cache:
return _model_cache[ckpt_name]
ckpt_path = os.path.join(CHECKPOINT_DIR, f"{ckpt_name}.pth")
model = UFormer(config).to(DEVICE).eval()
state = torch.load(ckpt_path, map_location=DEVICE)
model.load_state_dict(state)
_model_cache[ckpt_name] = model
return model
# ββββββββββββββββββββββ
# 2) Overlap-add for long audio
# ββββββββββββββββββββββ
def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=2.5):
C, T = x.shape
chunk, hop = int(sr*chunk_s), int(sr*hop_s)
pad = (-(T - chunk) % hop) if T > chunk else 0
x_pad = np.pad(x, ((0,0),(0,pad)), mode="reflect")
win = np.hanning(chunk)[None, :]
out = np.zeros_like(x_pad)
norm = np.zeros((1, x_pad.shape[1]))
n_chunks = 1 + (x_pad.shape[1] - chunk) // hop
print(f"Processing {n_chunks} chunks of size {chunk} with hop {hop}...")
for i in range(n_chunks):
s = i * hop
seg = x_pad[:, s:s+chunk].astype(np.float32)
with torch.no_grad():
y = model(torch.from_numpy(seg[None]).to(DEVICE)).squeeze(0).cpu().numpy()
out[:, s:s+chunk] += y * win
norm[:, s:s+chunk] += win
eps = 1e-8
return (out / (norm + eps))[:, :T]
# ββββββββββββββββββββββ
# 3) Restore function for Gradio
# ββββββββββββββββββββββ
def restore_fn(audio_path, checkpoint):
audio, sr = sf.read(audio_path)
if audio.ndim == 1:
audio = np.stack([audio, audio], axis=1)
x = audio.T # (C, T)
model = _get_model(checkpoint)
if x.shape[1] <= sr * 5:
seg = x.astype(np.float32)[None]
with torch.no_grad():
y = model(torch.from_numpy(seg).to(DEVICE)).squeeze(0).cpu().numpy()
else:
y = _overlap_add(model, x, sr)
tmp = "restored.wav"
sf.write(tmp, y.T, sr, format="WAV")
return tmp
# ββββββββββββββββββββββ
# 4) Gradio App
# ββββββββββββββββββββββ
demo = gr.Interface(
fn=restore_fn,
inputs=[
gr.Audio(sources="upload", type="filepath", label="Your Input"),
gr.Dropdown(VALID_CKPTS, label="Checkpoint")
],
outputs=gr.Audio(type="filepath", label="Restored Output"),
title="π΅ Music Source Restoration",
description="Upload an (stereo) audio file and choose an instrument/group checkpoint to restore. Please note that these are baseline models for demonstration purposes only, and most of them don't perform really well...",
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()
else:
demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
|