HoneyTian commited on
Commit
5831850
·
1 Parent(s): bd866e6
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py CHANGED
@@ -849,6 +849,11 @@ class SpectrumDfNet(nn.Module):
849
  spec_e = torch.squeeze(spec_e, dim=1)
850
  spec_e = spec_e.permute(0, 2, 1, 3)
851
  # spec_e shape: [batch_size, spec_bins, time_steps, 2]
 
 
 
 
 
852
  return spec_e, mask, lsnr
853
 
854
 
 
849
  spec_e = torch.squeeze(spec_e, dim=1)
850
  spec_e = spec_e.permute(0, 2, 1, 3)
851
  # spec_e shape: [batch_size, spec_bins, time_steps, 2]
852
+
853
+ mask = torch.squeeze(mask, dim=1)
854
+ mask = mask.permute(0, 2, 1)
855
+ # mask shape: [batch_size, spec_bins, time_steps]
856
+
857
  return spec_e, mask, lsnr
858
 
859