HoneyTian commited on
Commit
42e5399
·
1 Parent(s): b27ed9f
examples/clean_unet_aishell/step_1_prepare_data.py CHANGED
@@ -101,7 +101,7 @@ def get_dataset(args):
101
  count = 0
102
  process_bar = tqdm(desc="build dataset excel")
103
  for noise, speech in zip(noise_generator, speech_generator):
104
- if count > args.max_count:
105
  break
106
 
107
  noise_filename = noise["filename"]
 
101
  count = 0
102
  process_bar = tqdm(desc="build dataset excel")
103
  for noise, speech in zip(noise_generator, speech_generator):
104
+ if count >= args.max_count:
105
  break
106
 
107
  noise_filename = noise["filename"]
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -232,8 +232,8 @@ def main():
232
 
233
  loss = ae_loss + sc_loss + mag_loss
234
 
235
- enhanced_audios_list_r = list(enhanced_audios.cpu().numpy())
236
- clean_audios_list_r = list(clean_audios.cpu().numpy())
237
  pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
238
 
239
  optimizer.zero_grad()
@@ -292,8 +292,8 @@ def main():
292
 
293
  loss = ae_loss + sc_loss + mag_loss
294
 
295
- enhanced_audios_list_r = list(enhanced_audios.cpu().numpy())
296
- clean_audios_list_r = list(clean_audios.cpu().numpy())
297
  pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
298
 
299
  total_pesq_metric += pesq_metric.item()
 
232
 
233
  loss = ae_loss + sc_loss + mag_loss
234
 
235
+ enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
236
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
237
  pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
238
 
239
  optimizer.zero_grad()
 
292
 
293
  loss = ae_loss + sc_loss + mag_loss
294
 
295
+ enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
296
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
297
  pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
298
 
299
  total_pesq_metric += pesq_metric.item()