Spaces:
Running
Running
update
Browse files
examples/spectrum_unet_irm_aishell/step_2_train_model.py
CHANGED
@@ -21,6 +21,7 @@ sys.path.append(os.path.join(pwd, "../../"))
|
|
21 |
import numpy as np
|
22 |
import torch
|
23 |
import torch.nn as nn
|
|
|
24 |
from torch.utils.data.dataloader import DataLoader
|
25 |
import torchaudio
|
26 |
from tqdm import tqdm
|
@@ -95,6 +96,28 @@ class CollateFunction(object):
|
|
95 |
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
96 |
)
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
def __call__(self, batch: List[dict]):
|
99 |
mix_spec_list = list()
|
100 |
speech_irm_list = list()
|
@@ -119,7 +142,13 @@ class CollateFunction(object):
|
|
119 |
snr_db: torch.Tensor = 10 * torch.log10(
|
120 |
speech_spec / (noise_spec + self.epsilon)
|
121 |
)
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
# snr_db shape: [1, time_steps]
|
124 |
|
125 |
mix_spec_list.append(mix_spec)
|
@@ -262,9 +291,9 @@ def main():
|
|
262 |
|
263 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
264 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
265 |
-
|
266 |
-
|
267 |
-
loss = irm_loss
|
268 |
|
269 |
total_loss += loss.item()
|
270 |
total_examples += mix_spec.size(0)
|
@@ -297,9 +326,9 @@ def main():
|
|
297 |
with torch.no_grad():
|
298 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
299 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
300 |
-
|
301 |
-
|
302 |
-
loss = irm_loss
|
303 |
|
304 |
total_loss += loss.item()
|
305 |
total_examples += mix_spec.size(0)
|
|
|
21 |
import numpy as np
|
22 |
import torch
|
23 |
import torch.nn as nn
|
24 |
+
from torch.nn import functional as F
|
25 |
from torch.utils.data.dataloader import DataLoader
|
26 |
import torchaudio
|
27 |
from tqdm import tqdm
|
|
|
96 |
window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
97 |
)
|
98 |
|
99 |
+
@staticmethod
|
100 |
+
def make_unfold_snr_db(x: torch.Tensor, n_time_steps: int = 3):
|
101 |
+
batch_size, channels, freq_dim, time_steps = x.shape
|
102 |
+
|
103 |
+
# kernel: [freq_dim, n_time_step]
|
104 |
+
kernel_size = (freq_dim, n_time_steps)
|
105 |
+
|
106 |
+
# pad
|
107 |
+
pad = n_time_steps // 2
|
108 |
+
x = torch.concat(tensors=[
|
109 |
+
x[:, :, :, :pad],
|
110 |
+
x,
|
111 |
+
x[:, :, :, -pad:],
|
112 |
+
], dim=-1)
|
113 |
+
|
114 |
+
x = F.unfold(
|
115 |
+
input=x,
|
116 |
+
kernel_size=kernel_size,
|
117 |
+
)
|
118 |
+
# x shape: [batch_size, fold, time_steps]
|
119 |
+
return x
|
120 |
+
|
121 |
def __call__(self, batch: List[dict]):
|
122 |
mix_spec_list = list()
|
123 |
speech_irm_list = list()
|
|
|
142 |
snr_db: torch.Tensor = 10 * torch.log10(
|
143 |
speech_spec / (noise_spec + self.epsilon)
|
144 |
)
|
145 |
+
snr_db_ = torch.unsqueeze(snr_db, dim=0)
|
146 |
+
snr_db_ = torch.unsqueeze(snr_db_, dim=0)
|
147 |
+
snr_db_ = self.make_unfold_snr_db(snr_db_, n_time_steps=3)
|
148 |
+
snr_db_ = torch.squeeze(snr_db_, dim=0)
|
149 |
+
# snr_db_ shape: [fold, time_steps]
|
150 |
+
|
151 |
+
snr_db = torch.mean(snr_db_, dim=0, keepdim=True)
|
152 |
# snr_db shape: [1, time_steps]
|
153 |
|
154 |
mix_spec_list.append(mix_spec)
|
|
|
291 |
|
292 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
293 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
294 |
+
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
295 |
+
loss = irm_loss + 0.1 * snr_loss
|
296 |
+
# loss = irm_loss
|
297 |
|
298 |
total_loss += loss.item()
|
299 |
total_examples += mix_spec.size(0)
|
|
|
326 |
with torch.no_grad():
|
327 |
speech_irm_prediction, lsnr_prediction = model.forward(mix_spec)
|
328 |
irm_loss = irm_mse_loss.forward(speech_irm_prediction, speech_irm_target)
|
329 |
+
snr_loss = snr_mse_loss.forward(lsnr_prediction, snr_db_target)
|
330 |
+
loss = irm_loss + 0.1 * snr_loss
|
331 |
+
# loss = irm_loss
|
332 |
|
333 |
total_loss += loss.item()
|
334 |
total_examples += mix_spec.size(0)
|