Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py
CHANGED
@@ -849,6 +849,11 @@ class SpectrumDfNet(nn.Module):
|
|
849 |
spec_e = torch.squeeze(spec_e, dim=1)
|
850 |
spec_e = spec_e.permute(0, 2, 1, 3)
|
851 |
# spec_e shape: [batch_size, spec_bins, time_steps, 2]
|
|
|
|
|
|
|
|
|
|
|
852 |
return spec_e, mask, lsnr
|
853 |
|
854 |
|
|
|
849 |
spec_e = torch.squeeze(spec_e, dim=1)
|
850 |
spec_e = spec_e.permute(0, 2, 1, 3)
|
851 |
# spec_e shape: [batch_size, spec_bins, time_steps, 2]
|
852 |
+
|
853 |
+
mask = torch.squeeze(mask, dim=1)
|
854 |
+
mask = mask.permute(0, 2, 1)
|
855 |
+
# mask shape: [batch_size, spec_bins, time_steps]
|
856 |
+
|
857 |
return spec_e, mask, lsnr
|
858 |
|
859 |
|