HoneyTian commited on
Commit
ce34f8c
·
1 Parent(s): 6512ccb
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: VM Sound Classification
3
  emoji: 🐢
4
  colorFrom: purple
5
  colorTo: blue
 
1
  ---
2
+ title: NX Denoise
3
  emoji: 🐢
4
  colorFrom: purple
5
  colorTo: blue
examples/spectrum_unet_irm_aishell/run.sh CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name fi
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
- sh run.sh --stage 2 --stop_stage 3 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
 
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
 
15
+ sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
18
 
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -39,7 +39,7 @@ def get_args():
39
  parser.add_argument("--max_epochs", default=100, type=int)
40
 
41
  parser.add_argument("--batch_size", default=64, type=int)
42
- parser.add_argument("--learning_rate", default=1e-3, type=float)
43
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
44
  parser.add_argument("--patience", default=5, type=int)
45
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
@@ -303,7 +303,8 @@ def main():
303
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
304
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
305
  raise AssertionError("nan or inf in snr_loss")
306
- loss = irm_loss + 0.1 * snr_loss
 
307
  # loss = irm_loss
308
 
309
  total_loss += loss.item()
@@ -345,7 +346,8 @@ def main():
345
  if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
346
  raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
347
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
348
- loss = irm_loss + 0.1 * snr_loss
 
349
  # loss = irm_loss
350
 
351
  total_loss += loss.item()
 
39
  parser.add_argument("--max_epochs", default=100, type=int)
40
 
41
  parser.add_argument("--batch_size", default=64, type=int)
42
+ parser.add_argument("--learning_rate", default=1e-4, type=float)
43
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
44
  parser.add_argument("--patience", default=5, type=int)
45
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
 
303
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
304
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
305
  raise AssertionError("nan or inf in snr_loss")
306
+ # loss = irm_loss + 0.1 * snr_loss
307
+ loss = irm_loss + 0.05 * snr_loss
308
  # loss = irm_loss
309
 
310
  total_loss += loss.item()
 
346
  if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
347
  raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
348
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
349
+ # loss = irm_loss + 0.1 * snr_loss
350
+ loss = irm_loss + 0.05 * snr_loss
351
  # loss = irm_loss
352
 
353
  total_loss += loss.item()