HoneyTian commited on
Commit
3188b07
·
1 Parent(s): 5bb28a8
examples/mpnet_aishell/step_2_train_model.py CHANGED
@@ -167,11 +167,26 @@ def main():
167
  discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
168
 
169
  # resume training
 
170
  for epoch_i in serialization_dir.glob("epoch-*"):
171
  epoch_i = Path(epoch_i)
172
  epoch_idx = epoch_i.stem.split("-")[1]
173
- print(f"epoch_idx: {epoch_idx}")
174
- exit(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  # optimizer
177
  logger.info("prepare optimizer, lr_scheduler")
@@ -285,6 +300,8 @@ def main():
285
 
286
  # evaluation
287
  generator.eval()
 
 
288
  torch.cuda.empty_cache()
289
  total_pesq_score = 0.
290
  total_mag_err = 0.
@@ -361,6 +378,7 @@ def main():
361
  best_idx_epoch = idx_epoch
362
  best_metric = pesq_metric
363
  elif pesq_metric > best_metric:
 
364
  best_idx_epoch = idx_epoch
365
  best_metric = pesq_metric
366
  else:
 
167
  discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
168
 
169
  # resume training
170
+ epoch_max = -1
171
  for epoch_i in serialization_dir.glob("epoch-*"):
172
  epoch_i = Path(epoch_i)
173
  epoch_idx = epoch_i.stem.split("-")[1]
174
+ if epoch_idx > epoch_max:
175
+ epoch_max = epoch_idx
176
+
177
+ if epoch_max != -1:
178
+ logger.info(f"resume from epoch-{epoch_max}.")
179
+ generator_pt = serialization_dir / f"epoch-{epoch_max}/generator.pt"
180
+ discriminator_pt = serialization_dir / f"epoch-{epoch_max}/discriminator.pt"
181
+
182
+ logger.info(f"load state dict for generator.")
183
+ with open(generator_pt.as_posix(), "rb") as f:
184
+ state_dict = torch.load(f, map_location="cpu")
185
+ generator.load_state_dict(state_dict, strict=True)
186
+ logger.info(f"load state dict for discriminator.")
187
+ with open(discriminator_pt.as_posix(), "rb") as f:
188
+ state_dict = torch.load(f, map_location="cpu")
189
+ discriminator.load_state_dict(state_dict, strict=True)
190
 
191
  # optimizer
192
  logger.info("prepare optimizer, lr_scheduler")
 
300
 
301
  # evaluation
302
  generator.eval()
303
+ discriminator.eval()
304
+
305
  torch.cuda.empty_cache()
306
  total_pesq_score = 0.
307
  total_mag_err = 0.
 
378
  best_idx_epoch = idx_epoch
379
  best_metric = pesq_metric
380
  elif pesq_metric > best_metric:
381
+ # great is better.
382
  best_idx_epoch = idx_epoch
383
  best_metric = pesq_metric
384
  else: