HoneyTian commited on
Commit
949a94d
·
1 Parent(s): 20d2f3e
examples/mpnet_aishell/step_2_train_model.py CHANGED
@@ -166,6 +166,11 @@ def main():
166
  generator = MPNetPretrainedModel(config).to(device)
167
  discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
168
 
 
 
 
 
 
169
  # optimizer
170
  logger.info("prepare optimizer, lr_scheduler")
171
  num_params = 0
@@ -353,7 +358,7 @@ def main():
353
  if best_metric is None:
354
  best_idx_epoch = idx_epoch
355
  best_metric = pesq_metric
356
- elif pesq_metric < best_metric:
357
  best_idx_epoch = idx_epoch
358
  best_metric = pesq_metric
359
  else:
 
166
  generator = MPNetPretrainedModel(config).to(device)
167
  discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
168
 
169
+ # resume training
170
+ for epoch_i in serialization_dir.glob("epoch-*"):
171
+ print(epoch_i)
172
+ exit(0)
173
+
174
  # optimizer
175
  logger.info("prepare optimizer, lr_scheduler")
176
  num_params = 0
 
358
  if best_metric is None:
359
  best_idx_epoch = idx_epoch
360
  best_metric = pesq_metric
361
+ elif pesq_metric > best_metric:
362
  best_idx_epoch = idx_epoch
363
  best_metric = pesq_metric
364
  else: