Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/run.sh
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
-
sh run.sh --stage
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
9 |
--max_epochs 400
|
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
+
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
9 |
--max_epochs 400
|
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -48,11 +48,10 @@ def get_args():
|
|
48 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
49 |
parser.add_argument("--seed", default=1234, type=int)
|
50 |
|
51 |
-
parser.add_argument("--eval_steps", default=
|
52 |
|
53 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
54 |
|
55 |
-
|
56 |
args = parser.parse_args()
|
57 |
return args
|
58 |
|
@@ -240,7 +239,7 @@ def main():
|
|
240 |
total_mr_stft_loss = 0.
|
241 |
total_batches = 0.
|
242 |
total_steps = 0
|
243 |
-
|
244 |
desc="Training; epoch-{}".format(idx_epoch),
|
245 |
)
|
246 |
for train_batch in train_data_loader:
|
@@ -286,8 +285,8 @@ def main():
|
|
286 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
287 |
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
288 |
|
289 |
-
|
290 |
-
|
291 |
"pesq_score": average_pesq_score,
|
292 |
"loss": average_loss,
|
293 |
"ae_loss": average_ae_loss,
|
@@ -309,7 +308,8 @@ def main():
|
|
309 |
total_neg_stoi_loss = 0.
|
310 |
total_batches = 0.
|
311 |
|
312 |
-
|
|
|
313 |
desc="Evaluation; step-{}".format(total_steps),
|
314 |
)
|
315 |
with torch.no_grad():
|
@@ -348,16 +348,21 @@ def main():
|
|
348 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
349 |
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
350 |
|
351 |
-
|
352 |
-
|
353 |
"pesq_score": average_pesq_score,
|
354 |
"loss": average_loss,
|
355 |
"ae_loss": average_ae_loss,
|
356 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
357 |
"neg_stoi_loss": average_neg_stoi_loss,
|
358 |
"mr_stft_loss": average_mr_stft_loss,
|
359 |
-
|
360 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
|
362 |
# save path
|
363 |
save_dir = serialization_dir / "steps-{}".format(total_steps)
|
|
|
48 |
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
49 |
parser.add_argument("--seed", default=1234, type=int)
|
50 |
|
51 |
+
parser.add_argument("--eval_steps", default=25000, type=int)
|
52 |
|
53 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
54 |
|
|
|
55 |
args = parser.parse_args()
|
56 |
return args
|
57 |
|
|
|
239 |
total_mr_stft_loss = 0.
|
240 |
total_batches = 0.
|
241 |
total_steps = 0
|
242 |
+
progress_bar_train = tqdm(
|
243 |
desc="Training; epoch-{}".format(idx_epoch),
|
244 |
)
|
245 |
for train_batch in train_data_loader:
|
|
|
285 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
286 |
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
287 |
|
288 |
+
progress_bar_train.update(1)
|
289 |
+
progress_bar_train.set_postfix({
|
290 |
"pesq_score": average_pesq_score,
|
291 |
"loss": average_loss,
|
292 |
"ae_loss": average_ae_loss,
|
|
|
308 |
total_neg_stoi_loss = 0.
|
309 |
total_batches = 0.
|
310 |
|
311 |
+
progress_bar_train.close()
|
312 |
+
progress_bar_eval = tqdm(
|
313 |
desc="Evaluation; step-{}".format(total_steps),
|
314 |
)
|
315 |
with torch.no_grad():
|
|
|
348 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
349 |
average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
|
350 |
|
351 |
+
progress_bar_eval.update(1)
|
352 |
+
progress_bar_eval.set_postfix({
|
353 |
"pesq_score": average_pesq_score,
|
354 |
"loss": average_loss,
|
355 |
"ae_loss": average_ae_loss,
|
356 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
357 |
"neg_stoi_loss": average_neg_stoi_loss,
|
358 |
"mr_stft_loss": average_mr_stft_loss,
|
|
|
359 |
})
|
360 |
+
progress_bar_eval.close()
|
361 |
+
progress_bar_train = tqdm(
|
362 |
+
initial=progress_bar_train.n,
|
363 |
+
postfix=progress_bar_train.postfix,
|
364 |
+
desc=progress_bar_train.desc,
|
365 |
+
)
|
366 |
|
367 |
# save path
|
368 |
save_dir = serialization_dir / "steps-{}".format(total_steps)
|