pengdaqian commited on
Commit
d853526
·
1 Parent(s): 6b780da
app.py CHANGED
@@ -14,10 +14,9 @@ import gradio as gr
14
  import librosa
15
  import numpy as np
16
  import soundfile
17
- from spleeter.separator import Separator
18
  from pydub import AudioSegment
19
  import uuid
20
-
21
  import logging
22
 
23
  logging.getLogger('numba').setLevel(logging.WARNING)
@@ -84,11 +83,11 @@ model = SynthesizerInfer(
84
  load_svc_model("vits_pretrain/sovits5.0-48k-debug.pth", model)
85
  model.eval()
86
  model.to(device)
87
- separator = Separator('spleeter:2stems')
88
  whisper_model = whisper.inference.load_model(os.path.join("whisper_pretrain", "medium.pt"))
89
 
 
90
  # warm up
91
- separator.separate_to_file('warm.wav', '/tmp/warm')
92
 
93
 
94
  def svc_change(argswave, argsspk):
@@ -196,7 +195,8 @@ def svc_main(sid, input_audio):
196
  soundfile.write(input_audio_tmp_file, audio, sampling_rate, format="wav")
197
  if not os.path.exists(tmpfile_path):
198
  os.makedirs(tmpfile_path)
199
- separator.separate_to_file(input_audio_tmp_file, tmpfile_path)
 
200
 
201
  curr_tmp_path = os.path.join(tmpfile_path, os.path.splitext(input_audio_tmp_file)[0])
202
  vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
@@ -346,8 +346,8 @@ def main():
346
 
347
  app.queue(max_size=3, api_open=False).launch()
348
  except KeyboardInterrupt:
349
- separator._get_session().close()
350
  app.close()
 
351
  sys.exit(0)
352
 
353
 
 
14
  import librosa
15
  import numpy as np
16
  import soundfile
 
17
  from pydub import AudioSegment
18
  import uuid
19
+ from torchspleeter import split_to_parts
20
  import logging
21
 
22
  logging.getLogger('numba').setLevel(logging.WARNING)
 
83
  load_svc_model("vits_pretrain/sovits5.0-48k-debug.pth", model)
84
  model.eval()
85
  model.to(device)
 
86
  whisper_model = whisper.inference.load_model(os.path.join("whisper_pretrain", "medium.pt"))
87
 
88
+
89
  # warm up
90
+ # separator.separate_to_file('warm.wav', '/tmp/warm')
91
 
92
 
93
  def svc_change(argswave, argsspk):
 
195
  soundfile.write(input_audio_tmp_file, audio, sampling_rate, format="wav")
196
  if not os.path.exists(tmpfile_path):
197
  os.makedirs(tmpfile_path)
198
+
199
+ split_to_parts(input_audio_tmp_file, tmpfile_path, models='torchspleeter/checkpoints/2stems/testcheckpoint1.ckpt')
200
 
201
  curr_tmp_path = os.path.join(tmpfile_path, os.path.splitext(input_audio_tmp_file)[0])
202
  vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
 
346
 
347
  app.queue(max_size=3, api_open=False).launch()
348
  except KeyboardInterrupt:
 
349
  app.close()
350
+ separator._get_session().close()
351
  sys.exit(0)
352
 
353
 
torchspleeter/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from torchspleeter.command_interface import *
2
+
3
+
4
+ version="0.1.5"
torchspleeter/checkpoints/2stems/audio_example.mp3 ADDED
Binary file (263 kB). View file
 
torchspleeter/checkpoints/2stems/testcheckpoint0.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:385ea3372c6a3ceee01f6ded5504bb7ee1e9f0101950ae58869dc18382deb75c
3
+ size 59050239
torchspleeter/checkpoints/2stems/testcheckpoint1.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e4d6ede1ecad091468550773e77934aac3f1e039c0697fc9039aba9b935e344
3
+ size 59033471
torchspleeter/command_interface.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+
3
+ This provides an interface to interact with the spleeter system on
4
+
5
+
6
+ """
7
+
8
+ import os
9
+ from torchspleeter.estimator import Estimator
10
+ import argparse
11
+ import uuid
12
+ import numpy as np
13
+ import librosa
14
+ import soundfile
15
+ import torch
16
+ import pydub
17
+ import os
18
+ import shutil
19
+
20
+
21
+ def split_to_parts(inputaudiofile, outputdir, instruments=2, models=[]):
22
+ filedata = pydub.AudioSegment.from_file(inputaudiofile)
23
+ sr = filedata.frame_rate
24
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+ # es = Estimator(2, './checkpoints/2stems/model').to(device)
26
+ # es = Estimator(2, ['./checkpoints/2stems/testcheckpoint0.ckpt','./checkpoints/2stems/testcheckpoint1.ckpt']).to(device)
27
+ es = Estimator()
28
+ es.eval()
29
+
30
+ # load wav audio
31
+ testaudiofile = inputaudiofile
32
+ channels = filedata.channels
33
+ mono_selection = False
34
+ if not os.path.exists(outputdir):
35
+ os.makedirs(outputdir, exist_ok=True)
36
+ if channels == 1:
37
+ mono_selection = True
38
+ multichannel = pydub.AudioSegment.from_mono_audiosegments(filedata, filedata)
39
+ os.makedirs(os.path.join(outputdir, 'tmp'), exist_ok=True)
40
+ testaudiofile = os.path.join(outputdir, "tmp" + str(uuid.uuid4()) + "." + testaudiofile.split('.')[-1])
41
+ # testaudiofile=testaudiofile.split('.')[0]+"-stereo."+testaudiofile.split('.')[-1]
42
+ multichannel.export(out_f=testaudiofile, format=testaudiofile.split('.')[-1])
43
+ print(mono_selection)
44
+ print(channels)
45
+ wav, _ = librosa.load(testaudiofile, mono=False, res_type='kaiser_fast', sr=sr)
46
+ wav = torch.Tensor(wav).to(device)
47
+ if mono_selection:
48
+ shutil.rmtree(os.path.join(outputdir, "tmp"))
49
+ # os.remove(testaudiofile)
50
+ wavs = es.separate(wav)
51
+ outputname = str(uuid.uuid4())
52
+ returnarray = []
53
+ for i in range(len(wavs)):
54
+ finaloutput = os.path.join(outputdir, outputname)
55
+ fname = '-out_{}.wav'.format(i)
56
+ fname = finaloutput + fname
57
+ print('Writing ', fname)
58
+ soundfile.write(fname, wavs[i].cpu().detach().numpy().T, sr, "PCM_16")
59
+ returnarray.append(fname)
60
+ # write_wav(fname, np.asfortranarray(wavs[i].squeeze().numpy()), sr)
61
+ return returnarray
62
+
63
+
64
+ def get_file_list(dirname):
65
+ outputfilelist = []
66
+ for subdir, dirs, files in os.walk(dirname):
67
+ for file in files:
68
+ outputfilelist.append(os.path.join(subdir, file))
69
+
70
+ return outputfilelist
71
+
72
+
73
+ def main():
74
+ parser = argparse.ArgumentParser(
75
+ description='torchspleeter allows you to separate instrumentals from audio (vocals, instruments, background noise, etc) in a simple, cross platform manner')
76
+ parser.add_argument('-i', '--inputfile', help='Input Audio File to split into instrumentals', required=True)
77
+ parser.add_argument('-o', '--output', help='Output directory to deposit split audio', required=True)
78
+ parser.add_argument('-n', '--number', help="Number of instruments in the model (default 2)", required=False,
79
+ default=2, type=int)
80
+ parser.add_argument('-m', '--modeldir',
81
+ help="directory containing number of pre-converted torch compatible model components",
82
+ required=False)
83
+ args = vars(parser.parse_args())
84
+ print(args)
85
+ if args['modeldir'] is not None:
86
+ modelfiles = get_file_list(args['modeldir'])
87
+ if len(modelfiles) != args['number']:
88
+ raise ValueError("You must have the same number of models as you do number of instruments!")
89
+ else:
90
+ args['modeldir'] = []
91
+ outputfiles = split_to_parts(args['inputfile'], args['output'], args['number'], args['modeldir'])
92
+ print("Your output files are:")
93
+ for item in outputfiles:
94
+ print(item)
95
+
96
+
97
+ if __name__ == "__main__":
98
+ main()
torchspleeter/estimator.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ import tqdm
7
+ # from torchaudio.functional import istft
8
+
9
+ from torchspleeter.unet import UNet
10
+ #from .util import tf2pytorch
11
+
12
+ import os
13
+ dirname = os.path.dirname(__file__)
14
+ defaultmodel0 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint0.ckpt')
15
+ defaultmodel1 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint1.ckpt')
16
+
17
+ def load_ckpt(model, ckpt):
18
+ state_dict = model.state_dict()
19
+ for k, v in ckpt.items():
20
+ if k in state_dict:
21
+ target_shape = state_dict[k].shape
22
+ assert target_shape == v.shape
23
+ state_dict.update({k: torch.from_numpy(v)})
24
+ else:
25
+ print('Ignore ', k)
26
+
27
+ model.load_state_dict(state_dict)
28
+ return model
29
+
30
+
31
+ def pad_and_partition(tensor, T):
32
+ """
33
+ pads zero and partition tensor into segments of length T
34
+
35
+ Args:
36
+ tensor(Tensor): BxCxFxL
37
+
38
+ Returns:
39
+ tensor of size (B*[L/T] x C x F x T)
40
+ """
41
+ old_size = tensor.size(3)
42
+ new_size = math.ceil(old_size/T) * T
43
+ tensor = F.pad(tensor, [0, new_size - old_size])
44
+ [b, c, t, f] = tensor.shape
45
+ split = new_size // T
46
+ return torch.cat(torch.split(tensor, T, dim=3), dim=0)
47
+
48
+
49
+ class Estimator(nn.Module):
50
+ def __init__(self, num_instrumments=2, checkpoint_path=None):
51
+ super(Estimator, self).__init__()
52
+ if checkpoint_path is None:
53
+ checkpoint_path=[defaultmodel0,defaultmodel1]
54
+ else:
55
+ if len(checkpoint_path)<1:
56
+ checkpoint_path=[defaultmodel0,defaultmodel1]
57
+ # stft config
58
+ self.F = 1024
59
+ self.T = 512
60
+ self.win_length = 4096
61
+ self.hop_length = 1024
62
+ self.win = nn.Parameter(
63
+ torch.hann_window(self.win_length),
64
+ requires_grad=False
65
+ )
66
+
67
+ ckpts=[]
68
+ if len(checkpoint_path) != num_instrumments:
69
+ raise ValueError("You must submit as many models as there are instruments!")
70
+ for ckpt_path in checkpoint_path:
71
+ ckpts.append(torch.load(ckpt_path))
72
+
73
+ #self.ckpts = ckpt #torch.load(checkpoint_path)#, num_instrumments)
74
+
75
+ #ckpts = #tf2pytorch(checkpoint_path, num_instrumments)
76
+
77
+ # filter
78
+ self.instruments = nn.ModuleList()
79
+ for i in range(num_instrumments):
80
+ print('Loading model for instrumment {}'.format(i))
81
+ net = UNet(2)
82
+ ckpt = ckpts[i]
83
+ net = load_ckpt(net, ckpt)
84
+ net.eval() # change mode to eval
85
+ self.instruments.append(net)
86
+
87
+ def compute_stft(self, wav):
88
+ """
89
+ Computes stft feature from wav
90
+
91
+ Args:
92
+ wav (Tensor): B x L
93
+ """
94
+
95
+ stft = torch.stft(wav, n_fft=self.win_length, hop_length=self.hop_length, window=self.win,
96
+ center=True, return_complex=False, pad_mode='constant')
97
+
98
+ # only keep freqs smaller than self.F
99
+ stft = stft[:, :self.F, :, :]
100
+ real = stft[:, :, :, 0]
101
+ im = stft[:, :, :, 1]
102
+ mag = torch.sqrt(real ** 2 + im ** 2)
103
+
104
+ return stft, mag
105
+
106
+ def inverse_stft(self, stft):
107
+ """Inverses stft to wave form"""
108
+
109
+ pad = self.win_length // 2 + 1 - stft.size(1)
110
+ stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
111
+ wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
112
+ window=self.win)
113
+ return wav.detach()
114
+
115
+ def separate(self, wav):
116
+ """
117
+ Separates stereo wav into different tracks corresponding to different instruments
118
+
119
+ Args:
120
+ wav (tensor): 2 x L
121
+ """
122
+
123
+ # stft - 2 X F x L x 2
124
+ # stft_mag - 2 X F x L
125
+ stft, stft_mag = self.compute_stft(wav)
126
+
127
+ L = stft.size(2)
128
+
129
+ # 1 x 2 x F x T
130
+ stft_mag = stft_mag.unsqueeze(-1).permute([3, 0, 1, 2])
131
+ stft_mag = pad_and_partition(stft_mag, self.T) # B x 2 x F x T
132
+ stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
133
+
134
+ B = stft_mag.shape[0]
135
+
136
+ # compute instruments' mask
137
+ masks = []
138
+ for net in self.instruments:
139
+ mask = net(stft_mag)
140
+ masks.append(mask)
141
+
142
+ # compute denominator
143
+ mask_sum = sum([m ** 2 for m in masks])
144
+ mask_sum += 1e-10
145
+
146
+ wavs = []
147
+ for mask in tqdm.tqdm(masks):
148
+ mask = (mask ** 2 + 1e-10/2)/(mask_sum)
149
+ mask = mask.transpose(2, 3) # B x 2 X F x T
150
+
151
+ mask = torch.cat(
152
+ torch.split(mask, 1, dim=0), dim=3)
153
+
154
+ mask = mask.squeeze(0)[:,:,:L].unsqueeze(-1) # 2 x F x L x 1
155
+ stft_masked = stft * mask
156
+ wav_masked = self.inverse_stft(stft_masked)
157
+
158
+ wavs.append(wav_masked)
159
+
160
+ return wavs
torchspleeter/test/test_estimator.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import librosa
3
+ import soundfile
4
+ import torch
5
+ import pydub
6
+ import os
7
+ from torchspleeter.estimator import Estimator
8
+ dirname = os.path.dirname(__file__)
9
+ testfilename = os.path.join(dirname, 'checkpoints/2stems/audio_example.mp3')
10
+
11
+ if __name__ == '__main__':
12
+ sr = 44100
13
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+ #es = Estimator(2, './checkpoints/2stems/model').to(device)
15
+ #es = Estimator(2, ['./checkpoints/2stems/testcheckpoint0.ckpt','./checkpoints/2stems/testcheckpoint1.ckpt']).to(device)
16
+ es=Estimator()
17
+ es.eval()
18
+
19
+ # load wav audio
20
+ testaudiofile=testfilename
21
+ filestats=pydub.AudioSegment.from_file(testaudiofile)
22
+ channels=filestats.channels
23
+ mono_selection=False
24
+ if channels==1:
25
+ mono_selection=True
26
+ multichannel=pydub.AudioSegment.from_mono_audiosegments(filestats,filestats)
27
+ testaudiofile=testaudiofile.split('.')[0]+"-stereo."+testaudiofile.split('.')[-1]
28
+ multichannel.export(out_f=testaudiofile,format=testaudiofile.split('.')[-1])
29
+ print(mono_selection)
30
+ print(channels)
31
+ wav, _ = librosa.load(testaudiofile, mono=False, res_type='kaiser_fast',sr=sr)
32
+ wav = torch.Tensor(wav).to(device)
33
+ if mono_selection:
34
+ os.remove(testaudiofile)
35
+
36
+
37
+ # normalize audio
38
+ # wav_torch = wav / (wav.max() + 1e-8)
39
+
40
+ wavs = es.separate(wav)
41
+ for i in range(len(wavs)):
42
+ fname = 'output/out_{}.wav'.format(i)
43
+ print('Writing ',fname)
44
+ soundfile.write(fname, wavs[i].cpu().detach().numpy().T, sr, "PCM_16")
45
+ # write_wav(fname, np.asfortranarray(wavs[i].squeeze().numpy()), sr)
torchspleeter/unet.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+ class CustomPad(nn.Module):
6
+ def __init__(self, padding_setting=(1, 2, 1, 2)):
7
+ super(CustomPad, self).__init__()
8
+ self.padding_setting = padding_setting
9
+
10
+ def forward(self, x):
11
+ return F.pad(x, self.padding_setting, "constant", 0)
12
+
13
+
14
+ class CustomTransposedPad(nn.Module):
15
+ def __init__(self, padding_setting=(1, 2, 1, 2)):
16
+ super(CustomTransposedPad, self).__init__()
17
+ self.padding_setting = padding_setting
18
+
19
+ def forward(self, x):
20
+ l,r,t,b = self.padding_setting
21
+ return x[:,:,l:-r,t:-b]
22
+
23
+
24
+ def down_block(in_filters, out_filters):
25
+ return nn.Sequential(CustomPad(),
26
+ nn.Conv2d(in_filters, out_filters, kernel_size=5, stride=2,padding=0)), \
27
+ nn.Sequential(
28
+ nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01),
29
+ nn.LeakyReLU(0.2))
30
+
31
+
32
+ def up_block(in_filters, out_filters, dropout=False):
33
+ layers = [
34
+ nn.ConvTranspose2d(in_filters, out_filters, kernel_size=5,stride=2),
35
+ CustomTransposedPad(),
36
+ nn.ReLU(),
37
+ nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01)
38
+ ]
39
+ if dropout:
40
+ layers.append(nn.Dropout(0.5))
41
+
42
+ return nn.Sequential(*layers)
43
+
44
+
45
+ class UNet(nn.Module):
46
+ def __init__(self, in_channels=2):
47
+ super(UNet, self).__init__()
48
+ self.down1_conv, self.down1_act = down_block(in_channels, 16)
49
+ self.down2_conv, self.down2_act = down_block(16, 32)
50
+ self.down3_conv, self.down3_act = down_block(32, 64)
51
+ self.down4_conv, self.down4_act = down_block(64, 128)
52
+ self.down5_conv, self.down5_act = down_block(128, 256)
53
+ self.down6_conv, self.down6_act = down_block(256, 512)
54
+
55
+ self.up1 = up_block(512, 256, dropout=True)
56
+ self.up2 = up_block(512, 128, dropout=True)
57
+ self.up3 = up_block(256, 64, dropout=True)
58
+ self.up4 = up_block(128, 32)
59
+ self.up5 = up_block(64, 16)
60
+ self.up6 = up_block(32, 1)
61
+ self.up7 = nn.Sequential(
62
+ nn.Conv2d(1, in_channels, kernel_size=4, dilation=2, padding=3),
63
+ nn.Sigmoid()
64
+ )
65
+
66
+ def forward(self, x):
67
+ d1_conv = self.down1_conv(x)
68
+ d1 = self.down1_act(d1_conv)
69
+
70
+ d2_conv = self.down2_conv(d1)
71
+ d2 = self.down2_act(d2_conv)
72
+
73
+ d3_conv = self.down3_conv(d2)
74
+ d3 = self.down3_act(d3_conv)
75
+
76
+ d4_conv = self.down4_conv(d3)
77
+ d4 = self.down4_act(d4_conv)
78
+
79
+ d5_conv = self.down5_conv(d4)
80
+ d5 = self.down5_act(d5_conv)
81
+
82
+ d6_conv = self.down6_conv(d5)
83
+ d6 = self.down6_act(d6_conv)
84
+
85
+ u1 = self.up1(d6_conv)
86
+ u2 = self.up2(torch.cat([d5_conv, u1], axis=1))
87
+ u3 = self.up3(torch.cat([d4_conv, u2], axis=1))
88
+ u4 = self.up4(torch.cat([d3_conv, u3], axis=1))
89
+ u5 = self.up5(torch.cat([d2_conv, u4], axis=1))
90
+ u6 = self.up6(torch.cat([d1_conv, u5], axis=1))
91
+ u7 = self.up7(u6)
92
+ return u7 * x
93
+
94
+
95
+ if __name__ == '__main__':
96
+ net = UNet(14)
97
+ print(net(torch.rand(1, 14, 20, 48)).shape)