HoneyTian commited on
Commit
fc1879e
·
1 Parent(s): ff1dc38
examples/mpnet_aishell/run.sh CHANGED
@@ -109,7 +109,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
109
  $verbose && echo "stage 2: train model"
110
  cd "${work_dir}" || exit 1
111
  python3 step_2_train_model.py \
112
- --train_dataset "${train_dataset}" \
113
  --valid_dataset "${valid_dataset}" \
114
  --serialization_dir "${file_dir}" \
115
  --config_file "${config_file}" \
 
109
  $verbose && echo "stage 2: train model"
110
  cd "${work_dir}" || exit 1
111
  python3 step_2_train_model.py \
112
+ --train_dataset "${valid_dataset}" \
113
  --valid_dataset "${valid_dataset}" \
114
  --serialization_dir "${file_dir}" \
115
  --config_file "${config_file}" \
examples/mpnet_aishell/step_2_train_model.py CHANGED
@@ -166,6 +166,19 @@ def main():
166
  generator = MPNetPretrainedModel(config).to(device)
167
  discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  # resume training
170
  last_epoch = -1
171
  for epoch_i in serialization_dir.glob("epoch-*"):
@@ -179,6 +192,8 @@ def main():
179
  logger.info(f"resume from epoch-{last_epoch}.")
180
  generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt"
181
  discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt"
 
 
182
 
183
  logger.info(f"load state dict for generator.")
184
  with open(generator_pt.as_posix(), "rb") as f:
@@ -189,18 +204,14 @@ def main():
189
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
190
  discriminator.load_state_dict(state_dict, strict=True)
191
 
192
- # optimizer
193
- logger.info("prepare optimizer, lr_scheduler")
194
- num_params = 0
195
- for p in generator.parameters():
196
- num_params += p.numel()
197
- logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6))
198
-
199
- optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
200
- optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
201
-
202
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
203
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
204
 
205
  # training loop
206
 
@@ -369,6 +380,10 @@ def main():
369
  generator.save_pretrained(epoch_dir.as_posix())
370
  discriminator.save_pretrained(epoch_dir.as_posix())
371
 
 
 
 
 
372
  model_list.append(epoch_dir)
373
  if len(model_list) >= args.num_serialized_models_to_keep:
374
  model_to_delete: Path = model_list.pop(0)
 
166
  generator = MPNetPretrainedModel(config).to(device)
167
  discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
168
 
169
+ # optimizer
170
+ logger.info("prepare optimizer, lr_scheduler")
171
+ num_params = 0
172
+ for p in generator.parameters():
173
+ num_params += p.numel()
174
+ logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6))
175
+
176
+ optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
177
+ optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
178
+
179
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
180
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
181
+
182
  # resume training
183
  last_epoch = -1
184
  for epoch_i in serialization_dir.glob("epoch-*"):
 
192
  logger.info(f"resume from epoch-{last_epoch}.")
193
  generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt"
194
  discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt"
195
+ optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth"
196
+ optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth"
197
 
198
  logger.info(f"load state dict for generator.")
199
  with open(generator_pt.as_posix(), "rb") as f:
 
204
  state_dict = torch.load(f, map_location="cpu", weights_only=True)
205
  discriminator.load_state_dict(state_dict, strict=True)
206
 
207
+ logger.info(f"load state dict for optim_g.")
208
+ with open(optim_g_pth.as_posix(), "rb") as f:
209
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
210
+ optim_g.load_state_dict(state_dict, strict=True)
211
+ logger.info(f"load state dict for optim_d.")
212
+ with open(optim_d_pth.as_posix(), "rb") as f:
213
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
214
+ optim_d.load_state_dict(state_dict, strict=True)
 
 
 
 
215
 
216
  # training loop
217
 
 
380
  generator.save_pretrained(epoch_dir.as_posix())
381
  discriminator.save_pretrained(epoch_dir.as_posix())
382
 
383
+ # save optim
384
+ torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix())
385
+ torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix())
386
+
387
  model_list.append(epoch_dir)
388
  if len(model_list) >= args.num_serialized_models_to_keep:
389
  model_to_delete: Path = model_list.pop(0)