Spaces:
Running
Running
update
Browse files
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -79,7 +79,7 @@ class CollateFunction(object):
|
|
79 |
hop_length: int = 80,
|
80 |
window_fn: str = "hamming",
|
81 |
irm_beta: float = 1.0,
|
82 |
-
epsilon: float = 1e-
|
83 |
):
|
84 |
self.n_fft = n_fft
|
85 |
self.win_length = win_length
|
@@ -138,17 +138,10 @@ class CollateFunction(object):
|
|
138 |
|
139 |
# noise_spec, speech_spec, mix_spec, speech_irm
|
140 |
# shape: [freq_dim, time_steps]
|
141 |
-
if torch.any(torch.isnan(speech_spec)) or torch.any(torch.isinf(speech_spec)):
|
142 |
-
raise AssertionError("nan or inf in speech_spec")
|
143 |
-
if torch.any(torch.isnan(noise_spec)) or torch.any(torch.isinf(noise_spec)):
|
144 |
-
raise AssertionError("nan or inf in noise_spec")
|
145 |
|
146 |
snr_db: torch.Tensor = 10 * torch.log10(
|
147 |
-
speech_spec / (noise_spec + self.epsilon)
|
148 |
)
|
149 |
-
if torch.any(torch.isnan(snr_db)) or torch.any(torch.isinf(snr_db)):
|
150 |
-
raise AssertionError("nan or inf in snr_db")
|
151 |
-
|
152 |
snr_db_ = torch.unsqueeze(snr_db, dim=0)
|
153 |
snr_db_ = torch.unsqueeze(snr_db_, dim=0)
|
154 |
snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
|
@@ -305,7 +298,7 @@ def main():
|
|
305 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
306 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
307 |
raise AssertionError("nan or inf in snr_loss")
|
308 |
-
loss = irm_loss + 0 * snr_loss
|
309 |
# loss = irm_loss
|
310 |
|
311 |
total_loss += loss.item()
|
@@ -343,11 +336,11 @@ def main():
|
|
343 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
344 |
raise AssertionError("nan or inf in lsnr_prediction")
|
345 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
346 |
-
|
347 |
-
|
348 |
-
loss = irm_loss
|
349 |
|
350 |
-
|
351 |
total_examples += mix_spec.size(0)
|
352 |
|
353 |
evaluation_loss = total_loss / total_examples
|
|
|
79 |
hop_length: int = 80,
|
80 |
window_fn: str = "hamming",
|
81 |
irm_beta: float = 1.0,
|
82 |
+
epsilon: float = 1e-8,
|
83 |
):
|
84 |
self.n_fft = n_fft
|
85 |
self.win_length = win_length
|
|
|
138 |
|
139 |
# noise_spec, speech_spec, mix_spec, speech_irm
|
140 |
# shape: [freq_dim, time_steps]
|
|
|
|
|
|
|
|
|
141 |
|
142 |
snr_db: torch.Tensor = 10 * torch.log10(
|
143 |
+
speech_spec / (noise_spec + self.epsilon) + self.epsilon
|
144 |
)
|
|
|
|
|
|
|
145 |
snr_db_ = torch.unsqueeze(snr_db, dim=0)
|
146 |
snr_db_ = torch.unsqueeze(snr_db_, dim=0)
|
147 |
snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
|
|
|
298 |
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
299 |
if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
|
300 |
raise AssertionError("nan or inf in snr_loss")
|
301 |
+
loss = irm_loss + 0.1 * snr_loss
|
302 |
# loss = irm_loss
|
303 |
|
304 |
total_loss += loss.item()
|
|
|
336 |
if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
|
337 |
raise AssertionError("nan or inf in lsnr_prediction")
|
338 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
339 |
+
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
340 |
+
loss = irm_loss + 0.1 * snr_loss
|
341 |
+
# loss = irm_loss
|
342 |
|
343 |
+
total_loss += loss.item()
|
344 |
total_examples += mix_spec.size(0)
|
345 |
|
346 |
evaluation_loss = total_loss / total_examples
|
examples/test.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import torch
|
4 |
+
|
5 |
+
speech_spec = torch.tensor([0], dtype=torch.float32)
|
6 |
+
noise_spec = torch.tensor([0], dtype=torch.float32)
|
7 |
+
epsilon = 1e-8
|
8 |
+
|
9 |
+
|
10 |
+
result = torch.log10(
|
11 |
+
speech_spec / (noise_spec + epsilon) + epsilon
|
12 |
+
)
|
13 |
+
|
14 |
+
print(result)
|
15 |
+
|
16 |
+
|
17 |
+
if __name__ == '__main__':
|
18 |
+
pass
|