Spaces:
Running
Running
update
Browse files
examples/rnnoise/step_2_train_model.py
CHANGED
@@ -249,7 +249,6 @@ def main():
|
|
249 |
|
250 |
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
251 |
|
252 |
-
logger.info("training")
|
253 |
early_stop_flag = False
|
254 |
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
255 |
if early_stop_flag:
|
@@ -274,6 +273,7 @@ def main():
|
|
274 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
275 |
|
276 |
denoise_audios, _, _ = model.forward(noisy_audios)
|
|
|
277 |
|
278 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
279 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
@@ -339,6 +339,7 @@ def main():
|
|
339 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
340 |
|
341 |
denoise_audios, _, _ = model.forward(noisy_audios)
|
|
|
342 |
|
343 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
344 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
|
|
249 |
|
250 |
step_idx = 0 if last_step_idx == -1 else last_step_idx
|
251 |
|
|
|
252 |
early_stop_flag = False
|
253 |
for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
|
254 |
if early_stop_flag:
|
|
|
273 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
274 |
|
275 |
denoise_audios, _, _ = model.forward(noisy_audios)
|
276 |
+
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
277 |
|
278 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
279 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
|
|
339 |
noisy_audios: torch.Tensor = noisy_audios.to(device)
|
340 |
|
341 |
denoise_audios, _, _ = model.forward(noisy_audios)
|
342 |
+
denoise_audios = torch.squeeze(denoise_audios, dim=1)
|
343 |
|
344 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
345 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|