HoneyTian commited on
Commit
d791cee
·
1 Parent(s): 39d295e
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -144,7 +144,7 @@ def main():
144
  max_wave_value=32768.0,
145
  min_snr_db=config.min_snr_db,
146
  max_snr_db=config.max_snr_db,
147
- skip=225000,
148
  )
149
  valid_dataset = DenoiseJsonlDataset(
150
  jsonl_file=args.valid_dataset,
@@ -195,7 +195,7 @@ def main():
195
  step_idx = int(step_idx)
196
  if step_idx > last_step_idx:
197
  last_step_idx = step_idx
198
- last_epoch = 1
199
 
200
  if last_step_idx != -1:
201
  logger.info(f"resume from steps-{last_step_idx}.")
 
144
  max_wave_value=32768.0,
145
  min_snr_db=config.min_snr_db,
146
  max_snr_db=config.max_snr_db,
147
+ # skip=225000,
148
  )
149
  valid_dataset = DenoiseJsonlDataset(
150
  jsonl_file=args.valid_dataset,
 
195
  step_idx = int(step_idx)
196
  if step_idx > last_step_idx:
197
  last_step_idx = step_idx
198
+ last_epoch = 2
199
 
200
  if last_step_idx != -1:
201
  logger.info(f"resume from steps-{last_step_idx}.")
examples/conv_tasnet_gan/step_2_train_model.py CHANGED
@@ -364,7 +364,7 @@ def main():
364
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
365
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
366
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
367
- pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
368
 
369
  metric_g = discriminator.forward(denoise_audios, clean_audios)
370
  discriminator_g_loss = F.mse_loss(metric_g.flatten(), one_labels)
@@ -451,7 +451,7 @@ def main():
451
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
452
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
453
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
454
- pesq_loss = pesq_loss_fn.forward(clean_audios, denoise_audios)
455
 
456
  loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
457
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
 
364
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
365
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
366
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
367
+ pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
368
 
369
  metric_g = discriminator.forward(denoise_audios, clean_audios)
370
  discriminator_g_loss = F.mse_loss(metric_g.flatten(), one_labels)
 
451
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
452
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
453
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
454
+ pesq_loss = pesq_loss_fn.forward(denoise_audios, clean_audios)
455
 
456
  loss = 1.0 * ae_loss + 0.8 * neg_si_snr_loss + 0.7 * mr_stft_loss + 0.5 * neg_stoi_loss + 0.5 * pesq_loss
457
  if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):