Spaces:
Running
Running
update
Browse files
examples/mpnet_aishell/run.sh
CHANGED
@@ -109,7 +109,7 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
|
109 |
$verbose && echo "stage 2: train model"
|
110 |
cd "${work_dir}" || exit 1
|
111 |
python3 step_2_train_model.py \
|
112 |
-
--train_dataset "${
|
113 |
--valid_dataset "${valid_dataset}" \
|
114 |
--serialization_dir "${file_dir}" \
|
115 |
--config_file "${config_file}" \
|
|
|
109 |
$verbose && echo "stage 2: train model"
|
110 |
cd "${work_dir}" || exit 1
|
111 |
python3 step_2_train_model.py \
|
112 |
+
--train_dataset "${valid_dataset}" \
|
113 |
--valid_dataset "${valid_dataset}" \
|
114 |
--serialization_dir "${file_dir}" \
|
115 |
--config_file "${config_file}" \
|
examples/mpnet_aishell/step_2_train_model.py
CHANGED
@@ -166,6 +166,19 @@ def main():
|
|
166 |
generator = MPNetPretrainedModel(config).to(device)
|
167 |
discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
# resume training
|
170 |
last_epoch = -1
|
171 |
for epoch_i in serialization_dir.glob("epoch-*"):
|
@@ -179,6 +192,8 @@ def main():
|
|
179 |
logger.info(f"resume from epoch-{last_epoch}.")
|
180 |
generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt"
|
181 |
discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt"
|
|
|
|
|
182 |
|
183 |
logger.info(f"load state dict for generator.")
|
184 |
with open(generator_pt.as_posix(), "rb") as f:
|
@@ -189,18 +204,14 @@ def main():
|
|
189 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
190 |
discriminator.load_state_dict(state_dict, strict=True)
|
191 |
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
201 |
-
|
202 |
-
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
|
203 |
-
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
|
204 |
|
205 |
# training loop
|
206 |
|
@@ -369,6 +380,10 @@ def main():
|
|
369 |
generator.save_pretrained(epoch_dir.as_posix())
|
370 |
discriminator.save_pretrained(epoch_dir.as_posix())
|
371 |
|
|
|
|
|
|
|
|
|
372 |
model_list.append(epoch_dir)
|
373 |
if len(model_list) >= args.num_serialized_models_to_keep:
|
374 |
model_to_delete: Path = model_list.pop(0)
|
|
|
166 |
generator = MPNetPretrainedModel(config).to(device)
|
167 |
discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
|
168 |
|
169 |
+
# optimizer
|
170 |
+
logger.info("prepare optimizer, lr_scheduler")
|
171 |
+
num_params = 0
|
172 |
+
for p in generator.parameters():
|
173 |
+
num_params += p.numel()
|
174 |
+
logger.info("total parameters (generator): {:.3f}M".format(num_params/1e6))
|
175 |
+
|
176 |
+
optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
177 |
+
optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
178 |
+
|
179 |
+
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
|
180 |
+
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
|
181 |
+
|
182 |
# resume training
|
183 |
last_epoch = -1
|
184 |
for epoch_i in serialization_dir.glob("epoch-*"):
|
|
|
192 |
logger.info(f"resume from epoch-{last_epoch}.")
|
193 |
generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt"
|
194 |
discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt"
|
195 |
+
optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth"
|
196 |
+
optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth"
|
197 |
|
198 |
logger.info(f"load state dict for generator.")
|
199 |
with open(generator_pt.as_posix(), "rb") as f:
|
|
|
204 |
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
205 |
discriminator.load_state_dict(state_dict, strict=True)
|
206 |
|
207 |
+
logger.info(f"load state dict for optim_g.")
|
208 |
+
with open(optim_g_pth.as_posix(), "rb") as f:
|
209 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
210 |
+
optim_g.load_state_dict(state_dict, strict=True)
|
211 |
+
logger.info(f"load state dict for optim_d.")
|
212 |
+
with open(optim_d_pth.as_posix(), "rb") as f:
|
213 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
214 |
+
optim_d.load_state_dict(state_dict, strict=True)
|
|
|
|
|
|
|
|
|
215 |
|
216 |
# training loop
|
217 |
|
|
|
380 |
generator.save_pretrained(epoch_dir.as_posix())
|
381 |
discriminator.save_pretrained(epoch_dir.as_posix())
|
382 |
|
383 |
+
# save optim
|
384 |
+
torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix())
|
385 |
+
torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix())
|
386 |
+
|
387 |
model_list.append(epoch_dir)
|
388 |
if len(model_list) >= args.num_serialized_models_to_keep:
|
389 |
model_to_delete: Path = model_list.pop(0)
|