Spaces:
Running
Running
update
Browse files
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🐢
|
4 |
colorFrom: purple
|
5 |
colorTo: blue
|
|
|
1 |
---
|
2 |
+
title: NX Denoise
|
3 |
emoji: 🐢
|
4 |
colorFrom: purple
|
5 |
colorTo: blue
|
examples/spectrum_unet_irm_aishell/run.sh
CHANGED
@@ -12,7 +12,7 @@ sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name fi
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
-
sh run.sh --stage
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
|
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
14 |
|
15 |
+
sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir \
|
16 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
17 |
--speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
|
18 |
|
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -39,7 +39,7 @@ def get_args():
|
|
39 |
parser.add_argument("--max_epochs", default=100, type=int)
|
40 |
|
41 |
parser.add_argument("--batch_size", default=64, type=int)
|
42 |
-
parser.add_argument("--learning_rate", default=1e-
|
43 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
44 |
parser.add_argument("--patience", default=5, type=int)
|
45 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
@@ -303,7 +303,8 @@ def main():
|
|
303 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
304 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
305 |
raise AssertionError("nan or inf in snr_loss")
|
306 |
-
loss = irm_loss + 0.1 * snr_loss
|
|
|
307 |
# loss = irm_loss
|
308 |
|
309 |
total_loss += loss.item()
|
@@ -345,7 +346,8 @@ def main():
|
|
345 |
if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
|
346 |
raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
|
347 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
348 |
-
loss = irm_loss + 0.1 * snr_loss
|
|
|
349 |
# loss = irm_loss
|
350 |
|
351 |
total_loss += loss.item()
|
|
|
39 |
parser.add_argument("--max_epochs", default=100, type=int)
|
40 |
|
41 |
parser.add_argument("--batch_size", default=64, type=int)
|
42 |
+
parser.add_argument("--learning_rate", default=1e-4, type=float)
|
43 |
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
44 |
parser.add_argument("--patience", default=5, type=int)
|
45 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
|
|
303 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
304 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
305 |
raise AssertionError("nan or inf in snr_loss")
|
306 |
+
# loss = irm_loss + 0.1 * snr_loss
|
307 |
+
loss = irm_loss + 0.05 * snr_loss
|
308 |
# loss = irm_loss
|
309 |
|
310 |
total_loss += loss.item()
|
|
|
346 |
if torch.max(lsnr_prediction) > 1 or torch.min(lsnr_prediction) < 0:
|
347 |
raise AssertionError(f"expected lsnr_prediction between 0 and 1.")
|
348 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
349 |
+
# loss = irm_loss + 0.1 * snr_loss
|
350 |
+
loss = irm_loss + 0.05 * snr_loss
|
351 |
# loss = irm_loss
|
352 |
|
353 |
total_loss += loss.item()
|