moseca / app /demucs_runner.py
Fabio Grasso
init repo
59e5631
raw
history blame
5.98 kB
import argparse
import sys
from pathlib import Path
from typing import List
from dora.log import fatal
import torch as th
from demucs.apply import apply_model, BagOfModels
from demucs.audio import save_audio
from demucs.pretrained import get_model_from_args, ModelLoadingError
from demucs.separate import load_track
def separator(
tracks: List[Path],
out: Path,
model: str,
device: str,
shifts: int,
overlap: float,
stem: str,
int24: bool,
float32: bool,
clip_mode: str,
mp3: bool,
mp3_bitrate: int,
jobs: int,
verbose: bool,
):
"""Separate the sources for the given tracks
Args:
tracks (Path): Path to tracks
out (Path): Folder where to put extracted tracks. A subfolder with the model name will be created.
model (str): Model name
device (str): Device to use, default is cuda if available else cpu
shifts (int): Number of random shifts for equivariant stabilization. Increase separation time but improves quality for Demucs. 10 was used in the original paper.
overlap (float): Overlap
stem (str): Only separate audio into {STEM} and no_{STEM}.
int24 (bool): Save wav output as 24 bits wav.
float32 (bool): Save wav output as float32 (2x bigger).
clip_mode (str): Strategy for avoiding clipping: rescaling entire signal if necessary (rescale) or hard clipping (clamp).
mp3 (bool): Convert the output wavs to mp3.
mp3_bitrate (int): Bitrate of converted mp3.
jobs (int): Number of jobs. This can increase memory usage but will be much faster when multiple cores are available.
verbose (bool): Verbose
"""
args = argparse.Namespace()
args.tracks = tracks
args.out = out
args.model = model
args.device = device
args.shifts = shifts
args.overlap = overlap
args.stem = stem
args.int24 = int24
args.float32 = float32
args.clip_mode = clip_mode
args.mp3 = mp3
args.mp3_bitrate = mp3_bitrate
args.jobs = jobs
args.verbose = verbose
args.filename = "{track}/{stem}.{ext}"
args.split = True
args.segment = None
args.name = model
args.repo = None
try:
model = get_model_from_args(args)
except ModelLoadingError as error:
fatal(error.args[0])
if args.segment is not None and args.segment < 8:
fatal("Segment must greater than 8. ")
if ".." in args.filename.replace("\\", "/").split("/"):
fatal('".." must not appear in filename. ')
if isinstance(model, BagOfModels):
print(
f"Selected model is a bag of {len(model.models)} models. "
"You will see that many progress bars per track."
)
if args.segment is not None:
for sub in model.models:
sub.segment = args.segment
else:
if args.segment is not None:
model.segment = args.segment
model.cpu()
model.eval()
if args.stem is not None and args.stem not in model.sources:
fatal(
'error: stem "{stem}" is not in selected model. STEM must be one of {sources}.'.format(
stem=args.stem, sources=", ".join(model.sources)
)
)
out = args.out / args.name
out.mkdir(parents=True, exist_ok=True)
print(f"Separated tracks will be stored in {out.resolve()}")
for track in args.tracks:
if not track.exists():
print(
f"File {track} does not exist. If the path contains spaces, "
'please try again after surrounding the entire path with quotes "".',
file=sys.stderr,
)
continue
print(f"Separating track {track}")
wav = load_track(track, model.audio_channels, model.samplerate)
ref = wav.mean(0)
wav = (wav - ref.mean()) / ref.std()
sources = apply_model(
model,
wav[None],
device=args.device,
shifts=args.shifts,
split=args.split,
overlap=args.overlap,
progress=True,
num_workers=args.jobs,
)[0]
sources = sources * ref.std() + ref.mean()
if args.mp3:
ext = "mp3"
else:
ext = "wav"
kwargs = {
"samplerate": model.samplerate,
"bitrate": args.mp3_bitrate,
"clip": args.clip_mode,
"as_float": args.float32,
"bits_per_sample": 24 if args.int24 else 16,
}
if args.stem is None:
for source, name in zip(sources, model.sources):
stem = out / args.filename.format(
track=track.name.rsplit(".", 1)[0],
trackext=track.name.rsplit(".", 1)[-1],
stem=name,
ext=ext,
)
stem.parent.mkdir(parents=True, exist_ok=True)
save_audio(source, str(stem), **kwargs)
else:
sources = list(sources)
stem = out / args.filename.format(
track=track.name.rsplit(".", 1)[0],
trackext=track.name.rsplit(".", 1)[-1],
stem=args.stem,
ext=ext,
)
stem.parent.mkdir(parents=True, exist_ok=True)
save_audio(sources.pop(model.sources.index(args.stem)), str(stem), **kwargs)
# Warning : after poping the stem, selected stem is no longer in the list 'sources'
other_stem = th.zeros_like(sources[0])
for i in sources:
other_stem += i
stem = out / args.filename.format(
track=track.name.rsplit(".", 1)[0],
trackext=track.name.rsplit(".", 1)[-1],
stem="no_" + args.stem,
ext=ext,
)
stem.parent.mkdir(parents=True, exist_ok=True)
save_audio(other_stem, str(stem), **kwargs)