Spaces:
Running
Running
add dfnet2
Browse files
examples/dfnet2/run.sh
CHANGED
@@ -6,7 +6,7 @@ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name f
|
|
6 |
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
7 |
--speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
|
8 |
|
9 |
-
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name
|
10 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
11 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
12 |
|
|
|
6 |
--noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
|
7 |
--speech_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/speech"
|
8 |
|
9 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx-dns3 \
|
10 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
11 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
12 |
|
examples/dfnet2/step_2_train_model.py
CHANGED
@@ -265,6 +265,9 @@ def main():
|
|
265 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
266 |
|
267 |
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
|
|
|
|
|
|
268 |
|
269 |
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
270 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
@@ -336,6 +339,9 @@ def main():
|
|
336 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
337 |
|
338 |
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
|
|
|
|
|
|
339 |
|
340 |
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
341 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
|
|
265 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
266 |
|
267 |
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
268 |
+
# est_wav shape: [b, 1, n_samples]
|
269 |
+
est_wav = torch.squeeze(est_wav, dim=1)
|
270 |
+
# est_wav shape: [b, n_samples]
|
271 |
|
272 |
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
273 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|
|
|
339 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
340 |
|
341 |
est_spec, est_wav, est_mask, lsnr = model.forward(noisy_audios)
|
342 |
+
# est_wav shape: [b, 1, n_samples]
|
343 |
+
est_wav = torch.squeeze(est_wav, dim=1)
|
344 |
+
# est_wav shape: [b, n_samples]
|
345 |
|
346 |
mr_stft_loss = mr_stft_loss_fn.forward(est_wav, clean_audios)
|
347 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(est_wav, clean_audios)
|