Spaces:
Running
Running
update
Browse files
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py
CHANGED
@@ -181,6 +181,8 @@ class NXCleanUNet(nn.Module):
|
|
181 |
)
|
182 |
|
183 |
def forward(self, noisy_audios: torch.Tensor):
|
|
|
|
|
184 |
# noisy_audios shape: [batch_size, 1, n_samples]
|
185 |
|
186 |
n_samples = noisy_audios.shape[-1]
|
@@ -208,6 +210,9 @@ class NXCleanUNet(nn.Module):
|
|
208 |
|
209 |
enhanced_audios = enhanced_audios[:, :, :n_samples]
|
210 |
# enhanced_audios shape: [batch_size, 1, n_samples]
|
|
|
|
|
|
|
211 |
return enhanced_audios
|
212 |
|
213 |
|
@@ -310,11 +315,11 @@ def main():
|
|
310 |
|
311 |
# shape: [batch_size, channels, num_samples]
|
312 |
# min length: 94, stride: 32, 32 == 2**5
|
313 |
-
# x = torch.ones([4,
|
314 |
-
# x = torch.ones([4,
|
315 |
-
# x = torch.ones([4,
|
316 |
-
# x = torch.ones([4,
|
317 |
-
x = torch.ones([4,
|
318 |
|
319 |
model = NXCleanUNet(config)
|
320 |
enhanced_audios = model.forward(x)
|
|
|
181 |
)
|
182 |
|
183 |
def forward(self, noisy_audios: torch.Tensor):
|
184 |
+
# noisy_audios shape: [batch_size, n_samples]
|
185 |
+
noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
|
186 |
# noisy_audios shape: [batch_size, 1, n_samples]
|
187 |
|
188 |
n_samples = noisy_audios.shape[-1]
|
|
|
210 |
|
211 |
enhanced_audios = enhanced_audios[:, :, :n_samples]
|
212 |
# enhanced_audios shape: [batch_size, 1, n_samples]
|
213 |
+
|
214 |
+
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
215 |
+
|
216 |
return enhanced_audios
|
217 |
|
218 |
|
|
|
315 |
|
316 |
# shape: [batch_size, channels, num_samples]
|
317 |
# min length: 94, stride: 32, 32 == 2**5
|
318 |
+
# x = torch.ones([4, 94])
|
319 |
+
# x = torch.ones([4, 126])
|
320 |
+
# x = torch.ones([4, 158])
|
321 |
+
# x = torch.ones([4, 190])
|
322 |
+
x = torch.ones([4, 16000])
|
323 |
|
324 |
model = NXCleanUNet(config)
|
325 |
enhanced_audios = model.forward(x)
|