Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/run.sh
CHANGED
@@ -3,10 +3,10 @@
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
-
sh run.sh --stage
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
9 |
-
--max_epochs
|
10 |
|
11 |
|
12 |
END
|
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
9 |
+
--max_epochs 400
|
10 |
|
11 |
|
12 |
END
|
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -47,6 +47,8 @@ def get_args():
|
|
47 |
|
48 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
49 |
|
|
|
|
|
50 |
args = parser.parse_args()
|
51 |
return args
|
52 |
|
@@ -115,10 +117,10 @@ def main():
|
|
115 |
|
116 |
logger = logging_config(serialization_dir)
|
117 |
|
118 |
-
random.seed(
|
119 |
-
np.random.seed(
|
120 |
-
torch.manual_seed(
|
121 |
-
logger.info(f"set seed: {
|
122 |
|
123 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
124 |
n_gpu = torch.cuda.device_count()
|
|
|
47 |
|
48 |
parser.add_argument("--config_file", default="config.yaml", type=str)
|
49 |
|
50 |
+
parser.add_argument("--seed", default=1234, type=int)
|
51 |
+
|
52 |
args = parser.parse_args()
|
53 |
return args
|
54 |
|
|
|
117 |
|
118 |
logger = logging_config(serialization_dir)
|
119 |
|
120 |
+
random.seed(args.seed)
|
121 |
+
np.random.seed(args.seed)
|
122 |
+
torch.manual_seed(args.seed)
|
123 |
+
logger.info(f"set seed: {args.seed}")
|
124 |
|
125 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
126 |
n_gpu = torch.cuda.device_count()
|
requirements-python-3-9-9.txt
CHANGED
@@ -12,3 +12,4 @@ torch-pesq==0.1.2
|
|
12 |
torchmetrics==1.6.1
|
13 |
torchmetrics[audio]==1.6.1
|
14 |
einops==0.8.1
|
|
|
|
12 |
torchmetrics==1.6.1
|
13 |
torchmetrics[audio]==1.6.1
|
14 |
einops==0.8.1
|
15 |
+
torch_stoi==0.2.3
|