Spaces:
Running
Running
update
Browse files
examples/mpnet_aishell/step_2_train_model.py
CHANGED
@@ -167,11 +167,26 @@ def main():
|
|
167 |
discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
|
168 |
|
169 |
# resume training
|
|
|
170 |
for epoch_i in serialization_dir.glob("epoch-*"):
|
171 |
epoch_i = Path(epoch_i)
|
172 |
epoch_idx = epoch_i.stem.split("-")[1]
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
# optimizer
|
177 |
logger.info("prepare optimizer, lr_scheduler")
|
@@ -285,6 +300,8 @@ def main():
|
|
285 |
|
286 |
# evaluation
|
287 |
generator.eval()
|
|
|
|
|
288 |
torch.cuda.empty_cache()
|
289 |
total_pesq_score = 0.
|
290 |
total_mag_err = 0.
|
@@ -361,6 +378,7 @@ def main():
|
|
361 |
best_idx_epoch = idx_epoch
|
362 |
best_metric = pesq_metric
|
363 |
elif pesq_metric > best_metric:
|
|
|
364 |
best_idx_epoch = idx_epoch
|
365 |
best_metric = pesq_metric
|
366 |
else:
|
|
|
167 |
discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
|
168 |
|
169 |
# resume training
|
170 |
+
epoch_max = -1
|
171 |
for epoch_i in serialization_dir.glob("epoch-*"):
|
172 |
epoch_i = Path(epoch_i)
|
173 |
epoch_idx = epoch_i.stem.split("-")[1]
|
174 |
+
if epoch_idx > epoch_max:
|
175 |
+
epoch_max = epoch_idx
|
176 |
+
|
177 |
+
if epoch_max != -1:
|
178 |
+
logger.info(f"resume from epoch-{epoch_max}.")
|
179 |
+
generator_pt = serialization_dir / f"epoch-{epoch_max}/generator.pt"
|
180 |
+
discriminator_pt = serialization_dir / f"epoch-{epoch_max}/discriminator.pt"
|
181 |
+
|
182 |
+
logger.info(f"load state dict for generator.")
|
183 |
+
with open(generator_pt.as_posix(), "rb") as f:
|
184 |
+
state_dict = torch.load(f, map_location="cpu")
|
185 |
+
generator.load_state_dict(state_dict, strict=True)
|
186 |
+
logger.info(f"load state dict for discriminator.")
|
187 |
+
with open(discriminator_pt.as_posix(), "rb") as f:
|
188 |
+
state_dict = torch.load(f, map_location="cpu")
|
189 |
+
discriminator.load_state_dict(state_dict, strict=True)
|
190 |
|
191 |
# optimizer
|
192 |
logger.info("prepare optimizer, lr_scheduler")
|
|
|
300 |
|
301 |
# evaluation
|
302 |
generator.eval()
|
303 |
+
discriminator.eval()
|
304 |
+
|
305 |
torch.cuda.empty_cache()
|
306 |
total_pesq_score = 0.
|
307 |
total_mag_err = 0.
|
|
|
378 |
best_idx_epoch = idx_epoch
|
379 |
best_metric = pesq_metric
|
380 |
elif pesq_metric > best_metric:
|
381 |
+
# great is better.
|
382 |
best_idx_epoch = idx_epoch
|
383 |
best_metric = pesq_metric
|
384 |
else:
|