pengdaqian commited on
Commit
d7659a0
·
1 Parent(s): d853526
Files changed (2) hide show
  1. app.py +1 -1
  2. torchspleeter/estimator.py +16 -13
app.py CHANGED
@@ -196,7 +196,7 @@ def svc_main(sid, input_audio):
196
  if not os.path.exists(tmpfile_path):
197
  os.makedirs(tmpfile_path)
198
 
199
- split_to_parts(input_audio_tmp_file, tmpfile_path, models='torchspleeter/checkpoints/2stems/testcheckpoint1.ckpt')
200
 
201
  curr_tmp_path = os.path.join(tmpfile_path, os.path.splitext(input_audio_tmp_file)[0])
202
  vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
 
196
  if not os.path.exists(tmpfile_path):
197
  os.makedirs(tmpfile_path)
198
 
199
+ split_to_parts(input_audio_tmp_file, tmpfile_path)
200
 
201
  curr_tmp_path = os.path.join(tmpfile_path, os.path.splitext(input_audio_tmp_file)[0])
202
  vocals_filepath = os.path.join(curr_tmp_path, 'vocals.wav')
torchspleeter/estimator.py CHANGED
@@ -7,12 +7,15 @@ import tqdm
7
  # from torchaudio.functional import istft
8
 
9
  from torchspleeter.unet import UNet
10
- #from .util import tf2pytorch
11
 
12
  import os
 
13
  dirname = os.path.dirname(__file__)
14
  defaultmodel0 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint0.ckpt')
15
  defaultmodel1 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint1.ckpt')
 
 
16
 
17
  def load_ckpt(model, ckpt):
18
  state_dict = model.state_dict()
@@ -39,7 +42,7 @@ def pad_and_partition(tensor, T):
39
  tensor of size (B*[L/T] x C x F x T)
40
  """
41
  old_size = tensor.size(3)
42
- new_size = math.ceil(old_size/T) * T
43
  tensor = F.pad(tensor, [0, new_size - old_size])
44
  [b, c, t, f] = tensor.shape
45
  split = new_size // T
@@ -50,29 +53,29 @@ class Estimator(nn.Module):
50
  def __init__(self, num_instrumments=2, checkpoint_path=None):
51
  super(Estimator, self).__init__()
52
  if checkpoint_path is None:
53
- checkpoint_path=[defaultmodel0,defaultmodel1]
54
  else:
55
- if len(checkpoint_path)<1:
56
- checkpoint_path=[defaultmodel0,defaultmodel1]
57
  # stft config
58
  self.F = 1024
59
  self.T = 512
60
  self.win_length = 4096
61
  self.hop_length = 1024
62
  self.win = nn.Parameter(
63
- torch.hann_window(self.win_length),
64
  requires_grad=False
65
  )
66
 
67
- ckpts=[]
68
  if len(checkpoint_path) != num_instrumments:
69
  raise ValueError("You must submit as many models as there are instruments!")
70
  for ckpt_path in checkpoint_path:
71
  ckpts.append(torch.load(ckpt_path))
72
 
73
- #self.ckpts = ckpt #torch.load(checkpoint_path)#, num_instrumments)
74
 
75
- #ckpts = #tf2pytorch(checkpoint_path, num_instrumments)
76
 
77
  # filter
78
  self.instruments = nn.ModuleList()
@@ -109,7 +112,7 @@ class Estimator(nn.Module):
109
  pad = self.win_length // 2 + 1 - stft.size(1)
110
  stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
111
  wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
112
- window=self.win)
113
  return wav.detach()
114
 
115
  def separate(self, wav):
@@ -145,14 +148,14 @@ class Estimator(nn.Module):
145
 
146
  wavs = []
147
  for mask in tqdm.tqdm(masks):
148
- mask = (mask ** 2 + 1e-10/2)/(mask_sum)
149
  mask = mask.transpose(2, 3) # B x 2 X F x T
150
 
151
  mask = torch.cat(
152
  torch.split(mask, 1, dim=0), dim=3)
153
 
154
- mask = mask.squeeze(0)[:,:,:L].unsqueeze(-1) # 2 x F x L x 1
155
- stft_masked = stft * mask
156
  wav_masked = self.inverse_stft(stft_masked)
157
 
158
  wavs.append(wav_masked)
 
7
  # from torchaudio.functional import istft
8
 
9
  from torchspleeter.unet import UNet
10
+ # from .util import tf2pytorch
11
 
12
  import os
13
+
14
  dirname = os.path.dirname(__file__)
15
  defaultmodel0 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint0.ckpt')
16
  defaultmodel1 = os.path.join(dirname, 'checkpoints/2stems/testcheckpoint1.ckpt')
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
 
20
  def load_ckpt(model, ckpt):
21
  state_dict = model.state_dict()
 
42
  tensor of size (B*[L/T] x C x F x T)
43
  """
44
  old_size = tensor.size(3)
45
+ new_size = math.ceil(old_size / T) * T
46
  tensor = F.pad(tensor, [0, new_size - old_size])
47
  [b, c, t, f] = tensor.shape
48
  split = new_size // T
 
53
  def __init__(self, num_instrumments=2, checkpoint_path=None):
54
  super(Estimator, self).__init__()
55
  if checkpoint_path is None:
56
+ checkpoint_path = [defaultmodel0, defaultmodel1]
57
  else:
58
+ if len(checkpoint_path) < 1:
59
+ checkpoint_path = [defaultmodel0, defaultmodel1]
60
  # stft config
61
  self.F = 1024
62
  self.T = 512
63
  self.win_length = 4096
64
  self.hop_length = 1024
65
  self.win = nn.Parameter(
66
+ torch.hann_window(self.win_length, device=device),
67
  requires_grad=False
68
  )
69
 
70
+ ckpts = []
71
  if len(checkpoint_path) != num_instrumments:
72
  raise ValueError("You must submit as many models as there are instruments!")
73
  for ckpt_path in checkpoint_path:
74
  ckpts.append(torch.load(ckpt_path))
75
 
76
+ # self.ckpts = ckpt #torch.load(checkpoint_path)#, num_instrumments)
77
 
78
+ # ckpts = #tf2pytorch(checkpoint_path, num_instrumments)
79
 
80
  # filter
81
  self.instruments = nn.ModuleList()
 
112
  pad = self.win_length // 2 + 1 - stft.size(1)
113
  stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
114
  wav = torch.istft(stft, self.win_length, hop_length=self.hop_length, center=True,
115
+ window=self.win)
116
  return wav.detach()
117
 
118
  def separate(self, wav):
 
148
 
149
  wavs = []
150
  for mask in tqdm.tqdm(masks):
151
+ mask = (mask ** 2 + 1e-10 / 2) / (mask_sum)
152
  mask = mask.transpose(2, 3) # B x 2 X F x T
153
 
154
  mask = torch.cat(
155
  torch.split(mask, 1, dim=0), dim=3)
156
 
157
+ mask = mask.squeeze(0)[:, :, :L].unsqueeze(-1) # 2 x F x L x 1
158
+ stft_masked = stft * mask
159
  wav_masked = self.inverse_stft(stft_masked)
160
 
161
  wavs.append(wav_masked)