HoneyTian commited on
Commit
c3e8e98
·
1 Parent(s): 7e91720
examples/spectrum_unet_irm_aishell/run.sh CHANGED
@@ -12,6 +12,10 @@ sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name fi
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
 
 
 
 
15
 
16
  END
17
 
 
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
+ sh run.sh --stage 2 --stop_stage 3 --system_version centos --file_folder_name file_dir \
16
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
+
19
 
20
  END
21
 
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -290,6 +290,10 @@ def main():
290
  snr_db_target = snr_db.to(device)
291
 
292
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
 
 
 
 
293
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
294
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
295
  loss = irm_loss + 0.01 * snr_loss
@@ -325,6 +329,10 @@ def main():
325
 
326
  with torch.no_grad():
327
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
 
 
 
 
328
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
329
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
330
  loss = irm_loss + 0.01 * snr_loss
 
290
  snr_db_target = snr_db.to(device)
291
 
292
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
293
+ if torch.any(torch.isnan(speech_irm_prediction)):
294
+ raise AssertionError("nan in speech_irm_prediction")
295
+ if torch.any(torch.isnan(lsnr_prediction)):
296
+ raise AssertionError("nan in lsnr_prediction")
297
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
298
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
  loss = irm_loss + 0.01 * snr_loss
 
329
 
330
  with torch.no_grad():
331
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
332
+ if torch.any(torch.isnan(speech_irm_prediction)):
333
+ raise AssertionError("nan in speech_irm_prediction")
334
+ if torch.any(torch.isnan(lsnr_prediction)):
335
+ raise AssertionError("nan in lsnr_prediction")
336
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
337
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
338
  loss = irm_loss + 0.01 * snr_loss