Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -233,6 +233,7 @@ def main():
|
|
233 |
average_neg_stoi_loss = 1000000000
|
234 |
|
235 |
model_list = list()
|
|
|
236 |
best_steps = None
|
237 |
best_metric = None
|
238 |
patience_count = 0
|
@@ -409,10 +410,12 @@ def main():
|
|
409 |
|
410 |
# save metric
|
411 |
if best_metric is None:
|
|
|
412 |
best_steps = total_steps
|
413 |
best_metric = average_pesq_score
|
414 |
elif average_pesq_score > best_metric:
|
415 |
# great is better.
|
|
|
416 |
best_steps = total_steps
|
417 |
best_metric = average_pesq_score
|
418 |
else:
|
@@ -420,6 +423,7 @@ def main():
|
|
420 |
|
421 |
metrics = {
|
422 |
"idx_epoch": idx_epoch,
|
|
|
423 |
"best_steps": best_steps,
|
424 |
"pesq_score": average_pesq_score,
|
425 |
"loss": average_loss,
|
@@ -433,14 +437,14 @@ def main():
|
|
433 |
|
434 |
# save best
|
435 |
best_dir = serialization_dir / "best"
|
436 |
-
if best_idx_epoch == idx_epoch:
|
437 |
if best_dir.exists():
|
438 |
shutil.rmtree(best_dir)
|
439 |
shutil.copytree(save_dir, best_dir)
|
440 |
|
441 |
# early stop
|
442 |
early_stop_flag = False
|
443 |
-
if best_idx_epoch == idx_epoch:
|
444 |
patience_count = 0
|
445 |
else:
|
446 |
patience_count += 1
|
|
|
233 |
average_neg_stoi_loss = 1000000000
|
234 |
|
235 |
model_list = list()
|
236 |
+
best_idx_epoch = None
|
237 |
best_steps = None
|
238 |
best_metric = None
|
239 |
patience_count = 0
|
|
|
410 |
|
411 |
# save metric
|
412 |
if best_metric is None:
|
413 |
+
best_idx_epoch = idx_epoch
|
414 |
best_steps = total_steps
|
415 |
best_metric = average_pesq_score
|
416 |
elif average_pesq_score > best_metric:
|
417 |
# great is better.
|
418 |
+
best_idx_epoch = idx_epoch
|
419 |
best_steps = total_steps
|
420 |
best_metric = average_pesq_score
|
421 |
else:
|
|
|
423 |
|
424 |
metrics = {
|
425 |
"idx_epoch": idx_epoch,
|
426 |
+
"best_idx_epoch": best_idx_epoch,
|
427 |
"best_steps": best_steps,
|
428 |
"pesq_score": average_pesq_score,
|
429 |
"loss": average_loss,
|
|
|
437 |
|
438 |
# save best
|
439 |
best_dir = serialization_dir / "best"
|
440 |
+
if best_idx_epoch == idx_epoch and best_steps == total_steps:
|
441 |
if best_dir.exists():
|
442 |
shutil.rmtree(best_dir)
|
443 |
shutil.copytree(save_dir, best_dir)
|
444 |
|
445 |
# early stop
|
446 |
early_stop_flag = False
|
447 |
+
if best_idx_epoch == idx_epoch and best_steps == total_steps:
|
448 |
patience_count = 0
|
449 |
else:
|
450 |
patience_count += 1
|