HoneyTian commited on
Commit
7a4199e
·
1 Parent(s): 6c6c36a
examples/rnnoise/step_2_train_model.py CHANGED
@@ -249,7 +249,6 @@ def main():
249
 
250
  step_idx = 0 if last_step_idx == -1 else last_step_idx
251
 
252
- logger.info("training")
253
  early_stop_flag = False
254
  for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
255
  if early_stop_flag:
@@ -274,6 +273,7 @@ def main():
274
  noisy_audios: torch.Tensor = noisy_audios.to(device)
275
 
276
  denoise_audios, _, _ = model.forward(noisy_audios)
 
277
 
278
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
279
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
@@ -339,6 +339,7 @@ def main():
339
  noisy_audios: torch.Tensor = noisy_audios.to(device)
340
 
341
  denoise_audios, _, _ = model.forward(noisy_audios)
 
342
 
343
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
344
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
 
249
 
250
  step_idx = 0 if last_step_idx == -1 else last_step_idx
251
 
 
252
  early_stop_flag = False
253
  for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
254
  if early_stop_flag:
 
273
  noisy_audios: torch.Tensor = noisy_audios.to(device)
274
 
275
  denoise_audios, _, _ = model.forward(noisy_audios)
276
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
277
 
278
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
279
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
 
339
  noisy_audios: torch.Tensor = noisy_audios.to(device)
340
 
341
  denoise_audios, _, _ = model.forward(noisy_audios)
342
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
343
 
344
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
345
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)