HoneyTian commited on
Commit
ff1dc38
·
1 Parent(s): d91f843
examples/mpnet_aishell/step_2_train_model.py CHANGED
@@ -167,18 +167,18 @@ def main():
167
  discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
168
 
169
  # resume training
170
- epoch_max = -1
171
  for epoch_i in serialization_dir.glob("epoch-*"):
172
  epoch_i = Path(epoch_i)
173
  epoch_idx = epoch_i.stem.split("-")[1]
174
  epoch_idx = int(epoch_idx)
175
- if epoch_idx > epoch_max:
176
- epoch_max = epoch_idx
177
 
178
- if epoch_max != -1:
179
- logger.info(f"resume from epoch-{epoch_max}.")
180
- generator_pt = serialization_dir / f"epoch-{epoch_max}/generator.pt"
181
- discriminator_pt = serialization_dir / f"epoch-{epoch_max}/discriminator.pt"
182
 
183
  logger.info(f"load state dict for generator.")
184
  with open(generator_pt.as_posix(), "rb") as f:
@@ -199,8 +199,8 @@ def main():
199
  optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
200
  optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
201
 
202
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=-1)
203
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=-1)
204
 
205
  # training loop
206
 
@@ -219,7 +219,7 @@ def main():
219
  patience_count = 0
220
 
221
  logger.info("training")
222
- for idx_epoch in range(args.max_epochs):
223
  # train
224
  generator.train()
225
  discriminator.train()
 
167
  discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
168
 
169
  # resume training
170
+ last_epoch = -1
171
  for epoch_i in serialization_dir.glob("epoch-*"):
172
  epoch_i = Path(epoch_i)
173
  epoch_idx = epoch_i.stem.split("-")[1]
174
  epoch_idx = int(epoch_idx)
175
+ if epoch_idx > last_epoch:
176
+ last_epoch = epoch_idx
177
 
178
+ if last_epoch != -1:
179
+ logger.info(f"resume from epoch-{last_epoch}.")
180
+ generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt"
181
+ discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt"
182
 
183
  logger.info(f"load state dict for generator.")
184
  with open(generator_pt.as_posix(), "rb") as f:
 
199
  optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
200
  optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
201
 
202
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
203
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
204
 
205
  # training loop
206
 
 
219
  patience_count = 0
220
 
221
  logger.info("training")
222
+ for idx_epoch in range(max(0, last_epoch), args.max_epochs):
223
  # train
224
  generator.train()
225
  discriminator.train()