HoneyTian commited on
Commit
35a4689
·
1 Parent(s): f1a5461
examples/lstm/step_2_train_model.py CHANGED
@@ -26,6 +26,8 @@ import torchaudio
26
  from tqdm import tqdm
27
 
28
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
 
 
29
  from toolbox.torchaudio.metrics.pesq import run_pesq_score
30
  from toolbox.torchaudio.models.lstm.configuration_lstm import LstmConfig
31
  from toolbox.torchaudio.models.lstm.modeling_lstm import LstmPretrainedModel
@@ -72,95 +74,32 @@ def logging_config(file_dir: str):
72
 
73
 
74
  class CollateFunction(object):
75
- def __init__(self,
76
- n_fft: int = 512,
77
- win_length: int = 200,
78
- hop_length: int = 80,
79
- window_fn: str = "hamming",
80
- irm_beta: float = 1.0,
81
- epsilon: float = 1e-8,
82
- ):
83
- self.n_fft = n_fft
84
- self.win_length = win_length
85
- self.hop_length = hop_length
86
- self.window_fn = window_fn
87
- self.irm_beta = irm_beta
88
- self.epsilon = epsilon
89
-
90
- self.stft_mag = torchaudio.transforms.Spectrogram(
91
- n_fft=self.n_fft,
92
- win_length=self.win_length,
93
- hop_length=self.hop_length,
94
- power=1.0,
95
- window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
96
- )
97
- self.stft_complex = torchaudio.transforms.Spectrogram(
98
- n_fft=self.n_fft,
99
- win_length=self.win_length,
100
- hop_length=self.hop_length,
101
- power=None,
102
- window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
103
- )
104
-
105
- self.istft = torchaudio.transforms.InverseSpectrogram(
106
- n_fft=self.n_fft,
107
- win_length=self.win_length,
108
- hop_length=self.hop_length,
109
- window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
110
- )
111
 
112
  def __call__(self, batch: List[dict]):
113
- mag_noisy_audios = list()
114
- pha_noisy_audios = list()
115
- irm_gth = list()
116
-
117
  clean_audios = list()
 
 
118
 
119
  for sample in batch:
120
- noise_audio: torch.Tensor = sample["noise_wave"]
121
  clean_audio: torch.Tensor = sample["speech_wave"]
122
  noisy_audio: torch.Tensor = sample["mix_wave"]
123
- snr_db: float = sample["snr_db"]
124
-
125
- mag_noise = self.stft_mag.forward(noise_audio)
126
- mag_clean = self.stft_mag.forward(clean_audio)
127
- stft_noisy = self.stft_complex.forward(noisy_audio)
128
-
129
- irm_clean = mag_clean / (mag_noise + mag_clean + self.epsilon)
130
- irm_clean = torch.pow(irm_clean, self.irm_beta)
131
-
132
- real = torch.real(stft_noisy)
133
- imag = torch.imag(stft_noisy)
134
- mag_noisy = torch.sqrt(real ** 2 + imag ** 2)
135
- pha_noisy = torch.atan2(imag, real)
136
 
137
- mag_noisy_audios.append(mag_noisy)
138
- pha_noisy_audios.append(pha_noisy)
139
- irm_gth.append(irm_clean)
140
  clean_audios.append(clean_audio)
 
141
 
142
- mag_noisy_audios = torch.stack(mag_noisy_audios)
143
- pha_noisy_audios = torch.stack(pha_noisy_audios)
144
- irm_gth = torch.stack(irm_gth)
145
  clean_audios = torch.stack(clean_audios)
 
146
 
147
  # assert
148
- if torch.any(torch.isnan(mag_noisy_audios)):
149
- raise AssertionError("nan in mag_noisy_audios Tensor")
150
- if torch.any(torch.isnan(pha_noisy_audios)):
151
- raise AssertionError("nan in pha_noisy_audios Tensor")
152
- if torch.any(torch.isnan(irm_gth)):
153
- raise AssertionError("nan in irm_gth Tensor")
154
- if torch.any(torch.isnan(clean_audios)):
155
- raise AssertionError("nan in clean_audios Tensor")
156
-
157
- return mag_noisy_audios, pha_noisy_audios, irm_gth, clean_audios
158
-
159
- def enhance(self, mag_noisy: torch.Tensor, pha_noisy: torch.Tensor, irm_speech: torch.Tensor):
160
- mag_denoise = mag_noisy * irm_speech
161
- stft_denoise = mag_denoise * torch.exp((1j * pha_noisy))
162
- denoise = self.istft.forward(stft_denoise)
163
- return denoise
164
 
165
 
166
  collate_fn = CollateFunction()
@@ -282,8 +221,14 @@ def main():
282
  else:
283
  raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
284
 
285
- mse_loss_fn = nn.MSELoss(
286
- reduction="mean",
 
 
 
 
 
 
287
  ).to(device)
288
 
289
  # training loop
@@ -291,6 +236,8 @@ def main():
291
 
292
  average_pesq_score = 1000000000
293
  average_loss = 1000000000
 
 
294
 
295
  model_list = list()
296
  best_epoch_idx = None
@@ -311,6 +258,8 @@ def main():
311
 
312
  total_pesq_score = 0.
313
  total_loss = 0.
 
 
314
  total_batches = 0.
315
 
316
  progress_bar_train = tqdm(
@@ -318,15 +267,19 @@ def main():
318
  desc="Training; epoch: {}".format(epoch_idx),
319
  )
320
  for train_batch in train_data_loader:
321
- mag_noisy_audios, pha_noisy_audios, irm_gth, clean_audios = train_batch
322
- mag_noisy_audios = mag_noisy_audios.to(device)
323
- pha_noisy_audios = pha_noisy_audios.to(device)
324
- irm_gth = irm_gth.to(device)
325
- clean_audios = clean_audios.to(device)
 
 
 
326
 
327
- irm = model.forward(mag_noisy_audios)
328
- denoise_audios = collate_fn.enhance(mag_noisy_audios, pha_noisy_audios, irm)
329
- loss = mse_loss_fn.forward(irm, irm_gth)
 
330
 
331
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
332
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
@@ -340,16 +293,22 @@ def main():
340
 
341
  total_pesq_score += pesq_score
342
  total_loss += loss.item()
 
 
343
  total_batches += 1
344
 
345
  average_pesq_score = round(total_pesq_score / total_batches, 4)
346
  average_loss = round(total_loss / total_batches, 4)
 
 
347
 
348
  progress_bar_train.update(1)
349
  progress_bar_train.set_postfix({
350
  "lr": lr_scheduler.get_last_lr()[0],
351
  "pesq_score": average_pesq_score,
352
  "loss": average_loss,
 
 
353
  })
354
 
355
  # evaluation
@@ -360,6 +319,8 @@ def main():
360
 
361
  total_pesq_score = 0.
362
  total_loss = 0.
 
 
363
  total_batches = 0.
364
 
365
  progress_bar_train.close()
@@ -368,43 +329,48 @@ def main():
368
  )
369
 
370
  for eval_batch in valid_data_loader:
371
- mag_noisy_audios, pha_noisy_audios, irm_gth, clean_audios = eval_batch
372
- mag_noisy_audios = mag_noisy_audios.to(device)
373
- pha_noisy_audios = pha_noisy_audios.to(device)
374
- irm_gth = irm_gth.to(device)
375
- clean_audios = clean_audios.to(device)
376
 
377
- with torch.no_grad():
378
- irm = model.forward(mag_noisy_audios)
379
- denoise_audios = collate_fn.enhance(mag_noisy_audios, pha_noisy_audios, irm)
380
- loss = mse_loss_fn.forward(irm, irm_gth)
 
 
 
 
 
381
 
382
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
383
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
384
  pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
385
 
386
- optimizer.zero_grad()
387
- loss.backward()
388
- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
389
- optimizer.step()
390
- lr_scheduler.step()
391
-
392
  total_pesq_score += pesq_score
393
  total_loss += loss.item()
 
 
394
  total_batches += 1
395
 
396
  average_pesq_score = round(total_pesq_score / total_batches, 4)
397
  average_loss = round(total_loss / total_batches, 4)
 
 
398
 
399
  progress_bar_eval.update(1)
400
  progress_bar_eval.set_postfix({
401
  "lr": lr_scheduler.get_last_lr()[0],
402
  "pesq_score": average_pesq_score,
403
  "loss": average_loss,
 
 
404
  })
405
 
406
  total_pesq_score = 0.
407
  total_loss = 0.
 
 
408
  total_batches = 0.
409
 
410
  progress_bar_eval.close()
 
26
  from tqdm import tqdm
27
 
28
  from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
29
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
30
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
31
  from toolbox.torchaudio.metrics.pesq import run_pesq_score
32
  from toolbox.torchaudio.models.lstm.configuration_lstm import LstmConfig
33
  from toolbox.torchaudio.models.lstm.modeling_lstm import LstmPretrainedModel
 
74
 
75
 
76
  class CollateFunction(object):
77
+ def __init__(self):
78
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  def __call__(self, batch: List[dict]):
 
 
 
 
81
  clean_audios = list()
82
+ noisy_audios = list()
83
+ snr_db_list = list()
84
 
85
  for sample in batch:
86
+ # noise_wave: torch.Tensor = sample["noise_wave"]
87
  clean_audio: torch.Tensor = sample["speech_wave"]
88
  noisy_audio: torch.Tensor = sample["mix_wave"]
89
+ # snr_db: float = sample["snr_db"]
 
 
 
 
 
 
 
 
 
 
 
 
90
 
 
 
 
91
  clean_audios.append(clean_audio)
92
+ noisy_audios.append(noisy_audio)
93
 
 
 
 
94
  clean_audios = torch.stack(clean_audios)
95
+ noisy_audios = torch.stack(noisy_audios)
96
 
97
  # assert
98
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
99
+ raise AssertionError("nan or inf in clean_audios")
100
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
101
+ raise AssertionError("nan or inf in noisy_audios")
102
+ return clean_audios, noisy_audios
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  collate_fn = CollateFunction()
 
221
  else:
222
  raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
223
 
224
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
225
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
226
+ fft_size_list=[256, 512, 1024],
227
+ win_size_list=[256, 512, 1024],
228
+ hop_size_list=[128, 256, 512],
229
+ factor_sc=1.5,
230
+ factor_mag=1.0,
231
+ reduction="mean"
232
  ).to(device)
233
 
234
  # training loop
 
236
 
237
  average_pesq_score = 1000000000
238
  average_loss = 1000000000
239
+ average_mr_stft_loss = 1000000000
240
+ average_neg_si_snr_loss = 1000000000
241
 
242
  model_list = list()
243
  best_epoch_idx = None
 
258
 
259
  total_pesq_score = 0.
260
  total_loss = 0.
261
+ total_mr_stft_loss = 0.
262
+ total_neg_si_snr_loss = 0.
263
  total_batches = 0.
264
 
265
  progress_bar_train = tqdm(
 
267
  desc="Training; epoch: {}".format(epoch_idx),
268
  )
269
  for train_batch in train_data_loader:
270
+ clean_audios, noisy_audios = train_batch
271
+ clean_audios: torch.Tensor = clean_audios.to(device)
272
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
273
+
274
+ denoise_audios, _, _ = model.forward(noisy_audios)
275
+
276
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
277
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
278
 
279
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss
280
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
281
+ logger.info(f"find nan or inf in loss.")
282
+ continue
283
 
284
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
285
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
 
293
 
294
  total_pesq_score += pesq_score
295
  total_loss += loss.item()
296
+ total_mr_stft_loss += mr_stft_loss.item()
297
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
298
  total_batches += 1
299
 
300
  average_pesq_score = round(total_pesq_score / total_batches, 4)
301
  average_loss = round(total_loss / total_batches, 4)
302
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
303
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
304
 
305
  progress_bar_train.update(1)
306
  progress_bar_train.set_postfix({
307
  "lr": lr_scheduler.get_last_lr()[0],
308
  "pesq_score": average_pesq_score,
309
  "loss": average_loss,
310
+ "mr_stft_loss": average_mr_stft_loss,
311
+ "neg_si_snr_loss": average_neg_si_snr_loss,
312
  })
313
 
314
  # evaluation
 
319
 
320
  total_pesq_score = 0.
321
  total_loss = 0.
322
+ total_mr_stft_loss = 0.
323
+ total_neg_si_snr_loss = 0.
324
  total_batches = 0.
325
 
326
  progress_bar_train.close()
 
329
  )
330
 
331
  for eval_batch in valid_data_loader:
332
+ clean_audios, noisy_audios = eval_batch
333
+ clean_audios: torch.Tensor = clean_audios.to(device)
334
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
 
 
335
 
336
+ denoise_audios, _, _ = model.forward(noisy_audios)
337
+
338
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
339
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
340
+
341
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss
342
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
343
+ logger.info(f"find nan or inf in loss.")
344
+ continue
345
 
346
  denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
347
  clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
348
  pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
349
 
 
 
 
 
 
 
350
  total_pesq_score += pesq_score
351
  total_loss += loss.item()
352
+ total_mr_stft_loss += mr_stft_loss.item()
353
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
354
  total_batches += 1
355
 
356
  average_pesq_score = round(total_pesq_score / total_batches, 4)
357
  average_loss = round(total_loss / total_batches, 4)
358
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
359
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
360
 
361
  progress_bar_eval.update(1)
362
  progress_bar_eval.set_postfix({
363
  "lr": lr_scheduler.get_last_lr()[0],
364
  "pesq_score": average_pesq_score,
365
  "loss": average_loss,
366
+ "mr_stft_loss": average_mr_stft_loss,
367
+ "neg_si_snr_loss": average_neg_si_snr_loss,
368
  })
369
 
370
  total_pesq_score = 0.
371
  total_loss = 0.
372
+ total_mr_stft_loss = 0.
373
+ total_neg_si_snr_loss = 0.
374
  total_batches = 0.
375
 
376
  progress_bar_eval.close()
examples/lstm/yaml/config.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "lstm"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ segment_size: 32000
6
+ n_fft: 320
7
+ win_size: 320
8
+ hop_size: 160
9
+ win_type: hann
10
+
11
+ # data
12
+ max_snr_db: 20
13
+ min_snr_db: -10
14
+
15
+ # model
16
+ hidden_size: 512
17
+ num_layers: 3
18
+ dropout: 0.1
19
+
20
+ # train
21
+ max_epochs: 100
22
+ batch_size: 32
23
+ num_workers: 4
24
+ seed: 1234
25
+
26
+ lr: 0.001
27
+ lr_scheduler: CosineAnnealingLR
28
+ lr_scheduler_kwargs: {}
29
+
30
+ weight_decay: 0.00001
31
+ clip_grad_norm: 10.0
32
+ eval_steps: 25000
examples/rnnoise/run.sh ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir
6
+
7
+ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir
8
+
9
+ sh run.sh --stage 1 --stop_stage 3 --system_version centos --file_folder_name file_dir \
10
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
+ --speech_dir "/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train"
12
+
13
+
14
+ END
15
+
16
+
17
+ # params
18
+ system_version="windows";
19
+ verbose=true;
20
+ stage=0 # start from 0 if you need to start from data preparation
21
+ stop_stage=9
22
+
23
+ work_dir="$(pwd)"
24
+ file_folder_name=file_folder_name
25
+ final_model_name=final_model_name
26
+ config_file="yaml/config.yaml"
27
+ limit=10
28
+
29
+ noise_dir=/data/tianxing/HuggingDatasets/nx_noise/data/noise
30
+ speech_dir=/data/tianxing/HuggingDatasets/aishell/data_aishell/wav/train
31
+
32
+ nohup_name=nohup.out
33
+
34
+ # model params
35
+ batch_size=64
36
+ max_epochs=200
37
+ save_top_k=10
38
+ patience=5
39
+
40
+
41
+ # parse options
42
+ while true; do
43
+ [ -z "${1:-}" ] && break; # break if there are no arguments
44
+ case "$1" in
45
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
46
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
47
+ old_value="(eval echo \\$$name)";
48
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
49
+ was_bool=true;
50
+ else
51
+ was_bool=false;
52
+ fi
53
+
54
+ # Set the variable to the right value-- the escaped quotes make it work if
55
+ # the option had spaces, like --cmd "queue.pl -sync y"
56
+ eval "${name}=\"$2\"";
57
+
58
+ # Check that Boolean-valued arguments are really Boolean.
59
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
60
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
61
+ exit 1;
62
+ fi
63
+ shift 2;
64
+ ;;
65
+
66
+ *) break;
67
+ esac
68
+ done
69
+
70
+ file_dir="${work_dir}/${file_folder_name}"
71
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
72
+ evaluation_audio_dir="${file_dir}/evaluation_audio"
73
+
74
+ dataset="${file_dir}/dataset.xlsx"
75
+ train_dataset="${file_dir}/train.xlsx"
76
+ valid_dataset="${file_dir}/valid.xlsx"
77
+
78
+ $verbose && echo "system_version: ${system_version}"
79
+ $verbose && echo "file_folder_name: ${file_folder_name}"
80
+
81
+ if [ $system_version == "windows" ]; then
82
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/nx_denoise/Scripts/python.exe'
83
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
84
+ #source /data/local/bin/nx_denoise/bin/activate
85
+ alias python3='/data/local/bin/nx_denoise/bin/python3'
86
+ fi
87
+
88
+
89
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
90
+ $verbose && echo "stage 1: prepare data"
91
+ cd "${work_dir}" || exit 1
92
+ python3 step_1_prepare_data.py \
93
+ --file_dir "${file_dir}" \
94
+ --noise_dir "${noise_dir}" \
95
+ --speech_dir "${speech_dir}" \
96
+ --train_dataset "${train_dataset}" \
97
+ --valid_dataset "${valid_dataset}" \
98
+
99
+ fi
100
+
101
+
102
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
103
+ $verbose && echo "stage 2: train model"
104
+ cd "${work_dir}" || exit 1
105
+ python3 step_2_train_model.py \
106
+ --train_dataset "${train_dataset}" \
107
+ --valid_dataset "${valid_dataset}" \
108
+ --serialization_dir "${file_dir}" \
109
+ --config_file "${config_file}" \
110
+
111
+ fi
112
+
113
+
114
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
115
+ $verbose && echo "stage 3: test model"
116
+ cd "${work_dir}" || exit 1
117
+ python3 step_3_evaluation.py \
118
+ --valid_dataset "${valid_dataset}" \
119
+ --model_dir "${file_dir}/best" \
120
+ --evaluation_audio_dir "${evaluation_audio_dir}" \
121
+ --limit "${limit}" \
122
+
123
+ fi
124
+
125
+
126
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
127
+ $verbose && echo "stage 4: export model"
128
+ cd "${work_dir}" || exit 1
129
+ python3 step_5_export_models.py \
130
+ --vocabulary_dir "${vocabulary_dir}" \
131
+ --model_dir "${file_dir}/best" \
132
+ --serialization_dir "${file_dir}" \
133
+
134
+ fi
135
+
136
+
137
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
138
+ $verbose && echo "stage 5: collect files"
139
+ cd "${work_dir}" || exit 1
140
+
141
+ mkdir -p ${final_model_dir}
142
+
143
+ cp "${file_dir}/best"/* "${final_model_dir}"
144
+ cp -r "${file_dir}/vocabulary" "${final_model_dir}"
145
+
146
+ cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
147
+
148
+ cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
149
+ cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
150
+ cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
151
+ cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
152
+
153
+ cd "${final_model_dir}/.." || exit 1;
154
+
155
+ if [ -e "${final_model_name}.zip" ]; then
156
+ rm -rf "${final_model_name}_backup.zip"
157
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
158
+ fi
159
+
160
+ zip -r "${final_model_name}.zip" "${final_model_name}"
161
+ rm -rf "${final_model_name}"
162
+
163
+ fi
164
+
165
+
166
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
167
+ $verbose && echo "stage 6: clear file_dir"
168
+ cd "${work_dir}" || exit 1
169
+
170
+ rm -rf "${file_dir}";
171
+
172
+ fi
examples/rnnoise/step_1_prepare_data.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import random
7
+ import sys
8
+ import shutil
9
+
10
+ pwd = os.path.abspath(os.path.dirname(__file__))
11
+ sys.path.append(os.path.join(pwd, "../../"))
12
+
13
+ import pandas as pd
14
+ from scipy.io import wavfile
15
+ from tqdm import tqdm
16
+ import librosa
17
+
18
+ from project_settings import project_path
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--file_dir", default="./", type=str)
24
+
25
+ parser.add_argument(
26
+ "--noise_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\noise",
28
+ type=str
29
+ )
30
+ parser.add_argument(
31
+ "--speech_dir",
32
+ default=r"E:\programmer\asr_datasets\aishell\data_aishell\wav\train",
33
+ type=str
34
+ )
35
+
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--duration", default=2.0, type=float)
40
+ parser.add_argument("--min_nsr_db", default=-20, type=float)
41
+ parser.add_argument("--max_nsr_db", default=5, type=float)
42
+
43
+ parser.add_argument("--target_sample_rate", default=8000, type=int)
44
+
45
+ args = parser.parse_args()
46
+ return args
47
+
48
+
49
+ def filename_generator(data_dir: str):
50
+ data_dir = Path(data_dir)
51
+ for filename in data_dir.glob("**/*.wav"):
52
+ yield filename.as_posix()
53
+
54
+
55
+ def target_second_signal_generator(data_dir: str, duration: int = 2, sample_rate: int = 8000):
56
+ data_dir = Path(data_dir)
57
+ for filename in data_dir.glob("**/*.wav"):
58
+ signal, _ = librosa.load(filename.as_posix(), sr=sample_rate)
59
+ raw_duration = librosa.get_duration(y=signal, sr=sample_rate)
60
+
61
+ if raw_duration < duration:
62
+ # print(f"duration less than {duration} s. skip filename: {filename.as_posix()}")
63
+ continue
64
+ if signal.ndim != 1:
65
+ raise AssertionError(f"expected ndim 1, instead of {signal.ndim}")
66
+
67
+ signal_length = len(signal)
68
+ win_size = int(duration * sample_rate)
69
+ for begin in range(0, signal_length - win_size, win_size):
70
+ row = {
71
+ "filename": filename.as_posix(),
72
+ "raw_duration": round(raw_duration, 4),
73
+ "offset": round(begin / sample_rate, 4),
74
+ "duration": round(duration, 4),
75
+ }
76
+ yield row
77
+
78
+
79
+ def get_dataset(args):
80
+ file_dir = Path(args.file_dir)
81
+ file_dir.mkdir(exist_ok=True)
82
+
83
+ noise_dir = Path(args.noise_dir)
84
+ speech_dir = Path(args.speech_dir)
85
+
86
+ noise_generator = target_second_signal_generator(
87
+ noise_dir.as_posix(),
88
+ duration=args.duration,
89
+ sample_rate=args.target_sample_rate
90
+ )
91
+ speech_generator = target_second_signal_generator(
92
+ speech_dir.as_posix(),
93
+ duration=args.duration,
94
+ sample_rate=args.target_sample_rate
95
+ )
96
+
97
+ dataset = list()
98
+
99
+ count = 0
100
+ process_bar = tqdm(desc="build dataset excel")
101
+ for noise, speech in zip(noise_generator, speech_generator):
102
+
103
+ noise_filename = noise["filename"]
104
+ noise_raw_duration = noise["raw_duration"]
105
+ noise_offset = noise["offset"]
106
+ noise_duration = noise["duration"]
107
+
108
+ speech_filename = speech["filename"]
109
+ speech_raw_duration = speech["raw_duration"]
110
+ speech_offset = speech["offset"]
111
+ speech_duration = speech["duration"]
112
+
113
+ random1 = random.random()
114
+ random2 = random.random()
115
+
116
+ row = {
117
+ "noise_filename": noise_filename,
118
+ "noise_raw_duration": noise_raw_duration,
119
+ "noise_offset": noise_offset,
120
+ "noise_duration": noise_duration,
121
+
122
+ "speech_filename": speech_filename,
123
+ "speech_raw_duration": speech_raw_duration,
124
+ "speech_offset": speech_offset,
125
+ "speech_duration": speech_duration,
126
+
127
+ "snr_db": random.uniform(args.min_nsr_db, args.max_nsr_db),
128
+
129
+ "random1": random1,
130
+ "random2": random2,
131
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
132
+ }
133
+ dataset.append(row)
134
+ count += 1
135
+ duration_seconds = count * args.duration
136
+ duration_hours = duration_seconds / 3600
137
+
138
+ process_bar.update(n=1)
139
+ process_bar.set_postfix({
140
+ # "duration_seconds": round(duration_seconds, 4),
141
+ "duration_hours": round(duration_hours, 4),
142
+
143
+ })
144
+
145
+ dataset = pd.DataFrame(dataset)
146
+ dataset = dataset.sort_values(by=["random1"], ascending=False)
147
+ dataset.to_excel(
148
+ file_dir / "dataset.xlsx",
149
+ index=False,
150
+ )
151
+ return
152
+
153
+
154
+
155
+ def split_dataset(args):
156
+ """分割训练集, 测试集"""
157
+ file_dir = Path(args.file_dir)
158
+ file_dir.mkdir(exist_ok=True)
159
+
160
+ df = pd.read_excel(file_dir / "dataset.xlsx")
161
+
162
+ train = list()
163
+ test = list()
164
+
165
+ for i, row in df.iterrows():
166
+ flag = row["flag"]
167
+ if flag == "TRAIN":
168
+ train.append(row)
169
+ else:
170
+ test.append(row)
171
+
172
+ train = pd.DataFrame(train)
173
+ train.to_excel(
174
+ args.train_dataset,
175
+ index=False,
176
+ # encoding="utf_8_sig"
177
+ )
178
+ test = pd.DataFrame(test)
179
+ test.to_excel(
180
+ args.valid_dataset,
181
+ index=False,
182
+ # encoding="utf_8_sig"
183
+ )
184
+
185
+ return
186
+
187
+
188
+ def main():
189
+ args = get_args()
190
+
191
+ get_dataset(args)
192
+ split_dataset(args)
193
+ return
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
examples/rnnoise/step_2_train_model.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
5
+ """
6
+ import argparse
7
+ import json
8
+ import logging
9
+ from logging.handlers import TimedRotatingFileHandler
10
+ import os
11
+ import platform
12
+ from pathlib import Path
13
+ import random
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ from torch.utils.data.dataloader import DataLoader
25
+ import torchaudio
26
+ from tqdm import tqdm
27
+
28
+ from toolbox.torch.utils.data.dataset.denoise_jsonl_dataset import DenoiseJsonlDataset
29
+ from toolbox.torchaudio.losses.snr import NegativeSISNRLoss
30
+ from toolbox.torchaudio.losses.spectral import LSDLoss, MultiResolutionSTFTLoss
31
+ from toolbox.torchaudio.metrics.pesq import run_pesq_score
32
+ from toolbox.torchaudio.models.rnnoise.configuration_rnnoise import RNNoiseConfig
33
+ from toolbox.torchaudio.models.rnnoise.modeling_rnnoise import RNNoisePretrainedModel
34
+
35
+
36
+ def get_args():
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument("--train_dataset", default="train.jsonl", type=str)
39
+ parser.add_argument("--valid_dataset", default="valid.jsonl", type=str)
40
+ parser.add_argument("--max_epochs", default=100, type=int)
41
+
42
+ parser.add_argument("--batch_size", default=64, type=int)
43
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
44
+ parser.add_argument("--num_serialized_models_to_keep", default=15, type=int)
45
+ parser.add_argument("--patience", default=10, type=int)
46
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
47
+ parser.add_argument("--seed", default=0, type=int)
48
+
49
+ parser.add_argument("--config_file", default="config.yaml", type=str)
50
+
51
+ args = parser.parse_args()
52
+ return args
53
+
54
+
55
+ def logging_config(file_dir: str):
56
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
57
+
58
+ logging.basicConfig(format=fmt,
59
+ datefmt="%m/%d/%Y %H:%M:%S",
60
+ level=logging.INFO)
61
+ file_handler = TimedRotatingFileHandler(
62
+ filename=os.path.join(file_dir, "main.log"),
63
+ encoding="utf-8",
64
+ when="D",
65
+ interval=1,
66
+ backupCount=7
67
+ )
68
+ file_handler.setLevel(logging.INFO)
69
+ file_handler.setFormatter(logging.Formatter(fmt))
70
+ logger = logging.getLogger(__name__)
71
+ logger.addHandler(file_handler)
72
+
73
+ return logger
74
+
75
+
76
+ class CollateFunction(object):
77
+ def __init__(self):
78
+ pass
79
+
80
+ def __call__(self, batch: List[dict]):
81
+ clean_audios = list()
82
+ noisy_audios = list()
83
+ snr_db_list = list()
84
+
85
+ for sample in batch:
86
+ # noise_wave: torch.Tensor = sample["noise_wave"]
87
+ clean_audio: torch.Tensor = sample["speech_wave"]
88
+ noisy_audio: torch.Tensor = sample["mix_wave"]
89
+ # snr_db: float = sample["snr_db"]
90
+
91
+ clean_audios.append(clean_audio)
92
+ noisy_audios.append(noisy_audio)
93
+
94
+ clean_audios = torch.stack(clean_audios)
95
+ noisy_audios = torch.stack(noisy_audios)
96
+
97
+ # assert
98
+ if torch.any(torch.isnan(clean_audios)) or torch.any(torch.isinf(clean_audios)):
99
+ raise AssertionError("nan or inf in clean_audios")
100
+ if torch.any(torch.isnan(noisy_audios)) or torch.any(torch.isinf(noisy_audios)):
101
+ raise AssertionError("nan or inf in noisy_audios")
102
+ return clean_audios, noisy_audios
103
+
104
+
105
+ collate_fn = CollateFunction()
106
+
107
+
108
+ def main():
109
+ args = get_args()
110
+
111
+ config = RNNoiseConfig.from_pretrained(
112
+ pretrained_model_name_or_path=args.config_file,
113
+ )
114
+
115
+ serialization_dir = Path(args.serialization_dir)
116
+ serialization_dir.mkdir(parents=True, exist_ok=True)
117
+
118
+ logger = logging_config(serialization_dir)
119
+
120
+ random.seed(args.seed)
121
+ np.random.seed(args.seed)
122
+ torch.manual_seed(args.seed)
123
+ logger.info("set seed: {}".format(args.seed))
124
+
125
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
+ n_gpu = torch.cuda.device_count()
127
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
128
+
129
+ # datasets
130
+ logger.info("prepare datasets")
131
+ train_dataset = DenoiseJsonlDataset(
132
+ jsonl_file=args.train_dataset,
133
+ expected_sample_rate=config.sample_rate,
134
+ max_wave_value=32768.0,
135
+ min_snr_db=config.min_snr_db,
136
+ max_snr_db=config.max_snr_db,
137
+ # skip=225000,
138
+ )
139
+ valid_dataset = DenoiseJsonlDataset(
140
+ jsonl_file=args.valid_dataset,
141
+ expected_sample_rate=config.sample_rate,
142
+ max_wave_value=32768.0,
143
+ min_snr_db=config.min_snr_db,
144
+ max_snr_db=config.max_snr_db,
145
+ )
146
+ train_data_loader = DataLoader(
147
+ dataset=train_dataset,
148
+ batch_size=config.batch_size,
149
+ # shuffle=True,
150
+ sampler=None,
151
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
152
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
153
+ collate_fn=collate_fn,
154
+ pin_memory=False,
155
+ prefetch_factor=None if platform.system() == "Windows" else 2,
156
+ )
157
+ valid_data_loader = DataLoader(
158
+ dataset=valid_dataset,
159
+ batch_size=config.batch_size,
160
+ # shuffle=True,
161
+ sampler=None,
162
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
163
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
164
+ collate_fn=collate_fn,
165
+ pin_memory=False,
166
+ prefetch_factor=None if platform.system() == "Windows" else 2,
167
+ )
168
+
169
+ # models
170
+ logger.info(f"prepare models. config_file: {args.config_file}")
171
+ model = RNNoisePretrainedModel(
172
+ config=config,
173
+ )
174
+ model.to(device)
175
+ model.train()
176
+
177
+ # optimizer
178
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, evaluation_metric")
179
+ optimizer = torch.optim.AdamW(model.parameters(), config.lr)
180
+
181
+ # resume training
182
+ last_step_idx = -1
183
+ last_epoch = -1
184
+ for step_idx_str in serialization_dir.glob("steps-*"):
185
+ step_idx_str = Path(step_idx_str)
186
+ step_idx = step_idx_str.stem.split("-")[1]
187
+ step_idx = int(step_idx)
188
+ if step_idx > last_step_idx:
189
+ last_step_idx = step_idx
190
+ # last_epoch = 1
191
+
192
+ if last_step_idx != -1:
193
+ logger.info(f"resume from steps-{last_step_idx}.")
194
+ model_pt = serialization_dir / f"steps-{last_step_idx}/model.pt"
195
+ optimizer_pth = serialization_dir / f"steps-{last_step_idx}/optimizer.pth"
196
+
197
+ logger.info(f"load state dict for model.")
198
+ with open(model_pt.as_posix(), "rb") as f:
199
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
200
+ model.load_state_dict(state_dict, strict=True)
201
+
202
+ logger.info(f"load state dict for optimizer.")
203
+ with open(optimizer_pth.as_posix(), "rb") as f:
204
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
205
+ optimizer.load_state_dict(state_dict)
206
+
207
+ if config.lr_scheduler == "CosineAnnealingLR":
208
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
209
+ optimizer,
210
+ last_epoch=last_epoch,
211
+ # T_max=10 * config.eval_steps,
212
+ # eta_min=0.01 * config.lr,
213
+ **config.lr_scheduler_kwargs,
214
+ )
215
+ elif config.lr_scheduler == "MultiStepLR":
216
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
217
+ optimizer,
218
+ last_epoch=last_epoch,
219
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
220
+ )
221
+ else:
222
+ raise AssertionError(f"invalid lr_scheduler: {config.lr_scheduler}")
223
+
224
+ neg_si_snr_loss_fn = NegativeSISNRLoss(reduction="mean").to(device)
225
+ mr_stft_loss_fn = MultiResolutionSTFTLoss(
226
+ fft_size_list=[256, 512, 1024],
227
+ win_size_list=[256, 512, 1024],
228
+ hop_size_list=[128, 256, 512],
229
+ factor_sc=1.5,
230
+ factor_mag=1.0,
231
+ reduction="mean"
232
+ ).to(device)
233
+
234
+ # training loop
235
+ logger.info("training")
236
+
237
+ average_pesq_score = 1000000000
238
+ average_loss = 1000000000
239
+ average_mr_stft_loss = 1000000000
240
+ average_neg_si_snr_loss = 1000000000
241
+
242
+ model_list = list()
243
+ best_epoch_idx = None
244
+ best_step_idx = None
245
+ best_metric = None
246
+ patience_count = 0
247
+
248
+ step_idx = 0 if last_step_idx == -1 else last_step_idx
249
+
250
+ logger.info("training")
251
+ early_stop_flag = False
252
+ for epoch_idx in range(max(0, last_epoch+1), config.max_epochs):
253
+ if early_stop_flag:
254
+ break
255
+
256
+ # train
257
+ model.train()
258
+
259
+ total_pesq_score = 0.
260
+ total_loss = 0.
261
+ total_mr_stft_loss = 0.
262
+ total_neg_si_snr_loss = 0.
263
+ total_batches = 0.
264
+
265
+ progress_bar_train = tqdm(
266
+ initial=step_idx,
267
+ desc="Training; epoch: {}".format(epoch_idx),
268
+ )
269
+ for train_batch in train_data_loader:
270
+ clean_audios, noisy_audios = train_batch
271
+ clean_audios: torch.Tensor = clean_audios.to(device)
272
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
273
+
274
+ denoise_audios, _, _ = model.forward(noisy_audios)
275
+
276
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
277
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
278
+
279
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss
280
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
281
+ logger.info(f"find nan or inf in loss.")
282
+ continue
283
+
284
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
285
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
286
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
287
+
288
+ optimizer.zero_grad()
289
+ loss.backward()
290
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.clip_grad_norm)
291
+ optimizer.step()
292
+ lr_scheduler.step()
293
+
294
+ total_pesq_score += pesq_score
295
+ total_loss += loss.item()
296
+ total_mr_stft_loss += mr_stft_loss.item()
297
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
298
+ total_batches += 1
299
+
300
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
301
+ average_loss = round(total_loss / total_batches, 4)
302
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
303
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
304
+
305
+ progress_bar_train.update(1)
306
+ progress_bar_train.set_postfix({
307
+ "lr": lr_scheduler.get_last_lr()[0],
308
+ "pesq_score": average_pesq_score,
309
+ "loss": average_loss,
310
+ "mr_stft_loss": average_mr_stft_loss,
311
+ "neg_si_snr_loss": average_neg_si_snr_loss,
312
+ })
313
+
314
+ # evaluation
315
+ step_idx += 1
316
+ if step_idx % config.eval_steps == 0:
317
+ with torch.no_grad():
318
+ torch.cuda.empty_cache()
319
+
320
+ total_pesq_score = 0.
321
+ total_loss = 0.
322
+ total_mr_stft_loss = 0.
323
+ total_neg_si_snr_loss = 0.
324
+ total_batches = 0.
325
+
326
+ progress_bar_train.close()
327
+ progress_bar_eval = tqdm(
328
+ desc="Evaluation; steps-{}k".format(int(step_idx / 1000)),
329
+ )
330
+
331
+ for eval_batch in valid_data_loader:
332
+ clean_audios, noisy_audios = eval_batch
333
+ clean_audios: torch.Tensor = clean_audios.to(device)
334
+ noisy_audios: torch.Tensor = noisy_audios.to(device)
335
+
336
+ denoise_audios, _, _ = model.forward(noisy_audios)
337
+
338
+ mr_stft_loss = mr_stft_loss_fn.forward(denoise_audios, clean_audios)
339
+ neg_si_snr_loss = neg_si_snr_loss_fn.forward(denoise_audios, clean_audios)
340
+
341
+ loss = 1.0 * mr_stft_loss + 1.0 * neg_si_snr_loss
342
+ if torch.any(torch.isnan(loss)) or torch.any(torch.isinf(loss)):
343
+ logger.info(f"find nan or inf in loss.")
344
+ continue
345
+
346
+ denoise_audios_list_r = list(denoise_audios.detach().cpu().numpy())
347
+ clean_audios_list_r = list(clean_audios.detach().cpu().numpy())
348
+ pesq_score = run_pesq_score(clean_audios_list_r, denoise_audios_list_r, sample_rate=config.sample_rate, mode="nb")
349
+
350
+ total_pesq_score += pesq_score
351
+ total_loss += loss.item()
352
+ total_mr_stft_loss += mr_stft_loss.item()
353
+ total_neg_si_snr_loss += neg_si_snr_loss.item()
354
+ total_batches += 1
355
+
356
+ average_pesq_score = round(total_pesq_score / total_batches, 4)
357
+ average_loss = round(total_loss / total_batches, 4)
358
+ average_mr_stft_loss = round(total_mr_stft_loss / total_batches, 4)
359
+ average_neg_si_snr_loss = round(total_neg_si_snr_loss / total_batches, 4)
360
+
361
+ progress_bar_eval.update(1)
362
+ progress_bar_eval.set_postfix({
363
+ "lr": lr_scheduler.get_last_lr()[0],
364
+ "pesq_score": average_pesq_score,
365
+ "loss": average_loss,
366
+ "mr_stft_loss": average_mr_stft_loss,
367
+ "neg_si_snr_loss": average_neg_si_snr_loss,
368
+ })
369
+
370
+ total_pesq_score = 0.
371
+ total_loss = 0.
372
+ total_mr_stft_loss = 0.
373
+ total_neg_si_snr_loss = 0.
374
+ total_batches = 0.
375
+
376
+ progress_bar_eval.close()
377
+ progress_bar_train = tqdm(
378
+ initial=progress_bar_train.n,
379
+ postfix=progress_bar_train.postfix,
380
+ desc=progress_bar_train.desc,
381
+ )
382
+
383
+ # save path
384
+ epoch_dir = serialization_dir / "epoch-{}".format(epoch_idx)
385
+ epoch_dir.mkdir(parents=True, exist_ok=False)
386
+
387
+ # save models
388
+ model.save_pretrained(epoch_dir.as_posix())
389
+
390
+ model_list.append(epoch_dir)
391
+ if len(model_list) >= args.num_serialized_models_to_keep:
392
+ model_to_delete: Path = model_list.pop(0)
393
+ shutil.rmtree(model_to_delete.as_posix())
394
+
395
+ # save metric
396
+ if best_metric is None:
397
+ best_epoch_idx = epoch_idx
398
+ best_step_idx = step_idx
399
+ best_metric = average_pesq_score
400
+ elif average_pesq_score >= best_metric:
401
+ # great is better.
402
+ best_epoch_idx = epoch_idx
403
+ best_step_idx = step_idx
404
+ best_metric = average_pesq_score
405
+ else:
406
+ pass
407
+
408
+ metrics = {
409
+ "epoch_idx": epoch_idx,
410
+ "best_epoch_idx": best_epoch_idx,
411
+ "best_step_idx": best_step_idx,
412
+ "pesq_score": average_pesq_score,
413
+ "loss": average_loss,
414
+ }
415
+ metrics_filename = epoch_dir / "metrics_epoch.json"
416
+ with open(metrics_filename, "w", encoding="utf-8") as f:
417
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
418
+
419
+ # save best
420
+ best_dir = serialization_dir / "best"
421
+ if best_epoch_idx == epoch_idx:
422
+ if best_dir.exists():
423
+ shutil.rmtree(best_dir)
424
+ shutil.copytree(epoch_dir, best_dir)
425
+
426
+ # early stop
427
+ early_stop_flag = False
428
+ if best_epoch_idx == epoch_idx and best_step_idx == step_idx:
429
+ patience_count = 0
430
+ else:
431
+ patience_count += 1
432
+ if patience_count >= args.patience:
433
+ early_stop_flag = True
434
+
435
+ # early stop
436
+ if early_stop_flag:
437
+ break
438
+ return
439
+
440
+
441
+ if __name__ == '__main__':
442
+ main()
examples/rnnoise/yaml/config.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "rnnoise"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ segment_size: 32000
6
+ nfft: 512
7
+ win_size: 512
8
+ hop_size: 256
9
+ win_type: hann
10
+
11
+ # data
12
+ max_snr_db: 20
13
+ min_snr_db: -10
14
+
15
+ # model
16
+ conv_size: 256
17
+ gru_size: 256
18
+
19
+ # train
20
+ max_epochs: 100
21
+ batch_size: 32
22
+ num_workers: 4
23
+ seed: 1234
24
+
25
+ lr: 0.001
26
+ lr_scheduler: CosineAnnealingLR
27
+ lr_scheduler_kwargs: {}
28
+
29
+ weight_decay: 0.00001
30
+ clip_grad_norm: 10.0
31
+ eval_steps: 20000
examples/test.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ inputs = torch.randn(size=(1, 1, 16000))
8
+
9
+ conv1d = nn.Conv1d(
10
+ in_channels=1,
11
+ out_channels=1,
12
+ kernel_size=3,
13
+ stride=2,
14
+ padding=0,
15
+ dilation=1,
16
+ )
17
+ conv1dt = nn.ConvTranspose1d(
18
+ in_channels=1,
19
+ out_channels=1,
20
+ kernel_size=3,
21
+ stride=2,
22
+ padding=0,
23
+ output_padding=1,
24
+ dilation=1,
25
+ )
26
+
27
+ x = conv1d.forward(inputs)
28
+
29
+ print(x.shape)
30
+
31
+ x = conv1dt.forward(x)
32
+ print(x.shape)
33
+ print(x[:, :, 0])
34
+ print(x[:, :, -2])
35
+ print(x[:, :, -1])
36
+
37
+
38
+ if __name__ == "__main__":
39
+ pass
toolbox/{torchaudio/models/dfnet3 → torch/sparsification}/__init__.py RENAMED
@@ -2,5 +2,5 @@
2
  # -*- coding: utf-8 -*-
3
 
4
 
5
- if __name__ == '__main__':
6
  pass
 
2
  # -*- coding: utf-8 -*-
3
 
4
 
5
+ if __name__ == "__main__":
6
  pass
toolbox/torch/sparsification/common.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ /* Copyright (c) 2023 Amazon
5
+ Written by Jan Buethe */
6
+ /*
7
+ Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions
9
+ are met:
10
+
11
+ - Redistributions of source code must retain the above copyright
12
+ notice, this list of conditions and the following disclaimer.
13
+
14
+ - Redistributions in binary form must reproduce the above copyright
15
+ notice, this list of conditions and the following disclaimer in the
16
+ documentation and/or other materials provided with the distribution.
17
+
18
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
22
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
25
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
26
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
27
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ */
30
+ """
31
+ import torch
32
+
33
+
34
+ """
35
+ https://github.com/xiph/rnnoise/blob/main/torch/sparsification/common.py
36
+ """
37
+
38
+ def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
39
+ """ sparsifies matrix with specified block size
40
+
41
+ Parameters:
42
+ -----------
43
+ matrix : torch.tensor
44
+ matrix to sparsify
45
+ density : int
46
+ target density
47
+ block_size : [int, int]
48
+ block size dimensions
49
+ keep_diagonal : bool
50
+ If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
51
+ """
52
+
53
+ m, n = matrix.shape
54
+ m1, n1 = block_size
55
+
56
+ if m % m1 or n % n1:
57
+ raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
58
+
59
+ # extract diagonal if keep_diagonal = True
60
+ if keep_diagonal:
61
+ if m != n:
62
+ raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
63
+
64
+ to_spare = torch.diag(torch.diag(matrix))
65
+ matrix = matrix - to_spare
66
+ else:
67
+ to_spare = torch.zeros_like(matrix)
68
+
69
+ # calculate energy in sub-blocks
70
+ x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
71
+ x = x ** 2
72
+ block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
73
+
74
+ number_of_blocks = (m * n) // (m1 * n1)
75
+ number_of_survivors = round(number_of_blocks * density)
76
+
77
+ # masking threshold
78
+ if number_of_survivors == 0:
79
+ threshold = 0
80
+ else:
81
+ threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
82
+
83
+ # create mask
84
+ mask = torch.ones_like(block_energies)
85
+ mask[block_energies < threshold] = 0
86
+ mask = torch.repeat_interleave(mask, m1, dim=0)
87
+ mask = torch.repeat_interleave(mask, n1, dim=1)
88
+
89
+ # perform masking
90
+ masked_matrix = mask * matrix + to_spare
91
+
92
+ if return_mask:
93
+ return masked_matrix, mask
94
+ else:
95
+ return masked_matrix
96
+
97
+ def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
98
+ input_size = gru.input_size
99
+ hidden_size = gru.hidden_size
100
+ flops = 0
101
+
102
+ input_density = (
103
+ sparsification_dict.get('W_ir', [1])[0]
104
+ + sparsification_dict.get('W_in', [1])[0]
105
+ + sparsification_dict.get('W_iz', [1])[0]
106
+ ) / 3
107
+
108
+ recurrent_density = (
109
+ sparsification_dict.get('W_hr', [1])[0]
110
+ + sparsification_dict.get('W_hn', [1])[0]
111
+ + sparsification_dict.get('W_hz', [1])[0]
112
+ ) / 3
113
+
114
+ # input matrix vector multiplications
115
+ if not drop_input:
116
+ flops += 2 * 3 * input_size * hidden_size * input_density
117
+
118
+ # recurrent matrix vector multiplications
119
+ flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
120
+
121
+ # biases
122
+ flops += 6 * hidden_size
123
+
124
+ # activations estimated by 10 flops per activation
125
+ flops += 30 * hidden_size
126
+
127
+ return flops
128
+
129
+
130
+ if __name__ == "__main__":
131
+ pass
toolbox/torch/sparsification/gru_sparsifier.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ /* Copyright (c) 2023 Amazon
5
+ Written by Jan Buethe */
6
+ /*
7
+ Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions
9
+ are met:
10
+
11
+ - Redistributions of source code must retain the above copyright
12
+ notice, this list of conditions and the following disclaimer.
13
+
14
+ - Redistributions in binary form must reproduce the above copyright
15
+ notice, this list of conditions and the following disclaimer in the
16
+ documentation and/or other materials provided with the distribution.
17
+
18
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
22
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
25
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
26
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
27
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ */
30
+ """
31
+ import torch
32
+
33
+ from toolbox.torch.sparsification.common import sparsify_matrix
34
+
35
+
36
+ """
37
+ https://github.com/xiph/rnnoise/blob/main/torch/sparsification/gru_sparsifier.py
38
+ """
39
+
40
+ class GRUSparsifier:
41
+ def __init__(self, task_list, start, stop, interval, exponent=3):
42
+ """ Sparsifier for torch.nn.GRUs
43
+
44
+ Parameters:
45
+ -----------
46
+ task_list : list
47
+ task_list contains a list of tuples (gru, sparsify_dict), where gru is an instance
48
+ of torch.nn.GRU and sparsify_dic is a dictionary with keys in {'W_ir', 'W_iz', 'W_in',
49
+ 'W_hr', 'W_hz', 'W_hn'} corresponding to the input and recurrent weights for the reset,
50
+ update, and new gate. The values of sparsify_dict are tuples (density, [m, n], keep_diagonal),
51
+ where density is the target density in [0, 1], [m, n] is the shape sub-blocks to which
52
+ sparsification is applied and keep_diagonal is a bool variable indicating whether the diagonal
53
+ should be kept.
54
+
55
+ start : int
56
+ training step after which sparsification will be started.
57
+
58
+ stop : int
59
+ training step after which sparsification will be completed.
60
+
61
+ interval : int
62
+ sparsification interval for steps between start and stop. After stop sparsification will be
63
+ carried out after every call to GRUSparsifier.step()
64
+
65
+ exponent : float
66
+ Interpolation exponent for sparsification interval. In step i sparsification will be carried out
67
+ with density (alpha + target_density * (1 * alpha)), where
68
+ alpha = ((stop - i) / (start - stop)) ** exponent
69
+
70
+ Example:
71
+ --------
72
+ >>> import torch
73
+ >>> gru = torch.nn.GRU(10, 20)
74
+ >>> sparsify_dict = {
75
+ ... 'W_ir' : (0.5, [2, 2], False),
76
+ ... 'W_iz' : (0.6, [2, 2], False),
77
+ ... 'W_in' : (0.7, [2, 2], False),
78
+ ... 'W_hr' : (0.1, [4, 4], True),
79
+ ... 'W_hz' : (0.2, [4, 4], True),
80
+ ... 'W_hn' : (0.3, [4, 4], True),
81
+ ... }
82
+ >>> sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 50)
83
+ >>> for i in range(100):
84
+ ... sparsifier.step()
85
+ """
86
+ # just copying parameters...
87
+ self.start = start
88
+ self.stop = stop
89
+ self.interval = interval
90
+ self.exponent = exponent
91
+ self.task_list = task_list
92
+
93
+ # ... and setting counter to 0
94
+ self.step_counter = 0
95
+
96
+ self.last_masks = {key : None for key in ['W_ir', 'W_in', 'W_iz', 'W_hr', 'W_hn', 'W_hz']}
97
+
98
+ def step(self, verbose=False):
99
+ """ carries out sparsification step
100
+
101
+ Call this function after optimizer.step in your
102
+ training loop.
103
+
104
+ Parameters:
105
+ ----------
106
+ verbose : bool
107
+ if true, densities are printed out
108
+
109
+ Returns:
110
+ --------
111
+ None
112
+
113
+ """
114
+ # compute current interpolation factor
115
+ self.step_counter += 1
116
+
117
+ if self.step_counter < self.start:
118
+ return
119
+ elif self.step_counter < self.stop:
120
+ # update only every self.interval-th interval
121
+ if self.step_counter % self.interval:
122
+ return
123
+
124
+ alpha = ((self.stop - self.step_counter) / (self.stop - self.start)) ** self.exponent
125
+ else:
126
+ alpha = 0
127
+
128
+ with torch.no_grad():
129
+ for gru, params in self.task_list:
130
+ hidden_size = gru.hidden_size
131
+
132
+ # input weights
133
+ for i, key in enumerate(['W_ir', 'W_iz', 'W_in']):
134
+ if key in params:
135
+ density = alpha + (1 - alpha) * params[key][0]
136
+ if verbose:
137
+ print(f"[{self.step_counter}]: {key} density: {density}")
138
+
139
+ gru.weight_ih_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
140
+ gru.weight_ih_l0[i * hidden_size : (i + 1) * hidden_size, : ],
141
+ density, # density
142
+ params[key][1], # block_size
143
+ params[key][2], # keep_diagonal (might want to set this to False)
144
+ return_mask=True
145
+ )
146
+
147
+ if type(self.last_masks[key]) != type(None):
148
+ if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
149
+ print(f"sparsification mask {key} changed for gru {gru}")
150
+
151
+ self.last_masks[key] = new_mask
152
+
153
+ # recurrent weights
154
+ for i, key in enumerate(['W_hr', 'W_hz', 'W_hn']):
155
+ if key in params:
156
+ density = alpha + (1 - alpha) * params[key][0]
157
+ if verbose:
158
+ print(f"[{self.step_counter}]: {key} density: {density}")
159
+ gru.weight_hh_l0[i * hidden_size : (i+1) * hidden_size, : ], new_mask = sparsify_matrix(
160
+ gru.weight_hh_l0[i * hidden_size : (i + 1) * hidden_size, : ],
161
+ density,
162
+ params[key][1], # block_size
163
+ params[key][2], # keep_diagonal (might want to set this to False)
164
+ return_mask=True
165
+ )
166
+
167
+ if type(self.last_masks[key]) != type(None):
168
+ if not torch.all(self.last_masks[key] == new_mask) and self.step_counter > self.stop:
169
+ print(f"sparsification mask {key} changed for gru {gru}")
170
+
171
+ self.last_masks[key] = new_mask
172
+
173
+
174
+ if __name__ == "__main__":
175
+ print("Testing sparsifier")
176
+
177
+ gru = torch.nn.GRU(10, 20)
178
+ sparsify_dict = {
179
+ 'W_ir' : (0.5, [2, 2], False),
180
+ 'W_iz' : (0.6, [2, 2], False),
181
+ 'W_in' : (0.7, [2, 2], False),
182
+ 'W_hr' : (0.1, [4, 4], True),
183
+ 'W_hz' : (0.2, [4, 4], True),
184
+ 'W_hn' : (0.3, [4, 4], True),
185
+ }
186
+
187
+ sparsifier = GRUSparsifier([(gru, sparsify_dict)], 0, 100, 10)
188
+
189
+ for i in range(100):
190
+ sparsifier.step(verbose=True)
toolbox/torchaudio/models/dfnet/modeling_dfnet.py CHANGED
@@ -1,5 +1,11 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
 
 
 
 
 
 
3
  import os
4
  import math
5
  from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
@@ -8,7 +14,6 @@ import numpy as np
8
  import torch
9
  import torch.nn as nn
10
  from torch.nn import functional as F
11
- import torchaudio
12
 
13
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
14
  from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
@@ -480,6 +485,7 @@ class Encoder(nn.Module):
480
 
481
 
482
  class Decoder(nn.Module):
 
483
  def __init__(self, config: DfNetConfig):
484
  super(Decoder, self).__init__()
485
 
@@ -800,6 +806,9 @@ class DeepFiltering(nn.Module):
800
 
801
 
802
  class DfNet(nn.Module):
 
 
 
803
  def __init__(self, config: DfNetConfig):
804
  super(DfNet, self).__init__()
805
  self.config = config
@@ -867,23 +876,11 @@ class DfNet(nn.Module):
867
  if remainder > 0:
868
  n_samples_pad = self.hop_size - remainder
869
  signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
870
- return signal, n_samples
871
-
872
- def forward(self,
873
- noisy: torch.Tensor,
874
- ):
875
- """
876
- :param noisy:
877
- :return:
878
- est_spec: shape: [b, 257*2, t]
879
- est_wav: shape: [b, num_samples]
880
- est_mask: shape: [b, 257, t]
881
- lsnr: shape: [b, 1, t]
882
- """
883
- noisy, n_samples = self.signal_prepare(noisy)
884
 
 
885
  # noisy shape: [b, num_samples_pad]
886
- spec_cmp = self.stft.forward(noisy)
887
  # spec_complex shape: [b, f, t], torch.complex64
888
  spec_cmp = torch.transpose(spec_cmp, dim0=1, dim1=2)
889
  # spec_complex shape: [b, t, f], torch.complex64
@@ -906,6 +903,24 @@ class DfNet(nn.Module):
906
  feat_spec = feat_spec[..., :self.df_decoder.df_bins]
907
  # feat_spec shape: [b, 2, t, df_bins]
908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
909
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_erb, feat_spec)
910
 
911
  mask = self.decoder.forward(emb, e3, e2, e1, e0)
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
+ """
4
+ DeepFilterNet 的原生实现不直接支持流式推理
5
+
6
+ 社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现
7
+ https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF
8
+ """
9
  import os
10
  import math
11
  from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
 
14
  import torch
15
  import torch.nn as nn
16
  from torch.nn import functional as F
 
17
 
18
  from toolbox.torchaudio.configuration_utils import CONFIG_FILE
19
  from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
 
485
 
486
 
487
  class Decoder(nn.Module):
488
+ """ErbDecoder"""
489
  def __init__(self, config: DfNetConfig):
490
  super(Decoder, self).__init__()
491
 
 
806
 
807
 
808
  class DfNet(nn.Module):
809
+ """
810
+ 我感觉这个模型没办法实现完全一致的流式推理。
811
+ """
812
  def __init__(self, config: DfNetConfig):
813
  super(DfNet, self).__init__()
814
  self.config = config
 
876
  if remainder > 0:
877
  n_samples_pad = self.hop_size - remainder
878
  signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
879
+ return signal
 
 
 
 
 
 
 
 
 
 
 
 
 
880
 
881
+ def feature_prepare(self, signal: torch.Tensor):
882
  # noisy shape: [b, num_samples_pad]
883
+ spec_cmp = self.stft.forward(signal)
884
  # spec_complex shape: [b, f, t], torch.complex64
885
  spec_cmp = torch.transpose(spec_cmp, dim0=1, dim1=2)
886
  # spec_complex shape: [b, t, f], torch.complex64
 
903
  feat_spec = feat_spec[..., :self.df_decoder.df_bins]
904
  # feat_spec shape: [b, 2, t, df_bins]
905
 
906
+ return spec, feat_erb, feat_spec
907
+
908
+ def forward(self,
909
+ noisy: torch.Tensor,
910
+ ):
911
+ """
912
+ :param noisy:
913
+ :return:
914
+ est_spec: shape: [b, 257*2, t]
915
+ est_wav: shape: [b, num_samples]
916
+ est_mask: shape: [b, 257, t]
917
+ lsnr: shape: [b, 1, t]
918
+ """
919
+ n_samples = noisy.shape[-1]
920
+ noisy = self.signal_prepare(noisy)
921
+
922
+ spec, feat_erb, feat_spec = self.feature_prepare(noisy)
923
+
924
  e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_erb, feat_spec)
925
 
926
  mask = self.decoder.forward(emb, e3, e2, e1, e0)
toolbox/torchaudio/models/dfnet/modeling_dfnet_online.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ DeepFilterNet 的原生实现不直接支持流式推理
5
+
6
+ 社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现
7
+ https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF
8
+
9
+ 此文件试图实现一个支持流式推理的 dfnet
10
+
11
+ """
12
+ import os
13
+ import math
14
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+
21
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
22
+ from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
23
+ from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
24
+ from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
25
+ from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
26
+
27
+
28
+ MODEL_FILE = "model.pt"
29
+
30
+
31
+ norm_layer_dict = {
32
+ "batch_norm_2d": torch.nn.BatchNorm2d
33
+ }
34
+
35
+
36
+ activation_layer_dict = {
37
+ "relu": torch.nn.ReLU,
38
+ "identity": torch.nn.Identity,
39
+ "sigmoid": torch.nn.Sigmoid,
40
+ }
41
+
42
+
43
+ class CausalConv2d(nn.Module):
44
+ def __init__(self,
45
+ in_channels: int,
46
+ out_channels: int,
47
+ kernel_size: Union[int, Iterable[int]],
48
+ fstride: int = 1,
49
+ dilation: int = 1,
50
+ pad_f_dim: bool = True,
51
+ bias: bool = True,
52
+ separable: bool = False,
53
+ norm_layer: str = "batch_norm_2d",
54
+ activation_layer: str = "relu",
55
+ ):
56
+ super(CausalConv2d, self).__init__()
57
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
58
+
59
+ if pad_f_dim:
60
+ fpad = kernel_size[1] // 2 + dilation - 1
61
+ else:
62
+ fpad = 0
63
+
64
+ # for last 2 dim, pad (left, right, top, bottom).
65
+ self.lookback = kernel_size[0] - 1
66
+ if self.lookback > 0:
67
+ self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0)
68
+ else:
69
+ self.tpad = nn.Identity()
70
+
71
+ groups = math.gcd(in_channels, out_channels) if separable else 1
72
+ if groups == 1:
73
+ separable = False
74
+ if max(kernel_size) == 1:
75
+ separable = False
76
+
77
+ self.conv = nn.Conv2d(
78
+ in_channels,
79
+ out_channels,
80
+ kernel_size=kernel_size,
81
+ padding=(0, fpad),
82
+ stride=(1, fstride), # stride over time is always 1
83
+ dilation=(1, dilation), # dilation over time is always 1
84
+ groups=groups,
85
+ bias=bias,
86
+ )
87
+
88
+ if separable:
89
+ self.convp = nn.Conv2d(
90
+ out_channels,
91
+ out_channels,
92
+ kernel_size=1,
93
+ bias=False,
94
+ )
95
+ else:
96
+ self.convp = nn.Identity()
97
+
98
+ if norm_layer is not None:
99
+ norm_layer = norm_layer_dict[norm_layer]
100
+ self.norm = norm_layer(out_channels)
101
+ else:
102
+ self.norm = nn.Identity()
103
+
104
+ if activation_layer is not None:
105
+ activation_layer = activation_layer_dict[activation_layer]
106
+ self.activation = activation_layer()
107
+ else:
108
+ self.activation = nn.Identity()
109
+
110
+ super().__init__()
111
+
112
+ def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None):
113
+ """
114
+ :param inputs: shape: [b, c, t, f]
115
+ :param cache: shape: [b, c, lookback, f];
116
+ :return:
117
+ """
118
+ x = inputs
119
+
120
+ if cache is None:
121
+ x = self.tpad(x)
122
+ else:
123
+ x = torch.concat(tensors=[cache, x], dim=2)
124
+ new_cache = x[:, :, -self.lookback:, :]
125
+
126
+ x = self.conv(x)
127
+
128
+ x = self.convp(x)
129
+ x = self.norm(x)
130
+ x = self.activation(x)
131
+
132
+ return x, new_cache
133
+
134
+
135
+ class CausalConvTranspose2d(nn.Module):
136
+ def __init__(self,
137
+ in_channels: int,
138
+ out_channels: int,
139
+ kernel_size: Union[int, Iterable[int]],
140
+ fstride: int = 1,
141
+ dilation: int = 1,
142
+ pad_f_dim: bool = True,
143
+ bias: bool = True,
144
+ separable: bool = False,
145
+ norm_layer: str = "batch_norm_2d",
146
+ activation_layer: str = "relu",
147
+ ):
148
+ super(CausalConvTranspose2d, self).__init__()
149
+
150
+ kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
151
+
152
+ if pad_f_dim:
153
+ fpad = kernel_size[1] // 2
154
+ else:
155
+ fpad = 0
156
+
157
+ # for last 2 dim, pad (left, right, top, bottom).
158
+ self.lookback = kernel_size[0] - 1
159
+
160
+ groups = math.gcd(in_channels, out_channels) if separable else 1
161
+ if groups == 1:
162
+ separable = False
163
+
164
+ self.convt = nn.ConvTranspose2d(
165
+ in_channels,
166
+ out_channels,
167
+ kernel_size=kernel_size,
168
+ padding=(0, fpad),
169
+ output_padding=(0, 0),
170
+ stride=(1, fstride), # stride over time is always 1
171
+ dilation=(1, dilation), # dilation over time is always 1
172
+ groups=groups,
173
+ bias=bias,
174
+ )
175
+
176
+ if separable:
177
+ self.convp = nn.Conv2d(
178
+ out_channels,
179
+ out_channels,
180
+ kernel_size=1,
181
+ bias=False,
182
+ )
183
+ else:
184
+ self.convp = nn.Identity()
185
+
186
+ if norm_layer is not None:
187
+ norm_layer = norm_layer_dict[norm_layer]
188
+ self.norm = norm_layer(out_channels)
189
+ else:
190
+ self.norm = nn.Identity()
191
+
192
+ if activation_layer is not None:
193
+ activation_layer = activation_layer_dict[activation_layer]
194
+ self.activation = activation_layer()
195
+ else:
196
+ self.activation = nn.Identity()
197
+
198
+ def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None):
199
+ """
200
+ :param inputs: shape: [b, c, t, f]
201
+ :param cache: shape: [b, c, lookback, f];
202
+ :return:
203
+ """
204
+ x = inputs
205
+
206
+ # x shape: [b, c, t, f]
207
+ x = self.convt(x)
208
+ # x shape: [b, c, t+lookback, f]
209
+
210
+ if cache is not None:
211
+ x = torch.concat(tensors=[
212
+ x[:, :, :self.lookback, :] + cache,
213
+ x[:, :, self.lookback:, :]
214
+ ], dim=2)
215
+ x = x[:, :, :-self.lookback, :]
216
+ new_cache = x[:, :, -self.lookback:, :]
217
+
218
+ x = self.convp(x)
219
+ x = self.norm(x)
220
+ x = self.activation(x)
221
+
222
+ return x, new_cache
223
+
224
+
225
+ if __name__ == "__main__":
226
+ pass
toolbox/torchaudio/models/dfnet3/configuration_dfnet3.py DELETED
@@ -1,89 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- from typing import Any, Dict, List, Tuple, Union
4
-
5
- from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
-
7
-
8
- class DfNetConfig(PretrainedConfig):
9
- def __init__(self,
10
- sample_rate: int,
11
- fft_size: int,
12
- hop_size: int,
13
- df_bins: int,
14
- erb_bins: int,
15
- min_freq_bins_for_erb: int,
16
- df_order: int,
17
- df_lookahead: int,
18
- norm_tau: int,
19
- lsnr_max: int,
20
- lsnr_min: int,
21
- conv_channels: int,
22
- conv_kernel_size_input: Tuple[int, int],
23
- conv_kernel_size_inner: Tuple[int, int],
24
- convt_kernel_size_inner: Tuple[int, int],
25
- conv_lookahead: int,
26
- emb_hidden_dim: int,
27
- mask_post_filter: bool,
28
- df_hidden_dim: int,
29
- df_num_layers: int,
30
- df_pathway_kernel_size_t: int,
31
- df_gru_skip: str,
32
- post_filter_beta: float,
33
- df_n_iter: float,
34
- lsnr_dropout: bool,
35
- encoder_gru_skip_op: str,
36
- encoder_linear_groups: int,
37
- encoder_squeezed_gru_linear_groups: int,
38
- encoder_concat: bool,
39
- erb_decoder_gru_skip_op: str,
40
- erb_decoder_linear_groups: int,
41
- erb_decoder_emb_num_layers: int,
42
- df_decoder_linear_groups: int,
43
- **kwargs
44
- ):
45
- super(DfNetConfig, self).__init__(**kwargs)
46
- if df_gru_skip not in ("none", "identity", "grouped_linear"):
47
- raise AssertionError
48
-
49
- self.sample_rate = sample_rate
50
- self.fft_size = fft_size
51
- self.hop_size = hop_size
52
- self.df_bins = df_bins
53
- self.erb_bins = erb_bins
54
- self.min_freq_bins_for_erb = min_freq_bins_for_erb
55
- self.df_order = df_order
56
- self.df_lookahead = df_lookahead
57
- self.norm_tau = norm_tau
58
- self.lsnr_max = lsnr_max
59
- self.lsnr_min = lsnr_min
60
-
61
- self.conv_channels = conv_channels
62
- self.conv_kernel_size_input = conv_kernel_size_input
63
- self.conv_kernel_size_inner = conv_kernel_size_inner
64
- self.convt_kernel_size_inner = convt_kernel_size_inner
65
- self.conv_lookahead = conv_lookahead
66
-
67
- self.emb_hidden_dim = emb_hidden_dim
68
- self.mask_post_filter = mask_post_filter
69
- self.df_hidden_dim = df_hidden_dim
70
- self.df_num_layers = df_num_layers
71
- self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
72
- self.df_gru_skip = df_gru_skip
73
- self.post_filter_beta = post_filter_beta
74
- self.df_n_iter = df_n_iter
75
- self.lsnr_dropout = lsnr_dropout
76
- self.encoder_gru_skip_op = encoder_gru_skip_op
77
- self.encoder_linear_groups = encoder_linear_groups
78
- self.encoder_squeezed_gru_linear_groups = encoder_squeezed_gru_linear_groups
79
- self.encoder_concat = encoder_concat
80
-
81
- self.erb_decoder_gru_skip_op = erb_decoder_gru_skip_op
82
- self.erb_decoder_linear_groups = erb_decoder_linear_groups
83
- self.erb_decoder_emb_num_layers = erb_decoder_emb_num_layers
84
-
85
- self.df_decoder_linear_groups = df_decoder_linear_groups
86
-
87
-
88
- if __name__ == "__main__":
89
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/dfnet3/features.py DELETED
@@ -1,192 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import math
4
-
5
- import numpy as np
6
-
7
-
8
- def freq2erb(freq_hz: float) -> float:
9
- """
10
- https://www.cnblogs.com/LXP-Never/p/16011229.html
11
- 1 / (24.7 * 9.265) = 0.00436976
12
- """
13
- return 9.265 * math.log(freq_hz / (24.7 * 9.265) + 1)
14
-
15
-
16
- def erb2freq(n_erb: float) -> float:
17
- return 24.7 * 9.265 * (math.exp(n_erb / 9.265) - 1)
18
-
19
-
20
- def get_erb_widths(sample_rate: int, fft_size: int, erb_bins: int, min_freq_bins_for_erb: int) -> np.ndarray:
21
- """
22
- https://github.com/Rikorose/DeepFilterNet/blob/main/libDF/src/lib.rs
23
- :param sample_rate:
24
- :param fft_size:
25
- :param erb_bins: erb (Equivalent Rectangular Bandwidth) 等效矩形带宽的通道数.
26
- :param min_freq_bins_for_erb: Minimum number of frequency bands per erb band
27
- :return:
28
- """
29
- nyq_freq = sample_rate / 2.
30
- freq_width: float = sample_rate / fft_size
31
-
32
- min_erb: float = freq2erb(0.)
33
- max_erb: float = freq2erb(nyq_freq)
34
-
35
- erb = [0] * erb_bins
36
- step = (max_erb - min_erb) / erb_bins
37
-
38
- prev_freq_bin = 0
39
- freq_over = 0
40
- for i in range(1, erb_bins + 1):
41
- f = erb2freq(min_erb + i * step)
42
- freq_bin = int(round(f / freq_width))
43
- freq_bins = freq_bin - prev_freq_bin - freq_over
44
-
45
- if freq_bins < min_freq_bins_for_erb:
46
- freq_over = min_freq_bins_for_erb - freq_bins
47
- freq_bins = min_freq_bins_for_erb
48
- else:
49
- freq_over = 0
50
- erb[i - 1] = freq_bins
51
- prev_freq_bin = freq_bin
52
-
53
- erb[erb_bins - 1] += 1
54
- too_large = sum(erb) - (fft_size / 2 + 1)
55
- if too_large > 0:
56
- erb[erb_bins - 1] -= too_large
57
- return np.array(erb, dtype=np.uint64)
58
-
59
-
60
- def get_erb_filter_bank(erb_widths: np.ndarray,
61
- sample_rate: int,
62
- normalized: bool = True,
63
- inverse: bool = False,
64
- ):
65
- num_freq_bins = int(np.sum(erb_widths))
66
- num_erb_bins = len(erb_widths)
67
-
68
- fb: np.ndarray = np.zeros(shape=(num_freq_bins, num_erb_bins))
69
-
70
- points = np.cumsum([0] + erb_widths.tolist()).astype(int)[:-1]
71
- for i, (b, w) in enumerate(zip(points.tolist(), erb_widths.tolist())):
72
- fb[b: b + w, i] = 1
73
-
74
- if inverse:
75
- fb = fb.T
76
- if not normalized:
77
- fb /= np.sum(fb, axis=1, keepdims=True)
78
- else:
79
- if normalized:
80
- fb /= np.sum(fb, axis=0)
81
- return fb
82
-
83
-
84
- def spec2erb(spec: np.ndarray, erb_fb: np.ndarray, db: bool = True):
85
- """
86
- ERB filterbank and transform to decibel scale.
87
-
88
- :param spec: Spectrum of shape [B, C, T, F].
89
- :param erb_fb: ERB filterbank array of shape [B] containing the ERB widths,
90
- where B are the number of ERB bins.
91
- :param db: Whether to transform the output into decibel scale. Defaults to `True`.
92
- :return:
93
- """
94
- # complex spec to power spec. (real * real + image * image)
95
- spec_ = np.abs(spec) ** 2
96
-
97
- # spec to erb feature.
98
- erb_feat = np.matmul(spec_, erb_fb)
99
-
100
- if db:
101
- erb_feat = 10 * np.log10(erb_feat + 1e-10)
102
-
103
- erb_feat = np.array(erb_feat, dtype=np.float32)
104
- return erb_feat
105
-
106
-
107
- def _calculate_norm_alpha(sample_rate: int, hop_size: int, tau: float):
108
- """Exponential decay factor alpha for a given tau (decay window size [s])."""
109
- dt = hop_size / sample_rate
110
- result = math.exp(-dt / tau)
111
- return result
112
-
113
-
114
- def get_norm_alpha(sample_rate: int, hop_size: int, norm_tau: float) -> float:
115
- a_ = _calculate_norm_alpha(sample_rate=sample_rate, hop_size=hop_size, tau=norm_tau)
116
-
117
- precision = 3
118
- a = 1.0
119
- while a >= 1.0:
120
- a = round(a_, precision)
121
- precision += 1
122
-
123
- return a
124
-
125
-
126
- MEAN_NORM_INIT = [-60., -90.]
127
-
128
-
129
- def make_erb_norm_state(erb_bins: int, channels: int) -> np.ndarray:
130
- state = np.linspace(MEAN_NORM_INIT[0], MEAN_NORM_INIT[1], erb_bins)
131
- state = np.expand_dims(state, axis=0)
132
- state = np.repeat(state, channels, axis=0)
133
-
134
- # state shape: (audio_channels, erb_bins)
135
- return state
136
-
137
-
138
- def erb_normalize(erb_feat: np.ndarray, alpha: float, state: np.ndarray = None):
139
- erb_feat = np.copy(erb_feat)
140
- batch_size, time_steps, erb_bins = erb_feat.shape
141
-
142
- if state is None:
143
- state = make_erb_norm_state(erb_bins, erb_feat.shape[0])
144
- # state = np.linspace(MEAN_NORM_INIT[0], MEAN_NORM_INIT[1], erb_bins)
145
- # state = np.expand_dims(state, axis=0)
146
- # state = np.repeat(state, erb_feat.shape[0], axis=0)
147
-
148
- for i in range(batch_size):
149
- for j in range(time_steps):
150
- for k in range(erb_bins):
151
- x = erb_feat[i][j][k]
152
- s = state[i][k]
153
-
154
- state[i][k] = x * (1. - alpha) + s * alpha
155
- erb_feat[i][j][k] -= state[i][k]
156
- erb_feat[i][j][k] /= 40.
157
-
158
- return erb_feat
159
-
160
-
161
- UNIT_NORM_INIT = [0.001, 0.0001]
162
-
163
-
164
- def make_spec_norm_state(df_bins: int, channels: int) -> np.ndarray:
165
- state = np.linspace(UNIT_NORM_INIT[0], UNIT_NORM_INIT[1], df_bins)
166
- state = np.expand_dims(state, axis=0)
167
- state = np.repeat(state, channels, axis=0)
168
-
169
- # state shape: (audio_channels, df_bins)
170
- return state
171
-
172
-
173
- def spec_normalize(spec_feat: np.ndarray, alpha: float, state: np.ndarray = None):
174
- spec_feat = np.copy(spec_feat)
175
- batch_size, time_steps, df_bins = spec_feat.shape
176
-
177
- if state is None:
178
- state = make_spec_norm_state(df_bins, spec_feat.shape[0])
179
-
180
- for i in range(batch_size):
181
- for j in range(time_steps):
182
- for k in range(df_bins):
183
- x = spec_feat[i][j][k]
184
- s = state[i][k]
185
-
186
- state[i][k] = np.abs(x) * (1. - alpha) + s * alpha
187
- spec_feat[i][j][k] /= np.sqrt(state[i][k])
188
- return spec_feat
189
-
190
-
191
- if __name__ == '__main__':
192
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/dfnet3/modeling_dfnet3.py DELETED
@@ -1,835 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import logging
4
- import math
5
- from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
6
-
7
- import numpy as np
8
- import torch
9
- import torch.nn as nn
10
-
11
- from toolbox.torchaudio.models.dfnet3.configuration_dfnet3 import DfNetConfig
12
- from toolbox.torchaudio.models.dfnet3 import multiframes as MF
13
- from toolbox.torchaudio.models.dfnet3 import utils
14
-
15
- logger = logging.getLogger("toolbox")
16
-
17
- PI = 3.1415926535897932384626433
18
-
19
-
20
- norm_layer_dict = {
21
- "batch_norm_2d": torch.nn.BatchNorm2d
22
- }
23
-
24
- activation_layer_dict = {
25
- "relu": torch.nn.ReLU,
26
- "identity": torch.nn.Identity,
27
- "sigmoid": torch.nn.Sigmoid,
28
- }
29
-
30
-
31
- class CausalConv2d(nn.Sequential):
32
- def __init__(self,
33
- in_channels: int,
34
- out_channels: int,
35
- kernel_size: Union[int, Iterable[int]],
36
- fstride: int = 1,
37
- dilation: int = 1,
38
- fpad: bool = True,
39
- bias: bool = True,
40
- separable: bool = False,
41
- norm_layer: str = "batch_norm_2d",
42
- activation_layer: str = "relu",
43
- ):
44
- """
45
- Causal Conv2d by delaying the signal for any lookahead.
46
-
47
- Expected input format: [B, C, T, F]
48
-
49
- :param in_channels:
50
- :param out_channels:
51
- :param kernel_size:
52
- :param fstride:
53
- :param dilation:
54
- :param fpad:
55
- """
56
- super(CausalConv2d, self).__init__()
57
- lookahead = 0
58
-
59
- kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
60
-
61
- if fpad:
62
- fpad_ = kernel_size[1] // 2 + dilation - 1
63
- else:
64
- fpad_ = 0
65
-
66
- # for last 2 dim, pad (left, right, top, bottom).
67
- pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
68
-
69
- layers = []
70
- if any(x > 0 for x in pad):
71
- layers.append(nn.ConstantPad2d(pad, 0.0))
72
-
73
- groups = math.gcd(in_channels, out_channels) if separable else 1
74
- if groups == 1:
75
- separable = False
76
- if max(kernel_size) == 1:
77
- separable = False
78
-
79
- layers.append(
80
- nn.Conv2d(
81
- in_channels,
82
- out_channels,
83
- kernel_size=kernel_size,
84
- padding=(0, fpad_),
85
- stride=(1, fstride), # stride over time is always 1
86
- dilation=(1, dilation), # dilation over time is always 1
87
- groups=groups,
88
- bias=bias,
89
- )
90
- )
91
-
92
- if separable:
93
- layers.append(
94
- nn.Conv2d(
95
- out_channels,
96
- out_channels,
97
- kernel_size=1,
98
- bias=False,
99
- )
100
- )
101
-
102
- if norm_layer is not None:
103
- norm_layer = norm_layer_dict[norm_layer]
104
- layers.append(norm_layer(out_channels))
105
-
106
- if activation_layer is not None:
107
- activation_layer = activation_layer_dict[activation_layer]
108
- layers.append(activation_layer())
109
-
110
- super().__init__(*layers)
111
-
112
-
113
- class CausalConvTranspose2d(nn.Sequential):
114
- def __init__(self,
115
- in_channels: int,
116
- out_channels: int,
117
- kernel_size: Union[int, Iterable[int]],
118
- fstride: int = 1,
119
- dilation: int = 1,
120
- fpad: bool = True,
121
- bias: bool = True,
122
- separable: bool = False,
123
- norm_layer: str = "batch_norm_2d",
124
- activation_layer: str = "relu",
125
- ):
126
- """
127
- Causal ConvTranspose2d.
128
-
129
- Expected input format: [B, C, T, F]
130
- """
131
- super(CausalConvTranspose2d, self).__init__()
132
- lookahead = 0
133
-
134
- kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
135
-
136
- if fpad:
137
- fpad_ = kernel_size[1] // 2
138
- else:
139
- fpad_ = 0
140
-
141
- # for last 2 dim, pad (left, right, top, bottom).
142
- pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
143
-
144
- layers = []
145
- if any(x > 0 for x in pad):
146
- layers.append(nn.ConstantPad2d(pad, 0.0))
147
-
148
- groups = math.gcd(in_channels, out_channels) if separable else 1
149
- if groups == 1:
150
- separable = False
151
-
152
- layers.append(
153
- nn.ConvTranspose2d(
154
- in_channels,
155
- out_channels,
156
- kernel_size=kernel_size,
157
- padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
158
- output_padding=(0, fpad_),
159
- stride=(1, fstride), # stride over time is always 1
160
- dilation=(1, dilation), # dilation over time is always 1
161
- groups=groups,
162
- bias=bias,
163
- )
164
- )
165
-
166
- if separable:
167
- layers.append(
168
- nn.Conv2d(
169
- out_channels,
170
- out_channels,
171
- kernel_size=1,
172
- bias=False,
173
- )
174
- )
175
-
176
- if norm_layer is not None:
177
- norm_layer = norm_layer_dict[norm_layer]
178
- layers.append(norm_layer(out_channels))
179
-
180
- if activation_layer is not None:
181
- activation_layer = activation_layer_dict[activation_layer]
182
- layers.append(activation_layer())
183
-
184
- super().__init__(*layers)
185
-
186
-
187
- class GroupedLinear(nn.Module):
188
-
189
- def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
190
- super().__init__()
191
- # self.weight: Tensor
192
- self.input_size = input_size
193
- self.hidden_size = hidden_size
194
- self.groups = groups
195
- assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
196
- assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
197
- self.ws = input_size // groups
198
- self.register_parameter(
199
- "weight",
200
- torch.nn.Parameter(
201
- torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
202
- ),
203
- )
204
- self.reset_parameters()
205
-
206
- def reset_parameters(self):
207
- nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
208
-
209
- def forward(self, x: torch.Tensor) -> torch.Tensor:
210
- # x: [..., I]
211
- b, t, _ = x.shape
212
- # new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
213
- new_shape = (b, t, self.groups, self.ws)
214
- x = x.view(new_shape)
215
- # The better way, but not supported by torchscript
216
- # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
217
- x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
218
- x = x.flatten(2, 3) # [B, T, H]
219
- return x
220
-
221
- def __repr__(self):
222
- cls = self.__class__.__name__
223
- return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"
224
-
225
-
226
- class SqueezedGRU_S(nn.Module):
227
- """
228
- SGE net: Video object detection with squeezed GRU and information entropy map
229
- https://arxiv.org/abs/2106.07224
230
- """
231
-
232
- def __init__(
233
- self,
234
- input_size: int,
235
- hidden_size: int,
236
- output_size: Optional[int] = None,
237
- num_layers: int = 1,
238
- linear_groups: int = 8,
239
- batch_first: bool = True,
240
- skip_op: str = "none",
241
- activation_layer: str = "identity",
242
- ):
243
- super().__init__()
244
- self.input_size = input_size
245
- self.hidden_size = hidden_size
246
-
247
- self.linear_in = nn.Sequential(
248
- GroupedLinear(
249
- input_size=input_size,
250
- hidden_size=hidden_size,
251
- groups=linear_groups,
252
- ),
253
- activation_layer_dict[activation_layer](),
254
- )
255
-
256
- # gru skip operator
257
- self.gru_skip_op = None
258
-
259
- if skip_op == "none":
260
- self.gru_skip_op = None
261
- elif skip_op == "identity":
262
- if not input_size != output_size:
263
- raise AssertionError("Dimensions do not match")
264
- self.gru_skip_op = nn.Identity()
265
- elif skip_op == "grouped_linear":
266
- self.gru_skip_op = GroupedLinear(
267
- input_size=hidden_size,
268
- hidden_size=hidden_size,
269
- groups=linear_groups,
270
- )
271
- else:
272
- raise NotImplementedError()
273
-
274
- self.gru = nn.GRU(
275
- input_size=hidden_size,
276
- hidden_size=hidden_size,
277
- num_layers=num_layers,
278
- batch_first=batch_first,
279
- )
280
-
281
- if output_size is not None:
282
- self.linear_out = nn.Sequential(
283
- GroupedLinear(
284
- input_size=hidden_size,
285
- hidden_size=output_size,
286
- groups=linear_groups,
287
- ),
288
- activation_layer_dict[activation_layer](),
289
- )
290
- else:
291
- self.linear_out = nn.Identity()
292
-
293
- def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
294
- x = self.linear_in(inputs)
295
-
296
- x, h = self.gru(x, h)
297
-
298
- x = self.linear_out(x)
299
-
300
- if self.gru_skip_op is not None:
301
- x = x + self.gru_skip_op(inputs)
302
-
303
- return x, h
304
-
305
-
306
- class Add(nn.Module):
307
- def forward(self, a, b):
308
- return a + b
309
-
310
-
311
- class Concat(nn.Module):
312
- def forward(self, a, b):
313
- return torch.cat((a, b), dim=-1)
314
-
315
-
316
- class Encoder(nn.Module):
317
- def __init__(self, config: DfNetConfig):
318
- super(Encoder, self).__init__()
319
- self.emb_in_dim = config.conv_channels * config.erb_bins // 4
320
- self.emb_out_dim = config.conv_channels * config.erb_bins // 4
321
- self.emb_hidden_dim = config.emb_hidden_dim
322
-
323
- self.erb_conv0 = CausalConv2d(
324
- in_channels=1,
325
- out_channels=config.conv_channels,
326
- kernel_size=config.conv_kernel_size_input,
327
- bias=False,
328
- separable=True,
329
- )
330
- self.erb_conv1 = CausalConv2d(
331
- in_channels=config.conv_channels,
332
- out_channels=config.conv_channels,
333
- kernel_size=config.conv_kernel_size_inner,
334
- bias=False,
335
- separable=True,
336
- fstride=2,
337
- )
338
- self.erb_conv2 = CausalConv2d(
339
- in_channels=config.conv_channels,
340
- out_channels=config.conv_channels,
341
- kernel_size=config.conv_kernel_size_inner,
342
- bias=False,
343
- separable=True,
344
- fstride=2,
345
- )
346
- self.erb_conv3 = CausalConv2d(
347
- in_channels=config.conv_channels,
348
- out_channels=config.conv_channels,
349
- kernel_size=config.conv_kernel_size_inner,
350
- bias=False,
351
- separable=True,
352
- fstride=1,
353
- )
354
-
355
- self.df_conv0 = CausalConv2d(
356
- in_channels=2,
357
- out_channels=config.conv_channels,
358
- kernel_size=config.conv_kernel_size_input,
359
- bias=False,
360
- separable=True,
361
- )
362
- self.df_conv1 = CausalConv2d(
363
- in_channels=config.conv_channels,
364
- out_channels=config.conv_channels,
365
- kernel_size=config.conv_kernel_size_inner,
366
- bias=False,
367
- separable=True,
368
- fstride=2,
369
- )
370
-
371
- self.df_fc_emb = nn.Sequential(
372
- GroupedLinear(
373
- config.conv_channels * config.df_bins // 2,
374
- self.emb_in_dim,
375
- groups=config.encoder_linear_groups
376
- ),
377
- nn.ReLU(inplace=True)
378
- )
379
-
380
- if config.encoder_concat:
381
- self.emb_in_dim *= 2
382
- self.combine = Concat()
383
- else:
384
- self.combine = Add()
385
-
386
- self.emb_gru = SqueezedGRU_S(
387
- self.emb_in_dim,
388
- self.emb_hidden_dim,
389
- output_size=self.emb_out_dim,
390
- num_layers=1,
391
- batch_first=True,
392
- skip_op=config.encoder_gru_skip_op,
393
- linear_groups=config.encoder_squeezed_gru_linear_groups,
394
- activation_layer="relu",
395
- )
396
-
397
- self.lsnr_fc = nn.Sequential(
398
- nn.Linear(self.emb_out_dim, 1),
399
- nn.Sigmoid()
400
- )
401
- self.lsnr_scale = config.lsnr_max - config.lsnr_min
402
- self.lsnr_offset = config.lsnr_min
403
-
404
- def forward(self,
405
- feat_erb: torch.Tensor,
406
- feat_spec: torch.Tensor,
407
- h: torch.Tensor = None,
408
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
409
- # Encodes erb; erb should be in dB scale + normalized; Fe are number of erb bands.
410
- # erb: [B, 1, T, Fe]
411
- # spec: [B, 2, T, Fc]
412
- # b, _, t, _ = feat_erb.shape
413
- e0 = self.erb_conv0(feat_erb) # [B, C, T, F]
414
- e1 = self.erb_conv1(e0) # [B, C*2, T, F/2]
415
- e2 = self.erb_conv2(e1) # [B, C*4, T, F/4]
416
- e3 = self.erb_conv3(e2) # [B, C*4, T, F/4]
417
- c0 = self.df_conv0(feat_spec) # [B, C, T, Fc]
418
- c1 = self.df_conv1(c0) # [B, C*2, T, Fc/2]
419
- cemb = c1.permute(0, 2, 3, 1).flatten(2) # [B, T, -1]
420
- cemb = self.df_fc_emb(cemb) # [T, B, C * F/4]
421
- emb = e3.permute(0, 2, 3, 1).flatten(2) # [B, T, C * F]
422
- emb = self.combine(emb, cemb)
423
- emb, h = self.emb_gru(emb, h) # [B, T, -1]
424
-
425
- lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
426
- return e0, e1, e2, e3, emb, c0, lsnr, h
427
-
428
-
429
- class ErbDecoder(nn.Module):
430
- def __init__(self,
431
- config: DfNetConfig,
432
- ):
433
- super(ErbDecoder, self).__init__()
434
- if config.erb_bins % 8 != 0:
435
- raise AssertionError("erb_bins should be divisible by 8")
436
-
437
- self.emb_in_dim = config.conv_channels * config.erb_bins // 4
438
- self.emb_out_dim = config.conv_channels * config.erb_bins // 4
439
- self.emb_hidden_dim = config.emb_hidden_dim
440
-
441
- self.emb_gru = SqueezedGRU_S(
442
- self.emb_in_dim,
443
- self.emb_hidden_dim,
444
- output_size=self.emb_out_dim,
445
- num_layers=config.erb_decoder_emb_num_layers - 1,
446
- batch_first=True,
447
- skip_op=config.erb_decoder_gru_skip_op,
448
- linear_groups=config.erb_decoder_linear_groups,
449
- activation_layer="relu",
450
- )
451
-
452
- # convt: TransposedConvolution, convp: Pathway (encoder to decoder) convolutions
453
- self.conv3p = CausalConv2d(
454
- in_channels=config.conv_channels,
455
- out_channels=config.conv_channels,
456
- kernel_size=1,
457
- bias=False,
458
- separable=True,
459
- )
460
- self.convt3 = CausalConv2d(
461
- in_channels=config.conv_channels,
462
- out_channels=config.conv_channels,
463
- kernel_size=config.conv_kernel_size_inner,
464
- bias=False,
465
- separable=True,
466
- )
467
- self.conv2p = CausalConv2d(
468
- in_channels=config.conv_channels,
469
- out_channels=config.conv_channels,
470
- kernel_size=1,
471
- bias=False,
472
- separable=True,
473
- )
474
- self.convt2 = CausalConvTranspose2d(
475
- in_channels=config.conv_channels,
476
- out_channels=config.conv_channels,
477
- fstride=2,
478
- kernel_size=config.convt_kernel_size_inner,
479
- bias=False,
480
- separable=True,
481
- )
482
- self.conv1p = CausalConv2d(
483
- in_channels=config.conv_channels,
484
- out_channels=config.conv_channels,
485
- kernel_size=1,
486
- bias=False,
487
- separable=True,
488
- )
489
- self.convt1 = CausalConvTranspose2d(
490
- in_channels=config.conv_channels,
491
- out_channels=config.conv_channels,
492
- fstride=2,
493
- kernel_size=config.convt_kernel_size_inner,
494
- bias=False,
495
- separable=True,
496
- )
497
- self.conv0p = CausalConv2d(
498
- in_channels=config.conv_channels,
499
- out_channels=config.conv_channels,
500
- kernel_size=1,
501
- bias=False,
502
- separable=True,
503
- )
504
- self.conv0_out = CausalConv2d(
505
- in_channels=config.conv_channels,
506
- out_channels=1,
507
- kernel_size=config.conv_kernel_size_inner,
508
- activation_layer="sigmoid",
509
- bias=False,
510
- separable=True,
511
- )
512
-
513
- def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor:
514
- # Estimates erb mask
515
- b, _, t, f8 = e3.shape
516
- emb, _ = self.emb_gru(emb)
517
- emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) # [B, C*8, T, F/8]
518
- e3 = self.convt3(self.conv3p(e3) + emb) # [B, C*4, T, F/4]
519
- e2 = self.convt2(self.conv2p(e2) + e3) # [B, C*2, T, F/2]
520
- e1 = self.convt1(self.conv1p(e1) + e2) # [B, C, T, F]
521
- m = self.conv0_out(self.conv0p(e0) + e1) # [B, 1, T, F]
522
- return m
523
-
524
-
525
- class Mask(nn.Module):
526
- def __init__(self, erb_inv_fb: torch.FloatTensor, post_filter: bool = False, eps: float = 1e-12):
527
- super().__init__()
528
- self.erb_inv_fb: torch.FloatTensor
529
- self.register_buffer("erb_inv_fb", erb_inv_fb.float())
530
- self.clamp_tensor = torch.__version__ > "1.9.0" or torch.__version__ == "1.9.0"
531
- self.post_filter = post_filter
532
- self.eps = eps
533
-
534
- def pf(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
535
- """
536
- Post-Filter
537
-
538
- A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
539
- https://arxiv.org/abs/2008.04259
540
-
541
- :param mask: Real valued mask, typically of shape [B, C, T, F].
542
- :param beta: Global gain factor.
543
- :return:
544
- """
545
- mask_sin = mask * torch.sin(np.pi * mask / 2)
546
- mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
547
- return mask_pf
548
-
549
- def forward(self, spec: torch.Tensor, mask: torch.Tensor, atten_lim: Optional[torch.Tensor] = None) -> torch.Tensor:
550
- # spec (real) [B, 1, T, F, 2], F: freq_bins
551
- # mask (real): [B, 1, T, Fe], Fe: erb_bins
552
- # atten_lim: [B]
553
- if not self.training and self.post_filter:
554
- mask = self.pf(mask)
555
- if atten_lim is not None:
556
- # dB to amplitude
557
- atten_lim = 10 ** (-atten_lim / 20)
558
- # Greater equal (__ge__) not implemented for TorchVersion.
559
- if self.clamp_tensor:
560
- # Supported by torch >= 1.9
561
- mask = mask.clamp(min=atten_lim.view(-1, 1, 1, 1))
562
- else:
563
- m_out = []
564
- for i in range(atten_lim.shape[0]):
565
- m_out.append(mask[i].clamp_min(atten_lim[i].item()))
566
- mask = torch.stack(m_out, dim=0)
567
- mask = mask.matmul(self.erb_inv_fb) # [B, 1, T, F]
568
- if not spec.is_complex():
569
- mask = mask.unsqueeze(4)
570
- return spec * mask
571
-
572
-
573
- class DfDecoder(nn.Module):
574
- def __init__(self,
575
- config: DfNetConfig,
576
- ):
577
- super().__init__()
578
- layer_width = config.conv_channels
579
-
580
- self.emb_in_dim = config.conv_channels * config.erb_bins // 4
581
- self.emb_dim = config.df_hidden_dim
582
-
583
- self.df_n_hidden = config.df_hidden_dim
584
- self.df_n_layers = config.df_num_layers
585
- self.df_order = config.df_order
586
- self.df_bins = config.df_bins
587
- self.df_out_ch = config.df_order * 2
588
-
589
- self.df_convp = CausalConv2d(
590
- layer_width,
591
- self.df_out_ch,
592
- fstride=1,
593
- kernel_size=(config.df_pathway_kernel_size_t, 1),
594
- separable=True,
595
- bias=False,
596
- )
597
- self.df_gru = SqueezedGRU_S(
598
- self.emb_in_dim,
599
- self.emb_dim,
600
- num_layers=self.df_n_layers,
601
- batch_first=True,
602
- skip_op="none",
603
- activation_layer="relu",
604
- )
605
-
606
- if config.df_gru_skip == "none":
607
- self.df_skip = None
608
- elif config.df_gru_skip == "identity":
609
- if config.emb_hidden_dim != config.df_hidden_dim:
610
- raise AssertionError("Dimensions do not match")
611
- self.df_skip = nn.Identity()
612
- elif config.df_gru_skip == "grouped_linear":
613
- self.df_skip = GroupedLinear(self.emb_in_dim, self.emb_dim, groups=config.df_decoder_linear_groups)
614
- else:
615
- raise NotImplementedError()
616
-
617
- self.df_out: nn.Module
618
- out_dim = self.df_bins * self.df_out_ch
619
-
620
- self.df_out = nn.Sequential(
621
- GroupedLinear(
622
- input_size=self.df_n_hidden,
623
- hidden_size=out_dim,
624
- groups=config.df_decoder_linear_groups
625
- ),
626
- nn.Tanh()
627
- )
628
- self.df_fc_a = nn.Sequential(
629
- nn.Linear(self.df_n_hidden, 1),
630
- nn.Sigmoid()
631
- )
632
-
633
- def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor:
634
- b, t, _ = emb.shape
635
- c, _ = self.df_gru(emb) # [B, T, H], H: df_n_hidden
636
- if self.df_skip is not None:
637
- c = c + self.df_skip(emb)
638
- c0 = self.df_convp(c0).permute(0, 2, 3, 1) # [B, T, F, O*2], channels_last
639
- c = self.df_out(c) # [B, T, F*O*2], O: df_order
640
- c = c.view(b, t, self.df_bins, self.df_out_ch) + c0 # [B, T, F, O*2]
641
- return c
642
-
643
-
644
- class DfOutputReshapeMF(nn.Module):
645
- """Coefficients output reshape for multiframe/MultiFrameModule
646
-
647
- Requires input of shape B, C, T, F, 2.
648
- """
649
-
650
- def __init__(self, df_order: int, df_bins: int):
651
- super().__init__()
652
- self.df_order = df_order
653
- self.df_bins = df_bins
654
-
655
- def forward(self, coefs: torch.Tensor) -> torch.Tensor:
656
- # [B, T, F, O*2] -> [B, O, T, F, 2]
657
- new_shape = list(coefs.shape)
658
- new_shape[-1] = -1
659
- new_shape.append(2)
660
- coefs = coefs.view(new_shape)
661
- coefs = coefs.permute(0, 3, 1, 2, 4)
662
- return coefs
663
-
664
-
665
- class DfNet(nn.Module):
666
- """
667
- DeepFilterNet: Perceptually Motivated Real-Time Speech Enhancement
668
- https://arxiv.org/abs/2305.08227
669
-
670
671
- """
672
- def __init__(self,
673
- config: DfNetConfig,
674
- erb_fb: torch.FloatTensor,
675
- erb_inv_fb: torch.FloatTensor,
676
- run_df: bool = True,
677
- train_mask: bool = True,
678
- ):
679
- """
680
- :param erb_fb: erb filter bank.
681
- """
682
- super(DfNet, self).__init__()
683
- if config.erb_bins % 8 != 0:
684
- raise AssertionError("erb_bins should be divisible by 8")
685
-
686
- self.df_lookahead = config.df_lookahead
687
- self.df_bins = config.df_bins
688
- self.freq_bins: int = config.fft_size // 2 + 1
689
- self.emb_dim: int = config.conv_channels * config.erb_bins
690
- self.erb_bins: int = config.erb_bins
691
-
692
- if config.conv_lookahead > 0:
693
- if config.conv_lookahead < config.df_lookahead:
694
- raise AssertionError
695
- # for last 2 dim, pad (left, right, top, bottom).
696
- self.pad_feat = nn.ConstantPad2d((0, 0, -config.conv_lookahead, config.conv_lookahead), 0.0)
697
- else:
698
- self.pad_feat = nn.Identity()
699
-
700
- if config.df_lookahead > 0:
701
- # for last 3 dim, pad (left, right, top, bottom, front, back).
702
- self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, -config.df_lookahead, config.df_lookahead), 0.0)
703
- else:
704
- self.pad_spec = nn.Identity()
705
-
706
- self.register_buffer("erb_fb", erb_fb)
707
-
708
- self.enc = Encoder(config)
709
- self.erb_dec = ErbDecoder(config)
710
- self.mask = Mask(erb_inv_fb)
711
-
712
- self.erb_inv_fb = erb_inv_fb
713
- self.post_filter = config.mask_post_filter
714
- self.post_filter_beta = config.post_filter_beta
715
-
716
- self.df_order = config.df_order
717
- self.df_op = MF.DF(num_freqs=config.df_bins, frame_size=config.df_order, lookahead=self.df_lookahead)
718
- self.df_dec = DfDecoder(config)
719
- self.df_out_transform = DfOutputReshapeMF(self.df_order, config.df_bins)
720
-
721
- self.run_erb = config.df_bins + 1 < self.freq_bins
722
- if not self.run_erb:
723
- logger.warning("Running without ERB stage")
724
- self.run_df = run_df
725
- if not run_df:
726
- logger.warning("Running without DF stage")
727
- self.train_mask = train_mask
728
- self.lsnr_dropout = config.lsnr_dropout
729
- if config.df_n_iter != 1:
730
- raise AssertionError
731
-
732
- def forward1(
733
- self,
734
- spec: torch.Tensor,
735
- feat_erb: torch.Tensor,
736
- feat_spec: torch.Tensor, # Not used, take spec modified by mask instead
737
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
738
- """Forward method of DeepFilterNet2.
739
-
740
- Args:
741
- spec (Tensor): Spectrum of shape [B, 1, T, F, 2]
742
- feat_erb (Tensor): ERB features of shape [B, 1, T, E]
743
- feat_spec (Tensor): Complex spectrogram features of shape [B, 1, T, F', 2]
744
-
745
- Returns:
746
- spec (Tensor): Enhanced spectrum of shape [B, 1, T, F, 2]
747
- m (Tensor): ERB mask estimate of shape [B, 1, T, E]
748
- lsnr (Tensor): Local SNR estimate of shape [B, T, 1]
749
- """
750
- # feat_spec shape: [batch_size, 1, time_steps, freq_dim, 2]
751
- feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2)
752
- # feat_spec shape: [batch_size, 2, time_steps, freq_dim]
753
-
754
- # feat_erb shape: [batch_size, 1, time_steps, erb_bins]
755
- # assert time_steps >= conv_lookahead.
756
- feat_erb = self.pad_feat(feat_erb)
757
- feat_spec = self.pad_feat(feat_spec)
758
- e0, e1, e2, e3, emb, c0, lsnr, h = self.enc(feat_erb, feat_spec)
759
-
760
- if self.lsnr_droput:
761
- idcs = lsnr.squeeze() > -10.0
762
- b, t = (spec.shape[0], spec.shape[2])
763
- m = torch.zeros((b, 1, t, self.erb_bins), device=spec.device)
764
- df_coefs = torch.zeros((b, t, self.nb_df, self.df_order * 2))
765
- spec_m = spec.clone()
766
- emb = emb[:, idcs]
767
- e0 = e0[:, :, idcs]
768
- e1 = e1[:, :, idcs]
769
- e2 = e2[:, :, idcs]
770
- e3 = e3[:, :, idcs]
771
- c0 = c0[:, :, idcs]
772
-
773
- if self.run_erb:
774
- if self.lsnr_dropout:
775
- m[:, :, idcs] = self.erb_dec(emb, e3, e2, e1, e0)
776
- else:
777
- m = self.erb_dec(emb, e3, e2, e1, e0)
778
- spec_m = self.mask(spec, m)
779
- else:
780
- m = torch.zeros((), device=spec.device)
781
- spec_m = torch.zeros_like(spec)
782
-
783
- if self.run_df:
784
- if self.lsnr_dropout:
785
- df_coefs[:, idcs] = self.df_dec(emb, c0)
786
- else:
787
- df_coefs = self.df_dec(emb, c0)
788
- df_coefs = self.df_out_transform(df_coefs)
789
- spec_e = self.df_op(spec.clone(), df_coefs)
790
- spec_e[..., self.df_bins:, :] = spec_m[..., self.df_bins:, :]
791
- else:
792
- df_coefs = torch.zeros((), device=spec.device)
793
- spec_e = spec_m
794
-
795
- if self.post_filter:
796
- beta = self.post_filter_beta
797
- eps = 1e-12
798
- mask = (utils.as_complex(spec_e).abs() / utils.as_complex(spec).abs().add(eps)).clamp(eps, 1)
799
- mask_sin = mask * torch.sin(PI * mask / 2).clamp_min(eps)
800
- pf = (1 + beta) / (1 + beta * mask.div(mask_sin).pow(2))
801
- spec_e = spec_e * pf.unsqueeze(-1)
802
-
803
- return spec_e, m, lsnr, df_coefs
804
-
805
- def forward(
806
- self,
807
- spec: torch.Tensor,
808
- feat_erb: torch.Tensor,
809
- feat_spec: torch.Tensor, # Not used, take spec modified by mask instead
810
- erb_encoder_h: torch.Tensor = None,
811
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
812
- # feat_spec shape: [batch_size, 1, time_steps, freq_dim, 2]
813
- feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2)
814
- # feat_spec shape: [batch_size, 2, time_steps, freq_dim]
815
-
816
- # feat_erb shape: [batch_size, 1, time_steps, erb_bins]
817
- # assert time_steps >= conv_lookahead.
818
- feat_erb = self.pad_feat(feat_erb)
819
- feat_spec = self.pad_feat(feat_spec)
820
- e0, e1, e2, e3, emb, c0, lsnr, erb_encoder_h = self.enc(feat_erb, feat_spec, erb_encoder_h)
821
-
822
- m = self.erb_dec(emb, e3, e2, e1, e0)
823
- spec_m = self.mask(spec, m)
824
- # spec_e = spec_m
825
-
826
- df_coefs = self.df_dec(emb, c0)
827
- df_coefs = self.df_out_transform(df_coefs)
828
- spec_e = self.df_op(spec.clone(), df_coefs)
829
- spec_e[..., self.df_bins:, :] = spec_m[..., self.df_bins:, :]
830
-
831
- return spec_e, m, lsnr, df_coefs, erb_encoder_h
832
-
833
-
834
- if __name__ == "__main__":
835
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/dfnet3/multiframes.py DELETED
@@ -1,145 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
-
8
- # From torchaudio
9
- def _compute_mat_trace(input: torch.Tensor, dim1: int = -2, dim2: int = -1) -> torch.Tensor:
10
- r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
11
- Args:
12
- input (torch.Tensor): Tensor of dimension `(..., channel, channel)`
13
- dim1 (int, optional): the first dimension of the diagonal matrix
14
- (Default: -1)
15
- dim2 (int, optional): the second dimension of the diagonal matrix
16
- (Default: -2)
17
- Returns:
18
- Tensor: trace of the input Tensor
19
- """
20
- assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
21
- assert (
22
- input.shape[dim1] == input.shape[dim2]
23
- ), "The size of ``dim1`` and ``dim2`` must be the same."
24
- input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
25
- return input.sum(dim=-1)
26
-
27
-
28
- def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
29
- """Perform Tikhonov regularization (only modifying real part).
30
- Args:
31
- mat (torch.Tensor): input matrix (..., channel, channel)
32
- reg (float, optional): regularization factor (Default: 1e-8)
33
- eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``)
34
- Returns:
35
- Tensor: regularized matrix (..., channel, channel)
36
- """
37
- # Add eps
38
- C = mat.size(-1)
39
- eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
40
- epsilon = _compute_mat_trace(mat).real[..., None, None] * reg
41
- # in case that correlation_matrix is all-zero
42
- epsilon = epsilon + eps
43
- mat = mat + epsilon * eye[..., :, :]
44
- return mat
45
-
46
-
47
- class MultiFrameModule(nn.Module):
48
- """
49
- Multi-frame speech enhancement modules.
50
-
51
- Signal model and notation:
52
- Noisy: `x = s + n`
53
- Enhanced: `y = f(x)`
54
- Objective: `min ||s - y||`
55
-
56
- PSD: Power spectral density, notated eg. as `Rxx` for noisy PSD.
57
- IFC: Inter-frame correlation vector: PSD*u, u: selection vector. Notated as `rxx`
58
- RTF: Relative transfere function, also called steering vector.
59
- """
60
- def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, real: bool = False):
61
- """
62
- Multi-Frame filtering module.
63
-
64
- :param num_freqs: int. Number of frequency bins used for filtering.
65
- :param frame_size: int. Frame size in FD domain.
66
- :param lookahead: int. Lookahead, may be used to select the output time step.
67
- Note: This module does not add additional padding according to lookahead!
68
- :param real:
69
- """
70
- super().__init__()
71
- self.num_freqs = num_freqs
72
- self.frame_size = frame_size
73
- self.real = real
74
- if real:
75
- self.pad = nn.ConstantPad3d((0, 0, 0, 0, frame_size - 1 - lookahead, lookahead), 0.0)
76
- else:
77
- self.pad = nn.ConstantPad2d((0, 0, frame_size - 1 - lookahead, lookahead), 0.0)
78
- self.need_unfold = frame_size > 1
79
- self.lookahead = lookahead
80
-
81
- def spec_unfold_real(self, spec: torch.Tensor):
82
- if self.need_unfold:
83
- spec = self.pad(spec).unfold(-3, self.frame_size, 1)
84
- return spec.permute(0, 1, 5, 2, 3, 4)
85
- # return as_windowed(self.pad(spec), self.frame_size, 1, dim=-3)
86
- return spec.unsqueeze(-1)
87
-
88
- def spec_unfold(self, spec: torch.Tensor):
89
- """Pads and unfolds the spectrogram according to frame_size.
90
-
91
- Args:
92
- spec (complex Tensor): Spectrogram of shape [B, C, T, F]
93
- Returns:
94
- spec (Tensor): Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
95
- """
96
- if self.need_unfold:
97
- return self.pad(spec).unfold(2, self.frame_size, 1)
98
- return spec.unsqueeze(-1)
99
-
100
- @staticmethod
101
- def solve(Rxx, rss, diag_eps: float = 1e-8, eps: float = 1e-7) -> torch.Tensor:
102
- return torch.einsum(
103
- "...nm,...m->...n", torch.inverse(_tik_reg(Rxx, diag_eps, eps)), rss
104
- ) # [T, F, N]
105
-
106
- @staticmethod
107
- def apply_coefs(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
108
- # spec: [B, C, T, F, N]
109
- # coefs: [B, C, T, F, N]
110
- return torch.einsum("...n,...n->...", spec, coefs)
111
-
112
-
113
- class DF(MultiFrameModule):
114
- """Deep Filtering."""
115
-
116
- def __init__(self, num_freqs: int, frame_size: int, lookahead: int = 0, conj: bool = False):
117
- super().__init__(num_freqs, frame_size, lookahead)
118
- self.conj: bool = conj
119
-
120
- def forward(self, spec: torch.Tensor, coefs: torch.Tensor):
121
- spec_u = self.spec_unfold(torch.view_as_complex(spec))
122
- coefs = torch.view_as_complex(coefs)
123
- spec_f = spec_u.narrow(-2, 0, self.num_freqs)
124
- coefs = coefs.view(coefs.shape[0], -1, self.frame_size, *coefs.shape[2:])
125
- if self.conj:
126
- coefs = coefs.conj()
127
- spec_f = self.df(spec_f, coefs)
128
- if self.training:
129
- spec = spec.clone()
130
- spec[..., : self.num_freqs, :] = torch.view_as_real(spec_f)
131
- return spec
132
-
133
- @staticmethod
134
- def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
135
- """
136
- Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
137
- :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N].
138
- :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F].
139
- :return: (complex Tensor). Spectrogram of shape [B, C, T, F].
140
- """
141
- return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
142
-
143
-
144
- if __name__ == '__main__':
145
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/dfnet3/utils.py DELETED
@@ -1,17 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import torch
4
-
5
-
6
- def as_complex(x: torch.Tensor):
7
- if torch.is_complex(x):
8
- return x
9
- if x.shape[-1] != 2:
10
- raise ValueError(f"Last dimension need to be of length 2 (re + im), but got {x.shape}")
11
- if x.stride(-1) != 1:
12
- x = x.contiguous()
13
- return torch.view_as_complex(x)
14
-
15
-
16
- if __name__ == '__main__':
17
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/dtln/modeling_dtln.py CHANGED
@@ -2,6 +2,10 @@
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/AkenoSyuRi/DTLNPytorch
 
 
 
 
5
  """
6
  import os
7
  from typing import Optional, Union
 
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/AkenoSyuRi/DTLNPytorch
5
+
6
+ https://github.com/breizhn/DTLN
7
+ 在 dns3 500个小时的数据上训练, 在 dns3 的测试集上达到了 pesq 3.04 的水平。
8
+
9
  """
10
  import os
11
  from typing import Optional, Union
toolbox/torchaudio/models/frcrn/modeling_frcrn.py CHANGED
@@ -6,6 +6,8 @@ https://arxiv.org/abs/2206.07293
6
  https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/frcrn.py
7
  https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/models/frcrn_se/frcrn.py
8
 
 
 
9
  """
10
  import os
11
  from typing import Optional, Union
 
6
  https://github.com/modelscope/modelscope/blob/master/modelscope/models/audio/ans/frcrn.py
7
  https://huggingface.co/spaces/alibabasglab/ClearVoice/blob/main/models/frcrn_se/frcrn.py
8
 
9
+ https://github.com/modelscope/ClearerVoice-Studio/tree/main/clearvoice/clearvoice/models/frcrn_se
10
+
11
  """
12
  import os
13
  from typing import Optional, Union
toolbox/torchaudio/models/gtcrn/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+
5
+ if __name__ == "__main__":
6
+ pass
toolbox/torchaudio/models/gtcrn/modeling_gtcrn.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ https://blog.csdn.net/gitblog_00478/article/details/141522595
5
+
6
+ https://github.com/Xiaobin-Rong/gtcrn/blob/main/gtcrn.py
7
+ https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/gtcrn_stream.py
8
+ """
9
+ import torch
10
+ import torch.nn as nn
11
+ from typing import List, Tuple, Union
12
+
13
+
14
+ if __name__ == "__main__":
15
+ pass
toolbox/torchaudio/models/lstm/modeling_lstm.py CHANGED
@@ -85,13 +85,14 @@ class LstmModel(nn.Module):
85
  if remainder > 0:
86
  n_samples_pad = self.hop_size - remainder
87
  signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
88
- return signal, n_samples
89
 
90
  def forward(self,
91
  noisy: torch.Tensor,
92
  h_state: Tuple[torch.Tensor, torch.Tensor] = None,
93
  ):
94
- noisy, num_samples = self.signal_prepare(noisy)
 
95
  batch_size, _, num_samples_pad = noisy.shape
96
  # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
97
 
@@ -207,7 +208,7 @@ def main():
207
  model.eval()
208
 
209
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
210
- noisy, _ = model.signal_prepare(noisy)
211
  b, _, num_samples = noisy.shape
212
  t = (num_samples - config.win_size) / config.hop_size + 1
213
 
 
85
  if remainder > 0:
86
  n_samples_pad = self.hop_size - remainder
87
  signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
88
+ return signal
89
 
90
  def forward(self,
91
  noisy: torch.Tensor,
92
  h_state: Tuple[torch.Tensor, torch.Tensor] = None,
93
  ):
94
+ num_samples = noisy.shape[-1]
95
+ noisy = self.signal_prepare(noisy)
96
  batch_size, _, num_samples_pad = noisy.shape
97
  # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
98
 
 
208
  model.eval()
209
 
210
  noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
211
+ noisy = model.signal_prepare(noisy)
212
  b, _, num_samples = noisy.shape
213
  t = (num_samples - config.win_size) / config.hop_size + 1
214
 
toolbox/torchaudio/models/rnnoise/configuration_rnnoise.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class RNNoiseConfig(PretrainedConfig):
7
+ def __init__(self,
8
+ sample_rate: int = 8000,
9
+ segment_size: int = 32000,
10
+ nfft: int = 512,
11
+ win_size: int = 512,
12
+ hop_size: int = 256,
13
+ win_type: str = "hann",
14
+
15
+ erb_bins: int = 32,
16
+ min_freq_bins_for_erb: int = 2,
17
+
18
+ conv_size: int = 128,
19
+ gru_size: int = 256,
20
+
21
+ min_snr_db: float = -10,
22
+ max_snr_db: float = 20,
23
+
24
+ max_epochs: int = 100,
25
+ batch_size: int = 4,
26
+ num_workers: int = 4,
27
+ seed: int = 1234,
28
+
29
+ lr: float = 0.001,
30
+ lr_scheduler: str = "CosineAnnealingLR",
31
+ lr_scheduler_kwargs: dict = None,
32
+
33
+ weight_decay: float = 0.00001,
34
+ clip_grad_norm: float = 10.,
35
+ eval_steps: int = 25000,
36
+
37
+ **kwargs
38
+ ):
39
+ super(RNNoiseConfig, self).__init__(**kwargs)
40
+ self.sample_rate = sample_rate
41
+ self.segment_size = segment_size
42
+ self.nfft = nfft
43
+ self.win_size = win_size
44
+ self.hop_size = hop_size
45
+ self.win_type = win_type
46
+
47
+ self.erb_bins = erb_bins
48
+ self.min_freq_bins_for_erb = min_freq_bins_for_erb
49
+
50
+ self.conv_size = conv_size
51
+ self.gru_size = gru_size
52
+
53
+ self.min_snr_db = min_snr_db
54
+ self.max_snr_db = max_snr_db
55
+
56
+ self.max_epochs = max_epochs
57
+ self.batch_size = batch_size
58
+ self.num_workers = num_workers
59
+ self.seed = seed
60
+
61
+ self.lr = lr
62
+ self.lr_scheduler = lr_scheduler
63
+ self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()
64
+
65
+ self.weight_decay = weight_decay
66
+ self.clip_grad_norm = clip_grad_norm
67
+ self.eval_steps = eval_steps
68
+
69
+
70
+ def main():
71
+ config = RNNoiseConfig()
72
+ config.to_yaml_file("yaml/config.yaml")
73
+ return
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
toolbox/torchaudio/models/rnnoise/modeling_rnnoise.py CHANGED
@@ -2,10 +2,401 @@
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/xiph/rnnoise
 
5
 
6
  https://arxiv.org/abs/1709.08243
7
 
8
  """
 
 
9
 
10
- if __name__ == '__main__':
11
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  # -*- coding: utf-8 -*-
3
  """
4
  https://github.com/xiph/rnnoise
5
+ https://github.com/xiph/rnnoise/blob/main/torch/rnnoise/rnnoise.py
6
 
7
  https://arxiv.org/abs/1709.08243
8
 
9
  """
10
+ import os
11
+ from typing import Optional, Union, Tuple
12
 
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.nn import functional as F
16
+
17
+ from toolbox.torch.sparsification.gru_sparsifier import GRUSparsifier
18
+ from toolbox.torchaudio.models.rnnoise.configuration_rnnoise import RNNoiseConfig
19
+ from toolbox.torchaudio.configuration_utils import CONFIG_FILE
20
+ from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
21
+ from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
22
+
23
+
24
+ sparsify_start = 6000
25
+ sparsify_stop = 20000
26
+ sparsify_interval = 100
27
+ sparsify_exponent = 3
28
+
29
+
30
+ sparse_params1 = {
31
+ "W_hr" : (0.3, [8, 4], True),
32
+ "W_hz" : (0.2, [8, 4], True),
33
+ "W_hn" : (0.5, [8, 4], True),
34
+ "W_ir" : (0.3, [8, 4], False),
35
+ "W_iz" : (0.2, [8, 4], False),
36
+ "W_in" : (0.5, [8, 4], False),
37
+ }
38
+
39
+
40
+ def init_weights(module):
41
+ if isinstance(module, nn.GRU):
42
+ for p in module.named_parameters():
43
+ if p[0].startswith("weight_hh_"):
44
+ nn.init.orthogonal_(p[1])
45
+
46
+
47
+ class RNNoise(nn.Module):
48
+ def __init__(self,
49
+ sample_rate: int = 8000,
50
+ nfft: int = 512,
51
+ win_size: int = 512,
52
+ hop_size: int = 256,
53
+ win_type: str = "hann",
54
+ erb_bins: int = 32,
55
+ min_freq_bins_for_erb: int = 2,
56
+ conv_size: int = 128,
57
+ gru_size: int = 256,
58
+ ):
59
+ super(RNNoise, self).__init__()
60
+ self.sample_rate = sample_rate
61
+ self.nfft = nfft
62
+ self.win_size = win_size
63
+ self.hop_size = hop_size
64
+ self.win_type = win_type
65
+
66
+ self.erb_bins = erb_bins
67
+ self.min_freq_bins_for_erb = min_freq_bins_for_erb
68
+ self.conv_size = conv_size
69
+ self.gru_size = gru_size
70
+
71
+ self.input_dim = nfft // 2 + 1
72
+
73
+ self.eps = 1e-12
74
+
75
+ self.erb_bands = ErbBands(
76
+ sample_rate=self.sample_rate,
77
+ nfft=self.nfft,
78
+ erb_bins=self.erb_bins,
79
+ min_freq_bins_for_erb=self.min_freq_bins_for_erb,
80
+ )
81
+
82
+ self.stft = ConvSTFT(
83
+ nfft=self.nfft,
84
+ win_size=self.win_size,
85
+ hop_size=self.hop_size,
86
+ win_type=self.win_type,
87
+ power=None,
88
+ requires_grad=False
89
+ )
90
+ self.istft = ConviSTFT(
91
+ nfft=self.nfft,
92
+ win_size=self.win_size,
93
+ hop_size=self.hop_size,
94
+ win_type=self.win_type,
95
+ requires_grad=False
96
+ )
97
+
98
+ self.pad = nn.ConstantPad1d(padding=(2, 2), value=0)
99
+ self.conv1 = nn.Conv1d(self.erb_bins, conv_size, kernel_size=3, padding="valid")
100
+ self.conv2 = nn.Conv1d(conv_size, gru_size, kernel_size=3, padding="valid")
101
+
102
+ self.gru1 = nn.GRU(self.gru_size, self.gru_size, batch_first=True)
103
+ self.gru2 = nn.GRU(self.gru_size, self.gru_size, batch_first=True)
104
+ self.gru3 = nn.GRU(self.gru_size, self.gru_size, batch_first=True)
105
+
106
+ self.dense_out = nn.Linear(4*self.gru_size, self.erb_bins)
107
+
108
+ nb_params = sum(p.numel() for p in self.parameters())
109
+ print(f"model: {nb_params} weights")
110
+ self.apply(init_weights)
111
+
112
+ self.sparsifier = [
113
+ GRUSparsifier(
114
+ task_list=[(self.gru1, sparse_params1)],
115
+ start=sparsify_start,
116
+ stop=sparsify_stop,
117
+ interval=sparsify_interval,
118
+ exponent=sparsify_exponent,
119
+ ),
120
+ GRUSparsifier(
121
+ task_list=[(self.gru2, sparse_params1)],
122
+ start=sparsify_start,
123
+ stop=sparsify_stop,
124
+ interval=sparsify_interval,
125
+ exponent=sparsify_exponent,
126
+ ),
127
+ GRUSparsifier(
128
+ task_list=[(self.gru3, sparse_params1)],
129
+ start=sparsify_start,
130
+ stop=sparsify_stop,
131
+ interval=sparsify_interval,
132
+ exponent=sparsify_exponent,
133
+ )
134
+ ]
135
+
136
+ def sparsify(self):
137
+ for sparsifier in self.sparsifier:
138
+ sparsifier.step()
139
+
140
+ def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
141
+ if signal.dim() == 2:
142
+ signal = torch.unsqueeze(signal, dim=1)
143
+ _, _, n_samples = signal.shape
144
+ remainder = (n_samples - self.win_size) % self.hop_size
145
+ if remainder > 0:
146
+ n_samples_pad = self.hop_size - remainder
147
+ signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
148
+ return signal
149
+
150
+ def forward(self,
151
+ noisy: torch.Tensor,
152
+ states: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None,
153
+ ):
154
+ num_samples = noisy.shape[-1]
155
+ noisy = self.signal_prepare(noisy)
156
+ batch_size, _, num_samples_pad = noisy.shape
157
+ # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
158
+
159
+ mag_noisy, pha_noisy = self.mag_pha_stft(noisy)
160
+ # shape: (b, f, t)
161
+ # t = (num_samples - win_size) / hop_size + 1
162
+
163
+ mag_noisy_t = torch.transpose(mag_noisy, dim0=1, dim1=2)
164
+ # shape: (b, t, f)
165
+ mag_noisy_t_erb = self.erb_bands.erb_scale(mag_noisy_t, db=True)
166
+ # shape: (b, t, erb_bins)
167
+ mag_noisy_t_erb = torch.transpose(mag_noisy_t_erb, dim0=1, dim1=2)
168
+ # shape: (b, erb_bins, t)
169
+
170
+ mag_noisy_t_erb = self.pad(mag_noisy_t_erb)
171
+ mag_noisy_t_erb = self.forward_conv(mag_noisy_t_erb)
172
+ gru_out, states = self.forward_gru(mag_noisy_t_erb, states)
173
+ # gru_out shape: [b, t, f]
174
+ mask_erb = torch.sigmoid(self.dense_out(gru_out))
175
+ # mask_erb shape: (b, t, erb_bins)
176
+
177
+ mask = self.erb_bands.erb_scale_inv(mask_erb)
178
+ # mask shape: (b, t, f)
179
+ mask = torch.transpose(mask, dim0=1, dim1=2)
180
+ # mask shape: (b, f, t)
181
+
182
+ stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask)
183
+ denoise = self.istft.forward(stft_denoise)
184
+ # denoise shape: [b, 1, num_samples_pad]
185
+
186
+ denoise = denoise[:, :, :num_samples]
187
+ # denoise shape: [b, 1, num_samples]
188
+ return denoise, mask, states
189
+
190
+ def forward_conv(self, mag_noisy: torch.Tensor):
191
+ # mag_noisy shape: [b, f, t]
192
+ tmp = mag_noisy
193
+ # tmp shape: [b, f, t]
194
+ tmp = torch.tanh(self.conv1(tmp))
195
+ tmp = torch.tanh(self.conv2(tmp))
196
+ # tmp shape: [b, f, t]
197
+ return tmp
198
+
199
+ def forward_gru(self,
200
+ mag_noisy: torch.Tensor,
201
+ states: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None,
202
+ ):
203
+ if states is None:
204
+ gru1_state = None
205
+ gru2_state = None
206
+ gru3_state = None
207
+ else:
208
+ gru1_state = states[0]
209
+ gru2_state = states[1]
210
+ gru3_state = states[2]
211
+
212
+ # mag_noisy shape: [b, f, t]
213
+ tmp = mag_noisy.permute(0, 2, 1)
214
+ # tmp shape: [b, t, f]
215
+
216
+ gru1_out, gru1_state = self.gru1(tmp, gru1_state)
217
+ gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
218
+ gru3_out, gru3_state = self.gru3(gru2_out, gru3_state)
219
+ new_states = [gru1_state, gru2_state, gru3_state]
220
+
221
+ gru_out = torch.cat(tensors=[tmp, gru1_out, gru2_out, gru3_out], dim=-1)
222
+ # gru_out shape: [b, t, f]
223
+ return gru_out, new_states
224
+
225
+ def forward_chunk_by_chunk(self,
226
+ noisy: torch.Tensor,
227
+ ):
228
+ noisy = self.signal_prepare(noisy)
229
+ b, _, num_samples = noisy.shape
230
+ t = (num_samples - self.win_size) / self.hop_size + 1
231
+
232
+ waveform = torch.zeros(size=(b, 1, 0), dtype=torch.float32)
233
+
234
+ states = None
235
+ waveform_cache = None
236
+ coff_cache = None
237
+
238
+ cache_list = list()
239
+ for i in range(int(t)):
240
+ begin = i * self.hop_size
241
+ end = begin + self.win_size
242
+ sub_noisy = noisy[:, :, begin:end]
243
+ mag_noisy, pha_noisy = self.mag_pha_stft(sub_noisy)
244
+ mag_noisy_t = torch.transpose(mag_noisy, dim0=1, dim1=2)
245
+ mag_noisy_t_erb = self.erb_bands.erb_scale(mag_noisy_t, db=True)
246
+ mag_noisy_t_erb = torch.transpose(mag_noisy_t_erb, dim0=1, dim1=2)
247
+ # mag_noisy_t_erb shape: (b, erb_bins, t)
248
+
249
+ if len(cache_list) == 0:
250
+ cache_list.extend([{
251
+ "mag_noisy": torch.zeros_like(mag_noisy),
252
+ "pha_noisy": torch.zeros_like(pha_noisy),
253
+ "mag_noisy_t_erb": torch.zeros_like(mag_noisy_t_erb),
254
+ }] * 2)
255
+ cache_list.append({
256
+ "mag_noisy": mag_noisy,
257
+ "pha_noisy": pha_noisy,
258
+ "mag_noisy_t_erb": mag_noisy_t_erb,
259
+ })
260
+ if len(cache_list) < 5:
261
+ continue
262
+ mag_noisy_t_erb = torch.concat(
263
+ tensors=[c["mag_noisy_t_erb"] for c in cache_list],
264
+ dim=-1
265
+ )
266
+ mag_noisy = cache_list[2]["mag_noisy"]
267
+ pha_noisy = cache_list[2]["pha_noisy"]
268
+ cache_list.pop(0)
269
+ # mag_noisy_t_erb shape: [b, f, 5]
270
+ mag_noisy_t_erb = self.forward_conv(mag_noisy_t_erb)
271
+ # mag_noisy_t_erb shape: [b, f, 1]
272
+ gru_out, states = self.forward_gru(mag_noisy_t_erb, states)
273
+ mask_erb = torch.sigmoid(self.dense_out(gru_out))
274
+ mask = self.erb_bands.erb_scale_inv(mask_erb)
275
+ mask = torch.transpose(mask, dim0=1, dim1=2)
276
+ stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask)
277
+ sub_waveform, waveform_cache, coff_cache = self.istft.forward_chunk(stft_denoise, waveform_cache, coff_cache)
278
+ waveform = torch.concat(tensors=[waveform, sub_waveform], dim=-1)
279
+
280
+ return waveform
281
+
282
+ def do_mask(self,
283
+ mag_noisy: torch.Tensor,
284
+ pha_noisy: torch.Tensor,
285
+ mask: torch.Tensor,
286
+ ):
287
+ # (b, f, t)
288
+ mag_denoise = mag_noisy * mask
289
+ stft_denoise = mag_denoise * torch.exp((1j * pha_noisy))
290
+ return stft_denoise
291
+
292
+ def mag_pha_stft(self, noisy: torch.Tensor):
293
+ # noisy shape: [b, num_samples]
294
+ stft_noisy = self.stft.forward(noisy)
295
+ # stft_noisy shape: [b, f, t], torch.complex64
296
+
297
+ real = torch.real(stft_noisy)
298
+ imag = torch.imag(stft_noisy)
299
+ mag_noisy = torch.sqrt(real ** 2 + imag ** 2)
300
+ pha_noisy = torch.atan2(imag, real)
301
+ # shape: (b, f, t)
302
+ return mag_noisy, pha_noisy
303
+
304
+
305
+ MODEL_FILE = "model.pt"
306
+
307
+
308
+ class RNNoisePretrainedModel(RNNoise):
309
+ def __init__(self,
310
+ config: RNNoiseConfig,
311
+ ):
312
+ super(RNNoisePretrainedModel, self).__init__(
313
+ sample_rate=config.sample_rate,
314
+ nfft=config.nfft,
315
+ win_size=config.win_size,
316
+ hop_size=config.hop_size,
317
+ win_type=config.win_type,
318
+ erb_bins=config.erb_bins,
319
+ min_freq_bins_for_erb=config.min_freq_bins_for_erb,
320
+ conv_size=config.conv_size,
321
+ gru_size=config.gru_size,
322
+ )
323
+ self.config = config
324
+
325
+ @classmethod
326
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
327
+ config = RNNoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
328
+
329
+ model = cls(config)
330
+
331
+ if os.path.isdir(pretrained_model_name_or_path):
332
+ ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
333
+ else:
334
+ ckpt_file = pretrained_model_name_or_path
335
+
336
+ with open(ckpt_file, "rb") as f:
337
+ state_dict = torch.load(f, map_location="cpu", weights_only=True)
338
+ model.load_state_dict(state_dict, strict=True)
339
+ return model
340
+
341
+ def save_pretrained(self,
342
+ save_directory: Union[str, os.PathLike],
343
+ state_dict: Optional[dict] = None,
344
+ ):
345
+
346
+ model = self
347
+
348
+ if state_dict is None:
349
+ state_dict = model.state_dict()
350
+
351
+ os.makedirs(save_directory, exist_ok=True)
352
+
353
+ # save state dict
354
+ model_file = os.path.join(save_directory, MODEL_FILE)
355
+ torch.save(state_dict, model_file)
356
+
357
+ # save config
358
+ config_file = os.path.join(save_directory, CONFIG_FILE)
359
+ self.config.to_yaml_file(config_file)
360
+ return save_directory
361
+
362
+
363
+ def main1():
364
+ config = RNNoiseConfig()
365
+ model = RNNoisePretrainedModel(config)
366
+ model.eval()
367
+
368
+ noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
369
+ noisy = model.signal_prepare(noisy)
370
+ b, _, num_samples = noisy.shape
371
+ t = (num_samples - config.win_size) / config.hop_size + 1
372
+
373
+ waveform, mask, h_state = model.forward(noisy)
374
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
375
+ print(waveform[:, :, 300: 302])
376
+
377
+ return
378
+
379
+
380
+ def main2():
381
+ config = RNNoiseConfig()
382
+ model = RNNoisePretrainedModel(config)
383
+ model.eval()
384
+
385
+ noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
386
+ noisy = model.signal_prepare(noisy)
387
+ b, _, num_samples = noisy.shape
388
+ t = (num_samples - config.win_size) / config.hop_size + 1
389
+
390
+ waveform, mask, h_state = model.forward(noisy)
391
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
392
+ print(waveform[:, :, 300: 302])
393
+
394
+ waveform = model.forward_chunk_by_chunk(noisy)
395
+ print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
396
+ print(waveform[:, :, 300: 302])
397
+
398
+ return
399
+
400
+
401
+ if __name__ == "__main__":
402
+ main2()
toolbox/torchaudio/models/rnnoise/yaml/config.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "rnnoise"
2
+
3
+ # spec
4
+ sample_rate: 8000
5
+ segment_size: 32000
6
+ nfft: 512
7
+ win_size: 512
8
+ hop_size: 256
9
+ win_type: hann
10
+
11
+ erb_bins: 32
12
+ min_freq_bins_for_erb: 2
13
+
14
+ # data
15
+ max_snr_db: 20
16
+ min_snr_db: -10
17
+
18
+ # model
19
+ conv_size: 256
20
+ gru_size: 256
21
+
22
+ # train
23
+ max_epochs: 100
24
+ batch_size: 32
25
+ num_workers: 4
26
+ seed: 1234
27
+
28
+ lr: 0.001
29
+ lr_scheduler: CosineAnnealingLR
30
+ lr_scheduler_kwargs: {}
31
+
32
+ weight_decay: 0.00001
33
+ clip_grad_norm: 10.0
34
+ eval_steps: 20000
toolbox/torchaudio/modules/freq_bands/erb_bands.py CHANGED
@@ -147,6 +147,7 @@ class ErbBands(nn.Module):
147
  return erb_fb, erb_fb_inv
148
 
149
  def erb_scale(self, spec: torch.Tensor, db: bool = True):
 
150
  spec_erb = torch.matmul(spec, self.erb_fb)
151
  if db:
152
  spec_erb = 10 * torch.log10(spec_erb + 1e-10)
 
147
  return erb_fb, erb_fb_inv
148
 
149
  def erb_scale(self, spec: torch.Tensor, db: bool = True):
150
+ # spec shape: (b, t, f)
151
  spec_erb = torch.matmul(spec, self.erb_fb)
152
  if db:
153
  spec_erb = 10 * torch.log10(spec_erb + 1e-10)