Spaces:
Runtime error
Runtime error
pengdaqian
commited on
Commit
·
d7659a0
1
Parent(s):
d853526
fix
Browse files- app.py +1 -1
- torchspleeter/estimator.py +16 -13
app.py
CHANGED
@@ -196,7 +196,7 @@ def svc_main(sid, input_audio):
|
|
196 |
if not os.path.exists(tmpfile_path):
|
197 |
os.makedirs(tmpfile_path)
|
198 |
|
199 |
-
split_to_parts(input_audio_tmp_file, tmpfile_path
|
200 |
|
201 |
curr_tmp_path = os.path.join(tmpfile_path, os.path.splitext(input_audio_tmp_file)[0])
|
202 |
vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
|
|
|
196 |
if not os.path.exists(tmpfile_path):
|
197 |
os.makedirs(tmpfile_path)
|
198 |
|
199 |
+
split_to_parts(input_audio_tmp_file, tmpfile_path)
|
200 |
|
201 |
curr_tmp_path = os.path.join(tmpfile_path, os.path.splitext(input_audio_tmp_file)[0])
|
202 |
vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
|
torchspleeter/estimator.py
CHANGED
@@ -7,12 +7,15 @@ import tqdm
|
|
7 |
# from torchaudio.functional import istft
|
8 |
|
9 |
from torchspleeter.unet import UNet
|
10 |
-
#from .util import tf2pytorch
|
11 |
|
12 |
import os
|
|
|
13 |
dirname = os.path.dirname(__file__)
|
14 |
defaultmodel0 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint0.ckpt')
|
15 |
defaultmodel1 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint1.ckpt')
|
|
|
|
|
16 |
|
17 |
def load_ckpt(model, ckpt):
|
18 |
state_dict = model.state_dict()
|
@@ -39,7 +42,7 @@ def pad_and_partition(tensor, T):
|
|
39 |
tensor of size (B*[L/T] x C x F x T)
|
40 |
"""
|
41 |
old_size = tensor.size(3)
|
42 |
-
new_size = math.ceil(old_size/T) * T
|
43 |
tensor = F.pad(tensor, [0, new_size - old_size])
|
44 |
[b, c, t, f] = tensor.shape
|
45 |
split = new_size // T
|
@@ -50,29 +53,29 @@ class Estimator(nn.Module):
|
|
50 |
def __init__(self, num_instrumments=2, checkpoint_path=None):
|
51 |
super(Estimator, self).__init__()
|
52 |
if checkpoint_path is None:
|
53 |
-
checkpoint_path=[defaultmodel0,defaultmodel1]
|
54 |
else:
|
55 |
-
if len(checkpoint_path)<1:
|
56 |
-
checkpoint_path=[defaultmodel0,defaultmodel1]
|
57 |
# stft config
|
58 |
self.F = 1024
|
59 |
self.T = 512
|
60 |
self.win_length = 4096
|
61 |
self.hop_length = 1024
|
62 |
self.win = nn.Parameter(
|
63 |
-
torch.hann_window(self.win_length),
|
64 |
requires_grad=False
|
65 |
)
|
66 |
|
67 |
-
ckpts=[]
|
68 |
if len(checkpoint_path) != num_instrumments:
|
69 |
raise ValueError("You must submit as many models as there are instruments!")
|
70 |
for ckpt_path in checkpoint_path:
|
71 |
ckpts.append(torch.load(ckpt_path))
|
72 |
|
73 |
-
#self.ckpts = ckpt #torch.load(checkpoint_path)#, num_instrumments)
|
74 |
|
75 |
-
#ckpts = #tf2pytorch(checkpoint_path, num_instrumments)
|
76 |
|
77 |
# filter
|
78 |
self.instruments = nn.ModuleList()
|
@@ -109,7 +112,7 @@ class Estimator(nn.Module):
|
|
109 |
pad = self.win_length // 2 + 1 - stft.size(1)
|
110 |
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
|
111 |
wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
|
112 |
-
|
113 |
return wav.detach()
|
114 |
|
115 |
def separate(self, wav):
|
@@ -145,14 +148,14 @@ class Estimator(nn.Module):
|
|
145 |
|
146 |
wavs = []
|
147 |
for mask in tqdm.tqdm(masks):
|
148 |
-
mask = (mask ** 2 + 1e-10/2)/(mask_sum)
|
149 |
mask = mask.transpose(2, 3) # B x 2 X F x T
|
150 |
|
151 |
mask = torch.cat(
|
152 |
torch.split(mask, 1, dim=0), dim=3)
|
153 |
|
154 |
-
mask = mask.squeeze(0)[
|
155 |
-
stft_masked = stft *
|
156 |
wav_masked = self.inverse_stft(stft_masked)
|
157 |
|
158 |
wavs.append(wav_masked)
|
|
|
7 |
# from torchaudio.functional import istft
|
8 |
|
9 |
from torchspleeter.unet import UNet
|
10 |
+
# from .util import tf2pytorch
|
11 |
|
12 |
import os
|
13 |
+
|
14 |
dirname = os.path.dirname(__file__)
|
15 |
defaultmodel0 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint0.ckpt')
|
16 |
defaultmodel1 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint1.ckpt')
|
17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
|
19 |
|
20 |
def load_ckpt(model, ckpt):
|
21 |
state_dict = model.state_dict()
|
|
|
42 |
tensor of size (B*[L/T] x C x F x T)
|
43 |
"""
|
44 |
old_size = tensor.size(3)
|
45 |
+
new_size = math.ceil(old_size / T) * T
|
46 |
tensor = F.pad(tensor, [0, new_size - old_size])
|
47 |
[b, c, t, f] = tensor.shape
|
48 |
split = new_size // T
|
|
|
53 |
def __init__(self, num_instrumments=2, checkpoint_path=None):
|
54 |
super(Estimator, self).__init__()
|
55 |
if checkpoint_path is None:
|
56 |
+
checkpoint_path = [defaultmodel0, defaultmodel1]
|
57 |
else:
|
58 |
+
if len(checkpoint_path) < 1:
|
59 |
+
checkpoint_path = [defaultmodel0, defaultmodel1]
|
60 |
# stft config
|
61 |
self.F = 1024
|
62 |
self.T = 512
|
63 |
self.win_length = 4096
|
64 |
self.hop_length = 1024
|
65 |
self.win = nn.Parameter(
|
66 |
+
torch.hann_window(self.win_length, device=device),
|
67 |
requires_grad=False
|
68 |
)
|
69 |
|
70 |
+
ckpts = []
|
71 |
if len(checkpoint_path) != num_instrumments:
|
72 |
raise ValueError("You must submit as many models as there are instruments!")
|
73 |
for ckpt_path in checkpoint_path:
|
74 |
ckpts.append(torch.load(ckpt_path))
|
75 |
|
76 |
+
# self.ckpts = ckpt #torch.load(checkpoint_path)#, num_instrumments)
|
77 |
|
78 |
+
# ckpts = #tf2pytorch(checkpoint_path, num_instrumments)
|
79 |
|
80 |
# filter
|
81 |
self.instruments = nn.ModuleList()
|
|
|
112 |
pad = self.win_length // 2 + 1 - stft.size(1)
|
113 |
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
|
114 |
wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
|
115 |
+
window=self.win)
|
116 |
return wav.detach()
|
117 |
|
118 |
def separate(self, wav):
|
|
|
148 |
|
149 |
wavs = []
|
150 |
for mask in tqdm.tqdm(masks):
|
151 |
+
mask = (mask ** 2 + 1e-10 / 2) / (mask_sum)
|
152 |
mask = mask.transpose(2, 3) # B x 2 X F x T
|
153 |
|
154 |
mask = torch.cat(
|
155 |
torch.split(mask, 1, dim=0), dim=3)
|
156 |
|
157 |
+
mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1
|
158 |
+
stft_masked = stft * mask
|
159 |
wav_masked = self.inverse_stft(stft_masked)
|
160 |
|
161 |
wavs.append(wav_masked)
|