HoneyTian commited on
Commit
bb37ac1
·
1 Parent(s): 45e8101
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -233,6 +233,7 @@ def main():
233
  average_neg_stoi_loss = 1000000000
234
 
235
  model_list = list()
 
236
  best_steps = None
237
  best_metric = None
238
  patience_count = 0
@@ -409,10 +410,12 @@ def main():
409
 
410
  # save metric
411
  if best_metric is None:
 
412
  best_steps = total_steps
413
  best_metric = average_pesq_score
414
  elif average_pesq_score > best_metric:
415
  # great is better.
 
416
  best_steps = total_steps
417
  best_metric = average_pesq_score
418
  else:
@@ -420,6 +423,7 @@ def main():
420
 
421
  metrics = {
422
  "idx_epoch": idx_epoch,
 
423
  "best_steps": best_steps,
424
  "pesq_score": average_pesq_score,
425
  "loss": average_loss,
@@ -433,14 +437,14 @@ def main():
433
 
434
  # save best
435
  best_dir = serialization_dir / "best"
436
- if best_idx_epoch == idx_epoch:
437
  if best_dir.exists():
438
  shutil.rmtree(best_dir)
439
  shutil.copytree(save_dir, best_dir)
440
 
441
  # early stop
442
  early_stop_flag = False
443
- if best_idx_epoch == idx_epoch:
444
  patience_count = 0
445
  else:
446
  patience_count += 1
 
233
  average_neg_stoi_loss = 1000000000
234
 
235
  model_list = list()
236
+ best_idx_epoch = None
237
  best_steps = None
238
  best_metric = None
239
  patience_count = 0
 
410
 
411
  # save metric
412
  if best_metric is None:
413
+ best_idx_epoch = idx_epoch
414
  best_steps = total_steps
415
  best_metric = average_pesq_score
416
  elif average_pesq_score > best_metric:
417
  # great is better.
418
+ best_idx_epoch = idx_epoch
419
  best_steps = total_steps
420
  best_metric = average_pesq_score
421
  else:
 
423
 
424
  metrics = {
425
  "idx_epoch": idx_epoch,
426
+ "best_idx_epoch": best_idx_epoch,
427
  "best_steps": best_steps,
428
  "pesq_score": average_pesq_score,
429
  "loss": average_loss,
 
437
 
438
  # save best
439
  best_dir = serialization_dir / "best"
440
+ if best_idx_epoch == idx_epoch and best_steps == total_steps:
441
  if best_dir.exists():
442
  shutil.rmtree(best_dir)
443
  shutil.copytree(save_dir, best_dir)
444
 
445
  # early stop
446
  early_stop_flag = False
447
+ if best_idx_epoch == idx_epoch and best_steps == total_steps:
448
  patience_count = 0
449
  else:
450
  patience_count += 1