HoneyTian commited on
Commit
9828a2d
·
1 Parent(s): f69c753
examples/mpnet_aishell/step_2_train_model.py CHANGED
@@ -171,7 +171,7 @@ def main():
171
  num_params = 0
172
  for p in generator.parameters():
173
  num_params += p.numel()
174
- logger.info("Total Parameters (generator): {:.3f}M".format(num_params/1e6))
175
 
176
  optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
177
  optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
@@ -210,9 +210,9 @@ def main():
210
  )
211
  for batch in train_data_loader:
212
  clean_audio, noisy_audio = batch
213
- clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True))
214
- noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True))
215
- one_labels = torch.ones(config.batch_size).to(device, non_blocking=True)
216
 
217
  clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
218
  noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
@@ -293,8 +293,8 @@ def main():
293
  with torch.no_grad():
294
  for batch in valid_data_loader:
295
  clean_audio, noisy_audio = batch
296
- clean_audio = torch.autograd.Variable(clean_audio.to(device, non_blocking=True))
297
- noisy_audio = torch.autograd.Variable(noisy_audio.to(device, non_blocking=True))
298
 
299
  clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
300
  noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
 
171
  num_params = 0
172
  for p in generator.parameters():
173
  num_params += p.numel()
174
+ logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6))
175
 
176
  optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
177
  optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
 
210
  )
211
  for batch in train_data_loader:
212
  clean_audio, noisy_audio = batch
213
+ clean_audio = clean_audio.to(device)
214
+ noisy_audio = noisy_audio.to(device)
215
+ one_labels = torch.ones(config.batch_size).to(device)
216
 
217
  clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
218
  noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
 
293
  with torch.no_grad():
294
  for batch in valid_data_loader:
295
  clean_audio, noisy_audio = batch
296
+ clean_audio = clean_audio.to(device)
297
+ noisy_audio = noisy_audio.to(device)
298
 
299
  clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
300
  noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)