Spaces:
Running
Running
update
Browse files
examples/nx_clean_unet/step_2_train_model.py
CHANGED
@@ -242,8 +242,6 @@ def main():
|
|
242 |
clean_audios = clean_audios.to(device)
|
243 |
noisy_audios = noisy_audios.to(device)
|
244 |
one_labels = torch.ones(clean_audios.shape[0]).to(device)
|
245 |
-
print(f"clean_audios: {clean_audios.shape}")
|
246 |
-
print(f"noisy_audios: {noisy_audios.shape}")
|
247 |
|
248 |
audio_g = generator.forward(noisy_audios)
|
249 |
|
|
|
242 |
clean_audios = clean_audios.to(device)
|
243 |
noisy_audios = noisy_audios.to(device)
|
244 |
one_labels = torch.ones(clean_audios.shape[0]).to(device)
|
|
|
|
|
245 |
|
246 |
audio_g = generator.forward(noisy_audios)
|
247 |
|
examples/nx_clean_unet/yaml/config.yaml
CHANGED
@@ -12,8 +12,8 @@ down_sampling_hidden_channels: 64
|
|
12 |
down_sampling_kernel_size: 4
|
13 |
down_sampling_stride: 2
|
14 |
|
15 |
-
tsfm_hidden_size:
|
16 |
-
tsfm_attention_heads:
|
17 |
tsfm_num_blocks: 6
|
18 |
tsfm_dropout_rate: 0.1
|
19 |
|
|
|
12 |
down_sampling_kernel_size: 4
|
13 |
down_sampling_stride: 2
|
14 |
|
15 |
+
tsfm_hidden_size: 512
|
16 |
+
tsfm_attention_heads: 8
|
17 |
tsfm_num_blocks: 6
|
18 |
tsfm_dropout_rate: 0.1
|
19 |
|