HoneyTian commited on
Commit
7d18e1c
·
1 Parent(s): b8f2975
examples/spectrum_dfnet_aishell/step_3_evaluation.py CHANGED
@@ -111,7 +111,6 @@ def enhance(mix_spec_complex: torch.Tensor,
111
  # print(f"speech_spec_prediction: {speech_spec_prediction.shape}")
112
  # print(f"noise_spec: {noise_spec.shape}")
113
 
114
- speech_spec_prediction = torch.view_as_complex(speech_spec_prediction)
115
  speech_wave = istft.forward(speech_spec_prediction)
116
  # speech_wave = istft.forward(speech_spec)
117
  noise_wave = istft.forward(noise_spec)
@@ -245,6 +244,8 @@ def main():
245
 
246
  with torch.no_grad():
247
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_spec_complex)
 
 
248
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
249
  # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
250
  # loss = irm_loss + 0.1 * snr_loss
@@ -252,14 +253,31 @@ def main():
252
 
253
  # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
254
  # speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
255
- # batch_size, _, time_steps = speech_irm_prediction.shape
256
- # speech_irm_prediction = torch.concat(
257
- # [
258
- # speech_irm_prediction,
259
- # 0.5*torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device)
260
- # ],
261
- # dim=1,
262
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  # speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
264
  speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction)
265
  save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
 
111
  # print(f"speech_spec_prediction: {speech_spec_prediction.shape}")
112
  # print(f"noise_spec: {noise_spec.shape}")
113
 
 
114
  speech_wave = istft.forward(speech_spec_prediction)
115
  # speech_wave = istft.forward(speech_spec)
116
  noise_wave = istft.forward(noise_spec)
 
244
 
245
  with torch.no_grad():
246
  speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_spec_complex)
247
+ speech_spec_prediction = torch.view_as_complex(speech_spec_prediction)
248
+
249
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
250
  # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
251
  # loss = irm_loss + 0.1 * snr_loss
 
253
 
254
  # mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
255
  # speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
256
+ batch_size, _, time_steps = speech_irm_prediction.shape
257
+
258
+
259
+ mix_spec_complex = torch.concat(
260
+ [
261
+ mix_spec_complex,
262
+ torch.zeros(size=(batch_size, 1, time_steps), dtype=mix_spec_complex.dtype).to(device)
263
+ ],
264
+ dim=1,
265
+ )
266
+ speech_spec_prediction = torch.concat(
267
+ [
268
+ speech_spec_prediction,
269
+ torch.zeros(size=(batch_size, 1, time_steps), dtype=speech_spec_prediction.dtype).to(device)
270
+ ],
271
+ dim=1,
272
+ )
273
+ speech_irm_prediction = torch.concat(
274
+ [
275
+ speech_irm_prediction,
276
+ 0.5 * torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device)
277
+ ],
278
+ dim=1,
279
+ )
280
+
281
  # speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
282
  speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction)
283
  save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)