Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -204,7 +204,6 @@ def main():
|
|
204 |
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
205 |
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
206 |
neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
|
207 |
-
lds_loss_fn = LSDLoss(reduction="mean").to(device)
|
208 |
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
209 |
# fft_size_list=[256, 512, 1024],
|
210 |
win_size_list=[120, 240, 480],
|
@@ -220,7 +219,6 @@ def main():
|
|
220 |
average_ae_loss = 1000000000
|
221 |
average_neg_si_snr_loss = 1000000000
|
222 |
average_neg_stoi_loss = 1000000000
|
223 |
-
average_lds_loss = 1000000000
|
224 |
|
225 |
model_list = list()
|
226 |
best_idx_epoch = None
|
@@ -237,7 +235,6 @@ def main():
|
|
237 |
total_ae_loss = 0.
|
238 |
total_neg_si_snr_loss = 0.
|
239 |
total_neg_stoi_loss = 0.
|
240 |
-
total_lds_loss = 0.
|
241 |
total_batches = 0.
|
242 |
progress_bar = tqdm(
|
243 |
desc="Training; epoch-{}".format(idx_epoch),
|
@@ -256,10 +253,9 @@ def main():
|
|
256 |
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
|
257 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
258 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
259 |
-
lds_loss = lds_loss_fn.forward(denoise_audios, clean_audios)
|
260 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
261 |
|
262 |
-
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 *
|
263 |
|
264 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
265 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
@@ -275,7 +271,6 @@ def main():
|
|
275 |
total_ae_loss += ae_loss.item()
|
276 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
277 |
total_neg_stoi_loss += neg_stoi_loss.item()
|
278 |
-
total_lds_loss += lds_loss.item()
|
279 |
total_batches += 1
|
280 |
|
281 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
@@ -283,7 +278,6 @@ def main():
|
|
283 |
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
284 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
285 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
286 |
-
average_lds_loss = round(total_lds_loss / total_batches, 4)
|
287 |
|
288 |
progress_bar.update(1)
|
289 |
progress_bar.set_postfix({
|
@@ -292,7 +286,6 @@ def main():
|
|
292 |
"ae_loss": average_ae_loss,
|
293 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
294 |
"neg_stoi_loss": average_neg_stoi_loss,
|
295 |
-
"lds_loss": average_lds_loss,
|
296 |
})
|
297 |
|
298 |
# evaluation
|
@@ -304,7 +297,6 @@ def main():
|
|
304 |
total_ae_loss = 0.
|
305 |
total_neg_si_snr_loss = 0.
|
306 |
total_neg_stoi_loss = 0.
|
307 |
-
total_lds_loss = 0.
|
308 |
total_batches = 0.
|
309 |
|
310 |
progress_bar = tqdm(
|
@@ -322,9 +314,9 @@ def main():
|
|
322 |
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
|
323 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
324 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
325 |
-
|
326 |
|
327 |
-
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 *
|
328 |
|
329 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
330 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
@@ -335,7 +327,6 @@ def main():
|
|
335 |
total_ae_loss += ae_loss.item()
|
336 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
337 |
total_neg_stoi_loss += neg_stoi_loss.item()
|
338 |
-
total_lds_loss += lds_loss.item()
|
339 |
total_batches += 1
|
340 |
|
341 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
@@ -343,7 +334,6 @@ def main():
|
|
343 |
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
344 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
345 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
346 |
-
average_lds_loss = round(total_lds_loss / total_batches, 4)
|
347 |
|
348 |
progress_bar.update(1)
|
349 |
progress_bar.set_postfix({
|
@@ -352,7 +342,6 @@ def main():
|
|
352 |
"ae_loss": average_ae_loss,
|
353 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
354 |
"neg_stoi_loss": average_neg_stoi_loss,
|
355 |
-
"lds_loss": average_lds_loss,
|
356 |
})
|
357 |
|
358 |
# scheduler
|
@@ -392,7 +381,6 @@ def main():
|
|
392 |
"ae_loss": average_ae_loss,
|
393 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
394 |
"neg_stoi_loss": average_neg_stoi_loss,
|
395 |
-
"lds_loss": average_lds_loss,
|
396 |
}
|
397 |
metrics_filename = epoch_dir / "metrics_epoch.json"
|
398 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|
|
|
204 |
ae_loss_fn = nn.L1Loss(reduction="mean").to(device)
|
205 |
neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
|
206 |
neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
|
|
|
207 |
mr_stft_loss_fn = MultiResolutionSTFTLoss(
|
208 |
# fft_size_list=[256, 512, 1024],
|
209 |
win_size_list=[120, 240, 480],
|
|
|
219 |
average_ae_loss = 1000000000
|
220 |
average_neg_si_snr_loss = 1000000000
|
221 |
average_neg_stoi_loss = 1000000000
|
|
|
222 |
|
223 |
model_list = list()
|
224 |
best_idx_epoch = None
|
|
|
235 |
total_ae_loss = 0.
|
236 |
total_neg_si_snr_loss = 0.
|
237 |
total_neg_stoi_loss = 0.
|
|
|
238 |
total_batches = 0.
|
239 |
progress_bar = tqdm(
|
240 |
desc="Training; epoch-{}".format(idx_epoch),
|
|
|
253 |
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
|
254 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
255 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
|
|
256 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
257 |
|
258 |
+
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
259 |
|
260 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
261 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
|
|
271 |
total_ae_loss += ae_loss.item()
|
272 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
273 |
total_neg_stoi_loss += neg_stoi_loss.item()
|
|
|
274 |
total_batches += 1
|
275 |
|
276 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
|
|
278 |
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
279 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
280 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
|
|
281 |
|
282 |
progress_bar.update(1)
|
283 |
progress_bar.set_postfix({
|
|
|
286 |
"ae_loss": average_ae_loss,
|
287 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
288 |
"neg_stoi_loss": average_neg_stoi_loss,
|
|
|
289 |
})
|
290 |
|
291 |
# evaluation
|
|
|
297 |
total_ae_loss = 0.
|
298 |
total_neg_si_snr_loss = 0.
|
299 |
total_neg_stoi_loss = 0.
|
|
|
300 |
total_batches = 0.
|
301 |
|
302 |
progress_bar = tqdm(
|
|
|
314 |
ae_loss = ae_loss_fn.forward(denoise_audios, clean_audios)
|
315 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
316 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
317 |
+
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
318 |
|
319 |
+
loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * mr_stft_loss
|
320 |
|
321 |
denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
|
322 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
|
|
327 |
total_ae_loss += ae_loss.item()
|
328 |
total_neg_si_snr_loss += neg_si_snr_loss.item()
|
329 |
total_neg_stoi_loss += neg_stoi_loss.item()
|
|
|
330 |
total_batches += 1
|
331 |
|
332 |
average_pesq_score = round(total_pesq_score / total_batches, 4)
|
|
|
334 |
average_ae_loss = round(total_ae_loss / total_batches, 4)
|
335 |
average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
|
336 |
average_neg_stoi_loss = round(total_neg_stoi_loss / total_batches, 4)
|
|
|
337 |
|
338 |
progress_bar.update(1)
|
339 |
progress_bar.set_postfix({
|
|
|
342 |
"ae_loss": average_ae_loss,
|
343 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
344 |
"neg_stoi_loss": average_neg_stoi_loss,
|
|
|
345 |
})
|
346 |
|
347 |
# scheduler
|
|
|
381 |
"ae_loss": average_ae_loss,
|
382 |
"neg_si_snr_loss": average_neg_si_snr_loss,
|
383 |
"neg_stoi_loss": average_neg_stoi_loss,
|
|
|
384 |
}
|
385 |
metrics_filename = epoch_dir / "metrics_epoch.json"
|
386 |
with open(metrics_filename, "w", encoding="utf-8") as f:
|