# coding=utf-8 # Remove voice-over import logging import argparse import subprocess import librosa import os import torch import soundfile as sf import numpy as np # Using the embedded version of Python can also correctly import the utils module. # current_dir = os.path.dirname(os.path.abspath(__file__)) # sys.path.append(current_dir) from third_party.MusicSourceSeparationTraining.utils import demix, load_config, normalize_audio, denormalize_audio, draw_spectrogram from third_party.MusicSourceSeparationTraining.utils import prefer_target_instrument, apply_tta, load_start_checkpoint from third_party.MusicSourceSeparationTraining.models.bs_roformer import BSRoformer import warnings warnings.filterwarnings("ignore") model_base_dir = "pretrained/remove_vo/checkpoints" MODEL_PATHS = {"bs_roformer": [f"{model_base_dir}/model_bs_roformer_ep_317_sdr_12.9755.ckpt", f"{model_base_dir}/model_bs_roformer_ep_317_sdr_12.9755.yaml"]} class Step3: def __init__(self, model_type="bs_roformer"): model_path, config_path = MODEL_PATHS[model_type] self.log = logging.getLogger(self.__class__.__name__) self.log.setLevel(logging.INFO) self.device = 'cpu' if torch.cuda.is_available(): self.device = 'cuda' elif torch.backends.mps.is_available(): self.device = 'mps' else: self.log.warning('CUDA/MPS are not available, running on CPU') self.model_type = model_type # self.model, self.config = get_model_from_config(model_type, config_path) self.config = load_config(model_type, config_path) self.model = BSRoformer(**dict(self.config.model)) args = argparse.Namespace() args.start_check_point = model_path args.model_type = model_type args.lora_checkpoint = '' load_start_checkpoint(args, self.model, type_='inference') self.model = self.model.to(self.device) self.sample_rate = getattr(self.config.audio, 'sample_rate', 44100) def run(self, input_audio_path, temp_store_dir, # for remove result dir output_dir, # for final dir disable_detailed_pbar: bool=False, use_tta: bool= False, extract_instrumental: bool=True, codec="wav", subtype="FLOAT", draw_spectro=0, ): # self.log.info("Step3: Remove voice-over from audio.") os.makedirs(output_dir, exist_ok=True) if disable_detailed_pbar: detailed_pbar = False else: detailed_pbar = True instruments = prefer_target_instrument(self.config)[:] mix, sr = librosa.load(input_audio_path, sr=self.sample_rate, mono=False) # If mono audio we must adjust it depending on model if len(mix.shape) == 1: mix = np.expand_dims(mix, axis=0) if 'num_channels' in self.config.audio: if self.config.audio['num_channels'] == 2: print(f'Convert mono track to stereo...') mix = np.concatenate([mix, mix], axis=0) mix_orig = mix.copy() if 'normalize' in self.config.inference: if self.config.inference['normalize'] is True: mix, norm_params = normalize_audio(mix) waveforms_orig = demix(self.config, self.model, mix, self.device, model_type=self.model_type, pbar=detailed_pbar) if use_tta: waveforms_orig = apply_tta(self.config, self.model, mix, waveforms_orig, self.device, self.model_type) if extract_instrumental: instr = 'vocals' if 'vocals' in instruments else instruments[0] waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr] if 'instrumental' not in instruments: instruments.append('instrumental') file_name = os.path.splitext(os.path.basename(input_audio_path))[0].replace(".step1", "") temp_output_dir = os.path.join(temp_store_dir, file_name) os.makedirs(temp_output_dir, exist_ok=True) for instr in instruments: estimates = waveforms_orig[instr] if 'normalize' in self.config.inference: if self.config.inference['normalize'] is True: estimates = denormalize_audio(estimates, norm_params) output_path = os.path.join(temp_output_dir, f"{instr}.{codec}") sf.write(output_path, estimates.T, sr, subtype=subtype) if draw_spectro > 0: output_img_path = os.path.join(temp_output_dir, f"{instr}.jpg") draw_spectrogram(estimates.T, sr, draw_spectro, output_img_path) instrumental_file = os.path.join(temp_output_dir, 'instrumental.wav') step3_audio_path = f"{output_dir}/{file_name}.step3.wav" subprocess.run(['cp', instrumental_file, step3_audio_path]) self.log.info(f"The voice-over has been removed, and the audio is saved in {step3_audio_path}") self.log.info("Finish Step3 successfully.\n") return step3_audio_path