HoneyTian commited on
Commit
d0b38bd
·
1 Parent(s): e3e5274
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py CHANGED
@@ -181,6 +181,8 @@ class NXCleanUNet(nn.Module):
181
  )
182
 
183
  def forward(self, noisy_audios: torch.Tensor):
 
 
184
  # noisy_audios shape: [batch_size, 1, n_samples]
185
 
186
  n_samples = noisy_audios.shape[-1]
@@ -208,6 +210,9 @@ class NXCleanUNet(nn.Module):
208
 
209
  enhanced_audios = enhanced_audios[:, :, :n_samples]
210
  # enhanced_audios shape: [batch_size, 1, n_samples]
 
 
 
211
  return enhanced_audios
212
 
213
 
@@ -310,11 +315,11 @@ def main():
310
 
311
  # shape: [batch_size, channels, num_samples]
312
  # min length: 94, stride: 32, 32 == 2**5
313
- # x = torch.ones([4, 1, 94])
314
- # x = torch.ones([4, 1, 126])
315
- # x = torch.ones([4, 1, 158])
316
- # x = torch.ones([4, 1, 190])
317
- x = torch.ones([4, 1, 16000])
318
 
319
  model = NXCleanUNet(config)
320
  enhanced_audios = model.forward(x)
 
181
  )
182
 
183
  def forward(self, noisy_audios: torch.Tensor):
184
+ # noisy_audios shape: [batch_size, n_samples]
185
+ noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
186
  # noisy_audios shape: [batch_size, 1, n_samples]
187
 
188
  n_samples = noisy_audios.shape[-1]
 
210
 
211
  enhanced_audios = enhanced_audios[:, :, :n_samples]
212
  # enhanced_audios shape: [batch_size, 1, n_samples]
213
+
214
+ enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
215
+
216
  return enhanced_audios
217
 
218
 
 
315
 
316
  # shape: [batch_size, channels, num_samples]
317
  # min length: 94, stride: 32, 32 == 2**5
318
+ # x = torch.ones([4, 94])
319
+ # x = torch.ones([4, 126])
320
+ # x = torch.ones([4, 158])
321
+ # x = torch.ones([4, 190])
322
+ x = torch.ones([4, 16000])
323
 
324
  model = NXCleanUNet(config)
325
  enhanced_audios = model.forward(x)