HoneyTian commited on
Commit
9a47ac9
·
1 Parent(s): de8138b
examples/nx_clean_unet/step_2_train_model.py CHANGED
@@ -143,7 +143,7 @@ def main():
143
  num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
144
  collate_fn=collate_fn,
145
  pin_memory=False,
146
- # prefetch_factor=64,
147
  )
148
  valid_data_loader = DataLoader(
149
  dataset=valid_dataset,
@@ -154,7 +154,7 @@ def main():
154
  num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
155
  collate_fn=collate_fn,
156
  pin_memory=False,
157
- # prefetch_factor=64,
158
  )
159
 
160
  # models
 
143
  num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
144
  collate_fn=collate_fn,
145
  pin_memory=False,
146
+ prefetch_factor=16,
147
  )
148
  valid_data_loader = DataLoader(
149
  dataset=valid_dataset,
 
154
  num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
155
  collate_fn=collate_fn,
156
  pin_memory=False,
157
+ prefetch_factor=16,
158
  )
159
 
160
  # models
examples/nx_clean_unet/yaml/config.yaml CHANGED
@@ -20,7 +20,7 @@ tsfm_max_length: 1024
20
  tsfm_chunk_size: 1
21
  tsfm_num_left_chunks: 128
22
 
23
- discriminator_dim: 32
24
  discriminator_in_channel: 2
25
 
26
  compress_factor: 0.3
 
20
  tsfm_chunk_size: 1
21
  tsfm_num_left_chunks: 128
22
 
23
+ discriminator_dim: 16
24
  discriminator_in_channel: 2
25
 
26
  compress_factor: 0.3
toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py CHANGED
@@ -28,7 +28,7 @@ class NXCleanUNetConfig(PretrainedConfig):
28
  tsfm_chunk_size: int = 1,
29
  tsfm_num_left_chunks: int = 128,
30
 
31
- discriminator_dim: int = 32,
32
  discriminator_in_channel: int = 2,
33
 
34
  compress_factor: float = 0.3,
 
28
  tsfm_chunk_size: int = 1,
29
  tsfm_num_left_chunks: int = 128,
30
 
31
+ discriminator_dim: int = 16,
32
  discriminator_in_channel: int = 2,
33
 
34
  compress_factor: float = 0.3,
toolbox/torchaudio/models/nx_clean_unet/transformer/transformer.py CHANGED
@@ -562,6 +562,7 @@ def main():
562
  num_blocks=6,
563
  dropout_rate=0.1,
564
  )
 
565
 
566
  x = torch.ones([4, 200, 64])
567
 
 
562
  num_blocks=6,
563
  dropout_rate=0.1,
564
  )
565
+ print(encoder)
566
 
567
  x = torch.ones([4, 200, 64])
568
 
toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml CHANGED
@@ -12,15 +12,15 @@ down_sampling_hidden_channels: 64
12
  down_sampling_kernel_size: 4
13
  down_sampling_stride: 2
14
 
15
- tsfm_hidden_size: 512
16
- tsfm_attention_heads: 4
17
  tsfm_num_blocks: 6
18
  tsfm_dropout_rate: 0.1
19
  tsfm_max_length: 1024
20
  tsfm_chunk_size: 1
21
  tsfm_num_left_chunks: 128
22
 
23
- discriminator_dim: 32
24
  discriminator_in_channel: 2
25
 
26
  compress_factor: 0.3
 
12
  down_sampling_kernel_size: 4
13
  down_sampling_stride: 2
14
 
15
+ tsfm_hidden_size: 1024
16
+ tsfm_attention_heads: 8
17
  tsfm_num_blocks: 6
18
  tsfm_dropout_rate: 0.1
19
  tsfm_max_length: 1024
20
  tsfm_chunk_size: 1
21
  tsfm_num_left_chunks: 128
22
 
23
+ discriminator_dim: 16
24
  discriminator_in_channel: 2
25
 
26
  compress_factor: 0.3