HoneyTian commited on
Commit
bd3d872
·
1 Parent(s): 602ffc9
examples/dfnet2/yaml/config.yaml CHANGED
@@ -7,6 +7,9 @@ win_size: 200
7
  hop_size: 80
8
 
9
  spec_bins: 256
 
 
 
10
 
11
  # model
12
  conv_channels: 64
 
7
  hop_size: 80
8
 
9
  spec_bins: 256
10
+ erb_bins: 32
11
+ min_freq_bins_for_erb: 2
12
+ use_ema_norm: true
13
 
14
  # model
15
  conv_channels: 64
examples/dtln/step_2_train_model.py CHANGED
@@ -259,6 +259,7 @@ def main():
259
  noisy_audios: torch.Tensor = noisy_audios.to(device)
260
 
261
  denoise_audios = model.forward(noisy_audios)
 
262
 
263
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
264
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
 
259
  noisy_audios: torch.Tensor = noisy_audios.to(device)
260
 
261
  denoise_audios = model.forward(noisy_audios)
262
+ denoise_audios = torch.squeeze(denoise_audios, dim=1)
263
 
264
  mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
265
  neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
examples/rnnoise/run.sh CHANGED
@@ -6,10 +6,9 @@ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name f
6
 
7
  sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir
8
 
9
- sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \
10
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
12
- --sparse
13
 
14
 
15
  END
@@ -108,66 +107,6 @@ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
108
  --valid_dataset "${valid_dataset}" \
109
  --serialization_dir "${file_dir}" \
110
  --config_file "${config_file}" \
111
-
112
- fi
113
-
114
-
115
- if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
116
- $verbose && echo "stage 3: test model"
117
- cd "${work_dir}" || exit 1
118
- python3 step_3_evaluation.py \
119
- --valid_dataset "${valid_dataset}" \
120
- --model_dir "${file_dir}/best" \
121
- --evaluation_audio_dir "${evaluation_audio_dir}" \
122
- --limit "${limit}" \
123
-
124
- fi
125
-
126
-
127
- if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
128
- $verbose && echo "stage 4: export model"
129
- cd "${work_dir}" || exit 1
130
- python3 step_5_export_models.py \
131
- --vocabulary_dir "${vocabulary_dir}" \
132
- --model_dir "${file_dir}/best" \
133
- --serialization_dir "${file_dir}" \
134
-
135
- fi
136
-
137
-
138
- if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
139
- $verbose && echo "stage 5: collect files"
140
- cd "${work_dir}" || exit 1
141
-
142
- mkdir -p ${final_model_dir}
143
-
144
- cp "${file_dir}/best"/* "${final_model_dir}"
145
- cp -r "${file_dir}/vocabulary" "${final_model_dir}"
146
-
147
- cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
148
-
149
- cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
150
- cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
151
- cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
152
- cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
153
-
154
- cd "${final_model_dir}/.." || exit 1;
155
-
156
- if [ -e "${final_model_name}.zip" ]; then
157
- rm -rf "${final_model_name}_backup.zip"
158
- mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
159
- fi
160
-
161
- zip -r "${final_model_name}.zip" "${final_model_name}"
162
- rm -rf "${final_model_name}"
163
-
164
- fi
165
-
166
-
167
- if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
168
- $verbose && echo "stage 6: clear file_dir"
169
- cd "${work_dir}" || exit 1
170
-
171
- rm -rf "${file_dir}";
172
 
173
  fi
 
6
 
7
  sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir
8
 
9
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name rnnoise-nx-dns3 \
10
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
  --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train" \
 
12
 
13
 
14
  END
 
107
  --valid_dataset "${valid_dataset}" \
108
  --serialization_dir "${file_dir}" \
109
  --config_file "${config_file}" \
110
+ --sparse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  fi
toolbox/torchaudio/models/dfnet/inference_dfnet.py CHANGED
@@ -76,13 +76,10 @@ class InferenceDfNet(object):
76
  with torch.no_grad():
77
  est_spec, est_wav, est_mask, lsnr = self.model.forward(noisy_audios)
78
 
79
- # shape: [batch_size, num_samples]
80
- enhanced_audio = torch.unsqueeze(est_wav, dim=1)
81
  # shape: [batch_size, 1, num_samples]
82
-
83
- enhanced_audio = enhanced_audio[0]
84
  # shape: [channels, num_samples]
85
- return enhanced_audio
86
 
87
 
88
  def main():
@@ -90,7 +87,7 @@ def main():
90
  infer_model = InferenceDfNet(model_zip_file)
91
 
92
  sample_rate = 8000
93
- noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_3.wav"
94
  noisy_audio, sample_rate = librosa.load(
95
  noisy_audio_file.as_posix(),
96
  sr=sample_rate,
 
76
  with torch.no_grad():
77
  est_spec, est_wav, est_mask, lsnr = self.model.forward(noisy_audios)
78
 
 
 
79
  # shape: [batch_size, 1, num_samples]
80
+ denoise = est_wav[0]
 
81
  # shape: [channels, num_samples]
82
+ return denoise
83
 
84
 
85
  def main():
 
87
  infer_model = InferenceDfNet(model_zip_file)
88
 
89
  sample_rate = 8000
90
+ noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav"
91
  noisy_audio, sample_rate = librosa.load(
92
  noisy_audio_file.as_posix(),
93
  sr=sample_rate,
toolbox/torchaudio/models/dfnet/modeling_dfnet.py CHANGED
@@ -961,9 +961,8 @@ class DfNet(nn.Module):
961
  # est_spec shape: [b, f, t], torch.complex64
962
 
963
  est_wav = self.istft.forward(est_spec)
964
- est_wav = torch.squeeze(est_wav, dim=1)
965
- est_wav = est_wav[:, :n_samples]
966
- # est_wav shape: [b, n_samples]
967
 
968
  est_mask = torch.squeeze(mask, dim=1)
969
  est_mask = est_mask.permute(0, 2, 1)
 
961
  # est_spec shape: [b, f, t], torch.complex64
962
 
963
  est_wav = self.istft.forward(est_spec)
964
+ est_wav = est_wav[:, :, :n_samples]
965
+ # est_wav shape: [b, 1, n_samples]
 
966
 
967
  est_mask = torch.squeeze(mask, dim=1)
968
  est_mask = est_mask.permute(0, 2, 1)
toolbox/torchaudio/models/dfnet2/configuration_dfnet2.py CHANGED
@@ -16,6 +16,7 @@ class DfNet2Config(PretrainedConfig):
16
  spec_bins: int = 256,
17
  erb_bins: int = 32,
18
  min_freq_bins_for_erb: int = 2,
 
19
 
20
  conv_channels: int = 64,
21
  conv_kernel_size_input: Tuple[int, int] = (3, 3),
@@ -83,6 +84,8 @@ class DfNet2Config(PretrainedConfig):
83
  self.erb_bins = erb_bins
84
  self.min_freq_bins_for_erb = min_freq_bins_for_erb
85
 
 
 
86
  # conv
87
  self.conv_channels = conv_channels
88
  self.conv_kernel_size_input = conv_kernel_size_input
 
16
  spec_bins: int = 256,
17
  erb_bins: int = 32,
18
  min_freq_bins_for_erb: int = 2,
19
+ use_ema_norm: bool = True,
20
 
21
  conv_channels: int = 64,
22
  conv_kernel_size_input: Tuple[int, int] = (3, 3),
 
84
  self.erb_bins = erb_bins
85
  self.min_freq_bins_for_erb = min_freq_bins_for_erb
86
 
87
+ self.use_ema_norm = use_ema_norm
88
+
89
  # conv
90
  self.conv_channels = conv_channels
91
  self.conv_kernel_size_input = conv_kernel_size_input
toolbox/torchaudio/models/dfnet2/inference_dfnet2.py CHANGED
@@ -14,8 +14,8 @@ import torchaudio
14
  torch.set_num_threads(1)
15
 
16
  from project_settings import project_path
17
- from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
18
- from toolbox.torchaudio.models.dfnet.modeling_dfnet import DfNetPretrainedModel, MODEL_FILE
19
 
20
  logger = logging.getLogger("toolbox")
21
 
@@ -43,10 +43,10 @@ class InferenceDfNet(object):
43
  f_zip.extractall(path=out_root)
44
  model_path = out_root / model_path.stem
45
 
46
- config = DfNetConfig.from_pretrained(
47
  pretrained_model_name_or_path=model_path.as_posix(),
48
  )
49
- model = DfNetPretrainedModel.from_pretrained(
50
  pretrained_model_name_or_path=model_path.as_posix(),
51
  )
52
  model.to(self.device)
@@ -60,13 +60,13 @@ class InferenceDfNet(object):
60
  noisy_audio = noisy_audio.unsqueeze(dim=0)
61
 
62
  # noisy_audio shape: [batch_size, n_samples]
63
- enhanced_audio = self.enhancement_by_tensor(noisy_audio)
64
  # enhanced_audio shape: [channels, num_samples]
65
  enhanced_audio = enhanced_audio[0]
66
  # enhanced_audio shape: [num_samples]
67
  return enhanced_audio.cpu().numpy()
68
 
69
- def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
70
  if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
71
  raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
72
 
@@ -76,21 +76,33 @@ class InferenceDfNet(object):
76
  with torch.no_grad():
77
  est_spec, est_wav, est_mask, lsnr = self.model.forward(noisy_audios)
78
 
79
- # shape: [batch_size, num_samples]
80
- enhanced_audio = torch.unsqueeze(est_wav, dim=1)
81
  # shape: [batch_size, 1, num_samples]
 
 
 
82
 
83
- enhanced_audio = enhanced_audio[0]
 
 
 
 
 
 
 
 
 
 
 
84
  # shape: [channels, num_samples]
85
- return enhanced_audio
86
 
87
 
88
  def main():
89
- model_zip_file = project_path / "trained_models/dfnet-nx-dns3.zip"
90
  infer_model = InferenceDfNet(model_zip_file)
91
 
92
  sample_rate = 8000
93
- noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_3.wav"
94
  noisy_audio, sample_rate = librosa.load(
95
  noisy_audio_file.as_posix(),
96
  sr=sample_rate,
@@ -101,11 +113,19 @@ def main():
101
  noisy_audio = noisy_audio.unsqueeze(dim=0)
102
 
103
  begin = time.time()
104
- enhanced_audio = infer_model.enhancement_by_tensor(noisy_audio)
 
 
 
 
 
 
 
 
105
  time_cost = time.time() - begin
106
  print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
107
 
108
- filename = "enhanced_audio.wav"
109
  torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
110
 
111
  return
 
14
  torch.set_num_threads(1)
15
 
16
  from project_settings import project_path
17
+ from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config
18
+ from toolbox.torchaudio.models.dfnet2.modeling_dfnet2 import DfNet2PretrainedModel, MODEL_FILE
19
 
20
  logger = logging.getLogger("toolbox")
21
 
 
43
  f_zip.extractall(path=out_root)
44
  model_path = out_root / model_path.stem
45
 
46
+ config = DfNet2Config.from_pretrained(
47
  pretrained_model_name_or_path=model_path.as_posix(),
48
  )
49
+ model = DfNet2PretrainedModel.from_pretrained(
50
  pretrained_model_name_or_path=model_path.as_posix(),
51
  )
52
  model.to(self.device)
 
60
  noisy_audio = noisy_audio.unsqueeze(dim=0)
61
 
62
  # noisy_audio shape: [batch_size, n_samples]
63
+ enhanced_audio = self.denoise_offline(noisy_audio)
64
  # enhanced_audio shape: [channels, num_samples]
65
  enhanced_audio = enhanced_audio[0]
66
  # enhanced_audio shape: [num_samples]
67
  return enhanced_audio.cpu().numpy()
68
 
69
+ def denoise_offline(self, noisy_audio: torch.Tensor) -> torch.Tensor:
70
  if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
71
  raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
72
 
 
76
  with torch.no_grad():
77
  est_spec, est_wav, est_mask, lsnr = self.model.forward(noisy_audios)
78
 
 
 
79
  # shape: [batch_size, 1, num_samples]
80
+ denoise = est_wav[0]
81
+ # shape: [channels, num_samples]
82
+ return denoise
83
 
84
+ def denoise_online(self, noisy_audio: torch.Tensor) -> torch.Tensor:
85
+ if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
86
+ raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
87
+
88
+ # noisy_audio shape: [batch_size, num_samples]
89
+ noisy_audios = noisy_audio.to(self.device)
90
+
91
+ with torch.no_grad():
92
+ denoise = self.model.forward_chunk_by_chunk(noisy_audios)
93
+
94
+ # shape: [batch_size, 1, num_samples]
95
+ denoise = denoise[0]
96
  # shape: [channels, num_samples]
97
+ return denoise
98
 
99
 
100
  def main():
101
+ model_zip_file = project_path / "trained_models/dfnet2-nx-dns3.zip"
102
  infer_model = InferenceDfNet(model_zip_file)
103
 
104
  sample_rate = 8000
105
+ noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav"
106
  noisy_audio, sample_rate = librosa.load(
107
  noisy_audio_file.as_posix(),
108
  sr=sample_rate,
 
113
  noisy_audio = noisy_audio.unsqueeze(dim=0)
114
 
115
  begin = time.time()
116
+ enhanced_audio = infer_model.denoise_offline(noisy_audio)
117
+ time_cost = time.time() - begin
118
+ print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
119
+
120
+ filename = "enhanced_audio_offline.wav"
121
+ torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
122
+
123
+ begin = time.time()
124
+ enhanced_audio = infer_model.denoise_online(noisy_audio)
125
  time_cost = time.time() - begin
126
  print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
127
 
128
+ filename = "enhanced_audio_online.wav"
129
  torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
130
 
131
  return
toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py CHANGED
@@ -1097,7 +1097,8 @@ class DfNet2(nn.Module):
1097
  noisy = self.signal_prepare(noisy)
1098
 
1099
  spec, feat_erb, feat_spec = self.feature_prepare(noisy)
1100
- feat_erb, feat_spec, _ = self.feature_norm(feat_erb, feat_spec)
 
1101
 
1102
  e0, e1, e2, e3, emb, c0, lsnr, _ = self.encoder.forward(feat_erb, feat_spec)
1103
 
@@ -1180,11 +1181,12 @@ class DfNet2(nn.Module):
1180
  # spec shape: [b, 1, t, f, 2]
1181
  # feat_erb shape: [b, 1, t, erb_bins]
1182
  # feat_spec shape: [b, 2, t, df_bins]
1183
- feat_erb, feat_spec, cache_dict6 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict6)
 
1184
 
1185
- e0, e1, e2, e3, emb, c0, lsnr, cache_dict0 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict0)
1186
 
1187
- mask, cache_dict1 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict1)
1188
  # mask shape: [b, 1, t, erb_bins]
1189
  mask = self.erb_bands.erb_scale_inv(mask)
1190
  # mask shape: [b, 1, t, f]
@@ -1198,16 +1200,16 @@ class DfNet2(nn.Module):
1198
  lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
1199
  # lsnr shape: [b, 1, t]
1200
 
1201
- df_coefs, cache_dict2 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict2)
1202
  df_coefs = self.df_out_transform(df_coefs)
1203
  # df_coefs shape: [b, df_order, t, df_bins, 2]
1204
 
1205
  spec_ = spec[:, :, :, :self.config.spec_bins, :]
1206
  # spec shape: [b, 1, t, spec_bins, 2]
1207
- spec_f, cache_dict3 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict3)
1208
  # spec_f shape: [b, 1, t, df_bins, 2], torch.float32
1209
 
1210
- spec_e, cache_dict4 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict4)
1211
 
1212
  spec_e = torch.squeeze(spec_e, dim=1)
1213
  spec_e = spec_e.permute(0, 2, 1, 3)
@@ -1219,7 +1221,7 @@ class DfNet2(nn.Module):
1219
  est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
1220
  # est_spec shape: [b, f, t], torch.complex64
1221
 
1222
- est_wav, cache_dict5 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict5)
1223
  # est_wav shape: [b, 1, hop_size]
1224
 
1225
  waveform_list.append(est_wav)
@@ -1361,14 +1363,22 @@ class DfNet2PretrainedModel(DfNet2):
1361
 
1362
 
1363
  def main():
 
 
1364
 
1365
  config = DfNet2Config()
1366
  model = DfNet2PretrainedModel(config=config)
1367
  model.eval()
1368
 
1369
- noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
 
 
1370
 
 
1371
  est_spec, est_wav, est_mask, lsnr = model.forward(noisy)
 
 
 
1372
  # print(f"est_spec.shape: {est_spec.shape}")
1373
  # print(f"est_wav.shape: {est_wav.shape}")
1374
  # print(f"est_mask.shape: {est_mask.shape}")
@@ -1381,7 +1391,11 @@ def main():
1381
  print(waveform[:, :, 15760: 15762])
1382
  print(waveform[:, :, 15840: 15842])
1383
 
 
1384
  waveform = model.forward_chunk_by_chunk(noisy)
 
 
 
1385
  waveform = waveform[:, :, (config.df_lookahead*config.hop_size):]
1386
  print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
1387
  print(waveform[:, :, 300: 302])
 
1097
  noisy = self.signal_prepare(noisy)
1098
 
1099
  spec, feat_erb, feat_spec = self.feature_prepare(noisy)
1100
+ if self.config.use_ema_norm:
1101
+ feat_erb, feat_spec, _ = self.feature_norm(feat_erb, feat_spec)
1102
 
1103
  e0, e1, e2, e3, emb, c0, lsnr, _ = self.encoder.forward(feat_erb, feat_spec)
1104
 
 
1181
  # spec shape: [b, 1, t, f, 2]
1182
  # feat_erb shape: [b, 1, t, erb_bins]
1183
  # feat_spec shape: [b, 2, t, df_bins]
1184
+ if self.config.use_ema_norm:
1185
+ feat_erb, feat_spec, cache_dict0 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict0)
1186
 
1187
+ e0, e1, e2, e3, emb, c0, lsnr, cache_dict1 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict1)
1188
 
1189
+ mask, cache_dict2 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict2)
1190
  # mask shape: [b, 1, t, erb_bins]
1191
  mask = self.erb_bands.erb_scale_inv(mask)
1192
  # mask shape: [b, 1, t, f]
 
1200
  lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
1201
  # lsnr shape: [b, 1, t]
1202
 
1203
+ df_coefs, cache_dict3 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict3)
1204
  df_coefs = self.df_out_transform(df_coefs)
1205
  # df_coefs shape: [b, df_order, t, df_bins, 2]
1206
 
1207
  spec_ = spec[:, :, :, :self.config.spec_bins, :]
1208
  # spec shape: [b, 1, t, spec_bins, 2]
1209
+ spec_f, cache_dict4 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict4)
1210
  # spec_f shape: [b, 1, t, df_bins, 2], torch.float32
1211
 
1212
+ spec_e, cache_dict5 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict5)
1213
 
1214
  spec_e = torch.squeeze(spec_e, dim=1)
1215
  spec_e = spec_e.permute(0, 2, 1, 3)
 
1221
  est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
1222
  # est_spec shape: [b, f, t], torch.complex64
1223
 
1224
+ est_wav, cache_dict6 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict6)
1225
  # est_wav shape: [b, 1, hop_size]
1226
 
1227
  waveform_list.append(est_wav)
 
1363
 
1364
 
1365
  def main():
1366
+ import time
1367
+ # torch.set_num_threads(1)
1368
 
1369
  config = DfNet2Config()
1370
  model = DfNet2PretrainedModel(config=config)
1371
  model.eval()
1372
 
1373
+ num_samples = 16000
1374
+ noisy = torch.randn(size=(1, num_samples), dtype=torch.float32)
1375
+ duration = num_samples / config.sample_rate
1376
 
1377
+ begin = time.time()
1378
  est_spec, est_wav, est_mask, lsnr = model.forward(noisy)
1379
+ time_cost = time.time() - begin
1380
+ print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
1381
+
1382
  # print(f"est_spec.shape: {est_spec.shape}")
1383
  # print(f"est_wav.shape: {est_wav.shape}")
1384
  # print(f"est_mask.shape: {est_mask.shape}")
 
1391
  print(waveform[:, :, 15760: 15762])
1392
  print(waveform[:, :, 15840: 15842])
1393
 
1394
+ begin = time.time()
1395
  waveform = model.forward_chunk_by_chunk(noisy)
1396
+ time_cost = time.time() - begin
1397
+ print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
1398
+
1399
  waveform = waveform[:, :, (config.df_lookahead*config.hop_size):]
1400
  print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
1401
  print(waveform[:, :, 300: 302])
toolbox/torchaudio/models/dfnet2/yaml/config.yaml CHANGED
@@ -7,6 +7,9 @@ win_size: 200
7
  hop_size: 80
8
 
9
  spec_bins: 256
 
 
 
10
 
11
  # model
12
  conv_channels: 64
 
7
  hop_size: 80
8
 
9
  spec_bins: 256
10
+ erb_bins: 32
11
+ min_freq_bins_for_erb: 2
12
+ use_ema_norm: true
13
 
14
  # model
15
  conv_channels: 64
toolbox/torchaudio/models/dtln/inference_dtln.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import logging
4
+ from pathlib import Path
5
+ import shutil
6
+ import tempfile, time
7
+ import zipfile
8
+
9
+ import librosa
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+
14
+ torch.set_num_threads(1)
15
+
16
+ from project_settings import project_path
17
+ from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
18
+ from toolbox.torchaudio.models.dtln.modeling_dtln import DTLNPretrainedModel, MODEL_FILE
19
+
20
+ logger = logging.getLogger("toolbox")
21
+
22
+
23
+ class InferenceDTLN(object):
24
+ def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
25
+ self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
26
+ self.device = torch.device(device)
27
+
28
+ logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
29
+ config, model = self.load_models(self.pretrained_model_path_or_zip_file)
30
+ logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
31
+
32
+ self.config = config
33
+ self.model = model
34
+ self.model.to(device)
35
+ self.model.eval()
36
+
37
+ def load_models(self, model_path: str):
38
+ model_path = Path(model_path)
39
+ if model_path.name.endswith(".zip"):
40
+ with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
41
+ out_root = Path(tempfile.gettempdir()) / "nx_denoise"
42
+ out_root.mkdir(parents=True, exist_ok=True)
43
+ f_zip.extractall(path=out_root)
44
+ model_path = out_root / model_path.stem
45
+
46
+ config = DTLNConfig.from_pretrained(
47
+ pretrained_model_name_or_path=model_path.as_posix(),
48
+ )
49
+ model = DTLNPretrainedModel.from_pretrained(
50
+ pretrained_model_name_or_path=model_path.as_posix(),
51
+ )
52
+ model.to(self.device)
53
+ model.eval()
54
+
55
+ shutil.rmtree(model_path)
56
+ return config, model
57
+
58
+ def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray:
59
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
60
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
61
+
62
+ # noisy_audio shape: [batch_size, n_samples]
63
+ enhanced_audio = self.denoise_offline(noisy_audio)
64
+ # enhanced_audio shape: [channels, num_samples]
65
+ enhanced_audio = enhanced_audio[0]
66
+ # enhanced_audio shape: [num_samples]
67
+ return enhanced_audio.cpu().numpy()
68
+
69
+ def denoise_offline(self, noisy_audio: torch.Tensor) -> torch.Tensor:
70
+ if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
71
+ raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
72
+
73
+ # noisy_audio shape: [batch_size, num_samples]
74
+ noisy_audios = noisy_audio.to(self.device)
75
+
76
+ with torch.no_grad():
77
+ denoise = self.model.forward(noisy_audios)
78
+
79
+ # denoise shape: [batch_size, 1, num_samples]
80
+ denoise = denoise[0]
81
+ # shape: [channels, num_samples]
82
+ return denoise
83
+
84
+ def denoise_online(self, noisy_audio: torch.Tensor) -> torch.Tensor:
85
+ if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
86
+ raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
87
+
88
+ # noisy_audio shape: [batch_size, num_samples]
89
+ noisy_audios = noisy_audio.to(self.device)
90
+
91
+ with torch.no_grad():
92
+ denoise = self.model.forward_chunk_by_chunk(noisy_audios)
93
+
94
+ # denoise shape: [batch_size, 1, num_samples]
95
+ denoise = denoise[0]
96
+ # shape: [channels, num_samples]
97
+ return denoise
98
+
99
+
100
+ def main():
101
+ model_zip_file = project_path / "trained_models/dtln-nx-dns3.zip"
102
+ infer_model = InferenceDTLN(model_zip_file)
103
+
104
+ sample_rate = 8000
105
+ noisy_audio_file = project_path / "data/examples/ai_agent/chinese-3.wav"
106
+ noisy_audio, sample_rate = librosa.load(
107
+ noisy_audio_file.as_posix(),
108
+ sr=sample_rate,
109
+ )
110
+ duration = librosa.get_duration(y=noisy_audio, sr=sample_rate)
111
+ # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
112
+ noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
113
+ noisy_audio = noisy_audio.unsqueeze(dim=0)
114
+
115
+ # offline
116
+ begin = time.time()
117
+ enhanced_audio = infer_model.denoise_offline(noisy_audio)
118
+ time_cost = time.time() - begin
119
+ print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
120
+
121
+ filename = "enhanced_audio_offline.wav"
122
+ torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
123
+
124
+ # online
125
+ begin = time.time()
126
+ enhanced_audio = infer_model.denoise_online(noisy_audio)
127
+ time_cost = time.time() - begin
128
+ print(f"enhanced_audio.shape: {enhanced_audio.shape}, time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
129
+
130
+ filename = "enhanced_audio_online.wav"
131
+ torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate)
132
+
133
+ return
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main()
toolbox/torchaudio/models/dtln/modeling_dtln.py CHANGED
@@ -167,12 +167,13 @@ class DTLNModel(nn.Module):
167
  if remainder > 0:
168
  n_samples_pad = self.hop_size - remainder
169
  signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
170
- return signal, n_samples
171
 
172
  def forward(self,
173
  noisy: torch.Tensor,
174
  ):
175
- noisy, num_samples = self.signal_prepare(noisy)
 
176
  batch_size, _, num_samples_pad = noisy.shape
177
  # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
178
 
@@ -182,6 +183,8 @@ class DTLNModel(nn.Module):
182
 
183
  denoise = denoise[:, :num_samples]
184
  # denoise shape: [b, num_samples]
 
 
185
  return denoise
186
 
187
  def forward_chunk(self,
@@ -189,7 +192,7 @@ class DTLNModel(nn.Module):
189
  in_state1: torch.Tensor = None,
190
  in_state2: torch.Tensor = None,
191
  ):
192
- # noisy shape: [b, num_samples]
193
  spec = self.stft.forward(noisy)
194
  # spec shape: [b, f, t], torch.complex64
195
  # t = (num_samples - win_size) / hop_size + 1
@@ -233,6 +236,44 @@ class DTLNModel(nn.Module):
233
 
234
  return denoise_frame, out_state1, out_state2
235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int):
237
  # overlap and add
238
 
@@ -301,43 +342,28 @@ class DTLNPretrainedModel(DTLNModel):
301
 
302
 
303
  def main():
304
- fft_size = 512
305
- hop_size = 128
306
-
307
- model = DTLNModel(fft_size=fft_size, hop_size=hop_size)
308
 
309
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
310
- batch_size, num_samples = noisy.shape
311
 
312
- denoise = model.forward(noisy)
 
313
  print(f"denoise.shape: {denoise.shape}")
 
 
 
 
314
 
315
- t = (num_samples - fft_size) // hop_size + 1
316
-
317
- denoise_list = list()
318
- out_state1 = None
319
- out_state2 = None
320
- denoise_cache = torch.zeros(size=(batch_size, fft_size - hop_size,), dtype=noisy.dtype)
321
- denoise_list.append(torch.clone(denoise_cache))
322
- for i in range(t):
323
- begin = i * hop_size
324
- end = begin + fft_size
325
- sub_noisy = noisy[:, begin: end]
326
- with torch.no_grad():
327
- sub_denoise_frame, out_state1, out_state2 = model.forward_chunk(sub_noisy, out_state1, out_state2)
328
- # sub_denoise_frame shape: [b, fft_size, 1]
329
- sub_denoise_frame = sub_denoise_frame[:, :, 0]
330
- # sub_denoise_frame shape: [b, fft_size]
331
-
332
- sub_denoise_frame[:, hop_size:] += denoise_cache
333
- denoise_out = sub_denoise_frame[:, :hop_size]
334
- denoise_cache = sub_denoise_frame[:, hop_size:]
335
- # denoise_cache shape: [b, hop_size]
336
-
337
- denoise_list.append(denoise_out)
338
-
339
- denoise = torch.concat(denoise_list, dim=-1)
340
  print(f"denoise.shape: {denoise.shape}")
 
 
 
 
 
 
341
  return
342
 
343
 
 
167
  if remainder > 0:
168
  n_samples_pad = self.hop_size - remainder
169
  signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
170
+ return signal
171
 
172
  def forward(self,
173
  noisy: torch.Tensor,
174
  ):
175
+ num_samples = noisy.shape[-1]
176
+ noisy = self.signal_prepare(noisy)
177
  batch_size, _, num_samples_pad = noisy.shape
178
  # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
179
 
 
183
 
184
  denoise = denoise[:, :num_samples]
185
  # denoise shape: [b, num_samples]
186
+ denoise = torch.unsqueeze(denoise, dim=1)
187
+ # denoise shape: [b, 1, num_samples]
188
  return denoise
189
 
190
  def forward_chunk(self,
 
192
  in_state1: torch.Tensor = None,
193
  in_state2: torch.Tensor = None,
194
  ):
195
+ # noisy shape: [b, 1, num_samples]
196
  spec = self.stft.forward(noisy)
197
  # spec shape: [b, f, t], torch.complex64
198
  # t = (num_samples - win_size) / hop_size + 1
 
236
 
237
  return denoise_frame, out_state1, out_state2
238
 
239
+ def forward_chunk_by_chunk(self, noisy: torch.Tensor):
240
+ noisy = self.signal_prepare(noisy)
241
+ # noisy shape: [b, 1, num_samples]
242
+ batch_size, _, num_samples_pad = noisy.shape
243
+ # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
244
+
245
+ t = (num_samples_pad - self.fft_size) // self.hop_size + 1
246
+
247
+ denoise_list = list()
248
+ out_state1 = None
249
+ out_state2 = None
250
+ overlap_size = self.fft_size - self.hop_size
251
+ denoise_cache = torch.zeros(size=(batch_size, overlap_size), dtype=noisy.dtype)
252
+ # denoise_list.append(torch.clone(denoise_cache))
253
+ for i in range(t):
254
+ begin = i * self.hop_size
255
+ end = begin + self.fft_size
256
+ sub_noisy = noisy[:, :, begin: end]
257
+ # noisy shape: [b, 1, frame_size]
258
+ with torch.no_grad():
259
+ sub_denoise_frame, out_state1, out_state2 = self.forward_chunk(sub_noisy, out_state1, out_state2)
260
+ # sub_denoise_frame shape: [b, fft_size, 1]
261
+ sub_denoise_frame = sub_denoise_frame[:, :, 0]
262
+ # sub_denoise_frame shape: [b, fft_size]
263
+
264
+ sub_denoise_frame[:, :overlap_size] += denoise_cache
265
+ denoise_out = sub_denoise_frame[:, :self.hop_size]
266
+ denoise_cache = sub_denoise_frame[:, self.hop_size:]
267
+ # denoise_cache shape: [b, hop_size]
268
+
269
+ denoise_list.append(denoise_out)
270
+
271
+ denoise = torch.concat(denoise_list, dim=-1)
272
+ # denoise shape: [b, num_samples]
273
+ denoise = torch.unsqueeze(denoise, dim=1)
274
+ # denoise shape: [b, 1, num_samples]
275
+ return denoise
276
+
277
  def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int):
278
  # overlap and add
279
 
 
342
 
343
 
344
  def main():
345
+ config = DTLNConfig()
346
+ model = DTLNPretrainedModel(config)
347
+ model.eval()
 
348
 
349
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
 
350
 
351
+ with torch.no_grad():
352
+ denoise = model.forward(noisy)
353
  print(f"denoise.shape: {denoise.shape}")
354
+ print(denoise[:, :, 300: 302])
355
+ print(denoise[:, :, 15680: 15682])
356
+ print(denoise[:, :, 15760: 15762])
357
+ print(denoise[:, :, 15840: 15842])
358
 
359
+ denoise = model.forward_chunk_by_chunk(noisy)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
360
  print(f"denoise.shape: {denoise.shape}")
361
+ # denoise = denoise[:, :, (config.fft_size - config.hop_size):]
362
+ print(denoise[:, :, 300: 302])
363
+ print(denoise[:, :, 15680: 15682])
364
+ print(denoise[:, :, 15760: 15762])
365
+ print(denoise[:, :, 15840: 15842])
366
+
367
  return
368
 
369