HoneyTian commited on
Commit
8cf37ea
·
1 Parent(s): 9e01c3d
examples/spectrum_unet_irm_aishell/step_2_train_model.py CHANGED
@@ -21,6 +21,7 @@ sys.path.append(os.path.join(pwd, "../../"))
21
  import numpy as np
22
  import torch
23
  import torch.nn as nn
 
24
  from torch.utils.data.dataloader import DataLoader
25
  import torchaudio
26
  from tqdm import tqdm
@@ -95,6 +96,28 @@ class CollateFunction(object):
95
  window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
96
  )
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def __call__(self, batch: List[dict]):
99
  mix_spec_list = list()
100
  speech_irm_list = list()
@@ -119,7 +142,13 @@ class CollateFunction(object):
119
  snr_db: torch.Tensor = 10 * torch.log10(
120
  speech_spec / (noise_spec + self.epsilon)
121
  )
122
- snr_db = torch.mean(snr_db, dim=0, keepdim=True)
 
 
 
 
 
 
123
  # snr_db shape: [1, time_steps]
124
 
125
  mix_spec_list.append(mix_spec)
@@ -262,9 +291,9 @@ def main():
262
 
263
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
264
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
265
- # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
266
- # loss = irm_loss + 0.1 * snr_loss
267
- loss = irm_loss
268
 
269
  total_loss += loss.item()
270
  total_examples += mix_spec.size(0)
@@ -297,9 +326,9 @@ def main():
297
  with torch.no_grad():
298
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
299
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
300
- # snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
301
- # loss = irm_loss + 0.1 * snr_loss
302
- loss = irm_loss
303
 
304
  total_loss += loss.item()
305
  total_examples += mix_spec.size(0)
 
21
  import numpy as np
22
  import torch
23
  import torch.nn as nn
24
+ from torch.nn import functional as F
25
  from torch.utils.data.dataloader import DataLoader
26
  import torchaudio
27
  from tqdm import tqdm
 
96
  window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
97
  )
98
 
99
+ @staticmethod
100
+ def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3):
101
+ batch_size, channels, freq_dim, time_steps = x.shape
102
+
103
+ # kernel: [freq_dim, n_time_step]
104
+ kernel_size = (freq_dim, n_time_steps)
105
+
106
+ # pad
107
+ pad = n_time_steps // 2
108
+ x = torch.concat(tensors=[
109
+ x[:, :, :, :pad],
110
+ x,
111
+ x[:, :, :, -pad:],
112
+ ], dim=-1)
113
+
114
+ x = F.unfold(
115
+ input=x,
116
+ kernel_size=kernel_size,
117
+ )
118
+ # x shape: [batch_size, fold, time_steps]
119
+ return x
120
+
121
  def __call__(self, batch: List[dict]):
122
  mix_spec_list = list()
123
  speech_irm_list = list()
 
142
  snr_db: torch.Tensor = 10 * torch.log10(
143
  speech_spec / (noise_spec + self.epsilon)
144
  )
145
+ snr_db_ = torch.unsqueeze(snr_db, dim=0)
146
+ snr_db_ = torch.unsqueeze(snr_db_, dim=0)
147
+ snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
148
+ snr_db_ = torch.squeeze(snr_db_, dim=0)
149
+ # snr_db_ shape: [fold, time_steps]
150
+
151
+ snr_db = torch.mean(snr_db_, dim=0, keepdim=True)
152
  # snr_db shape: [1, time_steps]
153
 
154
  mix_spec_list.append(mix_spec)
 
291
 
292
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
293
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
294
+ snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
295
+ loss = irm_loss + 0.1 * snr_loss
296
+ # loss = irm_loss
297
 
298
  total_loss += loss.item()
299
  total_examples += mix_spec.size(0)
 
326
  with torch.no_grad():
327
  speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
328
  irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
329
+ snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
330
+ loss = irm_loss + 0.1 * snr_loss
331
+ # loss = irm_loss
332
 
333
  total_loss += loss.item()
334
  total_examples += mix_spec.size(0)