lym0302
our
1fd4e9c
# coding=utf-8
# Remove voice-over
import logging
import argparse
import subprocess
import librosa
import os
import torch
import soundfile as sf
import numpy as np
# Using the embedded version of Python can also correctly import the utils module.
# current_dir = os.path.dirname(os.path.abspath(__file__))
# sys.path.append(current_dir)
from third_party.MusicSourceSeparationTraining.utils import demix, load_config, normalize_audio, denormalize_audio, draw_spectrogram
from third_party.MusicSourceSeparationTraining.utils import prefer_target_instrument, apply_tta, load_start_checkpoint
from third_party.MusicSourceSeparationTraining.models.bs_roformer import BSRoformer
import warnings
warnings.filterwarnings("ignore")
model_base_dir = "pretrained/remove_vo/checkpoints"
MODEL_PATHS = {"bs_roformer": [f"{model_base_dir}/model_bs_roformer_ep_317_sdr_12.9755.ckpt", f"{model_base_dir}/model_bs_roformer_ep_317_sdr_12.9755.yaml"]}
class Step3:
def __init__(self, model_type="bs_roformer"):
model_path, config_path = MODEL_PATHS[model_type]
self.log = logging.getLogger(self.__class__.__name__)
self.log.setLevel(logging.INFO)
self.device = 'cpu'
if torch.cuda.is_available():
self.device = 'cuda'
elif torch.backends.mps.is_available():
self.device = 'mps'
else:
self.log.warning('CUDA/MPS are not available, running on CPU')
self.model_type = model_type
# self.model, self.config = get_model_from_config(model_type, config_path)
self.config = load_config(model_type, config_path)
self.model = BSRoformer(**dict(self.config.model))
args = argparse.Namespace()
args.start_check_point = model_path
args.model_type = model_type
args.lora_checkpoint = ''
load_start_checkpoint(args, self.model, type_='inference')
self.model = self.model.to(self.device)
self.sample_rate = getattr(self.config.audio, 'sample_rate', 44100)
def run(self,
input_audio_path,
temp_store_dir, # for remove result dir
output_dir, # for final dir
disable_detailed_pbar: bool=False,
use_tta: bool= False,
extract_instrumental: bool=True,
codec="wav",
subtype="FLOAT",
draw_spectro=0,
):
# self.log.info("Step3: Remove voice-over from audio.")
os.makedirs(output_dir, exist_ok=True)
if disable_detailed_pbar:
detailed_pbar = False
else:
detailed_pbar = True
instruments = prefer_target_instrument(self.config)[:]
mix, sr = librosa.load(input_audio_path, sr=self.sample_rate, mono=False)
# If mono audio we must adjust it depending on model
if len(mix.shape) == 1:
mix = np.expand_dims(mix, axis=0)
if 'num_channels' in self.config.audio:
if self.config.audio['num_channels'] == 2:
print(f'Convert mono track to stereo...')
mix = np.concatenate([mix, mix], axis=0)
mix_orig = mix.copy()
if 'normalize' in self.config.inference:
if self.config.inference['normalize'] is True:
mix, norm_params = normalize_audio(mix)
waveforms_orig = demix(self.config, self.model, mix, self.device, model_type=self.model_type, pbar=detailed_pbar)
if use_tta:
waveforms_orig = apply_tta(self.config, self.model, mix, waveforms_orig, self.device, self.model_type)
if extract_instrumental:
instr = 'vocals' if 'vocals' in instruments else instruments[0]
waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr]
if 'instrumental' not in instruments:
instruments.append('instrumental')
file_name = os.path.splitext(os.path.basename(input_audio_path))[0].replace(".step1", "")
temp_output_dir = os.path.join(temp_store_dir, file_name)
os.makedirs(temp_output_dir, exist_ok=True)
for instr in instruments:
estimates = waveforms_orig[instr]
if 'normalize' in self.config.inference:
if self.config.inference['normalize'] is True:
estimates = denormalize_audio(estimates, norm_params)
output_path = os.path.join(temp_output_dir, f"{instr}.{codec}")
sf.write(output_path, estimates.T, sr, subtype=subtype)
if draw_spectro > 0:
output_img_path = os.path.join(temp_output_dir, f"{instr}.jpg")
draw_spectrogram(estimates.T, sr, draw_spectro, output_img_path)
instrumental_file = os.path.join(temp_output_dir, 'instrumental.wav')
step3_audio_path = f"{output_dir}/{file_name}.step3.wav"
subprocess.run(['cp', instrumental_file, step3_audio_path])
self.log.info(f"The voice-over has been removed, and the audio is saved in {step3_audio_path}")
self.log.info("Finish Step3 successfully.\n")
return step3_audio_path