HoneyTian commited on
Commit
8eecf8d
·
1 Parent(s): ec8bf87
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -238,14 +238,17 @@ def main():
238
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
239
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
240
  pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
241
- print(f"pesq_metric: {pesq_metric}")
 
 
 
242
 
243
  optimizer.zero_grad()
244
  loss.backward()
245
  optimizer.step()
246
  lr_scheduler.step()
247
 
248
- total_pesq_metric += 0 if pesq_metric is None else pesq_metric.item()
249
  total_loss += loss.item()
250
  total_ae_loss += ae_loss.item()
251
  total_sc_loss += sc_loss.item()
 
238
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
239
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
240
  pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
241
+ if pesq_metric is None:
242
+ pesq_metric = 0
243
+ else:
244
+ pesq_metric = torch.mean(pesq_metric).item()
245
 
246
  optimizer.zero_grad()
247
  loss.backward()
248
  optimizer.step()
249
  lr_scheduler.step()
250
 
251
+ total_pesq_metric += pesq_metric
252
  total_loss += loss.item()
253
  total_ae_loss += ae_loss.item()
254
  total_sc_loss += sc_loss.item()