File size: 5,162 Bytes
1fd4e9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# coding=utf-8
# Remove voice-over
import logging
import argparse
import subprocess
import librosa
import os
import torch
import soundfile as sf
import numpy as np


# Using the embedded version of Python can also correctly import the utils module.
# current_dir = os.path.dirname(os.path.abspath(__file__))
# sys.path.append(current_dir)

from third_party.MusicSourceSeparationTraining.utils import demix, load_config, normalize_audio, denormalize_audio, draw_spectrogram
from third_party.MusicSourceSeparationTraining.utils import prefer_target_instrument, apply_tta, load_start_checkpoint
from third_party.MusicSourceSeparationTraining.models.bs_roformer import BSRoformer
import warnings

warnings.filterwarnings("ignore")

model_base_dir = "pretrained/remove_vo/checkpoints"
MODEL_PATHS = {"bs_roformer": [f"{model_base_dir}/model_bs_roformer_ep_317_sdr_12.9755.ckpt", f"{model_base_dir}/model_bs_roformer_ep_317_sdr_12.9755.yaml"]}


class Step3:
    def __init__(self, model_type="bs_roformer"):
        model_path, config_path = MODEL_PATHS[model_type]
        
        self.log = logging.getLogger(self.__class__.__name__)
        self.log.setLevel(logging.INFO)
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.device = 'cuda'
        elif torch.backends.mps.is_available():
            self.device = 'mps'
        else:
            self.log.warning('CUDA/MPS are not available, running on CPU')
        
        self.model_type = model_type

        # self.model, self.config = get_model_from_config(model_type, config_path)
        self.config = load_config(model_type, config_path)
        self.model = BSRoformer(**dict(self.config.model))
        args = argparse.Namespace()
        args.start_check_point = model_path
        args.model_type = model_type
        args.lora_checkpoint = ''
        load_start_checkpoint(args, self.model, type_='inference')
        self.model = self.model.to(self.device)
        self.sample_rate = getattr(self.config.audio, 'sample_rate', 44100)

        
    def run(self,
            input_audio_path,
            temp_store_dir,  # for remove result dir
            output_dir,  # for final dir
            disable_detailed_pbar: bool=False,
            use_tta: bool= False,
            extract_instrumental: bool=True,
            codec="wav",
            subtype="FLOAT",
            draw_spectro=0,
            ):
        
        # self.log.info("Step3: Remove voice-over from audio.")
        
        os.makedirs(output_dir, exist_ok=True)
        
        if disable_detailed_pbar:
            detailed_pbar = False
        else:
            detailed_pbar = True

        instruments = prefer_target_instrument(self.config)[:]
        
        mix, sr = librosa.load(input_audio_path, sr=self.sample_rate, mono=False)
        # If mono audio we must adjust it depending on model
        if len(mix.shape) == 1:
            mix = np.expand_dims(mix, axis=0)
            if 'num_channels' in self.config.audio:
                if self.config.audio['num_channels'] == 2:
                    print(f'Convert mono track to stereo...')
                    mix = np.concatenate([mix, mix], axis=0)

        mix_orig = mix.copy()
        if 'normalize' in self.config.inference:
            if self.config.inference['normalize'] is True:
                mix, norm_params = normalize_audio(mix)

        waveforms_orig = demix(self.config, self.model, mix, self.device, model_type=self.model_type, pbar=detailed_pbar)
        if use_tta:
            waveforms_orig = apply_tta(self.config, self.model, mix, waveforms_orig, self.device, self.model_type)

        if extract_instrumental:
            instr = 'vocals' if 'vocals' in instruments else instruments[0]
            waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr]
            if 'instrumental' not in instruments:
                instruments.append('instrumental')

        file_name = os.path.splitext(os.path.basename(input_audio_path))[0].replace(".step1", "")
        temp_output_dir = os.path.join(temp_store_dir, file_name)
        os.makedirs(temp_output_dir, exist_ok=True)

        for instr in instruments:
            estimates = waveforms_orig[instr]
            if 'normalize' in self.config.inference:
                if self.config.inference['normalize'] is True:
                    estimates = denormalize_audio(estimates, norm_params)

            output_path = os.path.join(temp_output_dir, f"{instr}.{codec}")
            sf.write(output_path, estimates.T, sr, subtype=subtype)
            if draw_spectro > 0:
                output_img_path = os.path.join(temp_output_dir, f"{instr}.jpg")
                draw_spectrogram(estimates.T, sr, draw_spectro, output_img_path)


        instrumental_file = os.path.join(temp_output_dir, 'instrumental.wav')
        step3_audio_path = f"{output_dir}/{file_name}.step3.wav"
        subprocess.run(['cp', instrumental_file, step3_audio_path])

        self.log.info(f"The voice-over has been removed, and the audio is saved in {step3_audio_path}")
        self.log.info("Finish Step3 successfully.\n")
        return step3_audio_path