Spaces:
Running
Running
update
Browse files
examples/spectrum_unet_irm_aishell/run.sh
CHANGED
@@ -12,6 +12,10 @@ sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name fi
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
|
|
|
|
|
|
|
|
15 |
|
16 |
END
|
17 |
|
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
+
sh run.sh --stage 2 --stop_stage 3 --system_version centos --file_folder_name file_dir \
|
16 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
+
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
+
|
19 |
|
20 |
END
|
21 |
|
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -290,6 +290,10 @@ def main():
|
|
290 |
snr_db_target = snr_db.to(device)
|
291 |
|
292 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
|
|
|
|
|
|
|
|
293 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
294 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
295 |
loss = irm_loss + 0.01 * snr_loss
|
@@ -325,6 +329,10 @@ def main():
|
|
325 |
|
326 |
with torch.no_grad():
|
327 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
|
|
|
|
|
|
|
|
328 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
329 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
330 |
loss = irm_loss + 0.01 * snr_loss
|
|
|
290 |
snr_db_target = snr_db.to(device)
|
291 |
|
292 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
293 |
+
if torch.any(torch.isnan(speech_irm_prediction)):
|
294 |
+
raise AssertionError("nan in speech_irm_prediction")
|
295 |
+
if torch.any(torch.isnan(lsnr_prediction)):
|
296 |
+
raise AssertionError("nan in lsnr_prediction")
|
297 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
298 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
299 |
loss = irm_loss + 0.01 * snr_loss
|
|
|
329 |
|
330 |
with torch.no_grad():
|
331 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
332 |
+
if torch.any(torch.isnan(speech_irm_prediction)):
|
333 |
+
raise AssertionError("nan in speech_irm_prediction")
|
334 |
+
if torch.any(torch.isnan(lsnr_prediction)):
|
335 |
+
raise AssertionError("nan in lsnr_prediction")
|
336 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
337 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
338 |
loss = irm_loss + 0.01 * snr_loss
|