Spaces:
Running
Running
update
Browse files
examples/mpnet_aishell/step_2_train_model.py
CHANGED
@@ -176,9 +176,6 @@ def main():
|
|
176 |
optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
177 |
optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
178 |
|
179 |
-
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
|
180 |
-
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
|
181 |
-
|
182 |
# resume training
|
183 |
last_epoch = -1
|
184 |
for epoch_i in serialization_dir.glob("epoch-*"):
|
@@ -213,6 +210,9 @@ def main():
|
|
213 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
214 |
optim_d.load_state_dict(state_dict, strict=True)
|
215 |
|
|
|
|
|
|
|
216 |
# training loop
|
217 |
|
218 |
# state
|
|
|
176 |
optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
177 |
optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
178 |
|
|
|
|
|
|
|
179 |
# resume training
|
180 |
last_epoch = -1
|
181 |
for epoch_i in serialization_dir.glob("epoch-*"):
|
|
|
210 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
211 |
optim_d.load_state_dict(state_dict, strict=True)
|
212 |
|
213 |
+
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
|
214 |
+
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
|
215 |
+
|
216 |
# training loop
|
217 |
|
218 |
# state
|