AkitoP's picture
Upload 87 files
0065413 verified
raw
history blame
8.38 kB
# This code is modified from https://github.com/ZFTurbo/
import pdb
import librosa
from tqdm import tqdm
import os
import torch
import numpy as np
import soundfile as sf
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore")
from bs_roformer.bs_roformer import BSRoformer
class BsRoformer_Loader:
def get_model_from_config(self):
config = {
"attn_dropout": 0.1,
"depth": 12,
"dim": 512,
"dim_freqs_in": 1025,
"dim_head": 64,
"ff_dropout": 0.1,
"flash_attn": True,
"freq_transformer_depth": 1,
"freqs_per_bands":(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 12, 12, 12, 12, 12, 12, 12, 24, 24, 24, 24, 24, 24, 24, 24, 48, 48, 48, 48, 48, 48, 48, 48, 128, 129),
"heads": 8,
"linear_transformer_depth": 0,
"mask_estimator_depth": 2,
"multi_stft_hop_size": 147,
"multi_stft_normalized": False,
"multi_stft_resolution_loss_weight": 1.0,
"multi_stft_resolutions_window_sizes":(4096, 2048, 1024, 512, 256),
"num_stems": 1,
"stereo": True,
"stft_hop_length": 441,
"stft_n_fft": 2048,
"stft_normalized": False,
"stft_win_length": 2048,
"time_transformer_depth": 1,
}
model = BSRoformer(
**dict(config)
)
return model
def demix_track(self, model, mix, device):
C = 352800
# num_overlap
N = 1
fade_size = C // 10
step = int(C // N)
border = C - step
batch_size = 4
length_init = mix.shape[-1]
progress_bar = tqdm(total=length_init // step + 1)
progress_bar.set_description("Processing")
# Do pad from the beginning and end to account floating window results better
if length_init > 2 * border and (border > 0):
mix = nn.functional.pad(mix, (border, border), mode='reflect')
# Prepare windows arrays (do 1 time for speed up). This trick repairs click problems on the edges of segment
window_size = C
fadein = torch.linspace(0, 1, fade_size)
fadeout = torch.linspace(1, 0, fade_size)
window_start = torch.ones(window_size)
window_middle = torch.ones(window_size)
window_finish = torch.ones(window_size)
window_start[-fade_size:] *= fadeout # First audio chunk, no fadein
window_finish[:fade_size] *= fadein # Last audio chunk, no fadeout
window_middle[-fade_size:] *= fadeout
window_middle[:fade_size] *= fadein
with torch.amp.autocast('cuda'):
with torch.inference_mode():
req_shape = (1, ) + tuple(mix.shape)
result = torch.zeros(req_shape, dtype=torch.float32)
counter = torch.zeros(req_shape, dtype=torch.float32)
i = 0
batch_data = []
batch_locations = []
while i < mix.shape[1]:
part = mix[:, i:i + C].to(device)
length = part.shape[-1]
if length < C:
if length > C // 2 + 1:
part = nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
else:
part = nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
if(self.is_half==True):
part=part.half()
batch_data.append(part)
batch_locations.append((i, length))
i += step
progress_bar.update(1)
if len(batch_data) >= batch_size or (i >= mix.shape[1]):
arr = torch.stack(batch_data, dim=0)
# print(23333333,arr.dtype)
x = model(arr)
window = window_middle
if i - step == 0: # First audio chunk, no fadein
window = window_start
elif i >= mix.shape[1]: # Last audio chunk, no fadeout
window = window_finish
for j in range(len(batch_locations)):
start, l = batch_locations[j]
result[..., start:start+l] += x[j][..., :l].cpu() * window[..., :l]
counter[..., start:start+l] += window[..., :l]
batch_data = []
batch_locations = []
estimated_sources = result / counter
estimated_sources = estimated_sources.cpu().numpy()
np.nan_to_num(estimated_sources, copy=False, nan=0.0)
if length_init > 2 * border and (border > 0):
# Remove pad
estimated_sources = estimated_sources[..., border:-border]
progress_bar.close()
return {k: v for k, v in zip(['vocals', 'other'], estimated_sources)}
def run_folder(self,input, vocal_root, others_root, format):
# start_time = time.time()
self.model.eval()
path = input
if not os.path.isdir(vocal_root):
os.mkdir(vocal_root)
if not os.path.isdir(others_root):
os.mkdir(others_root)
try:
mix, sr = librosa.load(path, sr=44100, mono=False)
except Exception as e:
print('Can read track: {}'.format(path))
print('Error message: {}'.format(str(e)))
return
# Convert mono to stereo if needed
if len(mix.shape) == 1:
mix = np.stack([mix, mix], axis=0)
mix_orig = mix.copy()
mixture = torch.tensor(mix, dtype=torch.float32)
res = self.demix_track(self.model, mixture, self.device)
estimates = res['vocals'].T
if format in ["wav", "flac"]:
sf.write("{}/{}_{}.{}".format(vocal_root, os.path.basename(path)[:-4], 'vocals', format), estimates, sr)
sf.write("{}/{}_{}.{}".format(others_root, os.path.basename(path)[:-4], 'instrumental', format), mix_orig.T - estimates, sr)
else:
path_vocal = "%s/%s_vocals.wav" % (vocal_root, os.path.basename(path)[:-4])
path_other = "%s/%s_instrumental.wav" % (others_root, os.path.basename(path)[:-4])
sf.write(path_vocal, estimates, sr)
sf.write(path_other, mix_orig.T - estimates, sr)
opt_path_vocal = path_vocal[:-4] + ".%s" % format
opt_path_other = path_other[:-4] + ".%s" % format
if os.path.exists(path_vocal):
os.system(
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_vocal, opt_path_vocal)
)
if os.path.exists(opt_path_vocal):
try:
os.remove(path_vocal)
except:
pass
if os.path.exists(path_other):
os.system(
"ffmpeg -i '%s' -vn '%s' -q:a 2 -y" % (path_other, opt_path_other)
)
if os.path.exists(opt_path_other):
try:
os.remove(path_other)
except:
pass
# print("Elapsed time: {:.2f} sec".format(time.time() - start_time))
def __init__(self, model_path, device,is_half):
self.device = device
self.extract_instrumental=True
model = self.get_model_from_config()
state_dict = torch.load(model_path,map_location="cpu")
model.load_state_dict(state_dict)
self.is_half=is_half
if(is_half==False):
self.model = model.to(device)
else:
self.model = model.half().to(device)
def _path_audio_(self, input, others_root, vocal_root, format, is_hp3=False):
self.run_folder(input, vocal_root, others_root, format)