HoneyTian commited on
Commit
29c8c0d
·
1 Parent(s): a8c2bc7
examples/frcrn/run.sh CHANGED
@@ -3,7 +3,7 @@
3
  : <<'END'
4
 
5
 
6
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn \
7
  --config_file "yaml/config-20-512.yaml" \
8
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
9
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
 
3
  : <<'END'
4
 
5
 
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-20-512-nx-dns3 \
7
  --config_file "yaml/config-20-512.yaml" \
8
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
9
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
examples/frcrn/step_2_train_model.py CHANGED
@@ -41,11 +41,11 @@ from toolbox.torchaudio.models.frcrn.modeling_frcrn import FRCRN, FRCRNPretraine
41
 
42
  def get_args():
43
  parser = argparse.ArgumentParser()
44
- parser.add_argument("--train_dataset", default="train.xlsx", type=str)
45
- parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
46
 
47
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
48
- parser.add_argument("--patience", default=10, type=int)
49
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
50
 
51
  parser.add_argument("--config_file", default="config.yaml", type=str)
@@ -185,7 +185,7 @@ def main():
185
  step_idx = int(step_idx)
186
  if step_idx > last_step_idx:
187
  last_step_idx = step_idx
188
- last_epoch = 0
189
 
190
  if last_step_idx != -1:
191
  logger.info(f"resume from steps-{last_step_idx}.")
 
41
 
42
  def get_args():
43
  parser = argparse.ArgumentParser()
44
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
45
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
46
 
47
  parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
48
+ parser.add_argument("--patience", default=30, type=int)
49
  parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
50
 
51
  parser.add_argument("--config_file", default="config.yaml", type=str)
 
185
  step_idx = int(step_idx)
186
  if step_idx > last_step_idx:
187
  last_step_idx = step_idx
188
+ # last_epoch = 0
189
 
190
  if last_step_idx != -1:
191
  logger.info(f"resume from steps-{last_step_idx}.")