Spaces:
Running
Running
update
Browse files
examples/conv_tasnet/run.sh
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
-
sh run.sh --stage
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
9 |
--max_epochs 400
|
|
|
3 |
: <<'END'
|
4 |
|
5 |
|
6 |
+
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
|
9 |
--max_epochs 400
|
examples/conv_tasnet/step_2_train_model.py
CHANGED
@@ -139,11 +139,17 @@ def main():
|
|
139 |
jsonl_file=args.train_dataset,
|
140 |
expected_sample_rate=config.sample_rate,
|
141 |
max_wave_value=32768.0,
|
|
|
|
|
|
|
142 |
)
|
143 |
valid_dataset = DenoiseJsonlDataset(
|
144 |
jsonl_file=args.valid_dataset,
|
145 |
expected_sample_rate=config.sample_rate,
|
146 |
max_wave_value=32768.0,
|
|
|
|
|
|
|
147 |
)
|
148 |
train_data_loader = DataLoader(
|
149 |
dataset=train_dataset,
|
|
|
139 |
jsonl_file=args.train_dataset,
|
140 |
expected_sample_rate=config.sample_rate,
|
141 |
max_wave_value=32768.0,
|
142 |
+
min_snr_db=config.min_snr_db,
|
143 |
+
max_snr_db=config.max_snr_db,
|
144 |
+
# skip=625000,
|
145 |
)
|
146 |
valid_dataset = DenoiseJsonlDataset(
|
147 |
jsonl_file=args.valid_dataset,
|
148 |
expected_sample_rate=config.sample_rate,
|
149 |
max_wave_value=32768.0,
|
150 |
+
min_snr_db=config.min_snr_db,
|
151 |
+
max_snr_db=config.max_snr_db,
|
152 |
+
# skip=625000,
|
153 |
)
|
154 |
train_data_loader = DataLoader(
|
155 |
dataset=train_dataset,
|
examples/conv_tasnet/yaml/config.yaml
CHANGED
@@ -16,6 +16,9 @@ norm_type: "gLN"
|
|
16 |
causal: false
|
17 |
mask_nonlinear: "relu"
|
18 |
|
|
|
|
|
|
|
19 |
lr: 0.001
|
20 |
lr_scheduler: "CosineAnnealingLR"
|
21 |
lr_scheduler_kwargs:
|
|
|
16 |
causal: false
|
17 |
mask_nonlinear: "relu"
|
18 |
|
19 |
+
min_snr_db: -10
|
20 |
+
max_snr_db: 20
|
21 |
+
|
22 |
lr: 0.001
|
23 |
lr_scheduler: "CosineAnnealingLR"
|
24 |
lr_scheduler_kwargs:
|
toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py
CHANGED
@@ -1,18 +1,13 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
import json
|
4 |
-
import os
|
5 |
import random
|
6 |
from typing import List
|
7 |
|
8 |
import librosa
|
9 |
import numpy as np
|
10 |
-
import pandas as pd
|
11 |
-
from scipy.io import wavfile
|
12 |
import torch
|
13 |
-
import torchaudio
|
14 |
from torch.utils.data import Dataset, IterableDataset
|
15 |
-
from tqdm import tqdm
|
16 |
|
17 |
|
18 |
class DenoiseJsonlDataset(IterableDataset):
|
@@ -22,13 +17,19 @@ class DenoiseJsonlDataset(IterableDataset):
|
|
22 |
resample: bool = False,
|
23 |
max_wave_value: float = 1.0,
|
24 |
buffer_size: int = 1000,
|
|
|
|
|
25 |
eps: float = 1e-8,
|
|
|
26 |
):
|
27 |
self.jsonl_file = jsonl_file
|
28 |
self.expected_sample_rate = expected_sample_rate
|
29 |
self.resample = resample
|
30 |
self.max_wave_value = max_wave_value
|
|
|
|
|
31 |
self.eps = eps
|
|
|
32 |
|
33 |
self.buffer_size = buffer_size
|
34 |
self.buffer_samples: List[dict] = list()
|
@@ -36,6 +37,12 @@ class DenoiseJsonlDataset(IterableDataset):
|
|
36 |
def __iter__(self):
|
37 |
iterable_source = self.iterable_source()
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
# 初始填充缓冲区
|
40 |
try:
|
41 |
for _ in range(self.buffer_size):
|
@@ -74,7 +81,10 @@ class DenoiseJsonlDataset(IterableDataset):
|
|
74 |
speech_offset = row["speech_offset"]
|
75 |
speech_duration = row["speech_duration"]
|
76 |
|
77 |
-
|
|
|
|
|
|
|
78 |
|
79 |
sample = {
|
80 |
"noise_filename": noise_filename,
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
import json
|
|
|
4 |
import random
|
5 |
from typing import List
|
6 |
|
7 |
import librosa
|
8 |
import numpy as np
|
|
|
|
|
9 |
import torch
|
|
|
10 |
from torch.utils.data import Dataset, IterableDataset
|
|
|
11 |
|
12 |
|
13 |
class DenoiseJsonlDataset(IterableDataset):
|
|
|
17 |
resample: bool = False,
|
18 |
max_wave_value: float = 1.0,
|
19 |
buffer_size: int = 1000,
|
20 |
+
min_snr_db: float = None,
|
21 |
+
max_snr_db: float = None,
|
22 |
eps: float = 1e-8,
|
23 |
+
skip: int = 0,
|
24 |
):
|
25 |
self.jsonl_file = jsonl_file
|
26 |
self.expected_sample_rate = expected_sample_rate
|
27 |
self.resample = resample
|
28 |
self.max_wave_value = max_wave_value
|
29 |
+
self.min_snr_db = min_snr_db
|
30 |
+
self.max_snr_db = max_snr_db
|
31 |
self.eps = eps
|
32 |
+
self.skip = skip
|
33 |
|
34 |
self.buffer_size = buffer_size
|
35 |
self.buffer_samples: List[dict] = list()
|
|
|
37 |
def __iter__(self):
|
38 |
iterable_source = self.iterable_source()
|
39 |
|
40 |
+
try:
|
41 |
+
for _ in range(self.skip):
|
42 |
+
next(iterable_source)
|
43 |
+
except StopIteration:
|
44 |
+
pass
|
45 |
+
|
46 |
# 初始填充缓冲区
|
47 |
try:
|
48 |
for _ in range(self.buffer_size):
|
|
|
81 |
speech_offset = row["speech_offset"]
|
82 |
speech_duration = row["speech_duration"]
|
83 |
|
84 |
+
if self.min_snr_db is None or self.max_snr_db is None:
|
85 |
+
snr_db = row["snr_db"]
|
86 |
+
else:
|
87 |
+
snr_db = random.uniform(self.min_snr_db, self.max_snr_db)
|
88 |
|
89 |
sample = {
|
90 |
"noise_filename": noise_filename,
|
toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py
CHANGED
@@ -27,12 +27,15 @@ class ConvTasNetConfig(PretrainedConfig):
|
|
27 |
causal: bool = False,
|
28 |
mask_nonlinear: str = "relu",
|
29 |
|
30 |
-
|
31 |
-
|
32 |
|
|
|
33 |
lr_scheduler: str = "CosineAnnealingLR",
|
34 |
lr_scheduler_kwargs: dict = None,
|
35 |
|
|
|
|
|
36 |
**kwargs
|
37 |
):
|
38 |
super(ConvTasNetConfig, self).__init__(**kwargs)
|
@@ -53,12 +56,15 @@ class ConvTasNetConfig(PretrainedConfig):
|
|
53 |
self.causal = causal
|
54 |
self.mask_nonlinear = mask_nonlinear
|
55 |
|
56 |
-
self.
|
57 |
-
self.
|
58 |
|
|
|
59 |
self.lr_scheduler = lr_scheduler
|
60 |
self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
|
61 |
|
|
|
|
|
62 |
|
63 |
if __name__ == "__main__":
|
64 |
pass
|
|
|
27 |
causal: bool = False,
|
28 |
mask_nonlinear: str = "relu",
|
29 |
|
30 |
+
min_snr_db: float = -10,
|
31 |
+
max_snr_db: float = 20,
|
32 |
|
33 |
+
lr: float = 1e-3,
|
34 |
lr_scheduler: str = "CosineAnnealingLR",
|
35 |
lr_scheduler_kwargs: dict = None,
|
36 |
|
37 |
+
eval_steps: int = 25000,
|
38 |
+
|
39 |
**kwargs
|
40 |
):
|
41 |
super(ConvTasNetConfig, self).__init__(**kwargs)
|
|
|
56 |
self.causal = causal
|
57 |
self.mask_nonlinear = mask_nonlinear
|
58 |
|
59 |
+
self.min_snr_db = min_snr_db
|
60 |
+
self.max_snr_db = max_snr_db
|
61 |
|
62 |
+
self.lr = lr
|
63 |
self.lr_scheduler = lr_scheduler
|
64 |
self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
|
65 |
|
66 |
+
self.eval_steps = eval_steps
|
67 |
+
|
68 |
|
69 |
if __name__ == "__main__":
|
70 |
pass
|