Spaces:
Running
Running
update
Browse files
examples/clean_unet_aishell/step_2_train_model.py
CHANGED
@@ -238,14 +238,17 @@ def main():
|
|
238 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
239 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
240 |
pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
|
241 |
-
|
|
|
|
|
|
|
242 |
|
243 |
optimizer.zero_grad()
|
244 |
loss.backward()
|
245 |
optimizer.step()
|
246 |
lr_scheduler.step()
|
247 |
|
248 |
-
total_pesq_metric +=
|
249 |
total_loss += loss.item()
|
250 |
total_ae_loss += ae_loss.item()
|
251 |
total_sc_loss += sc_loss.item()
|
|
|
238 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
239 |
clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
|
240 |
pesq_metric = batch_pesq(enhanced_audios_list_r, clean_audios_list_r)
|
241 |
+
if pesq_metric is None:
|
242 |
+
pesq_metric = 0
|
243 |
+
else:
|
244 |
+
pesq_metric = torch.mean(pesq_metric).item()
|
245 |
|
246 |
optimizer.zero_grad()
|
247 |
loss.backward()
|
248 |
optimizer.step()
|
249 |
lr_scheduler.step()
|
250 |
|
251 |
+
total_pesq_metric += pesq_metric
|
252 |
total_loss += loss.item()
|
253 |
total_ae_loss += ae_loss.item()
|
254 |
total_sc_loss += sc_loss.item()
|