Spaces:
Running
Running
update
Browse files
examples/spectrum_dfnet_aishell/step_2_train_model.py
CHANGED
@@ -313,22 +313,22 @@ def main():
|
|
313 |
snr_db_target = snr_db.to(device)
|
314 |
|
315 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
316 |
-
|
317 |
-
|
318 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
319 |
raise AssertionError("nan or inf in speech_irm_prediction")
|
320 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
321 |
raise AssertionError("nan or inf in lsnr_prediction")
|
322 |
|
323 |
-
|
324 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
325 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
326 |
|
327 |
# if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
328 |
# raise AssertionError("nan or inf in snr_loss")
|
329 |
|
330 |
-
|
331 |
-
loss = irm_loss + snr_loss
|
332 |
|
333 |
total_loss += loss.item()
|
334 |
total_examples += mix_complex_spec.size(0)
|
@@ -361,19 +361,19 @@ def main():
|
|
361 |
|
362 |
with torch.no_grad():
|
363 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
364 |
-
|
365 |
-
|
366 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
367 |
raise AssertionError("nan or inf in speech_irm_prediction")
|
368 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
369 |
raise AssertionError("nan or inf in lsnr_prediction")
|
370 |
|
371 |
-
|
372 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
373 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
374 |
|
375 |
-
|
376 |
-
loss = irm_loss + snr_loss
|
377 |
|
378 |
total_loss += loss.item()
|
379 |
total_examples += mix_complex_spec.size(0)
|
|
|
313 |
snr_db_target = snr_db.to(device)
|
314 |
|
315 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
316 |
+
if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
|
317 |
+
raise AssertionError("nan or inf in speech_spec_prediction")
|
318 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
319 |
raise AssertionError("nan or inf in speech_irm_prediction")
|
320 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
321 |
raise AssertionError("nan or inf in lsnr_prediction")
|
322 |
|
323 |
+
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
324 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
325 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
326 |
|
327 |
# if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
328 |
# raise AssertionError("nan or inf in snr_loss")
|
329 |
|
330 |
+
loss = speech_loss + irm_loss + snr_loss
|
331 |
+
# loss = irm_loss + snr_loss
|
332 |
|
333 |
total_loss += loss.item()
|
334 |
total_examples += mix_complex_spec.size(0)
|
|
|
361 |
|
362 |
with torch.no_grad():
|
363 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_complex_spec)
|
364 |
+
if torch.any(torch.isnan(speech_spec_prediction)) or torch.any(torch.isinf(speech_spec_prediction)):
|
365 |
+
raise AssertionError("nan or inf in speech_spec_prediction")
|
366 |
if torch.any(torch.isnan(speech_irm_prediction)) or torch.any(torch.isinf(speech_irm_prediction)):
|
367 |
raise AssertionError("nan or inf in speech_irm_prediction")
|
368 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
369 |
raise AssertionError("nan or inf in lsnr_prediction")
|
370 |
|
371 |
+
speech_loss = speech_mse_loss.forward(speech_spec_prediction, torch.view_as_real(speech_complex_spec))
|
372 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
373 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
374 |
|
375 |
+
loss = speech_loss + irm_loss + snr_loss
|
376 |
+
# loss = irm_loss + snr_loss
|
377 |
|
378 |
total_loss += loss.item()
|
379 |
total_examples += mix_complex_spec.size(0)
|
toolbox/torchaudio/models/spectrum_dfnet/modeling_spectrum_dfnet.py
CHANGED
@@ -803,31 +803,30 @@ class SpectrumDfNet(nn.Module):
|
|
803 |
def forward(self,
|
804 |
spec_complex: torch.Tensor,
|
805 |
):
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
|
821 |
-
|
822 |
-
|
823 |
-
|
824 |
-
# spec
|
825 |
-
|
826 |
-
# spec
|
827 |
-
|
828 |
-
# spec
|
829 |
-
|
830 |
-
# spec = spec.detach()
|
831 |
|
832 |
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
833 |
|
@@ -836,31 +835,30 @@ class SpectrumDfNet(nn.Module):
|
|
836 |
if torch.any(mask > 1) or torch.any(mask < 0):
|
837 |
raise AssertionError
|
838 |
|
839 |
-
|
840 |
|
841 |
# lsnr shape: [batch_size, time_steps, 1]
|
842 |
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
843 |
# lsnr shape: [batch_size, 1, time_steps]
|
844 |
|
845 |
-
|
846 |
-
|
847 |
-
#
|
848 |
-
|
849 |
-
|
850 |
-
#
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
#
|
857 |
|
858 |
mask = torch.squeeze(mask, dim=1)
|
859 |
mask = mask.permute(0, 2, 1)
|
860 |
# mask shape: [batch_size, spec_bins, time_steps]
|
861 |
|
862 |
-
|
863 |
-
return None, mask, lsnr
|
864 |
|
865 |
|
866 |
class SpectrumDfNetPretrainedModel(SpectrumDfNet):
|
|
|
803 |
def forward(self,
|
804 |
spec_complex: torch.Tensor,
|
805 |
):
|
806 |
+
feat_power = torch.square(torch.abs(spec_complex))
|
807 |
+
feat_power = feat_power.unsqueeze(1).permute(0, 1, 3, 2)
|
808 |
+
# feat_power shape: [batch_size, spec_bins, time_steps]
|
809 |
+
# feat_power shape: [batch_size, 1, spec_bins, time_steps]
|
810 |
+
# feat_power shape: [batch_size, 1, time_steps, spec_bins]
|
811 |
+
feat_power = feat_power.detach()
|
812 |
+
|
813 |
+
# spec shape: [batch_size, spec_bins, time_steps]
|
814 |
+
feat_spec = torch.view_as_real(spec_complex)
|
815 |
+
# spec shape: [batch_size, spec_bins, time_steps, 2]
|
816 |
+
feat_spec = feat_spec.permute(0, 3, 2, 1)
|
817 |
+
# feat_spec shape: [batch_size, 2, time_steps, spec_bins]
|
818 |
+
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
|
819 |
+
# feat_spec shape: [batch_size, 2, time_steps, df_bins]
|
820 |
+
feat_spec = feat_spec.detach()
|
821 |
+
|
822 |
+
# spec shape: [batch_size, spec_bins, time_steps]
|
823 |
+
spec = torch.unsqueeze(spec_complex, dim=1)
|
824 |
+
# spec shape: [batch_size, 1, spec_bins, time_steps]
|
825 |
+
spec = spec.permute(0, 1, 3, 2)
|
826 |
+
# spec shape: [batch_size, 1, time_steps, spec_bins]
|
827 |
+
spec = torch.view_as_real(spec)
|
828 |
+
# spec shape: [batch_size, 1, time_steps, spec_bins, 2]
|
829 |
+
spec = spec.detach()
|
|
|
830 |
|
831 |
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
|
832 |
|
|
|
835 |
if torch.any(mask > 1) or torch.any(mask < 0):
|
836 |
raise AssertionError
|
837 |
|
838 |
+
spec_m = self.mask.forward(spec, mask)
|
839 |
|
840 |
# lsnr shape: [batch_size, time_steps, 1]
|
841 |
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
842 |
# lsnr shape: [batch_size, 1, time_steps]
|
843 |
|
844 |
+
df_coefs = self.df_decoder.forward(emb, c0)
|
845 |
+
df_coefs = self.df_out_transform(df_coefs)
|
846 |
+
# df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
|
847 |
+
|
848 |
+
spec_e = self.df_op.forward(spec.clone(), df_coefs)
|
849 |
+
# spec_e shape: [batch_size, 1, time_steps, spec_bins, 2]
|
850 |
+
|
851 |
+
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
|
852 |
+
|
853 |
+
spec_e = torch.squeeze(spec_e, dim=1)
|
854 |
+
spec_e = spec_e.permute(0, 2, 1, 3)
|
855 |
+
# spec_e shape: [batch_size, spec_bins, time_steps, 2]
|
856 |
|
857 |
mask = torch.squeeze(mask, dim=1)
|
858 |
mask = mask.permute(0, 2, 1)
|
859 |
# mask shape: [batch_size, spec_bins, time_steps]
|
860 |
|
861 |
+
return spec_e, mask, lsnr
|
|
|
862 |
|
863 |
|
864 |
class SpectrumDfNetPretrainedModel(SpectrumDfNet):
|