HoneyTian commited on
Commit
46cf9fb
·
1 Parent(s): 0a82c5b
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -229,6 +229,7 @@ def main():
229
  print(f"noisy_audios shape: {noisy_audios.shape}, dtype: {noisy_audios.dtype}")
230
 
231
  enhanced_audios = model.forward(noisy_audios)
 
232
  print(f"enhanced_audios shape: {enhanced_audios.shape}, dtype: {enhanced_audios.dtype}")
233
 
234
  ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
@@ -291,6 +292,7 @@ def main():
291
 
292
  enhanced_audios = model.forward(noisy_audios)
293
  enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
 
294
  ae_loss = ae_loss_fn(enhanced_audios, enhanced_audios)
295
  sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios.squeeze(1), clean_audios.squeeze(1))
296
 
 
229
  print(f"noisy_audios shape: {noisy_audios.shape}, dtype: {noisy_audios.dtype}")
230
 
231
  enhanced_audios = model.forward(noisy_audios)
232
+ enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
233
  print(f"enhanced_audios shape: {enhanced_audios.shape}, dtype: {enhanced_audios.dtype}")
234
 
235
  ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
 
292
 
293
  enhanced_audios = model.forward(noisy_audios)
294
  enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
295
+
296
  ae_loss = ae_loss_fn(enhanced_audios, enhanced_audios)
297
  sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios.squeeze(1), clean_audios.squeeze(1))
298