HoneyTian commited on
Commit
6c34ab4
·
1 Parent(s): b7562b3
examples/conv_tasnet/run.sh CHANGED
@@ -3,7 +3,7 @@
3
  : <<'END'
4
 
5
 
6
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
  --max_epochs 400
 
3
  : <<'END'
4
 
5
 
6
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
  --max_epochs 400
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -48,11 +48,10 @@ def get_args():
48
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
49
  parser.add_argument("--seed", default=1234, type=int)
50
 
51
- parser.add_argument("--eval_steps", default=5000, type=int)
52
 
53
  parser.add_argument("--config_file", default="config.yaml", type=str)
54
 
55
-
56
  args = parser.parse_args()
57
  return args
58
 
@@ -240,7 +239,7 @@ def main():
240
  total_mr_stft_loss = 0.
241
  total_batches = 0.
242
  total_steps = 0
243
- progress_bar = tqdm(
244
  desc="Training; epoch-{}".format(idx_epoch),
245
  )
246
  for train_batch in train_data_loader:
@@ -286,8 +285,8 @@ def main():
286
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
287
  average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
288
 
289
- progress_bar.update(1)
290
- progress_bar.set_postfix({
291
  "pesq_score": average_pesq_score,
292
  "loss": average_loss,
293
  "ae_loss": average_ae_loss,
@@ -309,7 +308,8 @@ def main():
309
  total_neg_stoi_loss = 0.
310
  total_batches = 0.
311
 
312
- progress_bar = tqdm(
 
313
  desc="Evaluation; step-{}".format(total_steps),
314
  )
315
  with torch.no_grad():
@@ -348,16 +348,21 @@ def main():
348
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
349
  average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
350
 
351
- progress_bar.update(1)
352
- progress_bar.set_postfix({
353
  "pesq_score": average_pesq_score,
354
  "loss": average_loss,
355
  "ae_loss": average_ae_loss,
356
  "neg_si_snr_loss": average_neg_si_snr_loss,
357
  "neg_stoi_loss": average_neg_stoi_loss,
358
  "mr_stft_loss": average_mr_stft_loss,
359
-
360
  })
 
 
 
 
 
 
361
 
362
  # save path
363
  save_dir = serialization_dir / "steps-{}".format(total_steps)
 
48
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
49
  parser.add_argument("--seed", default=1234, type=int)
50
 
51
+ parser.add_argument("--eval_steps", default=25000, type=int)
52
 
53
  parser.add_argument("--config_file", default="config.yaml", type=str)
54
 
 
55
  args = parser.parse_args()
56
  return args
57
 
 
239
  total_mr_stft_loss = 0.
240
  total_batches = 0.
241
  total_steps = 0
242
+ progress_bar_train = tqdm(
243
  desc="Training; epoch-{}".format(idx_epoch),
244
  )
245
  for train_batch in train_data_loader:
 
285
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
286
  average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
287
 
288
+ progress_bar_train.update(1)
289
+ progress_bar_train.set_postfix({
290
  "pesq_score": average_pesq_score,
291
  "loss": average_loss,
292
  "ae_loss": average_ae_loss,
 
308
  total_neg_stoi_loss = 0.
309
  total_batches = 0.
310
 
311
+ progress_bar_train.close()
312
+ progress_bar_eval = tqdm(
313
  desc="Evaluation; step-{}".format(total_steps),
314
  )
315
  with torch.no_grad():
 
348
  average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
349
  average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
350
 
351
+ progress_bar_eval.update(1)
352
+ progress_bar_eval.set_postfix({
353
  "pesq_score": average_pesq_score,
354
  "loss": average_loss,
355
  "ae_loss": average_ae_loss,
356
  "neg_si_snr_loss": average_neg_si_snr_loss,
357
  "neg_stoi_loss": average_neg_stoi_loss,
358
  "mr_stft_loss": average_mr_stft_loss,
 
359
  })
360
+ progress_bar_eval.close()
361
+ progress_bar_train = tqdm(
362
+ initial=progress_bar_train.n,
363
+ postfix=progress_bar_train.postfix,
364
+ desc=progress_bar_train.desc,
365
+ )
366
 
367
  # save path
368
  save_dir = serialization_dir / "steps-{}".format(total_steps)