HoneyTian commited on
Commit
87129e4
·
1 Parent(s): b10ef9c
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. examples/data_preprocess/nx_speech_denoise/nx_speech_denoise.py +0 -83
  2. examples/dfnet2/run.sh +3 -3
  3. examples/dtln/run.sh +9 -2
  4. examples/frcrn/run.sh +3 -3
  5. main.py +20 -13
  6. toolbox/torchaudio/models/{nx_clean_unet/transformers → dccrn}/__init__.py +1 -1
  7. toolbox/torchaudio/models/{nx_denoise/stftnet/istftnet.py → dccrn/modeling_dccrn.py} +6 -3
  8. toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py +97 -84
  9. toolbox/torchaudio/models/dtln/modeling_dtln.py +9 -2
  10. toolbox/torchaudio/models/ehnet/modeling_ehnet.py +0 -1
  11. toolbox/torchaudio/models/nx_clean_unet/__init__.py +0 -6
  12. toolbox/torchaudio/models/nx_clean_unet/causal_convolution/__init__.py +0 -6
  13. toolbox/torchaudio/models/nx_clean_unet/causal_convolution/causal_conv2d.py +0 -261
  14. toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py +0 -100
  15. toolbox/torchaudio/models/nx_clean_unet/discriminator.py +0 -132
  16. toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav +0 -0
  17. toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py +0 -96
  18. toolbox/torchaudio/models/nx_clean_unet/loss.py +0 -22
  19. toolbox/torchaudio/models/nx_clean_unet/metrics.py +0 -80
  20. toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py +0 -401
  21. toolbox/torchaudio/models/nx_clean_unet/transformers/attention.py +0 -270
  22. toolbox/torchaudio/models/nx_clean_unet/transformers/mask.py +0 -74
  23. toolbox/torchaudio/models/nx_clean_unet/transformers/transformers.py +0 -266
  24. toolbox/torchaudio/models/nx_clean_unet/utils.py +0 -45
  25. toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml +0 -51
  26. toolbox/torchaudio/models/nx_denoise/__init__.py +0 -6
  27. toolbox/torchaudio/models/nx_denoise/causal_convolution/__init__.py +0 -6
  28. toolbox/torchaudio/models/nx_denoise/causal_convolution/causal_conv2d.py +0 -281
  29. toolbox/torchaudio/models/nx_denoise/configuration_nx_denoise.py +0 -102
  30. toolbox/torchaudio/models/nx_denoise/discriminator.py +0 -132
  31. toolbox/torchaudio/models/nx_denoise/inference_nx_denoise.py +0 -97
  32. toolbox/torchaudio/models/nx_denoise/loss.py +0 -22
  33. toolbox/torchaudio/models/nx_denoise/metrics.py +0 -80
  34. toolbox/torchaudio/models/nx_denoise/modeling_nx_denoise.py +0 -392
  35. toolbox/torchaudio/models/nx_denoise/stftnet/__init__.py +0 -6
  36. toolbox/torchaudio/models/nx_denoise/stftnet/stfnets.py +0 -9
  37. toolbox/torchaudio/models/nx_denoise/transformers/__init__.py +0 -6
  38. toolbox/torchaudio/models/nx_denoise/transformers/attention.py +0 -263
  39. toolbox/torchaudio/models/nx_denoise/transformers/mask.py +0 -74
  40. toolbox/torchaudio/models/nx_denoise/transformers/transformers.py +0 -479
  41. toolbox/torchaudio/models/nx_denoise/utils.py +0 -45
  42. toolbox/torchaudio/models/nx_denoise/yaml/config.yaml +0 -51
  43. toolbox/torchaudio/models/nx_dfnet/configuration_nx_dfnet.py +0 -102
  44. toolbox/torchaudio/models/nx_dfnet/modeling_nx_dfnet.py +0 -989
  45. toolbox/torchaudio/models/nx_dfnet/utils.py +0 -55
  46. toolbox/torchaudio/models/nx_mpnet/__init__.py +0 -6
  47. toolbox/torchaudio/models/nx_mpnet/causal_convolution/__init__.py +0 -6
  48. toolbox/torchaudio/models/nx_mpnet/causal_convolution/causal_conv2d.py +0 -445
  49. toolbox/torchaudio/models/nx_mpnet/configuration_nx_mpnet.py +0 -90
  50. toolbox/torchaudio/models/nx_mpnet/discriminator.py +0 -102
examples/data_preprocess/nx_speech_denoise/nx_speech_denoise.py DELETED
@@ -1,83 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import argparse
4
- import os
5
- from pathlib import Path
6
- import sys
7
-
8
- from gradio_client import Client, handle_file
9
- import numpy as np
10
- from tqdm import tqdm
11
- import shutil
12
-
13
- pwd = os.path.abspath(os.path.dirname(__file__))
14
- sys.path.append(os.path.join(pwd, "../../"))
15
-
16
- import librosa
17
- from scipy.io import wavfile
18
-
19
-
20
- def get_args():
21
- parser = argparse.ArgumentParser()
22
- parser.add_argument(
23
- "--src_dir",
24
- default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-PH",
25
- # default=r"/data/tianxing/HuggingDatasets/nx_noise/data/speech/en-PH",
26
- type=str
27
- )
28
- parser.add_argument(
29
- "--tgt_dir",
30
- default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\speech-denoise\en-PH",
31
- # default=r"/data/tianxing/HuggingDatasets/nx_noise/data/speech-denoise/en-PH",
32
- type=str
33
- )
34
- args = parser.parse_args()
35
- return args
36
-
37
-
38
- def main():
39
- args = get_args()
40
-
41
- # client = Client(src="http://10.75.27.247:7865/")
42
- client = Client(src="http://127.0.0.1:7865/")
43
-
44
- src_dir = Path(args.src_dir)
45
- tgt_dir = Path(args.tgt_dir)
46
- tgt_dir.mkdir(parents=True, exist_ok=True)
47
-
48
- tgt_date_list = list(sorted([date.name for date in src_dir.glob("*") if not date.name.endswith(".zip")]))
49
- finished_date_set = set(tgt_date_list[:-1])
50
- current_date = tgt_date_list[-1]
51
-
52
- print(f"finished_date_set: {finished_date_set}")
53
- print(f"current_date: {current_date}")
54
-
55
- finished_set = set()
56
- for filename in (tgt_dir / current_date).glob("*.wav"):
57
- name = filename.name
58
- finished_set.add(name)
59
-
60
- src_date_list = list(sorted([date.name for date in src_dir.glob("*")]))
61
- for date in src_date_list:
62
- if date in finished_date_set:
63
- continue
64
- for filename in (src_dir / current_date).glob("**/*.wav"):
65
- result = client.predict(
66
- noisy_audio_file_t=handle_file(filename.as_posix()),
67
- noisy_audio_microphone_t=None,
68
- engine="frcrn-dns3",
69
- api_name="/when_click_denoise_button"
70
- )
71
- denoise_file = result[0]
72
- tgt_file = tgt_dir / current_date / f"{filename.name}"
73
- tgt_file.parent.mkdir(parents=True, exist_ok=True)
74
-
75
- shutil.move(denoise_file, tgt_file)
76
- print(denoise_file)
77
- exit(0)
78
-
79
- return
80
-
81
-
82
- if __name__ == "__main__":
83
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/dfnet2/run.sh CHANGED
@@ -10,9 +10,9 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
10
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
12
 
13
- sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx-devoice \
14
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech" \
15
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise"
16
 
17
 
18
  END
 
10
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
11
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
12
 
13
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx2 \
14
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
15
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
16
 
17
 
18
  END
examples/dtln/run.sh CHANGED
@@ -7,16 +7,23 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
9
 
 
10
  sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \
11
  --config_file "yaml/config-512.yaml" \
12
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
13
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
14
 
15
 
16
- sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-1024 --final_model_name dtln-1024-nx \
17
  --config_file "yaml/config-1024.yaml" \
18
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
19
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech"
 
 
 
 
 
 
20
 
21
 
22
  END
 
7
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
8
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
9
 
10
+
11
  sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \
12
  --config_file "yaml/config-512.yaml" \
13
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
14
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
15
 
16
 
17
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtnl-1024-nx2 --final_model_name dtln-1024-nx2 \
18
  --config_file "yaml/config-1024.yaml" \
19
  --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
20
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
21
+
22
+
23
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir-1024 --final_model_name dtln-1024-nx-devoice \
24
+ --config_file "yaml/config-1024.yaml" \
25
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2" \
26
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise"
27
 
28
 
29
  END
examples/frcrn/run.sh CHANGED
@@ -9,10 +9,10 @@ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name fi
9
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
10
 
11
 
12
- sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx-devoice \
13
  --config_file "yaml/config-10.yaml" \
14
- --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech" \
15
- --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise"
16
 
17
  END
18
 
 
9
  --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
10
 
11
 
12
+ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx2 \
13
  --config_file "yaml/config-10.yaml" \
14
+ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
15
+ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
16
 
17
  END
18
 
main.py CHANGED
@@ -177,14 +177,10 @@ def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_
177
  infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs)
178
 
179
  begin = time.time()
180
- enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
181
  time_cost = time.time() - begin
182
 
183
- noisy_mag_db = generate_spectrogram(noisy_audio, title="noisy")
184
- denoise_mag_db = generate_spectrogram(enhanced_audio, title="denoise")
185
-
186
  fpr = time_cost / audio_duration
187
-
188
  info = {
189
  "time_cost": round(time_cost, 4),
190
  "audio_duration": round(audio_duration, 4),
@@ -192,12 +188,21 @@ def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_
192
  }
193
  message = json.dumps(info, ensure_ascii=False, indent=4)
194
 
195
- enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16)
 
 
 
 
 
 
 
 
196
  except Exception as e:
197
  raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.")
198
 
199
- enhanced_audio_t = (sample_rate, enhanced_audio)
200
- return enhanced_audio_t, message, noisy_mag_db, denoise_mag_db
 
201
 
202
 
203
  def main():
@@ -255,21 +260,23 @@ def main():
255
  with gr.Column(variant="panel", scale=5):
256
  with gr.Tabs():
257
  with gr.TabItem("audio"):
258
- dn_enhanced_audio = gr.Audio(label="enhanced_audio")
 
259
  dn_message = gr.Textbox(lines=1, max_lines=20, label="message")
260
  with gr.TabItem("mag_db"):
261
  dn_noisy_mag_db = gr.Image(label="noisy_mag_db")
262
  dn_denoise_mag_db = gr.Image(label="denoise_mag_db")
 
263
 
264
  dn_button.click(
265
  when_click_denoise_button,
266
  inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
267
- outputs=[dn_enhanced_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db]
268
  )
269
  gr.Examples(
270
  examples=examples,
271
  inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
272
- outputs=[dn_enhanced_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db],
273
  fn=when_click_denoise_button,
274
  # cache_examples=True,
275
  # cache_mode="lazy",
@@ -289,8 +296,8 @@ def main():
289
  # http://127.0.0.1:7865/
290
  # http://10.75.27.247:7865/
291
  blocks.queue().launch(
292
- share=True,
293
- # share=False if platform.system() == "Windows" else False,
294
  server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
295
  server_port=args.server_port
296
  )
 
177
  infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs)
178
 
179
  begin = time.time()
180
+ denoise_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
181
  time_cost = time.time() - begin
182
 
 
 
 
183
  fpr = time_cost / audio_duration
 
184
  info = {
185
  "time_cost": round(time_cost, 4),
186
  "audio_duration": round(audio_duration, 4),
 
188
  }
189
  message = json.dumps(info, ensure_ascii=False, indent=4)
190
 
191
+ noise_audio = noisy_audio - denoise_audio
192
+
193
+ noisy_mag_db = generate_spectrogram(noisy_audio, title="noisy")
194
+ denoise_mag_db = generate_spectrogram(denoise_audio, title="denoise")
195
+ noise_mag_db = generate_spectrogram(noise_audio, title="noise")
196
+
197
+ denoise_audio = np.array(denoise_audio * (1 << 15), dtype=np.int16)
198
+ noise_audio = np.array(noise_audio * (1 << 15), dtype=np.int16)
199
+
200
  except Exception as e:
201
  raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.")
202
 
203
+ denoise_audio_t = (sample_rate, denoise_audio)
204
+ noise_audio_t = (sample_rate, noise_audio)
205
+ return denoise_audio_t, noise_audio_t, message, noisy_mag_db, denoise_mag_db, noise_mag_db
206
 
207
 
208
  def main():
 
260
  with gr.Column(variant="panel", scale=5):
261
  with gr.Tabs():
262
  with gr.TabItem("audio"):
263
+ dn_denoise_audio = gr.Audio(label="denoise_audio")
264
+ dn_noise_audio = gr.Audio(label="noise_audio")
265
  dn_message = gr.Textbox(lines=1, max_lines=20, label="message")
266
  with gr.TabItem("mag_db"):
267
  dn_noisy_mag_db = gr.Image(label="noisy_mag_db")
268
  dn_denoise_mag_db = gr.Image(label="denoise_mag_db")
269
+ dn_noise_mag_db = gr.Image(label="noise_mag_db")
270
 
271
  dn_button.click(
272
  when_click_denoise_button,
273
  inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
274
+ outputs=[dn_denoise_audio, dn_noise_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db, dn_noise_mag_db]
275
  )
276
  gr.Examples(
277
  examples=examples,
278
  inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
279
+ outputs=[dn_denoise_audio, dn_noise_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db, dn_noise_mag_db],
280
  fn=when_click_denoise_button,
281
  # cache_examples=True,
282
  # cache_mode="lazy",
 
296
  # http://127.0.0.1:7865/
297
  # http://10.75.27.247:7865/
298
  blocks.queue().launch(
299
+ # share=True,
300
+ share=False if platform.system() == "Windows" else False,
301
  server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
302
  server_port=args.server_port
303
  )
toolbox/torchaudio/models/{nx_clean_unet/transformers → dccrn}/__init__.py RENAMED
@@ -2,5 +2,5 @@
2
  # -*- coding: utf-8 -*-
3
 
4
 
5
- if __name__ == '__main__':
6
  pass
 
2
  # -*- coding: utf-8 -*-
3
 
4
 
5
+ if __name__ == "__main__":
6
  pass
toolbox/torchaudio/models/{nx_denoise/stftnet/istftnet.py → dccrn/modeling_dccrn.py} RENAMED
@@ -1,9 +1,12 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  """
4
- https://arxiv.org/abs/2203.02395
5
- """
6
 
 
 
 
 
 
7
 
8
- if __name__ == '__main__':
9
  pass
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  """
 
 
4
 
5
+ https://arxiv.org/abs/2008.00264
6
+
7
+ https://github.com/huyanxin/DeepComplexCRN
8
+
9
+ """
10
 
11
+ if __name__ == "__main__":
12
  pass
toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py CHANGED
@@ -11,7 +11,6 @@ https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd444
11
  """
12
  import os
13
  import math
14
- from collections import defaultdict
15
  from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
16
 
17
  import numpy as np
@@ -109,7 +108,7 @@ class CausalConv2d(nn.Module):
109
  else:
110
  self.activation = nn.Identity()
111
 
112
- def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None):
113
  """
114
  :param inputs: shape: [b, c, t, f]
115
  :param cache: shape: [b, c, lookback, f];
@@ -560,15 +559,14 @@ class Encoder(nn.Module):
560
  feat_spec: torch.Tensor,
561
  cache_dict: dict = None,
562
  ):
563
- if cache_dict is None:
564
- cache_dict = defaultdict(lambda: None)
565
- cache0 = cache_dict["cache0"]
566
- cache1 = cache_dict["cache1"]
567
- cache2 = cache_dict["cache2"]
568
- cache3 = cache_dict["cache3"]
569
- cache4 = cache_dict["cache4"]
570
- cache5 = cache_dict["cache5"]
571
- cache6 = cache_dict["cache6"]
572
 
573
  # feat_erb shape: (b, 1, t, erb_bins)
574
  e0, new_cache0 = self.spec_conv0.forward(feat_erb, cache=cache0)
@@ -716,13 +714,12 @@ class ErbDecoder(nn.Module):
716
  )
717
 
718
  def forward(self, emb, e3, e2, e1, e0, cache_dict: dict = None) -> torch.Tensor:
719
- if cache_dict is None:
720
- cache_dict = defaultdict(lambda: None)
721
- cache0 = cache_dict["cache0"]
722
- cache1 = cache_dict["cache1"]
723
- cache2 = cache_dict["cache2"]
724
- cache3 = cache_dict["cache3"]
725
- cache4 = cache_dict["cache4"]
726
 
727
  # Estimates erb mask
728
  b, _, t, f8 = e3.shape
@@ -814,10 +811,9 @@ class DfDecoder(nn.Module):
814
  )
815
 
816
  def forward(self, emb: torch.Tensor, c0: torch.Tensor, cache_dict: dict = None) -> torch.Tensor:
817
- if cache_dict is None:
818
- cache_dict = defaultdict(lambda: None)
819
- cache0 = cache_dict["cache0"]
820
- cache1 = cache_dict["cache1"]
821
 
822
  # emb shape: [batch_size, time_steps, df_bins // 4 * channels]
823
  b, t, _ = emb.shape
@@ -995,10 +991,9 @@ class DeepFiltering(nn.Module):
995
  coefs: torch.Tensor,
996
  cache_dict: dict = None,
997
  ):
998
- if cache_dict is None:
999
- cache_dict = defaultdict(lambda: None)
1000
- cache0 = cache_dict["cache0"]
1001
- cache1 = cache_dict["cache1"]
1002
 
1003
  # spec shape: [b, 1, t, spec_bins, 2]
1004
  spec_c = torch.view_as_complex(spec.contiguous())
@@ -1163,10 +1158,9 @@ class DfNet2(nn.Module):
1163
  return spec, feat_erb, feat_spec
1164
 
1165
  def feature_norm(self, feat_erb, feat_spec, cache_dict: dict = None):
1166
- if cache_dict is None:
1167
- cache_dict = defaultdict(lambda: None)
1168
- cache0 = cache_dict["cache0"]
1169
- cache1 = cache_dict["cache1"]
1170
 
1171
  feat_erb, new_cache0 = self.erb_ema.norm(feat_erb, state=cache0)
1172
  feat_spec, new_cache1 = self.spec_ema.norm(feat_spec, state=cache1)
@@ -1249,6 +1243,65 @@ class DfNet2(nn.Module):
1249
 
1250
  return est_spec, est_wav, est_mask, lsnr
1251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1252
  def forward_chunk_by_chunk(self,
1253
  noisy: torch.Tensor,
1254
  ):
@@ -1275,52 +1328,13 @@ class DfNet2(nn.Module):
1275
  end = begin + self.win_size
1276
  sub_noisy = noisy[:, :, begin: end]
1277
 
1278
- spec, feat_erb, feat_spec = self.feature_prepare(sub_noisy)
1279
- # spec shape: [b, 1, t, f, 2]
1280
- # feat_erb shape: [b, 1, t, erb_bins]
1281
- # feat_spec shape: [b, 2, t, df_bins]
1282
- if self.config.use_ema_norm:
1283
- feat_erb, feat_spec, cache_dict0 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict0)
1284
-
1285
- e0, e1, e2, e3, emb, c0, lsnr, cache_dict1 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict1)
1286
-
1287
- mask, cache_dict2 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict2)
1288
- # mask shape: [b, 1, t, erb_bins]
1289
- mask = self.erb_bands.erb_scale_inv(mask)
1290
- # mask shape: [b, 1, t, f]
1291
-
1292
- spec_m = self.mask.forward(spec, mask)
1293
- # spec_m shape: [b, 1, t, f, 2]
1294
- spec_m = spec_m[:, :, :, :self.config.spec_bins, :]
1295
- # spec_m shape: [b, 1, t, spec_bins, 2]
1296
-
1297
- # lsnr shape: [b, t, 1]
1298
- lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
1299
- # lsnr shape: [b, 1, t]
1300
-
1301
- df_coefs, cache_dict3 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict3)
1302
- df_coefs = self.df_out_transform(df_coefs)
1303
- # df_coefs shape: [b, df_order, t, df_bins, 2]
1304
-
1305
- spec_ = spec[:, :, :, :self.config.spec_bins, :]
1306
- # spec shape: [b, 1, t, spec_bins, 2]
1307
- spec_f, cache_dict4 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict4)
1308
- # spec_f shape: [b, 1, t, df_bins, 2], torch.float32
1309
-
1310
- spec_e, cache_dict5 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict5)
1311
-
1312
- spec_e = torch.squeeze(spec_e, dim=1)
1313
- spec_e = spec_e.permute(0, 2, 1, 3)
1314
- # spec_e shape: [b, spec_bins, t, 2]
1315
-
1316
- # spec_e shape: [b, spec_bins, t, 2]
1317
- est_spec = torch.view_as_complex(spec_e.contiguous())
1318
- # est_spec shape: [b, spec_bins, t], torch.complex64
1319
- est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
1320
- # est_spec shape: [b, f, t], torch.complex64
1321
-
1322
- est_wav, cache_dict6 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict6)
1323
- # est_wav shape: [b, 1, hop_size]
1324
 
1325
  waveform_list.append(est_wav)
1326
 
@@ -1335,27 +1349,26 @@ class DfNet2(nn.Module):
1335
  :param cache_dict:
1336
  :return:
1337
  """
1338
- if cache_dict is None:
1339
- cache_dict = defaultdict(lambda: None)
1340
- cache_spec_m = cache_dict["cache_spec_m"]
1341
 
1342
- if cache_spec_m is None:
1343
  b, c, t, f, _ = spec_m.shape
1344
- cache_spec_m = spec_m.new_zeros(size=(b, c, self.config.df_lookahead, f, 2))
1345
  # cache0 shape: [b, 1, lookahead, f, 2]
1346
  spec_m_cat = torch.concat(tensors=[
1347
- cache_spec_m, spec_m,
1348
  ], dim=2)
1349
 
1350
  spec_m = spec_m_cat[:, :, :-self.config.df_lookahead, :, :]
1351
- new_cache_spec_m = spec_m_cat[:, :, -self.config.df_lookahead:, :, :]
1352
 
1353
  spec_e = torch.concat(tensors=[
1354
  spec_f, spec_m[..., self.df_decoder.df_bins:, :]
1355
  ], dim=3)
1356
 
1357
  new_cache_dict = {
1358
- "cache_spec_m": new_cache_spec_m,
1359
  }
1360
  return spec_e, new_cache_dict
1361
 
 
11
  """
12
  import os
13
  import math
 
14
  from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
15
 
16
  import numpy as np
 
108
  else:
109
  self.activation = nn.Identity()
110
 
111
+ def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None):
112
  """
113
  :param inputs: shape: [b, c, t, f]
114
  :param cache: shape: [b, c, lookback, f];
 
559
  feat_spec: torch.Tensor,
560
  cache_dict: dict = None,
561
  ):
562
+ cache_dict = cache_dict or dict()
563
+ cache0 = cache_dict.get("cache0", None)
564
+ cache1 = cache_dict.get("cache1", None)
565
+ cache2 = cache_dict.get("cache2", None)
566
+ cache3 = cache_dict.get("cache3", None)
567
+ cache4 = cache_dict.get("cache4", None)
568
+ cache5 = cache_dict.get("cache5", None)
569
+ cache6 = cache_dict.get("cache6", None)
 
570
 
571
  # feat_erb shape: (b, 1, t, erb_bins)
572
  e0, new_cache0 = self.spec_conv0.forward(feat_erb, cache=cache0)
 
714
  )
715
 
716
  def forward(self, emb, e3, e2, e1, e0, cache_dict: dict = None) -> torch.Tensor:
717
+ cache_dict = cache_dict or dict()
718
+ cache0 = cache_dict.get("cache0", None)
719
+ cache1 = cache_dict.get("cache1", None)
720
+ cache2 = cache_dict.get("cache2", None)
721
+ cache3 = cache_dict.get("cache3", None)
722
+ cache4 = cache_dict.get("cache4", None)
 
723
 
724
  # Estimates erb mask
725
  b, _, t, f8 = e3.shape
 
811
  )
812
 
813
  def forward(self, emb: torch.Tensor, c0: torch.Tensor, cache_dict: dict = None) -> torch.Tensor:
814
+ cache_dict = cache_dict or dict()
815
+ cache0 = cache_dict.get("cache0", None)
816
+ cache1 = cache_dict.get("cache1", None)
 
817
 
818
  # emb shape: [batch_size, time_steps, df_bins // 4 * channels]
819
  b, t, _ = emb.shape
 
991
  coefs: torch.Tensor,
992
  cache_dict: dict = None,
993
  ):
994
+ cache_dict = cache_dict or dict()
995
+ cache0 = cache_dict.get("cache0", None)
996
+ cache1 = cache_dict.get("cache1", None)
 
997
 
998
  # spec shape: [b, 1, t, spec_bins, 2]
999
  spec_c = torch.view_as_complex(spec.contiguous())
 
1158
  return spec, feat_erb, feat_spec
1159
 
1160
  def feature_norm(self, feat_erb, feat_spec, cache_dict: dict = None):
1161
+ cache_dict = cache_dict or dict()
1162
+ cache0 = cache_dict.get("cache0", None)
1163
+ cache1 = cache_dict.get("cache1", None)
 
1164
 
1165
  feat_erb, new_cache0 = self.erb_ema.norm(feat_erb, state=cache0)
1166
  feat_spec, new_cache1 = self.spec_ema.norm(feat_spec, state=cache1)
 
1243
 
1244
  return est_spec, est_wav, est_mask, lsnr
1245
 
1246
+ def forward_chunk(self,
1247
+ sub_noisy: torch.Tensor,
1248
+ cache_dict0: dict = None,
1249
+ cache_dict1: dict = None,
1250
+ cache_dict2: dict = None,
1251
+ cache_dict3: dict = None,
1252
+ cache_dict4: dict = None,
1253
+ cache_dict5: dict = None,
1254
+ cache_dict6: dict = None,
1255
+ ):
1256
+
1257
+ spec, feat_erb, feat_spec = self.feature_prepare(sub_noisy)
1258
+ # spec shape: [b, 1, t, f, 2]
1259
+ # feat_erb shape: [b, 1, t, erb_bins]
1260
+ # feat_spec shape: [b, 2, t, df_bins]
1261
+ if self.config.use_ema_norm:
1262
+ feat_erb, feat_spec, cache_dict0 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict0)
1263
+
1264
+ e0, e1, e2, e3, emb, c0, lsnr, cache_dict1 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict1)
1265
+
1266
+ mask, cache_dict2 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict2)
1267
+ # mask shape: [b, 1, t, erb_bins]
1268
+ mask = self.erb_bands.erb_scale_inv(mask)
1269
+ # mask shape: [b, 1, t, f]
1270
+
1271
+ spec_m = self.mask.forward(spec, mask)
1272
+ # spec_m shape: [b, 1, t, f, 2]
1273
+ spec_m = spec_m[:, :, :, :self.config.spec_bins, :]
1274
+ # spec_m shape: [b, 1, t, spec_bins, 2]
1275
+
1276
+ # lsnr shape: [b, t, 1]
1277
+ lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
1278
+ # lsnr shape: [b, 1, t]
1279
+
1280
+ df_coefs, cache_dict3 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict3)
1281
+ df_coefs = self.df_out_transform(df_coefs)
1282
+ # df_coefs shape: [b, df_order, t, df_bins, 2]
1283
+
1284
+ spec_ = spec[:, :, :, :self.config.spec_bins, :]
1285
+ # spec shape: [b, 1, t, spec_bins, 2]
1286
+ spec_f, cache_dict4 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict4)
1287
+ # spec_f shape: [b, 1, t, df_bins, 2], torch.float32
1288
+
1289
+ spec_e, cache_dict5 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict5)
1290
+
1291
+ spec_e = torch.squeeze(spec_e, dim=1)
1292
+ spec_e = spec_e.permute(0, 2, 1, 3)
1293
+ # spec_e shape: [b, spec_bins, t, 2]
1294
+
1295
+ # spec_e shape: [b, spec_bins, t, 2]
1296
+ est_spec = torch.view_as_complex(spec_e.contiguous())
1297
+ # est_spec shape: [b, spec_bins, t], torch.complex64
1298
+ est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
1299
+ # est_spec shape: [b, f, t], torch.complex64
1300
+
1301
+ est_wav, cache_dict6 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict6)
1302
+ # est_wav shape: [b, 1, hop_size]
1303
+ return est_wav, cache_dict0, cache_dict1, cache_dict2, cache_dict3, cache_dict4, cache_dict5, cache_dict6
1304
+
1305
  def forward_chunk_by_chunk(self,
1306
  noisy: torch.Tensor,
1307
  ):
 
1328
  end = begin + self.win_size
1329
  sub_noisy = noisy[:, :, begin: end]
1330
 
1331
+ (est_wav,
1332
+ cache_dict0, cache_dict1, cache_dict2, cache_dict3,
1333
+ cache_dict4, cache_dict5, cache_dict6) = self.forward_chunk(
1334
+ sub_noisy,
1335
+ cache_dict0, cache_dict1, cache_dict2, cache_dict3,
1336
+ cache_dict4, cache_dict5, cache_dict6
1337
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1338
 
1339
  waveform_list.append(est_wav)
1340
 
 
1349
  :param cache_dict:
1350
  :return:
1351
  """
1352
+ cache_dict = cache_dict or dict()
1353
+ cache0 = cache_dict.get("cache0", None)
 
1354
 
1355
+ if cache0 is None:
1356
  b, c, t, f, _ = spec_m.shape
1357
+ cache0 = spec_m.new_zeros(size=(b, c, self.config.df_lookahead, f, 2))
1358
  # cache0 shape: [b, 1, lookahead, f, 2]
1359
  spec_m_cat = torch.concat(tensors=[
1360
+ cache0, spec_m,
1361
  ], dim=2)
1362
 
1363
  spec_m = spec_m_cat[:, :, :-self.config.df_lookahead, :, :]
1364
+ new_cache0 = spec_m_cat[:, :, -self.config.df_lookahead:, :, :]
1365
 
1366
  spec_e = torch.concat(tensors=[
1367
  spec_f, spec_m[..., self.df_decoder.df_bins:, :]
1368
  ], dim=3)
1369
 
1370
  new_cache_dict = {
1371
+ "cache0": new_cache0,
1372
  }
1373
  return spec_e, new_cache_dict
1374
 
toolbox/torchaudio/models/dtln/modeling_dtln.py CHANGED
@@ -1,9 +1,17 @@
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  """
 
 
4
  https://github.com/AkenoSyuRi/DTLNPytorch
5
 
6
  https://github.com/breizhn/DTLN
 
 
 
 
 
 
7
  在 dns3 500个小时的数据上训练, 在 dns3 的测试集上达到了 pesq 3.04 的水平。
8
 
9
  """
@@ -245,13 +253,12 @@ class DTLNModel(nn.Module):
245
  # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
246
 
247
  t = (num_samples_pad - self.fft_size) // self.hop_size + 1
 
248
 
249
  denoise_list = list()
250
  out_state1 = None
251
  out_state2 = None
252
- overlap_size = self.fft_size - self.hop_size
253
  denoise_cache = torch.zeros(size=(batch_size, overlap_size), dtype=noisy.dtype)
254
- # denoise_list.append(torch.clone(denoise_cache))
255
  for i in range(t):
256
  begin = i * self.hop_size
257
  end = begin + self.fft_size
 
1
  #!/usr/bin/python3
2
  # -*- coding: utf-8 -*-
3
  """
4
+ https://www.isca-archive.org/interspeech_2020/westhausen20_interspeech.pdf
5
+
6
  https://github.com/AkenoSyuRi/DTLNPytorch
7
 
8
  https://github.com/breizhn/DTLN
9
+
10
+ 数据集: DNS3 DNS-Challenge
11
+ 信噪比 -5 到 25 dB
12
+ 5 到 30 dB
13
+ 窗长 32ms, 窗移 8ms
14
+
15
  在 dns3 500个小时的数据上训练, 在 dns3 的测试集上达到了 pesq 3.04 的水平。
16
 
17
  """
 
253
  # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
254
 
255
  t = (num_samples_pad - self.fft_size) // self.hop_size + 1
256
+ overlap_size = self.fft_size - self.hop_size
257
 
258
  denoise_list = list()
259
  out_state1 = None
260
  out_state2 = None
 
261
  denoise_cache = torch.zeros(size=(batch_size, overlap_size), dtype=noisy.dtype)
 
262
  for i in range(t):
263
  begin = i * self.hop_size
264
  end = begin + self.fft_size
toolbox/torchaudio/models/ehnet/modeling_ehnet.py CHANGED
@@ -71,7 +71,6 @@ class CausalTransConvBlock(nn.Module):
71
  return x
72
 
73
 
74
-
75
  class CRN(nn.Module):
76
  """
77
  Input: [batch size, channels=1, T, n_fft]
 
71
  return x
72
 
73
 
 
74
  class CRN(nn.Module):
75
  """
76
  Input: [batch size, channels=1, T, n_fft]
toolbox/torchaudio/models/nx_clean_unet/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
-
4
-
5
- if __name__ == '__main__':
6
- pass
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/causal_convolution/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
-
4
-
5
- if __name__ == '__main__':
6
- pass
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/causal_convolution/causal_conv2d.py DELETED
@@ -1,261 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import math
4
- import os
5
- from typing import List, Optional, Union, Iterable
6
-
7
- import numpy as np
8
- import torch
9
- import torch.nn as nn
10
- from torch.nn import functional as F
11
-
12
-
13
- norm_layer_dict = {
14
- "batch_norm_2d": torch.nn.BatchNorm2d
15
- }
16
-
17
-
18
- activation_layer_dict = {
19
- "relu": torch.nn.ReLU,
20
- "identity": torch.nn.Identity,
21
- "sigmoid": torch.nn.Sigmoid,
22
- }
23
-
24
-
25
- class CausalConv2d(nn.Module):
26
- def __init__(self,
27
- in_channels: int,
28
- out_channels: int,
29
- kernel_size: Union[int, Iterable[int]],
30
- f_stride: int = 1,
31
- dilation: int = 1,
32
- do_f_pad: bool = True,
33
- bias: bool = True,
34
- separable: bool = False,
35
- norm_layer: str = "batch_norm_2d",
36
- activation_layer: str = "relu",
37
- lookahead: int = 0
38
- ):
39
- super(CausalConv2d, self).__init__()
40
- kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
41
-
42
- if do_f_pad:
43
- f_pad = kernel_size[1] // 2 + dilation - 1
44
- else:
45
- f_pad = 0
46
-
47
- self.causal_left_pad = kernel_size[0] - 1 - lookahead
48
- self.causal_right_pad = lookahead
49
- self.constant_pad = nn.ConstantPad2d(
50
- padding=(0, 0, self.causal_left_pad, self.causal_right_pad),
51
- value=0.0
52
- )
53
-
54
- groups = math.gcd(in_channels, out_channels) if separable else 1
55
- self.conv1 = nn.Conv2d(
56
- in_channels,
57
- out_channels,
58
- kernel_size=kernel_size,
59
- padding=(0, f_pad),
60
- stride=(1, f_stride),
61
- dilation=(1, dilation),
62
- groups=groups,
63
- bias=bias,
64
- )
65
-
66
- self.conv2 = None
67
- if not any([groups == 1, max(kernel_size) == 1]):
68
- self.conv2 = nn.Conv2d(
69
- out_channels,
70
- out_channels,
71
- kernel_size=1,
72
- bias=False,
73
- )
74
-
75
- self.norm = None
76
- if norm_layer is not None:
77
- norm_layer = norm_layer_dict[norm_layer]
78
- self.norm = norm_layer(out_channels)
79
-
80
- self.activation = None
81
- if activation_layer is not None:
82
- activation_layer = activation_layer_dict[activation_layer]
83
- self.activation = activation_layer()
84
-
85
- def forward(self,
86
- inputs: torch.Tensor,
87
- causal_cache: torch.Tensor = None,
88
- ):
89
-
90
- if causal_cache is None:
91
- # inputs shape: [batch_size, 1, time_steps, hidden_size]
92
- x = self.constant_pad.forward(inputs)
93
- else:
94
- # inputs shape: [batch_size, 1, time_steps + self.causal_right_pad, hidden_size]
95
- # causal_cache shape: [batch_size, 1, self.causal_left_pad, hidden_size]
96
- x = torch.concat(tensors=[causal_cache, inputs], dim=2)
97
- # x shape: [batch_size, 1, time_steps2, hidden_size]
98
- # time_steps2 = time_steps + self.causal_left_pad + self.causal_right_pad
99
-
100
- x = self.conv1.forward(x)
101
- # inputs shape: [batch_size, 1, time_steps, hidden_size]
102
-
103
- if self.conv2:
104
- x = self.conv2.forward(x)
105
-
106
- if self.norm:
107
- x = self.norm(x)
108
- if self.activation:
109
- x = self.activation(x)
110
-
111
- causal_cache = x[:, :, -self.causal_left_pad:, :]
112
-
113
- # inputs shape: [batch_size, 1, time_steps, hidden_size]
114
- return x, causal_cache
115
-
116
-
117
- class CausalConv2dEncoder(nn.Module):
118
- def __init__(self,
119
- in_channels: int,
120
- out_channels: int,
121
- kernel_size: Union[int, Iterable[int]],
122
- f_stride: int = 1,
123
- dilation: int = 1,
124
- do_f_pad: bool = True,
125
- bias: bool = True,
126
- separable: bool = False,
127
- norm_layer: str = "batch_norm_2d",
128
- activation_layer: str = "relu",
129
- lookahead: int = 0,
130
- num_layers: int = 5,
131
- ):
132
- super(CausalConv2dEncoder, self).__init__()
133
- self.num_layers = num_layers
134
-
135
- self.total_causal_left_pad = 0
136
- self.total_causal_right_pad = 0
137
-
138
- self.causal_conv_list: List[CausalConv2d] = nn.ModuleList(modules=[])
139
- for i_layer in range(num_layers):
140
- conv = CausalConv2d(
141
- in_channels=in_channels,
142
- out_channels=out_channels,
143
- kernel_size=kernel_size,
144
- f_stride=f_stride,
145
- dilation=dilation,
146
- do_f_pad=do_f_pad,
147
- bias=bias,
148
- separable=separable,
149
- norm_layer=norm_layer,
150
- activation_layer=activation_layer,
151
- lookahead=lookahead,
152
- )
153
- self.causal_conv_list.append(conv)
154
-
155
- self.total_causal_left_pad += conv.causal_left_pad
156
- self.total_causal_right_pad += conv.causal_right_pad
157
-
158
- in_channels = out_channels
159
-
160
- def forward(self, inputs: torch.Tensor):
161
- # inputs shape: [batch_size, 1, time_steps, hidden_size]
162
-
163
- x = inputs
164
- for layer in self.causal_conv_list:
165
- x, _ = layer.forward(x)
166
- return x
167
-
168
- def forward_chunk(self,
169
- chunk: torch.Tensor,
170
- causal_cache: torch.Tensor = None,
171
- ):
172
- # causal_cache shape: [self.num_layers, 1, causal_left_pad, hidden_size]
173
-
174
- new_causal_cache_list = list()
175
- for idx, causal_conv in enumerate(self.causal_conv_list):
176
- chunk, new_causal_cache = causal_conv.forward(
177
- inputs=chunk, causal_cache=causal_cache[idx: idx+1] if causal_cache is not None else None
178
- )
179
- new_causal_cache_list.append(new_causal_cache)
180
-
181
- new_causal_cache = torch.cat(new_causal_cache_list, dim=0)
182
- return chunk, new_causal_cache
183
-
184
- def forward_chunk_by_chunk(self, inputs: torch.Tensor):
185
- # inputs shape: [batch_size, 1, time_steps, hidden_size]
186
- # batch_size = 1
187
-
188
- batch_size, channels, time_steps, hidden_size = inputs.shape
189
-
190
- causal_cache = None
191
-
192
- outputs = []
193
- for idx in range(0, time_steps, 1):
194
- begin = idx
195
- end = begin + self.total_causal_right_pad + 1
196
- chunk_xs = inputs[:, :, begin:end, :]
197
-
198
- ys, attention_cache = self.forward_chunk(
199
- chunk=chunk_xs,
200
- causal_cache=causal_cache,
201
- )
202
- # ys shape: [batch_size, channels, self.total_causal_right_pad + 1 , hidden_size]
203
- ys = ys[:, :, :1, :]
204
-
205
- # ys shape: [batch_size, chunk_size, hidden_size]
206
- outputs.append(ys)
207
-
208
- ys = torch.cat(outputs, 2)
209
- return ys
210
-
211
-
212
- def main2():
213
- conv = CausalConv2d(
214
- in_channels=1,
215
- out_channels=64,
216
- kernel_size=3,
217
- bias=False,
218
- separable=True,
219
- f_stride=1,
220
- lookahead=0,
221
- )
222
-
223
- spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
224
- # spec shape: [batch_size, 1, time_steps, hidden_size]
225
- cache = torch.randn(size=(1, 1, conv.causal_left_pad, 64), dtype=torch.float32)
226
-
227
- output, _ = conv.forward(spec)
228
- print(output.shape)
229
-
230
- output, _ = conv.forward(spec, cache)
231
- print(output.shape)
232
-
233
- return
234
-
235
-
236
- def main():
237
- causal = CausalConv2dEncoder(
238
- in_channels=1,
239
- out_channels=1,
240
- kernel_size=3,
241
- bias=False,
242
- separable=True,
243
- f_stride=1,
244
- lookahead=0,
245
- num_layers=3,
246
- )
247
-
248
- spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
249
- # spec shape: [batch_size, 1, time_steps, hidden_size]
250
-
251
- output = causal.forward(spec)
252
- print(output.shape)
253
-
254
- output = causal.forward_chunk_by_chunk(spec)
255
- print(output.shape)
256
-
257
- return
258
-
259
-
260
- if __name__ == '__main__':
261
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py DELETED
@@ -1,100 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
-
5
-
6
- class NXCleanUNetConfig(PretrainedConfig):
7
- """
8
- https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
9
- """
10
- def __init__(self,
11
- sample_rate: int = 8000,
12
- segment_size: int = 16000,
13
- n_fft: int = 512,
14
- win_length: int = 200,
15
- hop_length: int = 80,
16
-
17
- down_sampling_num_layers: int = 5,
18
- down_sampling_in_channels: int = 1,
19
- down_sampling_hidden_channels: int = 64,
20
- down_sampling_kernel_size: int = 4,
21
- down_sampling_stride: int = 2,
22
-
23
- causal_in_channels: int = 64,
24
- causal_out_channels: int = 64,
25
- causal_kernel_size: int = 3,
26
- causal_bias: bool = False,
27
- causal_separable: bool = True,
28
- causal_f_stride: int = 1,
29
- # causal_lookahead: int = 0,
30
- causal_num_layers: int = 3,
31
-
32
- tsfm_hidden_size: int = 256,
33
- tsfm_attention_heads: int = 4,
34
- tsfm_num_blocks: int = 6,
35
- tsfm_dropout_rate: float = 0.1,
36
- tsfm_max_length: int = 1024,
37
- tsfm_chunk_size: int = 4,
38
- tsfm_num_left_chunks: int = 128,
39
- tsfm_num_right_chunks: int = 2,
40
-
41
- discriminator_dim: int = 16,
42
- discriminator_in_channel: int = 2,
43
-
44
- compress_factor: float = 0.3,
45
-
46
- batch_size: int = 4,
47
- learning_rate: float = 0.0005,
48
- adam_b1: float = 0.8,
49
- adam_b2: float = 0.99,
50
- lr_decay: float = 0.99,
51
- seed: int = 1234,
52
-
53
- **kwargs
54
- ):
55
- super(NXCleanUNetConfig, self).__init__(**kwargs)
56
- self.sample_rate = sample_rate
57
- self.segment_size = segment_size
58
- self.n_fft = n_fft
59
- self.win_length = win_length
60
- self.hop_length = hop_length
61
-
62
- self.down_sampling_num_layers = down_sampling_num_layers
63
- self.down_sampling_in_channels = down_sampling_in_channels
64
- self.down_sampling_hidden_channels = down_sampling_hidden_channels
65
- self.down_sampling_kernel_size = down_sampling_kernel_size
66
- self.down_sampling_stride = down_sampling_stride
67
-
68
- self.causal_in_channels = causal_in_channels
69
- self.causal_out_channels = causal_out_channels
70
- self.causal_kernel_size = causal_kernel_size
71
- self.causal_bias = causal_bias
72
- self.causal_separable = causal_separable
73
- self.causal_f_stride = causal_f_stride
74
- # self.causal_lookahead = causal_lookahead
75
- self.causal_num_layers = causal_num_layers
76
-
77
- self.tsfm_hidden_size = tsfm_hidden_size
78
- self.tsfm_attention_heads = tsfm_attention_heads
79
- self.tsfm_num_blocks = tsfm_num_blocks
80
- self.tsfm_dropout_rate = tsfm_dropout_rate
81
- self.tsfm_max_length = tsfm_max_length
82
- self.tsfm_chunk_size = tsfm_chunk_size
83
- self.tsfm_num_left_chunks = tsfm_num_left_chunks
84
- self.tsfm_num_right_chunks = tsfm_num_right_chunks
85
-
86
- self.discriminator_dim = discriminator_dim
87
- self.discriminator_in_channel = discriminator_in_channel
88
-
89
- self.compress_factor = compress_factor
90
-
91
- self.batch_size = batch_size
92
- self.learning_rate = learning_rate
93
- self.adam_b1 = adam_b1
94
- self.adam_b2 = adam_b2
95
- self.lr_decay = lr_decay
96
- self.seed = seed
97
-
98
-
99
- if __name__ == '__main__':
100
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/discriminator.py DELETED
@@ -1,132 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import os
4
- from typing import Optional, Union
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torchaudio
9
-
10
- from toolbox.torchaudio.configuration_utils import CONFIG_FILE
11
- from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
12
- from toolbox.torchaudio.models.nx_clean_unet.utils import LearnableSigmoid1d
13
-
14
-
15
- class MetricDiscriminator(nn.Module):
16
- def __init__(self, config: NXCleanUNetConfig):
17
- super(MetricDiscriminator, self).__init__()
18
- dim = config.discriminator_dim
19
- self.in_channel = config.discriminator_in_channel
20
-
21
- self.n_fft = config.n_fft
22
- self.win_length = config.win_length
23
- self.hop_length = config.hop_length
24
-
25
- self.transform = torchaudio.transforms.Spectrogram(
26
- n_fft=self.n_fft,
27
- win_length=self.win_length,
28
- hop_length=self.hop_length,
29
- power=1.0,
30
- window_fn=torch.hann_window,
31
- # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
32
- )
33
-
34
- self.layers = nn.Sequential(
35
- nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
36
- nn.InstanceNorm2d(dim, affine=True),
37
- nn.PReLU(dim),
38
- nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
39
- nn.InstanceNorm2d(dim*2, affine=True),
40
- nn.PReLU(dim*2),
41
- nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
42
- nn.InstanceNorm2d(dim*4, affine=True),
43
- nn.PReLU(dim*4),
44
- nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
45
- nn.InstanceNorm2d(dim*8, affine=True),
46
- nn.PReLU(dim*8),
47
- nn.AdaptiveMaxPool2d(1),
48
- nn.Flatten(),
49
- nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
50
- nn.Dropout(0.3),
51
- nn.PReLU(dim*4),
52
- nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
53
- LearnableSigmoid1d(1)
54
- )
55
-
56
- def forward(self, x, y):
57
- x = self.transform.forward(x)
58
- y = self.transform.forward(y)
59
-
60
- xy = torch.stack((x, y), dim=1)
61
- return self.layers(xy)
62
-
63
-
64
- MODEL_FILE = "discriminator.pt"
65
-
66
-
67
- class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
68
- def __init__(self,
69
- config: NXCleanUNetConfig,
70
- ):
71
- super(MetricDiscriminatorPretrainedModel, self).__init__(
72
- config=config,
73
- )
74
- self.config = config
75
-
76
- @classmethod
77
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
78
- config = NXCleanUNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
79
-
80
- model = cls(config)
81
-
82
- if os.path.isdir(pretrained_model_name_or_path):
83
- ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
84
- else:
85
- ckpt_file = pretrained_model_name_or_path
86
-
87
- with open(ckpt_file, "rb") as f:
88
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
89
- model.load_state_dict(state_dict, strict=True)
90
- return model
91
-
92
- def save_pretrained(self,
93
- save_directory: Union[str, os.PathLike],
94
- state_dict: Optional[dict] = None,
95
- ):
96
-
97
- model = self
98
-
99
- if state_dict is None:
100
- state_dict = model.state_dict()
101
-
102
- os.makedirs(save_directory, exist_ok=True)
103
-
104
- # save state dict
105
- model_file = os.path.join(save_directory, MODEL_FILE)
106
- torch.save(state_dict, model_file)
107
-
108
- # save config
109
- config_file = os.path.join(save_directory, CONFIG_FILE)
110
- self.config.to_yaml_file(config_file)
111
- return save_directory
112
-
113
-
114
- def main():
115
- config = NXCleanUNetConfig()
116
- discriminator = MetricDiscriminator(config=config)
117
-
118
- # shape: [batch_size, num_samples]
119
- # x = torch.ones([4, int(4.5 * 16000)])
120
- # y = torch.ones([4, int(4.5 * 16000)])
121
- x = torch.ones([4, 16000])
122
- y = torch.ones([4, 16000])
123
-
124
- output = discriminator.forward(x, y)
125
- print(output.shape)
126
- print(output)
127
-
128
- return
129
-
130
-
131
- if __name__ == "__main__":
132
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav DELETED
Binary file (63.8 kB)
 
toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py DELETED
@@ -1,96 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import logging
4
- from pathlib import Path
5
- import shutil
6
- import tempfile
7
- import zipfile
8
-
9
- import librosa
10
- import numpy as np
11
- import torch
12
- import torchaudio
13
-
14
- from project_settings import project_path
15
- from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
16
- from toolbox.torchaudio.models.nx_clean_unet.modeling_nx_clean_unet import NXCleanUNetPretrainedModel, MODEL_FILE
17
-
18
- logger = logging.getLogger("toolbox")
19
-
20
-
21
- class InferenceNXCleanUNet(object):
22
- def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
23
- self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
24
- self.device = torch.device(device)
25
-
26
- logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
27
- config, model = self.load_models(self.pretrained_model_path_or_zip_file)
28
- logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
29
-
30
- self.config = config
31
- self.model = model
32
- self.model.to(device)
33
- self.model.eval()
34
-
35
- def load_models(self, model_path: str):
36
- model_path = Path(model_path)
37
- if model_path.name.endswith(".zip"):
38
- with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
39
- out_root = Path(tempfile.gettempdir()) / "nx_denoise"
40
- out_root.mkdir(parents=True, exist_ok=True)
41
- f_zip.extractall(path=out_root)
42
- model_path = out_root / model_path.stem
43
-
44
- config = NXCleanUNetConfig.from_pretrained(
45
- pretrained_model_name_or_path=model_path.as_posix(),
46
- )
47
- model = NXCleanUNetPretrainedModel.from_pretrained(
48
- pretrained_model_name_or_path=model_path.as_posix(),
49
- )
50
- model.to(self.device)
51
- model.eval()
52
-
53
- shutil.rmtree(model_path)
54
- return config, model
55
-
56
- def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
57
- if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
58
- raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
59
-
60
- # noisy_audio shape: [batch_size, num_samples]
61
- noisy_audios = noisy_audio.to(self.device)
62
-
63
- with torch.no_grad():
64
- enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios)
65
- # enhanced_audios = self.model.forward(noisy_audios)
66
- # enhanced_audio shape: [batch_size, n_samples]
67
- # enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
68
-
69
- enhanced_audio = enhanced_audios[0]
70
- # enhanced_audio shape: [num_samples,]
71
- return enhanced_audio
72
-
73
- def main():
74
- model_zip_file = project_path / "trained_models/nx-clean-unet-14-epoch.zip"
75
- infer_nx_clean_unet = InferenceNXCleanUNet(model_zip_file)
76
-
77
- sample_rate = 8000
78
- noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
79
- noisy_audio, _ = librosa.load(
80
- noisy_audio_file.as_posix(),
81
- sr=sample_rate,
82
- )
83
- noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
84
- noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
85
- noisy_audio = noisy_audio.unsqueeze(dim=0)
86
-
87
- enhanced_audio = infer_nx_clean_unet.enhancement_by_tensor(noisy_audio)
88
-
89
- filename = "enhanced_audio.wav"
90
- torchaudio.save(filename, enhanced_audio.detach().cpu().unsqueeze(dim=0), sample_rate)
91
-
92
- return
93
-
94
-
95
- if __name__ == '__main__':
96
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/loss.py DELETED
@@ -1,22 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import numpy as np
4
- import torch
5
-
6
-
7
- def anti_wrapping_function(x):
8
-
9
- return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
10
-
11
-
12
- def phase_losses(phase_r, phase_g):
13
-
14
- ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
15
- gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
16
- iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
17
-
18
- return ip_loss, gd_loss, iaf_loss
19
-
20
-
21
- if __name__ == '__main__':
22
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/metrics.py DELETED
@@ -1,80 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- from joblib import Parallel, delayed
4
- import numpy as np
5
- from pesq import pesq
6
- from typing import List
7
-
8
- from pesq import cypesq
9
-
10
-
11
- def run_pesq(clean_audio: np.ndarray,
12
- noisy_audio: np.ndarray,
13
- sample_rate: int = 16000,
14
- mode: str = "wb",
15
- ) -> float:
16
- if sample_rate == 8000 and mode == "wb":
17
- raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
18
- try:
19
- pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
20
- except cypesq.NoUtterancesError as e:
21
- pesq_score = -1
22
- except Exception as e:
23
- print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
24
- pesq_score = -1
25
- return pesq_score
26
-
27
-
28
- def run_batch_pesq(clean_audio_list: List[np.ndarray],
29
- noisy_audio_list: List[np.ndarray],
30
- sample_rate: int = 16000,
31
- mode: str = "wb",
32
- n_jobs: int = 4,
33
- ) -> List[float]:
34
- parallel = Parallel(n_jobs=n_jobs)
35
-
36
- parallel_tasks = list()
37
- for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
38
- parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
39
- parallel_tasks.append(parallel_task)
40
-
41
- pesq_score_list = parallel.__call__(parallel_tasks)
42
- return pesq_score_list
43
-
44
-
45
- def run_pesq_score(clean_audio_list: List[np.ndarray],
46
- noisy_audio_list: List[np.ndarray],
47
- sample_rate: int = 16000,
48
- mode: str = "wb",
49
- n_jobs: int = 4,
50
- ) -> List[float]:
51
-
52
- pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
53
- noisy_audio_list=noisy_audio_list,
54
- sample_rate=sample_rate,
55
- mode=mode,
56
- n_jobs=n_jobs,
57
- )
58
-
59
- pesq_score = np.mean(pesq_score_list)
60
- return pesq_score
61
-
62
-
63
- def main():
64
- clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
65
- noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
66
-
67
- clean_audio_list = list(clean_audio)
68
- noisy_audio_list = list(noisy_audio)
69
-
70
- pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
71
- print(pesq_score_list)
72
-
73
- pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
74
- print(pesq_score)
75
-
76
- return
77
-
78
-
79
- if __name__ == "__main__":
80
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py DELETED
@@ -1,401 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import os
4
- from typing import List, Optional, Union
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- from torch.nn import functional as F
10
-
11
- from toolbox.torchaudio.configuration_utils import CONFIG_FILE
12
- from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
13
- from toolbox.torchaudio.models.nx_clean_unet.transformers.transformers import TransformerEncoder
14
- from toolbox.torchaudio.models.nx_clean_unet.causal_convolution.causal_conv2d import CausalConv2dEncoder
15
-
16
-
17
- class DownSamplingBlock(nn.Module):
18
- def __init__(self,
19
- in_channels: int,
20
- hidden_channels: int,
21
- kernel_size: int,
22
- stride: int,
23
- ):
24
- super(DownSamplingBlock, self).__init__()
25
- self.conv1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, stride)
26
- self.relu = nn.ReLU()
27
- self.conv2 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1)
28
- self.glu = nn.GLU(dim=1)
29
-
30
- def forward(self, x: torch.Tensor):
31
- # x shape: [batch_size, 1, num_samples]
32
- x = self.conv1.forward(x)
33
- # x shape: [batch_size, hidden_channels, new_num_samples]
34
- x = self.relu(x)
35
- x = self.conv2.forward(x)
36
- # x shape: [batch_size, hidden_channels*2, new_num_samples]
37
- x = self.glu(x)
38
- # x shape: [batch_size, hidden_channels, new_num_samples]
39
- # new_num_samples = (num_samples-kernel_size) // stride + 1
40
- return x
41
-
42
-
43
- class DownSampling(nn.Module):
44
- def __init__(self,
45
- num_layers: int,
46
- in_channels: int,
47
- hidden_channels: int,
48
- kernel_size: int,
49
- stride: int,
50
- ):
51
- super(DownSampling, self).__init__()
52
- self.num_layers = num_layers
53
-
54
- down_sampling_block_list = list()
55
- for idx in range(self.num_layers):
56
- down_sampling_block = DownSamplingBlock(
57
- in_channels=in_channels,
58
- hidden_channels=hidden_channels,
59
- kernel_size=kernel_size,
60
- stride=stride,
61
- )
62
- down_sampling_block_list.append(down_sampling_block)
63
- in_channels = hidden_channels
64
-
65
- self.down_sampling_block_list = nn.ModuleList(modules=down_sampling_block_list)
66
-
67
- def forward(self, x: torch.Tensor):
68
- # x shape: [batch_size, channels, num_samples]
69
- skip_connection_list = list()
70
- for down_sampling_block in self.down_sampling_block_list:
71
- x = down_sampling_block.forward(x)
72
- skip_connection_list.append(x)
73
- # x shape: [batch_size, hidden_channels, num_samples**]
74
- return x, skip_connection_list
75
-
76
-
77
- class UpSamplingBlock(nn.Module):
78
- def __init__(self,
79
- out_channels: int,
80
- hidden_channels: int,
81
- kernel_size: int,
82
- stride: int,
83
- do_relu: bool = True,
84
- ):
85
- super(UpSamplingBlock, self).__init__()
86
- self.do_relu = do_relu
87
-
88
- self.conv1 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1)
89
- self.glu = nn.GLU(dim=1)
90
- self.convt = nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride)
91
- self.relu = nn.ReLU()
92
-
93
- def forward(self, x: torch.Tensor):
94
- # x shape: [batch_size, hidden_channels*2, num_samples]
95
- x = self.conv1.forward(x)
96
- # x shape: [batch_size, hidden_channels, num_samples]
97
- x = self.glu(x)
98
- # x shape: [batch_size, hidden_channels, num_samples]
99
- x = self.convt.forward(x)
100
- # x shape: [batch_size, hidden_channels, new_num_samples]
101
- # new_num_samples = (num_samples - 1) * stride + kernel_size
102
- if self.do_relu:
103
- x = self.relu(x)
104
- return x
105
-
106
-
107
- class UpSampling(nn.Module):
108
- def __init__(self,
109
- num_layers: int,
110
- out_channels: int,
111
- hidden_channels: int,
112
- kernel_size: int,
113
- stride: int,
114
- ):
115
- super(UpSampling, self).__init__()
116
- self.num_layers = num_layers
117
-
118
- up_sampling_block_list = list()
119
- for idx in range(self.num_layers-1):
120
- up_sampling_block = UpSamplingBlock(
121
- out_channels=hidden_channels,
122
- hidden_channels=hidden_channels,
123
- kernel_size=kernel_size,
124
- stride=stride,
125
- do_relu=True,
126
- )
127
- up_sampling_block_list.append(up_sampling_block)
128
- else:
129
- up_sampling_block = UpSamplingBlock(
130
- out_channels=out_channels,
131
- hidden_channels=hidden_channels,
132
- kernel_size=kernel_size,
133
- stride=stride,
134
- do_relu=False,
135
- )
136
- up_sampling_block_list.append(up_sampling_block)
137
- self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
138
-
139
- def forward(self, x: torch.Tensor, skip_connection_list: List[torch.Tensor]):
140
- skip_connection_list = skip_connection_list[::-1]
141
-
142
- # x shape: [batch_size, channels, num_samples]
143
- for idx, up_sampling_block in enumerate(self.up_sampling_block_list):
144
- skip_x = skip_connection_list[idx]
145
- x = x + skip_x
146
- # x = x + skip_x[:, :, :x.size(2)]
147
- x = up_sampling_block.forward(x)
148
- return x
149
-
150
-
151
- def get_padding_length(length, num_layers: int, kernel_size: int, stride: int):
152
- for _ in range(num_layers):
153
- if length < kernel_size:
154
- length = 1
155
- else:
156
- length = 1 + np.ceil((length - kernel_size) / stride)
157
-
158
- for _ in range(num_layers):
159
- length = (length - 1) * stride + kernel_size
160
-
161
- padded_length = int(length)
162
- return padded_length
163
-
164
-
165
- class NXCleanUNet(nn.Module):
166
- def __init__(self, config):
167
- super().__init__()
168
- self.config = config
169
-
170
- self.down_sampling = DownSampling(
171
- num_layers=config.down_sampling_num_layers,
172
- in_channels=config.down_sampling_in_channels,
173
- hidden_channels=config.down_sampling_hidden_channels,
174
- kernel_size=config.down_sampling_kernel_size,
175
- stride=config.down_sampling_stride,
176
- )
177
- self.causal_encoder = CausalConv2dEncoder(
178
- in_channels=config.causal_in_channels,
179
- out_channels=config.causal_out_channels,
180
- kernel_size=config.causal_kernel_size,
181
- bias=config.causal_bias,
182
- separable=config.causal_separable,
183
- f_stride=config.causal_f_stride,
184
- lookahead=0,
185
- num_layers=config.causal_num_layers,
186
- )
187
- self.transformer = TransformerEncoder(
188
- input_size=config.down_sampling_hidden_channels,
189
- hidden_size=config.tsfm_hidden_size,
190
- attention_heads=config.tsfm_attention_heads,
191
- num_blocks=config.tsfm_num_blocks,
192
- dropout_rate=config.tsfm_dropout_rate,
193
- chunk_size=config.tsfm_chunk_size,
194
- num_left_chunks=config.tsfm_num_left_chunks,
195
- num_right_chunks=config.tsfm_num_right_chunks,
196
- )
197
- self.up_sampling = UpSampling(
198
- num_layers=config.down_sampling_num_layers,
199
- out_channels=config.down_sampling_in_channels,
200
- hidden_channels=config.down_sampling_hidden_channels,
201
- kernel_size=config.down_sampling_kernel_size,
202
- stride=config.down_sampling_stride,
203
- )
204
-
205
- def forward(self, noisy_audios: torch.Tensor):
206
- # noisy_audios shape: [batch_size, n_samples]
207
- noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
208
- # noisy_audios shape: [batch_size, 1, n_samples]
209
-
210
- n_samples = noisy_audios.shape[-1]
211
- padded_length = get_padding_length(
212
- n_samples,
213
- num_layers=self.config.down_sampling_num_layers,
214
- kernel_size=self.config.down_sampling_kernel_size,
215
- stride=self.config.down_sampling_stride,
216
- )
217
- noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
218
-
219
- bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
220
- # bottle_neck shape: [batch_size, channels, time_steps]
221
-
222
- bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
223
- # bottle_neck shape: [batch_size, time_steps, input_size]
224
-
225
- bottle_neck = bottle_neck.unsqueeze(dim=1)
226
- bottle_neck = self.causal_encoder.forward(bottle_neck)
227
- bottle_neck = bottle_neck.squeeze(dim=1)
228
- # bottle_neck shape: [batch_size, time_steps, input_size]
229
-
230
- bottle_neck = self.transformer.forward(bottle_neck)
231
- # bottle_neck shape: [batch_size, time_steps, input_size]
232
-
233
- bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
234
- # bottle_neck shape: [batch_size, channels, time_steps]
235
-
236
- enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
237
-
238
- enhanced_audios = enhanced_audios[:, :, :n_samples]
239
- # enhanced_audios shape: [batch_size, 1, n_samples]
240
-
241
- enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
242
- # enhanced_audios shape: [batch_size, n_samples]
243
-
244
- return enhanced_audios
245
-
246
- def forward_chunk_by_chunk(self, noisy_audios: torch.Tensor):
247
- # noisy_audios shape: [batch_size, n_samples]
248
- noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
249
- # noisy_audios shape: [batch_size, 1, n_samples]
250
-
251
- n_samples = noisy_audios.shape[-1]
252
- padded_length = get_padding_length(
253
- n_samples,
254
- num_layers=self.config.down_sampling_num_layers,
255
- kernel_size=self.config.down_sampling_kernel_size,
256
- stride=self.config.down_sampling_stride,
257
- )
258
- noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
259
-
260
- bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
261
- # bottle_neck shape: [batch_size, channels, time_steps]
262
-
263
- bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
264
- # bottle_neck shape: [batch_size, time_steps, input_size]
265
-
266
- bottle_neck = bottle_neck.unsqueeze(dim=1)
267
- bottle_neck = self.causal_encoder.forward_chunk_by_chunk(bottle_neck)
268
- bottle_neck = bottle_neck.squeeze(dim=1)
269
- # bottle_neck shape: [batch_size, time_steps, input_size]
270
-
271
- bottle_neck = self.transformer.forward_chunk_by_chunk(bottle_neck)
272
- # bottle_neck shape: [batch_size, time_steps, input_size]
273
-
274
- bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
275
- # bottle_neck shape: [batch_size, channels, time_steps]
276
-
277
- enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
278
-
279
- enhanced_audios = enhanced_audios[:, :, :n_samples]
280
- # enhanced_audios shape: [batch_size, 1, n_samples]
281
-
282
- enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
283
- # enhanced_audios shape: [batch_size, n_samples]
284
-
285
- return enhanced_audios
286
-
287
-
288
-
289
- MODEL_FILE = "generator.pt"
290
-
291
-
292
- class NXCleanUNetPretrainedModel(NXCleanUNet):
293
- def __init__(self,
294
- config: NXCleanUNetConfig,
295
- ):
296
- super(NXCleanUNetPretrainedModel, self).__init__(
297
- config=config,
298
- )
299
- self.config = config
300
-
301
- @classmethod
302
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
303
- config = NXCleanUNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
304
-
305
- model = cls(config)
306
-
307
- if os.path.isdir(pretrained_model_name_or_path):
308
- ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
309
- else:
310
- ckpt_file = pretrained_model_name_or_path
311
-
312
- with open(ckpt_file, "rb") as f:
313
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
314
- model.load_state_dict(state_dict, strict=True)
315
- return model
316
-
317
- def save_pretrained(self,
318
- save_directory: Union[str, os.PathLike],
319
- state_dict: Optional[dict] = None,
320
- ):
321
-
322
- model = self
323
-
324
- if state_dict is None:
325
- state_dict = model.state_dict()
326
-
327
- os.makedirs(save_directory, exist_ok=True)
328
-
329
- # save state dict
330
- model_file = os.path.join(save_directory, MODEL_FILE)
331
- torch.save(state_dict, model_file)
332
-
333
- # save config
334
- config_file = os.path.join(save_directory, CONFIG_FILE)
335
- self.config.to_yaml_file(config_file)
336
- return save_directory
337
-
338
-
339
-
340
- def main2():
341
-
342
- config = NXCleanUNetConfig()
343
- down_sampling = DownSampling(
344
- num_layers=config.down_sampling_num_layers,
345
- in_channels=config.down_sampling_in_channels,
346
- hidden_channels=config.down_sampling_hidden_channels,
347
- kernel_size=config.down_sampling_kernel_size,
348
- stride=config.down_sampling_stride,
349
- )
350
- up_sampling = UpSampling(
351
- num_layers=config.down_sampling_num_layers,
352
- out_channels=config.down_sampling_in_channels,
353
- hidden_channels=config.down_sampling_hidden_channels,
354
- kernel_size=config.down_sampling_kernel_size,
355
- stride=config.down_sampling_stride,
356
- )
357
-
358
- # shape: [batch_size, channels, num_samples]
359
- # min length: 94, stride: 32, 32 == 2**5
360
- # x = torch.ones([4, 1, 94])
361
- # x = torch.ones([4, 1, 126])
362
- # x = torch.ones([4, 1, 158])
363
- x = torch.ones([4, 1, 190])
364
-
365
- length = x.shape[-1]
366
- padded_length = get_padding_length(
367
- length,
368
- num_layers=config.down_sampling_num_layers,
369
- kernel_size=config.down_sampling_kernel_size,
370
- stride=config.down_sampling_stride,
371
- )
372
- x = F.pad(input=x, pad=(0, padded_length - length), mode="constant", value=0)
373
- # print(x)
374
- print(x.shape)
375
- bottle_neck = down_sampling.forward(x)
376
- print("-" * 150)
377
- x = up_sampling.forward(bottle_neck)
378
- print(x.shape)
379
- return
380
-
381
-
382
- def main():
383
-
384
- config = NXCleanUNetConfig()
385
-
386
- # shape: [batch_size, channels, num_samples]
387
- # min length: 94, stride: 32, 32 == 2**5
388
- # x = torch.ones([4, 94])
389
- # x = torch.ones([4, 126])
390
- # x = torch.ones([4, 158])
391
- # x = torch.ones([4, 190])
392
- x = torch.ones([4, 16000])
393
-
394
- model = NXCleanUNet(config)
395
- enhanced_audios = model.forward(x)
396
- print(enhanced_audios.shape)
397
- return
398
-
399
-
400
- if __name__ == "__main__":
401
- main2()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/transformers/attention.py DELETED
@@ -1,270 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import math
4
- from typing import Tuple
5
-
6
- import torch
7
- import torch.nn as nn
8
-
9
-
10
- class MultiHeadSelfAttention(nn.Module):
11
- def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
12
- """
13
- :param n_head: int. the number of heads.
14
- :param n_feat: int. the number of features.
15
- :param dropout_rate: float. dropout rate.
16
- """
17
- super().__init__()
18
- assert n_feat % n_head == 0
19
- # We assume d_v always equals d_k
20
- self.d_k = n_feat // n_head
21
- self.h = n_head
22
- self.linear_q = nn.Linear(n_feat, n_feat)
23
- self.linear_k = nn.Linear(n_feat, n_feat)
24
- self.linear_v = nn.Linear(n_feat, n_feat)
25
- self.linear_out = nn.Linear(n_feat, n_feat)
26
- self.dropout = nn.Dropout(p=dropout_rate)
27
-
28
- def forward_qkv(self,
29
- query: torch.Tensor,
30
- key: torch.Tensor,
31
- value: torch.Tensor
32
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
33
- """
34
- transform query, key and value.
35
- :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
36
- :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
37
- :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
38
- :return:
39
- """
40
- n_batch = query.size(0)
41
- q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
42
- k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
43
- v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
44
- q = q.transpose(1, 2) # (batch, head, time1, d_k)
45
- k = k.transpose(1, 2) # (batch, head, time2, d_k)
46
- v = v.transpose(1, 2) # (batch, head, time2, d_k)
47
-
48
- return q, k, v
49
-
50
- def forward_attention(self,
51
- value: torch.Tensor,
52
- scores: torch.Tensor,
53
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
54
- ) -> torch.Tensor:
55
- """
56
- compute attention context vector.
57
- :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
58
- :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
59
- :param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
60
- (batch_size, time1, time2), (0, 0, 0) means fake mask.
61
- :return: torch.Tensor. transformed value. (batch_size, time1, d_model).
62
- weighted by the attention score (batch_size, time1, time2).
63
- """
64
- n_batch = value.size(0)
65
- # NOTE: When will `if mask.size(2) > 0` be True?
66
- # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
67
- # 1st chunk to ease the onnx export.]
68
- # 2. pytorch training
69
- if mask.size(2) > 0: # time2 > 0
70
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
71
- # For last chunk, time2 might be larger than scores.size(-1)
72
- mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
73
- scores = scores.masked_fill(mask, -float('inf'))
74
- attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
75
-
76
- # NOTE: When will `if mask.size(2) > 0` be False?
77
- # 1. onnx(16/-1, -1/-1, 16/0)
78
- # 2. jit (16/-1, -1/-1, 16/0, 16/4)
79
- else:
80
- attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
81
-
82
- p_attn = self.dropout(attn)
83
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
84
- x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
85
-
86
- return self.linear_out(x) # (batch, time1, n_feat)
87
-
88
- def forward(self,
89
- x: torch.Tensor,
90
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
91
- cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
92
- ) -> Tuple[torch.Tensor, torch.Tensor]:
93
-
94
- q, k, v = self.forward_qkv(x, x, x)
95
-
96
- if cache.size(0) > 0:
97
- key_cache, value_cache = torch.split(
98
- cache, cache.size(-1) // 2, dim=-1)
99
- k = torch.cat([key_cache, k], dim=2)
100
- v = torch.cat([value_cache, v], dim=2)
101
- # NOTE: We do cache slicing in encoder.forward_chunk, since it's
102
- # non-trivial to calculate `next_cache_start` here.
103
- new_cache = torch.cat((k, v), dim=-1)
104
-
105
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
106
- return self.forward_attention(v, scores, mask), new_cache
107
-
108
-
109
- class RelativeMultiHeadSelfAttention(nn.Module):
110
-
111
- def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120):
112
- """
113
- :param n_head: int. the number of heads.
114
- :param n_feat: int. the number of features.
115
- :param dropout_rate: float. dropout rate.
116
- :param max_relative_position: int. maximum relative position for relative position encoding.
117
- """
118
- super().__init__()
119
- assert n_feat % n_head == 0
120
- # We assume d_v always equals d_k
121
- self.d_k = n_feat // n_head
122
- self.h = n_head
123
- self.linear_q = nn.Linear(n_feat, n_feat)
124
- self.linear_k = nn.Linear(n_feat, n_feat)
125
- self.linear_v = nn.Linear(n_feat, n_feat)
126
- self.linear_out = nn.Linear(n_feat, n_feat)
127
- self.dropout = nn.Dropout(p=dropout_rate)
128
-
129
- # Relative position encoding
130
- self.max_relative_position = max_relative_position
131
- self.relative_position_k = nn.Parameter(torch.randn(max_relative_position * 2 + 1, self.d_k))
132
-
133
- def forward_qkv(self,
134
- query: torch.Tensor,
135
- key: torch.Tensor,
136
- value: torch.Tensor
137
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
138
- """
139
- transform query, key and value.
140
- :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
141
- :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
142
- :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
143
- :return:
144
- """
145
- n_batch = query.size(0)
146
- q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
147
- k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
148
- v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
149
- q = q.transpose(1, 2) # (batch, head, time1, d_k)
150
- k = k.transpose(1, 2) # (batch, head, time2, d_k)
151
- v = v.transpose(1, 2) # (batch, head, time2, d_k)
152
-
153
- return q, k, v
154
-
155
- def forward_attention(self,
156
- value: torch.Tensor,
157
- scores: torch.Tensor,
158
- mask: torch.Tensor = None
159
- ) -> torch.Tensor:
160
- """
161
- compute attention context vector.
162
- :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, key_time_steps, d_k).
163
- :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, query_time_steps, key_time_steps).
164
- :param mask: torch.Tensor. mask. shape=(batch_size, 1, key_time_steps) or (batch_size, query_time_steps, key_time_steps).
165
- :return: torch.Tensor. transformed value. (batch_size, query_time_steps, d_model).
166
- weighted by the attention score (batch_size, query_time_steps, key_time_steps).
167
- """
168
- n_batch = value.size(0)
169
- if mask is not None:
170
- mask = mask.unsqueeze(1).eq(0)
171
- # mask shape: [batch_size, 1, query_time_steps, key_time_steps]
172
- scores = scores.masked_fill(mask, -float('inf'))
173
- attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
174
- else:
175
- attn = torch.softmax(scores, dim=-1)
176
- # attn shape: [batch_size, n_head, query_time_steps, key_time_steps]
177
-
178
- p_attn = self.dropout(attn)
179
-
180
- x = torch.matmul(p_attn, value)
181
- # x shape: [batch_size, n_head, query_time_steps, d_k]
182
- x = x.transpose(1, 2)
183
- # x shape: [batch_size, query_time_steps, n_head, d_k]
184
-
185
- x = x.contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
186
- # x shape: [batch_size, query_time_steps, n_head * d_k]
187
- # x shape: [batch_size, query_time_steps, n_feat]
188
-
189
- x = self.linear_out(x)
190
- # x shape: [batch_size, query_time_steps, n_feat]
191
- return x
192
-
193
- def relative_position_encoding(self, length: int) -> torch.Tensor:
194
- """
195
- Generate relative position encoding.
196
- :param length: int. length of the sequence.
197
- :return: torch.Tensor. relative position encoding. shape=(length, length, d_k).
198
- """
199
- range_vec = torch.arange(length)
200
- distance_mat = range_vec.unsqueeze(0) - range_vec.unsqueeze(1)
201
- distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
202
- final_mat = distance_mat_clipped + self.max_relative_position
203
- return final_mat
204
-
205
- def forward(self,
206
- x: torch.Tensor,
207
- mask: torch.Tensor = None,
208
- cache: torch.Tensor = None
209
- ) -> Tuple[torch.Tensor, torch.Tensor]:
210
- """
211
-
212
- :param x:
213
- :param mask:
214
- :param cache: Tensor, shape: [1, n_heads, time_steps, dim]
215
- :return:
216
- """
217
- # attention! self attention.
218
-
219
- q, k, v = self.forward_qkv(x, x, x)
220
- # q k v shape: [batch_size, self.h, query_time_steps, self.d_k]
221
-
222
- if cache is not None:
223
- key_cache, value_cache = torch.split(
224
- cache, cache.size(-1) // 2, dim=-1)
225
- k = torch.cat([key_cache, k], dim=2)
226
- v = torch.cat([value_cache, v], dim=2)
227
-
228
- # new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2]
229
- new_cache = torch.cat((k, v), dim=-1)
230
-
231
- # native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
232
- native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
233
-
234
- # Compute relative position encoding
235
- q_length, k_length = q.size(2), k.size(2)
236
- relative_position = self.relative_position_encoding(k_length)
237
-
238
- relative_position = relative_position[-q_length:]
239
-
240
- relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1)
241
-
242
- relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k)
243
- relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k)
244
-
245
- relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
246
- # relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
247
-
248
- # score
249
- scores = native_scores + relative_position_scores
250
-
251
- return self.forward_attention(v, scores, mask), new_cache
252
-
253
-
254
- def main():
255
- rel_attention = RelativeMultiHeadSelfAttention(n_head=4, n_feat=256, dropout_rate=0.1)
256
-
257
- x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
258
- xt, new_cache = rel_attention.forward(x, x, x)
259
-
260
- # x = torch.ones(size=(1, 1, 256), dtype=torch.float32)
261
- # cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32)
262
- # xt, new_cache = rel_attention.forward(x, x, x, cache=cache)
263
-
264
- print(xt.shape)
265
- print(new_cache.shape)
266
- return
267
-
268
-
269
- if __name__ == '__main__':
270
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/transformers/mask.py DELETED
@@ -1,74 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import torch
4
-
5
-
6
- def make_pad_mask(lengths: torch.Tensor,
7
- max_len: int = 0,
8
- ) -> torch.Tensor:
9
- batch_size = lengths.size(0)
10
- max_len = max_len if max_len > 0 else lengths.max().item()
11
- seq_range = torch.arange(
12
- 0,
13
- max_len,
14
- dtype=torch.int64,
15
- device=lengths.device
16
- )
17
- seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
18
- seq_length_expand = lengths.unsqueeze(-1)
19
- mask = seq_range_expand >= seq_length_expand
20
- return mask
21
-
22
-
23
-
24
- def subsequent_chunk_mask(
25
- size: int,
26
- chunk_size: int,
27
- num_left_chunks: int = -1,
28
- num_right_chunks: int = 0,
29
- device: torch.device = torch.device("cpu"),
30
- ) -> torch.Tensor:
31
- """
32
- Create mask for subsequent steps (size, size) with chunk size,
33
- this is for streaming encoder
34
-
35
- Examples:
36
- > subsequent_chunk_mask(4, 2)
37
- [[1, 1, 0, 0],
38
- [1, 1, 0, 0],
39
- [1, 1, 1, 1],
40
- [1, 1, 1, 1]]
41
-
42
- :param size: int. size of mask.
43
- :param chunk_size: int. size of chunk.
44
- :param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks.
45
- :param num_right_chunks: int. number of right chunks.
46
- :param device: torch.device. "cpu" or "cuda" or torch.Tensor.device.
47
- :return: torch.Tensor. mask
48
- """
49
-
50
- ret = torch.zeros(size, size, device=device, dtype=torch.bool)
51
- for i in range(size):
52
- if num_left_chunks < 0:
53
- start = 0
54
- else:
55
- start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
56
- ending = min((i // chunk_size + 1 + num_right_chunks) * chunk_size, size)
57
- ret[i, start:ending] = True
58
- return ret
59
-
60
-
61
- def main():
62
- chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2)
63
- print(chunk_mask)
64
-
65
- chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
66
- print(chunk_mask)
67
-
68
- chunk_mask = subsequent_chunk_mask(size=9, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
69
- print(chunk_mask)
70
- return
71
-
72
-
73
- if __name__ == '__main__':
74
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/transformers/transformers.py DELETED
@@ -1,266 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- from typing import Dict, Optional, Tuple, List, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from toolbox.torchaudio.models.nx_clean_unet.transformers.mask import subsequent_chunk_mask
9
- from toolbox.torchaudio.models.nx_clean_unet.transformers.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
10
-
11
-
12
- class PositionwiseFeedForward(nn.Module):
13
- def __init__(self,
14
- input_dim: int,
15
- hidden_units: int,
16
- dropout_rate: float,
17
- activation: torch.nn.Module = torch.nn.ReLU()):
18
- """
19
- FeedForward are applied on each position of the sequence.
20
- the output dim is same with the input dim.
21
-
22
- :param input_dim: int. input dimension.
23
- :param hidden_units: int. the number of hidden units.
24
- :param dropout_rate: float. dropout rate.
25
- :param activation: torch.nn.Module. activation function.
26
- """
27
- super(PositionwiseFeedForward, self).__init__()
28
- self.w_1 = torch.nn.Linear(input_dim, hidden_units)
29
- self.activation = activation
30
- self.dropout = torch.nn.Dropout(dropout_rate)
31
- self.w_2 = torch.nn.Linear(hidden_units, input_dim)
32
-
33
- def forward(self, xs: torch.Tensor) -> torch.Tensor:
34
- """
35
- Forward function.
36
- :param xs: torch.Tensor. input tensor. shape=(batch_size, max_length, dim).
37
- :return: output tensor. shape=(batch_size, max_length, dim).
38
- """
39
- return self.w_2(self.dropout(self.activation(self.w_1(xs))))
40
-
41
-
42
- class TransformerBlock(nn.Module):
43
- def __init__(self,
44
- input_dim: int,
45
- dropout_rate: float = 0.1,
46
- n_heads: int = 4,
47
- max_relative_position: int = 5120
48
- ):
49
- super().__init__()
50
- self.norm1 = nn.LayerNorm(input_dim, eps=1e-5)
51
- self.attention = RelativeMultiHeadSelfAttention(
52
- n_head=n_heads,
53
- n_feat=input_dim,
54
- dropout_rate=dropout_rate,
55
- max_relative_position=max_relative_position,
56
- )
57
-
58
- self.dropout1 = nn.Dropout(dropout_rate)
59
- self.norm2 = nn.LayerNorm(input_dim, eps=1e-5)
60
- self.ffn = PositionwiseFeedForward(
61
- input_dim=input_dim,
62
- hidden_units=input_dim,
63
- dropout_rate=dropout_rate
64
- )
65
- self.dropout2 = nn.Dropout(dropout_rate)
66
- self.norm3 = nn.LayerNorm(input_dim, eps=1e-5)
67
-
68
- def forward(
69
- self,
70
- x: torch.Tensor,
71
- mask: torch.Tensor = None,
72
- attention_cache: torch.Tensor = None,
73
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
74
- """
75
-
76
- :param x: torch.Tensor. shape=(batch_size, time, input_dim).
77
- :param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time).
78
- :param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE
79
- shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim.
80
- :return:
81
- torch.Tensor: Output tensor (batch_size, time, input_dim).
82
- torch.Tensor: att_cache tensor, (batch_size=1, head, cache_t1 + time, d_k * 2).
83
- """
84
-
85
- xt = self.norm1(x)
86
-
87
- x_att, new_att_cache = self.attention.forward(
88
- xt, mask=mask, cache=attention_cache
89
- )
90
- x = x + self.dropout1(xt)
91
- xt = self.norm2(x)
92
- xt = self.ffn.forward(xt)
93
- x = x + self.dropout2(xt)
94
-
95
- x = self.norm3(x)
96
-
97
- return x, new_att_cache
98
-
99
-
100
- class TransformerEncoder(nn.Module):
101
- """
102
- https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/encoder.py#L364
103
- """
104
- def __init__(self,
105
- input_size: int = 64,
106
- hidden_size: int = 256,
107
- attention_heads: int = 4,
108
- num_blocks: int = 6,
109
- dropout_rate: float = 0.1,
110
- max_relative_position: int = 1024,
111
- chunk_size: int = 1,
112
- num_left_chunks: int = 128,
113
- num_right_chunks: int = 2,
114
- ):
115
- super().__init__()
116
- self.input_size = input_size
117
- self.hidden_size = hidden_size
118
-
119
- self.max_relative_position = max_relative_position
120
- self.chunk_size = chunk_size
121
- self.num_left_chunks = num_left_chunks
122
- self.num_right_chunks = num_right_chunks
123
-
124
- self.input_linear = nn.Linear(
125
- in_features=self.input_size,
126
- out_features=self.hidden_size,
127
- )
128
-
129
- self.encoder_layer_list = torch.nn.ModuleList([
130
- TransformerBlock(
131
- input_dim=hidden_size,
132
- n_heads=attention_heads,
133
- dropout_rate=dropout_rate,
134
- max_relative_position=max_relative_position,
135
- ) for _ in range(num_blocks)
136
- ])
137
-
138
- self.output_linear = nn.Linear(
139
- in_features=self.hidden_size,
140
- out_features=self.input_size,
141
- )
142
-
143
- def forward(self,
144
- xs: torch.Tensor,
145
- ):
146
- """
147
- :param xs: Tensor, shape: [batch_size, time_steps, input_size]
148
- :return: Tensor, shape: [batch_size, time_steps, input_size]
149
- """
150
- batch_size, time_steps, _ = xs.shape
151
- # xs shape: [batch_size, time_steps, input_size]
152
- xs = self.input_linear.forward(xs)
153
- # xs shape: [batch_size, time_steps, hidden_size]
154
-
155
- chunk_masks = subsequent_chunk_mask(
156
- size=time_steps,
157
- chunk_size=self.chunk_size,
158
- num_left_chunks=self.num_left_chunks,
159
- num_right_chunks=self.num_right_chunks,
160
- )
161
- chunk_masks = chunk_masks.to(xs.device)
162
- # chunk_masks shape: [1, time_steps, time_steps]
163
- chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps))
164
- # chunk_masks shape: [batch_size, time_steps, time_steps]
165
-
166
- for encoder_layer in self.encoder_layer_list:
167
- xs, _ = encoder_layer.forward(xs, chunk_masks)
168
-
169
- # xs shape: [batch_size, time_steps, hidden_size]
170
- xs = self.output_linear.forward(xs)
171
- # xs shape: [batch_size, time_steps, input_size]
172
-
173
- return xs
174
-
175
- def forward_chunk(self,
176
- xs: torch.Tensor,
177
- max_att_cache_length: int,
178
- attention_cache: torch.Tensor = None,
179
- ) -> Tuple[torch.Tensor, torch.Tensor]:
180
- """
181
- Forward just one chunk.
182
- :param xs: torch.Tensor. chunk input, with shape (b=1, time, mel-dim),
183
- where `time == (chunk_size - 1) * subsample_rate + subsample.right_context + 1`
184
- :param max_att_cache_length:
185
- :param attention_cache: torch.Tensor.
186
- :return:
187
- """
188
- # xs shape: [batch_size, time_steps, input_size]
189
- xs = self.input_linear.forward(xs)
190
- # xs shape: [batch_size, time_steps, hidden_size]
191
-
192
- r_att_cache = []
193
- for idx, encoder_layer in enumerate(self.encoder_layer_list):
194
- xs, new_att_cache = encoder_layer.forward(
195
- x=xs, attention_cache=attention_cache[idx: idx+1] if attention_cache is not None else None,
196
- )
197
- if new_att_cache.size(2) > max_att_cache_length:
198
- begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
199
- end = self.num_right_chunks * self.chunk_size
200
- new_att_cache = new_att_cache[:, :, -begin:-end, :]
201
- r_att_cache.append(new_att_cache)
202
-
203
- r_att_cache = torch.cat(r_att_cache, dim=0)
204
-
205
- return xs, r_att_cache
206
-
207
- def forward_chunk_by_chunk(
208
- self,
209
- xs: torch.Tensor,
210
- ) -> torch.Tensor:
211
-
212
- batch_size, time_steps, _ = xs.shape
213
-
214
- # attention_cache shape: [num_blocks, attention_heads, self.num_left_chunks * self.chunk_size, n_heads * d_k * 2]
215
- max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
216
- attention_cache = None
217
-
218
- outputs = []
219
- for idx in range(0, time_steps - self.chunk_size, self.chunk_size):
220
- begin = idx
221
- end = begin + self.chunk_size * (self.num_right_chunks + 1)
222
- chunk_xs = xs[:, begin:end, :]
223
- # print(f"begin: {begin}, end: {end}, length: {chunk_xs.size(1)}")
224
-
225
- ys, attention_cache = self.forward_chunk(
226
- xs=chunk_xs,
227
- max_att_cache_length=max_att_cache_length,
228
- attention_cache=attention_cache,
229
- )
230
- # ys shape: [batch_size, self.chunk_size * (self.num_right_chunks + 1), hidden_size]
231
- ys = ys[:, :self.chunk_size, :]
232
-
233
- # ys shape: [batch_size, chunk_size, hidden_size]
234
- ys = self.output_linear.forward(ys)
235
- # ys shape: [batch_size, chunk_size, input_size]
236
-
237
- outputs.append(ys)
238
-
239
- ys = torch.cat(outputs, 1)
240
- return ys
241
-
242
-
243
- def main():
244
-
245
- encoder = TransformerEncoder(
246
- input_size=64,
247
- hidden_size=256,
248
- attention_heads=4,
249
- num_blocks=6,
250
- dropout_rate=0.1,
251
- )
252
- print(encoder)
253
-
254
- x = torch.ones([4, 200, 64])
255
-
256
- y = encoder.forward(xs=x)
257
- print(y.shape)
258
-
259
- # y = encoder.forward_chunk_by_chunk(xs=x)
260
- # print(y.shape)
261
-
262
- return
263
-
264
-
265
- if __name__ == '__main__':
266
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/utils.py DELETED
@@ -1,45 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import torch
4
- import torch.nn as nn
5
-
6
-
7
- class LearnableSigmoid1d(nn.Module):
8
- def __init__(self, in_features, beta=1):
9
- super().__init__()
10
- self.beta = beta
11
- self.slope = nn.Parameter(torch.ones(in_features))
12
- self.slope.requiresGrad = True
13
-
14
- def forward(self, x):
15
- # x shape: [batch_size, time_steps, spec_bins]
16
- return self.beta * torch.sigmoid(self.slope * x)
17
-
18
-
19
- def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
20
-
21
- hann_window = torch.hann_window(win_size).to(y.device)
22
- stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
23
- center=center, pad_mode='reflect', normalized=False, return_complex=True)
24
- stft_spec = torch.view_as_real(stft_spec)
25
- mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9)
26
- pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5)
27
- # Magnitude Compression
28
- mag = torch.pow(mag, compress_factor)
29
- com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1)
30
-
31
- return mag, pha, com
32
-
33
-
34
- def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
35
- # Magnitude Decompression
36
- mag = torch.pow(mag, (1.0/compress_factor))
37
- com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha))
38
- hann_window = torch.hann_window(win_size).to(com.device)
39
- wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
40
-
41
- return wav
42
-
43
-
44
- if __name__ == '__main__':
45
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml DELETED
@@ -1,51 +0,0 @@
1
- model_name: "nx_clean_unet"
2
-
3
- sample_rate: 8000
4
- segment_size: 16000
5
- n_fft: 512
6
- win_size: 200
7
- hop_size: 80
8
- # 因为 hop_size 取 80,则相当于 stft 的时间步是 10ms 一步,所以降采样也考虑到差不多的分辨率。
9
-
10
- # 2**down_sampling_num_layers,
11
- # 例如 2**6=64 就意味着 64 个值在降采样之后是一个时间步,
12
- # 则一步是 64/sample_rate = 0.008秒。
13
- # 那么 tsfm_chunk_size=2 则为16ms,tsfm_chunk_size=4 则为32ms
14
- # 假设每次向左看1秒,向右看30ms,则:
15
- # tsfm_chunk_size=1,tsfm_num_left_chunks=128,tsfm_num_right_chunks=4
16
- # tsfm_chunk_size=2,tsfm_num_left_chunks=64,tsfm_num_right_chunks=2
17
- # tsfm_chunk_size=4,tsfm_num_left_chunks=32,tsfm_num_right_chunks=1
18
- down_sampling_num_layers: 6
19
- down_sampling_in_channels: 1
20
- down_sampling_hidden_channels: 64
21
- down_sampling_kernel_size: 4
22
- down_sampling_stride: 2
23
-
24
- causal_in_channels: 1
25
- causal_out_channels: 1
26
- causal_kernel_size: 3
27
- causal_bias: false
28
- causal_separable: true
29
- causal_f_stride: 1
30
- causal_num_layers: 3
31
-
32
- tsfm_hidden_size: 256
33
- tsfm_attention_heads: 8
34
- tsfm_num_blocks: 6
35
- tsfm_dropout_rate: 0.1
36
- tsfm_max_length: 512
37
- tsfm_chunk_size: 1
38
- tsfm_num_left_chunks: 128
39
- tsfm_num_right_chunks: 4
40
-
41
- discriminator_dim: 32
42
- discriminator_in_channel: 2
43
-
44
- compress_factor: 0.3
45
-
46
- batch_size: 4
47
- learning_rate: 0.0005
48
- adam_b1: 0.8
49
- adam_b2: 0.99
50
- lr_decay: 0.99
51
- seed: 1234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
-
4
-
5
- if __name__ == '__main__':
6
- pass
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/causal_convolution/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
-
4
-
5
- if __name__ == '__main__':
6
- pass
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/causal_convolution/causal_conv2d.py DELETED
@@ -1,281 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import math
4
- import os
5
- from typing import List, Optional, Union, Iterable
6
-
7
- import numpy as np
8
- import torch
9
- import torch.nn as nn
10
- from torch.nn import functional as F
11
-
12
-
13
- norm_layer_dict = {
14
- "batch_norm_2d": torch.nn.BatchNorm2d
15
- }
16
-
17
-
18
- activation_layer_dict = {
19
- "relu": torch.nn.ReLU,
20
- "identity": torch.nn.Identity,
21
- "sigmoid": torch.nn.Sigmoid,
22
- }
23
-
24
-
25
- class CausalConv2d(nn.Module):
26
- def __init__(self,
27
- in_channels: int,
28
- out_channels: int,
29
- kernel_size: Union[int, Iterable[int]],
30
- f_stride: int = 1,
31
- dilation: int = 1,
32
- do_f_pad: bool = True,
33
- bias: bool = True,
34
- separable: bool = False,
35
- norm_layer: str = "batch_norm_2d",
36
- activation_layer: str = "relu",
37
- lookahead: int = 0
38
- ):
39
- super(CausalConv2d, self).__init__()
40
- kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
41
-
42
- if do_f_pad:
43
- f_pad = kernel_size[1] // 2 + dilation - 1
44
- else:
45
- f_pad = 0
46
-
47
- self.causal_left_pad = kernel_size[0] - 1 - lookahead
48
- self.causal_right_pad = lookahead
49
- self.constant_pad = nn.ConstantPad2d(
50
- padding=(0, 0, self.causal_left_pad, self.causal_right_pad),
51
- value=0.0
52
- )
53
-
54
- groups = math.gcd(in_channels, out_channels) if separable else 1
55
- self.conv1 = nn.Conv2d(
56
- in_channels,
57
- out_channels,
58
- kernel_size=kernel_size,
59
- padding=(0, f_pad),
60
- stride=(1, f_stride),
61
- dilation=(1, dilation),
62
- groups=groups,
63
- bias=bias,
64
- )
65
-
66
- self.conv2 = None
67
- if not any([groups == 1, max(kernel_size) == 1]):
68
- self.conv2 = nn.Conv2d(
69
- out_channels,
70
- out_channels,
71
- kernel_size=1,
72
- bias=False,
73
- )
74
-
75
- self.norm = None
76
- if norm_layer is not None:
77
- norm_layer = norm_layer_dict[norm_layer]
78
- self.norm = norm_layer(out_channels)
79
-
80
- self.activation = None
81
- if activation_layer is not None:
82
- activation_layer = activation_layer_dict[activation_layer]
83
- self.activation = activation_layer()
84
-
85
- def forward(self,
86
- inputs: torch.Tensor,
87
- causal_cache: List[torch.Tensor] = None,
88
- ):
89
-
90
- if causal_cache is None:
91
- # inputs shape: [batch_size, 1, time_steps, hidden_size]
92
- x = self.constant_pad.forward(inputs)
93
- else:
94
- # inputs shape: [batch_size, 1, time_steps + self.causal_right_pad, hidden_size]
95
- # causal_cache shape: [batch_size, 1, self.causal_left_pad, hidden_size]
96
- x = torch.concat(tensors=[causal_cache, inputs], dim=2)
97
- # x shape: [batch_size, 1, time_steps2, hidden_size]
98
- # time_steps2 = time_steps + self.causal_left_pad + self.causal_right_pad
99
-
100
- causal_cache = x[:, :, -self.causal_left_pad:, :]
101
-
102
- x = self.conv1.forward(x)
103
- # inputs shape: [batch_size, 1, time_steps, hidden_size]
104
-
105
- if self.conv2:
106
- x = self.conv2.forward(x)
107
-
108
- if self.norm:
109
- x = self.norm(x)
110
- if self.activation:
111
- x = self.activation(x)
112
-
113
- # inputs shape: [batch_size, 1, time_steps, hidden_size]
114
- return x, causal_cache
115
-
116
-
117
- class CausalConv2dEncoder(nn.Module):
118
- def __init__(self,
119
- in_channels: int,
120
- hidden_channels: int,
121
- out_channels: int,
122
- kernel_size: Union[int, Iterable[int]],
123
- f_stride: int = 1,
124
- dilation: int = 1,
125
- do_f_pad: bool = True,
126
- bias: bool = True,
127
- separable: bool = False,
128
- norm_layer: str = "batch_norm_2d",
129
- activation_layer: str = "relu",
130
- lookahead: int = 0,
131
- num_layers: int = 5,
132
- ):
133
- super(CausalConv2dEncoder, self).__init__()
134
- self.num_layers = num_layers
135
-
136
- self.total_causal_left_pad = 0
137
- self.total_causal_right_pad = 0
138
-
139
- self.causal_conv_list: List[CausalConv2d] = nn.ModuleList(modules=[])
140
- for i_layer in range(num_layers):
141
- conv = CausalConv2d(
142
- in_channels=in_channels,
143
- out_channels=hidden_channels,
144
- kernel_size=kernel_size,
145
- f_stride=f_stride,
146
- dilation=dilation,
147
- do_f_pad=do_f_pad,
148
- bias=bias,
149
- separable=separable,
150
- norm_layer=norm_layer,
151
- activation_layer=activation_layer,
152
- lookahead=lookahead,
153
- )
154
- self.causal_conv_list.append(conv)
155
-
156
- self.total_causal_left_pad += conv.causal_left_pad
157
- self.total_causal_right_pad += conv.causal_right_pad
158
-
159
- in_channels = hidden_channels
160
- else:
161
- conv = CausalConv2d(
162
- in_channels=hidden_channels,
163
- out_channels=out_channels,
164
- kernel_size=kernel_size,
165
- f_stride=f_stride,
166
- dilation=dilation,
167
- do_f_pad=do_f_pad,
168
- bias=bias,
169
- separable=separable,
170
- norm_layer=norm_layer,
171
- activation_layer=activation_layer,
172
- lookahead=lookahead,
173
- )
174
- self.causal_conv_list.append(conv)
175
-
176
- self.total_causal_left_pad += conv.causal_left_pad
177
- self.total_causal_right_pad += conv.causal_right_pad
178
-
179
-
180
- def forward(self, inputs: torch.Tensor):
181
- # inputs shape: [batch_size, 1, time_steps, hidden_size]
182
-
183
- x = inputs
184
- for layer in self.causal_conv_list:
185
- x, _ = layer.forward(x)
186
- return x
187
-
188
- def forward_chunk(self,
189
- chunk: torch.Tensor,
190
- causal_cache: List[torch.Tensor] = None,
191
- ):
192
- # causal_cache shape: [self.num_layers, batch_size, 1, causal_left_pad, hidden_size]
193
-
194
- new_causal_cache_list: List[torch.Tensor] = list()
195
- for idx, causal_conv in enumerate(self.causal_conv_list):
196
- chunk, new_causal_cache = causal_conv.forward(
197
- inputs=chunk, causal_cache=causal_cache[idx] if causal_cache is not None else None
198
- )
199
- # print(f"idx: {idx}, new_causal_cache: {new_causal_cache.shape}")
200
- new_causal_cache_list.append(new_causal_cache)
201
-
202
- return chunk, new_causal_cache_list
203
-
204
- def forward_chunk_by_chunk(self, inputs: torch.Tensor):
205
- # inputs shape: [batch_size, 1, time_steps, hidden_size]
206
- # batch_size = 1
207
-
208
- batch_size, channels, time_steps, hidden_size = inputs.shape
209
-
210
- new_causal_cache_list: List[torch.Tensor] = None
211
-
212
- outputs = []
213
- for idx in range(0, time_steps, 1):
214
- begin = idx
215
- end = begin + self.total_causal_right_pad + 1
216
- chunk_xs = inputs[:, :, begin:end, :]
217
-
218
- ys, new_causal_cache_list = self.forward_chunk(
219
- chunk=chunk_xs,
220
- causal_cache=new_causal_cache_list,
221
- )
222
- # ys shape: [batch_size, channels, self.total_causal_right_pad + 1 , hidden_size]
223
- ys = ys[:, :, :1, :]
224
-
225
- # ys shape: [batch_size, chunk_size, hidden_size]
226
- outputs.append(ys)
227
-
228
- ys = torch.cat(outputs, 2)
229
- return ys
230
-
231
-
232
- def main2():
233
- conv = CausalConv2d(
234
- in_channels=1,
235
- out_channels=64,
236
- kernel_size=3,
237
- bias=False,
238
- separable=True,
239
- f_stride=1,
240
- lookahead=0,
241
- )
242
-
243
- spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
244
- # spec shape: [batch_size, 1, time_steps, hidden_size]
245
- cache = torch.randn(size=(1, 1, conv.causal_left_pad, 64), dtype=torch.float32)
246
-
247
- output, _ = conv.forward(spec)
248
- print(output.shape)
249
-
250
- output, _ = conv.forward(spec, cache)
251
- print(output.shape)
252
-
253
- return
254
-
255
-
256
- def main():
257
- causal = CausalConv2dEncoder(
258
- in_channels=1,
259
- out_channels=1,
260
- kernel_size=3,
261
- bias=False,
262
- separable=True,
263
- f_stride=1,
264
- lookahead=0,
265
- num_layers=3,
266
- )
267
-
268
- spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
269
- # spec shape: [batch_size, 1, time_steps, hidden_size]
270
-
271
- output = causal.forward(spec)
272
- print(output.shape)
273
-
274
- output = causal.forward_chunk_by_chunk(spec)
275
- print(output.shape)
276
-
277
- return
278
-
279
-
280
- if __name__ == '__main__':
281
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/configuration_nx_denoise.py DELETED
@@ -1,102 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
-
5
-
6
- class NXDenoiseConfig(PretrainedConfig):
7
- """
8
- https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
9
- """
10
- def __init__(self,
11
- sample_rate: int = 8000,
12
- segment_size: int = 16000,
13
- n_fft: int = 512,
14
- win_length: int = 200,
15
- hop_length: int = 80,
16
-
17
- down_sampling_num_layers: int = 5,
18
- down_sampling_in_channels: int = 1,
19
- down_sampling_hidden_channels: int = 64,
20
- down_sampling_kernel_size: int = 4,
21
- down_sampling_stride: int = 2,
22
-
23
- causal_in_channels: int = 1,
24
- causal_hidden_channels: int = 64,
25
- causal_kernel_size: int = 3,
26
- causal_bias: bool = False,
27
- causal_separable: bool = True,
28
- causal_f_stride: int = 1,
29
- # causal_lookahead: int = 0,
30
- causal_num_layers: int = 3,
31
-
32
- tsfm_hidden_size: int = 256,
33
- tsfm_attention_heads: int = 4,
34
- tsfm_num_blocks: int = 6,
35
- tsfm_dropout_rate: float = 0.1,
36
- tsfm_max_time_relative_position: int = 1024,
37
- tsfm_max_freq_relative_position: int = 128,
38
- tsfm_chunk_size: int = 4,
39
- tsfm_num_left_chunks: int = 128,
40
- tsfm_num_right_chunks: int = 2,
41
-
42
- discriminator_dim: int = 16,
43
- discriminator_in_channel: int = 2,
44
-
45
- compress_factor: float = 0.3,
46
-
47
- batch_size: int = 4,
48
- learning_rate: float = 0.0005,
49
- adam_b1: float = 0.8,
50
- adam_b2: float = 0.99,
51
- lr_decay: float = 0.99,
52
- seed: int = 1234,
53
-
54
- **kwargs
55
- ):
56
- super(NXDenoiseConfig, self).__init__(**kwargs)
57
- self.sample_rate = sample_rate
58
- self.segment_size = segment_size
59
- self.n_fft = n_fft
60
- self.win_length = win_length
61
- self.hop_length = hop_length
62
-
63
- self.down_sampling_num_layers = down_sampling_num_layers
64
- self.down_sampling_in_channels = down_sampling_in_channels
65
- self.down_sampling_hidden_channels = down_sampling_hidden_channels
66
- self.down_sampling_kernel_size = down_sampling_kernel_size
67
- self.down_sampling_stride = down_sampling_stride
68
-
69
- self.causal_in_channels = causal_in_channels
70
- self.causal_hidden_channels = causal_hidden_channels
71
- self.causal_kernel_size = causal_kernel_size
72
- self.causal_bias = causal_bias
73
- self.causal_separable = causal_separable
74
- self.causal_f_stride = causal_f_stride
75
- # self.causal_lookahead = causal_lookahead
76
- self.causal_num_layers = causal_num_layers
77
-
78
- self.tsfm_hidden_size = tsfm_hidden_size
79
- self.tsfm_attention_heads = tsfm_attention_heads
80
- self.tsfm_num_blocks = tsfm_num_blocks
81
- self.tsfm_dropout_rate = tsfm_dropout_rate
82
- self.tsfm_max_time_relative_position = tsfm_max_time_relative_position
83
- self.tsfm_max_freq_relative_position = tsfm_max_freq_relative_position
84
- self.tsfm_chunk_size = tsfm_chunk_size
85
- self.tsfm_num_left_chunks = tsfm_num_left_chunks
86
- self.tsfm_num_right_chunks = tsfm_num_right_chunks
87
-
88
- self.discriminator_dim = discriminator_dim
89
- self.discriminator_in_channel = discriminator_in_channel
90
-
91
- self.compress_factor = compress_factor
92
-
93
- self.batch_size = batch_size
94
- self.learning_rate = learning_rate
95
- self.adam_b1 = adam_b1
96
- self.adam_b2 = adam_b2
97
- self.lr_decay = lr_decay
98
- self.seed = seed
99
-
100
-
101
- if __name__ == '__main__':
102
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/discriminator.py DELETED
@@ -1,132 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import os
4
- from typing import Optional, Union
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torchaudio
9
-
10
- from toolbox.torchaudio.configuration_utils import CONFIG_FILE
11
- from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig
12
- from toolbox.torchaudio.models.nx_denoise.utils import LearnableSigmoid1d
13
-
14
-
15
- class MetricDiscriminator(nn.Module):
16
- def __init__(self, config: NXDenoiseConfig):
17
- super(MetricDiscriminator, self).__init__()
18
- dim = config.discriminator_dim
19
- self.in_channel = config.discriminator_in_channel
20
-
21
- self.n_fft = config.n_fft
22
- self.win_length = config.win_length
23
- self.hop_length = config.hop_length
24
-
25
- self.transform = torchaudio.transforms.Spectrogram(
26
- n_fft=self.n_fft,
27
- win_length=self.win_length,
28
- hop_length=self.hop_length,
29
- power=1.0,
30
- window_fn=torch.hann_window,
31
- # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
32
- )
33
-
34
- self.layers = nn.Sequential(
35
- nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
36
- nn.InstanceNorm2d(dim, affine=True),
37
- nn.PReLU(dim),
38
- nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
39
- nn.InstanceNorm2d(dim*2, affine=True),
40
- nn.PReLU(dim*2),
41
- nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
42
- nn.InstanceNorm2d(dim*4, affine=True),
43
- nn.PReLU(dim*4),
44
- nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
45
- nn.InstanceNorm2d(dim*8, affine=True),
46
- nn.PReLU(dim*8),
47
- nn.AdaptiveMaxPool2d(1),
48
- nn.Flatten(),
49
- nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
50
- nn.Dropout(0.3),
51
- nn.PReLU(dim*4),
52
- nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
53
- LearnableSigmoid1d(1)
54
- )
55
-
56
- def forward(self, x, y):
57
- x = self.transform.forward(x)
58
- y = self.transform.forward(y)
59
-
60
- xy = torch.stack((x, y), dim=1)
61
- return self.layers(xy)
62
-
63
-
64
- MODEL_FILE = "discriminator.pt"
65
-
66
-
67
- class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
68
- def __init__(self,
69
- config: NXDenoiseConfig,
70
- ):
71
- super(MetricDiscriminatorPretrainedModel, self).__init__(
72
- config=config,
73
- )
74
- self.config = config
75
-
76
- @classmethod
77
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
78
- config = NXDenoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
79
-
80
- model = cls(config)
81
-
82
- if os.path.isdir(pretrained_model_name_or_path):
83
- ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
84
- else:
85
- ckpt_file = pretrained_model_name_or_path
86
-
87
- with open(ckpt_file, "rb") as f:
88
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
89
- model.load_state_dict(state_dict, strict=True)
90
- return model
91
-
92
- def save_pretrained(self,
93
- save_directory: Union[str, os.PathLike],
94
- state_dict: Optional[dict] = None,
95
- ):
96
-
97
- model = self
98
-
99
- if state_dict is None:
100
- state_dict = model.state_dict()
101
-
102
- os.makedirs(save_directory, exist_ok=True)
103
-
104
- # save state dict
105
- model_file = os.path.join(save_directory, MODEL_FILE)
106
- torch.save(state_dict, model_file)
107
-
108
- # save config
109
- config_file = os.path.join(save_directory, CONFIG_FILE)
110
- self.config.to_yaml_file(config_file)
111
- return save_directory
112
-
113
-
114
- def main():
115
- config = NXDenoiseConfig()
116
- discriminator = MetricDiscriminator(config=config)
117
-
118
- # shape: [batch_size, num_samples]
119
- # x = torch.ones([4, int(4.5 * 16000)])
120
- # y = torch.ones([4, int(4.5 * 16000)])
121
- x = torch.ones([4, 16000])
122
- y = torch.ones([4, 16000])
123
-
124
- output = discriminator.forward(x, y)
125
- print(output.shape)
126
- print(output)
127
-
128
- return
129
-
130
-
131
- if __name__ == "__main__":
132
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/inference_nx_denoise.py DELETED
@@ -1,97 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import logging
4
- from pathlib import Path
5
- import shutil
6
- import tempfile
7
- import zipfile
8
-
9
- import librosa
10
- import numpy as np
11
- import torch
12
- import torchaudio
13
-
14
- from project_settings import project_path
15
- from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig
16
- from toolbox.torchaudio.models.nx_denoise.modeling_nx_denoise import NXDenoisePretrainedModel, MODEL_FILE
17
-
18
- logger = logging.getLogger("toolbox")
19
-
20
-
21
- class InferenceNXDenoise(object):
22
- def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
23
- self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
24
- self.device = torch.device(device)
25
-
26
- logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
27
- config, model = self.load_models(self.pretrained_model_path_or_zip_file)
28
- logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
29
-
30
- self.config = config
31
- self.model = model
32
- self.model.to(device)
33
- self.model.eval()
34
-
35
- def load_models(self, model_path: str):
36
- model_path = Path(model_path)
37
- if model_path.name.endswith(".zip"):
38
- with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
39
- out_root = Path(tempfile.gettempdir()) / "nx_denoise"
40
- out_root.mkdir(parents=True, exist_ok=True)
41
- f_zip.extractall(path=out_root)
42
- model_path = out_root / model_path.stem
43
-
44
- config = NXDenoiseConfig.from_pretrained(
45
- pretrained_model_name_or_path=model_path.as_posix(),
46
- )
47
- model = NXDenoisePretrainedModel.from_pretrained(
48
- pretrained_model_name_or_path=model_path.as_posix(),
49
- )
50
- model.to(self.device)
51
- model.eval()
52
-
53
- shutil.rmtree(model_path)
54
- return config, model
55
-
56
- def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
57
- if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
58
- raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
59
-
60
- # noisy_audio shape: [batch_size, num_samples]
61
- noisy_audios = noisy_audio.to(self.device)
62
-
63
- with torch.no_grad():
64
- # enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios)
65
- enhanced_audios = self.model.forward(noisy_audios)
66
- # enhanced_audio shape: [batch_size, n_samples]
67
- # enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
68
-
69
- enhanced_audio = enhanced_audios[0]
70
- # enhanced_audio shape: [num_samples,]
71
- return enhanced_audio
72
-
73
-
74
- def main():
75
- model_zip_file = project_path / "trained_models/nx-denoise.zip"
76
- runtime = InferenceNXDenoise(model_zip_file)
77
-
78
- sample_rate = 8000
79
- noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
80
- noisy_audio, _ = librosa.load(
81
- noisy_audio_file.as_posix(),
82
- sr=sample_rate,
83
- )
84
- noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
85
- noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
86
- noisy_audio = noisy_audio.unsqueeze(dim=0)
87
-
88
- enhanced_audio = runtime.enhancement_by_tensor(noisy_audio)
89
-
90
- filename = "enhanced_audio.wav"
91
- torchaudio.save(filename, enhanced_audio.detach().cpu().unsqueeze(dim=0), sample_rate)
92
-
93
- return
94
-
95
-
96
- if __name__ == '__main__':
97
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/loss.py DELETED
@@ -1,22 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import numpy as np
4
- import torch
5
-
6
-
7
- def anti_wrapping_function(x):
8
-
9
- return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
10
-
11
-
12
- def phase_losses(phase_r, phase_g):
13
-
14
- ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
15
- gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
16
- iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
17
-
18
- return ip_loss, gd_loss, iaf_loss
19
-
20
-
21
- if __name__ == '__main__':
22
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/metrics.py DELETED
@@ -1,80 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- from joblib import Parallel, delayed
4
- import numpy as np
5
- from pesq import pesq
6
- from typing import List
7
-
8
- from pesq import cypesq
9
-
10
-
11
- def run_pesq(clean_audio: np.ndarray,
12
- noisy_audio: np.ndarray,
13
- sample_rate: int = 16000,
14
- mode: str = "wb",
15
- ) -> float:
16
- if sample_rate == 8000 and mode == "wb":
17
- raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
18
- try:
19
- pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
20
- except cypesq.NoUtterancesError as e:
21
- pesq_score = -1
22
- except Exception as e:
23
- print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
24
- pesq_score = -1
25
- return pesq_score
26
-
27
-
28
- def run_batch_pesq(clean_audio_list: List[np.ndarray],
29
- noisy_audio_list: List[np.ndarray],
30
- sample_rate: int = 16000,
31
- mode: str = "wb",
32
- n_jobs: int = 4,
33
- ) -> List[float]:
34
- parallel = Parallel(n_jobs=n_jobs)
35
-
36
- parallel_tasks = list()
37
- for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
38
- parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
39
- parallel_tasks.append(parallel_task)
40
-
41
- pesq_score_list = parallel.__call__(parallel_tasks)
42
- return pesq_score_list
43
-
44
-
45
- def run_pesq_score(clean_audio_list: List[np.ndarray],
46
- noisy_audio_list: List[np.ndarray],
47
- sample_rate: int = 16000,
48
- mode: str = "wb",
49
- n_jobs: int = 4,
50
- ) -> List[float]:
51
-
52
- pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
53
- noisy_audio_list=noisy_audio_list,
54
- sample_rate=sample_rate,
55
- mode=mode,
56
- n_jobs=n_jobs,
57
- )
58
-
59
- pesq_score = np.mean(pesq_score_list)
60
- return pesq_score
61
-
62
-
63
- def main():
64
- clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
65
- noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
66
-
67
- clean_audio_list = list(clean_audio)
68
- noisy_audio_list = list(noisy_audio)
69
-
70
- pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
71
- print(pesq_score_list)
72
-
73
- pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
74
- print(pesq_score)
75
-
76
- return
77
-
78
-
79
- if __name__ == "__main__":
80
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/modeling_nx_denoise.py DELETED
@@ -1,392 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import os
4
- from typing import List, Optional, Union
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- from torch.nn import functional as F
10
-
11
- from toolbox.torchaudio.configuration_utils import CONFIG_FILE
12
- from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig
13
- from toolbox.torchaudio.models.nx_denoise.causal_convolution.causal_conv2d import CausalConv2dEncoder
14
- from toolbox.torchaudio.models.nx_denoise.transformers.transformers import TSTransformerEncoder
15
-
16
-
17
- class DownSamplingBlock(nn.Module):
18
- def __init__(self,
19
- in_channels: int,
20
- hidden_channels: int,
21
- kernel_size: int,
22
- stride: int,
23
- ):
24
- super(DownSamplingBlock, self).__init__()
25
- self.conv1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, stride)
26
- self.relu = nn.ReLU()
27
- self.conv2 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1)
28
- self.glu = nn.GLU(dim=1)
29
-
30
- def forward(self, x: torch.Tensor):
31
- # x shape: [batch_size, 1, num_samples]
32
- x = self.conv1.forward(x)
33
- # x shape: [batch_size, hidden_channels, new_num_samples]
34
- x = self.relu(x)
35
- x = self.conv2.forward(x)
36
- # x shape: [batch_size, hidden_channels*2, new_num_samples]
37
- x = self.glu(x)
38
- # x shape: [batch_size, hidden_channels, new_num_samples]
39
- # new_num_samples = (num_samples-kernel_size) // stride + 1
40
- return x
41
-
42
-
43
- class DownSampling(nn.Module):
44
- def __init__(self,
45
- num_layers: int,
46
- in_channels: int,
47
- hidden_channels: int,
48
- kernel_size: int,
49
- stride: int,
50
- ):
51
- super(DownSampling, self).__init__()
52
- self.num_layers = num_layers
53
-
54
- down_sampling_block_list = list()
55
- for idx in range(self.num_layers):
56
- down_sampling_block = DownSamplingBlock(
57
- in_channels=in_channels,
58
- hidden_channels=hidden_channels,
59
- kernel_size=kernel_size,
60
- stride=stride,
61
- )
62
- down_sampling_block_list.append(down_sampling_block)
63
- in_channels = hidden_channels
64
-
65
- self.down_sampling_block_list = nn.ModuleList(modules=down_sampling_block_list)
66
-
67
- def forward(self, x: torch.Tensor):
68
- # x shape: [batch_size, channels, num_samples]
69
- skip_connection_list = list()
70
- for down_sampling_block in self.down_sampling_block_list:
71
- x = down_sampling_block.forward(x)
72
- skip_connection_list.append(x)
73
- # x shape: [batch_size, hidden_channels, num_samples**]
74
- return x, skip_connection_list
75
-
76
-
77
- class UpSamplingBlock(nn.Module):
78
- def __init__(self,
79
- out_channels: int,
80
- hidden_channels: int,
81
- kernel_size: int,
82
- stride: int,
83
- do_relu: bool = True,
84
- ):
85
- super(UpSamplingBlock, self).__init__()
86
- self.do_relu = do_relu
87
-
88
- self.conv1 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1)
89
- self.glu = nn.GLU(dim=1)
90
- self.convt = nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride)
91
- self.relu = nn.ReLU()
92
-
93
- def forward(self, x: torch.Tensor):
94
- # x shape: [batch_size, hidden_channels*2, num_samples]
95
- x = self.conv1.forward(x)
96
- # x shape: [batch_size, hidden_channels, num_samples]
97
- x = self.glu(x)
98
- # x shape: [batch_size, hidden_channels, num_samples]
99
- x = self.convt.forward(x)
100
- # x shape: [batch_size, hidden_channels, new_num_samples]
101
- # new_num_samples = (num_samples - 1) * stride + kernel_size
102
- if self.do_relu:
103
- x = self.relu(x)
104
- return x
105
-
106
-
107
- class UpSampling(nn.Module):
108
- def __init__(self,
109
- num_layers: int,
110
- out_channels: int,
111
- hidden_channels: int,
112
- kernel_size: int,
113
- stride: int,
114
- ):
115
- super(UpSampling, self).__init__()
116
- self.num_layers = num_layers
117
-
118
- up_sampling_block_list = list()
119
- for idx in range(self.num_layers-1):
120
- up_sampling_block = UpSamplingBlock(
121
- out_channels=hidden_channels,
122
- hidden_channels=hidden_channels,
123
- kernel_size=kernel_size,
124
- stride=stride,
125
- do_relu=True,
126
- )
127
- up_sampling_block_list.append(up_sampling_block)
128
- else:
129
- up_sampling_block = UpSamplingBlock(
130
- out_channels=out_channels,
131
- hidden_channels=hidden_channels,
132
- kernel_size=kernel_size,
133
- stride=stride,
134
- do_relu=False,
135
- )
136
- up_sampling_block_list.append(up_sampling_block)
137
- self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
138
-
139
- def forward(self, x: torch.Tensor, skip_connection_list: List[torch.Tensor]):
140
- skip_connection_list = skip_connection_list[::-1]
141
-
142
- # x shape: [batch_size, channels, num_samples]
143
- for idx, up_sampling_block in enumerate(self.up_sampling_block_list):
144
- skip_x = skip_connection_list[idx]
145
- x = x + skip_x
146
- # x = x + skip_x[:, :, :x.size(2)]
147
- x = up_sampling_block.forward(x)
148
- return x
149
-
150
-
151
- def get_padding_length(length, num_layers: int, kernel_size: int, stride: int):
152
- for _ in range(num_layers):
153
- if length < kernel_size:
154
- length = 1
155
- else:
156
- length = 1 + np.ceil((length - kernel_size) / stride)
157
-
158
- for _ in range(num_layers):
159
- length = (length - 1) * stride + kernel_size
160
-
161
- padded_length = int(length)
162
- return padded_length
163
-
164
-
165
- class NXDenoise(nn.Module):
166
- def __init__(self, config: NXDenoiseConfig):
167
- super().__init__()
168
- self.config = config
169
-
170
- self.down_sampling = DownSampling(
171
- num_layers=config.down_sampling_num_layers,
172
- in_channels=config.down_sampling_in_channels,
173
- hidden_channels=config.down_sampling_hidden_channels,
174
- kernel_size=config.down_sampling_kernel_size,
175
- stride=config.down_sampling_stride,
176
- )
177
- self.causal_conv_in = CausalConv2dEncoder(
178
- in_channels=config.causal_in_channels,
179
- hidden_channels=config.causal_hidden_channels,
180
- out_channels=config.causal_hidden_channels,
181
- kernel_size=config.causal_kernel_size,
182
- bias=config.causal_bias,
183
- separable=config.causal_separable,
184
- f_stride=config.causal_f_stride,
185
- lookahead=0,
186
- num_layers=config.causal_num_layers,
187
- )
188
- self.ts_transformer = TSTransformerEncoder(
189
- input_size=config.down_sampling_hidden_channels,
190
- hidden_size=config.tsfm_hidden_size,
191
- attention_heads=config.tsfm_attention_heads,
192
- num_blocks=config.tsfm_num_blocks,
193
- dropout_rate=config.tsfm_dropout_rate,
194
- max_time_relative_position=config.tsfm_max_time_relative_position,
195
- max_freq_relative_position=config.tsfm_max_freq_relative_position,
196
- chunk_size=config.tsfm_chunk_size,
197
- num_left_chunks=config.tsfm_num_left_chunks,
198
- num_right_chunks=config.tsfm_num_right_chunks,
199
- )
200
- self.causal_conv_out = CausalConv2dEncoder(
201
- in_channels=config.causal_hidden_channels,
202
- hidden_channels=config.causal_hidden_channels,
203
- out_channels=config.causal_in_channels,
204
- kernel_size=config.causal_kernel_size,
205
- bias=config.causal_bias,
206
- separable=config.causal_separable,
207
- f_stride=config.causal_f_stride,
208
- lookahead=0,
209
- num_layers=config.causal_num_layers,
210
- )
211
- self.up_sampling = UpSampling(
212
- num_layers=config.down_sampling_num_layers,
213
- out_channels=config.down_sampling_in_channels,
214
- hidden_channels=config.down_sampling_hidden_channels,
215
- kernel_size=config.down_sampling_kernel_size,
216
- stride=config.down_sampling_stride,
217
- )
218
-
219
- def forward(self, noisy_audios: torch.Tensor):
220
- # noisy_audios shape: [batch_size, n_samples]
221
- noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
222
- # noisy_audios shape: [batch_size, 1, n_samples]
223
-
224
- n_samples = noisy_audios.shape[-1]
225
- padded_length = get_padding_length(
226
- n_samples,
227
- num_layers=self.config.down_sampling_num_layers,
228
- kernel_size=self.config.down_sampling_kernel_size,
229
- stride=self.config.down_sampling_stride,
230
- )
231
- noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
232
-
233
- # down sampling
234
- bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
235
- # bottle_neck shape: [batch_size, channels, time_steps]
236
- bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
237
- # bottle_neck shape: [batch_size, time_steps, channels]
238
- bottle_neck = torch.unsqueeze(bottle_neck, dim=1)
239
- # bottle_neck shape: [batch_size, 1, time_steps, freq_dim]
240
-
241
- # causal conv in
242
- bottle_neck = self.causal_conv_in.forward(bottle_neck)
243
- # bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
244
-
245
- # ts transformer
246
- # bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
247
- bottle_neck = self.ts_transformer.forward(bottle_neck)
248
- # bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
249
-
250
- # causal conv out
251
- bottle_neck = self.causal_conv_out.forward(bottle_neck)
252
- # bottle_neck shape: [batch_size, 1, time_steps, freq_dim]
253
-
254
- # up sampling
255
- bottle_neck = torch.squeeze(bottle_neck, dim=1)
256
- # bottle_neck shape: [batch_size, time_steps, channels]
257
- bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
258
- # bottle_neck shape: [batch_size, channels, time_steps]
259
-
260
- enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
261
-
262
- enhanced_audios = enhanced_audios[:, :, :n_samples]
263
- # enhanced_audios shape: [batch_size, 1, n_samples]
264
-
265
- enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
266
- # enhanced_audios shape: [batch_size, n_samples]
267
-
268
- return enhanced_audios
269
-
270
-
271
- def forward_chunk_by_chunk(self, noisy_audios: torch.Tensor):
272
- # noisy_audios shape: [batch_size, n_samples]
273
- noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
274
- # noisy_audios shape: [batch_size, 1, n_samples]
275
-
276
- n_samples = noisy_audios.shape[-1]
277
- padded_length = get_padding_length(
278
- n_samples,
279
- num_layers=self.config.down_sampling_num_layers,
280
- kernel_size=self.config.down_sampling_kernel_size,
281
- stride=self.config.down_sampling_stride,
282
- )
283
- noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
284
-
285
- # down sampling
286
- bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
287
- # bottle_neck shape: [batch_size, channels, time_steps]
288
- bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
289
- # bottle_neck shape: [batch_size, time_steps, channels]
290
- bottle_neck = torch.unsqueeze(bottle_neck, dim=1)
291
- # bottle_neck shape: [batch_size, 1, time_steps, freq_dim]
292
-
293
- # causal conv in
294
- bottle_neck = self.causal_conv_in.forward_chunk_by_chunk(bottle_neck)
295
- # bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
296
-
297
- # ts transformer
298
- # bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
299
- bottle_neck = self.ts_transformer.forward_chunk_by_chunk(bottle_neck)
300
- # bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
301
-
302
- # causal conv out
303
- bottle_neck = self.causal_conv_out.forward_chunk_by_chunk(bottle_neck)
304
- # bottle_neck shape: [batch_size, 1, time_steps, freq_dim]
305
-
306
- # up sampling
307
- bottle_neck = torch.squeeze(bottle_neck, dim=1)
308
- # bottle_neck shape: [batch_size, time_steps, channels]
309
- bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
310
- # bottle_neck shape: [batch_size, channels, time_steps]
311
-
312
- enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
313
-
314
- enhanced_audios = enhanced_audios[:, :, :n_samples]
315
- # enhanced_audios shape: [batch_size, 1, n_samples]
316
-
317
- enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
318
- # enhanced_audios shape: [batch_size, n_samples]
319
-
320
- return enhanced_audios
321
-
322
-
323
- MODEL_FILE = "generator.pt"
324
-
325
-
326
- class NXDenoisePretrainedModel(NXDenoise):
327
- def __init__(self,
328
- config: NXDenoiseConfig,
329
- ):
330
- super(NXDenoisePretrainedModel, self).__init__(
331
- config=config,
332
- )
333
- self.config = config
334
-
335
- @classmethod
336
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
337
- config = NXDenoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
338
-
339
- model = cls(config)
340
-
341
- if os.path.isdir(pretrained_model_name_or_path):
342
- ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
343
- else:
344
- ckpt_file = pretrained_model_name_or_path
345
-
346
- with open(ckpt_file, "rb") as f:
347
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
348
- model.load_state_dict(state_dict, strict=True)
349
- return model
350
-
351
- def save_pretrained(self,
352
- save_directory: Union[str, os.PathLike],
353
- state_dict: Optional[dict] = None,
354
- ):
355
-
356
- model = self
357
-
358
- if state_dict is None:
359
- state_dict = model.state_dict()
360
-
361
- os.makedirs(save_directory, exist_ok=True)
362
-
363
- # save state dict
364
- model_file = os.path.join(save_directory, MODEL_FILE)
365
- torch.save(state_dict, model_file)
366
-
367
- # save config
368
- config_file = os.path.join(save_directory, CONFIG_FILE)
369
- self.config.to_yaml_file(config_file)
370
- return save_directory
371
-
372
-
373
- def main():
374
-
375
- config = NXDenoiseConfig()
376
-
377
- # shape: [batch_size, channels, num_samples]
378
- # min length: 94, stride: 32, 32 == 2**5
379
- # x = torch.ones([4, 94])
380
- # x = torch.ones([4, 126])
381
- # x = torch.ones([4, 158])
382
- # x = torch.ones([4, 190])
383
- x = torch.ones([4, 16000])
384
-
385
- model = NXDenoise(config)
386
- enhanced_audios = model.forward(x)
387
- print(enhanced_audios.shape)
388
- return
389
-
390
-
391
- if __name__ == "__main__":
392
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/stftnet/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
-
4
-
5
- if __name__ == '__main__':
6
- pass
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/stftnet/stfnets.py DELETED
@@ -1,9 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- https://arxiv.org/abs/1902.07849
5
- """
6
-
7
-
8
- if __name__ == '__main__':
9
- pass
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/transformers/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
-
4
-
5
- if __name__ == '__main__':
6
- pass
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/transformers/attention.py DELETED
@@ -1,263 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import math
4
- from typing import Tuple
5
-
6
- import torch
7
- import torch.nn as nn
8
-
9
-
10
- class MultiHeadSelfAttention(nn.Module):
11
- def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
12
- """
13
- :param n_head: int. the number of heads.
14
- :param n_feat: int. the number of features.
15
- :param dropout_rate: float. dropout rate.
16
- """
17
- super().__init__()
18
- assert n_feat % n_head == 0
19
- # We assume d_v always equals d_k
20
- self.d_k = n_feat // n_head
21
- self.h = n_head
22
- self.linear_q = nn.Linear(n_feat, n_feat)
23
- self.linear_k = nn.Linear(n_feat, n_feat)
24
- self.linear_v = nn.Linear(n_feat, n_feat)
25
- self.linear_out = nn.Linear(n_feat, n_feat)
26
- self.dropout = nn.Dropout(p=dropout_rate)
27
-
28
- def forward_qkv(self,
29
- query: torch.Tensor,
30
- key: torch.Tensor,
31
- value: torch.Tensor
32
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
33
- """
34
- transform query, key and value.
35
- :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
36
- :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
37
- :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
38
- :return:
39
- """
40
- n_batch = query.size(0)
41
- q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
42
- k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
43
- v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
44
- q = q.transpose(1, 2) # (batch, head, time1, d_k)
45
- k = k.transpose(1, 2) # (batch, head, time2, d_k)
46
- v = v.transpose(1, 2) # (batch, head, time2, d_k)
47
-
48
- return q, k, v
49
-
50
- def forward_attention(self,
51
- value: torch.Tensor,
52
- scores: torch.Tensor,
53
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
54
- ) -> torch.Tensor:
55
- """
56
- compute attention context vector.
57
- :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
58
- :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
59
- :param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
60
- (batch_size, time1, time2), (0, 0, 0) means fake mask.
61
- :return: torch.Tensor. transformed value. (batch_size, time1, d_model).
62
- weighted by the attention score (batch_size, time1, time2).
63
- """
64
- n_batch = value.size(0)
65
- # NOTE: When will `if mask.size(2) > 0` be True?
66
- # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
67
- # 1st chunk to ease the onnx export.]
68
- # 2. pytorch training
69
- if mask.size(2) > 0: # time2 > 0
70
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
71
- # For last chunk, time2 might be larger than scores.size(-1)
72
- mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
73
- scores = scores.masked_fill(mask, -float('inf'))
74
- attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
75
-
76
- # NOTE: When will `if mask.size(2) > 0` be False?
77
- # 1. onnx(16/-1, -1/-1, 16/0)
78
- # 2. jit (16/-1, -1/-1, 16/0, 16/4)
79
- else:
80
- attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
81
-
82
- p_attn = self.dropout(attn)
83
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
84
- x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
85
-
86
- return self.linear_out(x) # (batch, time1, n_feat)
87
-
88
- def forward(self,
89
- x: torch.Tensor,
90
- mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
91
- cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
92
- ) -> Tuple[torch.Tensor, torch.Tensor]:
93
-
94
- q, k, v = self.forward_qkv(x, x, x)
95
-
96
- if cache.size(0) > 0:
97
- key_cache, value_cache = torch.split(
98
- cache, cache.size(-1) // 2, dim=-1)
99
- k = torch.cat([key_cache, k], dim=2)
100
- v = torch.cat([value_cache, v], dim=2)
101
- # NOTE: We do cache slicing in encoder.forward_chunk, since it's
102
- # non-trivial to calculate `next_cache_start` here.
103
- new_cache = torch.cat((k, v), dim=-1)
104
-
105
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
106
- return self.forward_attention(v, scores, mask), new_cache
107
-
108
-
109
- class RelativeMultiHeadSelfAttention(nn.Module):
110
-
111
- def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120):
112
- """
113
- :param n_head: int. the number of heads.
114
- :param n_feat: int. the number of features.
115
- :param dropout_rate: float. dropout rate.
116
- :param max_relative_position: int. maximum relative position for relative position encoding.
117
- """
118
- super().__init__()
119
- assert n_feat % n_head == 0
120
- # We assume d_v always equals d_k
121
- self.d_k = n_feat // n_head
122
- self.h = n_head
123
- self.linear_q = nn.Linear(n_feat, n_feat)
124
- self.linear_k = nn.Linear(n_feat, n_feat)
125
- self.linear_v = nn.Linear(n_feat, n_feat)
126
- self.linear_out = nn.Linear(n_feat, n_feat)
127
- self.dropout = nn.Dropout(p=dropout_rate)
128
-
129
- # Relative position encoding
130
- self.max_relative_position = max_relative_position
131
- self.relative_position_k = nn.Parameter(torch.randn(max_relative_position * 2 + 1, self.d_k))
132
-
133
- def forward_qkv(self,
134
- query: torch.Tensor,
135
- key: torch.Tensor,
136
- value: torch.Tensor
137
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
138
- """
139
- transform query, key and value.
140
- :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
141
- :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
142
- :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
143
- :return:
144
- """
145
- n_batch = query.size(0)
146
- q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
147
- k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
148
- v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
149
- q = q.transpose(1, 2) # (batch, head, time1, d_k)
150
- k = k.transpose(1, 2) # (batch, head, time2, d_k)
151
- v = v.transpose(1, 2) # (batch, head, time2, d_k)
152
-
153
- return q, k, v
154
-
155
- def forward_attention(self,
156
- value: torch.Tensor,
157
- scores: torch.Tensor,
158
- mask: torch.Tensor = None
159
- ) -> torch.Tensor:
160
- """
161
- compute attention context vector.
162
- :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, key_time_steps, d_k).
163
- :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, query_time_steps, key_time_steps).
164
- :param mask: torch.Tensor. mask. shape=(batch_size, 1, key_time_steps) or (batch_size, query_time_steps, key_time_steps).
165
- :return: torch.Tensor. transformed value. (batch_size, query_time_steps, d_model).
166
- weighted by the attention score (batch_size, query_time_steps, key_time_steps).
167
- """
168
- n_batch = value.size(0)
169
- if mask is not None:
170
- mask = mask.unsqueeze(1).eq(0)
171
- # mask shape: [batch_size, 1, query_time_steps, key_time_steps]
172
- scores = scores.masked_fill(mask, -float('inf'))
173
- attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
174
- else:
175
- attn = torch.softmax(scores, dim=-1)
176
- # attn shape: [batch_size, n_head, query_time_steps, key_time_steps]
177
-
178
- p_attn = self.dropout(attn)
179
-
180
- x = torch.matmul(p_attn, value)
181
- # x shape: [batch_size, n_head, query_time_steps, d_k]
182
- x = x.transpose(1, 2)
183
- # x shape: [batch_size, query_time_steps, n_head, d_k]
184
-
185
- x = x.contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
186
- # x shape: [batch_size, query_time_steps, n_head * d_k]
187
- # x shape: [batch_size, query_time_steps, n_feat]
188
-
189
- x = self.linear_out(x)
190
- # x shape: [batch_size, query_time_steps, n_feat]
191
- return x
192
-
193
- def relative_position_encoding(self, length: int) -> torch.Tensor:
194
- """
195
- Generate relative position encoding.
196
- :param length: int. length of the sequence.
197
- :return: torch.Tensor. relative position encoding. shape=(length, length, d_k).
198
- """
199
- range_vec = torch.arange(length)
200
- distance_mat = range_vec.unsqueeze(0) - range_vec.unsqueeze(1)
201
- distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
202
- final_mat = distance_mat_clipped + self.max_relative_position
203
- return final_mat
204
-
205
- def forward(self,
206
- x: torch.Tensor,
207
- mask: torch.Tensor = None,
208
- cache: torch.Tensor = None
209
- ) -> Tuple[torch.Tensor, torch.Tensor]:
210
- # attention! self attention.
211
-
212
- q, k, v = self.forward_qkv(x, x, x)
213
- # q k v shape: [batch_size, self.h, query_time_steps, self.d_k]
214
-
215
- if cache is not None:
216
- key_cache, value_cache = torch.split(
217
- cache, cache.size(-1) // 2, dim=-1)
218
- k = torch.cat([key_cache, k], dim=2)
219
- v = torch.cat([value_cache, v], dim=2)
220
-
221
- # new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2]
222
- new_cache = torch.cat((k, v), dim=-1)
223
-
224
- # native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
225
- native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
226
-
227
- # Compute relative position encoding
228
- q_length, k_length = q.size(2), k.size(2)
229
- relative_position = self.relative_position_encoding(k_length)
230
-
231
- relative_position = relative_position[-q_length:]
232
-
233
- relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1)
234
-
235
- relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k)
236
- relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k)
237
-
238
- relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
239
- # relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
240
-
241
- # score
242
- scores = native_scores + relative_position_scores
243
-
244
- return self.forward_attention(v, scores, mask), new_cache
245
-
246
-
247
- def main():
248
- rel_attention = RelativeMultiHeadSelfAttention(n_head=4, n_feat=256, dropout_rate=0.1)
249
-
250
- x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
251
- xt, new_cache = rel_attention.forward(x, x, x)
252
-
253
- # x = torch.ones(size=(1, 1, 256), dtype=torch.float32)
254
- # cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32)
255
- # xt, new_cache = rel_attention.forward(x, x, x, cache=cache)
256
-
257
- print(xt.shape)
258
- print(new_cache.shape)
259
- return
260
-
261
-
262
- if __name__ == '__main__':
263
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/transformers/mask.py DELETED
@@ -1,74 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import torch
4
-
5
-
6
- def make_pad_mask(lengths: torch.Tensor,
7
- max_len: int = 0,
8
- ) -> torch.Tensor:
9
- batch_size = lengths.size(0)
10
- max_len = max_len if max_len > 0 else lengths.max().item()
11
- seq_range = torch.arange(
12
- 0,
13
- max_len,
14
- dtype=torch.int64,
15
- device=lengths.device
16
- )
17
- seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
18
- seq_length_expand = lengths.unsqueeze(-1)
19
- mask = seq_range_expand >= seq_length_expand
20
- return mask
21
-
22
-
23
-
24
- def subsequent_chunk_mask(
25
- size: int,
26
- chunk_size: int,
27
- num_left_chunks: int = -1,
28
- num_right_chunks: int = 0,
29
- device: torch.device = torch.device("cpu"),
30
- ) -> torch.Tensor:
31
- """
32
- Create mask for subsequent steps (size, size) with chunk size,
33
- this is for streaming encoder
34
-
35
- Examples:
36
- > subsequent_chunk_mask(4, 2)
37
- [[1, 1, 0, 0],
38
- [1, 1, 0, 0],
39
- [1, 1, 1, 1],
40
- [1, 1, 1, 1]]
41
-
42
- :param size: int. size of mask.
43
- :param chunk_size: int. size of chunk.
44
- :param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks.
45
- :param num_right_chunks: int. number of right chunks.
46
- :param device: torch.device. "cpu" or "cuda" or torch.Tensor.device.
47
- :return: torch.Tensor. mask
48
- """
49
-
50
- ret = torch.zeros(size, size, device=device, dtype=torch.bool)
51
- for i in range(size):
52
- if num_left_chunks < 0:
53
- start = 0
54
- else:
55
- start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
56
- ending = min((i // chunk_size + 1 + num_right_chunks) * chunk_size, size)
57
- ret[i, start:ending] = True
58
- return ret
59
-
60
-
61
- def main():
62
- chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2)
63
- print(chunk_mask)
64
-
65
- chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
66
- print(chunk_mask)
67
-
68
- chunk_mask = subsequent_chunk_mask(size=9, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
69
- print(chunk_mask)
70
- return
71
-
72
-
73
- if __name__ == '__main__':
74
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/transformers/transformers.py DELETED
@@ -1,479 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- from typing import Dict, Optional, Tuple, List, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from toolbox.torchaudio.models.nx_clean_unet.transformers.mask import subsequent_chunk_mask
9
- from toolbox.torchaudio.models.nx_clean_unet.transformers.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
10
-
11
-
12
- class PositionwiseFeedForward(nn.Module):
13
- def __init__(self,
14
- input_dim: int,
15
- hidden_units: int,
16
- dropout_rate: float,
17
- activation: torch.nn.Module = torch.nn.ReLU()):
18
- """
19
- FeedForward are applied on each position of the sequence.
20
- the output dim is same with the input dim.
21
-
22
- :param input_dim: int. input dimension.
23
- :param hidden_units: int. the number of hidden units.
24
- :param dropout_rate: float. dropout rate.
25
- :param activation: torch.nn.Module. activation function.
26
- """
27
- super(PositionwiseFeedForward, self).__init__()
28
- self.w_1 = torch.nn.Linear(input_dim, hidden_units)
29
- self.activation = activation
30
- self.dropout = torch.nn.Dropout(dropout_rate)
31
- self.w_2 = torch.nn.Linear(hidden_units, input_dim)
32
-
33
- def forward(self, xs: torch.Tensor) -> torch.Tensor:
34
- """
35
- Forward function.
36
- :param xs: torch.Tensor. input tensor. shape=(batch_size, max_length, dim).
37
- :return: output tensor. shape=(batch_size, max_length, dim).
38
- """
39
- return self.w_2(self.dropout(self.activation(self.w_1(xs))))
40
-
41
-
42
- class TransformerBlock(nn.Module):
43
- def __init__(self,
44
- input_dim: int,
45
- dropout_rate: float = 0.1,
46
- n_heads: int = 4,
47
- max_relative_position: int = 5120
48
- ):
49
- super().__init__()
50
- self.norm1 = nn.LayerNorm(input_dim, eps=1e-5)
51
- self.attention = RelativeMultiHeadSelfAttention(
52
- n_head=n_heads,
53
- n_feat=input_dim,
54
- dropout_rate=dropout_rate,
55
- max_relative_position=max_relative_position,
56
- )
57
-
58
- self.dropout1 = nn.Dropout(dropout_rate)
59
- self.norm2 = nn.LayerNorm(input_dim, eps=1e-5)
60
- self.ffn = PositionwiseFeedForward(
61
- input_dim=input_dim,
62
- hidden_units=input_dim,
63
- dropout_rate=dropout_rate
64
- )
65
- self.dropout2 = nn.Dropout(dropout_rate)
66
- self.norm3 = nn.LayerNorm(input_dim, eps=1e-5)
67
-
68
- def forward(
69
- self,
70
- x: torch.Tensor,
71
- mask: torch.Tensor = None,
72
- attention_cache: torch.Tensor = None,
73
- ) -> Tuple[torch.Tensor, torch.Tensor]:
74
- """
75
-
76
- :param x: torch.Tensor. shape=(batch_size, time, input_dim).
77
- :param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time).
78
- :param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE
79
- shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim.
80
- :return:
81
- torch.Tensor: Output tensor (batch_size, time, input_dim).
82
- torch.Tensor: att_cache tensor, (batch_size=1, head, cache_t1 + time, d_k * 2).
83
- """
84
- xt = self.norm1(x)
85
-
86
- x_att, new_att_cache = self.attention.forward(
87
- xt, mask=mask, cache=attention_cache
88
- )
89
- x = x + self.dropout1(xt)
90
- xt = self.norm2(x)
91
- xt = self.ffn.forward(xt)
92
- x = x + self.dropout2(xt)
93
-
94
- x = self.norm3(x)
95
-
96
- return x, new_att_cache
97
-
98
-
99
- class TransformerEncoder(nn.Module):
100
- """
101
- https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/encoder.py#L364
102
- """
103
- def __init__(self,
104
- input_size: int = 64,
105
- hidden_size: int = 256,
106
- attention_heads: int = 4,
107
- num_blocks: int = 6,
108
- dropout_rate: float = 0.1,
109
- max_relative_position: int = 1024,
110
- chunk_size: int = 1,
111
- num_left_chunks: int = 128,
112
- num_right_chunks: int = 2,
113
- ):
114
- super().__init__()
115
- self.input_size = input_size
116
- self.hidden_size = hidden_size
117
-
118
- self.max_relative_position = max_relative_position
119
- self.chunk_size = chunk_size
120
- self.num_left_chunks = num_left_chunks
121
- self.num_right_chunks = num_right_chunks
122
-
123
- self.input_linear = nn.Linear(
124
- in_features=self.input_size,
125
- out_features=self.hidden_size,
126
- )
127
-
128
- self.encoder_layer_list = torch.nn.ModuleList([
129
- TransformerBlock(
130
- input_dim=hidden_size,
131
- n_heads=attention_heads,
132
- dropout_rate=dropout_rate,
133
- max_relative_position=max_relative_position,
134
- ) for _ in range(num_blocks)
135
- ])
136
-
137
- self.output_linear = nn.Linear(
138
- in_features=self.hidden_size,
139
- out_features=self.input_size,
140
- )
141
-
142
- def forward(self,
143
- xs: torch.Tensor,
144
- ):
145
- """
146
- :param xs: Tensor, shape: [batch_size, time_steps, input_size]
147
- :return: Tensor, shape: [batch_size, time_steps, input_size]
148
- """
149
- batch_size, time_steps, _ = xs.shape
150
- # xs shape: [batch_size, time_steps, input_size]
151
- xs = self.input_linear.forward(xs)
152
- # xs shape: [batch_size, time_steps, hidden_size]
153
-
154
- chunk_masks = subsequent_chunk_mask(
155
- size=time_steps,
156
- chunk_size=self.chunk_size,
157
- num_left_chunks=self.num_left_chunks,
158
- num_right_chunks=self.num_right_chunks,
159
- )
160
- chunk_masks = chunk_masks.to(xs.device)
161
- # chunk_masks shape: [time_steps, time_steps]
162
- chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps))
163
- # chunk_masks shape: [batch_size, time_steps, time_steps]
164
-
165
- for encoder_layer in self.encoder_layer_list:
166
- xs, _ = encoder_layer.forward(xs, chunk_masks)
167
-
168
- # xs shape: [batch_size, time_steps, hidden_size]
169
- xs = self.output_linear.forward(xs)
170
- # xs shape: [batch_size, time_steps, input_size]
171
-
172
- return xs
173
-
174
- def forward_chunk(self,
175
- xs: torch.Tensor,
176
- max_att_cache_length: int,
177
- attention_cache: torch.Tensor = None,
178
- ) -> Tuple[torch.Tensor, torch.Tensor]:
179
- """
180
-
181
- :param xs:
182
- :param max_att_cache_length:
183
- :param attention_cache: Tensor, [num_layers, ...]
184
- :return:
185
- """
186
- # xs shape: [batch_size, time_steps, input_size]
187
- xs = self.input_linear.forward(xs)
188
- # xs shape: [batch_size, time_steps, hidden_size]
189
-
190
- r_att_cache = []
191
- for idx, encoder_layer in enumerate(self.encoder_layer_list):
192
- xs, new_att_cache = encoder_layer.forward(
193
- x=xs, attention_cache=attention_cache[idx] if attention_cache is not None else None,
194
- )
195
- # new_att_cache shape: [batch_size, n_heads, time_steps, dim]
196
- if new_att_cache.size(2) > max_att_cache_length:
197
- begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
198
- end = self.num_right_chunks * self.chunk_size
199
- new_att_cache = new_att_cache[:, :, -begin:-end, :]
200
- r_att_cache.append(new_att_cache)
201
-
202
- r_att_cache = torch.stack(r_att_cache, dim=0)
203
-
204
- # xs shape: [batch_size, time_steps, hidden_size]
205
- xs = self.output_linear.forward(xs)
206
- # xs shape: [batch_size, time_steps, input_size]
207
-
208
- return xs, r_att_cache
209
-
210
- def forward_chunk_by_chunk(
211
- self,
212
- xs: torch.Tensor,
213
- ) -> torch.Tensor:
214
-
215
- batch_size, time_steps, _ = xs.shape
216
-
217
- # attention_cache shape: [num_blocks, attention_heads, self.num_left_chunks * self.chunk_size, n_heads * d_k * 2]
218
- max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
219
- attention_cache = None
220
-
221
- outputs = []
222
- for idx in range(0, time_steps, self.chunk_size):
223
- begin = idx
224
- end = begin + self.chunk_size * (self.num_right_chunks + 1)
225
- chunk_xs = xs[:, begin:end, :]
226
- # print(f"begin: {begin}, end: {end}, length: {chunk_xs.size(1)}")
227
-
228
- ys, attention_cache = self.forward_chunk(
229
- xs=chunk_xs,
230
- max_att_cache_length=max_att_cache_length,
231
- attention_cache=attention_cache,
232
- )
233
-
234
- # ys shape: [batch_size, self.chunk_size * (self.num_right_chunks + 1), input_size]
235
- ys = ys[:, :self.chunk_size, :]
236
-
237
- outputs.append(ys)
238
-
239
- ys = torch.cat(outputs, 1)
240
- return ys
241
-
242
-
243
- class TSTransformerBlock(nn.Module):
244
- def __init__(self,
245
- input_dim: int,
246
- dropout_rate: float = 0.1,
247
- n_heads: int = 4,
248
- max_time_relative_position: int = 1024,
249
- max_freq_relative_position: int = 128,
250
- ):
251
- super(TSTransformerBlock, self).__init__()
252
- self.time_transformer = TransformerBlock(input_dim, dropout_rate, n_heads, max_time_relative_position)
253
- self.freq_transformer = TransformerBlock(input_dim, dropout_rate, n_heads, max_freq_relative_position)
254
-
255
- def forward(self,
256
- x: torch.Tensor,
257
- mask: torch.Tensor = None,
258
- attention_cache: torch.Tensor = None,
259
- ):
260
- """
261
-
262
- :param x: Tensor. shape: [batch_size, hidden_size, time_steps, input_size]
263
- :param mask: Tensor. shape: [time_steps, time_steps]
264
- :param attention_cache:
265
- :return:
266
- """
267
- b, c, t, f = x.size()
268
-
269
- mask = None if mask is None else torch.broadcast_to(mask, size=(b*f, t, t))
270
-
271
- x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
272
- x_, new_att_cache = self.time_transformer.forward(x, mask, attention_cache)
273
- x = x_ + x
274
- x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
275
- x_, _ = self.freq_transformer.forward(x)
276
- x = x_ + x
277
- x = x.view(b, t, f, c).permute(0, 3, 1, 2)
278
- return x, new_att_cache
279
-
280
-
281
- class TSTransformerEncoder(nn.Module):
282
- def __init__(self,
283
- input_size: int = 64,
284
- hidden_size: int = 256,
285
- attention_heads: int = 4,
286
- num_blocks: int = 6,
287
- dropout_rate: float = 0.1,
288
- max_time_relative_position: int = 1024,
289
- max_freq_relative_position: int = 128,
290
- chunk_size: int = 1,
291
- num_left_chunks: int = 128,
292
- num_right_chunks: int = 2,
293
- ):
294
- super().__init__()
295
- self.input_size = input_size
296
- self.hidden_size = hidden_size
297
-
298
- self.max_time_relative_position = max_time_relative_position
299
- self.max_freq_relative_position = max_freq_relative_position
300
- self.chunk_size = chunk_size
301
- self.num_left_chunks = num_left_chunks
302
- self.num_right_chunks = num_right_chunks
303
-
304
- self.input_linear = nn.Linear(
305
- in_features=self.input_size,
306
- out_features=self.hidden_size,
307
- )
308
-
309
- self.encoder_layer_list = torch.nn.ModuleList([
310
- TSTransformerBlock(
311
- input_dim=hidden_size,
312
- n_heads=attention_heads,
313
- dropout_rate=dropout_rate,
314
- max_time_relative_position=max_time_relative_position,
315
- max_freq_relative_position=max_freq_relative_position,
316
- ) for _ in range(num_blocks)
317
- ])
318
-
319
- self.output_linear = nn.Linear(
320
- in_features=self.hidden_size,
321
- out_features=self.input_size,
322
- )
323
-
324
- def forward(self,
325
- xs: torch.Tensor,
326
- ):
327
- """
328
- :param xs: Tensor, shape: [batch_size, channels, time_steps, input_size]
329
- :return: Tensor, shape: [batch_size, channels, time_steps, input_size]
330
- """
331
- batch_size, channels, time_steps, _ = xs.shape
332
- # xs shape: [batch_size, channels, time_steps, input_size]
333
- xs = xs.permute(0, 3, 2, 1)
334
- # xs shape: [batch_size, input_size, time_steps, channels]
335
- xs = self.input_linear.forward(xs)
336
- # xs shape: [batch_size, input_size, time_steps, hidden_size]
337
- xs = xs.permute(0, 3, 2, 1)
338
- # xs shape: [batch_size, hidden_size, time_steps, input_size]
339
-
340
- chunk_masks = subsequent_chunk_mask(
341
- size=time_steps,
342
- chunk_size=self.chunk_size,
343
- num_left_chunks=self.num_left_chunks,
344
- num_right_chunks=self.num_right_chunks,
345
- )
346
- chunk_masks = chunk_masks.to(xs.device)
347
- # chunk_masks shape: [time_steps, time_steps]
348
-
349
- for encoder_layer in self.encoder_layer_list:
350
- xs, _ = encoder_layer.forward(xs, chunk_masks)
351
- # xs shape: [batch_size, hidden_size, time_steps, input_size]
352
- xs = xs.permute(0, 3, 2, 1)
353
- # xs shape: [batch_size, input_size, time_steps, hidden_size]
354
- xs = self.output_linear.forward(xs)
355
- # xs shape: [batch_size, input_size, time_steps, channels]
356
- xs = xs.permute(0, 3, 2, 1)
357
- # xs shape: [batch_size, channels, time_steps, input_size]
358
-
359
- return xs
360
-
361
- def forward_chunk(self,
362
- xs: torch.Tensor,
363
- max_att_cache_length: int,
364
- attention_cache: torch.Tensor = None,
365
- ) -> Tuple[torch.Tensor, torch.Tensor]:
366
- """
367
-
368
- :param xs:
369
- :param max_att_cache_length:
370
- :param attention_cache: Tensor, shape: [num_layers, ...]
371
- :return:
372
- """
373
- # xs shape: [batch_size, channels, time_steps, input_size]
374
- xs = xs.permute(0, 3, 2, 1)
375
- xs = self.input_linear.forward(xs)
376
- xs = xs.permute(0, 3, 2, 1)
377
- # xs shape: [batch_size, hidden_size, time_steps, input_size]
378
-
379
- r_att_cache = []
380
- for idx, encoder_layer in enumerate(self.encoder_layer_list):
381
- xs, new_att_cache = encoder_layer.forward(
382
- x=xs, attention_cache=attention_cache[idx] if attention_cache is not None else None,
383
- )
384
- # new_att_cache shape: [b*f, n_heads, time_steps, dim]
385
- if new_att_cache.size(2) > max_att_cache_length:
386
- begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
387
- end = self.num_right_chunks * self.chunk_size
388
- new_att_cache = new_att_cache[:, :, -begin:-end, :]
389
- r_att_cache.append(new_att_cache)
390
-
391
- r_att_cache = torch.stack(r_att_cache, dim=0)
392
-
393
- # xs shape: [batch_size, hidden_size, time_steps, input_size]
394
- xs = xs.permute(0, 3, 2, 1)
395
- xs = self.output_linear.forward(xs)
396
- xs = xs.permute(0, 3, 2, 1)
397
- # xs shape: [batch_size, channels, time_steps, input_size]
398
-
399
- return xs, r_att_cache
400
-
401
- def forward_chunk_by_chunk(
402
- self,
403
- xs: torch.Tensor,
404
- ) -> torch.Tensor:
405
-
406
- batch_size, channels, time_steps, _ = xs.shape
407
-
408
- max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
409
- attention_cache = None
410
-
411
- outputs = []
412
- for idx in range(0, time_steps, self.chunk_size):
413
- begin = idx
414
- end = begin + self.chunk_size * (self.num_right_chunks + 1)
415
- chunk_xs = xs[:, :, begin:end, :]
416
- # chunk_xs shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
417
-
418
- ys, attention_cache = self.forward_chunk(
419
- xs=chunk_xs,
420
- max_att_cache_length=max_att_cache_length,
421
- attention_cache=attention_cache,
422
- )
423
- # ys shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
424
- ys = ys[:, :, :self.chunk_size, :]
425
-
426
- outputs.append(ys)
427
-
428
- ys = torch.cat(outputs, dim=2)
429
- return ys
430
-
431
-
432
- def main2():
433
-
434
- encoder = TransformerEncoder(
435
- input_size=64,
436
- hidden_size=256,
437
- attention_heads=4,
438
- num_blocks=6,
439
- dropout_rate=0.1,
440
- )
441
- print(encoder)
442
-
443
- x = torch.ones([4, 200, 64])
444
-
445
- x = torch.ones([4, 200, 64])
446
- y = encoder.forward(xs=x)
447
- print(y.shape)
448
-
449
- x = torch.ones([4, 200, 64])
450
- y = encoder.forward_chunk_by_chunk(xs=x)
451
- print(y.shape)
452
-
453
- return
454
-
455
-
456
- def main():
457
-
458
- encoder = TSTransformerEncoder(
459
- input_size=8,
460
- hidden_size=16,
461
- attention_heads=2,
462
- num_blocks=2,
463
- dropout_rate=0.1,
464
- )
465
- # print(encoder)
466
-
467
- x = torch.ones([4, 8, 200, 8])
468
- y = encoder.forward(xs=x)
469
- print(y.shape)
470
-
471
- x = torch.ones([4, 8, 200, 8])
472
- y = encoder.forward_chunk_by_chunk(xs=x)
473
- print(y.shape)
474
-
475
- return
476
-
477
-
478
- if __name__ == '__main__':
479
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/utils.py DELETED
@@ -1,45 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import torch
4
- import torch.nn as nn
5
-
6
-
7
- class LearnableSigmoid1d(nn.Module):
8
- def __init__(self, in_features, beta=1):
9
- super().__init__()
10
- self.beta = beta
11
- self.slope = nn.Parameter(torch.ones(in_features))
12
- self.slope.requiresGrad = True
13
-
14
- def forward(self, x):
15
- # x shape: [batch_size, time_steps, spec_bins]
16
- return self.beta * torch.sigmoid(self.slope * x)
17
-
18
-
19
- def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
20
-
21
- hann_window = torch.hann_window(win_size).to(y.device)
22
- stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
23
- center=center, pad_mode='reflect', normalized=False, return_complex=True)
24
- stft_spec = torch.view_as_real(stft_spec)
25
- mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9)
26
- pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5)
27
- # Magnitude Compression
28
- mag = torch.pow(mag, compress_factor)
29
- com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1)
30
-
31
- return mag, pha, com
32
-
33
-
34
- def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
35
- # Magnitude Decompression
36
- mag = torch.pow(mag, (1.0/compress_factor))
37
- com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha))
38
- hann_window = torch.hann_window(win_size).to(com.device)
39
- wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
40
-
41
- return wav
42
-
43
-
44
- if __name__ == '__main__':
45
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_denoise/yaml/config.yaml DELETED
@@ -1,51 +0,0 @@
1
- model_name: "nx_denoise"
2
-
3
- sample_rate: 8000
4
- segment_size: 16000
5
- n_fft: 512
6
- win_size: 200
7
- hop_size: 80
8
- # 因为 hop_size 取 80,则相当于 stft 的时间步是 10ms 一步,所以降采样也考虑到差不多的分辨率。
9
-
10
- # 2**down_sampling_num_layers,
11
- # 例如 2**6=64 就意味着 64 个值在降采样之后是一个时间步,
12
- # 则一步是 64/sample_rate = 0.008秒。
13
- # 那么 tsfm_chunk_size=2 则为16ms,tsfm_chunk_size=4 则为32ms
14
- # 假设每次向左看1秒,向右看30ms,则:
15
- # tsfm_chunk_size=1,tsfm_num_left_chunks=128,tsfm_num_right_chunks=4
16
- # tsfm_chunk_size=2,tsfm_num_left_chunks=64,tsfm_num_right_chunks=2
17
- # tsfm_chunk_size=4,tsfm_num_left_chunks=32,tsfm_num_right_chunks=1
18
- down_sampling_num_layers: 6
19
- down_sampling_in_channels: 1
20
- down_sampling_hidden_channels: 64
21
- down_sampling_kernel_size: 4
22
- down_sampling_stride: 2
23
-
24
- causal_in_channels: 1
25
- causal_out_channels: 64
26
- causal_kernel_size: 3
27
- causal_bias: false
28
- causal_separable: true
29
- causal_f_stride: 1
30
- causal_num_layers: 3
31
-
32
- tsfm_hidden_size: 256
33
- tsfm_attention_heads: 8
34
- tsfm_num_blocks: 6
35
- tsfm_dropout_rate: 0.1
36
- tsfm_max_length: 512
37
- tsfm_chunk_size: 1
38
- tsfm_num_left_chunks: 128
39
- tsfm_num_right_chunks: 4
40
-
41
- discriminator_dim: 32
42
- discriminator_in_channel: 2
43
-
44
- compress_factor: 0.3
45
-
46
- batch_size: 4
47
- learning_rate: 0.0005
48
- adam_b1: 0.8
49
- adam_b2: 0.99
50
- lr_decay: 0.99
51
- seed: 1234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_dfnet/configuration_nx_dfnet.py DELETED
@@ -1,102 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- from typing import Tuple
4
-
5
- from toolbox.torchaudio.configuration_utils import PretrainedConfig
6
-
7
-
8
- class NXDfNetConfig(PretrainedConfig):
9
- def __init__(self,
10
- sample_rate: int = 8000,
11
- freq_bins: int = 256,
12
- win_size: int = 200,
13
- hop_size: int = 100,
14
-
15
- conv_channels: int = 64,
16
- conv_kernel_size_input: Tuple[int, int] = (3, 3),
17
- conv_kernel_size_inner: Tuple[int, int] = (1, 3),
18
- conv_lookahead: int = 0,
19
-
20
- convt_kernel_size_inner: Tuple[int, int] = (1, 3),
21
-
22
- embedding_hidden_size: int = 256,
23
- encoder_combine_op: str = "concat",
24
-
25
- encoder_emb_skip_op: str = "none",
26
- encoder_emb_linear_groups: int = 16,
27
- encoder_emb_hidden_size: int = 256,
28
-
29
- encoder_linear_groups: int = 32,
30
-
31
- lsnr_max: int = 30,
32
- lsnr_min: int = -15,
33
- norm_tau: float = 1.,
34
-
35
- decoder_emb_num_layers: int = 3,
36
- decoder_emb_skip_op: str = "none",
37
- decoder_emb_linear_groups: int = 16,
38
- decoder_emb_hidden_size: int = 256,
39
-
40
- df_decoder_hidden_size: int = 256,
41
- df_num_layers: int = 2,
42
- df_order: int = 5,
43
- df_bins: int = 96,
44
- df_gru_skip: str = "grouped_linear",
45
- df_decoder_linear_groups: int = 16,
46
- df_pathway_kernel_size_t: int = 5,
47
- df_lookahead: int = 2,
48
-
49
- use_post_filter: bool = False,
50
- **kwargs
51
- ):
52
- super(NXDfNetConfig, self).__init__(**kwargs)
53
- # transform
54
- self.sample_rate = sample_rate
55
- self.freq_bins = freq_bins
56
- self.win_size = win_size
57
- self.hop_size = hop_size
58
-
59
- # conv
60
- self.conv_channels = conv_channels
61
- self.conv_kernel_size_input = conv_kernel_size_input
62
- self.conv_kernel_size_inner = conv_kernel_size_inner
63
- self.conv_lookahead = conv_lookahead
64
-
65
- self.convt_kernel_size_inner = convt_kernel_size_inner
66
-
67
- self.embedding_hidden_size = embedding_hidden_size
68
-
69
- # encoder
70
- self.encoder_emb_skip_op = encoder_emb_skip_op
71
- self.encoder_emb_linear_groups = encoder_emb_linear_groups
72
- self.encoder_emb_hidden_size = encoder_emb_hidden_size
73
-
74
- self.encoder_linear_groups = encoder_linear_groups
75
- self.encoder_combine_op = encoder_combine_op
76
-
77
- self.lsnr_max = lsnr_max
78
- self.lsnr_min = lsnr_min
79
- self.norm_tau = norm_tau
80
-
81
- # decoder
82
- self.decoder_emb_num_layers = decoder_emb_num_layers
83
- self.decoder_emb_skip_op = decoder_emb_skip_op
84
- self.decoder_emb_linear_groups = decoder_emb_linear_groups
85
- self.decoder_emb_hidden_size = decoder_emb_hidden_size
86
-
87
- # df decoder
88
- self.df_decoder_hidden_size = df_decoder_hidden_size
89
- self.df_num_layers = df_num_layers
90
- self.df_order = df_order
91
- self.df_bins = df_bins
92
- self.df_gru_skip = df_gru_skip
93
- self.df_decoder_linear_groups = df_decoder_linear_groups
94
- self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
95
- self.df_lookahead = df_lookahead
96
-
97
- # runtime
98
- self.use_post_filter = use_post_filter
99
-
100
-
101
- if __name__ == "__main__":
102
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_dfnet/modeling_nx_dfnet.py DELETED
@@ -1,989 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import os
4
- import math
5
- from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
6
-
7
- import numpy as np
8
- import torch
9
- import torch.nn as nn
10
- from torch.nn import functional as F
11
- import torchaudio
12
-
13
- from toolbox.torchaudio.models.nx_dfnet.utils import overlap_and_add
14
- from toolbox.torchaudio.models.nx_dfnet.configuration_nx_dfnet import NXDfNetConfig
15
- from toolbox.torchaudio.configuration_utils import CONFIG_FILE
16
-
17
-
18
- MODEL_FILE = "model.pt"
19
-
20
-
21
- norm_layer_dict = {
22
- "batch_norm_2d": torch.nn.BatchNorm2d
23
- }
24
-
25
-
26
- activation_layer_dict = {
27
- "relu": torch.nn.ReLU,
28
- "identity": torch.nn.Identity,
29
- "sigmoid": torch.nn.Sigmoid,
30
- }
31
-
32
-
33
- class CausalConv2d(nn.Sequential):
34
- def __init__(self,
35
- in_channels: int,
36
- out_channels: int,
37
- kernel_size: Union[int, Iterable[int]],
38
- fstride: int = 1,
39
- dilation: int = 1,
40
- fpad: bool = True,
41
- bias: bool = True,
42
- separable: bool = False,
43
- norm_layer: str = "batch_norm_2d",
44
- activation_layer: str = "relu",
45
- lookahead: int = 0
46
- ):
47
- """
48
- Causal Conv2d by delaying the signal for any lookahead.
49
-
50
- Expected input format: [batch_size, channels, time_steps, spec_dim]
51
-
52
- :param in_channels:
53
- :param out_channels:
54
- :param kernel_size:
55
- :param fstride:
56
- :param dilation:
57
- :param fpad:
58
- """
59
- super(CausalConv2d, self).__init__()
60
- kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
61
-
62
- if fpad:
63
- fpad_ = kernel_size[1] // 2 + dilation - 1
64
- else:
65
- fpad_ = 0
66
-
67
- # for last 2 dim, pad (left, right, top, bottom).
68
- pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
69
-
70
- layers = list()
71
- if any(x > 0 for x in pad):
72
- layers.append(nn.ConstantPad2d(pad, 0.0))
73
-
74
- groups = math.gcd(in_channels, out_channels) if separable else 1
75
- if groups == 1:
76
- separable = False
77
- if max(kernel_size) == 1:
78
- separable = False
79
-
80
- layers.append(
81
- nn.Conv2d(
82
- in_channels,
83
- out_channels,
84
- kernel_size=kernel_size,
85
- padding=(0, fpad_),
86
- stride=(1, fstride), # stride over time is always 1
87
- dilation=(1, dilation), # dilation over time is always 1
88
- groups=groups,
89
- bias=bias,
90
- )
91
- )
92
-
93
- if separable:
94
- layers.append(
95
- nn.Conv2d(
96
- out_channels,
97
- out_channels,
98
- kernel_size=1,
99
- bias=False,
100
- )
101
- )
102
-
103
- if norm_layer is not None:
104
- norm_layer = norm_layer_dict[norm_layer]
105
- layers.append(norm_layer(out_channels))
106
-
107
- if activation_layer is not None:
108
- activation_layer = activation_layer_dict[activation_layer]
109
- layers.append(activation_layer())
110
-
111
- super().__init__(*layers)
112
-
113
- def forward(self, inputs):
114
- for module in self:
115
- inputs = module(inputs)
116
- return inputs
117
-
118
-
119
- class CausalConvTranspose2d(nn.Sequential):
120
- def __init__(self,
121
- in_channels: int,
122
- out_channels: int,
123
- kernel_size: Union[int, Iterable[int]],
124
- fstride: int = 1,
125
- dilation: int = 1,
126
- fpad: bool = True,
127
- bias: bool = True,
128
- separable: bool = False,
129
- norm_layer: str = "batch_norm_2d",
130
- activation_layer: str = "relu",
131
- lookahead: int = 0
132
- ):
133
- """
134
- Causal ConvTranspose2d.
135
-
136
- Expected input format: [batch_size, channels, time_steps, spec_dim]
137
- """
138
- super(CausalConvTranspose2d, self).__init__()
139
-
140
- kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
141
-
142
- if fpad:
143
- fpad_ = kernel_size[1] // 2
144
- else:
145
- fpad_ = 0
146
-
147
- # for last 2 dim, pad (left, right, top, bottom).
148
- pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
149
-
150
- layers = []
151
- if any(x > 0 for x in pad):
152
- layers.append(nn.ConstantPad2d(pad, 0.0))
153
-
154
- groups = math.gcd(in_channels, out_channels) if separable else 1
155
- if groups == 1:
156
- separable = False
157
-
158
- layers.append(
159
- nn.ConvTranspose2d(
160
- in_channels,
161
- out_channels,
162
- kernel_size=kernel_size,
163
- padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
164
- output_padding=(0, fpad_),
165
- stride=(1, fstride), # stride over time is always 1
166
- dilation=(1, dilation), # dilation over time is always 1
167
- groups=groups,
168
- bias=bias,
169
- )
170
- )
171
-
172
- if separable:
173
- layers.append(
174
- nn.Conv2d(
175
- out_channels,
176
- out_channels,
177
- kernel_size=1,
178
- bias=False,
179
- )
180
- )
181
-
182
- if norm_layer is not None:
183
- norm_layer = norm_layer_dict[norm_layer]
184
- layers.append(norm_layer(out_channels))
185
-
186
- if activation_layer is not None:
187
- activation_layer = activation_layer_dict[activation_layer]
188
- layers.append(activation_layer())
189
-
190
- super().__init__(*layers)
191
-
192
-
193
- class GroupedLinear(nn.Module):
194
-
195
- def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
196
- super().__init__()
197
- # self.weight: Tensor
198
- self.input_size = input_size
199
- self.hidden_size = hidden_size
200
- self.groups = groups
201
- assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
202
- assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
203
- self.ws = input_size // groups
204
- self.register_parameter(
205
- "weight",
206
- torch.nn.Parameter(
207
- torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
208
- ),
209
- )
210
- self.reset_parameters()
211
-
212
- def reset_parameters(self):
213
- nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
214
-
215
- def forward(self, x: torch.Tensor) -> torch.Tensor:
216
- # x: [..., I]
217
- b, t, _ = x.shape
218
- # new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
219
- new_shape = (b, t, self.groups, self.ws)
220
- x = x.view(new_shape)
221
- # The better way, but not supported by torchscript
222
- # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
223
- x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
224
- x = x.flatten(2, 3) # [B, T, H]
225
- return x
226
-
227
- def __repr__(self):
228
- cls = self.__class__.__name__
229
- return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"
230
-
231
-
232
- class SqueezedGRU_S(nn.Module):
233
- """
234
- SGE net: Video object detection with squeezed GRU and information entropy map
235
- https://arxiv.org/abs/2106.07224
236
- """
237
-
238
- def __init__(
239
- self,
240
- input_size: int,
241
- hidden_size: int,
242
- output_size: Optional[int] = None,
243
- num_layers: int = 1,
244
- linear_groups: int = 8,
245
- batch_first: bool = True,
246
- skip_op: str = "none",
247
- activation_layer: str = "identity",
248
- ):
249
- super().__init__()
250
- self.input_size = input_size
251
- self.hidden_size = hidden_size
252
-
253
- self.linear_in = nn.Sequential(
254
- GroupedLinear(
255
- input_size=input_size,
256
- hidden_size=hidden_size,
257
- groups=linear_groups,
258
- ),
259
- activation_layer_dict[activation_layer](),
260
- )
261
-
262
- # gru skip operator
263
- self.gru_skip_op = None
264
-
265
- if skip_op == "none":
266
- self.gru_skip_op = None
267
- elif skip_op == "identity":
268
- if not input_size != output_size:
269
- raise AssertionError("Dimensions do not match")
270
- self.gru_skip_op = nn.Identity()
271
- elif skip_op == "grouped_linear":
272
- self.gru_skip_op = GroupedLinear(
273
- input_size=hidden_size,
274
- hidden_size=hidden_size,
275
- groups=linear_groups,
276
- )
277
- else:
278
- raise NotImplementedError()
279
-
280
- self.gru = nn.GRU(
281
- input_size=hidden_size,
282
- hidden_size=hidden_size,
283
- num_layers=num_layers,
284
- batch_first=batch_first,
285
- bidirectional=False,
286
- )
287
-
288
- if output_size is not None:
289
- self.linear_out = nn.Sequential(
290
- GroupedLinear(
291
- input_size=hidden_size,
292
- hidden_size=output_size,
293
- groups=linear_groups,
294
- ),
295
- activation_layer_dict[activation_layer](),
296
- )
297
- else:
298
- self.linear_out = nn.Identity()
299
-
300
- def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
301
- x = self.linear_in(inputs)
302
-
303
- x, h = self.gru.forward(x, h)
304
-
305
- x = self.linear_out(x)
306
-
307
- if self.gru_skip_op is not None:
308
- x = x + self.gru_skip_op(inputs)
309
-
310
- return x, h
311
-
312
-
313
- class Add(nn.Module):
314
- def forward(self, a, b):
315
- return a + b
316
-
317
-
318
- class Concat(nn.Module):
319
- def forward(self, a, b):
320
- return torch.cat((a, b), dim=-1)
321
-
322
-
323
- class DeepSTFT(nn.Module):
324
- def __init__(self, win_size: int, freq_bins: int):
325
- super(DeepSTFT, self).__init__()
326
- self.win_size = win_size
327
- self.freq_bins = freq_bins
328
-
329
- self.conv1d_U = nn.Conv1d(
330
- in_channels=1,
331
- out_channels=freq_bins * 2,
332
- kernel_size=win_size,
333
- stride=win_size // 2,
334
- bias=False
335
- )
336
-
337
- def forward(self, signal: torch.Tensor):
338
- """
339
- :param signal: Tensor, shape: [batch_size, num_samples]
340
- :return: v, Tensor, shape: [batch_size, freq_bins, time_steps, 2],
341
- where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1
342
- """
343
- signal = torch.unsqueeze(signal, 1)
344
- # signal shape: [batch_size, 1, num_samples]
345
- spec = F.relu(self.conv1d_U(signal))
346
- # spec shape: [batch_size, freq_bins * 2, time_steps]
347
- b, f2, t = spec.shape
348
- spec = spec.view(b, f2//2, 2, t).permute(0, 1, 3, 2)
349
- # spec shape: [batch_size, freq_bins, time_steps, 2]
350
- return spec
351
-
352
-
353
- class DeepISTFT(nn.Module):
354
- def __init__(self, win_size: int, freq_bins: int):
355
- super(DeepISTFT, self).__init__()
356
- self.win_size = win_size
357
- self.freq_bins = freq_bins
358
-
359
- self.basis_signals = nn.Linear(
360
- in_features=freq_bins * 2,
361
- out_features=win_size,
362
- bias=False
363
- )
364
-
365
- def forward(self,
366
- spec: torch.Tensor,
367
- ):
368
- """
369
- :param spec: Tensor, shape: [batch_size, freq_bins, time_steps, 2],
370
- where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1
371
- :return: Tensor, shape: [batch_size, c, num_samples],
372
- """
373
- b, f, t, _ = spec.shape
374
- # spec shape: [b, f, t, 2]
375
- spec = spec.permute(0, 2, 1, 3)
376
- # spec shape: [b, t, f, 2]
377
- spec = spec.view(b, 1, t, -1)
378
- # spec shape: [b, 1, t, f2]
379
- signal = self.basis_signals(spec)
380
- # signal shape: [b, 1, t, win_size]
381
- signal = overlap_and_add(signal, self.win_size//2)
382
- # signal shape: [b, 1, num_samples]
383
- return signal
384
-
385
-
386
- class Encoder(nn.Module):
387
- def __init__(self, config: NXDfNetConfig):
388
- super(Encoder, self).__init__()
389
- self.embedding_input_size = config.conv_channels * config.freq_bins // 4
390
- self.embedding_output_size = config.conv_channels * config.freq_bins // 4
391
- self.embedding_hidden_size = config.embedding_hidden_size
392
-
393
- self.spec_conv0 = CausalConv2d(
394
- in_channels=1,
395
- out_channels=config.conv_channels,
396
- kernel_size=config.conv_kernel_size_input,
397
- bias=False,
398
- separable=True,
399
- fstride=1,
400
- lookahead=config.conv_lookahead,
401
- )
402
- self.spec_conv1 = CausalConv2d(
403
- in_channels=config.conv_channels,
404
- out_channels=config.conv_channels,
405
- kernel_size=config.conv_kernel_size_inner,
406
- bias=False,
407
- separable=True,
408
- fstride=2,
409
- lookahead=config.conv_lookahead,
410
- )
411
- self.spec_conv2 = CausalConv2d(
412
- in_channels=config.conv_channels,
413
- out_channels=config.conv_channels,
414
- kernel_size=config.conv_kernel_size_inner,
415
- bias=False,
416
- separable=True,
417
- fstride=2,
418
- lookahead=config.conv_lookahead,
419
- )
420
- self.spec_conv3 = CausalConv2d(
421
- in_channels=config.conv_channels,
422
- out_channels=config.conv_channels,
423
- kernel_size=config.conv_kernel_size_inner,
424
- bias=False,
425
- separable=True,
426
- fstride=1,
427
- lookahead=config.conv_lookahead,
428
- )
429
-
430
- self.df_conv0 = CausalConv2d(
431
- in_channels=2,
432
- out_channels=config.conv_channels,
433
- kernel_size=config.conv_kernel_size_input,
434
- bias=False,
435
- separable=True,
436
- fstride=1,
437
- )
438
- self.df_conv1 = CausalConv2d(
439
- in_channels=config.conv_channels,
440
- out_channels=config.conv_channels,
441
- kernel_size=config.conv_kernel_size_inner,
442
- bias=False,
443
- separable=True,
444
- fstride=2,
445
- )
446
- self.df_fc_emb = nn.Sequential(
447
- GroupedLinear(
448
- config.conv_channels * config.df_bins // 2,
449
- self.embedding_input_size,
450
- groups=config.encoder_linear_groups
451
- ),
452
- nn.ReLU(inplace=True)
453
- )
454
-
455
- if config.encoder_combine_op == "concat":
456
- self.embedding_input_size *= 2
457
- self.combine = Concat()
458
- else:
459
- self.combine = Add()
460
-
461
- # emb_gru
462
- if config.freq_bins % 8 != 0:
463
- raise AssertionError("freq_bins should be divisible by 8")
464
-
465
- self.emb_gru = SqueezedGRU_S(
466
- self.embedding_input_size,
467
- self.embedding_hidden_size,
468
- output_size=self.embedding_output_size,
469
- num_layers=1,
470
- batch_first=True,
471
- skip_op=config.encoder_emb_skip_op,
472
- linear_groups=config.encoder_emb_linear_groups,
473
- activation_layer="relu",
474
- )
475
-
476
- # lsnr
477
- self.lsnr_fc = nn.Sequential(
478
- nn.Linear(self.embedding_output_size, 1),
479
- nn.Sigmoid()
480
- )
481
- self.lsnr_scale = config.lsnr_max - config.lsnr_min
482
- self.lsnr_offset = config.lsnr_min
483
-
484
- def forward(self,
485
- power_spec: torch.Tensor,
486
- df_spec: torch.Tensor,
487
- hidden_state: torch.Tensor = None,
488
- ):
489
- # power_spec shape: (batch_size, 1, time_steps, spec_dim)
490
- e0 = self.spec_conv0.forward(power_spec)
491
- e1 = self.spec_conv1.forward(e0)
492
- e2 = self.spec_conv2.forward(e1)
493
- e3 = self.spec_conv3.forward(e2)
494
- # e0 shape: [batch_size, channels, time_steps, spec_dim]
495
- # e1 shape: [batch_size, channels, time_steps, spec_dim // 2]
496
- # e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
497
- # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
498
-
499
- # df_spec, shape: (batch_size, 2, time_steps, df_bins)
500
- c0 = self.df_conv0(df_spec)
501
- c1 = self.df_conv1(c0)
502
- # c0 shape: [batch_size, channels, time_steps, df_bins]
503
- # c1 shape: [batch_size, channels, time_steps, df_bins // 2]
504
-
505
- cemb = c1.permute(0, 2, 3, 1)
506
- # cemb shape: [batch_size, time_steps, df_bins // 2, channels]
507
- cemb = cemb.flatten(2)
508
- # cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
509
- cemb = self.df_fc_emb(cemb)
510
- # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
511
-
512
- # e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
513
- emb = e3.permute(0, 2, 3, 1)
514
- # emb shape: [batch_size, time_steps, spec_dim // 4, channels]
515
- emb = emb.flatten(2)
516
- # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
517
-
518
- emb = self.combine(emb, cemb)
519
- # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
520
- # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
521
-
522
- emb, h = self.emb_gru.forward(emb, hidden_state)
523
- # emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
524
- # h shape: [batch_size, 1, spec_dim]
525
-
526
- lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
527
- # lsnr shape: [batch_size, time_steps, 1]
528
-
529
- return e0, e1, e2, e3, emb, c0, lsnr, h
530
-
531
-
532
- class Decoder(nn.Module):
533
- def __init__(self, config: NXDfNetConfig):
534
- super(Decoder, self).__init__()
535
-
536
- if config.freq_bins % 8 != 0:
537
- raise AssertionError("freq_bins should be divisible by 8")
538
-
539
- self.emb_in_dim = config.conv_channels * config.freq_bins // 4
540
- self.emb_out_dim = config.conv_channels * config.freq_bins // 4
541
- self.emb_hidden_dim = config.decoder_emb_hidden_size
542
-
543
- self.emb_gru = SqueezedGRU_S(
544
- self.emb_in_dim,
545
- self.emb_hidden_dim,
546
- output_size=self.emb_out_dim,
547
- num_layers=config.decoder_emb_num_layers - 1,
548
- batch_first=True,
549
- skip_op=config.decoder_emb_skip_op,
550
- linear_groups=config.decoder_emb_linear_groups,
551
- activation_layer="relu",
552
- )
553
- self.conv3p = CausalConv2d(
554
- in_channels=config.conv_channels,
555
- out_channels=config.conv_channels,
556
- kernel_size=1,
557
- bias=False,
558
- separable=True,
559
- fstride=1,
560
- lookahead=config.conv_lookahead,
561
- )
562
- self.convt3 = CausalConv2d(
563
- in_channels=config.conv_channels,
564
- out_channels=config.conv_channels,
565
- kernel_size=config.conv_kernel_size_inner,
566
- bias=False,
567
- separable=True,
568
- fstride=1,
569
- lookahead=config.conv_lookahead,
570
- )
571
- self.conv2p = CausalConv2d(
572
- in_channels=config.conv_channels,
573
- out_channels=config.conv_channels,
574
- kernel_size=1,
575
- bias=False,
576
- separable=True,
577
- fstride=1,
578
- lookahead=config.conv_lookahead,
579
- )
580
- self.convt2 = CausalConvTranspose2d(
581
- in_channels=config.conv_channels,
582
- out_channels=config.conv_channels,
583
- kernel_size=config.convt_kernel_size_inner,
584
- bias=False,
585
- separable=True,
586
- fstride=2,
587
- lookahead=config.conv_lookahead,
588
- )
589
- self.conv1p = CausalConv2d(
590
- in_channels=config.conv_channels,
591
- out_channels=config.conv_channels,
592
- kernel_size=1,
593
- bias=False,
594
- separable=True,
595
- fstride=1,
596
- lookahead=config.conv_lookahead,
597
- )
598
- self.convt1 = CausalConvTranspose2d(
599
- in_channels=config.conv_channels,
600
- out_channels=config.conv_channels,
601
- kernel_size=config.convt_kernel_size_inner,
602
- bias=False,
603
- separable=True,
604
- fstride=2,
605
- lookahead=config.conv_lookahead,
606
- )
607
- self.conv0p = CausalConv2d(
608
- in_channels=config.conv_channels,
609
- out_channels=config.conv_channels,
610
- kernel_size=1,
611
- bias=False,
612
- separable=True,
613
- fstride=1,
614
- lookahead=config.conv_lookahead,
615
- )
616
- self.conv0_out = CausalConv2d(
617
- in_channels=config.conv_channels,
618
- out_channels=1,
619
- kernel_size=config.conv_kernel_size_inner,
620
- activation_layer="sigmoid",
621
- bias=False,
622
- separable=True,
623
- fstride=1,
624
- lookahead=config.conv_lookahead,
625
- )
626
-
627
- def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor:
628
- # Estimates erb mask
629
- b, _, t, f8 = e3.shape
630
-
631
- # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
632
- emb, _ = self.emb_gru(emb)
633
- # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
634
- emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
635
- e3 = self.convt3(self.conv3p(e3) + emb)
636
- # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
637
- e2 = self.convt2(self.conv2p(e2) + e3)
638
- # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2]
639
- e1 = self.convt1(self.conv1p(e1) + e2)
640
- # e1 shape: [batch_size, conv_channels, time_steps, freq_dim]
641
- mask = self.conv0_out(self.conv0p(e0) + e1)
642
- # mask shape: [batch_size, 1, time_steps, freq_dim]
643
- return mask
644
-
645
-
646
- class DfDecoder(nn.Module):
647
- def __init__(self, config: NXDfNetConfig):
648
- super(DfDecoder, self).__init__()
649
-
650
- self.embedding_input_size = config.conv_channels * config.freq_bins // 4
651
- self.df_decoder_hidden_size = config.df_decoder_hidden_size
652
- self.df_num_layers = config.df_num_layers
653
-
654
- self.df_order = config.df_order
655
-
656
- self.df_bins = config.df_bins
657
- self.df_out_ch = config.df_order * 2
658
-
659
- self.df_convp = CausalConv2d(
660
- config.conv_channels,
661
- self.df_out_ch,
662
- fstride=1,
663
- kernel_size=(config.df_pathway_kernel_size_t, 1),
664
- separable=True,
665
- bias=False,
666
- )
667
- self.df_gru = SqueezedGRU_S(
668
- self.embedding_input_size,
669
- self.df_decoder_hidden_size,
670
- num_layers=self.df_num_layers,
671
- batch_first=True,
672
- skip_op="none",
673
- activation_layer="relu",
674
- )
675
-
676
- if config.df_gru_skip == "none":
677
- self.df_skip = None
678
- elif config.df_gru_skip == "identity":
679
- if config.embedding_hidden_size != config.df_decoder_hidden_size:
680
- raise AssertionError("Dimensions do not match")
681
- self.df_skip = nn.Identity()
682
- elif config.df_gru_skip == "grouped_linear":
683
- self.df_skip = GroupedLinear(
684
- self.embedding_input_size,
685
- self.df_decoder_hidden_size,
686
- groups=config.df_decoder_linear_groups
687
- )
688
- else:
689
- raise NotImplementedError()
690
-
691
- self.df_out: nn.Module
692
- out_dim = self.df_bins * self.df_out_ch
693
-
694
- self.df_out = nn.Sequential(
695
- GroupedLinear(
696
- input_size=self.df_decoder_hidden_size,
697
- hidden_size=out_dim,
698
- groups=config.df_decoder_linear_groups
699
- ),
700
- nn.Tanh()
701
- )
702
- self.df_fc_a = nn.Sequential(
703
- nn.Linear(self.df_decoder_hidden_size, 1),
704
- nn.Sigmoid()
705
- )
706
-
707
- def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor:
708
- # emb shape: [batch_size, time_steps, df_bins // 4 * channels]
709
- b, t, _ = emb.shape
710
- df_coefs, _ = self.df_gru(emb)
711
- if self.df_skip is not None:
712
- df_coefs = df_coefs + self.df_skip(emb)
713
- # df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size]
714
-
715
- # c0 shape: [batch_size, channels, time_steps, df_bins]
716
- c0 = self.df_convp(c0)
717
- # c0 shape: [batch_size, df_order * 2, time_steps, df_bins]
718
- c0 = c0.permute(0, 2, 3, 1)
719
- # c0 shape: [batch_size, time_steps, df_bins, df_order * 2]
720
-
721
- df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order
722
- # df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2]
723
- df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch)
724
- # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
725
- df_coefs = df_coefs + c0
726
- # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
727
- return df_coefs
728
-
729
-
730
- class DfOutputReshapeMF(nn.Module):
731
- """Coefficients output reshape for multiframe/MultiFrameModule
732
-
733
- Requires input of shape B, C, T, F, 2.
734
- """
735
-
736
- def __init__(self, df_order: int, df_bins: int):
737
- super().__init__()
738
- self.df_order = df_order
739
- self.df_bins = df_bins
740
-
741
- def forward(self, coefs: torch.Tensor) -> torch.Tensor:
742
- # [B, T, F, O*2] -> [B, O, T, F, 2]
743
- new_shape = list(coefs.shape)
744
- new_shape[-1] = -1
745
- new_shape.append(2)
746
- coefs = coefs.view(new_shape)
747
- coefs = coefs.permute(0, 3, 1, 2, 4)
748
- return coefs
749
-
750
-
751
- class Mask(nn.Module):
752
- def __init__(self, use_post_filter: bool = False, eps: float = 1e-12):
753
- super().__init__()
754
- self.use_post_filter = use_post_filter
755
- self.eps = eps
756
-
757
- def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
758
- """
759
- Post-Filter
760
-
761
- A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
762
- https://arxiv.org/abs/2008.04259
763
-
764
- :param mask: Real valued mask, typically of shape [B, C, T, F].
765
- :param beta: Global gain factor.
766
- :return:
767
- """
768
- mask_sin = mask * torch.sin(np.pi * mask / 2)
769
- mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
770
- return mask_pf
771
-
772
- def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
773
- # spec shape: [batch_size, 1, time_steps, freq_bins, 2]
774
-
775
- if not self.training and self.use_post_filter:
776
- mask = self.post_filter(mask)
777
-
778
- # mask shape: [batch_size, 1, time_steps, freq_bins]
779
- mask = mask.unsqueeze(4)
780
- # mask shape: [batch_size, 1, time_steps, freq_bins, 1]
781
- return spec * mask
782
-
783
-
784
- class DeepFiltering(nn.Module):
785
- def __init__(self,
786
- df_bins: int,
787
- df_order: int,
788
- lookahead: int = 0,
789
- ):
790
- super(DeepFiltering, self).__init__()
791
- self.df_bins = df_bins
792
- self.df_order = df_order
793
- self.need_unfold = df_order > 1
794
- self.lookahead = lookahead
795
-
796
- self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0)
797
-
798
- def spec_unfold(self, spec: torch.Tensor):
799
- """
800
- Pads and unfolds the spectrogram according to frame_size.
801
- :param spec: complex Tensor, Spectrogram of shape [B, C, T, F].
802
- :return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
803
- """
804
- if self.need_unfold:
805
- # spec shape: [batch_size, freq_bins, time_steps]
806
- spec_pad = self.pad(spec)
807
- # spec_pad shape: [batch_size, 1, time_steps_pad, freq_bins]
808
- spec_unfold = spec_pad.unfold(2, self.df_order, 1)
809
- # spec_unfold shape: [batch_size, 1, time_steps, freq_bins, df_order]
810
- return spec_unfold
811
- else:
812
- return spec.unsqueeze(-1)
813
-
814
- def forward(self,
815
- spec: torch.Tensor,
816
- coefs: torch.Tensor,
817
- ):
818
- # spec shape: [batch_size, 1, time_steps, freq_bins, 2]
819
- spec = spec.contiguous()
820
- spec_u = self.spec_unfold(torch.view_as_complex(spec))
821
- # spec_u shape: [batch_size, 1, time_steps, freq_bins, df_order]
822
-
823
- # coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
824
- coefs = torch.view_as_complex(coefs)
825
- # coefs shape: [batch_size, df_order, time_steps, df_bins]
826
- spec_f = spec_u.narrow(-2, 0, self.df_bins)
827
- # spec_f shape: [batch_size, 1, time_steps, df_bins, df_order]
828
-
829
- coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:])
830
- # coefs shape: [batch_size, 1, df_order, time_steps, df_bins]
831
-
832
- spec_f = self.df(spec_f, coefs)
833
- # spec_f shape: [batch_size, 1, time_steps, df_bins]
834
-
835
- if self.training:
836
- spec = spec.clone()
837
- spec[..., :self.df_bins, :] = torch.view_as_real(spec_f)
838
- # spec shape: [batch_size, 1, time_steps, freq_bins, 2]
839
- return spec
840
-
841
- @staticmethod
842
- def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
843
- """
844
- Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
845
- :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N].
846
- :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F].
847
- :return: (complex Tensor). Spectrogram of shape [B, C, T, F].
848
- """
849
- return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
850
-
851
-
852
- class NXDfNet(nn.Module):
853
- def __init__(self, config: NXDfNetConfig):
854
- super(NXDfNet, self).__init__()
855
- self.config = config
856
-
857
- self.stft = DeepSTFT(win_size=config.win_size, freq_bins=config.freq_bins)
858
- self.istft = DeepISTFT(win_size=config.win_size, freq_bins=config.freq_bins)
859
-
860
- self.encoder = Encoder(config)
861
- self.decoder = Decoder(config)
862
-
863
- self.df_decoder = DfDecoder(config)
864
- self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins)
865
- self.df_op = DeepFiltering(
866
- df_bins=config.df_bins,
867
- df_order=config.df_order,
868
- lookahead=config.df_lookahead,
869
- )
870
-
871
- self.mask = Mask(use_post_filter=config.use_post_filter)
872
-
873
- def forward(self,
874
- noisy: torch.Tensor,
875
- ):
876
- """
877
- :param noisy: Tensor, shape: [batch_size, num_samples]
878
- :return:
879
- """
880
- spec = self.stft.forward(noisy)
881
- # spec shape: [batch_size, freq_bins, time_steps, 2]
882
- power_spec = torch.sum(torch.square(spec), dim=-1)
883
- power_spec = power_spec.unsqueeze(1).permute(0, 1, 3, 2)
884
- # power_spec shape: [batch_size, freq_bins, time_steps]
885
- # power_spec shape: [batch_size, 1, freq_bins, time_steps]
886
- # power_spec shape: [batch_size, 1, time_steps, freq_bins]
887
-
888
- df_spec = spec.permute(0, 3, 2, 1)
889
- # df_spec shape: [batch_size, 2, time_steps, freq_bins]
890
- df_spec = df_spec[..., :self.df_decoder.df_bins]
891
- # df_spec shape: [batch_size, 2, time_steps, df_bins]
892
-
893
- # spec shape: [batch_size, freq_bins, time_steps, 2]
894
- spec = torch.transpose(spec, dim0=1, dim1=2)
895
- # spec shape: [batch_size, time_steps, freq_bins, 2]
896
- spec = torch.unsqueeze(spec, dim=1)
897
- # spec shape: [batch_size, 1, time_steps, freq_bins, 2]
898
-
899
- e0, e1, e2, e3, emb, c0, _, h = self.encoder.forward(power_spec, df_spec)
900
-
901
- mask = self.decoder.forward(emb, e3, e2, e1, e0)
902
- # mask shape: [batch_size, 1, time_steps, freq_bins]
903
- if torch.any(mask > 1) or torch.any(mask < 0):
904
- raise AssertionError
905
-
906
- spec_m = self.mask.forward(spec, mask)
907
-
908
- # lsnr shape: [batch_size, time_steps, 1]
909
- # lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
910
- # lsnr shape: [batch_size, 1, time_steps]
911
-
912
- df_coefs = self.df_decoder.forward(emb, c0)
913
- df_coefs = self.df_out_transform(df_coefs)
914
- # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
915
-
916
- spec_e = self.df_op.forward(spec.clone(), df_coefs)
917
- # spec_e shape: [batch_size, 1, time_steps, freq_bins, 2]
918
-
919
- spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
920
-
921
- spec_e = torch.squeeze(spec_e, dim=1)
922
- spec_e = spec_e.permute(0, 2, 1, 3)
923
- # spec_e shape: [batch_size, freq_bins, time_steps, 2]
924
-
925
- denoise = self.istft.forward(spec_e)
926
- # spec_e shape: [batch_size, freq_bins, time_steps, 2]
927
- return denoise
928
-
929
-
930
- class NXDfNetPretrainedModel(NXDfNet):
931
- def __init__(self,
932
- config: NXDfNetConfig,
933
- ):
934
- super(NXDfNetPretrainedModel, self).__init__(
935
- config=config,
936
- )
937
-
938
- @classmethod
939
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
940
- config = NXDfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
941
-
942
- model = cls(config)
943
-
944
- if os.path.isdir(pretrained_model_name_or_path):
945
- ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
946
- else:
947
- ckpt_file = pretrained_model_name_or_path
948
-
949
- with open(ckpt_file, "rb") as f:
950
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
951
- model.load_state_dict(state_dict, strict=True)
952
- return model
953
-
954
- def save_pretrained(self,
955
- save_directory: Union[str, os.PathLike],
956
- state_dict: Optional[dict] = None,
957
- ):
958
-
959
- model = self
960
-
961
- if state_dict is None:
962
- state_dict = model.state_dict()
963
-
964
- os.makedirs(save_directory, exist_ok=True)
965
-
966
- # save state dict
967
- model_file = os.path.join(save_directory, MODEL_FILE)
968
- torch.save(state_dict, model_file)
969
-
970
- # save config
971
- config_file = os.path.join(save_directory, CONFIG_FILE)
972
- self.config.to_yaml_file(config_file)
973
- return save_directory
974
-
975
-
976
- def main():
977
-
978
- config = NXDfNetConfig()
979
- model = NXDfNet(config=config)
980
-
981
- inputs = torch.randn(size=(1, 16000), dtype=torch.float32)
982
-
983
- denoise = model.forward(inputs)
984
- print(denoise.shape)
985
- return
986
-
987
-
988
- if __name__ == "__main__":
989
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_dfnet/utils.py DELETED
@@ -1,55 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- """
4
- https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py
5
- """
6
- import math
7
- import torch
8
-
9
-
10
- def overlap_and_add(signal: torch.Tensor, frame_step: int):
11
- """
12
- Reconstructs a signal from a framed representation.
13
-
14
- Adds potentially overlapping frames of a signal with shape
15
- `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
16
- The resulting tensor has shape `[..., output_size]` where
17
-
18
- output_size = (frames - 1) * frame_step + frame_length
19
-
20
- Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
21
-
22
- :param signal: Tensor, shape: [..., frames, frame_length]. All dimensions may be unknown, and rank must be at least 2.
23
- :param frame_step: int, overlap offsets. Must be less than or equal to frame_length.
24
- :return: Tensor, shape: [..., output_size].
25
- containing the overlap-added frames of signal's inner-most two dimensions.
26
- output_size = (frames - 1) * frame_step + frame_length
27
- """
28
- outer_dimensions = signal.size()[:-2]
29
- frames, frame_length = signal.size()[-2:]
30
-
31
- subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
32
- subframe_step = frame_step // subframe_length
33
- subframes_per_frame = frame_length // subframe_length
34
-
35
- output_size = frame_step * (frames - 1) + frame_length
36
- output_subframes = output_size // subframe_length
37
-
38
- subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
39
-
40
- frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
41
-
42
- frame = frame.clone().detach()
43
- frame = frame.to(signal.device)
44
- frame = frame.long()
45
-
46
- frame = frame.contiguous().view(-1)
47
-
48
- result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
49
- result.index_add_(-2, frame, subframe_signal)
50
- result = result.view(*outer_dimensions, -1)
51
- return result
52
-
53
-
54
- if __name__ == "__main__":
55
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_mpnet/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
-
4
-
5
- if __name__ == '__main__':
6
- pass
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_mpnet/causal_convolution/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
-
4
-
5
- if __name__ == '__main__':
6
- pass
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_mpnet/causal_convolution/causal_conv2d.py DELETED
@@ -1,445 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- from typing import List, Tuple, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
-
8
- from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid2d
9
-
10
-
11
- class SPConvTranspose2d(nn.Module):
12
- def __init__(self,
13
- in_channels: int,
14
- out_channels: int,
15
- kernel_size: Union[int, Tuple[int]],
16
- r=1
17
- ):
18
- super(SPConvTranspose2d, self).__init__()
19
- self.pad_freq = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
20
- self.out_channels = out_channels
21
- self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1))
22
- self.r = r
23
-
24
- def forward(self, x: torch.Tensor):
25
- x = self.pad_freq(x)
26
- out = self.conv(x)
27
-
28
- b, c, t, f = out.shape
29
-
30
- out = out.view((b, self.r, c // self.r, t, f))
31
- out = out.permute(0, 2, 3, 4, 1)
32
- out = out.contiguous().view((b, c // self.r, t, -1))
33
- return out
34
-
35
-
36
- class CausalConv2dBlock(nn.Module):
37
- def __init__(self,
38
- in_channels: int,
39
- out_channels: int,
40
- dilation: int,
41
- kernel_size: Tuple[int, int] = (2, 3),
42
- ):
43
- super(CausalConv2dBlock, self).__init__()
44
- self.pad_length = dilation
45
-
46
- self.pad_time = nn.ConstantPad2d((0, 0, self.pad_length, 0), value=0.)
47
- self.pad_freq = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
48
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, dilation=(dilation, 1))
49
- self.norm = nn.InstanceNorm2d(out_channels, affine=True)
50
- self.activation = nn.PReLU(out_channels)
51
-
52
- def forward(self,
53
- x: torch.Tensor,
54
- cache_pad: torch.Tensor = None,
55
- ):
56
- """
57
-
58
- :param x: Tensor, shape: [batch_size, channels, time_steps, dim]
59
- :param cache_pad:
60
- :return:
61
- """
62
- if cache_pad is None:
63
- x = self.pad_time(x)
64
- else:
65
- x = torch.concat(tensors=[cache_pad, x], dim=2)
66
- new_cache_pad = x[:, :, -self.pad_length:, :]
67
-
68
- x = self.pad_freq(x)
69
-
70
- x = self.conv(x)
71
- x = self.norm(x)
72
- x = self.activation(x)
73
- return x, new_cache_pad
74
-
75
-
76
- class CausalConv2dEncoder(nn.Module):
77
- def __init__(self,
78
- num_blocks: int,
79
- hidden_size: int,
80
- ):
81
- super(CausalConv2dEncoder, self).__init__()
82
- self.num_blocks = num_blocks
83
-
84
- self.blocks: List[CausalConv2dBlock] = nn.ModuleList([])
85
- for idx in range(num_blocks):
86
- in_channels = hidden_size * (idx+1)
87
- dilation = 2 ** idx
88
- block = CausalConv2dBlock(
89
- in_channels=in_channels,
90
- out_channels=hidden_size,
91
- dilation=dilation,
92
- kernel_size=(2, 3),
93
- )
94
- self.blocks.append(block)
95
-
96
- def forward(self,
97
- x: torch.Tensor,
98
- cache_pad_list: List[torch.Tensor] = None,
99
- ):
100
- """
101
- :param x: Tensor, shape: [batch_size, channels, time_steps, dim].
102
- :param cache_pad_list: List[Tensor]
103
- :return:
104
- """
105
- new_cache_pad_list = list()
106
-
107
- skip = x
108
- for idx, block in enumerate(self.blocks):
109
- x, new_cache_pad = block.forward(
110
- skip,
111
- cache_pad=None if cache_pad_list is None else cache_pad_list[idx]
112
- )
113
- new_cache_pad_list.append(new_cache_pad)
114
- skip = torch.cat([x, skip], dim=1)
115
- # x shape: [batch_size, channels, time_steps, dim].
116
- return x, new_cache_pad_list
117
-
118
- def forward_chunk(self,
119
- chunk: torch.Tensor,
120
- cache_pad_list: List[torch.Tensor] = None,
121
- ):
122
- return self.forward(chunk, cache_pad_list)
123
-
124
- def forward_chunk_by_chunk(self,
125
- x: torch.Tensor,
126
- ):
127
- """
128
- :param x: Tensor, shape: [batch_size, channels, time_steps, dim].
129
- :return:
130
- """
131
- batch_size, channels, time_steps, _ = x.shape
132
-
133
- cache_pad_list = None
134
-
135
- outputs = list()
136
- for idx in range(time_steps):
137
- chunk = x[:, :, idx:idx+1, :]
138
-
139
- y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list)
140
- outputs.append(y)
141
-
142
- outputs = torch.concat(outputs, dim=2)
143
- return outputs
144
-
145
-
146
- class DenseEncoder(nn.Module):
147
- def __init__(self,
148
- num_blocks: int,
149
- in_channels: int,
150
- out_channels: int,
151
- ):
152
- super(DenseEncoder, self).__init__()
153
- self.dense_conv_1 = nn.Sequential(
154
- nn.Conv2d(in_channels, out_channels, (1, 1)),
155
- nn.InstanceNorm2d(out_channels, affine=True),
156
- nn.PReLU(out_channels)
157
- )
158
- self.dense_block = CausalConv2dEncoder(
159
- num_blocks=num_blocks, hidden_size=out_channels,
160
- )
161
- self.dense_conv_2 = nn.Sequential(
162
- nn.Conv2d(out_channels, out_channels, (1, 3), (1, 2), padding=(0, 1)),
163
- nn.InstanceNorm2d(out_channels, affine=True),
164
- nn.PReLU(out_channels)
165
- )
166
-
167
- def forward(self,
168
- x: torch.Tensor,
169
- ):
170
- """
171
- :param x: Tensor, shape: [batch_size, channels, time_steps, dim]
172
- :return:
173
- """
174
- x = self.dense_conv_1(x)
175
- x, _ = self.dense_block.forward(x)
176
- x = self.dense_conv_2(x)
177
- # x shape: [b, c, t, f//2]
178
- return x
179
-
180
- def forward_chunk(self,
181
- x: torch.Tensor,
182
- cache_pad_list: List[torch.Tensor] = None,
183
- ):
184
- x = self.dense_conv_1(x)
185
- x, new_cache_pad_list = self.dense_block.forward(x, cache_pad_list)
186
- x = self.dense_conv_2(x)
187
- # x shape: [b, c, t, f//2]
188
- return x, new_cache_pad_list
189
-
190
- def forward_chunk_by_chunk(self,
191
- x: torch.Tensor,
192
- ):
193
- """
194
- :param x: Tensor, shape: [batch_size, channels, time_steps, dim].
195
- :return:
196
- """
197
- batch_size, channels, time_steps, _ = x.shape
198
-
199
- cache_pad_list = None
200
-
201
- outputs = list()
202
- for idx in range(time_steps):
203
- chunk = x[:, :, idx:idx+1, :]
204
-
205
- y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list)
206
- outputs.append(y)
207
-
208
- outputs = torch.concat(outputs, dim=2)
209
- return outputs
210
-
211
-
212
- class MaskDecoder(nn.Module):
213
- def __init__(self,
214
- num_blocks: int,
215
- hidden_size: int,
216
- out_channels: int = 1,
217
- beta: float = 2.0,
218
- n_fft: int = 512,
219
- ):
220
- super(MaskDecoder, self).__init__()
221
- self.dense_block = CausalConv2dEncoder(
222
- num_blocks=num_blocks, hidden_size=hidden_size,
223
- )
224
- self.mask_conv = nn.Sequential(
225
- SPConvTranspose2d(hidden_size, hidden_size, (1, 3), 2),
226
- nn.InstanceNorm2d(hidden_size, affine=True),
227
- nn.PReLU(hidden_size),
228
- nn.Conv2d(hidden_size, out_channels, (1, 2))
229
- )
230
- self.lsigmoid = LearnableSigmoid2d(n_fft//2+1, beta=beta)
231
-
232
- def forward(self,
233
- x: torch.Tensor,
234
- ):
235
- """
236
-
237
- :param x: Tensor, shape: [batch_size, channels, time_steps, dim]
238
- :return:
239
- """
240
- x, _ = self.dense_block(x)
241
- x = self.mask_conv(x)
242
- # x shape: [batch_size, 1, time_steps, dim*2-1]
243
- x = x.permute(0, 3, 2, 1).squeeze(-1)
244
- # x shape: [b, f, t]
245
- x = self.lsigmoid(x)
246
- return x
247
-
248
- def forward_chunk(self,
249
- x: torch.Tensor,
250
- cache_pad_list: List[torch.Tensor] = None,
251
- ):
252
- x, new_cache_pad_list = self.dense_block(x, cache_pad_list)
253
- x = self.mask_conv(x)
254
- # x shape: [batch_size, 1, time_steps, dim*2-1]
255
- x = x.permute(0, 3, 2, 1).squeeze(-1)
256
- # x shape: [b, f, t]
257
- x = self.lsigmoid(x)
258
- return x, new_cache_pad_list
259
-
260
- def forward_chunk_by_chunk(self,
261
- x: torch.Tensor,
262
- ):
263
- """
264
- :param x: Tensor, shape: [batch_size, channels, time_steps, dim].
265
- :return:
266
- """
267
- batch_size, channels, time_steps, _ = x.shape
268
-
269
- cache_pad_list = None
270
-
271
- outputs = list()
272
- for idx in range(time_steps):
273
- chunk = x[:, :, idx:idx+1, :]
274
-
275
- y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list)
276
- outputs.append(y)
277
-
278
- outputs = torch.concat(outputs, dim=2)
279
- return outputs
280
-
281
-
282
- class PhaseDecoder(nn.Module):
283
- def __init__(self,
284
- num_blocks: int,
285
- hidden_size: int,
286
- out_channels: int = 1,
287
- ):
288
- super(PhaseDecoder, self).__init__()
289
- self.dense_block = CausalConv2dEncoder(
290
- num_blocks=num_blocks, hidden_size=hidden_size,
291
- )
292
-
293
- self.phase_conv = nn.Sequential(
294
- SPConvTranspose2d(hidden_size, hidden_size, (1, 3), 2),
295
- nn.InstanceNorm2d(hidden_size, affine=True),
296
- nn.PReLU(hidden_size)
297
- )
298
- self.phase_conv_r = nn.Conv2d(hidden_size, out_channels, (1, 2))
299
- self.phase_conv_i = nn.Conv2d(hidden_size, out_channels, (1, 2))
300
-
301
- def forward(self,
302
- x: torch.Tensor,
303
- ):
304
- """
305
- :param x: Tensor, shape: [batch_size, channels, time_steps, dim]
306
- :return:
307
- """
308
- x, _ = self.dense_block(x)
309
-
310
- x = self.phase_conv(x)
311
- x_r = self.phase_conv_r(x)
312
- x_i = self.phase_conv_i(x)
313
- x = torch.atan2(x_i, x_r)
314
- x = x.permute(0, 3, 2, 1).squeeze(-1)
315
- # x shape: [b, f, t]
316
- return x
317
-
318
- def forward_chunk(self,
319
- x: torch.Tensor,
320
- cache_pad_list: List[torch.Tensor] = None,
321
- ):
322
- x, new_cache_pad_list = self.dense_block(x, cache_pad_list)
323
-
324
- x = self.phase_conv(x)
325
- x_r = self.phase_conv_r(x)
326
- x_i = self.phase_conv_i(x)
327
- x = torch.atan2(x_i, x_r)
328
- x = x.permute(0, 3, 2, 1).squeeze(-1)
329
- # x shape: [b, f, t]
330
- return x, new_cache_pad_list
331
-
332
- def forward_chunk_by_chunk(self,
333
- x: torch.Tensor,
334
- ):
335
- """
336
- :param x: Tensor, shape: [batch_size, channels, time_steps, dim].
337
- :return:
338
- """
339
- batch_size, channels, time_steps, _ = x.shape
340
-
341
- cache_pad_list = None
342
-
343
- outputs = list()
344
- for idx in range(time_steps):
345
- chunk = x[:, :, idx:idx+1, :]
346
-
347
- y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list)
348
- outputs.append(y)
349
-
350
- outputs = torch.concat(outputs, dim=2)
351
- return outputs
352
-
353
-
354
- def main1():
355
-
356
- encoder = CausalConv2dEncoder(
357
- num_blocks=3, hidden_size=8,
358
- )
359
-
360
- # x shape: [batch_size, channels, time_steps, dim]
361
- x = torch.rand(size=(1, 8, 200, 32))
362
- x, new_cache_pad_list = encoder.forward(x)
363
- print(x.shape)
364
- for new_cache_pad in new_cache_pad_list:
365
- print(new_cache_pad.shape)
366
-
367
- x = torch.rand(size=(1, 8, 200, 32))
368
- x = encoder.forward_chunk_by_chunk(x)
369
- print(x.shape)
370
-
371
- return
372
-
373
-
374
- def main2():
375
-
376
- encoder = DenseEncoder(
377
- num_blocks=3, in_channels=8, out_channels=8
378
- )
379
-
380
- # x shape: [batch_size, channels, time_steps, dim]
381
- x = torch.rand(size=(1, 8, 200, 32))
382
- x, new_cache_pad_list = encoder.forward(x)
383
- print(x.shape)
384
- for new_cache_pad in new_cache_pad_list:
385
- print(new_cache_pad.shape)
386
-
387
- x = torch.rand(size=(1, 8, 200, 32))
388
- x = encoder.forward_chunk_by_chunk(x)
389
- print(x.shape)
390
-
391
- return
392
-
393
-
394
- def main3():
395
-
396
- encoder = MaskDecoder(
397
- num_blocks=3, hidden_size=64, out_channels=1,
398
- n_fft=512,
399
- )
400
-
401
- # 512 // 2 + 1 = 257
402
- # 129 * 2 - 1 = 257
403
- # 257 // 2 + 1 = 129
404
-
405
- # x shape: [batch_size, channels, time_steps, dim]
406
- x = torch.rand(size=(1, 64, 201, 129))
407
- x, new_cache_pad_list = encoder.forward(x)
408
- print(x.shape)
409
- for new_cache_pad in new_cache_pad_list:
410
- print(new_cache_pad.shape)
411
-
412
- x = torch.rand(size=(1, 64, 201, 129))
413
- x = encoder.forward_chunk_by_chunk(x)
414
- print(x.shape)
415
-
416
- return
417
-
418
-
419
-
420
- def main():
421
-
422
- encoder = PhaseDecoder(
423
- num_blocks=3, hidden_size=64, out_channels=1,
424
- )
425
-
426
- # 512 // 2 + 1 = 257
427
- # 129 * 2 - 1 = 257
428
- # 257 // 2 + 1 = 129
429
-
430
- # x shape: [batch_size, channels, time_steps, dim]
431
- x = torch.rand(size=(1, 64, 201, 129))
432
- x, new_cache_pad_list = encoder.forward(x)
433
- print(x.shape)
434
- for new_cache_pad in new_cache_pad_list:
435
- print(new_cache_pad.shape)
436
-
437
- x = torch.rand(size=(1, 64, 201, 129))
438
- x = encoder.forward_chunk_by_chunk(x)
439
- print(x.shape)
440
-
441
- return
442
-
443
-
444
- if __name__ == "__main__":
445
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_mpnet/configuration_nx_mpnet.py DELETED
@@ -1,90 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- from toolbox.torchaudio.configuration_utils import PretrainedConfig
4
-
5
-
6
- class NXMPNetConfig(PretrainedConfig):
7
- """
8
- https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
9
- """
10
- def __init__(self,
11
- sample_rate: int = 8000,
12
- segment_size: int = 16000,
13
- n_fft: int = 512,
14
- win_size: int = 200,
15
- hop_size: int = 80,
16
-
17
- dense_num_blocks: int = 4,
18
- dense_hidden_size: int = 64,
19
-
20
- mask_num_blocks: int = 4,
21
- mask_hidden_size: int = 64,
22
-
23
- phase_num_blocks: int = 4,
24
- phase_hidden_size: int = 64,
25
-
26
- tsfm_hidden_size: int = 64,
27
- tsfm_attention_heads: int = 4,
28
- tsfm_num_blocks: int = 4,
29
- tsfm_dropout_rate: float = 0.0,
30
- tsfm_max_time_relative_position: int = 2048,
31
- tsfm_max_freq_relative_position: int = 256,
32
- tsfm_chunk_size: int = 1,
33
- tsfm_num_left_chunks: int = 64,
34
- tsfm_num_right_chunks: int = 2,
35
-
36
- discriminator_dim: int = 32,
37
- discriminator_in_channel: int = 2,
38
-
39
- compress_factor: float = 0.3,
40
-
41
- batch_size: int = 4,
42
- learning_rate: float = 0.0005,
43
- adam_b1: float = 0.8,
44
- adam_b2: float = 0.99,
45
- lr_decay: float = 0.99,
46
- seed: int = 1234,
47
-
48
- **kwargs
49
- ):
50
- super(NXMPNetConfig, self).__init__(**kwargs)
51
- self.sample_rate = sample_rate
52
- self.segment_size = segment_size
53
- self.n_fft = n_fft
54
- self.win_size = win_size
55
- self.hop_size = hop_size
56
-
57
- self.dense_num_blocks = dense_num_blocks
58
- self.dense_hidden_size = dense_hidden_size
59
-
60
- self.mask_num_blocks = mask_num_blocks
61
- self.mask_hidden_size = mask_hidden_size
62
-
63
- self.phase_num_blocks = phase_num_blocks
64
- self.phase_hidden_size = phase_hidden_size
65
-
66
- self.tsfm_hidden_size = tsfm_hidden_size
67
- self.tsfm_attention_heads = tsfm_attention_heads
68
- self.tsfm_num_blocks = tsfm_num_blocks
69
- self.tsfm_dropout_rate = tsfm_dropout_rate
70
- self.tsfm_max_time_relative_position = tsfm_max_time_relative_position
71
- self.tsfm_max_freq_relative_position = tsfm_max_freq_relative_position
72
- self.tsfm_chunk_size = tsfm_chunk_size
73
- self.tsfm_num_left_chunks = tsfm_num_left_chunks
74
- self.tsfm_num_right_chunks = tsfm_num_right_chunks
75
-
76
- self.discriminator_dim = discriminator_dim
77
- self.discriminator_in_channel = discriminator_in_channel
78
-
79
- self.compress_factor = compress_factor
80
-
81
- self.batch_size = batch_size
82
- self.learning_rate = learning_rate
83
- self.adam_b1 = adam_b1
84
- self.adam_b2 = adam_b2
85
- self.lr_decay = lr_decay
86
- self.seed = seed
87
-
88
-
89
- if __name__ == '__main__':
90
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
toolbox/torchaudio/models/nx_mpnet/discriminator.py DELETED
@@ -1,102 +0,0 @@
1
- #!/usr/bin/python3
2
- # -*- coding: utf-8 -*-
3
- import os
4
- from typing import Optional, Union
5
-
6
- import torch
7
- import torch.nn as nn
8
- import numpy as np
9
- import torch.nn.functional as F
10
- from pesq import pesq
11
- from joblib import Parallel, delayed
12
-
13
- from toolbox.torchaudio.configuration_utils import CONFIG_FILE
14
- from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
15
- from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid1d
16
-
17
-
18
- class MetricDiscriminator(nn.Module):
19
- def __init__(self, config: NXMPNetConfig):
20
- super(MetricDiscriminator, self).__init__()
21
- dim = config.discriminator_dim
22
- in_channel = config.discriminator_in_channel
23
-
24
- self.layers = nn.Sequential(
25
- nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
26
- nn.InstanceNorm2d(dim, affine=True),
27
- nn.PReLU(dim),
28
- nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
29
- nn.InstanceNorm2d(dim*2, affine=True),
30
- nn.PReLU(dim*2),
31
- nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
32
- nn.InstanceNorm2d(dim*4, affine=True),
33
- nn.PReLU(dim*4),
34
- nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
35
- nn.InstanceNorm2d(dim*8, affine=True),
36
- nn.PReLU(dim*8),
37
- nn.AdaptiveMaxPool2d(1),
38
- nn.Flatten(),
39
- nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
40
- nn.Dropout(0.3),
41
- nn.PReLU(dim*4),
42
- nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
43
- LearnableSigmoid1d(1)
44
- )
45
-
46
- def forward(self, x, y):
47
- xy = torch.stack((x, y), dim=1)
48
- return self.layers(xy)
49
-
50
-
51
- MODEL_FILE = "discriminator.pt"
52
-
53
-
54
- class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
55
- def __init__(self,
56
- config: NXMPNetConfig,
57
- ):
58
- super(MetricDiscriminatorPretrainedModel, self).__init__(
59
- config=config,
60
- )
61
- self.config = config
62
-
63
- @classmethod
64
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
65
- config = NXMPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
66
-
67
- model = cls(config)
68
-
69
- if os.path.isdir(pretrained_model_name_or_path):
70
- ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
71
- else:
72
- ckpt_file = pretrained_model_name_or_path
73
-
74
- with open(ckpt_file, "rb") as f:
75
- state_dict = torch.load(f, map_location="cpu", weights_only=True)
76
- model.load_state_dict(state_dict, strict=True)
77
- return model
78
-
79
- def save_pretrained(self,
80
- save_directory: Union[str, os.PathLike],
81
- state_dict: Optional[dict] = None,
82
- ):
83
-
84
- model = self
85
-
86
- if state_dict is None:
87
- state_dict = model.state_dict()
88
-
89
- os.makedirs(save_directory, exist_ok=True)
90
-
91
- # save state dict
92
- model_file = os.path.join(save_directory, MODEL_FILE)
93
- torch.save(state_dict, model_file)
94
-
95
- # save config
96
- config_file = os.path.join(save_directory, CONFIG_FILE)
97
- self.config.to_yaml_file(config_file)
98
- return save_directory
99
-
100
-
101
- if __name__ == '__main__':
102
- pass