HoneyTian commited on
Commit
9d91461
·
1 Parent(s): bb37ac1
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -171,19 +171,19 @@ def main():
171
  optimizer = torch.optim.AdamW(model.parameters(), config.lr)
172
 
173
  # resume training
174
- last_steps = -1
175
  last_epoch = -1
176
- for step_i in serialization_dir.glob("steps-*"):
177
- step_i = Path(step_i)
178
- step_idx = step_i.stem.split("-")[1]
179
  step_idx = int(step_idx)
180
- if step_idx > last_steps:
181
- last_steps = step_idx
182
 
183
- if last_steps != -1:
184
- logger.info(f"resume from steps-{last_steps}.")
185
- model_pt = serialization_dir / f"steps-{last_steps}/model.pt"
186
- optimizer_pth = serialization_dir / f"steps-{last_steps}/optimizer.pth"
187
 
188
  logger.info(f"load state dict for model.")
189
  with open(model_pt.as_posix(), "rb") as f:
@@ -233,13 +233,13 @@ def main():
233
  average_neg_stoi_loss = 1000000000
234
 
235
  model_list = list()
236
- best_idx_epoch = None
237
- best_steps = None
238
  best_metric = None
239
  patience_count = 0
240
 
241
  logger.info("training")
242
- for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
243
  # train
244
  model.train()
245
 
@@ -251,10 +251,10 @@ def main():
251
  total_mr_stft_loss = 0.
252
  total_batches = 0.
253
 
254
- total_steps = 0 if last_steps == -1 else last_steps
255
  progress_bar_train = tqdm(
256
- initial=total_steps,
257
- desc="Training; epoch-{}".format(idx_epoch),
258
  )
259
  for train_batch in train_data_loader:
260
  clean_audios, noisy_audios = train_batch
@@ -314,8 +314,8 @@ def main():
314
  })
315
 
316
  # evaluation
317
- total_steps += 1
318
- if total_steps % config.eval_steps == 0:
319
  with torch.no_grad():
320
  torch.cuda.empty_cache()
321
 
@@ -328,7 +328,7 @@ def main():
328
 
329
  progress_bar_train.close()
330
  progress_bar_eval = tqdm(
331
- desc="Evaluation; step-{}k".format(int(total_steps/1000)),
332
  )
333
  for eval_batch in valid_data_loader:
334
  clean_audios, noisy_audios = eval_batch
@@ -394,7 +394,7 @@ def main():
394
  )
395
 
396
  # save path
397
- save_dir = serialization_dir / "steps-{}k".format(int(total_steps/1000))
398
  save_dir.mkdir(parents=True, exist_ok=False)
399
 
400
  # save models
@@ -410,26 +410,27 @@ def main():
410
 
411
  # save metric
412
  if best_metric is None:
413
- best_idx_epoch = idx_epoch
414
- best_steps = total_steps
415
  best_metric = average_pesq_score
416
  elif average_pesq_score > best_metric:
417
  # great is better.
418
- best_idx_epoch = idx_epoch
419
- best_steps = total_steps
420
  best_metric = average_pesq_score
421
  else:
422
  pass
423
 
424
  metrics = {
425
- "idx_epoch": idx_epoch,
426
- "best_idx_epoch": best_idx_epoch,
427
- "best_steps": best_steps,
428
  "pesq_score": average_pesq_score,
429
  "loss": average_loss,
430
  "ae_loss": average_ae_loss,
431
  "neg_si_snr_loss": average_neg_si_snr_loss,
432
  "neg_stoi_loss": average_neg_stoi_loss,
 
433
  }
434
  metrics_filename = save_dir / "metrics_epoch.json"
435
  with open(metrics_filename, "w", encoding="utf-8") as f:
@@ -437,14 +438,14 @@ def main():
437
 
438
  # save best
439
  best_dir = serialization_dir / "best"
440
- if best_idx_epoch == idx_epoch and best_steps == total_steps:
441
  if best_dir.exists():
442
  shutil.rmtree(best_dir)
443
  shutil.copytree(save_dir, best_dir)
444
 
445
  # early stop
446
  early_stop_flag = False
447
- if best_idx_epoch == idx_epoch and best_steps == total_steps:
448
  patience_count = 0
449
  else:
450
  patience_count += 1
 
171
  optimizer = torch.optim.AdamW(model.parameters(), config.lr)
172
 
173
  # resume training
174
+ last_step_idx = -1
175
  last_epoch = -1
176
+ for step_idx_str in serialization_dir.glob("steps-*"):
177
+ step_idx_str = Path(step_idx_str)
178
+ step_idx = step_idx_str.stem.split("-")[1]
179
  step_idx = int(step_idx)
180
+ if step_idx > last_step_idx:
181
+ last_step_idx = step_idx
182
 
183
+ if last_step_idx != -1:
184
+ logger.info(f"resume from steps-{last_step_idx}.")
185
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
186
+ optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
187
 
188
  logger.info(f"load state dict for model.")
189
  with open(model_pt.as_posix(), "rb") as f:
 
233
  average_neg_stoi_loss = 1000000000
234
 
235
  model_list = list()
236
+ best_epoch_idx = None
237
+ best_step_idx = None
238
  best_metric = None
239
  patience_count = 0
240
 
241
  logger.info("training")
242
+ for epoch_idx in range(max(0, last_epoch+1), args.max_epochs):
243
  # train
244
  model.train()
245
 
 
251
  total_mr_stft_loss = 0.
252
  total_batches = 0.
253
 
254
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
255
  progress_bar_train = tqdm(
256
+ initial=step_idx,
257
+ desc="Training; epoch-{}".format(epoch_idx),
258
  )
259
  for train_batch in train_data_loader:
260
  clean_audios, noisy_audios = train_batch
 
314
  })
315
 
316
  # evaluation
317
+ step_idx += 1
318
+ if step_idx % config.eval_steps == 0:
319
  with torch.no_grad():
320
  torch.cuda.empty_cache()
321
 
 
328
 
329
  progress_bar_train.close()
330
  progress_bar_eval = tqdm(
331
+ desc="Evaluation; step-{}k".format(int(step_idx/1000)),
332
  )
333
  for eval_batch in valid_data_loader:
334
  clean_audios, noisy_audios = eval_batch
 
394
  )
395
 
396
  # save path
397
+ save_dir = serialization_dir / "steps-{}k".format(int(step_idx/1000))
398
  save_dir.mkdir(parents=True, exist_ok=False)
399
 
400
  # save models
 
410
 
411
  # save metric
412
  if best_metric is None:
413
+ best_epoch_idx = epoch_idx
414
+ best_step_idx = step_idx
415
  best_metric = average_pesq_score
416
  elif average_pesq_score > best_metric:
417
  # great is better.
418
+ best_epoch_idx = epoch_idx
419
+ best_step_idx = step_idx
420
  best_metric = average_pesq_score
421
  else:
422
  pass
423
 
424
  metrics = {
425
+ "epoch_idx": epoch_idx,
426
+ "best_epoch_idx": best_epoch_idx,
427
+ "best_step_idx": best_step_idx,
428
  "pesq_score": average_pesq_score,
429
  "loss": average_loss,
430
  "ae_loss": average_ae_loss,
431
  "neg_si_snr_loss": average_neg_si_snr_loss,
432
  "neg_stoi_loss": average_neg_stoi_loss,
433
+ "mr_stft_loss": average_mr_stft_loss,
434
  }
435
  metrics_filename = save_dir / "metrics_epoch.json"
436
  with open(metrics_filename, "w", encoding="utf-8") as f:
 
438
 
439
  # save best
440
  best_dir = serialization_dir / "best"
441
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
442
  if best_dir.exists():
443
  shutil.rmtree(best_dir)
444
  shutil.copytree(save_dir, best_dir)
445
 
446
  # early stop
447
  early_stop_flag = False
448
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
449
  patience_count = 0
450
  else:
451
  patience_count += 1