Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -171,19 +171,19 @@ def main():
|
|
171 |
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
172 |
|
173 |
# resume training
|
174 |
-
|
175 |
last_epoch = -1
|
176 |
-
for
|
177 |
-
|
178 |
-
step_idx =
|
179 |
step_idx = int(step_idx)
|
180 |
-
if step_idx >
|
181 |
-
|
182 |
|
183 |
-
if
|
184 |
-
logger.info(f"resume from steps-{
|
185 |
-
model_pt = serialization_dir / f"steps-{
|
186 |
-
optimizer_pth = serialization_dir / f"steps-{
|
187 |
|
188 |
logger.info(f"load state dict for model.")
|
189 |
with open(model_pt.as_posix(), "rb") as f:
|
@@ -233,13 +233,13 @@ def main():
|
|
233 |
average_neg_stoi_loss = 1000000000
|
234 |
|
235 |
model_list = list()
|
236 |
-
|
237 |
-
|
238 |
best_metric = None
|
239 |
patience_count = 0
|
240 |
|
241 |
logger.info("training")
|
242 |
-
for
|
243 |
# train
|
244 |
model.train()
|
245 |
|
@@ -251,10 +251,10 @@ def main():
|
|
251 |
total_mr_stft_loss = 0.
|
252 |
total_batches = 0.
|
253 |
|
254 |
-
|
255 |
progress_bar_train = tqdm(
|
256 |
-
initial=
|
257 |
-
desc="Training; epoch-{}".format(
|
258 |
)
|
259 |
for train_batch in train_data_loader:
|
260 |
clean_audios, noisy_audios = train_batch
|
@@ -314,8 +314,8 @@ def main():
|
|
314 |
})
|
315 |
|
316 |
# evaluation
|
317 |
-
|
318 |
-
if
|
319 |
with torch.no_grad():
|
320 |
torch.cuda.empty_cache()
|
321 |
|
@@ -328,7 +328,7 @@ def main():
|
|
328 |
|
329 |
progress_bar_train.close()
|
330 |
progress_bar_eval = tqdm(
|
331 |
-
desc="Evaluation; step-{}k".format(int(
|
332 |
)
|
333 |
for eval_batch in valid_data_loader:
|
334 |
clean_audios, noisy_audios = eval_batch
|
@@ -394,7 +394,7 @@ def main():
|
|
394 |
)
|
395 |
|
396 |
# save path
|
397 |
-
save_dir = serialization_dir / "steps-{}k".format(int(
|
398 |
save_dir.mkdir(parents=True, exist_ok=False)
|
399 |
|
400 |
# save models
|
@@ -410,26 +410,27 @@ def main():
|
|
410 |
|
411 |
# save metric
|
412 |
if best_metric is None:
|
413 |
-
|
414 |
-
|
415 |
best_metric = average_pesq_score
|
416 |
elif average_pesq_score > best_metric:
|
417 |
# great is better.
|
418 |
-
|
419 |
-
|
420 |
best_metric = average_pesq_score
|
421 |
else:
|
422 |
pass
|
423 |
|
424 |
metrics = {
|
425 |
-
"
|
426 |
-
"
|
427 |
-
"
|
428 |
"pesq_score": average_pesq_score,
|
429 |
"loss": average_loss,
|
430 |
"ae_loss": average_ae_loss,
|
431 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
432 |
"neg_stoi_loss": average_neg_stoi_loss,
|
|
|
433 |
}
|
434 |
metrics_filename = save_dir / "metrics_epoch.json"
|
435 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
@@ -437,14 +438,14 @@ def main():
|
|
437 |
|
438 |
# save best
|
439 |
best_dir = serialization_dir / "best"
|
440 |
-
if
|
441 |
if best_dir.exists():
|
442 |
shutil.rmtree(best_dir)
|
443 |
shutil.copytree(save_dir, best_dir)
|
444 |
|
445 |
# early stop
|
446 |
early_stop_flag = False
|
447 |
-
if
|
448 |
patience_count = 0
|
449 |
else:
|
450 |
patience_count += 1
|
|
|
171 |
optimizer = torch.optim.AdamW(model.parameters(), config.lr)
|
172 |
|
173 |
# resume training
|
174 |
+
last_step_idx = -1
|
175 |
last_epoch = -1
|
176 |
+
for step_idx_str in serialization_dir.glob("steps-*"):
|
177 |
+
step_idx_str = Path(step_idx_str)
|
178 |
+
step_idx = step_idx_str.stem.split("-")[1]
|
179 |
step_idx = int(step_idx)
|
180 |
+
if step_idx > last_step_idx:
|
181 |
+
last_step_idx = step_idx
|
182 |
|
183 |
+
if last_step_idx != -1:
|
184 |
+
logger.info(f"resume from steps-{last_step_idx}.")
|
185 |
+
model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
|
186 |
+
optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
|
187 |
|
188 |
logger.info(f"load state dict for model.")
|
189 |
with open(model_pt.as_posix(), "rb") as f:
|
|
|
233 |
average_neg_stoi_loss = 1000000000
|
234 |
|
235 |
model_list = list()
|
236 |
+
best_epoch_idx = None
|
237 |
+
best_step_idx = None
|
238 |
best_metric = None
|
239 |
patience_count = 0
|
240 |
|
241 |
logger.info("training")
|
242 |
+
for epoch_idx in range(max(0, last_epoch+1), args.max_epochs):
|
243 |
# train
|
244 |
model.train()
|
245 |
|
|
|
251 |
total_mr_stft_loss = 0.
|
252 |
total_batches = 0.
|
253 |
|
254 |
+
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
255 |
progress_bar_train = tqdm(
|
256 |
+
initial=step_idx,
|
257 |
+
desc="Training; epoch-{}".format(epoch_idx),
|
258 |
)
|
259 |
for train_batch in train_data_loader:
|
260 |
clean_audios, noisy_audios = train_batch
|
|
|
314 |
})
|
315 |
|
316 |
# evaluation
|
317 |
+
step_idx += 1
|
318 |
+
if step_idx % config.eval_steps == 0:
|
319 |
with torch.no_grad():
|
320 |
torch.cuda.empty_cache()
|
321 |
|
|
|
328 |
|
329 |
progress_bar_train.close()
|
330 |
progress_bar_eval = tqdm(
|
331 |
+
desc="Evaluation; step-{}k".format(int(step_idx/1000)),
|
332 |
)
|
333 |
for eval_batch in valid_data_loader:
|
334 |
clean_audios, noisy_audios = eval_batch
|
|
|
394 |
)
|
395 |
|
396 |
# save path
|
397 |
+
save_dir = serialization_dir / "steps-{}k".format(int(step_idx/1000))
|
398 |
save_dir.mkdir(parents=True, exist_ok=False)
|
399 |
|
400 |
# save models
|
|
|
410 |
|
411 |
# save metric
|
412 |
if best_metric is None:
|
413 |
+
best_epoch_idx = epoch_idx
|
414 |
+
best_step_idx = step_idx
|
415 |
best_metric = average_pesq_score
|
416 |
elif average_pesq_score > best_metric:
|
417 |
# great is better.
|
418 |
+
best_epoch_idx = epoch_idx
|
419 |
+
best_step_idx = step_idx
|
420 |
best_metric = average_pesq_score
|
421 |
else:
|
422 |
pass
|
423 |
|
424 |
metrics = {
|
425 |
+
"epoch_idx": epoch_idx,
|
426 |
+
"best_epoch_idx": best_epoch_idx,
|
427 |
+
"best_step_idx": best_step_idx,
|
428 |
"pesq_score": average_pesq_score,
|
429 |
"loss": average_loss,
|
430 |
"ae_loss": average_ae_loss,
|
431 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
432 |
"neg_stoi_loss": average_neg_stoi_loss,
|
433 |
+
"mr_stft_loss": average_mr_stft_loss,
|
434 |
}
|
435 |
metrics_filename = save_dir / "metrics_epoch.json"
|
436 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
|
|
438 |
|
439 |
# save best
|
440 |
best_dir = serialization_dir / "best"
|
441 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
442 |
if best_dir.exists():
|
443 |
shutil.rmtree(best_dir)
|
444 |
shutil.copytree(save_dir, best_dir)
|
445 |
|
446 |
# early stop
|
447 |
early_stop_flag = False
|
448 |
+
if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
|
449 |
patience_count = 0
|
450 |
else:
|
451 |
patience_count += 1
|