HoneyTian commited on
Commit
f69c753
·
1 Parent(s): e2f2829
examples/mpnet_aishell/step_2_train_model.py CHANGED
@@ -32,7 +32,7 @@ from tqdm import tqdm
32
 
33
  from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
34
  from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig
35
- from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminator, batch_pesq
36
  from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses, pesq_score
37
  from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
38
 
@@ -164,14 +164,14 @@ def main():
164
  # models
165
  logger.info(f"prepare models. config_file: {args.config_file}")
166
  generator = MPNetPretrainedModel(config).to(device)
167
- discriminator = MetricDiscriminator().to(device)
168
 
169
  # optimizer
170
- logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
171
  num_params = 0
172
  for p in generator.parameters():
173
  num_params += p.numel()
174
- print("Total Parameters (generator): {:.3f}M".format(num_params/1e6))
175
 
176
  optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
177
  optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
@@ -180,8 +180,24 @@ def main():
180
  scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=-1)
181
 
182
  # training loop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  logger.info("training")
184
  for idx_epoch in range(args.max_epochs):
 
185
  generator.train()
186
  discriminator.train()
187
 
@@ -251,12 +267,16 @@ def main():
251
  total_loss_g += loss_gen_all.item()
252
  total_batches += 1
253
 
 
 
 
254
  progress_bar.update(1)
255
  progress_bar.set_postfix({
256
- "loss_d": round(total_loss_d / total_batches, 4),
257
- "loss_g": round(total_loss_g / total_batches, 4),
258
  })
259
 
 
260
  generator.eval()
261
  torch.cuda.empty_cache()
262
  total_pesq_score = 0.
@@ -297,18 +317,87 @@ def main():
297
 
298
  total_batches += 1
299
 
 
 
 
 
 
 
300
  progress_bar.update(1)
301
  progress_bar.set_postfix({
302
- "pesq_score": round(total_pesq_score / total_batches, 4),
303
- "mag_err": round(total_mag_err / total_batches, 4),
304
- "pha_err": round(total_pha_err / total_batches, 4),
305
- "com_err": round(total_com_err / total_batches, 4),
306
- "stft_err": round(total_stft_err / total_batches, 4),
307
-
308
  })
309
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  return
311
 
312
 
313
- if __name__ == '__main__':
314
  main()
 
32
 
33
  from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
34
  from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig
35
+ from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminatorPretrainedModel, batch_pesq
36
  from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses, pesq_score
37
  from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
38
 
 
164
  # models
165
  logger.info(f"prepare models. config_file: {args.config_file}")
166
  generator = MPNetPretrainedModel(config).to(device)
167
+ discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
168
 
169
  # optimizer
170
+ logger.info("prepare optimizer, lr_scheduler")
171
  num_params = 0
172
  for p in generator.parameters():
173
  num_params += p.numel()
174
+ logger.info("Total Parameters (generator): {:.3f}M".format(num_params/1e6))
175
 
176
  optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
177
  optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
 
180
  scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=-1)
181
 
182
  # training loop
183
+
184
+ # state
185
+ loss_d = 10000000000
186
+ loss_g = 10000000000
187
+ pesq_metric = 10000000000
188
+ mag_err = 10000000000
189
+ pha_err = 10000000000
190
+ com_err = 10000000000
191
+ stft_err = 10000000000
192
+
193
+ model_list = list()
194
+ best_idx_epoch = None
195
+ best_metric = None
196
+ patience_count = 0
197
+
198
  logger.info("training")
199
  for idx_epoch in range(args.max_epochs):
200
+ # train
201
  generator.train()
202
  discriminator.train()
203
 
 
267
  total_loss_g += loss_gen_all.item()
268
  total_batches += 1
269
 
270
+ loss_d = round(total_loss_d / total_batches, 4)
271
+ loss_g = round(total_loss_g / total_batches, 4)
272
+
273
  progress_bar.update(1)
274
  progress_bar.set_postfix({
275
+ "loss_d": loss_d,
276
+ "loss_g": loss_g,
277
  })
278
 
279
+ # evaluation
280
  generator.eval()
281
  torch.cuda.empty_cache()
282
  total_pesq_score = 0.
 
317
 
318
  total_batches += 1
319
 
320
+ pesq_metric = round(total_pesq_score / total_batches, 4)
321
+ mag_err = round(total_mag_err / total_batches, 4)
322
+ pha_err = round(total_pha_err / total_batches, 4)
323
+ com_err = round(total_com_err / total_batches, 4)
324
+ stft_err = round(total_stft_err / total_batches, 4)
325
+
326
  progress_bar.update(1)
327
  progress_bar.set_postfix({
328
+ "pesq_metric": pesq_metric,
329
+ "mag_err": mag_err,
330
+ "pha_err": pha_err,
331
+ "com_err": com_err,
332
+ "stft_err": stft_err,
 
333
  })
334
 
335
+ # scheduler
336
+ scheduler_g.step()
337
+ scheduler_d.step()
338
+
339
+ # save path
340
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
341
+ epoch_dir.mkdir(parents=True, exist_ok=False)
342
+
343
+ # save models
344
+ generator.save_pretrained(epoch_dir.as_posix())
345
+ discriminator.save_pretrained(epoch_dir.as_posix())
346
+
347
+ model_list.append(epoch_dir)
348
+ if len(model_list) >= args.num_serialized_models_to_keep:
349
+ model_to_delete: Path = model_list.pop(0)
350
+ shutil.rmtree(model_to_delete.as_posix())
351
+
352
+ # save metric
353
+ if best_metric is None:
354
+ best_idx_epoch = idx_epoch
355
+ best_metric = pesq_metric
356
+ elif pesq_metric < best_metric:
357
+ best_idx_epoch = idx_epoch
358
+ best_metric = pesq_metric
359
+ else:
360
+ pass
361
+
362
+ metrics = {
363
+ "idx_epoch": idx_epoch,
364
+ "best_idx_epoch": best_idx_epoch,
365
+ "loss_d": loss_d,
366
+ "loss_g": loss_g,
367
+
368
+ "pesq_metric": pesq_metric,
369
+ "mag_err": mag_err,
370
+ "pha_err": pha_err,
371
+ "com_err": com_err,
372
+ "stft_err": stft_err,
373
+
374
+ }
375
+ metrics_filename = epoch_dir / "metrics_epoch.json"
376
+ with open(metrics_filename, "w", encoding="utf-8") as f:
377
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
378
+
379
+ # save best
380
+ best_dir = serialization_dir / "best"
381
+ if best_idx_epoch == idx_epoch:
382
+ if best_dir.exists():
383
+ shutil.rmtree(best_dir)
384
+ shutil.copytree(epoch_dir, best_dir)
385
+
386
+ # early stop
387
+ early_stop_flag = False
388
+ if best_idx_epoch == idx_epoch:
389
+ patience_count = 0
390
+ else:
391
+ patience_count += 1
392
+ if patience_count >= args.patience:
393
+ early_stop_flag = True
394
+
395
+ # early stop
396
+ if early_stop_flag:
397
+ break
398
+
399
  return
400
 
401
 
402
+ if __name__ == "__main__":
403
  main()
toolbox/torchaudio/models/mpnet/configuation_mpnet.py CHANGED
@@ -33,6 +33,9 @@ class MPNetConfig(PretrainedConfig):
33
 
34
  dist_config: dict = None,
35
 
 
 
 
36
  **kwargs
37
  ):
38
  super(MPNetConfig, self).__init__(**kwargs)
@@ -63,6 +66,9 @@ class MPNetConfig(PretrainedConfig):
63
  "world_size": 1
64
  }
65
 
 
 
 
66
 
67
  if __name__ == "__main__":
68
  pass
 
33
 
34
  dist_config: dict = None,
35
 
36
+ discriminator_dim: int = 32,
37
+ discriminator_in_channel: int = 2,
38
+
39
  **kwargs
40
  ):
41
  super(MPNetConfig, self).__init__(**kwargs)
 
66
  "world_size": 1
67
  }
68
 
69
+ self.discriminator_dim = discriminator_dim
70
+ self.discriminator_in_channel = discriminator_in_channel
71
+
72
 
73
  if __name__ == "__main__":
74
  pass
toolbox/torchaudio/models/mpnet/discriminator.py CHANGED
@@ -1,5 +1,8 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
 
 
 
3
  import torch
4
  import torch.nn as nn
5
  import numpy as np
@@ -7,6 +10,8 @@ import torch.nn.functional as F
7
  from pesq import pesq
8
  from joblib import Parallel, delayed
9
 
 
 
10
  from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d
11
 
12
 
@@ -38,8 +43,12 @@ def metric_loss(metric_ref, metrics_gen):
38
 
39
 
40
  class MetricDiscriminator(nn.Module):
41
- def __init__(self, dim=16, in_channel=2):
42
  super(MetricDiscriminator, self).__init__()
 
 
 
 
43
  self.layers = nn.Sequential(
44
  nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
45
  nn.InstanceNorm2d(dim, affine=True),
@@ -67,5 +76,54 @@ class MetricDiscriminator(nn.Module):
67
  return self.layers(xy)
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  if __name__ == '__main__':
71
  pass
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
+ import os
4
+ from typing import Optional, Union
5
+
6
  import torch
7
  import torch.nn as nn
8
  import numpy as np
 
10
  from pesq import pesq
11
  from joblib import Parallel, delayed
12
 
13
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
14
+ from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig
15
  from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d
16
 
17
 
 
43
 
44
 
45
  class MetricDiscriminator(nn.Module):
46
+ def __init__(self, config: MPNetConfig):
47
  super(MetricDiscriminator, self).__init__()
48
+
49
+ dim = config.discriminator_dim
50
+ in_channel = config.discriminator_in_channel
51
+
52
  self.layers = nn.Sequential(
53
  nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
54
  nn.InstanceNorm2d(dim, affine=True),
 
76
  return self.layers(xy)
77
 
78
 
79
+ MODEL_FILE = "discriminator.pt"
80
+
81
+
82
+ class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
83
+ def __init__(self,
84
+ config: MPNetConfig,
85
+ ):
86
+ super(MetricDiscriminatorPretrainedModel, self).__init__(
87
+ config=config,
88
+ )
89
+
90
+ @classmethod
91
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
92
+ config = MPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
93
+
94
+ model = cls(config)
95
+
96
+ if os.path.isdir(pretrained_model_name_or_path):
97
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
98
+ else:
99
+ ckpt_file = pretrained_model_name_or_path
100
+
101
+ with open(ckpt_file, "rb") as f:
102
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
103
+ model.load_state_dict(state_dict, strict=True)
104
+ return model
105
+
106
+ def save_pretrained(self,
107
+ save_directory: Union[str, os.PathLike],
108
+ state_dict: Optional[dict] = None,
109
+ ):
110
+
111
+ model = self
112
+
113
+ if state_dict is None:
114
+ state_dict = model.state_dict()
115
+
116
+ os.makedirs(save_directory, exist_ok=True)
117
+
118
+ # save state dict
119
+ model_file = os.path.join(save_directory, MODEL_FILE)
120
+ torch.save(state_dict, model_file)
121
+
122
+ # save config
123
+ config_file = os.path.join(save_directory, CONFIG_FILE)
124
+ self.config.to_yaml_file(config_file)
125
+ return save_directory
126
+
127
+
128
  if __name__ == '__main__':
129
  pass
toolbox/torchaudio/models/mpnet/modeling_mpnet.py CHANGED
@@ -183,7 +183,7 @@ class MPNet(nn.Module):
183
  return denoised_amp, denoised_pha, denoised_com
184
 
185
 
186
- MODEL_FILE = "model.pt"
187
 
188
 
189
  class MPNetPretrainedModel(MPNet):
 
183
  return denoised_amp, denoised_pha, denoised_com
184
 
185
 
186
+ MODEL_FILE = "generator.pt"
187
 
188
 
189
  class MPNetPretrainedModel(MPNet):