MSR_UFormers / inference.py
yongyizang's picture
Upload folder using huggingface_hub
6485dde verified
#!/usr/bin/env python
import io, os, torch, numpy as np, soundfile as sf
from huggingface_hub import snapshot_download
from model import UFormer, UFormerConfig
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 1) Setup
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
REPO_ID = "yongyizang/MSR_UFormers"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
local_dir= snapshot_download(REPO_ID)
config = UFormerConfig()
_model_cache = {}
VALID_CKPTS = [
"acoustic_guitar","bass","electric_guitar","guitars","keyboards",
"orchestra","rhythm_section","synth","vocals"
]
def _get_model(ckpt_name: str):
if ckpt_name not in VALID_CKPTS:
raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}")
if ckpt_name in _model_cache:
return _model_cache[ckpt_name]
path = os.path.join(local_dir, "checkpoints", f"{ckpt_name}.pth")
m = UFormer(config).to(device).eval()
sd = torch.load(path, map_location="cpu")
m.load_state_dict(sd)
_model_cache[ckpt_name] = m
return m
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 2) Overlap-add helper
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=2.5):
C, T = x.shape
chunk, hop = int(sr*chunk_s), int(sr*hop_s)
pad = (-(T - chunk) % hop) if T>chunk else 0
x_pad = np.pad(x, ((0,0),(0,pad)), mode="reflect")
win = np.hanning(chunk)[None,:]
out = np.zeros_like(x_pad); norm = np.zeros((1,x_pad.shape[1]))
n_chunks = 1 + (x_pad.shape[1] - chunk)//hop
for i in range(n_chunks):
s = i*hop
seg = x_pad[:, s:s+chunk]
seg = seg.astype(np.float32) # Ensure float32 for model input
with torch.no_grad():
y = model(torch.from_numpy(seg[None]).to(device)).squeeze(0).cpu().numpy()
out[:, s:s+chunk] += y * win
norm[:, s:s+chunk] += win
return (out / norm)[:, :T]
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 3) HF Inference entry-point
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def inference(input_bytes: bytes, checkpoint: str = "guitars") -> bytes:
"""
audio_bytes in β†’ restored_bytes out.
Pass {"inputs": <bytes>, "parameters": {"checkpoint": "<name>"}} to choose.
"""
audio, sr = sf.read(io.BytesIO(input_bytes))
if audio.ndim==1: audio = np.stack([audio,audio],axis=1)
x = audio.T # (C,T)
model = _get_model(checkpoint)
if x.shape[1] <= sr*5:
with torch.no_grad():
y = model(torch.from_numpy(x[None]).to(device)).squeeze(0).cpu().numpy()
else:
y = _overlap_add(model, x, sr)
buf = io.BytesIO()
sf.write(buf, y.T, sr, format="WAV")
return buf.getvalue()
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 4) CLI & Gradio
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("UFormer RESTORE")
parser.add_argument("-i","--input", type=str, help="noisy WAV")
parser.add_argument("-o","--output",type=str, help="restored WAV")
parser.add_argument("-c","--checkpoint",type=str,default="guitars",
choices=VALID_CKPTS)
parser.add_argument("--serve",action="store_true", help="launch Gradio")
args = parser.parse_args()
if args.serve:
import gradio as gr
def _gr(path, ckpt):
return inference(open(path,"rb").read(), checkpoint=ckpt)
gr.Interface(
fn=_gr,
inputs=[
gr.Audio(source="upload", type="filepath"),
gr.Dropdown(VALID_CKPTS, label="Checkpoint")
],
outputs=gr.Audio(type="filepath"),
title="🎡 Music Source Restoration Restoration",
description="Choose which instrument/group model to run."
).launch()
else:
assert args.input and args.output
out = inference(open(args.input,"rb").read(),
checkpoint=args.checkpoint)
open(args.output,"wb").write(out)
print(f"βœ… Restored β†’ {args.output} using {args.checkpoint}")