HoneyTian commited on
Commit
19f90ec
·
1 Parent(s): aa9e11e
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -48,6 +48,8 @@ def get_args():
48
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
49
  parser.add_argument("--seed", default=1234, type=int)
50
 
 
 
51
  parser.add_argument("--config_file", default="config.yaml", type=str)
52
 
53
 
@@ -237,10 +239,13 @@ def main():
237
  total_neg_stoi_loss = 0.
238
  total_mr_stft_loss = 0.
239
  total_batches = 0.
 
240
  progress_bar = tqdm(
241
  desc="Training; epoch-{}".format(idx_epoch),
242
  )
243
  for batch in train_data_loader:
 
 
244
  clean_audios, noisy_audios = batch
245
  clean_audios = clean_audios.to(device)
246
  noisy_audios = noisy_audios.to(device)
@@ -293,128 +298,126 @@ def main():
293
  "mr_stft_loss": average_mr_stft_loss,
294
  })
295
 
296
- # evaluation
297
- model.eval()
298
- torch.cuda.empty_cache()
299
-
300
- total_pesq_score = 0.
301
- total_loss = 0.
302
- total_ae_loss = 0.
303
- total_neg_si_snr_loss = 0.
304
- total_neg_stoi_loss = 0.
305
- total_batches = 0.
306
-
307
- progress_bar = tqdm(
308
- desc="Evaluation; epoch-{}".format(idx_epoch),
309
- )
310
- with torch.no_grad():
311
- for batch in valid_data_loader:
312
- clean_audios, noisy_audios = batch
313
- clean_audios = clean_audios.to(device)
314
- noisy_audios = noisy_audios.to(device)
315
-
316
- denoise_audios = model.forward(noisy_audios)
317
- denoise_audios = torch.squeeze(denoise_audios, dim=1)
318
-
319
- ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
320
- neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
321
- neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
322
- mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
323
-
324
- # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
325
- loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
326
-
327
- denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
328
- clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
329
- pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=8000, mode="nb")
330
-
331
- total_pesq_score += pesq_score
332
- total_loss += loss.item()
333
- total_ae_loss += ae_loss.item()
334
- total_neg_si_snr_loss += neg_si_snr_loss.item()
335
- total_neg_stoi_loss += neg_stoi_loss.item()
336
- total_mr_stft_loss += mr_stft_loss.item()
337
- total_batches += 1
338
-
339
- average_pesq_score = round(total_pesq_score / total_batches, 4)
340
- average_loss = round(total_loss / total_batches, 4)
341
- average_ae_loss = round(total_ae_loss / total_batches, 4)
342
- average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
343
- average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
344
- average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
345
-
346
- progress_bar.update(1)
347
- progress_bar.set_postfix({
348
- "pesq_score": average_pesq_score,
349
- "loss": average_loss,
350
- "ae_loss": average_ae_loss,
351
- "neg_si_snr_loss": average_neg_si_snr_loss,
352
- "neg_stoi_loss": average_neg_stoi_loss,
353
- "mr_stft_loss": average_mr_stft_loss,
354
-
355
- })
356
-
357
- # scheduler
358
- lr_scheduler.step()
359
-
360
- # save path
361
- epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
362
- epoch_dir.mkdir(parents=True, exist_ok=False)
363
-
364
- # save models
365
- model.save_pretrained(epoch_dir.as_posix())
366
-
367
- model_list.append(epoch_dir)
368
- if len(model_list) >= args.num_serialized_models_to_keep:
369
- model_to_delete: Path = model_list.pop(0)
370
- shutil.rmtree(model_to_delete.as_posix())
371
-
372
- # save optim
373
- torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix())
374
-
375
- # save metric
376
- if best_metric is None:
377
- best_idx_epoch = idx_epoch
378
- best_metric = average_loss
379
- elif average_loss < best_metric:
380
- # great is better.
381
- best_idx_epoch = idx_epoch
382
- best_metric = average_loss
383
- else:
384
- pass
385
-
386
- metrics = {
387
- "idx_epoch": idx_epoch,
388
- "best_idx_epoch": best_idx_epoch,
389
- "pesq_score": average_pesq_score,
390
- "loss": average_loss,
391
- "ae_loss": average_ae_loss,
392
- "neg_si_snr_loss": average_neg_si_snr_loss,
393
- "neg_stoi_loss": average_neg_stoi_loss,
394
- }
395
- metrics_filename = epoch_dir / "metrics_epoch.json"
396
- with open(metrics_filename, "w", encoding="utf-8") as f:
397
- json.dump(metrics, f, indent=4, ensure_ascii=False)
398
-
399
- # save best
400
- best_dir = serialization_dir / "best"
401
- if best_idx_epoch == idx_epoch:
402
- if best_dir.exists():
403
- shutil.rmtree(best_dir)
404
- shutil.copytree(epoch_dir, best_dir)
405
-
406
- # early stop
407
- early_stop_flag = False
408
- if best_idx_epoch == idx_epoch:
409
- patience_count = 0
410
- else:
411
- patience_count += 1
412
- if patience_count >= args.patience:
413
- early_stop_flag = True
414
-
415
- # early stop
416
- if early_stop_flag:
417
- break
418
 
419
  return
420
 
 
48
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
49
  parser.add_argument("--seed", default=1234, type=int)
50
 
51
+ parser.add_argument("--eval_steps", default=5000, type=int)
52
+
53
  parser.add_argument("--config_file", default="config.yaml", type=str)
54
 
55
 
 
239
  total_neg_stoi_loss = 0.
240
  total_mr_stft_loss = 0.
241
  total_batches = 0.
242
+ total_steps = 0
243
  progress_bar = tqdm(
244
  desc="Training; epoch-{}".format(idx_epoch),
245
  )
246
  for batch in train_data_loader:
247
+ total_steps += 1
248
+
249
  clean_audios, noisy_audios = batch
250
  clean_audios = clean_audios.to(device)
251
  noisy_audios = noisy_audios.to(device)
 
298
  "mr_stft_loss": average_mr_stft_loss,
299
  })
300
 
301
+ # evaluation
302
+ if total_steps % args.eval_steps:
303
+ model.eval()
304
+ torch.cuda.empty_cache()
305
+
306
+ total_pesq_score = 0.
307
+ total_loss = 0.
308
+ total_ae_loss = 0.
309
+ total_neg_si_snr_loss = 0.
310
+ total_neg_stoi_loss = 0.
311
+ total_batches = 0.
312
+
313
+ progress_bar = tqdm(
314
+ desc="Evaluation; epoch-{}".format(idx_epoch),
315
+ )
316
+ with torch.no_grad():
317
+ for batch in valid_data_loader:
318
+ clean_audios, noisy_audios = batch
319
+ clean_audios = clean_audios.to(device)
320
+ noisy_audios = noisy_audios.to(device)
321
+
322
+ denoise_audios = model.forward(noisy_audios)
323
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
324
+
325
+ ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
326
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
327
+ neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
328
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
329
+
330
+ # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
331
+ loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
332
+
333
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
334
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
335
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=8000, mode="nb")
336
+
337
+ total_pesq_score += pesq_score
338
+ total_loss += loss.item()
339
+ total_ae_loss += ae_loss.item()
340
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
341
+ total_neg_stoi_loss += neg_stoi_loss.item()
342
+ total_mr_stft_loss += mr_stft_loss.item()
343
+ total_batches += 1
344
+
345
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
346
+ average_loss = round(total_loss / total_batches, 4)
347
+ average_ae_loss = round(total_ae_loss / total_batches, 4)
348
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
349
+ average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
350
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
351
+
352
+ progress_bar.update(1)
353
+ progress_bar.set_postfix({
354
+ "pesq_score": average_pesq_score,
355
+ "loss": average_loss,
356
+ "ae_loss": average_ae_loss,
357
+ "neg_si_snr_loss": average_neg_si_snr_loss,
358
+ "neg_stoi_loss": average_neg_stoi_loss,
359
+ "mr_stft_loss": average_mr_stft_loss,
360
+
361
+ })
362
+
363
+ # save path
364
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
365
+ epoch_dir.mkdir(parents=True, exist_ok=False)
366
+
367
+ # save models
368
+ model.save_pretrained(epoch_dir.as_posix())
369
+
370
+ model_list.append(epoch_dir)
371
+ if len(model_list) >= args.num_serialized_models_to_keep:
372
+ model_to_delete: Path = model_list.pop(0)
373
+ shutil.rmtree(model_to_delete.as_posix())
374
+
375
+ # save optim
376
+ torch.save(optimizer.state_dict(), (epoch_dir / "optimizer.pth").as_posix())
377
+
378
+ # save metric
379
+ if best_metric is None:
380
+ best_idx_epoch = idx_epoch
381
+ best_metric = average_loss
382
+ elif average_loss < best_metric:
383
+ # great is better.
384
+ best_idx_epoch = idx_epoch
385
+ best_metric = average_loss
386
+ else:
387
+ pass
388
+
389
+ metrics = {
390
+ "idx_epoch": idx_epoch,
391
+ "best_idx_epoch": best_idx_epoch,
392
+ "pesq_score": average_pesq_score,
393
+ "loss": average_loss,
394
+ "ae_loss": average_ae_loss,
395
+ "neg_si_snr_loss": average_neg_si_snr_loss,
396
+ "neg_stoi_loss": average_neg_stoi_loss,
397
+ }
398
+ metrics_filename = epoch_dir / "metrics_epoch.json"
399
+ with open(metrics_filename, "w", encoding="utf-8") as f:
400
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
401
+
402
+ # save best
403
+ best_dir = serialization_dir / "best"
404
+ if best_idx_epoch == idx_epoch:
405
+ if best_dir.exists():
406
+ shutil.rmtree(best_dir)
407
+ shutil.copytree(epoch_dir, best_dir)
408
+
409
+ # early stop
410
+ early_stop_flag = False
411
+ if best_idx_epoch == idx_epoch:
412
+ patience_count = 0
413
+ else:
414
+ patience_count += 1
415
+ if patience_count >= args.patience:
416
+ early_stop_flag = True
417
+
418
+ # early stop
419
+ if early_stop_flag:
420
+ break
 
 
421
 
422
  return
423