HoneyTian commited on
Commit
ce1e2dc
·
1 Parent(s): 1474235
examples/dtln/step_2_train_model.py CHANGED
@@ -323,6 +323,7 @@ def main():
323
  noisy_audios: torch.Tensor = noisy_audios.to(device)
324
 
325
  denoise_audios = model.forward(noisy_audios)
 
326
 
327
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
328
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
 
323
  noisy_audios: torch.Tensor = noisy_audios.to(device)
324
 
325
  denoise_audios = model.forward(noisy_audios)
326
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
327
 
328
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
329
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)