Spaces:
Running
Running
update
Browse files
examples/dtln/step_2_train_model.py
CHANGED
@@ -323,6 +323,7 @@ def main():
|
|
323 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
324 |
|
325 |
denoise_audios = model.forward(noisy_audios)
|
|
|
326 |
|
327 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
328 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
|
|
323 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
324 |
|
325 |
denoise_audios = model.forward(noisy_audios)
|
326 |
+
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
327 |
|
328 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
329 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|