HoneyTian commited on
Commit
7f9e32d
·
1 Parent(s): e86d760
examples/conv_tasnet/run.sh CHANGED
@@ -3,25 +3,10 @@
3
  : <<'END'
4
 
5
 
6
- sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
7
- --noise_dir "E:/Users/tianx/HuggingDatasets/nx_noise/data/noise" \
8
- --speech_dir "E:/programmer/asr_datasets/aishell/data_aishell/wav/train"
9
-
10
-
11
- sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name file_dir --final_model_name mpnet-aishell-20250224 \
12
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
- --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
14
-
15
- sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name nx-clean-unet-aishell-20250228 \
16
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
17
- --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
18
- --max_epochs 100
19
-
20
-
21
- sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name mpnet-nx-speech-20250224 \
22
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
23
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" \
24
- --max_epochs 100 --max_count 10000
25
 
26
 
27
  END
 
3
  : <<'END'
4
 
5
 
6
+ sh run.sh --stage 1 --stop_stage 1 --system_version centos --file_folder_name file_dir --final_model_name conv-tasnet-dns3-20250319 \
7
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/dns3-noise" \
8
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" \
9
+ --max_epochs 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  END
examples/conv_tasnet/step_1_prepare_data.py CHANGED
@@ -54,28 +54,30 @@ def filename_generator(data_dir: str):
54
  yield filename.as_posix()
55
 
56
 
57
- def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
58
  data_dir = Path(data_dir)
59
- for filename in data_dir.glob("**/*.wav"):
60
- signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
61
- raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
62
-
63
- if raw_duration < duration:
64
- # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
65
- continue
66
- if signal.ndim != 1:
67
- raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
68
-
69
- signal_length = len(signal)
70
- win_size = int(duration * sample_rate)
71
- for begin in range(0, signal_length - win_size, win_size):
72
- row = {
73
- "filename": filename.as_posix(),
74
- "raw_duration": round(raw_duration, 4),
75
- "offset": round(begin / sample_rate, 4),
76
- "duration": round(duration, 4),
77
- }
78
- yield row
 
 
79
 
80
 
81
  def get_dataset(args):
@@ -88,12 +90,14 @@ def get_dataset(args):
88
  noise_generator = target_second_signal_generator(
89
  noise_dir.as_posix(),
90
  duration=args.duration,
91
- sample_rate=args.target_sample_rate
 
92
  )
93
  speech_generator = target_second_signal_generator(
94
  speech_dir.as_posix(),
95
  duration=args.duration,
96
- sample_rate=args.target_sample_rate
 
97
  )
98
 
99
  dataset = list()
@@ -155,7 +159,6 @@ def get_dataset(args):
155
  return
156
 
157
 
158
-
159
  def split_dataset(args):
160
  """分割训练集, 测试集"""
161
  file_dir = Path(args.file_dir)
 
54
  yield filename.as_posix()
55
 
56
 
57
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000, max_epoch: int = 20000):
58
  data_dir = Path(data_dir)
59
+ for epoch_idx in range(max_epoch):
60
+ for filename in data_dir.glob("**/*.wav"):
61
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
62
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
63
+
64
+ if raw_duration < duration:
65
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
66
+ continue
67
+ if signal.ndim != 1:
68
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
69
+
70
+ signal_length = len(signal)
71
+ win_size = int(duration * sample_rate)
72
+ for begin in range(0, signal_length - win_size, win_size):
73
+ row = {
74
+ "epoch_idx": epoch_idx,
75
+ "filename": filename.as_posix(),
76
+ "raw_duration": round(raw_duration, 4),
77
+ "offset": round(begin / sample_rate, 4),
78
+ "duration": round(duration, 4),
79
+ }
80
+ yield row
81
 
82
 
83
  def get_dataset(args):
 
90
  noise_generator = target_second_signal_generator(
91
  noise_dir.as_posix(),
92
  duration=args.duration,
93
+ sample_rate=args.target_sample_rate,
94
+ max_epoch=100000,
95
  )
96
  speech_generator = target_second_signal_generator(
97
  speech_dir.as_posix(),
98
  duration=args.duration,
99
+ sample_rate=args.target_sample_rate,
100
+ max_epoch=1,
101
  )
102
 
103
  dataset = list()
 
159
  return
160
 
161
 
 
162
  def split_dataset(args):
163
  """分割训练集, 测试集"""
164
  file_dir = Path(args.file_dir)
examples/conv_tasnet/step_2_train_model.py CHANGED
@@ -29,7 +29,7 @@ from toolbox.torch.utils.data.dataset.denoise_excel_dataset import DenoiseExcelD
29
  from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
30
  from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
31
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
32
- from toolbox.torchaudio.losses.spectral import LSDLoss
33
  from toolbox.torchaudio.losses.perceptual import NegSTOILoss
34
  from toolbox.torchaudio.metrics.pesq import run_pesq_score
35
 
@@ -39,7 +39,7 @@ def get_args():
39
  parser.add_argument("--train_dataset", default="train.xlsx", type=str)
40
  parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
41
 
42
- parser.add_argument("--max_epochs", default=100, type=int)
43
 
44
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
45
  parser.add_argument("--patience", default=5, type=int)
@@ -201,6 +201,12 @@ def main():
201
  neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
202
  neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
203
  lds_loss_fn = LSDLoss(reduction="mean").to(device)
 
 
 
 
 
 
204
 
205
  # training loop
206
 
@@ -245,8 +251,9 @@ def main():
245
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
246
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
247
  lds_loss = lds_loss_fn.forward(denoise_audios, clean_audios)
 
248
 
249
- loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * lds_loss
250
 
251
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
252
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
 
29
  from toolbox.torchaudio.models.conv_tasnet.configuration_conv_tasnet import ConvTasNetConfig
30
  from toolbox.torchaudio.models.conv_tasnet.modeling_conv_tasnet import ConvTasNet, ConvTasNetPretrainedModel
31
  from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
32
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
33
  from toolbox.torchaudio.losses.perceptual import NegSTOILoss
34
  from toolbox.torchaudio.metrics.pesq import run_pesq_score
35
 
 
39
  parser.add_argument("--train_dataset", default="train.xlsx", type=str)
40
  parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
41
 
42
+ parser.add_argument("--max_epochs", default=200, type=int)
43
 
44
  parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
45
  parser.add_argument("--patience", default=5, type=int)
 
201
  neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
202
  neg_stoi_loss_fn = NegSTOILoss(sample_rate=8000, reduction="mean").to(device)
203
  lds_loss_fn = LSDLoss(reduction="mean").to(device)
204
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
205
+ fft_size_list=[256, 512, 1024],
206
+ win_size_list=[120, 240, 480],
207
+ hop_size_list=[25, 50, 100],
208
+ reduction="mean"
209
+ ).to(device)
210
 
211
  # training loop
212
 
 
251
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
252
  neg_stoi_loss = neg_stoi_loss_fn.forward(denoise_audios, clean_audios)
253
  lds_loss = lds_loss_fn.forward(denoise_audios, clean_audios)
254
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
255
 
256
+ loss = 0.25 * ae_loss + 0.25 * neg_si_snr_loss + 0.25 * neg_stoi_loss + 0.25 * lds_loss + 0.25 * mr_stft_loss
257
 
258
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
259
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
examples/conv_tasnet/yaml/config.yaml CHANGED
@@ -1,42 +1,17 @@
1
- model_name: "nx_clean_unet"
2
 
3
  sample_rate: 8000
4
- segment_size: 16000
5
- n_fft: 512
6
- win_size: 200
7
- hop_size: 80
8
 
9
- down_sampling_num_layers: 6
10
- down_sampling_in_channels: 1
11
- down_sampling_hidden_channels: 64
12
- down_sampling_kernel_size: 4
13
- down_sampling_stride: 2
 
 
 
14
 
15
- causal_in_channels: 1
16
- causal_out_channels: 1
17
- causal_kernel_size: 3
18
- causal_bias: false
19
- causal_separable: true
20
- causal_f_stride: 1
21
- causal_num_layers: 5
22
-
23
- tsfm_hidden_size: 256
24
- tsfm_attention_heads: 8
25
- tsfm_num_blocks: 6
26
- tsfm_dropout_rate: 0.1
27
- tsfm_max_length: 512
28
- tsfm_chunk_size: 1
29
- tsfm_num_left_chunks: 128
30
- tsfm_num_right_chunks: 4
31
-
32
- discriminator_dim: 32
33
- discriminator_in_channel: 2
34
-
35
- compress_factor: 0.3
36
-
37
- batch_size: 64
38
- learning_rate: 0.0005
39
- adam_b1: 0.8
40
- adam_b2: 0.99
41
- lr_decay: 0.99
42
- seed: 1234
 
1
+ model_name: "conv_tasnet"
2
 
3
  sample_rate: 8000
4
+ segment_size: 4
 
 
 
5
 
6
+ win_size: 20
7
+ freq_bins: 256
8
+ bottleneck_channels: 256
9
+ num_speakers: 1
10
+ num_blocks: 4
11
+ num_sub_blocks: 8
12
+ sub_blocks_channels: 512
13
+ sub_blocks_kernel_size: 3
14
 
15
+ norm_type: "gLN"
16
+ causal: false
17
+ mask_nonlinear: "relu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/data_preprocess/dns_challenge_to_8k/process_dns_challenge_clean_read_speech.py CHANGED
@@ -57,9 +57,11 @@ def get_args():
57
  # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.read_speech\datasets\clean",
58
  # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.mandarin_speech\datasets\clean\mandarin_speech",
59
  # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.singing_voice\datasets\clean\singing_voice",
60
- # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.german_speech\datasets\clean\german_speech",
 
61
  # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.italian_speech\datasets\clean\italian_speech",
62
- default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.russian_speech\datasets\clean\russian_speech",
 
63
  type=str
64
  )
65
  parser.add_argument(
@@ -67,9 +69,11 @@ def get_args():
67
  # default=r"E:\programmer\asr_datasets\denoise\dns-clean-read-speech-8k",
68
  # default=r"E:\programmer\asr_datasets\denoise\dns-clean-mandarin-speech-8k",
69
  # default=r"E:\programmer\asr_datasets\denoise\dns-clean-singing-voice-8k",
70
- # default=r"E:\programmer\asr_datasets\denoise\dns-clean-german-speech-8k",
 
71
  # default=r"E:\programmer\asr_datasets\denoise\dns-clean-italian-speech-8k",
72
- default=r"E:\programmer\asr_datasets\denoise\dns-clean-russian-speech-8k",
 
73
  type=str
74
  )
75
  parser.add_argument("--sample_rate", default=8000, type=int)
@@ -87,24 +91,26 @@ def main():
87
  # finished_set
88
  finished_set = set()
89
  for filename in tqdm(output_dir.glob("**/*.wav")):
90
- name = filename.stem
91
- finished_set.add(name)
 
 
92
  print(f"finished_set count: {len(finished_set)}")
93
 
94
  for filename in tqdm(data_dir.glob("**/*.wav")):
95
- label = filename.parts[-2]
96
- name = filename.stem
97
  relative_name = filename.relative_to(data_dir)
98
- # print(f"filename: {filename.as_posix()}")
99
- if name in finished_set:
100
  continue
101
- finished_set.add(name)
102
 
103
  try:
104
- signal, _ = librosa.load(filename.as_posix(), sr=args.sample_rate)
105
  except Exception:
106
  print(f"skip file: {filename.as_posix()}")
107
  continue
 
 
108
 
109
  signal = signal * (1 << 15)
110
  signal = np.array(signal, dtype=np.int16)
 
57
  # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.read_speech\datasets\clean",
58
  # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.mandarin_speech\datasets\clean\mandarin_speech",
59
  # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.singing_voice\datasets\clean\singing_voice",
60
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.french_data\datasets\clean\french_data",
61
+ default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.german_speech\datasets\clean\german_speech",
62
  # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.italian_speech\datasets\clean\italian_speech",
63
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.russian_speech\datasets\clean\russian_speech",
64
+ # default=r"E:\programmer\asr_datasets\dns-challenge\datasets.clean.spanish_speech\datasets\clean\spanish_speech",
65
  type=str
66
  )
67
  parser.add_argument(
 
69
  # default=r"E:\programmer\asr_datasets\denoise\dns-clean-read-speech-8k",
70
  # default=r"E:\programmer\asr_datasets\denoise\dns-clean-mandarin-speech-8k",
71
  # default=r"E:\programmer\asr_datasets\denoise\dns-clean-singing-voice-8k",
72
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-french-speech-8k",
73
+ default=r"E:\programmer\asr_datasets\denoise\dns-clean-german-speech-8k",
74
  # default=r"E:\programmer\asr_datasets\denoise\dns-clean-italian-speech-8k",
75
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-russian-speech-8k",
76
+ # default=r"E:\programmer\asr_datasets\denoise\dns-clean-spanish-speech-8k",
77
  type=str
78
  )
79
  parser.add_argument("--sample_rate", default=8000, type=int)
 
91
  # finished_set
92
  finished_set = set()
93
  for filename in tqdm(output_dir.glob("**/*.wav")):
94
+ filename = Path(filename)
95
+ relative_name = filename.relative_to(output_dir)
96
+ relative_name_ = relative_name.as_posix()
97
+ finished_set.add(relative_name_)
98
  print(f"finished_set count: {len(finished_set)}")
99
 
100
  for filename in tqdm(data_dir.glob("**/*.wav")):
 
 
101
  relative_name = filename.relative_to(data_dir)
102
+ relative_name_ = relative_name.as_posix()
103
+ if relative_name_ in finished_set:
104
  continue
105
+ finished_set.add(relative_name_)
106
 
107
  try:
108
+ signal, _ = librosa.load(filename.as_posix(), mono=False, sr=args.sample_rate)
109
  except Exception:
110
  print(f"skip file: {filename.as_posix()}")
111
  continue
112
+ if signal.ndim != 1:
113
+ raise AssertionError
114
 
115
  signal = signal * (1 << 15)
116
  signal = np.array(signal, dtype=np.int16)
toolbox/torchaudio/losses/spectral.py CHANGED
@@ -283,6 +283,7 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
283
  hop_size_list: List[int] = None,
284
  factor_sc=0.1,
285
  factor_mag=0.1,
 
286
  ):
287
  super(MultiResolutionSTFTLoss, self).__init__()
288
  fft_size_list = fft_size_list or [1024, 2048, 512]
@@ -299,6 +300,7 @@ class MultiResolutionSTFTLoss(torch.nn.Module):
299
  n_fft=n_fft,
300
  win_size=win_size,
301
  hop_size=hop_size,
 
302
  )
303
  )
304
 
 
283
  hop_size_list: List[int] = None,
284
  factor_sc=0.1,
285
  factor_mag=0.1,
286
+ reduction: str = "mean",
287
  ):
288
  super(MultiResolutionSTFTLoss, self).__init__()
289
  fft_size_list = fft_size_list or [1024, 2048, 512]
 
300
  n_fft=n_fft,
301
  win_size=win_size,
302
  hop_size=hop_size,
303
+ reduction=reduction,
304
  )
305
  )
306
 
toolbox/torchaudio/models/conv_tasnet/modeling_conv_tasnet.py CHANGED
@@ -388,6 +388,7 @@ class ConvTasNet(nn.Module):
388
  est_mask = self.separator.forward(mixture_w)
389
  # est_mask shape: [batch_size, num_speakers, freq_bins, time_steps]
390
  est_source = self.decoder.forward(mixture_w, est_mask)
 
391
 
392
  num_samples1 = mixture.size(-1)
393
  num_samples2 = est_source.size(-1)
 
388
  est_mask = self.separator.forward(mixture_w)
389
  # est_mask shape: [batch_size, num_speakers, freq_bins, time_steps]
390
  est_source = self.decoder.forward(mixture_w, est_mask)
391
+ # est_source shape: [batch_size, num_speakers, num_samples]
392
 
393
  num_samples1 = mixture.size(-1)
394
  num_samples2 = est_source.size(-1)