HoneyTian commited on
Commit
cdaf8e7
·
1 Parent(s): d085023
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -227,12 +227,8 @@ def main():
227
  clean_audios = clean_audios.to(device)
228
  noisy_audios = noisy_audios.to(device)
229
 
230
- print(f"clean_audios shape: {clean_audios.shape}, dtype: {clean_audios.dtype}")
231
- print(f"noisy_audios shape: {noisy_audios.shape}, dtype: {noisy_audios.dtype}")
232
-
233
  enhanced_audios = model.forward(noisy_audios)
234
  enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
235
- print(f"enhanced_audios shape: {enhanced_audios.shape}, dtype: {enhanced_audios.dtype}")
236
 
237
  ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
238
  sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
@@ -248,7 +244,7 @@ def main():
248
  optimizer.step()
249
  lr_scheduler.step()
250
 
251
- total_pesq_metric += pesq_metric.item()
252
  total_loss += loss.item()
253
  total_ae_loss += ae_loss.item()
254
  total_sc_loss += sc_loss.item()
 
227
  clean_audios = clean_audios.to(device)
228
  noisy_audios = noisy_audios.to(device)
229
 
 
 
 
230
  enhanced_audios = model.forward(noisy_audios)
231
  enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
 
232
 
233
  ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
234
  sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
 
244
  optimizer.step()
245
  lr_scheduler.step()
246
 
247
+ total_pesq_metric += 0 if pesq_metric is None else pesq_metric.item()
248
  total_loss += loss.item()
249
  total_ae_loss += ae_loss.item()
250
  total_sc_loss += sc_loss.item()