Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -256,7 +256,8 @@ def main():
|
|
256 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
257 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
258 |
|
259 |
-
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
|
|
260 |
|
261 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
262 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
@@ -320,7 +321,8 @@ def main():
|
|
320 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
321 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
322 |
|
323 |
-
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
|
|
324 |
|
325 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
326 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
|
|
256 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
257 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
258 |
|
259 |
+
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
260 |
+
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss
|
261 |
|
262 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
263 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
|
|
321 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
322 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
323 |
|
324 |
+
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss
|
325 |
+
# loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss
|
326 |
|
327 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
328 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
toolbox/torchaudio/losses/spectral.py
CHANGED
@@ -191,9 +191,11 @@ class LogSTFTMagnitudeLoss(torch.nn.Module):
|
|
191 |
|
192 |
def __init__(self,
|
193 |
reduction: str = "mean",
|
|
|
194 |
):
|
195 |
super(LogSTFTMagnitudeLoss, self).__init__()
|
196 |
self.reduction = reduction
|
|
|
197 |
|
198 |
if reduction not in ("sum", "mean"):
|
199 |
raise AssertionError(f"param reduction must be sum or mean.")
|
@@ -207,7 +209,7 @@ class LogSTFTMagnitudeLoss(torch.nn.Module):
|
|
207 |
:param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
|
208 |
:return:
|
209 |
"""
|
210 |
-
return F.l1_loss(torch.log(denoise_magnitude), torch.log(clean_magnitude))
|
211 |
|
212 |
|
213 |
class STFTLoss(torch.nn.Module):
|
|
|
191 |
|
192 |
def __init__(self,
|
193 |
reduction: str = "mean",
|
194 |
+
eps: float = 1e-8,
|
195 |
):
|
196 |
super(LogSTFTMagnitudeLoss, self).__init__()
|
197 |
self.reduction = reduction
|
198 |
+
self.eps = eps
|
199 |
|
200 |
if reduction not in ("sum", "mean"):
|
201 |
raise AssertionError(f"param reduction must be sum or mean.")
|
|
|
209 |
:param clean_magnitude: Tensor, shape: [batch_size, time_steps, freq_bins]
|
210 |
:return:
|
211 |
"""
|
212 |
+
return F.l1_loss(torch.log(denoise_magnitude + self.eps), torch.log(clean_magnitude + self.eps))
|
213 |
|
214 |
|
215 |
class STFTLoss(torch.nn.Module):
|