HoneyTian commited on
Commit
db3e977
·
1 Parent(s): 19f90ec
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -244,8 +244,6 @@ def main():
244
  desc="Training; epoch-{}".format(idx_epoch),
245
  )
246
  for batch in train_data_loader:
247
- total_steps += 1
248
-
249
  clean_audios, noisy_audios = batch
250
  clean_audios = clean_audios.to(device)
251
  noisy_audios = noisy_audios.to(device)
@@ -299,6 +297,7 @@ def main():
299
  })
300
 
301
  # evaluation
 
302
  if total_steps % args.eval_steps:
303
  model.eval()
304
  torch.cuda.empty_cache()
@@ -311,7 +310,7 @@ def main():
311
  total_batches = 0.
312
 
313
  progress_bar = tqdm(
314
- desc="Evaluation; epoch-{}".format(idx_epoch),
315
  )
316
  with torch.no_grad():
317
  for batch in valid_data_loader:
@@ -361,28 +360,28 @@ def main():
361
  })
362
 
363
  # save path
364
- epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
365
- epoch_dir.mkdir(parents=True, exist_ok=False)
366
 
367
  # save models
368
- model.save_pretrained(epoch_dir.as_posix())
369
 
370
- model_list.append(epoch_dir)
371
  if len(model_list) >= args.num_serialized_models_to_keep:
372
  model_to_delete: Path = model_list.pop(0)
373
  shutil.rmtree(model_to_delete.as_posix())
374
 
375
  # save optim
376
- torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix())
377
 
378
  # save metric
379
  if best_metric is None:
380
  best_idx_epoch = idx_epoch
381
- best_metric = average_loss
382
- elif average_loss < best_metric:
383
  # great is better.
384
  best_idx_epoch = idx_epoch
385
- best_metric = average_loss
386
  else:
387
  pass
388
 
@@ -395,7 +394,7 @@ def main():
395
  "neg_si_snr_loss": average_neg_si_snr_loss,
396
  "neg_stoi_loss": average_neg_stoi_loss,
397
  }
398
- metrics_filename = epoch_dir / "metrics_epoch.json"
399
  with open(metrics_filename, "w", encoding="utf-8") as f:
400
  json.dump(metrics, f, indent=4, ensure_ascii=False)
401
 
@@ -404,7 +403,7 @@ def main():
404
  if best_idx_epoch == idx_epoch:
405
  if best_dir.exists():
406
  shutil.rmtree(best_dir)
407
- shutil.copytree(epoch_dir, best_dir)
408
 
409
  # early stop
410
  early_stop_flag = False
 
244
  desc="Training; epoch-{}".format(idx_epoch),
245
  )
246
  for batch in train_data_loader:
 
 
247
  clean_audios, noisy_audios = batch
248
  clean_audios = clean_audios.to(device)
249
  noisy_audios = noisy_audios.to(device)
 
297
  })
298
 
299
  # evaluation
300
+ total_steps += 1
301
  if total_steps % args.eval_steps:
302
  model.eval()
303
  torch.cuda.empty_cache()
 
310
  total_batches = 0.
311
 
312
  progress_bar = tqdm(
313
+ desc="Evaluation; step-{}".format(total_steps),
314
  )
315
  with torch.no_grad():
316
  for batch in valid_data_loader:
 
360
  })
361
 
362
  # save path
363
+ save_dir = serialization_dir / "steps-{}".format(total_steps)
364
+ save_dir.mkdir(parents=True, exist_ok=False)
365
 
366
  # save models
367
+ model.save_pretrained(save_dir.as_posix())
368
 
369
+ model_list.append(save_dir)
370
  if len(model_list) >= args.num_serialized_models_to_keep:
371
  model_to_delete: Path = model_list.pop(0)
372
  shutil.rmtree(model_to_delete.as_posix())
373
 
374
  # save optim
375
+ torch.save(optimizer.state_dict(), (save_dir / "optimizer.pth").as_posix())
376
 
377
  # save metric
378
  if best_metric is None:
379
  best_idx_epoch = idx_epoch
380
+ best_metric = average_pesq_score
381
+ elif average_pesq_score > best_metric:
382
  # great is better.
383
  best_idx_epoch = idx_epoch
384
+ best_metric = average_pesq_score
385
  else:
386
  pass
387
 
 
394
  "neg_si_snr_loss": average_neg_si_snr_loss,
395
  "neg_stoi_loss": average_neg_stoi_loss,
396
  }
397
+ metrics_filename = save_dir / "metrics_epoch.json"
398
  with open(metrics_filename, "w", encoding="utf-8") as f:
399
  json.dump(metrics, f, indent=4, ensure_ascii=False)
400
 
 
403
  if best_idx_epoch == idx_epoch:
404
  if best_dir.exists():
405
  shutil.rmtree(best_dir)
406
+ shutil.copytree(save_dir, best_dir)
407
 
408
  # early stop
409
  early_stop_flag = False