HoneyTian commited on
Commit
4fbb8e0
·
1 Parent(s): 0be4793
examples/spectrum_dfnet_aishell/step_2_train_model.py CHANGED
@@ -313,22 +313,22 @@ def main():
313
  snr_db_target = snr_db.to(device)
314
 
315
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
316
- # if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
317
- # raise AssertionError("nan or inf in speech_spec_prediction")
318
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
319
  raise AssertionError("nan or inf in speech_irm_prediction")
320
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
321
  raise AssertionError("nan or inf in lsnr_prediction")
322
 
323
- # speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
324
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
325
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
326
 
327
  # if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
328
  # raise AssertionError("nan or inf in snr_loss")
329
 
330
- # loss = speech_loss + irm_loss + snr_loss
331
- loss = irm_loss + snr_loss
332
 
333
  total_loss += loss.item()
334
  total_examples += mix_complex_spec.size(0)
@@ -361,19 +361,19 @@ def main():
361
 
362
  with torch.no_grad():
363
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
364
- # if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
365
- # raise AssertionError("nan or inf in speech_spec_prediction")
366
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
367
  raise AssertionError("nan or inf in speech_irm_prediction")
368
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
369
  raise AssertionError("nan or inf in lsnr_prediction")
370
 
371
- # speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
372
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
373
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
374
 
375
- # loss = speech_loss + irm_loss + snr_loss
376
- loss = irm_loss + snr_loss
377
 
378
  total_loss += loss.item()
379
  total_examples += mix_complex_spec.size(0)
 
313
  snr_db_target = snr_db.to(device)
314
 
315
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
316
+ if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
317
+ raise AssertionError("nan or inf in speech_spec_prediction")
318
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
319
  raise AssertionError("nan or inf in speech_irm_prediction")
320
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
321
  raise AssertionError("nan or inf in lsnr_prediction")
322
 
323
+ speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
324
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
325
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
326
 
327
  # if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
328
  # raise AssertionError("nan or inf in snr_loss")
329
 
330
+ loss = speech_loss + irm_loss + snr_loss
331
+ # loss = irm_loss + snr_loss
332
 
333
  total_loss += loss.item()
334
  total_examples += mix_complex_spec.size(0)
 
361
 
362
  with torch.no_grad():
363
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
364
+ if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
365
+ raise AssertionError("nan or inf in speech_spec_prediction")
366
  if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
367
  raise AssertionError("nan or inf in speech_irm_prediction")
368
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
369
  raise AssertionError("nan or inf in lsnr_prediction")
370
 
371
+ speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
372
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
373
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
374
 
375
+ loss = speech_loss + irm_loss + snr_loss
376
+ # loss = irm_loss + snr_loss
377
 
378
  total_loss += loss.item()
379
  total_examples += mix_complex_spec.size(0)
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py CHANGED
@@ -803,31 +803,30 @@ class SpectrumDfNet(nn.Module):
803
  def forward(self,
804
  spec_complex: torch.Tensor,
805
  ):
806
- with torch.no_grad():
807
- feat_power = torch.square(torch.abs(spec_complex))
808
- feat_power = feat_power.unsqueeze(1).permute(0, 1, 3, 2)
809
- # feat_power shape: [batch_size, spec_bins, time_steps]
810
- # feat_power shape: [batch_size, 1, spec_bins, time_steps]
811
- # feat_power shape: [batch_size, 1, time_steps, spec_bins]
812
- feat_power = feat_power.detach()
813
-
814
- # spec shape: [batch_size, spec_bins, time_steps]
815
- feat_spec = torch.view_as_real(spec_complex)
816
- # spec shape: [batch_size, spec_bins, time_steps, 2]
817
- feat_spec = feat_spec.permute(0, 3, 2, 1)
818
- # feat_spec shape: [batch_size, 2, time_steps, spec_bins]
819
- feat_spec = feat_spec[..., :self.df_decoder.df_bins]
820
- # feat_spec shape: [batch_size, 2, time_steps, df_bins]
821
- feat_spec = feat_spec.detach()
822
-
823
- # # spec shape: [batch_size, spec_bins, time_steps]
824
- # spec = torch.unsqueeze(spec_complex, dim=1)
825
- # # spec shape: [batch_size, 1, spec_bins, time_steps]
826
- # spec = spec.permute(0, 1, 3, 2)
827
- # # spec shape: [batch_size, 1, time_steps, spec_bins]
828
- # spec = torch.view_as_real(spec)
829
- # # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
830
- # spec = spec.detach()
831
 
832
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
833
 
@@ -836,31 +835,30 @@ class SpectrumDfNet(nn.Module):
836
  if torch.any(mask > 1) or torch.any(mask < 0):
837
  raise AssertionError
838
 
839
- # spec_m = self.mask.forward(spec, mask)
840
 
841
  # lsnr shape: [batch_size, time_steps, 1]
842
  lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
843
  # lsnr shape: [batch_size, 1, time_steps]
844
 
845
- # df_coefs = self.df_decoder.forward(emb, c0)
846
- # df_coefs = self.df_out_transform(df_coefs)
847
- # # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
848
- #
849
- # spec_e = self.df_op.forward(spec.clone(), df_coefs)
850
- # # spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
851
- #
852
- # spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
853
- #
854
- # spec_e = torch.squeeze(spec_e, dim=1)
855
- # spec_e = spec_e.permute(0, 2, 1, 3)
856
- # # spec_e shape: [batch_size, spec_bins, time_steps, 2]
857
 
858
  mask = torch.squeeze(mask, dim=1)
859
  mask = mask.permute(0, 2, 1)
860
  # mask shape: [batch_size, spec_bins, time_steps]
861
 
862
- # return spec_e, mask, lsnr
863
- return None, mask, lsnr
864
 
865
 
866
  class SpectrumDfNetPretrainedModel(SpectrumDfNet):
 
803
  def forward(self,
804
  spec_complex: torch.Tensor,
805
  ):
806
+ feat_power = torch.square(torch.abs(spec_complex))
807
+ feat_power = feat_power.unsqueeze(1).permute(0, 1, 3, 2)
808
+ # feat_power shape: [batch_size, spec_bins, time_steps]
809
+ # feat_power shape: [batch_size, 1, spec_bins, time_steps]
810
+ # feat_power shape: [batch_size, 1, time_steps, spec_bins]
811
+ feat_power = feat_power.detach()
812
+
813
+ # spec shape: [batch_size, spec_bins, time_steps]
814
+ feat_spec = torch.view_as_real(spec_complex)
815
+ # spec shape: [batch_size, spec_bins, time_steps, 2]
816
+ feat_spec = feat_spec.permute(0, 3, 2, 1)
817
+ # feat_spec shape: [batch_size, 2, time_steps, spec_bins]
818
+ feat_spec = feat_spec[..., :self.df_decoder.df_bins]
819
+ # feat_spec shape: [batch_size, 2, time_steps, df_bins]
820
+ feat_spec = feat_spec.detach()
821
+
822
+ # spec shape: [batch_size, spec_bins, time_steps]
823
+ spec = torch.unsqueeze(spec_complex, dim=1)
824
+ # spec shape: [batch_size, 1, spec_bins, time_steps]
825
+ spec = spec.permute(0, 1, 3, 2)
826
+ # spec shape: [batch_size, 1, time_steps, spec_bins]
827
+ spec = torch.view_as_real(spec)
828
+ # spec shape: [batch_size, 1, time_steps, spec_bins, 2]
829
+ spec = spec.detach()
 
830
 
831
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
832
 
 
835
  if torch.any(mask > 1) or torch.any(mask < 0):
836
  raise AssertionError
837
 
838
+ spec_m = self.mask.forward(spec, mask)
839
 
840
  # lsnr shape: [batch_size, time_steps, 1]
841
  lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
842
  # lsnr shape: [batch_size, 1, time_steps]
843
 
844
+ df_coefs = self.df_decoder.forward(emb, c0)
845
+ df_coefs = self.df_out_transform(df_coefs)
846
+ # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
847
+
848
+ spec_e = self.df_op.forward(spec.clone(), df_coefs)
849
+ # spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
850
+
851
+ spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
852
+
853
+ spec_e = torch.squeeze(spec_e, dim=1)
854
+ spec_e = spec_e.permute(0, 2, 1, 3)
855
+ # spec_e shape: [batch_size, spec_bins, time_steps, 2]
856
 
857
  mask = torch.squeeze(mask, dim=1)
858
  mask = mask.permute(0, 2, 1)
859
  # mask shape: [batch_size, spec_bins, time_steps]
860
 
861
+ return spec_e, mask, lsnr
 
862
 
863
 
864
  class SpectrumDfNetPretrainedModel(SpectrumDfNet):