HoneyTian commited on
Commit
c8f41d6
·
1 Parent(s): 657b015
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py CHANGED
@@ -802,30 +802,31 @@ class SpectrumDfNet(nn.Module):
802
  def forward(self,
803
  spec_complex: torch.Tensor,
804
  ):
805
- feat_power = torch.square(torch.abs(spec_complex))
806
- feat_power = feat_power.unsqueeze(1).permute(0, 1, 3, 2)
807
- # feat_power shape: [batch_size, spec_bins, time_steps]
808
- # feat_power shape: [batch_size, 1, spec_bins, time_steps]
809
- # feat_power shape: [batch_size, 1, time_steps, spec_bins]
810
- feat_power = feat_power.detach()
811
-
812
- # spec shape: [batch_size, spec_bins, time_steps]
813
- feat_spec = torch.view_as_real(spec_complex)
814
- # spec shape: [batch_size, spec_bins, time_steps, 2]
815
- feat_spec = feat_spec.permute(0, 3, 2, 1)
816
- # feat_spec shape: [batch_size, 2, time_steps, spec_bins]
817
- feat_spec = feat_spec[..., :self.df_decoder.df_bins]
818
- # feat_spec shape: [batch_size, 2, time_steps, df_bins]
819
- feat_spec = feat_spec.detach()
820
-
821
- # spec shape: [batch_size, spec_bins, time_steps]
822
- spec = torch.unsqueeze(spec_complex, dim=1)
823
- # spec shape: [batch_size, 1, spec_bins, time_steps]
824
- spec = spec.permute(0, 1, 3, 2)
825
- # spec shape: [batch_size, 1, time_steps, spec_bins]
826
- spec = torch.view_as_real(spec)
827
- # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
828
- spec = spec.detach()
 
829
 
830
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
831
 
@@ -834,24 +835,24 @@ class SpectrumDfNet(nn.Module):
834
  if torch.any(mask > 1) or torch.any(mask < 0):
835
  raise AssertionError
836
 
837
- spec_m = self.mask.forward(spec, mask)
838
-
839
- # lsnr shape: [batch_size, time_steps, 1]
840
- lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
841
- # lsnr shape: [batch_size, 1, time_steps]
842
-
843
- df_coefs = self.df_decoder.forward(emb, c0)
844
- df_coefs = self.df_out_transform(df_coefs)
845
- # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
846
-
847
- spec_e = self.df_op.forward(spec.clone(), df_coefs)
848
- # spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
849
-
850
- spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
851
-
852
- spec_e = torch.squeeze(spec_e, dim=1)
853
- spec_e = spec_e.permute(0, 2, 1, 3)
854
- # spec_e shape: [batch_size, spec_bins, time_steps, 2]
855
 
856
  mask = torch.squeeze(mask, dim=1)
857
  mask = mask.permute(0, 2, 1)
 
802
  def forward(self,
803
  spec_complex: torch.Tensor,
804
  ):
805
+ with torch.no_grad():
806
+ feat_power = torch.square(torch.abs(spec_complex))
807
+ feat_power = feat_power.unsqueeze(1).permute(0, 1, 3, 2)
808
+ # feat_power shape: [batch_size, spec_bins, time_steps]
809
+ # feat_power shape: [batch_size, 1, spec_bins, time_steps]
810
+ # feat_power shape: [batch_size, 1, time_steps, spec_bins]
811
+ feat_power = feat_power.detach()
812
+
813
+ # spec shape: [batch_size, spec_bins, time_steps]
814
+ feat_spec = torch.view_as_real(spec_complex)
815
+ # spec shape: [batch_size, spec_bins, time_steps, 2]
816
+ feat_spec = feat_spec.permute(0, 3, 2, 1)
817
+ # feat_spec shape: [batch_size, 2, time_steps, spec_bins]
818
+ feat_spec = feat_spec[..., :self.df_decoder.df_bins]
819
+ # feat_spec shape: [batch_size, 2, time_steps, df_bins]
820
+ feat_spec = feat_spec.detach()
821
+
822
+ # # spec shape: [batch_size, spec_bins, time_steps]
823
+ # spec = torch.unsqueeze(spec_complex, dim=1)
824
+ # # spec shape: [batch_size, 1, spec_bins, time_steps]
825
+ # spec = spec.permute(0, 1, 3, 2)
826
+ # # spec shape: [batch_size, 1, time_steps, spec_bins]
827
+ # spec = torch.view_as_real(spec)
828
+ # # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
829
+ # spec = spec.detach()
830
 
831
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
832
 
 
835
  if torch.any(mask > 1) or torch.any(mask < 0):
836
  raise AssertionError
837
 
838
+ # spec_m = self.mask.forward(spec, mask)
839
+ #
840
+ # # lsnr shape: [batch_size, time_steps, 1]
841
+ # lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
842
+ # # lsnr shape: [batch_size, 1, time_steps]
843
+ #
844
+ # df_coefs = self.df_decoder.forward(emb, c0)
845
+ # df_coefs = self.df_out_transform(df_coefs)
846
+ # # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
847
+ #
848
+ # spec_e = self.df_op.forward(spec.clone(), df_coefs)
849
+ # # spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
850
+ #
851
+ # spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
852
+ #
853
+ # spec_e = torch.squeeze(spec_e, dim=1)
854
+ # spec_e = spec_e.permute(0, 2, 1, 3)
855
+ # # spec_e shape: [batch_size, spec_bins, time_steps, 2]
856
 
857
  mask = torch.squeeze(mask, dim=1)
858
  mask = mask.permute(0, 2, 1)