HoneyTian commited on
Commit
9b1d5cc
·
1 Parent(s): bd728a1
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -256,7 +256,8 @@ def main():
256
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
257
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
258
 
259
- loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss
 
260
 
261
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
262
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
@@ -320,7 +321,8 @@ def main():
320
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
321
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
322
 
323
- loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss
 
324
 
325
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
326
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
 
256
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
257
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
258
 
259
+ loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
260
+ # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss
261
 
262
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
263
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
 
321
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
322
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
323
 
324
+ loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
325
+ # loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss
326
 
327
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
328
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
toolbox/torchaudio/losses/spectral.py CHANGED
@@ -191,9 +191,11 @@ class LogSTFTMagnitudeLoss(torch.nn.Module):
191
 
192
  def __init__(self,
193
  reduction: str = "mean",
 
194
  ):
195
  super(LogSTFTMagnitudeLoss, self).__init__()
196
  self.reduction = reduction
 
197
 
198
  if reduction not in ("sum", "mean"):
199
  raise AssertionError(f"param reduction must be sum or mean.")
@@ -207,7 +209,7 @@ class LogSTFTMagnitudeLoss(torch.nn.Module):
207
  :param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
208
  :return:
209
  """
210
- return F.l1_loss(torch.log(denoise_magnitude), torch.log(clean_magnitude))
211
 
212
 
213
  class STFTLoss(torch.nn.Module):
 
191
 
192
  def __init__(self,
193
  reduction: str = "mean",
194
+ eps: float = 1e-8,
195
  ):
196
  super(LogSTFTMagnitudeLoss, self).__init__()
197
  self.reduction = reduction
198
+ self.eps = eps
199
 
200
  if reduction not in ("sum", "mean"):
201
  raise AssertionError(f"param reduction must be sum or mean.")
 
209
  :param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
210
  :return:
211
  """
212
+ return F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps))
213
 
214
 
215
  class STFTLoss(torch.nn.Module):