HoneyTian commited on
Commit
b27ed9f
·
1 Parent(s): e27a095
examples/clean_unet_aishell/run.sh CHANGED
@@ -14,7 +14,8 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
14
 
15
  sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
- --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
 
18
 
19
 
20
  END
@@ -35,6 +36,8 @@ limit=10
35
  noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
36
  speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
37
 
 
 
38
  nohup_name=nohup.out
39
 
40
  # model params
@@ -101,6 +104,7 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
101
  --speech_dir "${speech_dir}" \
102
  --train_dataset "${train_dataset}" \
103
  --valid_dataset "${valid_dataset}" \
 
104
 
105
  fi
106
 
 
14
 
15
  sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir \
16
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
18
+ --max_count 10000
19
 
20
 
21
  END
 
36
  noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
37
  speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
38
 
39
+ max_count=10000000
40
+
41
  nohup_name=nohup.out
42
 
43
  # model params
 
104
  --speech_dir "${speech_dir}" \
105
  --train_dataset "${train_dataset}" \
106
  --valid_dataset "${valid_dataset}" \
107
+ --max_count "${max_count}" \
108
 
109
  fi
110
 
examples/clean_unet_aishell/step_1_prepare_data.py CHANGED
@@ -42,7 +42,7 @@ def get_args():
42
 
43
  parser.add_argument("--target_sample_rate", default=8000, type=int)
44
 
45
- parser.add_argument("--scale", default=1, type=float)
46
 
47
  args = parser.parse_args()
48
  return args
@@ -101,9 +101,8 @@ def get_dataset(args):
101
  count = 0
102
  process_bar = tqdm(desc="build dataset excel")
103
  for noise, speech in zip(noise_generator, speech_generator):
104
- flag = random.random()
105
- if flag > args.scale:
106
- continue
107
 
108
  noise_filename = noise["filename"]
109
  noise_raw_duration = noise["raw_duration"]
 
42
 
43
  parser.add_argument("--target_sample_rate", default=8000, type=int)
44
 
45
+ parser.add_argument("--max_count", default=10000, type=int)
46
 
47
  args = parser.parse_args()
48
  return args
 
101
  count = 0
102
  process_bar = tqdm(desc="build dataset excel")
103
  for noise, speech in zip(noise_generator, speech_generator):
104
+ if count > args.max_count:
105
+ break
 
106
 
107
  noise_filename = noise["filename"]
108
  noise_raw_duration = noise["raw_duration"]
examples/clean_unet_aishell/step_2_train_model.py CHANGED
@@ -44,10 +44,10 @@ def get_args():
44
 
45
  parser.add_argument("--batch_size", default=64, type=int)
46
  parser.add_argument("--learning_rate", default=2e-4, type=float)
47
-
48
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
49
  parser.add_argument("--patience", default=5, type=int)
50
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
 
51
 
52
  parser.add_argument("--config_file", default="config.yaml", type=str)
53
 
@@ -119,10 +119,10 @@ def main():
119
 
120
  logger = logging_config(serialization_dir)
121
 
122
- random.seed(config.seed)
123
- np.random.seed(config.seed)
124
- torch.manual_seed(config.seed)
125
- logger.info(f"set seed: {config.seed}")
126
 
127
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
  n_gpu = torch.cuda.device_count()
@@ -141,7 +141,7 @@ def main():
141
  )
142
  train_data_loader = DataLoader(
143
  dataset=train_dataset,
144
- batch_size=config.batch_size,
145
  shuffle=True,
146
  sampler=None,
147
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
@@ -152,7 +152,7 @@ def main():
152
  )
153
  valid_data_loader = DataLoader(
154
  dataset=valid_dataset,
155
- batch_size=config.batch_size,
156
  shuffle=True,
157
  sampler=None,
158
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
@@ -168,7 +168,7 @@ def main():
168
 
169
  # optimizer
170
  logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
171
- optimizer = torch.optim.AdamW(model.parameters(), config.learning_rate)
172
  lr_scheduler = LinearWarmupCosineDecay(
173
  optimizer,
174
  lr_max=args.learning_rate,
 
44
 
45
  parser.add_argument("--batch_size", default=64, type=int)
46
  parser.add_argument("--learning_rate", default=2e-4, type=float)
 
47
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
48
  parser.add_argument("--patience", default=5, type=int)
49
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
50
+ parser.add_argument("--seed", default=0, type=int)
51
 
52
  parser.add_argument("--config_file", default="config.yaml", type=str)
53
 
 
119
 
120
  logger = logging_config(serialization_dir)
121
 
122
+ random.seed(args.seed)
123
+ np.random.seed(args.seed)
124
+ torch.manual_seed(args.seed)
125
+ logger.info(f"set seed: {args.seed}")
126
 
127
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
  n_gpu = torch.cuda.device_count()
 
141
  )
142
  train_data_loader = DataLoader(
143
  dataset=train_dataset,
144
+ batch_size=args.batch_size,
145
  shuffle=True,
146
  sampler=None,
147
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
 
152
  )
153
  valid_data_loader = DataLoader(
154
  dataset=valid_dataset,
155
+ batch_size=args.batch_size,
156
  shuffle=True,
157
  sampler=None,
158
  # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
 
168
 
169
  # optimizer
170
  logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
171
+ optimizer = torch.optim.AdamW(model.parameters(), args.learning_rate)
172
  lr_scheduler = LinearWarmupCosineDecay(
173
  optimizer,
174
  lr_max=args.learning_rate,
examples/clean_unet_aishell/yaml/config.yaml CHANGED
@@ -11,3 +11,4 @@ tsfm_n_layers: 5
11
  tsfm_n_head: 8
12
  tsfm_d_model: 512
13
  tsfm_d_inner: 2048
 
 
11
  tsfm_n_head: 8
12
  tsfm_d_model: 512
13
  tsfm_d_inner: 2048
14
+
examples/{mpnet_aishell → mpnet}/run.sh RENAMED
File without changes
examples/{mpnet_aishell → mpnet}/step_1_prepare_data.py RENAMED
File without changes
examples/{mpnet_aishell → mpnet}/step_2_train_model.py RENAMED
File without changes
examples/{mpnet_aishell → mpnet}/step_3_evaluation.py RENAMED
File without changes
examples/{mpnet_aishell → mpnet}/yaml/config.yaml RENAMED
File without changes