Spaces:
Running
Running
update
Browse files
examples/nx_clean_unet/step_2_train_model.py
CHANGED
@@ -242,6 +242,8 @@ def main():
|
|
242 |
clean_audios = clean_audios.to(device)
|
243 |
noisy_audios = noisy_audios.to(device)
|
244 |
one_labels = torch.ones(clean_audios.shape[0]).to(device)
|
|
|
|
|
245 |
|
246 |
audio_g = generator.forward(noisy_audios)
|
247 |
|
|
|
242 |
clean_audios = clean_audios.to(device)
|
243 |
noisy_audios = noisy_audios.to(device)
|
244 |
one_labels = torch.ones(clean_audios.shape[0]).to(device)
|
245 |
+
print(f"clean_audios: {clean_audios.shape}")
|
246 |
+
print(f"noisy_audios: {noisy_audios.shape}")
|
247 |
|
248 |
audio_g = generator.forward(noisy_audios)
|
249 |
|
toolbox/torchaudio/models/nx_clean_unet/discriminator.py
CHANGED
@@ -116,8 +116,10 @@ def main():
|
|
116 |
discriminator = MetricDiscriminator(config=config)
|
117 |
|
118 |
# shape: [batch_size, num_samples]
|
119 |
-
x = torch.ones([4, int(4.5 * 16000)])
|
120 |
-
y = torch.ones([4, int(4.5 * 16000)])
|
|
|
|
|
121 |
|
122 |
output = discriminator.forward(x, y)
|
123 |
print(output.shape)
|
|
|
116 |
discriminator = MetricDiscriminator(config=config)
|
117 |
|
118 |
# shape: [batch_size, num_samples]
|
119 |
+
# x = torch.ones([4, int(4.5 * 16000)])
|
120 |
+
# y = torch.ones([4, int(4.5 * 16000)])
|
121 |
+
x = torch.ones([4, 16000])
|
122 |
+
y = torch.ones([4, 16000])
|
123 |
|
124 |
output = discriminator.forward(x, y)
|
125 |
print(output.shape)
|