Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py
CHANGED
@@ -802,30 +802,31 @@ class SpectrumDfNet(nn.Module):
|
|
802 |
def forward(self,
|
803 |
spec_complex: torch.Tensor,
|
804 |
):
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
spec
|
823 |
-
# spec
|
824 |
-
spec
|
825 |
-
# spec
|
826 |
-
spec
|
827 |
-
# spec
|
828 |
-
spec
|
|
|
829 |
|
830 |
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
831 |
|
@@ -834,24 +835,24 @@ class SpectrumDfNet(nn.Module):
|
|
834 |
if torch.any(mask > 1) or torch.any(mask < 0):
|
835 |
raise AssertionError
|
836 |
|
837 |
-
spec_m = self.mask.forward(spec, mask)
|
838 |
-
|
839 |
-
# lsnr shape: [batch_size, time_steps, 1]
|
840 |
-
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
841 |
-
# lsnr shape: [batch_size, 1, time_steps]
|
842 |
-
|
843 |
-
df_coefs = self.df_decoder.forward(emb, c0)
|
844 |
-
df_coefs = self.df_out_transform(df_coefs)
|
845 |
-
# df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
|
846 |
-
|
847 |
-
spec_e = self.df_op.forward(spec.clone(), df_coefs)
|
848 |
-
# spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
|
849 |
-
|
850 |
-
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
|
851 |
-
|
852 |
-
spec_e = torch.squeeze(spec_e, dim=1)
|
853 |
-
spec_e = spec_e.permute(0, 2, 1, 3)
|
854 |
-
# spec_e shape: [batch_size, spec_bins, time_steps, 2]
|
855 |
|
856 |
mask = torch.squeeze(mask, dim=1)
|
857 |
mask = mask.permute(0, 2, 1)
|
|
|
802 |
def forward(self,
|
803 |
spec_complex: torch.Tensor,
|
804 |
):
|
805 |
+
with torch.no_grad():
|
806 |
+
feat_power = torch.square(torch.abs(spec_complex))
|
807 |
+
feat_power = feat_power.unsqueeze(1).permute(0, 1, 3, 2)
|
808 |
+
# feat_power shape: [batch_size, spec_bins, time_steps]
|
809 |
+
# feat_power shape: [batch_size, 1, spec_bins, time_steps]
|
810 |
+
# feat_power shape: [batch_size, 1, time_steps, spec_bins]
|
811 |
+
feat_power = feat_power.detach()
|
812 |
+
|
813 |
+
# spec shape: [batch_size, spec_bins, time_steps]
|
814 |
+
feat_spec = torch.view_as_real(spec_complex)
|
815 |
+
# spec shape: [batch_size, spec_bins, time_steps, 2]
|
816 |
+
feat_spec = feat_spec.permute(0, 3, 2, 1)
|
817 |
+
# feat_spec shape: [batch_size, 2, time_steps, spec_bins]
|
818 |
+
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
|
819 |
+
# feat_spec shape: [batch_size, 2, time_steps, df_bins]
|
820 |
+
feat_spec = feat_spec.detach()
|
821 |
+
|
822 |
+
# # spec shape: [batch_size, spec_bins, time_steps]
|
823 |
+
# spec = torch.unsqueeze(spec_complex, dim=1)
|
824 |
+
# # spec shape: [batch_size, 1, spec_bins, time_steps]
|
825 |
+
# spec = spec.permute(0, 1, 3, 2)
|
826 |
+
# # spec shape: [batch_size, 1, time_steps, spec_bins]
|
827 |
+
# spec = torch.view_as_real(spec)
|
828 |
+
# # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
|
829 |
+
# spec = spec.detach()
|
830 |
|
831 |
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
832 |
|
|
|
835 |
if torch.any(mask > 1) or torch.any(mask < 0):
|
836 |
raise AssertionError
|
837 |
|
838 |
+
# spec_m = self.mask.forward(spec, mask)
|
839 |
+
#
|
840 |
+
# # lsnr shape: [batch_size, time_steps, 1]
|
841 |
+
# lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
842 |
+
# # lsnr shape: [batch_size, 1, time_steps]
|
843 |
+
#
|
844 |
+
# df_coefs = self.df_decoder.forward(emb, c0)
|
845 |
+
# df_coefs = self.df_out_transform(df_coefs)
|
846 |
+
# # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
|
847 |
+
#
|
848 |
+
# spec_e = self.df_op.forward(spec.clone(), df_coefs)
|
849 |
+
# # spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
|
850 |
+
#
|
851 |
+
# spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
|
852 |
+
#
|
853 |
+
# spec_e = torch.squeeze(spec_e, dim=1)
|
854 |
+
# spec_e = spec_e.permute(0, 2, 1, 3)
|
855 |
+
# # spec_e shape: [batch_size, spec_bins, time_steps, 2]
|
856 |
|
857 |
mask = torch.squeeze(mask, dim=1)
|
858 |
mask = mask.permute(0, 2, 1)
|