Spaces:
Running
Running
update
Browse files- examples/dfnet2/run.sh +1 -1
- examples/dfnet2/step_1_prepare_data.py +2 -2
- examples/dfnet2/yaml/config.yaml +1 -1
- examples/dtln/run.sh +1 -1
- examples/dtln/step_1_prepare_data.py +2 -2
- main.py +11 -13
- toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py +16 -2
- toolbox/torchaudio/models/dfnet2/yaml/{config.yaml → config-200.yaml} +2 -2
- toolbox/torchaudio/models/dfnet2/yaml/config-512.yaml +75 -0
examples/dfnet2/run.sh
CHANGED
@@ -29,7 +29,7 @@ limit=10
|
|
29 |
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
30 |
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
31 |
|
32 |
-
max_count
|
33 |
|
34 |
nohup_name=nohup.out
|
35 |
|
|
|
29 |
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
30 |
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
31 |
|
32 |
+
max_count=-1
|
33 |
|
34 |
nohup_name=nohup.out
|
35 |
|
examples/dfnet2/step_1_prepare_data.py
CHANGED
@@ -33,13 +33,13 @@ def get_args():
|
|
33 |
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
34 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
35 |
|
36 |
-
parser.add_argument("--duration", default=
|
37 |
parser.add_argument("--min_snr_db", default=-10, type=float)
|
38 |
parser.add_argument("--max_snr_db", default=20, type=float)
|
39 |
|
40 |
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
41 |
|
42 |
-
parser.add_argument("--max_count", default
|
43 |
|
44 |
args = parser.parse_args()
|
45 |
return args
|
|
|
33 |
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
34 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
35 |
|
36 |
+
parser.add_argument("--duration", default=2.0, type=float)
|
37 |
parser.add_argument("--min_snr_db", default=-10, type=float)
|
38 |
parser.add_argument("--max_snr_db", default=20, type=float)
|
39 |
|
40 |
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
41 |
|
42 |
+
parser.add_argument("--max_count", default=-1, type=int)
|
43 |
|
44 |
args = parser.parse_args()
|
45 |
return args
|
examples/dfnet2/yaml/config.yaml
CHANGED
@@ -68,7 +68,7 @@ clip_grad_norm: 10.0
|
|
68 |
seed: 1234
|
69 |
|
70 |
num_workers: 8
|
71 |
-
batch_size:
|
72 |
eval_steps: 10000
|
73 |
|
74 |
# runtime
|
|
|
68 |
seed: 1234
|
69 |
|
70 |
num_workers: 8
|
71 |
+
batch_size: 96
|
72 |
eval_steps: 10000
|
73 |
|
74 |
# runtime
|
examples/dtln/run.sh
CHANGED
@@ -31,7 +31,7 @@ limit=10
|
|
31 |
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
32 |
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
33 |
|
34 |
-
max_count
|
35 |
|
36 |
nohup_name=nohup.out
|
37 |
|
|
|
31 |
noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
|
32 |
speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
|
33 |
|
34 |
+
max_count=-1
|
35 |
|
36 |
nohup_name=nohup.out
|
37 |
|
examples/dtln/step_1_prepare_data.py
CHANGED
@@ -33,13 +33,13 @@ def get_args():
|
|
33 |
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
34 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
35 |
|
36 |
-
parser.add_argument("--duration", default=
|
37 |
parser.add_argument("--min_snr_db", default=-10, type=float)
|
38 |
parser.add_argument("--max_snr_db", default=20, type=float)
|
39 |
|
40 |
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
41 |
|
42 |
-
parser.add_argument("--max_count", default
|
43 |
|
44 |
args = parser.parse_args()
|
45 |
return args
|
|
|
33 |
parser.add_argument("--train_dataset", default="train.jsonl", type=str)
|
34 |
parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
|
35 |
|
36 |
+
parser.add_argument("--duration", default=2.0, type=float)
|
37 |
parser.add_argument("--min_snr_db", default=-10, type=float)
|
38 |
parser.add_argument("--max_snr_db", default=20, type=float)
|
39 |
|
40 |
parser.add_argument("--target_sample_rate", default=8000, type=int)
|
41 |
|
42 |
+
parser.add_argument("--max_count", default=-1, type=int)
|
43 |
|
44 |
args = parser.parse_args()
|
45 |
return args
|
main.py
CHANGED
@@ -72,22 +72,22 @@ def shell(cmd: str):
|
|
72 |
|
73 |
|
74 |
denoise_engines = {
|
75 |
-
"dtln-nx-dns3": {
|
76 |
"infer_cls": InferenceDTLN,
|
77 |
"kwargs": {
|
78 |
-
"pretrained_model_path_or_zip_file": (project_path / "trained_models/dtln-nx-dns3.zip").as_posix()
|
79 |
}
|
80 |
},
|
81 |
-
"
|
82 |
-
"infer_cls":
|
83 |
"kwargs": {
|
84 |
-
"pretrained_model_path_or_zip_file": (project_path / "trained_models/
|
85 |
}
|
86 |
},
|
87 |
-
"
|
88 |
-
"infer_cls":
|
89 |
"kwargs": {
|
90 |
-
"pretrained_model_path_or_zip_file": (project_path / "trained_models/
|
91 |
}
|
92 |
},
|
93 |
"frcrn-dns3": {
|
@@ -114,13 +114,11 @@ def load_denoise_model(infer_cls, **kwargs):
|
|
114 |
|
115 |
def generate_spectrogram(signal: np.ndarray, sample_rate: int = 8000, title: str = "Spectrogram"):
|
116 |
mag = np.abs(librosa.stft(signal))
|
117 |
-
mag_db = librosa.amplitude_to_db(mag, ref=np.max)
|
|
|
118 |
|
119 |
-
|
120 |
-
plt.figure(figsize=(10, 3))
|
121 |
librosa.display.specshow(mag_db, sr=sample_rate)
|
122 |
-
# librosa.display.specshow(mag_db, sr=sample_rate, x_axis='time', y_axis='log')
|
123 |
-
# plt.colorbar(format='%+2.0f dB')
|
124 |
plt.title(title)
|
125 |
|
126 |
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
|
|
72 |
|
73 |
|
74 |
denoise_engines = {
|
75 |
+
"dtln-256-nx-dns3": {
|
76 |
"infer_cls": InferenceDTLN,
|
77 |
"kwargs": {
|
78 |
+
"pretrained_model_path_or_zip_file": (project_path / "trained_models/dtln-256-nx-dns3.zip").as_posix()
|
79 |
}
|
80 |
},
|
81 |
+
"dtln-512-nx-dns3": {
|
82 |
+
"infer_cls": InferenceDTLN,
|
83 |
"kwargs": {
|
84 |
+
"pretrained_model_path_or_zip_file": (project_path / "trained_models/dtln-512-nx-dns3.zip").as_posix()
|
85 |
}
|
86 |
},
|
87 |
+
"dfnet2-nx-dns3": {
|
88 |
+
"infer_cls": InferenceDfNet2,
|
89 |
"kwargs": {
|
90 |
+
"pretrained_model_path_or_zip_file": (project_path / "trained_models/dfnet2-nx-dns3.zip").as_posix()
|
91 |
}
|
92 |
},
|
93 |
"frcrn-dns3": {
|
|
|
114 |
|
115 |
def generate_spectrogram(signal: np.ndarray, sample_rate: int = 8000, title: str = "Spectrogram"):
|
116 |
mag = np.abs(librosa.stft(signal))
|
117 |
+
# mag_db = librosa.amplitude_to_db(mag, ref=np.max)
|
118 |
+
mag_db = librosa.amplitude_to_db(mag, ref=20)
|
119 |
|
120 |
+
plt.figure(figsize=(10, 4))
|
|
|
121 |
librosa.display.specshow(mag_db, sr=sample_rate)
|
|
|
|
|
122 |
plt.title(title)
|
123 |
|
124 |
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
|
toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py
CHANGED
@@ -1464,7 +1464,14 @@ def main():
|
|
1464 |
import time
|
1465 |
# torch.set_num_threads(1)
|
1466 |
|
1467 |
-
config = DfNet2Config(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1468 |
model = DfNet2PretrainedModel(config=config)
|
1469 |
model.eval()
|
1470 |
|
@@ -1473,7 +1480,8 @@ def main():
|
|
1473 |
duration = num_samples / config.sample_rate
|
1474 |
|
1475 |
begin = time.time()
|
1476 |
-
|
|
|
1477 |
time_cost = time.time() - begin
|
1478 |
print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
|
1479 |
|
@@ -1485,6 +1493,9 @@ def main():
|
|
1485 |
waveform = est_wav
|
1486 |
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
1487 |
print(waveform[:, :, 300: 302])
|
|
|
|
|
|
|
1488 |
print(waveform[:, :, 15680: 15682])
|
1489 |
print(waveform[:, :, 15760: 15762])
|
1490 |
print(waveform[:, :, 15840: 15842])
|
@@ -1497,6 +1508,9 @@ def main():
|
|
1497 |
waveform = waveform[:, :, (config.df_lookahead*config.hop_size):]
|
1498 |
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
1499 |
print(waveform[:, :, 300: 302])
|
|
|
|
|
|
|
1500 |
print(waveform[:, :, 15680: 15682])
|
1501 |
print(waveform[:, :, 15760: 15762])
|
1502 |
print(waveform[:, :, 15840: 15842])
|
|
|
1464 |
import time
|
1465 |
# torch.set_num_threads(1)
|
1466 |
|
1467 |
+
config = DfNet2Config(
|
1468 |
+
# nfft=512,
|
1469 |
+
# win_size=200,
|
1470 |
+
# hop_size=80,
|
1471 |
+
nfft=512,
|
1472 |
+
win_size=512,
|
1473 |
+
hop_size=128,
|
1474 |
+
)
|
1475 |
model = DfNet2PretrainedModel(config=config)
|
1476 |
model.eval()
|
1477 |
|
|
|
1480 |
duration = num_samples / config.sample_rate
|
1481 |
|
1482 |
begin = time.time()
|
1483 |
+
with torch.no_grad():
|
1484 |
+
est_spec, est_wav, est_mask, lsnr = model.forward(noisy)
|
1485 |
time_cost = time.time() - begin
|
1486 |
print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
|
1487 |
|
|
|
1493 |
waveform = est_wav
|
1494 |
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
1495 |
print(waveform[:, :, 300: 302])
|
1496 |
+
print(waveform[:, :, 1000: 1002])
|
1497 |
+
print(waveform[:, :, 8000: 8002])
|
1498 |
+
print(waveform[:, :, 14000: 14002])
|
1499 |
print(waveform[:, :, 15680: 15682])
|
1500 |
print(waveform[:, :, 15760: 15762])
|
1501 |
print(waveform[:, :, 15840: 15842])
|
|
|
1508 |
waveform = waveform[:, :, (config.df_lookahead*config.hop_size):]
|
1509 |
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
|
1510 |
print(waveform[:, :, 300: 302])
|
1511 |
+
print(waveform[:, :, 1000: 1002])
|
1512 |
+
print(waveform[:, :, 8000: 8002])
|
1513 |
+
print(waveform[:, :, 14000: 14002])
|
1514 |
print(waveform[:, :, 15680: 15682])
|
1515 |
print(waveform[:, :, 15760: 15762])
|
1516 |
print(waveform[:, :, 15840: 15842])
|
toolbox/torchaudio/models/dfnet2/yaml/{config.yaml → config-200.yaml}
RENAMED
@@ -1,4 +1,4 @@
|
|
1 |
-
model_name: "
|
2 |
|
3 |
# spec
|
4 |
sample_rate: 8000
|
@@ -68,7 +68,7 @@ clip_grad_norm: 10.0
|
|
68 |
seed: 1234
|
69 |
|
70 |
num_workers: 8
|
71 |
-
batch_size:
|
72 |
eval_steps: 10000
|
73 |
|
74 |
# runtime
|
|
|
1 |
+
model_name: "dfnet2"
|
2 |
|
3 |
# spec
|
4 |
sample_rate: 8000
|
|
|
68 |
seed: 1234
|
69 |
|
70 |
num_workers: 8
|
71 |
+
batch_size: 96
|
72 |
eval_steps: 10000
|
73 |
|
74 |
# runtime
|
toolbox/torchaudio/models/dfnet2/yaml/config-512.yaml
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "dfnet"
|
2 |
+
|
3 |
+
# spec
|
4 |
+
sample_rate: 8000
|
5 |
+
nfft: 512
|
6 |
+
win_size: 512
|
7 |
+
hop_size: 128
|
8 |
+
|
9 |
+
spec_bins: 256
|
10 |
+
erb_bins: 32
|
11 |
+
min_freq_bins_for_erb: 2
|
12 |
+
use_ema_norm: true
|
13 |
+
|
14 |
+
# model
|
15 |
+
conv_channels: 64
|
16 |
+
conv_kernel_size_input:
|
17 |
+
- 3
|
18 |
+
- 3
|
19 |
+
conv_kernel_size_inner:
|
20 |
+
- 1
|
21 |
+
- 3
|
22 |
+
convt_kernel_size_inner:
|
23 |
+
- 1
|
24 |
+
- 3
|
25 |
+
|
26 |
+
embedding_hidden_size: 256
|
27 |
+
encoder_combine_op: "concat"
|
28 |
+
|
29 |
+
encoder_emb_skip_op: "none"
|
30 |
+
encoder_emb_linear_groups: 16
|
31 |
+
encoder_emb_hidden_size: 256
|
32 |
+
|
33 |
+
encoder_linear_groups: 32
|
34 |
+
|
35 |
+
decoder_emb_num_layers: 3
|
36 |
+
decoder_emb_skip_op: "none"
|
37 |
+
decoder_emb_linear_groups: 16
|
38 |
+
decoder_emb_hidden_size: 256
|
39 |
+
|
40 |
+
df_decoder_hidden_size: 256
|
41 |
+
df_num_layers: 2
|
42 |
+
df_order: 5
|
43 |
+
df_bins: 96
|
44 |
+
df_gru_skip: "grouped_linear"
|
45 |
+
df_decoder_linear_groups: 16
|
46 |
+
df_pathway_kernel_size_t: 5
|
47 |
+
df_lookahead: 2
|
48 |
+
|
49 |
+
# lsnr
|
50 |
+
n_frame: 3
|
51 |
+
lsnr_max: 30
|
52 |
+
lsnr_min: -15
|
53 |
+
norm_tau: 1.
|
54 |
+
|
55 |
+
# data
|
56 |
+
min_snr_db: -10
|
57 |
+
max_snr_db: 20
|
58 |
+
|
59 |
+
# train
|
60 |
+
lr: 0.001
|
61 |
+
lr_scheduler: "CosineAnnealingLR"
|
62 |
+
lr_scheduler_kwargs:
|
63 |
+
T_max: 250000
|
64 |
+
eta_min: 0.0001
|
65 |
+
|
66 |
+
max_epochs: 100
|
67 |
+
clip_grad_norm: 10.0
|
68 |
+
seed: 1234
|
69 |
+
|
70 |
+
num_workers: 8
|
71 |
+
batch_size: 96
|
72 |
+
eval_steps: 10000
|
73 |
+
|
74 |
+
# runtime
|
75 |
+
use_post_filter: true
|