HoneyTian commited on
Commit
60e65ac
·
1 Parent(s): cbe8ba1
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -79,7 +79,7 @@ class CollateFunction(object):
79
  hop_length: int = 80,
80
  window_fn: str = "hamming",
81
  irm_beta: float = 1.0,
82
- epsilon: float = 1e-7,
83
  ):
84
  self.n_fft = n_fft
85
  self.win_length = win_length
@@ -138,17 +138,10 @@ class CollateFunction(object):
138
 
139
  # noise_spec, speech_spec, mix_spec, speech_irm
140
  # shape: [freq_dim, time_steps]
141
- if torch.any(torch.isnan(speech_spec)) or torch.any(torch.isinf(speech_spec)):
142
- raise AssertionError("nan or inf in speech_spec")
143
- if torch.any(torch.isnan(noise_spec)) or torch.any(torch.isinf(noise_spec)):
144
- raise AssertionError("nan or inf in noise_spec")
145
 
146
  snr_db: torch.Tensor = 10 * torch.log10(
147
- speech_spec / (noise_spec + self.epsilon)
148
  )
149
- if torch.any(torch.isnan(snr_db)) or torch.any(torch.isinf(snr_db)):
150
- raise AssertionError("nan or inf in snr_db")
151
-
152
  snr_db_ = torch.unsqueeze(snr_db, dim=0)
153
  snr_db_ = torch.unsqueeze(snr_db_, dim=0)
154
  snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
@@ -305,7 +298,7 @@ def main():
305
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
306
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
307
  raise AssertionError("nan or inf in snr_loss")
308
- loss = irm_loss + 0 * snr_loss
309
  # loss = irm_loss
310
 
311
  total_loss += loss.item()
@@ -343,11 +336,11 @@ def main():
343
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
344
  raise AssertionError("nan or inf in lsnr_prediction")
345
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
346
- # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
347
- # loss = irm_loss + 0*snr_loss
348
- loss = irm_loss
349
 
350
- total_loss += loss.item()
351
  total_examples += mix_spec.size(0)
352
 
353
  evaluation_loss = total_loss / total_examples
 
79
  hop_length: int = 80,
80
  window_fn: str = "hamming",
81
  irm_beta: float = 1.0,
82
+ epsilon: float = 1e-8,
83
  ):
84
  self.n_fft = n_fft
85
  self.win_length = win_length
 
138
 
139
  # noise_spec, speech_spec, mix_spec, speech_irm
140
  # shape: [freq_dim, time_steps]
 
 
 
 
141
 
142
  snr_db: torch.Tensor = 10 * torch.log10(
143
+ speech_spec / (noise_spec + self.epsilon) + self.epsilon
144
  )
 
 
 
145
  snr_db_ = torch.unsqueeze(snr_db, dim=0)
146
  snr_db_ = torch.unsqueeze(snr_db_, dim=0)
147
  snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
 
298
  snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
299
  if torch.any(torch.isnan(snr_loss)) or torch.any(torch.isinf(snr_loss)):
300
  raise AssertionError("nan or inf in snr_loss")
301
+ loss = irm_loss + 0.1 * snr_loss
302
  # loss = irm_loss
303
 
304
  total_loss += loss.item()
 
336
  if torch.any(torch.isnan(lsnr_prediction)) or torch.any(torch.isinf(lsnr_prediction)):
337
  raise AssertionError("nan or inf in lsnr_prediction")
338
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
339
+ snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
340
+ loss = irm_loss + 0.1 * snr_loss
341
+ # loss = irm_loss
342
 
343
+ total_loss += loss.item()
344
  total_examples += mix_spec.size(0)
345
 
346
  evaluation_loss = total_loss / total_examples
examples/test.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+
5
+ speech_spec = torch.tensor([0], dtype=torch.float32)
6
+ noise_spec = torch.tensor([0], dtype=torch.float32)
7
+ epsilon = 1e-8
8
+
9
+
10
+ result = torch.log10(
11
+ speech_spec / (noise_spec + epsilon) + epsilon
12
+ )
13
+
14
+ print(result)
15
+
16
+
17
+ if __name__ == '__main__':
18
+ pass