HoneyTian commited on
Commit
1042eee
·
1 Parent(s): 46cf9fb
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -233,7 +233,7 @@ def main():
233
  print(f"enhanced_audios shape: {enhanced_audios.shape}, dtype: {enhanced_audios.dtype}")
234
 
235
  ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
236
- sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios.squeeze(1), clean_audios.squeeze(1))
237
 
238
  loss = ae_loss + sc_loss + mag_loss
239
 
@@ -294,7 +294,7 @@ def main():
294
  enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
295
 
296
  ae_loss = ae_loss_fn(enhanced_audios, enhanced_audios)
297
- sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios.squeeze(1), clean_audios.squeeze(1))
298
 
299
  loss = ae_loss + sc_loss + mag_loss
300
 
 
233
  print(f"enhanced_audios shape: {enhanced_audios.shape}, dtype: {enhanced_audios.dtype}")
234
 
235
  ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
236
+ sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
237
 
238
  loss = ae_loss + sc_loss + mag_loss
239
 
 
294
  enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
295
 
296
  ae_loss = ae_loss_fn(enhanced_audios, enhanced_audios)
297
+ sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
298
 
299
  loss = ae_loss + sc_loss + mag_loss
300