Spaces:
Running
Running
update
Browse files- examples/nx_clean_unet/step_2_train_model.py +2 -2
- examples/nx_clean_unet/yaml/config.yaml +1 -1
- toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py +1 -1
- toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py +1 -0
- toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml +3 -3
examples/nx_clean_unet/step_2_train_model.py
CHANGED
@@ -143,7 +143,7 @@ def main():
|
|
143 |
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
144 |
collate_fn=collate_fn,
|
145 |
pin_memory=False,
|
146 |
-
|
147 |
)
|
148 |
valid_data_loader = DataLoader(
|
149 |
dataset=valid_dataset,
|
@@ -154,7 +154,7 @@ def main():
|
|
154 |
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
155 |
collate_fn=collate_fn,
|
156 |
pin_memory=False,
|
157 |
-
|
158 |
)
|
159 |
|
160 |
# models
|
|
|
143 |
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
144 |
collate_fn=collate_fn,
|
145 |
pin_memory=False,
|
146 |
+
prefetch_factor=16,
|
147 |
)
|
148 |
valid_data_loader = DataLoader(
|
149 |
dataset=valid_dataset,
|
|
|
154 |
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
155 |
collate_fn=collate_fn,
|
156 |
pin_memory=False,
|
157 |
+
prefetch_factor=16,
|
158 |
)
|
159 |
|
160 |
# models
|
examples/nx_clean_unet/yaml/config.yaml
CHANGED
@@ -20,7 +20,7 @@ tsfm_max_length: 1024
|
|
20 |
tsfm_chunk_size: 1
|
21 |
tsfm_num_left_chunks: 128
|
22 |
|
23 |
-
discriminator_dim:
|
24 |
discriminator_in_channel: 2
|
25 |
|
26 |
compress_factor: 0.3
|
|
|
20 |
tsfm_chunk_size: 1
|
21 |
tsfm_num_left_chunks: 128
|
22 |
|
23 |
+
discriminator_dim: 16
|
24 |
discriminator_in_channel: 2
|
25 |
|
26 |
compress_factor: 0.3
|
toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py
CHANGED
@@ -28,7 +28,7 @@ class NXCleanUNetConfig(PretrainedConfig):
|
|
28 |
tsfm_chunk_size: int = 1,
|
29 |
tsfm_num_left_chunks: int = 128,
|
30 |
|
31 |
-
discriminator_dim: int =
|
32 |
discriminator_in_channel: int = 2,
|
33 |
|
34 |
compress_factor: float = 0.3,
|
|
|
28 |
tsfm_chunk_size: int = 1,
|
29 |
tsfm_num_left_chunks: int = 128,
|
30 |
|
31 |
+
discriminator_dim: int = 16,
|
32 |
discriminator_in_channel: int = 2,
|
33 |
|
34 |
compress_factor: float = 0.3,
|
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py
CHANGED
@@ -562,6 +562,7 @@ def main():
|
|
562 |
num_blocks=6,
|
563 |
dropout_rate=0.1,
|
564 |
)
|
|
|
565 |
|
566 |
x = torch.ones([4, 200, 64])
|
567 |
|
|
|
562 |
num_blocks=6,
|
563 |
dropout_rate=0.1,
|
564 |
)
|
565 |
+
print(encoder)
|
566 |
|
567 |
x = torch.ones([4, 200, 64])
|
568 |
|
toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml
CHANGED
@@ -12,15 +12,15 @@ down_sampling_hidden_channels: 64
|
|
12 |
down_sampling_kernel_size: 4
|
13 |
down_sampling_stride: 2
|
14 |
|
15 |
-
tsfm_hidden_size:
|
16 |
-
tsfm_attention_heads:
|
17 |
tsfm_num_blocks: 6
|
18 |
tsfm_dropout_rate: 0.1
|
19 |
tsfm_max_length: 1024
|
20 |
tsfm_chunk_size: 1
|
21 |
tsfm_num_left_chunks: 128
|
22 |
|
23 |
-
discriminator_dim:
|
24 |
discriminator_in_channel: 2
|
25 |
|
26 |
compress_factor: 0.3
|
|
|
12 |
down_sampling_kernel_size: 4
|
13 |
down_sampling_stride: 2
|
14 |
|
15 |
+
tsfm_hidden_size: 1024
|
16 |
+
tsfm_attention_heads: 8
|
17 |
tsfm_num_blocks: 6
|
18 |
tsfm_dropout_rate: 0.1
|
19 |
tsfm_max_length: 1024
|
20 |
tsfm_chunk_size: 1
|
21 |
tsfm_num_left_chunks: 128
|
22 |
|
23 |
+
discriminator_dim: 16
|
24 |
discriminator_in_channel: 2
|
25 |
|
26 |
compress_factor: 0.3
|