#!/usr/bin/env python import os import torch import numpy as np import soundfile as sf import gradio as gr from model import UFormer, UFormerConfig # —————————————————————— # 1) Setup & model loading from local checkpoints # —————————————————————— DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CHECKPOINT_DIR = "checkpoints" config = UFormerConfig() _model_cache = {} VALID_CKPTS = [ "acoustic_guitar","bass","electric_guitar","guitars","keyboards", "orchestra","rhythm_section","synth","vocals" ] def _get_model(ckpt_name: str): if ckpt_name not in VALID_CKPTS: raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}") if ckpt_name in _model_cache: return _model_cache[ckpt_name] ckpt_path = os.path.join(CHECKPOINT_DIR, f"{ckpt_name}.pth") model = UFormer(config).to(DEVICE).eval() state = torch.load(ckpt_path, map_location=DEVICE) model.load_state_dict(state) _model_cache[ckpt_name] = model return model # —————————————————————— # 2) Overlap-add for long audio # —————————————————————— def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=2.5): C, T = x.shape chunk, hop = int(sr*chunk_s), int(sr*hop_s) pad = (-(T - chunk) % hop) if T > chunk else 0 x_pad = np.pad(x, ((0,0),(0,pad)), mode="reflect") win = np.hanning(chunk)[None, :] out = np.zeros_like(x_pad) norm = np.zeros((1, x_pad.shape[1])) n_chunks = 1 + (x_pad.shape[1] - chunk) // hop print(f"Processing {n_chunks} chunks of size {chunk} with hop {hop}...") for i in range(n_chunks): s = i * hop seg = x_pad[:, s:s+chunk].astype(np.float32) with torch.no_grad(): y = model(torch.from_numpy(seg[None]).to(DEVICE)).squeeze(0).cpu().numpy() out[:, s:s+chunk] += y * win norm[:, s:s+chunk] += win eps = 1e-8 return (out / (norm + eps))[:, :T] # —————————————————————— # 3) Restore function for Gradio # —————————————————————— def restore_fn(audio_path, checkpoint): audio, sr = sf.read(audio_path) if audio.ndim == 1: audio = np.stack([audio, audio], axis=1) x = audio.T # (C, T) model = _get_model(checkpoint) if x.shape[1] <= sr * 5: seg = x.astype(np.float32)[None] with torch.no_grad(): y = model(torch.from_numpy(seg).to(DEVICE)).squeeze(0).cpu().numpy() else: y = _overlap_add(model, x, sr) tmp = "restored.wav" sf.write(tmp, y.T, sr, format="WAV") return tmp # —————————————————————— # 4) Gradio App # —————————————————————— demo = gr.Interface( fn=restore_fn, inputs=[ gr.Audio(sources="upload", type="filepath", label="Your Input"), gr.Dropdown(VALID_CKPTS, label="Checkpoint") ], outputs=gr.Audio(type="filepath", label="Restored Output"), title="🎵 Music Source Restoration", description="Upload an (stereo) audio file and choose an instrument/group checkpoint to restore. Please note that these are baseline models for demonstration purposes only, and most of them don't perform really well...", allow_flagging="never" ) if __name__ == "__main__": demo.launch() else: demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))