HoneyTian commited on
Commit
b580752
·
1 Parent(s): 07fcb5c
examples/cnn_vad_by_webrtcvad/step_4_train_model.py CHANGED
@@ -272,7 +272,7 @@ def main():
272
  dice_loss = dice_loss_fn.forward(probs, targets)
273
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
274
 
275
- loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
276
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
277
  logger.info(f"find nan or inf in loss. continue.")
278
  continue
@@ -352,7 +352,7 @@ def main():
352
  dice_loss = dice_loss_fn.forward(probs, targets)
353
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
354
 
355
- loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.3 * lsnr_loss
356
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
357
  logger.info(f"find nan or inf in loss. continue.")
358
  continue
 
272
  dice_loss = dice_loss_fn.forward(probs, targets)
273
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
274
 
275
+ loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
276
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
277
  logger.info(f"find nan or inf in loss. continue.")
278
  continue
 
352
  dice_loss = dice_loss_fn.forward(probs, targets)
353
  lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
354
 
355
+ loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
356
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
357
  logger.info(f"find nan or inf in loss. continue.")
358
  continue
toolbox/torchaudio/models/vad/cnn_vad/modeling_cnn_vad.py CHANGED
@@ -201,10 +201,6 @@ class CNNVadModel(nn.Module):
201
  raise AssertionError("Input signals must have the same shape")
202
  noise = noisy - clean
203
 
204
- print(f"lsnr: {lsnr.shape}")
205
- print(f"clean: {clean.shape}")
206
- print(f"noisy: {noisy.shape}")
207
-
208
  if clean.dim() == 2:
209
  clean = torch.unsqueeze(clean, dim=1)
210
  if noise.dim() == 2:
@@ -227,9 +223,6 @@ class CNNVadModel(nn.Module):
227
  lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
228
  # lsnr_gth shape: [b, t]
229
 
230
- print(f"lsnr: {lsnr.shape}")
231
- print(f"lsnr_gth: {lsnr_gth.shape}")
232
-
233
  loss = F.mse_loss(lsnr, lsnr_gth)
234
  return loss
235
 
 
201
  raise AssertionError("Input signals must have the same shape")
202
  noise = noisy - clean
203
 
 
 
 
 
204
  if clean.dim() == 2:
205
  clean = torch.unsqueeze(clean, dim=1)
206
  if noise.dim() == 2:
 
223
  lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
224
  # lsnr_gth shape: [b, t]
225
 
 
 
 
226
  loss = F.mse_loss(lsnr, lsnr_gth)
227
  return loss
228