HoneyTian commited on
Commit
e27a095
·
1 Parent(s): 637d40c
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -2,6 +2,8 @@
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/NVIDIA/CleanUNet/blob/main/train.py
 
 
5
  """
6
  import argparse
7
  import json
@@ -20,6 +22,7 @@ sys.path.append(os.path.join(pwd, "../../"))
20
 
21
  import numpy as np
22
  import torch
 
23
  from torch.nn import functional as F
24
  from torch.utils.data.dataloader import DataLoader
25
  from tqdm import tqdm
@@ -27,6 +30,9 @@ from tqdm import tqdm
27
  from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
28
  from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUnetConfig
29
  from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
 
 
 
30
 
31
 
32
  def get_args():
@@ -36,6 +42,9 @@ def get_args():
36
 
37
  parser.add_argument("--max_epochs", default=100, type=int)
38
 
 
 
 
39
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
40
  parser.add_argument("--patience", default=5, type=int)
41
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
@@ -158,56 +167,37 @@ def main():
158
  model = CleanUNetPretrainedModel(config).to(device)
159
 
160
  # optimizer
161
- logger.info("prepare optimizer, lr_scheduler")
162
- optim_g = torch.optim.AdamW(model.parameters(), config.learning_rate, betas=[config.adam_b1, config.adam_b2])
163
-
164
- # resume training
165
- last_epoch = -1
166
- for epoch_i in serialization_dir.glob("epoch-*"):
167
- epoch_i = Path(epoch_i)
168
- epoch_idx = epoch_i.stem.split("-")[1]
169
- epoch_idx = int(epoch_idx)
170
- if epoch_idx > last_epoch:
171
- last_epoch = epoch_idx
172
-
173
- if last_epoch != -1:
174
- logger.info(f"resume from epoch-{last_epoch}.")
175
- generator_pt = serialization_dir / f"epoch-{last_epoch}/generator.pt"
176
- discriminator_pt = serialization_dir / f"epoch-{last_epoch}/discriminator.pt"
177
- optim_g_pth = serialization_dir / f"epoch-{last_epoch}/optim_g.pth"
178
- optim_d_pth = serialization_dir / f"epoch-{last_epoch}/optim_d.pth"
179
-
180
- logger.info(f"load state dict for generator.")
181
- with open(generator_pt.as_posix(), "rb") as f:
182
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
183
- generator.load_state_dict(state_dict, strict=True)
184
- logger.info(f"load state dict for discriminator.")
185
- with open(discriminator_pt.as_posix(), "rb") as f:
186
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
187
- discriminator.load_state_dict(state_dict, strict=True)
188
-
189
- logger.info(f"load state dict for optim_g.")
190
- with open(optim_g_pth.as_posix(), "rb") as f:
191
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
192
- optim_g.load_state_dict(state_dict)
193
- logger.info(f"load state dict for optim_d.")
194
- with open(optim_d_pth.as_posix(), "rb") as f:
195
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
196
- optim_d.load_state_dict(state_dict)
197
-
198
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.lr_decay, last_epoch=last_epoch)
199
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.lr_decay, last_epoch=last_epoch)
200
 
201
  # training loop
202
 
203
  # state
204
- loss_d = 10000000000
205
- loss_g = 10000000000
206
- pesq_metric = 10000000000
207
- mag_err = 10000000000
208
- pha_err = 10000000000
209
- com_err = 10000000000
210
- stft_err = 10000000000
211
 
212
  model_list = list()
213
  best_idx_epoch = None
@@ -215,96 +205,74 @@ def main():
215
  patience_count = 0
216
 
217
  logger.info("training")
218
- for idx_epoch in range(max(0, last_epoch+1), args.max_epochs):
219
  # train
220
- generator.train()
221
- discriminator.train()
222
 
223
- total_loss_d = 0.
224
- total_loss_g = 0.
 
 
 
225
  total_batches = 0.
 
226
  progress_bar = tqdm(
227
  total=len(train_data_loader),
228
  desc="Training; epoch: {}".format(idx_epoch),
229
  )
230
  for batch in train_data_loader:
231
- clean_audio, noisy_audio = batch
232
- clean_audio = clean_audio.to(device)
233
- noisy_audio = noisy_audio.to(device)
234
- one_labels = torch.ones(clean_audio.shape[0]).to(device)
235
-
236
- clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
237
- noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
238
-
239
- mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
240
-
241
- audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
242
- mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
243
-
244
- audio_list_r, audio_list_g = list(clean_audio.cpu().numpy()), list(audio_g.detach().cpu().numpy())
245
- batch_pesq_score = batch_pesq(audio_list_r, audio_list_g)
246
-
247
- # Discriminator
248
- optim_d.zero_grad()
249
- metric_r = discriminator.forward(clean_mag, clean_mag)
250
- metric_g = discriminator.forward(clean_mag, mag_g_hat.detach())
251
- loss_disc_r = F.mse_loss(one_labels, metric_r.flatten())
252
-
253
- if batch_pesq_score is not None:
254
- loss_disc_g = F.mse_loss(batch_pesq_score.to(device), metric_g.flatten())
255
- else:
256
- # print("pesq is None!")
257
- loss_disc_g = 0
258
-
259
- loss_disc_all = loss_disc_r + loss_disc_g
260
- loss_disc_all.backward()
261
- optim_d.step()
262
-
263
- # Generator
264
- optim_g.zero_grad()
265
- # L2 Magnitude Loss
266
- loss_mag = F.mse_loss(clean_mag, mag_g)
267
- # Anti-wrapping Phase Loss
268
- loss_ip, loss_gd, loss_iaf = phase_losses(clean_pha, pha_g)
269
- loss_pha = loss_ip + loss_gd + loss_iaf
270
- # L2 Complex Loss
271
- loss_com = F.mse_loss(clean_com, com_g) * 2
272
- # L2 Consistency Loss
273
- loss_stft = F.mse_loss(com_g, com_g_hat) * 2
274
- # Time Loss
275
- loss_time = F.l1_loss(clean_audio, audio_g)
276
- # Metric Loss
277
- metric_g = discriminator.forward(clean_mag, mag_g_hat)
278
- loss_metric = F.mse_loss(metric_g.flatten(), one_labels)
279
-
280
- loss_gen_all = loss_mag * 0.9 + loss_pha * 0.3 + loss_com * 0.1 + loss_stft * 0.1 + loss_metric * 0.05 + loss_time * 0.2
281
-
282
- loss_gen_all.backward()
283
- optim_g.step()
284
-
285
- total_loss_d += loss_disc_all.item()
286
- total_loss_g += loss_gen_all.item()
287
  total_batches += 1
288
 
289
- loss_d = round(total_loss_d / total_batches, 4)
290
- loss_g = round(total_loss_g / total_batches, 4)
 
 
 
291
 
292
  progress_bar.update(1)
293
  progress_bar.set_postfix({
294
- "loss_d": loss_d,
295
- "loss_g": loss_g,
 
 
 
296
  })
297
 
298
  # evaluation
299
- generator.eval()
300
- discriminator.eval()
301
 
302
  torch.cuda.empty_cache()
303
- total_pesq_score = 0.
304
- total_mag_err = 0.
305
- total_pha_err = 0.
306
- total_com_err = 0.
307
- total_stft_err = 0.
 
308
  total_batches = 0.
309
 
310
  progress_bar = tqdm(
@@ -313,61 +281,52 @@ def main():
313
  )
314
  with torch.no_grad():
315
  for batch in valid_data_loader:
316
- clean_audio, noisy_audio = batch
317
- clean_audio = clean_audio.to(device)
318
- noisy_audio = noisy_audio.to(device)
319
-
320
- clean_mag, clean_pha, clean_com = mag_pha_stft(clean_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
321
- noisy_mag, noisy_pha, noisy_com = mag_pha_stft(noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
322
-
323
- mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha)
324
-
325
- audio_g = mag_pha_istft(mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
326
- mag_g_hat, pha_g_hat, com_g_hat = mag_pha_stft(audio_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor)
327
-
328
- total_pesq_score += pesq_score(
329
- torch.split(clean_audio, 1, dim=0),
330
- torch.split(audio_g, 1, dim=0),
331
- config
332
- ).item()
333
- total_mag_err += F.mse_loss(clean_mag, mag_g).item()
334
- val_ip_err, val_gd_err, val_iaf_err = phase_losses(clean_pha, pha_g)
335
- total_pha_err += (val_ip_err + val_gd_err + val_iaf_err).item()
336
- total_com_err += F.mse_loss(clean_com, com_g).item()
337
- total_stft_err += F.mse_loss(com_g, com_g_hat).item()
338
-
339
  total_batches += 1
340
 
341
- pesq_metric = round(total_pesq_score / total_batches, 4)
342
- mag_err = round(total_mag_err / total_batches, 4)
343
- pha_err = round(total_pha_err / total_batches, 4)
344
- com_err = round(total_com_err / total_batches, 4)
345
- stft_err = round(total_stft_err / total_batches, 4)
346
 
347
  progress_bar.update(1)
348
  progress_bar.set_postfix({
349
- "pesq_metric": pesq_metric,
350
- "mag_err": mag_err,
351
- "pha_err": pha_err,
352
- "com_err": com_err,
353
- "stft_err": stft_err,
354
  })
355
 
356
  # scheduler
357
- scheduler_g.step()
358
- scheduler_d.step()
359
 
360
  # save path
361
  epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
362
  epoch_dir.mkdir(parents=True, exist_ok=False)
363
 
364
  # save models
365
- generator.save_pretrained(epoch_dir.as_posix())
366
- discriminator.save_pretrained(epoch_dir.as_posix())
367
-
368
- # save optim
369
- torch.save(optim_d.state_dict(), (epoch_dir / "optim_d.pth").as_posix())
370
- torch.save(optim_g.state_dict(), (epoch_dir / "optim_g.pth").as_posix())
371
 
372
  model_list.append(epoch_dir)
373
  if len(model_list) >= args.num_serialized_models_to_keep:
@@ -377,25 +336,23 @@ def main():
377
  # save metric
378
  if best_metric is None:
379
  best_idx_epoch = idx_epoch
380
- best_metric = pesq_metric
381
- elif pesq_metric > best_metric:
382
  # great is better.
383
  best_idx_epoch = idx_epoch
384
- best_metric = pesq_metric
385
  else:
386
  pass
387
 
388
  metrics = {
389
  "idx_epoch": idx_epoch,
390
  "best_idx_epoch": best_idx_epoch,
391
- "loss_d": loss_d,
392
- "loss_g": loss_g,
393
-
394
- "pesq_metric": pesq_metric,
395
- "mag_err": mag_err,
396
- "pha_err": pha_err,
397
- "com_err": com_err,
398
- "stft_err": stft_err,
399
 
400
  }
401
  metrics_filename = epoch_dir / "metrics_epoch.json"
 
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/NVIDIA/CleanUNet/blob/main/train.py
5
+
6
+ https://github.com/NVIDIA/CleanUNet/blob/main/configs/DNS-large-full.json
7
  """
8
  import argparse
9
  import json
 
22
 
23
  import numpy as np
24
  import torch
25
+ import torch.nn as nn
26
  from torch.nn import functional as F
27
  from torch.utils.data.dataloader import DataLoader
28
  from tqdm import tqdm
 
30
  from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelDataset
31
  from toolbox.torchaudio.models.clean_unet.configuration_clean_unet import CleanUnetConfig
32
  from toolbox.torchaudio.models.clean_unet.modeling_clean_unet import CleanUNetPretrainedModel
33
+ from toolbox.torchaudio.models.clean_unet.training import LinearWarmupCosineDecay
34
+ from toolbox.torchaudio.models.clean_unet.loss import MultiResolutionSTFTLoss
35
+ from toolbox.torchaudio.models.clean_unet.metrics import batch_pesq
36
 
37
 
38
  def get_args():
 
42
 
43
  parser.add_argument("--max_epochs", default=100, type=int)
44
 
45
+ parser.add_argument("--batch_size", default=64, type=int)
46
+ parser.add_argument("--learning_rate", default=2e-4, type=float)
47
+
48
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
49
  parser.add_argument("--patience", default=5, type=int)
50
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
 
167
  model = CleanUNetPretrainedModel(config).to(device)
168
 
169
  # optimizer
170
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
171
+ optimizer = torch.optim.AdamW(model.parameters(), config.learning_rate)
172
+ lr_scheduler = LinearWarmupCosineDecay(
173
+ optimizer,
174
+ lr_max=args.learning_rate,
175
+ n_iter=250000,
176
+ iteration=250000,
177
+ divider=25,
178
+ warmup_proportion=0.05,
179
+ phase=("linear", "cosine"),
180
+ )
181
+ # ae_loss_fn = nn.MSELoss(reduction="mean")
182
+ ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
183
+
184
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
185
+ fft_sizes=[512, 1024, 2048],
186
+ hop_sizes=[50, 120, 240],
187
+ win_lengths=[240, 600, 1200],
188
+ sc_lambda=0.5,
189
+ mag_lambda=0.5,
190
+ band="full"
191
+ ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  # training loop
194
 
195
  # state
196
+ average_pesq_metric = 10000000000
197
+ average_loss = 10000000000
198
+ average_ae_loss = 10000000000
199
+ average_sc_loss = 10000000000
200
+ average_mag_loss = 10000000000
 
 
201
 
202
  model_list = list()
203
  best_idx_epoch = None
 
205
  patience_count = 0
206
 
207
  logger.info("training")
208
+ for idx_epoch in range(args.max_epochs):
209
  # train
210
+ model.train()
 
211
 
212
+ total_pesq_metric = 0.
213
+ total_loss = 0.
214
+ total_ae_loss = 0.
215
+ total_sc_loss = 0.
216
+ total_mag_loss = 0.
217
  total_batches = 0.
218
+
219
  progress_bar = tqdm(
220
  total=len(train_data_loader),
221
  desc="Training; epoch: {}".format(idx_epoch),
222
  )
223
  for batch in train_data_loader:
224
+ clean_audios, noisy_audios = batch
225
+ clean_audios = clean_audios.to(device)
226
+ noisy_audios = noisy_audios.to(device)
227
+
228
+ enhanced_audios = model.forward(noisy_audios)
229
+
230
+ ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
231
+ sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios.squeeze(1), clean_audios.squeeze(1))
232
+
233
+ loss = ae_loss + sc_loss + mag_loss
234
+
235
+ enhanced_audios_list_r = list(enhanced_audios.cpu().numpy())
236
+ clean_audios_list_r = list(clean_audios.cpu().numpy())
237
+ pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
238
+
239
+ optimizer.zero_grad()
240
+ loss.backward()
241
+ optimizer.step()
242
+ lr_scheduler.step()
243
+
244
+ total_pesq_metric += pesq_metric.item()
245
+ total_loss += loss.item()
246
+ total_ae_loss += ae_loss.item()
247
+ total_sc_loss += sc_loss.item()
248
+ total_mag_loss += mag_loss.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  total_batches += 1
250
 
251
+ average_pesq_metric = round(total_pesq_metric / total_batches, 4)
252
+ average_loss = round(total_loss / total_batches, 4)
253
+ average_ae_loss = round(total_ae_loss / total_batches, 4)
254
+ average_sc_loss = round(total_sc_loss / total_batches, 4)
255
+ average_mag_loss = round(total_mag_loss / total_batches, 4)
256
 
257
  progress_bar.update(1)
258
  progress_bar.set_postfix({
259
+ "pesq_metric": average_pesq_metric,
260
+ "loss": average_loss,
261
+ "ae_loss": average_ae_loss,
262
+ "sc_loss": average_sc_loss,
263
+ "mag_loss": average_mag_loss,
264
  })
265
 
266
  # evaluation
267
+ model.eval()
 
268
 
269
  torch.cuda.empty_cache()
270
+
271
+ total_pesq_metric = 0.
272
+ total_loss = 0.
273
+ total_ae_loss = 0.
274
+ total_sc_loss = 0.
275
+ total_mag_loss = 0.
276
  total_batches = 0.
277
 
278
  progress_bar = tqdm(
 
281
  )
282
  with torch.no_grad():
283
  for batch in valid_data_loader:
284
+ clean_audios, noisy_audios = batch
285
+ clean_audios = clean_audios.to(device)
286
+ noisy_audios = noisy_audios.to(device)
287
+
288
+ enhanced_audios = model.forward(noisy_audios)
289
+ enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
290
+ ae_loss = ae_loss_fn(enhanced_audios, enhanced_audios)
291
+ sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios.squeeze(1), clean_audios.squeeze(1))
292
+
293
+ loss = ae_loss + sc_loss + mag_loss
294
+
295
+ enhanced_audios_list_r = list(enhanced_audios.cpu().numpy())
296
+ clean_audios_list_r = list(clean_audios.cpu().numpy())
297
+ pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
298
+
299
+ total_pesq_metric += pesq_metric.item()
300
+ total_loss += loss.item()
301
+ total_ae_loss += ae_loss.item()
302
+ total_sc_loss += sc_loss.item()
303
+ total_mag_loss += mag_loss.item()
 
 
 
304
  total_batches += 1
305
 
306
+ average_pesq_metric = round(total_pesq_metric / total_batches, 4)
307
+ average_loss = round(total_loss / total_batches, 4)
308
+ average_ae_loss = round(total_ae_loss / total_batches, 4)
309
+ average_sc_loss = round(total_sc_loss / total_batches, 4)
310
+ average_mag_loss = round(total_mag_loss / total_batches, 4)
311
 
312
  progress_bar.update(1)
313
  progress_bar.set_postfix({
314
+ "pesq_metric": average_pesq_metric,
315
+ "loss": average_loss,
316
+ "ae_loss": average_ae_loss,
317
+ "sc_loss": average_sc_loss,
318
+ "mag_loss": average_mag_loss,
319
  })
320
 
321
  # scheduler
322
+ lr_scheduler.step()
 
323
 
324
  # save path
325
  epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
326
  epoch_dir.mkdir(parents=True, exist_ok=False)
327
 
328
  # save models
329
+ model.save_pretrained(epoch_dir.as_posix())
 
 
 
 
 
330
 
331
  model_list.append(epoch_dir)
332
  if len(model_list) >= args.num_serialized_models_to_keep:
 
336
  # save metric
337
  if best_metric is None:
338
  best_idx_epoch = idx_epoch
339
+ best_metric = average_pesq_metric
340
+ elif average_pesq_metric > best_metric:
341
  # great is better.
342
  best_idx_epoch = idx_epoch
343
+ best_metric = average_pesq_metric
344
  else:
345
  pass
346
 
347
  metrics = {
348
  "idx_epoch": idx_epoch,
349
  "best_idx_epoch": best_idx_epoch,
350
+
351
+ "pesq_metric": average_pesq_metric,
352
+ "loss": average_loss,
353
+ "ae_loss": average_ae_loss,
354
+ "sc_loss": average_sc_loss,
355
+ "mag_loss": average_mag_loss,
 
 
356
 
357
  }
358
  metrics_filename = epoch_dir / "metrics_epoch.json"
toolbox/torchaudio/models/clean_unet/loss.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ # from distutils.version import LooseVersion
8
+
9
+
10
+ # is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion("1.7")
11
+ is_pytorch_17plus = True
12
+
13
+
14
+ def stft(x, fft_size, hop_size, win_length, window):
15
+ """
16
+ Perform STFT and convert to magnitude spectrogram.
17
+ :param x: Tensor, Input signal tensor (B, T).
18
+ :param fft_size: int, FFT size.
19
+ :param hop_size: int, Hop size.
20
+ :param win_length: int, Window length.
21
+ :param window: str, Window function type.
22
+ :return: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
23
+ """
24
+
25
+ if is_pytorch_17plus:
26
+ x_stft = torch.stft(
27
+ x, fft_size, hop_size, win_length, window, return_complex=False
28
+ )
29
+ else:
30
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
31
+ real = x_stft[..., 0]
32
+ imag = x_stft[..., 1]
33
+
34
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
35
+ return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
36
+
37
+
38
+ class SpectralConvergenceLoss(torch.nn.Module):
39
+ """Spectral convergence loss module."""
40
+
41
+ def __init__(self):
42
+ super(SpectralConvergenceLoss, self).__init__()
43
+
44
+ def forward(self, x_mag, y_mag):
45
+ """
46
+ Calculate forward propagation.
47
+ :param x_mag: Tensor, Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
48
+ :param y_mag: Tensor, Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
49
+ :return: Tensor, Spectral convergence loss value.
50
+ """
51
+ return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
52
+
53
+
54
+ class LogSTFTMagnitudeLoss(torch.nn.Module):
55
+ """Log STFT magnitude loss module."""
56
+
57
+ def __init__(self):
58
+ super(LogSTFTMagnitudeLoss, self).__init__()
59
+
60
+ def forward(self, x_mag, y_mag):
61
+ """
62
+ Calculate forward propagation.
63
+ :param x_mag: Tensor, Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
64
+ :param y_mag: Tensor, Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
65
+ :return: Tensor, Log STFT magnitude loss value.
66
+ """
67
+ return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
68
+
69
+
70
+ class STFTLoss(torch.nn.Module):
71
+ """STFT loss module."""
72
+
73
+ def __init__(
74
+ self, fft_size=1024, shift_size=120, win_length=600, window="hann_window",
75
+ band="full"
76
+ ):
77
+ super(STFTLoss, self).__init__()
78
+ self.fft_size = fft_size
79
+ self.shift_size = shift_size
80
+ self.win_length = win_length
81
+ self.band = band
82
+
83
+ self.spectral_convergence_loss = SpectralConvergenceLoss()
84
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
85
+ # NOTE(kan-bayashi): Use register_buffer to fix #223
86
+ self.register_buffer("window", getattr(torch, window)(win_length))
87
+
88
+ def forward(self, x, y):
89
+ """
90
+ Calculate forward propagation.
91
+ :param x: Tensor, Predicted signal (B, T).
92
+ :param y: Tensor, Groundtruth signal (B, T).
93
+ :return:
94
+ Tensor, Spectral convergence loss value.
95
+ Tensor, Log STFT magnitude loss value.
96
+ """
97
+ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
98
+ y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
99
+
100
+ if self.band == "high":
101
+ freq_mask_ind = x_mag.shape[1] // 2 # only select high frequency bands
102
+ sc_loss = self.spectral_convergence_loss(x_mag[:,freq_mask_ind:,:], y_mag[:,freq_mask_ind:,:])
103
+ mag_loss = self.log_stft_magnitude_loss(x_mag[:,freq_mask_ind:,:], y_mag[:,freq_mask_ind:,:])
104
+ elif self.band == "full":
105
+ sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
106
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
107
+ else:
108
+ raise NotImplementedError
109
+
110
+ return sc_loss, mag_loss
111
+
112
+
113
+ class MultiResolutionSTFTLoss(torch.nn.Module):
114
+ """Multi resolution STFT loss module."""
115
+
116
+ def __init__(self,
117
+ fft_sizes=None, hop_sizes=None, win_lengths=None,
118
+ window="hann_window", sc_lambda=0.1, mag_lambda=0.1, band="full",
119
+ ):
120
+ """
121
+ Initialize Multi resolution STFT loss module.
122
+ :param fft_sizes: list, List of FFT sizes.
123
+ :param hop_sizes: list, List of hop sizes.
124
+ :param win_lengths: list, List of window lengths.
125
+ :param window: str, Window function type.
126
+ :param sc_lambda: float, a balancing factor across different losses.
127
+ :param mag_lambda: float, a balancing factor across different losses.
128
+ :param band: str, high-band or full-band loss
129
+ """
130
+ super(MultiResolutionSTFTLoss, self).__init__()
131
+ fft_sizes = fft_sizes or [1024, 2048, 512]
132
+ hop_sizes = hop_sizes or [120, 240, 50]
133
+ win_lengths = win_lengths or [600, 1200, 240]
134
+
135
+ self.sc_lambda = sc_lambda
136
+ self.mag_lambda = mag_lambda
137
+
138
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
139
+ self.stft_losses = torch.nn.ModuleList()
140
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
141
+ self.stft_losses += [STFTLoss(fs, ss, wl, window, band)]
142
+
143
+ def forward(self, x, y):
144
+ """
145
+ Calculate forward propagation.
146
+ :param x: Tensor, Predicted signal (B, T) or (B, #subband, T).
147
+ :param y: Tensor, Groundtruth signal (B, T) or (B, #subband, T).
148
+ :return:
149
+ Tensor, Multi resolution spectral convergence loss value.
150
+ Tensor, Multi resolution log STFT magnitude loss value.
151
+ """
152
+ if len(x.shape) == 3:
153
+ x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T)
154
+ y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T)
155
+ sc_loss = 0.0
156
+ mag_loss = 0.0
157
+ for f in self.stft_losses:
158
+ sc_l, mag_l = f(x, y)
159
+ sc_loss += sc_l
160
+ mag_loss += mag_l
161
+
162
+ sc_loss *= self.sc_lambda
163
+ sc_loss /= len(self.stft_losses)
164
+ mag_loss *= self.mag_lambda
165
+ mag_loss /= len(self.stft_losses)
166
+
167
+ return sc_loss, mag_loss
168
+
169
+
170
+ if __name__ == '__main__':
171
+ pass
toolbox/torchaudio/models/clean_unet/metrics.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from joblib import Parallel, delayed
4
+ import numpy as np
5
+ from pesq import pesq
6
+ import torch
7
+
8
+
9
+ def cal_pesq(clean, noisy, sr=16000):
10
+ try:
11
+ pesq_score = pesq(sr, clean, noisy, "wb")
12
+ except Exception as e:
13
+ # print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
14
+ # error can happen due to silent period
15
+ pesq_score = -1
16
+ return pesq_score
17
+
18
+
19
+ def batch_pesq(clean, noisy):
20
+ pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy))
21
+ pesq_score = np.array(pesq_score)
22
+ if -1 in pesq_score:
23
+ return None
24
+ pesq_score = (pesq_score - 1) / 3.5
25
+ return torch.FloatTensor(pesq_score)
26
+
27
+
28
+ def main():
29
+
30
+ prediction = torch.rand(size=(1, 160000), dtype=torch.float32)
31
+ ground_truth = torch.rand(size=(1, 160000), dtype=torch.float32)
32
+
33
+ prediction_list_r = list(prediction.cpu().numpy())
34
+ ground_truth_list_r = list(ground_truth.cpu().numpy())
35
+
36
+ pesq_score = batch_pesq(prediction_list_r, ground_truth_list_r)
37
+ print(pesq_score)
38
+ return
39
+
40
+
41
+ if __name__ == "__main__":
42
+ main()
toolbox/torchaudio/models/clean_unet/training.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import math
4
+
5
+
6
+ def anneal_linear(start, end, proportion):
7
+ return start + proportion * (end - start)
8
+
9
+
10
+ def anneal_cosine(start, end, proportion):
11
+ cos_val = math.cos(math.pi * proportion) + 1
12
+ return end + (start - end) / 2 * cos_val
13
+
14
+
15
+ class Phase:
16
+ def __init__(self, start, end, n_iter, cur_iter, anneal_fn):
17
+ self.start, self.end = start, end
18
+ self.n_iter = n_iter
19
+ self.anneal_fn = anneal_fn
20
+ self.n = cur_iter
21
+
22
+ def step(self):
23
+ self.n += 1
24
+
25
+ return self.anneal_fn(self.start, self.end, self.n / self.n_iter)
26
+
27
+ def reset(self):
28
+ self.n = 0
29
+
30
+ @property
31
+ def is_done(self):
32
+ return self.n >= self.n_iter
33
+
34
+
35
+ class LinearWarmupCosineDecay(object):
36
+ def __init__(
37
+ self,
38
+ optimizer,
39
+ lr_max,
40
+ n_iter,
41
+ iteration=0,
42
+ divider=25,
43
+ warmup_proportion=0.3,
44
+ phase=('linear', 'cosine'),
45
+ ):
46
+ self.optimizer = optimizer
47
+
48
+ phase1 = int(n_iter * warmup_proportion)
49
+ phase2 = n_iter - phase1
50
+ lr_min = lr_max / divider
51
+
52
+ phase_map = {'linear': anneal_linear, 'cosine': anneal_cosine}
53
+
54
+ cur_iter_phase1 = iteration
55
+ cur_iter_phase2 = max(0, iteration - phase1)
56
+ self.lr_phase = [
57
+ Phase(lr_min, lr_max, phase1, cur_iter_phase1, phase_map[phase[0]]),
58
+ Phase(lr_max, lr_min / 1e4, phase2, cur_iter_phase2, phase_map[phase[1]]),
59
+ ]
60
+
61
+ if iteration < phase1:
62
+ self.phase = 0
63
+ else:
64
+ self.phase = 1
65
+
66
+ def step(self):
67
+ lr = self.lr_phase[self.phase].step()
68
+
69
+ for group in self.optimizer.param_groups:
70
+ group['lr'] = lr
71
+
72
+ if self.lr_phase[self.phase].is_done:
73
+ self.phase += 1
74
+
75
+ if self.phase >= len(self.lr_phase):
76
+ for phase in self.lr_phase:
77
+ phase.reset()
78
+
79
+ self.phase = 0
80
+
81
+ return lr
82
+
83
+
84
+ if __name__ == '__main__':
85
+ pass