HoneyTian commited on
Commit
9bd6f8b
·
1 Parent(s): e07fc8f
examples/nx_clean_unet/step_2_train_model.py CHANGED
@@ -242,6 +242,8 @@ def main():
242
  clean_audios = clean_audios.to(device)
243
  noisy_audios = noisy_audios.to(device)
244
  one_labels = torch.ones(clean_audios.shape[0]).to(device)
 
 
245
 
246
  audio_g = generator.forward(noisy_audios)
247
 
 
242
  clean_audios = clean_audios.to(device)
243
  noisy_audios = noisy_audios.to(device)
244
  one_labels = torch.ones(clean_audios.shape[0]).to(device)
245
+ print(f"clean_audios: {clean_audios.shape}")
246
+ print(f"noisy_audios: {noisy_audios.shape}")
247
 
248
  audio_g = generator.forward(noisy_audios)
249
 
toolbox/torchaudio/models/nx_clean_unet/discriminator.py CHANGED
@@ -116,8 +116,10 @@ def main():
116
  discriminator = MetricDiscriminator(config=config)
117
 
118
  # shape: [batch_size, num_samples]
119
- x = torch.ones([4, int(4.5 * 16000)])
120
- y = torch.ones([4, int(4.5 * 16000)])
 
 
121
 
122
  output = discriminator.forward(x, y)
123
  print(output.shape)
 
116
  discriminator = MetricDiscriminator(config=config)
117
 
118
  # shape: [batch_size, num_samples]
119
+ # x = torch.ones([4, int(4.5 * 16000)])
120
+ # y = torch.ones([4, int(4.5 * 16000)])
121
+ x = torch.ones([4, 16000])
122
+ y = torch.ones([4, 16000])
123
 
124
  output = discriminator.forward(x, y)
125
  print(output.shape)