Spaces:
Running
Running
update
Browse files
examples/mpnet_aishell/step_2_train_model.py
CHANGED
@@ -167,18 +167,18 @@ def main():
|
|
167 |
discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
|
168 |
|
169 |
# resume training
|
170 |
-
|
171 |
for epoch_i in serialization_dir.glob("epoch-*"):
|
172 |
epoch_i = Path(epoch_i)
|
173 |
epoch_idx = epoch_i.stem.split("-")[1]
|
174 |
epoch_idx = int(epoch_idx)
|
175 |
-
if epoch_idx >
|
176 |
-
|
177 |
|
178 |
-
if
|
179 |
-
logger.info(f"resume from epoch-{
|
180 |
-
generator_pt = serialization_dir / f"epoch-{
|
181 |
-
discriminator_pt = serialization_dir / f"epoch-{
|
182 |
|
183 |
logger.info(f"load state dict for generator.")
|
184 |
with open(generator_pt.as_posix(), "rb") as f:
|
@@ -199,8 +199,8 @@ def main():
|
|
199 |
optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
200 |
optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
201 |
|
202 |
-
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch
|
203 |
-
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch
|
204 |
|
205 |
# training loop
|
206 |
|
@@ -219,7 +219,7 @@ def main():
|
|
219 |
patience_count = 0
|
220 |
|
221 |
logger.info("training")
|
222 |
-
for idx_epoch in range(args.max_epochs):
|
223 |
# train
|
224 |
generator.train()
|
225 |
discriminator.train()
|
|
|
167 |
discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
|
168 |
|
169 |
# resume training
|
170 |
+
last_epoch = -1
|
171 |
for epoch_i in serialization_dir.glob("epoch-*"):
|
172 |
epoch_i = Path(epoch_i)
|
173 |
epoch_idx = epoch_i.stem.split("-")[1]
|
174 |
epoch_idx = int(epoch_idx)
|
175 |
+
if epoch_idx > last_epoch:
|
176 |
+
last_epoch = epoch_idx
|
177 |
|
178 |
+
if last_epoch != -1:
|
179 |
+
logger.info(f"resume from epoch-{last_epoch}.")
|
180 |
+
generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt"
|
181 |
+
discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt"
|
182 |
|
183 |
logger.info(f"load state dict for generator.")
|
184 |
with open(generator_pt.as_posix(), "rb") as f:
|
|
|
199 |
optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
200 |
optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
201 |
|
202 |
+
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
|
203 |
+
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
|
204 |
|
205 |
# training loop
|
206 |
|
|
|
219 |
patience_count = 0
|
220 |
|
221 |
logger.info("training")
|
222 |
+
for idx_epoch in range(max(0, last_epoch), args.max_epochs):
|
223 |
# train
|
224 |
generator.train()
|
225 |
discriminator.train()
|