update
Browse files
examples/cnn_vad_by_webrtcvad/step_4_train_model.py
CHANGED
@@ -272,7 +272,7 @@ def main():
|
|
272 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
273 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
274 |
|
275 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.
|
276 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
277 |
logger.info(f"find nan or inf in loss. continue.")
|
278 |
continue
|
@@ -352,7 +352,7 @@ def main():
|
|
352 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
353 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
354 |
|
355 |
-
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.
|
356 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
357 |
logger.info(f"find nan or inf in loss. continue.")
|
358 |
continue
|
|
|
272 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
273 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
274 |
|
275 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
|
276 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
277 |
logger.info(f"find nan or inf in loss. continue.")
|
278 |
continue
|
|
|
352 |
dice_loss = dice_loss_fn.forward(probs, targets)
|
353 |
lsnr_loss = model.lsnr_loss_fn(lsnr, clean_audios, noisy_audios)
|
354 |
|
355 |
+
loss = 1.0 * bce_loss + 1.0 * dice_loss + 0.03 * lsnr_loss
|
356 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
357 |
logger.info(f"find nan or inf in loss. continue.")
|
358 |
continue
|
toolbox/torchaudio/models/vad/cnn_vad/modeling_cnn_vad.py
CHANGED
@@ -201,10 +201,6 @@ class CNNVadModel(nn.Module):
|
|
201 |
raise AssertionError("Input signals must have the same shape")
|
202 |
noise = noisy - clean
|
203 |
|
204 |
-
print(f"lsnr: {lsnr.shape}")
|
205 |
-
print(f"clean: {clean.shape}")
|
206 |
-
print(f"noisy: {noisy.shape}")
|
207 |
-
|
208 |
if clean.dim() == 2:
|
209 |
clean = torch.unsqueeze(clean, dim=1)
|
210 |
if noise.dim() == 2:
|
@@ -227,9 +223,6 @@ class CNNVadModel(nn.Module):
|
|
227 |
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
|
228 |
# lsnr_gth shape: [b, t]
|
229 |
|
230 |
-
print(f"lsnr: {lsnr.shape}")
|
231 |
-
print(f"lsnr_gth: {lsnr_gth.shape}")
|
232 |
-
|
233 |
loss = F.mse_loss(lsnr, lsnr_gth)
|
234 |
return loss
|
235 |
|
|
|
201 |
raise AssertionError("Input signals must have the same shape")
|
202 |
noise = noisy - clean
|
203 |
|
|
|
|
|
|
|
|
|
204 |
if clean.dim() == 2:
|
205 |
clean = torch.unsqueeze(clean, dim=1)
|
206 |
if noise.dim() == 2:
|
|
|
223 |
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
|
224 |
# lsnr_gth shape: [b, t]
|
225 |
|
|
|
|
|
|
|
226 |
loss = F.mse_loss(lsnr, lsnr_gth)
|
227 |
return loss
|
228 |
|