|
import os |
|
import sys |
|
import onnx |
|
import torch |
|
import platform |
|
import onnx2torch |
|
|
|
import numpy as np |
|
import onnxruntime as ort |
|
|
|
from tqdm import tqdm |
|
|
|
now_dir = os.getcwd() |
|
sys.path.append(now_dir) |
|
|
|
from main.configs.config import Config |
|
from main.library.uvr5_separator import spec_utils |
|
from main.library.uvr5_separator.common_separator import CommonSeparator |
|
|
|
|
|
translations = Config().translations |
|
|
|
class MDXSeparator(CommonSeparator): |
|
def __init__(self, common_config, arch_config): |
|
super().__init__(config=common_config) |
|
|
|
self.segment_size = arch_config.get("segment_size") |
|
self.overlap = arch_config.get("overlap") |
|
self.batch_size = arch_config.get("batch_size", 1) |
|
self.hop_length = arch_config.get("hop_length") |
|
self.enable_denoise = arch_config.get("enable_denoise") |
|
self.logger.debug(translations["mdx_info"].format(batch_size=self.batch_size, segment_size=self.segment_size)) |
|
self.logger.debug(translations["mdx_info_2"].format(overlap=self.overlap, hop_length=self.hop_length, enable_denoise=self.enable_denoise)) |
|
self.compensate = self.model_data["compensate"] |
|
self.dim_f = self.model_data["mdx_dim_f_set"] |
|
self.dim_t = 2 ** self.model_data["mdx_dim_t_set"] |
|
self.n_fft = self.model_data["mdx_n_fft_scale_set"] |
|
self.config_yaml = self.model_data.get("config_yaml", None) |
|
self.logger.debug(f"{translations['mdx_info_3']}: compensate = {self.compensate}, dim_f = {self.dim_f}, dim_t = {self.dim_t}, n_fft = {self.n_fft}") |
|
self.logger.debug(f"{translations['mdx_info_3']}: config_yaml = {self.config_yaml}") |
|
self.load_model() |
|
self.n_bins = 0 |
|
self.trim = 0 |
|
self.chunk_size = 0 |
|
self.gen_size = 0 |
|
self.stft = None |
|
self.primary_source = None |
|
self.secondary_source = None |
|
self.audio_file_path = None |
|
self.audio_file_base = None |
|
|
|
|
|
def load_model(self): |
|
self.logger.debug(translations["load_model_onnx"]) |
|
|
|
if self.segment_size == self.dim_t: |
|
ort_session_options = ort.SessionOptions() |
|
|
|
if self.log_level > 10: ort_session_options.log_severity_level = 3 |
|
else: ort_session_options.log_severity_level = 0 |
|
|
|
ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options) |
|
self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0] |
|
self.logger.debug(translations["load_model_onnx_success"]) |
|
else: |
|
if platform.system() == 'Windows': |
|
onnx_model = onnx.load(self.model_path) |
|
self.model_run = onnx2torch.convert(onnx_model) |
|
else: self.model_run = onnx2torch.convert(self.model_path) |
|
|
|
self.model_run.to(self.torch_device).eval() |
|
self.logger.warning(translations["onnx_to_pytorch"]) |
|
|
|
def separate(self, audio_file_path): |
|
self.audio_file_path = audio_file_path |
|
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0] |
|
|
|
self.logger.debug(translations["mix"].format(audio_file_path=self.audio_file_path)) |
|
mix = self.prepare_mix(self.audio_file_path) |
|
|
|
self.logger.debug(translations["normalization_demix"]) |
|
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold) |
|
|
|
source = self.demix(mix) |
|
self.logger.debug(translations["mix_success"]) |
|
|
|
output_files = [] |
|
self.logger.debug(translations["process_output_file"]) |
|
|
|
if not isinstance(self.primary_source, np.ndarray): |
|
self.logger.debug(translations["primary_source"]) |
|
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold).T |
|
if not isinstance(self.secondary_source, np.ndarray): |
|
self.logger.debug(translations["secondary_source"]) |
|
raw_mix = self.demix(mix, is_match_mix=True) |
|
|
|
if self.invert_using_spec: |
|
self.logger.debug(translations["invert_using_spec"]) |
|
self.secondary_source = spec_utils.invert_stem(raw_mix, source) |
|
else: |
|
self.logger.debug(translations["invert_using_spec_2"]) |
|
self.secondary_source = mix.T - source.T |
|
|
|
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower(): |
|
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}") |
|
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.secondary_stem_name, stem_output_path=self.secondary_stem_output_path)) |
|
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name) |
|
output_files.append(self.secondary_stem_output_path) |
|
|
|
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower(): |
|
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}") |
|
|
|
if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T |
|
|
|
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.primary_stem_name, stem_output_path=self.primary_stem_output_path)) |
|
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name) |
|
output_files.append(self.primary_stem_output_path) |
|
|
|
return output_files |
|
|
|
def initialize_model_settings(self): |
|
self.logger.debug(translations["starting_model"]) |
|
|
|
self.n_bins = self.n_fft // 2 + 1 |
|
self.trim = self.n_fft // 2 |
|
|
|
self.chunk_size = self.hop_length * (self.segment_size - 1) |
|
self.gen_size = self.chunk_size - 2 * self.trim |
|
|
|
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device) |
|
|
|
self.logger.debug(f"{translations['input_info']}: n_fft = {self.n_fft} hop_length = {self.hop_length} dim_f = {self.dim_f}") |
|
self.logger.debug(f"{translations['model_settings']}: n_bins = {self.n_bins}, Trim = {self.trim}, chunk_size = {self.chunk_size}, gen_size = {self.gen_size}") |
|
|
|
def initialize_mix(self, mix, is_ckpt=False): |
|
self.logger.debug(translations["initialize_mix"].format(is_ckpt=is_ckpt, shape=mix.shape)) |
|
|
|
if mix.shape[0] != 2: |
|
error_message = translations["!=2"].format(shape=mix.shape[0]) |
|
self.logger.error(error_message) |
|
raise ValueError(error_message) |
|
|
|
if is_ckpt: |
|
self.logger.debug(translations["process_check"]) |
|
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size) |
|
self.logger.debug(f"{translations['cache']}: {pad}") |
|
|
|
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1) |
|
|
|
num_chunks = mixture.shape[-1] // self.gen_size |
|
self.logger.debug(translations["shape"].format(shape=mixture.shape, num_chunks=num_chunks)) |
|
|
|
mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)] |
|
else: |
|
self.logger.debug(translations["process_no_check"]) |
|
mix_waves = [] |
|
n_sample = mix.shape[1] |
|
|
|
pad = self.gen_size - n_sample % self.gen_size |
|
self.logger.debug(translations["n_sample_or_pad"].format(n_sample=n_sample, pad=pad)) |
|
|
|
mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1) |
|
self.logger.debug(f"{translations['shape_2']}: {mix_p.shape}") |
|
|
|
i = 0 |
|
while i < n_sample + pad: |
|
waves = np.array(mix_p[:, i : i + self.chunk_size]) |
|
mix_waves.append(waves) |
|
|
|
self.logger.debug(translations["process_part"].format(mix_waves=len(mix_waves), i=i, ii=i + self.chunk_size)) |
|
i += self.gen_size |
|
|
|
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device) |
|
self.logger.debug(translations["mix_waves_to_tensor"].format(shape=mix_waves_tensor.shape)) |
|
|
|
return mix_waves_tensor, pad |
|
|
|
def demix(self, mix, is_match_mix=False): |
|
self.logger.debug(f"{translations['demix_is_match_mix']}: {is_match_mix}...") |
|
self.initialize_model_settings() |
|
|
|
org_mix = mix |
|
self.logger.debug(f"{translations['mix_shape']}: {org_mix.shape}") |
|
|
|
tar_waves_ = [] |
|
|
|
if is_match_mix: |
|
chunk_size = self.hop_length * (self.segment_size - 1) |
|
overlap = 0.02 |
|
self.logger.debug(translations["chunk_size_or_overlap"].format(chunk_size=chunk_size, overlap=overlap)) |
|
else: |
|
chunk_size = self.chunk_size |
|
overlap = self.overlap |
|
self.logger.debug(translations["chunk_size_or_overlap_standard"].format(chunk_size=chunk_size, overlap=overlap)) |
|
|
|
|
|
gen_size = chunk_size - 2 * self.trim |
|
self.logger.debug(f"{translations['calc_size']}: {gen_size}") |
|
|
|
|
|
pad = gen_size + self.trim - ((mix.shape[-1]) % gen_size) |
|
|
|
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1) |
|
self.logger.debug(f"{translations['mix_cache']}: {mixture.shape}") |
|
|
|
step = int((1 - overlap) * chunk_size) |
|
self.logger.debug(translations["step_or_overlap"].format(step=step, overlap=overlap)) |
|
|
|
result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32) |
|
divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32) |
|
|
|
total = 0 |
|
total_chunks = (mixture.shape[-1] + step - 1) // step |
|
self.logger.debug(f"{translations['all_process_part']}: {total_chunks}") |
|
|
|
for i in tqdm(range(0, mixture.shape[-1], step)): |
|
total += 1 |
|
start = i |
|
end = min(i + chunk_size, mixture.shape[-1]) |
|
self.logger.debug(translations["process_part_2"].format(total=total, total_chunks=total_chunks, start=start, end=end)) |
|
|
|
chunk_size_actual = end - start |
|
window = None |
|
|
|
if overlap != 0: |
|
window = np.hanning(chunk_size_actual) |
|
window = np.tile(window[None, None, :], (1, 2, 1)) |
|
self.logger.debug(translations["window"]) |
|
|
|
mix_part_ = mixture[:, start:end] |
|
|
|
if end != i + chunk_size: |
|
pad_size = (i + chunk_size) - end |
|
mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1) |
|
|
|
mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device) |
|
|
|
mix_waves = mix_part.split(self.batch_size) |
|
total_batches = len(mix_waves) |
|
self.logger.debug(f"{translations['mix_or_batch']}: {total_batches}") |
|
|
|
with torch.no_grad(): |
|
batches_processed = 0 |
|
for mix_wave in mix_waves: |
|
batches_processed += 1 |
|
self.logger.debug(f"{translations['mix_wave']} {batches_processed}/{total_batches}") |
|
|
|
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix) |
|
|
|
if window is not None: |
|
tar_waves[..., :chunk_size_actual] *= window |
|
divider[..., start:end] += window |
|
else: divider[..., start:end] += 1 |
|
|
|
result[..., start:end] += tar_waves[..., : end - start] |
|
|
|
|
|
self.logger.debug(translations["normalization_2"]) |
|
tar_waves = result / divider |
|
tar_waves_.append(tar_waves) |
|
|
|
tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim : -self.trim] |
|
tar_waves = np.concatenate(tar_waves_, axis=-1)[:, : mix.shape[-1]] |
|
|
|
source = tar_waves[:, 0:None] |
|
self.logger.debug(f"{translations['tar_waves']}: {tar_waves.shape}") |
|
|
|
if not is_match_mix: |
|
source *= self.compensate |
|
self.logger.debug(translations["mix_match"]) |
|
|
|
self.logger.debug(translations["mix_success"]) |
|
return source |
|
|
|
def run_model(self, mix, is_match_mix=False): |
|
spek = self.stft(mix.to(self.torch_device)) |
|
self.logger.debug(translations["stft_2"].format(shape=spek.shape)) |
|
|
|
spek[:, :, :3, :] *= 0 |
|
|
|
if is_match_mix: |
|
spec_pred = spek.cpu().numpy() |
|
self.logger.debug(translations["is_match_mix"]) |
|
else: |
|
if self.enable_denoise: |
|
spec_pred_neg = self.model_run(-spek) |
|
spec_pred_pos = self.model_run(spek) |
|
spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5) |
|
self.logger.debug(translations["enable_denoise"]) |
|
else: |
|
spec_pred = self.model_run(spek) |
|
self.logger.debug(translations["no_denoise"]) |
|
|
|
result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy() |
|
self.logger.debug(f"{translations['stft']}: {result.shape}") |
|
|
|
return result |
|
|
|
class STFT: |
|
def __init__(self, logger, n_fft, hop_length, dim_f, device): |
|
self.logger = logger |
|
self.n_fft = n_fft |
|
self.hop_length = hop_length |
|
self.dim_f = dim_f |
|
self.device = device |
|
self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True) |
|
|
|
def __call__(self, input_tensor): |
|
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"] |
|
|
|
if is_non_standard_device: input_tensor = input_tensor.cpu() |
|
|
|
stft_window = self.hann_window.to(input_tensor.device) |
|
|
|
batch_dimensions = input_tensor.shape[:-2] |
|
channel_dim, time_dim = input_tensor.shape[-2:] |
|
|
|
reshaped_tensor = input_tensor.reshape([-1, time_dim]) |
|
stft_output = torch.stft(reshaped_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True, return_complex=False) |
|
|
|
permuted_stft_output = stft_output.permute([0, 3, 1, 2]) |
|
|
|
final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape([*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]]) |
|
|
|
if is_non_standard_device: final_output = final_output.to(self.device) |
|
|
|
return final_output[..., : self.dim_f, :] |
|
|
|
def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins): |
|
freq_padding = torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device) |
|
padded_tensor = torch.cat([input_tensor, freq_padding], -2) |
|
|
|
return padded_tensor |
|
|
|
def calculate_inverse_dimensions(self, input_tensor): |
|
batch_dimensions = input_tensor.shape[:-3] |
|
channel_dim, freq_dim, time_dim = input_tensor.shape[-3:] |
|
|
|
num_freq_bins = self.n_fft // 2 + 1 |
|
|
|
return batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins |
|
|
|
def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim): |
|
reshaped_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim]) |
|
flattened_tensor = reshaped_tensor.reshape([-1, 2, num_freq_bins, time_dim]) |
|
permuted_tensor = flattened_tensor.permute([0, 2, 3, 1]) |
|
complex_tensor = permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j |
|
|
|
return complex_tensor |
|
|
|
def inverse(self, input_tensor): |
|
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"] |
|
|
|
if is_non_standard_device: input_tensor = input_tensor.cpu() |
|
|
|
stft_window = self.hann_window.to(input_tensor.device) |
|
|
|
batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor) |
|
|
|
padded_tensor = self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins) |
|
|
|
complex_tensor = self.prepare_for_istft(padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim) |
|
|
|
istft_result = torch.istft(complex_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True) |
|
|
|
final_output = istft_result.reshape([*batch_dimensions, 2, -1]) |
|
|
|
if is_non_standard_device: final_output = final_output.to(self.device) |
|
|
|
return final_output |