HoneyTian commited on
Commit
46c2bb3
·
1 Parent(s): 85947fe
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -236,19 +236,9 @@ def main():
236
  enhanced_audios = model.forward(noisy_audios)
237
  enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
238
 
239
- if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
240
- raise AssertionError("nan or inf in clean_audios")
241
- if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
242
- raise AssertionError("nan or inf in noisy_audios")
243
- if torch.any(torch.isnan(enhanced_audios)) or torch.any(torch.isinf(enhanced_audios)):
244
- raise AssertionError("nan or inf in enhanced_audios")
245
-
246
  ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
247
  sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
248
 
249
- if torch.any(torch.isnan(mag_loss)) or torch.any(torch.isinf(mag_loss)):
250
- raise AssertionError("nan or inf in mag_loss")
251
-
252
  loss = ae_loss + sc_loss + mag_loss
253
 
254
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
 
236
  enhanced_audios = model.forward(noisy_audios)
237
  enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
238
 
 
 
 
 
 
 
 
239
  ae_loss = ae_loss_fn(enhanced_audios, clean_audios)
240
  sc_loss, mag_loss = mr_stft_loss_fn(enhanced_audios, clean_audios)
241
 
 
 
 
242
  loss = ae_loss + sc_loss + mag_loss
243
 
244
  enhanced_audios_list_r = list(enhanced_audios.detach().cpu().numpy())
toolbox/torchaudio/models/clean_unet/modeling_clean_unet.py CHANGED
@@ -144,7 +144,7 @@ class CleanUNet(nn.Module):
144
  nn.Conv1d(channels_h, channels_h * 2, 1),
145
  nn.GLU(dim=1),
146
  nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
147
- nn.ReLU(inplace=False)
148
  ))
149
  channels_output = channels_h
150
 
 
144
  nn.Conv1d(channels_h, channels_h * 2, 1),
145
  nn.GLU(dim=1),
146
  nn.ConvTranspose1d(channels_h, channels_output, kernel_size, stride),
147
+ # nn.ReLU(inplace=False)
148
  ))
149
  channels_output = channels_h
150