Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -144,7 +144,7 @@ def main():
|
|
144 |
max_wave_value=32768.0,
|
145 |
min_snr_db=config.min_snr_db,
|
146 |
max_snr_db=config.max_snr_db,
|
147 |
-
skip=225000,
|
148 |
)
|
149 |
valid_dataset = DenoiseJsonlDataset(
|
150 |
jsonl_file=args.valid_dataset,
|
@@ -195,7 +195,7 @@ def main():
|
|
195 |
step_idx = int(step_idx)
|
196 |
if step_idx > last_step_idx:
|
197 |
last_step_idx = step_idx
|
198 |
-
last_epoch =
|
199 |
|
200 |
if last_step_idx != -1:
|
201 |
logger.info(f"resume from steps-{last_step_idx}.")
|
|
|
144 |
max_wave_value=32768.0,
|
145 |
min_snr_db=config.min_snr_db,
|
146 |
max_snr_db=config.max_snr_db,
|
147 |
+
# skip=225000,
|
148 |
)
|
149 |
valid_dataset = DenoiseJsonlDataset(
|
150 |
jsonl_file=args.valid_dataset,
|
|
|
195 |
step_idx = int(step_idx)
|
196 |
if step_idx > last_step_idx:
|
197 |
last_step_idx = step_idx
|
198 |
+
last_epoch = 2
|
199 |
|
200 |
if last_step_idx != -1:
|
201 |
logger.info(f"resume from steps-{last_step_idx}.")
|
examples/conv_tasnet_gan/step_2_train_model.py
CHANGED
@@ -364,7 +364,7 @@ def main():
|
|
364 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
365 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
366 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
367 |
-
pesq_loss = pesq_loss_fn.forward(
|
368 |
|
369 |
metric_g = discriminator.forward(denoise_audios, clean_audios)
|
370 |
discriminator_g_loss = F.mse_loss(metric_g.flatten(), one_labels)
|
@@ -451,7 +451,7 @@ def main():
|
|
451 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
452 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
453 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
454 |
-
pesq_loss = pesq_loss_fn.forward(
|
455 |
|
456 |
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
457 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|
|
|
364 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
365 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
366 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
367 |
+
pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
|
368 |
|
369 |
metric_g = discriminator.forward(denoise_audios, clean_audios)
|
370 |
discriminator_g_loss = F.mse_loss(metric_g.flatten(), one_labels)
|
|
|
451 |
neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
|
452 |
neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
|
453 |
mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
|
454 |
+
pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
|
455 |
|
456 |
loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
|
457 |
if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
|