Spaces:
Running
Running
update
Browse files
examples/clean_unet_aishell/step_2_train_model.py
CHANGED
@@ -236,19 +236,9 @@ def main():
|
|
236 |
enhanced_audios = model.forward(noisy_audios)
|
237 |
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
238 |
|
239 |
-
if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
|
240 |
-
raise AssertionError("nan or inf in clean_audios")
|
241 |
-
if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
|
242 |
-
raise AssertionError("nan or inf in noisy_audios")
|
243 |
-
if torch.any(torch.isnan(enhanced_audios)) or torch.any(torch.isinf(enhanced_audios)):
|
244 |
-
raise AssertionError("nan or inf in enhanced_audios")
|
245 |
-
|
246 |
ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
|
247 |
sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
|
248 |
|
249 |
-
if torch.any(torch.isnan(mag_loss)) or torch.any(torch.isinf(mag_loss)):
|
250 |
-
raise AssertionError("nan or inf in mag_loss")
|
251 |
-
|
252 |
loss = ae_loss + sc_loss + mag_loss
|
253 |
|
254 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
|
|
236 |
enhanced_audios = model.forward(noisy_audios)
|
237 |
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
|
240 |
sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
|
241 |
|
|
|
|
|
|
|
242 |
loss = ae_loss + sc_loss + mag_loss
|
243 |
|
244 |
enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
|
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py
CHANGED
@@ -144,7 +144,7 @@ class CleanUNet(nn.Module):
|
|
144 |
nn.Conv1d(channels_h, channels_h * 2, 1),
|
145 |
nn.GLU(dim=1),
|
146 |
nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
|
147 |
-
nn.ReLU(inplace=False)
|
148 |
))
|
149 |
channels_output = channels_h
|
150 |
|
|
|
144 |
nn.Conv1d(channels_h, channels_h * 2, 1),
|
145 |
nn.GLU(dim=1),
|
146 |
nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
|
147 |
+
# nn.ReLU(inplace=False)
|
148 |
))
|
149 |
channels_output = channels_h
|
150 |
|