update
Browse files
toolbox/torchaudio/models/vad/cnn_vad/modeling_cnn_vad.py
CHANGED
@@ -201,6 +201,10 @@ class CNNVadModel(nn.Module):
|
|
201 |
raise AssertionError("Input signals must have the same shape")
|
202 |
noise = noisy - clean
|
203 |
|
|
|
|
|
|
|
|
|
204 |
if clean.dim() == 2:
|
205 |
clean = torch.unsqueeze(clean, dim=1)
|
206 |
if noise.dim() == 2:
|
@@ -223,6 +227,9 @@ class CNNVadModel(nn.Module):
|
|
223 |
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
|
224 |
# lsnr_gth shape: [b, t]
|
225 |
|
|
|
|
|
|
|
226 |
loss = F.mse_loss(lsnr, lsnr_gth)
|
227 |
return loss
|
228 |
|
|
|
201 |
raise AssertionError("Input signals must have the same shape")
|
202 |
noise = noisy - clean
|
203 |
|
204 |
+
print(f"lsnr: {lsnr.shape}")
|
205 |
+
print(f"clean: {clean.shape}")
|
206 |
+
print(f"noisy: {noisy.shape}")
|
207 |
+
|
208 |
if clean.dim() == 2:
|
209 |
clean = torch.unsqueeze(clean, dim=1)
|
210 |
if noise.dim() == 2:
|
|
|
227 |
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
|
228 |
# lsnr_gth shape: [b, t]
|
229 |
|
230 |
+
print(f"lsnr: {lsnr.shape}")
|
231 |
+
print(f"lsnr_gth: {lsnr_gth.shape}")
|
232 |
+
|
233 |
loss = F.mse_loss(lsnr, lsnr_gth)
|
234 |
return loss
|
235 |
|