HoneyTian commited on
Commit
7335f6f
·
1 Parent(s): f418b0d
examples/frcrn/run.sh CHANGED
@@ -4,7 +4,7 @@
4
 
5
 
6
  sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \
7
- --config_file "yaml/config-20.yaml" \
8
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
9
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
10
 
 
4
 
5
 
6
  sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \
7
+ --config_file "yaml/config-10.yaml" \
8
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
9
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
10
 
examples/frcrn/step_2_train_model.py CHANGED
@@ -34,6 +34,7 @@ from tqdm import tqdm
34
 
35
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
36
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
 
37
  from toolbox.torchaudio.metrics.pesq import run_pesq_score
38
  from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
39
  from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel
@@ -220,6 +221,14 @@ def main():
220
  raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
221
 
222
  neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
 
 
 
 
 
 
 
 
223
 
224
  # training loop
225
 
@@ -248,6 +257,7 @@ def main():
248
 
249
  total_pesq_score = 0.
250
  total_loss = 0.
 
251
  total_neg_si_snr_loss = 0.
252
  total_mask_loss = 0.
253
  total_batches = 0.
@@ -264,10 +274,11 @@ def main():
264
  est_spec, est_wav, est_mask = model.forward(noisy_audios)
265
  denoise_audios = est_wav
266
 
 
267
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
268
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
269
 
270
- loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
271
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
272
  logger.info(f"find nan or inf in loss.")
273
  continue
@@ -284,12 +295,14 @@ def main():
284
 
285
  total_pesq_score += pesq_score
286
  total_loss += loss.item()
 
287
  total_neg_si_snr_loss += neg_si_snr_loss.item()
288
  total_mask_loss += mask_loss.item()
289
  total_batches += 1
290
 
291
  average_pesq_score = round(total_pesq_score / total_batches, 4)
292
  average_loss = round(total_loss / total_batches, 4)
 
293
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
294
  average_mask_loss = round(total_mask_loss / total_batches, 4)
295
 
@@ -298,6 +311,7 @@ def main():
298
  "lr": lr_scheduler.get_last_lr()[0],
299
  "pesq_score": average_pesq_score,
300
  "loss": average_loss,
 
301
  "neg_si_snr_loss": average_neg_si_snr_loss,
302
  "mask_loss": average_mask_loss,
303
  })
@@ -311,6 +325,7 @@ def main():
311
 
312
  total_pesq_score = 0.
313
  total_loss = 0.
 
314
  total_neg_si_snr_loss = 0.
315
  total_mask_loss = 0.
316
  total_batches = 0.
@@ -327,10 +342,11 @@ def main():
327
  est_spec, est_wav, est_mask = model.forward(noisy_audios)
328
  denoise_audios = est_wav
329
 
 
330
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
331
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
332
 
333
- loss = 1.0 * neg_si_snr_loss + 1.0 * mask_loss
334
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
335
  logger.info(f"find nan or inf in loss.")
336
  continue
@@ -347,6 +363,7 @@ def main():
347
 
348
  average_pesq_score = round(total_pesq_score / total_batches, 4)
349
  average_loss = round(total_loss / total_batches, 4)
 
350
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
351
  average_mask_loss = round(total_mask_loss / total_batches, 4)
352
 
@@ -355,12 +372,14 @@ def main():
355
  "lr": lr_scheduler.get_last_lr()[0],
356
  "pesq_score": average_pesq_score,
357
  "loss": average_loss,
 
358
  "neg_si_snr_loss": average_neg_si_snr_loss,
359
  "mask_loss": average_mask_loss,
360
  })
361
 
362
  total_pesq_score = 0.
363
  total_loss = 0.
 
364
  total_neg_si_snr_loss = 0.
365
  total_mask_loss = 0.
366
  total_batches = 0.
 
34
 
35
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
36
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
37
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
38
  from toolbox.torchaudio.metrics.pesq import run_pesq_score
39
  from toolbox.torchaudio.models.frcrn.configuration_frcrn import FRCRNConfig
40
  from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretrainedModel
 
221
  raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
222
 
223
  neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
224
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
225
+ fft_size_list=[256, 512, 1024],
226
+ win_size_list=[256, 512, 1024],
227
+ hop_size_list=[128, 256, 512],
228
+ factor_sc=1.5,
229
+ factor_mag=1.0,
230
+ reduction="mean"
231
+ ).to(device)
232
 
233
  # training loop
234
 
 
257
 
258
  total_pesq_score = 0.
259
  total_loss = 0.
260
+ total_mr_stft_loss = 0.
261
  total_neg_si_snr_loss = 0.
262
  total_mask_loss = 0.
263
  total_batches = 0.
 
274
  est_spec, est_wav, est_mask = model.forward(noisy_audios)
275
  denoise_audios = est_wav
276
 
277
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
278
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
279
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
280
 
281
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
282
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
283
  logger.info(f"find nan or inf in loss.")
284
  continue
 
295
 
296
  total_pesq_score += pesq_score
297
  total_loss += loss.item()
298
+ total_mr_stft_loss += mr_stft_loss.item()
299
  total_neg_si_snr_loss += neg_si_snr_loss.item()
300
  total_mask_loss += mask_loss.item()
301
  total_batches += 1
302
 
303
  average_pesq_score = round(total_pesq_score / total_batches, 4)
304
  average_loss = round(total_loss / total_batches, 4)
305
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
306
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
307
  average_mask_loss = round(total_mask_loss / total_batches, 4)
308
 
 
311
  "lr": lr_scheduler.get_last_lr()[0],
312
  "pesq_score": average_pesq_score,
313
  "loss": average_loss,
314
+ "mr_stft_loss": average_mr_stft_loss,
315
  "neg_si_snr_loss": average_neg_si_snr_loss,
316
  "mask_loss": average_mask_loss,
317
  })
 
325
 
326
  total_pesq_score = 0.
327
  total_loss = 0.
328
+ total_mr_stft_loss = 0.
329
  total_neg_si_snr_loss = 0.
330
  total_mask_loss = 0.
331
  total_batches = 0.
 
342
  est_spec, est_wav, est_mask = model.forward(noisy_audios)
343
  denoise_audios = est_wav
344
 
345
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
346
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
347
  mask_loss = model.mask_loss_fn(est_mask, clean_audios, noisy_audios)
348
 
349
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss + 1.0 * mask_loss
350
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
351
  logger.info(f"find nan or inf in loss.")
352
  continue
 
363
 
364
  average_pesq_score = round(total_pesq_score / total_batches, 4)
365
  average_loss = round(total_loss / total_batches, 4)
366
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
367
  average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
368
  average_mask_loss = round(total_mask_loss / total_batches, 4)
369
 
 
372
  "lr": lr_scheduler.get_last_lr()[0],
373
  "pesq_score": average_pesq_score,
374
  "loss": average_loss,
375
+ "mr_stft_loss": average_mr_stft_loss,
376
  "neg_si_snr_loss": average_neg_si_snr_loss,
377
  "mask_loss": average_mask_loss,
378
  })
379
 
380
  total_pesq_score = 0.
381
  total_loss = 0.
382
+ total_mr_stft_loss = 0.
383
  total_neg_si_snr_loss = 0.
384
  total_mask_loss = 0.
385
  total_batches = 0.
toolbox/torchaudio/models/frcrn/modeling_frcrn.py CHANGED
@@ -300,25 +300,38 @@ class FRCRNPretrainedModel(FRCRN):
300
  def main():
301
  # model = FRCRN(
302
  # use_complex_networks=True,
303
- # model_complexity=45,
 
 
 
 
 
 
 
 
 
 
 
304
  # model_depth=14,
305
  # padding_mode="zeros",
306
- # nfft=512,
307
- # win_size=400,
308
- # hop_size=200,
309
  # win_type="hann",
310
  # )
 
311
  model = FRCRN(
312
  use_complex_networks=True,
313
- model_complexity=45,
314
- model_depth=14,
315
  padding_mode="zeros",
316
- nfft=640,
317
- win_size=640,
318
- hop_size=320,
319
  win_type="hann",
320
  )
321
- mixture = torch.rand(size=(1, 8000), dtype=torch.float32)
 
322
 
323
  est_spec, est_wav, est_mask = model.forward(mixture)
324
  print(est_spec.shape)
 
300
  def main():
301
  # model = FRCRN(
302
  # use_complex_networks=True,
303
+ # model_complexity=-1,
304
+ # model_depth=10,
305
+ # padding_mode="zeros",
306
+ # nfft=128,
307
+ # win_size=128,
308
+ # hop_size=64,
309
+ # win_type="hann",
310
+ # )
311
+
312
+ # model = FRCRN(
313
+ # use_complex_networks=True,
314
+ # model_complexity=-1,
315
  # model_depth=14,
316
  # padding_mode="zeros",
317
+ # nfft=640,
318
+ # win_size=640,
319
+ # hop_size=320,
320
  # win_type="hann",
321
  # )
322
+
323
  model = FRCRN(
324
  use_complex_networks=True,
325
+ model_complexity=20,
326
+ model_depth=20,
327
  padding_mode="zeros",
328
+ nfft=512,
329
+ win_size=512,
330
+ hop_size=256,
331
  win_type="hann",
332
  )
333
+
334
+ mixture = torch.rand(size=(1, 32000), dtype=torch.float32)
335
 
336
  est_spec, est_wav, est_mask = model.forward(mixture)
337
  print(est_spec.shape)
toolbox/torchaudio/models/frcrn/unet.py CHANGED
@@ -339,19 +339,8 @@ class UNet(nn.Module):
339
  return cmp_spec
340
 
341
 
342
- def main():
343
  # [batch_size, 1, freq_bins, time_steps, 2]
344
- # x = torch.rand(size=(1, 1, 257, 2000, 2))
345
- # unet = UNet(
346
- # in_channels=1,
347
- # model_complexity=45,
348
- # model_depth=20,
349
- # use_complex_networks=True
350
- # )
351
- # print(unet)
352
- # result = unet.forward(x)
353
- # print(result.shape)
354
-
355
  # x = torch.rand(size=(1, 1, 65, 2000, 2))
356
  x = torch.rand(size=(1, 1, 65, 200, 2))
357
  unet = UNet(
@@ -366,5 +355,20 @@ def main():
366
  return
367
 
368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  if __name__ == "__main__":
370
- main()
 
339
  return cmp_spec
340
 
341
 
342
+ def main10():
343
  # [batch_size, 1, freq_bins, time_steps, 2]
 
 
 
 
 
 
 
 
 
 
 
344
  # x = torch.rand(size=(1, 1, 65, 2000, 2))
345
  x = torch.rand(size=(1, 1, 65, 200, 2))
346
  unet = UNet(
 
355
  return
356
 
357
 
358
+ def main20():
359
+ # [batch_size, 1, freq_bins, time_steps, 2]
360
+ x = torch.rand(size=(1, 1, 257, 2000, 2))
361
+ unet = UNet(
362
+ in_channels=1,
363
+ model_complexity=45,
364
+ model_depth=20,
365
+ use_complex_networks=True
366
+ )
367
+ print(unet)
368
+ result = unet.forward(x)
369
+ print(result.shape)
370
+ return
371
+
372
+
373
  if __name__ == "__main__":
374
+ main20()