HoneyTian commited on
Commit
60a97c6
·
1 Parent(s): 9828a2d
examples/mpnet_aishell/step_2_train_model.py CHANGED
@@ -212,7 +212,7 @@ def main():
212
  clean_audio, noisy_audio = batch
213
  clean_audio = clean_audio.to(device)
214
  noisy_audio = noisy_audio.to(device)
215
- one_labels = torch.ones(config.batch_size).to(device)
216
 
217
  clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
218
  noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
@@ -234,7 +234,7 @@ def main():
234
  if batch_pesq_score is not None:
235
  loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten())
236
  else:
237
- print("pesq is None!")
238
  loss_disc_g = 0
239
 
240
  loss_disc_all = loss_disc_r + loss_disc_g
 
212
  clean_audio, noisy_audio = batch
213
  clean_audio = clean_audio.to(device)
214
  noisy_audio = noisy_audio.to(device)
215
+ one_labels = torch.ones(clean_audio.shape[0]).to(device)
216
 
217
  clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
218
  noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
 
234
  if batch_pesq_score is not None:
235
  loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten())
236
  else:
237
+ # print("pesq is None!")
238
  loss_disc_g = 0
239
 
240
  loss_disc_all = loss_disc_r + loss_disc_g