HoneyTian commited on
Commit
b408ac3
·
1 Parent(s): 10059e6
examples/conv_tasnet/run.sh CHANGED
@@ -3,7 +3,7 @@
3
  : <<'END'
4
 
5
 
6
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
  --max_epochs 400
 
3
  : <<'END'
4
 
5
 
6
+ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
  --max_epochs 400
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -139,11 +139,17 @@ def main():
139
  jsonl_file=args.train_dataset,
140
  expected_sample_rate=config.sample_rate,
141
  max_wave_value=32768.0,
 
 
 
142
  )
143
  valid_dataset = DenoiseJsonlDataset(
144
  jsonl_file=args.valid_dataset,
145
  expected_sample_rate=config.sample_rate,
146
  max_wave_value=32768.0,
 
 
 
147
  )
148
  train_data_loader = DataLoader(
149
  dataset=train_dataset,
 
139
  jsonl_file=args.train_dataset,
140
  expected_sample_rate=config.sample_rate,
141
  max_wave_value=32768.0,
142
+ min_snr_db=config.min_snr_db,
143
+ max_snr_db=config.max_snr_db,
144
+ # skip=625000,
145
  )
146
  valid_dataset = DenoiseJsonlDataset(
147
  jsonl_file=args.valid_dataset,
148
  expected_sample_rate=config.sample_rate,
149
  max_wave_value=32768.0,
150
+ min_snr_db=config.min_snr_db,
151
+ max_snr_db=config.max_snr_db,
152
+ # skip=625000,
153
  )
154
  train_data_loader = DataLoader(
155
  dataset=train_dataset,
examples/conv_tasnet/yaml/config.yaml CHANGED
@@ -16,6 +16,9 @@ norm_type: "gLN"
16
  causal: false
17
  mask_nonlinear: "relu"
18
 
 
 
 
19
  lr: 0.001
20
  lr_scheduler: "CosineAnnealingLR"
21
  lr_scheduler_kwargs:
 
16
  causal: false
17
  mask_nonlinear: "relu"
18
 
19
+ min_snr_db: -10
20
+ max_snr_db: 20
21
+
22
  lr: 0.001
23
  lr_scheduler: "CosineAnnealingLR"
24
  lr_scheduler_kwargs:
toolbox/torch/utils/data/dataset/denoise_jsonl_dataset.py CHANGED
@@ -1,18 +1,13 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import json
4
- import os
5
  import random
6
  from typing import List
7
 
8
  import librosa
9
  import numpy as np
10
- import pandas as pd
11
- from scipy.io import wavfile
12
  import torch
13
- import torchaudio
14
  from torch.utils.data import Dataset, IterableDataset
15
- from tqdm import tqdm
16
 
17
 
18
  class DenoiseJsonlDataset(IterableDataset):
@@ -22,13 +17,19 @@ class DenoiseJsonlDataset(IterableDataset):
22
  resample: bool = False,
23
  max_wave_value: float = 1.0,
24
  buffer_size: int = 1000,
 
 
25
  eps: float = 1e-8,
 
26
  ):
27
  self.jsonl_file = jsonl_file
28
  self.expected_sample_rate = expected_sample_rate
29
  self.resample = resample
30
  self.max_wave_value = max_wave_value
 
 
31
  self.eps = eps
 
32
 
33
  self.buffer_size = buffer_size
34
  self.buffer_samples: List[dict] = list()
@@ -36,6 +37,12 @@ class DenoiseJsonlDataset(IterableDataset):
36
  def __iter__(self):
37
  iterable_source = self.iterable_source()
38
 
 
 
 
 
 
 
39
  # 初始填充缓冲区
40
  try:
41
  for _ in range(self.buffer_size):
@@ -74,7 +81,10 @@ class DenoiseJsonlDataset(IterableDataset):
74
  speech_offset = row["speech_offset"]
75
  speech_duration = row["speech_duration"]
76
 
77
- snr_db = row["snr_db"]
 
 
 
78
 
79
  sample = {
80
  "noise_filename": noise_filename,
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  import json
 
4
  import random
5
  from typing import List
6
 
7
  import librosa
8
  import numpy as np
 
 
9
  import torch
 
10
  from torch.utils.data import Dataset, IterableDataset
 
11
 
12
 
13
  class DenoiseJsonlDataset(IterableDataset):
 
17
  resample: bool = False,
18
  max_wave_value: float = 1.0,
19
  buffer_size: int = 1000,
20
+ min_snr_db: float = None,
21
+ max_snr_db: float = None,
22
  eps: float = 1e-8,
23
+ skip: int = 0,
24
  ):
25
  self.jsonl_file = jsonl_file
26
  self.expected_sample_rate = expected_sample_rate
27
  self.resample = resample
28
  self.max_wave_value = max_wave_value
29
+ self.min_snr_db = min_snr_db
30
+ self.max_snr_db = max_snr_db
31
  self.eps = eps
32
+ self.skip = skip
33
 
34
  self.buffer_size = buffer_size
35
  self.buffer_samples: List[dict] = list()
 
37
  def __iter__(self):
38
  iterable_source = self.iterable_source()
39
 
40
+ try:
41
+ for _ in range(self.skip):
42
+ next(iterable_source)
43
+ except StopIteration:
44
+ pass
45
+
46
  # 初始填充缓冲区
47
  try:
48
  for _ in range(self.buffer_size):
 
81
  speech_offset = row["speech_offset"]
82
  speech_duration = row["speech_duration"]
83
 
84
+ if self.min_snr_db is None or self.max_snr_db is None:
85
+ snr_db = row["snr_db"]
86
+ else:
87
+ snr_db = random.uniform(self.min_snr_db, self.max_snr_db)
88
 
89
  sample = {
90
  "noise_filename": noise_filename,
toolbox/torchaudio/models/conv_tasnet/configuration_conv_tasnet.py CHANGED
@@ -27,12 +27,15 @@ class ConvTasNetConfig(PretrainedConfig):
27
  causal: bool = False,
28
  mask_nonlinear: str = "relu",
29
 
30
- lr: float = 1e-3,
31
- eval_steps: int = 25000,
32
 
 
33
  lr_scheduler: str = "CosineAnnealingLR",
34
  lr_scheduler_kwargs: dict = None,
35
 
 
 
36
  **kwargs
37
  ):
38
  super(ConvTasNetConfig, self).__init__(**kwargs)
@@ -53,12 +56,15 @@ class ConvTasNetConfig(PretrainedConfig):
53
  self.causal = causal
54
  self.mask_nonlinear = mask_nonlinear
55
 
56
- self.lr = lr
57
- self.eval_steps = eval_steps
58
 
 
59
  self.lr_scheduler = lr_scheduler
60
  self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
61
 
 
 
62
 
63
  if __name__ == "__main__":
64
  pass
 
27
  causal: bool = False,
28
  mask_nonlinear: str = "relu",
29
 
30
+ min_snr_db: float = -10,
31
+ max_snr_db: float = 20,
32
 
33
+ lr: float = 1e-3,
34
  lr_scheduler: str = "CosineAnnealingLR",
35
  lr_scheduler_kwargs: dict = None,
36
 
37
+ eval_steps: int = 25000,
38
+
39
  **kwargs
40
  ):
41
  super(ConvTasNetConfig, self).__init__(**kwargs)
 
56
  self.causal = causal
57
  self.mask_nonlinear = mask_nonlinear
58
 
59
+ self.min_snr_db = min_snr_db
60
+ self.max_snr_db = max_snr_db
61
 
62
+ self.lr = lr
63
  self.lr_scheduler = lr_scheduler
64
  self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
65
 
66
+ self.eval_steps = eval_steps
67
+
68
 
69
  if __name__ == "__main__":
70
  pass