Spaces:
Running
Running
update
Browse files
examples/mpnet_aishell/step_2_train_model.py
CHANGED
@@ -32,7 +32,7 @@ from tqdm import tqdm
|
|
32 |
|
33 |
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
|
34 |
from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig
|
35 |
-
from toolbox.torchaudio.models.mpnet.discriminator import
|
36 |
from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses, pesq_score
|
37 |
from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
|
38 |
|
@@ -164,14 +164,14 @@ def main():
|
|
164 |
# models
|
165 |
logger.info(f"prepare models. config_file: {args.config_file}")
|
166 |
generator = MPNetPretrainedModel(config).to(device)
|
167 |
-
discriminator =
|
168 |
|
169 |
# optimizer
|
170 |
-
logger.info("prepare optimizer, lr_scheduler
|
171 |
num_params = 0
|
172 |
for p in generator.parameters():
|
173 |
num_params += p.numel()
|
174 |
-
|
175 |
|
176 |
optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
177 |
optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
@@ -180,8 +180,24 @@ def main():
|
|
180 |
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=-1)
|
181 |
|
182 |
# training loop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
logger.info("training")
|
184 |
for idx_epoch in range(args.max_epochs):
|
|
|
185 |
generator.train()
|
186 |
discriminator.train()
|
187 |
|
@@ -251,12 +267,16 @@ def main():
|
|
251 |
total_loss_g += loss_gen_all.item()
|
252 |
total_batches += 1
|
253 |
|
|
|
|
|
|
|
254 |
progress_bar.update(1)
|
255 |
progress_bar.set_postfix({
|
256 |
-
"loss_d":
|
257 |
-
"loss_g":
|
258 |
})
|
259 |
|
|
|
260 |
generator.eval()
|
261 |
torch.cuda.empty_cache()
|
262 |
total_pesq_score = 0.
|
@@ -297,18 +317,87 @@ def main():
|
|
297 |
|
298 |
total_batches += 1
|
299 |
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
progress_bar.update(1)
|
301 |
progress_bar.set_postfix({
|
302 |
-
"
|
303 |
-
"mag_err":
|
304 |
-
"pha_err":
|
305 |
-
"com_err":
|
306 |
-
"stft_err":
|
307 |
-
|
308 |
})
|
309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
return
|
311 |
|
312 |
|
313 |
-
if __name__ ==
|
314 |
main()
|
|
|
32 |
|
33 |
from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
|
34 |
from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig
|
35 |
+
from toolbox.torchaudio.models.mpnet.discriminator import MetricDiscriminatorPretrainedModel, batch_pesq
|
36 |
from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNet, MPNetPretrainedModel, phase_losses, pesq_score
|
37 |
from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft
|
38 |
|
|
|
164 |
# models
|
165 |
logger.info(f"prepare models. config_file: {args.config_file}")
|
166 |
generator = MPNetPretrainedModel(config).to(device)
|
167 |
+
discriminator = MetricDiscriminatorPretrainedModel(config).to(device)
|
168 |
|
169 |
# optimizer
|
170 |
+
logger.info("prepare optimizer, lr_scheduler")
|
171 |
num_params = 0
|
172 |
for p in generator.parameters():
|
173 |
num_params += p.numel()
|
174 |
+
logger.info("Total Parameters (generator): {:.3f}M".format(num_params/1e6))
|
175 |
|
176 |
optim_g = torch.optim.AdamW(generator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
177 |
optim_d = torch.optim.AdamW(discriminator.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
|
|
|
180 |
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=-1)
|
181 |
|
182 |
# training loop
|
183 |
+
|
184 |
+
# state
|
185 |
+
loss_d = 10000000000
|
186 |
+
loss_g = 10000000000
|
187 |
+
pesq_metric = 10000000000
|
188 |
+
mag_err = 10000000000
|
189 |
+
pha_err = 10000000000
|
190 |
+
com_err = 10000000000
|
191 |
+
stft_err = 10000000000
|
192 |
+
|
193 |
+
model_list = list()
|
194 |
+
best_idx_epoch = None
|
195 |
+
best_metric = None
|
196 |
+
patience_count = 0
|
197 |
+
|
198 |
logger.info("training")
|
199 |
for idx_epoch in range(args.max_epochs):
|
200 |
+
# train
|
201 |
generator.train()
|
202 |
discriminator.train()
|
203 |
|
|
|
267 |
total_loss_g += loss_gen_all.item()
|
268 |
total_batches += 1
|
269 |
|
270 |
+
loss_d = round(total_loss_d / total_batches, 4)
|
271 |
+
loss_g = round(total_loss_g / total_batches, 4)
|
272 |
+
|
273 |
progress_bar.update(1)
|
274 |
progress_bar.set_postfix({
|
275 |
+
"loss_d": loss_d,
|
276 |
+
"loss_g": loss_g,
|
277 |
})
|
278 |
|
279 |
+
# evaluation
|
280 |
generator.eval()
|
281 |
torch.cuda.empty_cache()
|
282 |
total_pesq_score = 0.
|
|
|
317 |
|
318 |
total_batches += 1
|
319 |
|
320 |
+
pesq_metric = round(total_pesq_score / total_batches, 4)
|
321 |
+
mag_err = round(total_mag_err / total_batches, 4)
|
322 |
+
pha_err = round(total_pha_err / total_batches, 4)
|
323 |
+
com_err = round(total_com_err / total_batches, 4)
|
324 |
+
stft_err = round(total_stft_err / total_batches, 4)
|
325 |
+
|
326 |
progress_bar.update(1)
|
327 |
progress_bar.set_postfix({
|
328 |
+
"pesq_metric": pesq_metric,
|
329 |
+
"mag_err": mag_err,
|
330 |
+
"pha_err": pha_err,
|
331 |
+
"com_err": com_err,
|
332 |
+
"stft_err": stft_err,
|
|
|
333 |
})
|
334 |
|
335 |
+
# scheduler
|
336 |
+
scheduler_g.step()
|
337 |
+
scheduler_d.step()
|
338 |
+
|
339 |
+
# save path
|
340 |
+
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
|
341 |
+
epoch_dir.mkdir(parents=True, exist_ok=False)
|
342 |
+
|
343 |
+
# save models
|
344 |
+
generator.save_pretrained(epoch_dir.as_posix())
|
345 |
+
discriminator.save_pretrained(epoch_dir.as_posix())
|
346 |
+
|
347 |
+
model_list.append(epoch_dir)
|
348 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
349 |
+
model_to_delete: Path = model_list.pop(0)
|
350 |
+
shutil.rmtree(model_to_delete.as_posix())
|
351 |
+
|
352 |
+
# save metric
|
353 |
+
if best_metric is None:
|
354 |
+
best_idx_epoch = idx_epoch
|
355 |
+
best_metric = pesq_metric
|
356 |
+
elif pesq_metric < best_metric:
|
357 |
+
best_idx_epoch = idx_epoch
|
358 |
+
best_metric = pesq_metric
|
359 |
+
else:
|
360 |
+
pass
|
361 |
+
|
362 |
+
metrics = {
|
363 |
+
"idx_epoch": idx_epoch,
|
364 |
+
"best_idx_epoch": best_idx_epoch,
|
365 |
+
"loss_d": loss_d,
|
366 |
+
"loss_g": loss_g,
|
367 |
+
|
368 |
+
"pesq_metric": pesq_metric,
|
369 |
+
"mag_err": mag_err,
|
370 |
+
"pha_err": pha_err,
|
371 |
+
"com_err": com_err,
|
372 |
+
"stft_err": stft_err,
|
373 |
+
|
374 |
+
}
|
375 |
+
metrics_filename = epoch_dir / "metrics_epoch.json"
|
376 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
377 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
378 |
+
|
379 |
+
# save best
|
380 |
+
best_dir = serialization_dir / "best"
|
381 |
+
if best_idx_epoch == idx_epoch:
|
382 |
+
if best_dir.exists():
|
383 |
+
shutil.rmtree(best_dir)
|
384 |
+
shutil.copytree(epoch_dir, best_dir)
|
385 |
+
|
386 |
+
# early stop
|
387 |
+
early_stop_flag = False
|
388 |
+
if best_idx_epoch == idx_epoch:
|
389 |
+
patience_count = 0
|
390 |
+
else:
|
391 |
+
patience_count += 1
|
392 |
+
if patience_count >= args.patience:
|
393 |
+
early_stop_flag = True
|
394 |
+
|
395 |
+
# early stop
|
396 |
+
if early_stop_flag:
|
397 |
+
break
|
398 |
+
|
399 |
return
|
400 |
|
401 |
|
402 |
+
if __name__ == "__main__":
|
403 |
main()
|
toolbox/torchaudio/models/mpnet/configuation_mpnet.py
CHANGED
@@ -33,6 +33,9 @@ class MPNetConfig(PretrainedConfig):
|
|
33 |
|
34 |
dist_config: dict = None,
|
35 |
|
|
|
|
|
|
|
36 |
**kwargs
|
37 |
):
|
38 |
super(MPNetConfig, self).__init__(**kwargs)
|
@@ -63,6 +66,9 @@ class MPNetConfig(PretrainedConfig):
|
|
63 |
"world_size": 1
|
64 |
}
|
65 |
|
|
|
|
|
|
|
66 |
|
67 |
if __name__ == "__main__":
|
68 |
pass
|
|
|
33 |
|
34 |
dist_config: dict = None,
|
35 |
|
36 |
+
discriminator_dim: int = 32,
|
37 |
+
discriminator_in_channel: int = 2,
|
38 |
+
|
39 |
**kwargs
|
40 |
):
|
41 |
super(MPNetConfig, self).__init__(**kwargs)
|
|
|
66 |
"world_size": 1
|
67 |
}
|
68 |
|
69 |
+
self.discriminator_dim = discriminator_dim
|
70 |
+
self.discriminator_in_channel = discriminator_in_channel
|
71 |
+
|
72 |
|
73 |
if __name__ == "__main__":
|
74 |
pass
|
toolbox/torchaudio/models/mpnet/discriminator.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import numpy as np
|
@@ -7,6 +10,8 @@ import torch.nn.functional as F
|
|
7 |
from pesq import pesq
|
8 |
from joblib import Parallel, delayed
|
9 |
|
|
|
|
|
10 |
from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d
|
11 |
|
12 |
|
@@ -38,8 +43,12 @@ def metric_loss(metric_ref, metrics_gen):
|
|
38 |
|
39 |
|
40 |
class MetricDiscriminator(nn.Module):
|
41 |
-
def __init__(self,
|
42 |
super(MetricDiscriminator, self).__init__()
|
|
|
|
|
|
|
|
|
43 |
self.layers = nn.Sequential(
|
44 |
nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
|
45 |
nn.InstanceNorm2d(dim, affine=True),
|
@@ -67,5 +76,54 @@ class MetricDiscriminator(nn.Module):
|
|
67 |
return self.layers(xy)
|
68 |
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
if __name__ == '__main__':
|
71 |
pass
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
from typing import Optional, Union
|
5 |
+
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import numpy as np
|
|
|
10 |
from pesq import pesq
|
11 |
from joblib import Parallel, delayed
|
12 |
|
13 |
+
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
14 |
+
from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig
|
15 |
from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d
|
16 |
|
17 |
|
|
|
43 |
|
44 |
|
45 |
class MetricDiscriminator(nn.Module):
|
46 |
+
def __init__(self, config: MPNetConfig):
|
47 |
super(MetricDiscriminator, self).__init__()
|
48 |
+
|
49 |
+
dim = config.discriminator_dim
|
50 |
+
in_channel = config.discriminator_in_channel
|
51 |
+
|
52 |
self.layers = nn.Sequential(
|
53 |
nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
|
54 |
nn.InstanceNorm2d(dim, affine=True),
|
|
|
76 |
return self.layers(xy)
|
77 |
|
78 |
|
79 |
+
MODEL_FILE = "discriminator.pt"
|
80 |
+
|
81 |
+
|
82 |
+
class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
|
83 |
+
def __init__(self,
|
84 |
+
config: MPNetConfig,
|
85 |
+
):
|
86 |
+
super(MetricDiscriminatorPretrainedModel, self).__init__(
|
87 |
+
config=config,
|
88 |
+
)
|
89 |
+
|
90 |
+
@classmethod
|
91 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
92 |
+
config = MPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
93 |
+
|
94 |
+
model = cls(config)
|
95 |
+
|
96 |
+
if os.path.isdir(pretrained_model_name_or_path):
|
97 |
+
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
98 |
+
else:
|
99 |
+
ckpt_file = pretrained_model_name_or_path
|
100 |
+
|
101 |
+
with open(ckpt_file, "rb") as f:
|
102 |
+
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
103 |
+
model.load_state_dict(state_dict, strict=True)
|
104 |
+
return model
|
105 |
+
|
106 |
+
def save_pretrained(self,
|
107 |
+
save_directory: Union[str, os.PathLike],
|
108 |
+
state_dict: Optional[dict] = None,
|
109 |
+
):
|
110 |
+
|
111 |
+
model = self
|
112 |
+
|
113 |
+
if state_dict is None:
|
114 |
+
state_dict = model.state_dict()
|
115 |
+
|
116 |
+
os.makedirs(save_directory, exist_ok=True)
|
117 |
+
|
118 |
+
# save state dict
|
119 |
+
model_file = os.path.join(save_directory, MODEL_FILE)
|
120 |
+
torch.save(state_dict, model_file)
|
121 |
+
|
122 |
+
# save config
|
123 |
+
config_file = os.path.join(save_directory, CONFIG_FILE)
|
124 |
+
self.config.to_yaml_file(config_file)
|
125 |
+
return save_directory
|
126 |
+
|
127 |
+
|
128 |
if __name__ == '__main__':
|
129 |
pass
|
toolbox/torchaudio/models/mpnet/modeling_mpnet.py
CHANGED
@@ -183,7 +183,7 @@ class MPNet(nn.Module):
|
|
183 |
return denoised_amp, denoised_pha, denoised_com
|
184 |
|
185 |
|
186 |
-
MODEL_FILE = "
|
187 |
|
188 |
|
189 |
class MPNetPretrainedModel(MPNet):
|
|
|
183 |
return denoised_amp, denoised_pha, denoised_com
|
184 |
|
185 |
|
186 |
+
MODEL_FILE = "generator.pt"
|
187 |
|
188 |
|
189 |
class MPNetPretrainedModel(MPNet):
|