HoneyTian commited on
Commit
dc94aa4
·
1 Parent(s): 63dd56a
examples/spectrum_dfnet_aishell/run.sh CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
- sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
 
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
+ sh run.sh --stage 2 --stop_stage 3 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
examples/spectrum_dfnet_aishell/step_2_train_model.py CHANGED
@@ -324,9 +324,6 @@ def main():
324
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
325
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
326
 
327
- # if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
328
- # raise AssertionError("nan or inf in snr_loss")
329
-
330
  loss = speech_loss + irm_loss + snr_loss
331
 
332
  total_loss += loss.item()
 
324
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
325
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
326
 
 
 
 
327
  loss = speech_loss + irm_loss + snr_loss
328
 
329
  total_loss += loss.item()
examples/spectrum_dfnet_aishell/step_3_evaluation.py CHANGED
@@ -19,7 +19,7 @@ import torch.nn as nn
19
  import torchaudio
20
  from tqdm import tqdm
21
 
22
- from toolbox.torchaudio.models.spectrum_unet_irm.modeling_spectrum_unet_irm import SpectrumUnetIRMPretrainedModel
23
 
24
 
25
  def get_args():
@@ -152,7 +152,7 @@ def main():
152
  logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
153
 
154
  logger.info("prepare model")
155
- model = SpectrumUnetIRMPretrainedModel.from_pretrained(
156
  pretrained_model_name_or_path=args.model_dir,
157
  )
158
  model.to(device)
 
19
  import torchaudio
20
  from tqdm import tqdm
21
 
22
+ from toolbox.torchaudio.models.spectrum_dfnet.modeling_spectrum_dfnet import SpectrumDfNetPretrainedModel
23
 
24
 
25
  def get_args():
 
152
  logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
153
 
154
  logger.info("prepare model")
155
+ model = SpectrumDfNetPretrainedModel.from_pretrained(
156
  pretrained_model_name_or_path=args.model_dir,
157
  )
158
  model.to(device)