Spaces:
Running
Running
update
Browse files
examples/spectrum_dfnet_aishell/step_3_evaluation.py
CHANGED
@@ -111,7 +111,6 @@ def enhance(mix_spec_complex: torch.Tensor,
|
|
111 |
# print(f"speech_spec_prediction: {speech_spec_prediction.shape}")
|
112 |
# print(f"noise_spec: {noise_spec.shape}")
|
113 |
|
114 |
-
speech_spec_prediction = torch.view_as_complex(speech_spec_prediction)
|
115 |
speech_wave = istft.forward(speech_spec_prediction)
|
116 |
# speech_wave = istft.forward(speech_spec)
|
117 |
noise_wave = istft.forward(noise_spec)
|
@@ -245,6 +244,8 @@ def main():
|
|
245 |
|
246 |
with torch.no_grad():
|
247 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_spec_complex)
|
|
|
|
|
248 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
249 |
# snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
250 |
# loss = irm_loss + 0.1 * snr_loss
|
@@ -252,14 +253,31 @@ def main():
|
|
252 |
|
253 |
# mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
|
254 |
# speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
# speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
|
264 |
speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction)
|
265 |
save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
|
|
|
111 |
# print(f"speech_spec_prediction: {speech_spec_prediction.shape}")
|
112 |
# print(f"noise_spec: {noise_spec.shape}")
|
113 |
|
|
|
114 |
speech_wave = istft.forward(speech_spec_prediction)
|
115 |
# speech_wave = istft.forward(speech_spec)
|
116 |
noise_wave = istft.forward(noise_spec)
|
|
|
244 |
|
245 |
with torch.no_grad():
|
246 |
speech_spec_prediction, speech_irm_prediction, lsnr_prediction = model.forward(mix_spec_complex)
|
247 |
+
speech_spec_prediction = torch.view_as_complex(speech_spec_prediction)
|
248 |
+
|
249 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
250 |
# snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
251 |
# loss = irm_loss + 0.1 * snr_loss
|
|
|
253 |
|
254 |
# mix_spec_complex shape: [batch_size, freq_dim (257), time_steps, 2]
|
255 |
# speech_irm_prediction shape: [batch_size, freq_dim (256), time_steps]
|
256 |
+
batch_size, _, time_steps = speech_irm_prediction.shape
|
257 |
+
|
258 |
+
|
259 |
+
mix_spec_complex = torch.concat(
|
260 |
+
[
|
261 |
+
mix_spec_complex,
|
262 |
+
torch.zeros(size=(batch_size, 1, time_steps), dtype=mix_spec_complex.dtype).to(device)
|
263 |
+
],
|
264 |
+
dim=1,
|
265 |
+
)
|
266 |
+
speech_spec_prediction = torch.concat(
|
267 |
+
[
|
268 |
+
speech_spec_prediction,
|
269 |
+
torch.zeros(size=(batch_size, 1, time_steps), dtype=speech_spec_prediction.dtype).to(device)
|
270 |
+
],
|
271 |
+
dim=1,
|
272 |
+
)
|
273 |
+
speech_irm_prediction = torch.concat(
|
274 |
+
[
|
275 |
+
speech_irm_prediction,
|
276 |
+
0.5 * torch.ones(size=(batch_size, 1, time_steps), dtype=speech_irm_prediction.dtype).to(device)
|
277 |
+
],
|
278 |
+
dim=1,
|
279 |
+
)
|
280 |
+
|
281 |
# speech_irm_prediction shape: [batch_size, freq_dim (257), time_steps]
|
282 |
speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_spec_prediction, speech_irm_prediction)
|
283 |
save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
|