HoneyTian commited on
Commit
b9f223d
·
1 Parent(s): 2ebb5f8
examples/conv_tasnet/run.sh CHANGED
@@ -3,10 +3,10 @@
3
  : <<'END'
4
 
5
 
6
- sh run.sh --stage 1 --stop_stage 1 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
- --max_epochs 200
10
 
11
 
12
  END
 
3
  : <<'END'
4
 
5
 
6
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
+ --max_epochs 400
10
 
11
 
12
  END
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -47,6 +47,8 @@ def get_args():
47
 
48
  parser.add_argument("--config_file", default="config.yaml", type=str)
49
 
 
 
50
  args = parser.parse_args()
51
  return args
52
 
@@ -115,10 +117,10 @@ def main():
115
 
116
  logger = logging_config(serialization_dir)
117
 
118
- random.seed(config.seed)
119
- np.random.seed(config.seed)
120
- torch.manual_seed(config.seed)
121
- logger.info(f"set seed: {config.seed}")
122
 
123
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
124
  n_gpu = torch.cuda.device_count()
 
47
 
48
  parser.add_argument("--config_file", default="config.yaml", type=str)
49
 
50
+ parser.add_argument("--seed", default=1234, type=int)
51
+
52
  args = parser.parse_args()
53
  return args
54
 
 
117
 
118
  logger = logging_config(serialization_dir)
119
 
120
+ random.seed(args.seed)
121
+ np.random.seed(args.seed)
122
+ torch.manual_seed(args.seed)
123
+ logger.info(f"set seed: {args.seed}")
124
 
125
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
  n_gpu = torch.cuda.device_count()
requirements-python-3-9-9.txt CHANGED
@@ -12,3 +12,4 @@ torch-pesq==0.1.2
12
  torchmetrics==1.6.1
13
  torchmetrics[audio]==1.6.1
14
  einops==0.8.1
 
 
12
  torchmetrics==1.6.1
13
  torchmetrics[audio]==1.6.1
14
  einops==0.8.1
15
+ torch_stoi==0.2.3