Spaces:
Running
Running
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- examples/data_preprocess/nx_speech_denoise/nx_speech_denoise.py +0 -83
- examples/dfnet2/run.sh +3 -3
- examples/dtln/run.sh +9 -2
- examples/frcrn/run.sh +3 -3
- main.py +20 -13
- toolbox/torchaudio/models/{nx_clean_unet/transformers → dccrn}/__init__.py +1 -1
- toolbox/torchaudio/models/{nx_denoise/stftnet/istftnet.py → dccrn/modeling_dccrn.py} +6 -3
- toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py +97 -84
- toolbox/torchaudio/models/dtln/modeling_dtln.py +9 -2
- toolbox/torchaudio/models/ehnet/modeling_ehnet.py +0 -1
- toolbox/torchaudio/models/nx_clean_unet/__init__.py +0 -6
- toolbox/torchaudio/models/nx_clean_unet/causal_convolution/__init__.py +0 -6
- toolbox/torchaudio/models/nx_clean_unet/causal_convolution/causal_conv2d.py +0 -261
- toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py +0 -100
- toolbox/torchaudio/models/nx_clean_unet/discriminator.py +0 -132
- toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav +0 -0
- toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py +0 -96
- toolbox/torchaudio/models/nx_clean_unet/loss.py +0 -22
- toolbox/torchaudio/models/nx_clean_unet/metrics.py +0 -80
- toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py +0 -401
- toolbox/torchaudio/models/nx_clean_unet/transformers/attention.py +0 -270
- toolbox/torchaudio/models/nx_clean_unet/transformers/mask.py +0 -74
- toolbox/torchaudio/models/nx_clean_unet/transformers/transformers.py +0 -266
- toolbox/torchaudio/models/nx_clean_unet/utils.py +0 -45
- toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml +0 -51
- toolbox/torchaudio/models/nx_denoise/__init__.py +0 -6
- toolbox/torchaudio/models/nx_denoise/causal_convolution/__init__.py +0 -6
- toolbox/torchaudio/models/nx_denoise/causal_convolution/causal_conv2d.py +0 -281
- toolbox/torchaudio/models/nx_denoise/configuration_nx_denoise.py +0 -102
- toolbox/torchaudio/models/nx_denoise/discriminator.py +0 -132
- toolbox/torchaudio/models/nx_denoise/inference_nx_denoise.py +0 -97
- toolbox/torchaudio/models/nx_denoise/loss.py +0 -22
- toolbox/torchaudio/models/nx_denoise/metrics.py +0 -80
- toolbox/torchaudio/models/nx_denoise/modeling_nx_denoise.py +0 -392
- toolbox/torchaudio/models/nx_denoise/stftnet/__init__.py +0 -6
- toolbox/torchaudio/models/nx_denoise/stftnet/stfnets.py +0 -9
- toolbox/torchaudio/models/nx_denoise/transformers/__init__.py +0 -6
- toolbox/torchaudio/models/nx_denoise/transformers/attention.py +0 -263
- toolbox/torchaudio/models/nx_denoise/transformers/mask.py +0 -74
- toolbox/torchaudio/models/nx_denoise/transformers/transformers.py +0 -479
- toolbox/torchaudio/models/nx_denoise/utils.py +0 -45
- toolbox/torchaudio/models/nx_denoise/yaml/config.yaml +0 -51
- toolbox/torchaudio/models/nx_dfnet/configuration_nx_dfnet.py +0 -102
- toolbox/torchaudio/models/nx_dfnet/modeling_nx_dfnet.py +0 -989
- toolbox/torchaudio/models/nx_dfnet/utils.py +0 -55
- toolbox/torchaudio/models/nx_mpnet/__init__.py +0 -6
- toolbox/torchaudio/models/nx_mpnet/causal_convolution/__init__.py +0 -6
- toolbox/torchaudio/models/nx_mpnet/causal_convolution/causal_conv2d.py +0 -445
- toolbox/torchaudio/models/nx_mpnet/configuration_nx_mpnet.py +0 -90
- toolbox/torchaudio/models/nx_mpnet/discriminator.py +0 -102
examples/data_preprocess/nx_speech_denoise/nx_speech_denoise.py
DELETED
@@ -1,83 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import argparse
|
4 |
-
import os
|
5 |
-
from pathlib import Path
|
6 |
-
import sys
|
7 |
-
|
8 |
-
from gradio_client import Client, handle_file
|
9 |
-
import numpy as np
|
10 |
-
from tqdm import tqdm
|
11 |
-
import shutil
|
12 |
-
|
13 |
-
pwd = os.path.abspath(os.path.dirname(__file__))
|
14 |
-
sys.path.append(os.path.join(pwd, "../../"))
|
15 |
-
|
16 |
-
import librosa
|
17 |
-
from scipy.io import wavfile
|
18 |
-
|
19 |
-
|
20 |
-
def get_args():
|
21 |
-
parser = argparse.ArgumentParser()
|
22 |
-
parser.add_argument(
|
23 |
-
"--src_dir",
|
24 |
-
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-PH",
|
25 |
-
# default=r"/data/tianxing/HuggingDatasets/nx_noise/data/speech/en-PH",
|
26 |
-
type=str
|
27 |
-
)
|
28 |
-
parser.add_argument(
|
29 |
-
"--tgt_dir",
|
30 |
-
default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\speech-denoise\en-PH",
|
31 |
-
# default=r"/data/tianxing/HuggingDatasets/nx_noise/data/speech-denoise/en-PH",
|
32 |
-
type=str
|
33 |
-
)
|
34 |
-
args = parser.parse_args()
|
35 |
-
return args
|
36 |
-
|
37 |
-
|
38 |
-
def main():
|
39 |
-
args = get_args()
|
40 |
-
|
41 |
-
# client = Client(src="http://10.75.27.247:7865/")
|
42 |
-
client = Client(src="http://127.0.0.1:7865/")
|
43 |
-
|
44 |
-
src_dir = Path(args.src_dir)
|
45 |
-
tgt_dir = Path(args.tgt_dir)
|
46 |
-
tgt_dir.mkdir(parents=True, exist_ok=True)
|
47 |
-
|
48 |
-
tgt_date_list = list(sorted([date.name for date in src_dir.glob("*") if not date.name.endswith(".zip")]))
|
49 |
-
finished_date_set = set(tgt_date_list[:-1])
|
50 |
-
current_date = tgt_date_list[-1]
|
51 |
-
|
52 |
-
print(f"finished_date_set: {finished_date_set}")
|
53 |
-
print(f"current_date: {current_date}")
|
54 |
-
|
55 |
-
finished_set = set()
|
56 |
-
for filename in (tgt_dir / current_date).glob("*.wav"):
|
57 |
-
name = filename.name
|
58 |
-
finished_set.add(name)
|
59 |
-
|
60 |
-
src_date_list = list(sorted([date.name for date in src_dir.glob("*")]))
|
61 |
-
for date in src_date_list:
|
62 |
-
if date in finished_date_set:
|
63 |
-
continue
|
64 |
-
for filename in (src_dir / current_date).glob("**/*.wav"):
|
65 |
-
result = client.predict(
|
66 |
-
noisy_audio_file_t=handle_file(filename.as_posix()),
|
67 |
-
noisy_audio_microphone_t=None,
|
68 |
-
engine="frcrn-dns3",
|
69 |
-
api_name="/when_click_denoise_button"
|
70 |
-
)
|
71 |
-
denoise_file = result[0]
|
72 |
-
tgt_file = tgt_dir / current_date / f"{filename.name}"
|
73 |
-
tgt_file.parent.mkdir(parents=True, exist_ok=True)
|
74 |
-
|
75 |
-
shutil.move(denoise_file, tgt_file)
|
76 |
-
print(denoise_file)
|
77 |
-
exit(0)
|
78 |
-
|
79 |
-
return
|
80 |
-
|
81 |
-
|
82 |
-
if __name__ == "__main__":
|
83 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/dfnet2/run.sh
CHANGED
@@ -10,9 +10,9 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
|
|
10 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
11 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
12 |
|
13 |
-
sh run.sh --stage
|
14 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/
|
15 |
-
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/
|
16 |
|
17 |
|
18 |
END
|
|
|
10 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
11 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
12 |
|
13 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx2 \
|
14 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
|
15 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
|
16 |
|
17 |
|
18 |
END
|
examples/dtln/run.sh
CHANGED
@@ -7,16 +7,23 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi
|
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
9 |
|
|
|
10 |
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \
|
11 |
--config_file "yaml/config-512.yaml" \
|
12 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
13 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
14 |
|
15 |
|
16 |
-
sh run.sh --stage
|
17 |
--config_file "yaml/config-1024.yaml" \
|
18 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
|
19 |
-
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
END
|
|
|
7 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
8 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
9 |
|
10 |
+
|
11 |
sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \
|
12 |
--config_file "yaml/config-512.yaml" \
|
13 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \
|
14 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech"
|
15 |
|
16 |
|
17 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtnl-1024-nx2 --final_model_name dtln-1024-nx2 \
|
18 |
--config_file "yaml/config-1024.yaml" \
|
19 |
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
|
20 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
|
21 |
+
|
22 |
+
|
23 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir-1024 --final_model_name dtln-1024-nx-devoice \
|
24 |
+
--config_file "yaml/config-1024.yaml" \
|
25 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2" \
|
26 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise"
|
27 |
|
28 |
|
29 |
END
|
examples/frcrn/run.sh
CHANGED
@@ -9,10 +9,10 @@ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name fi
|
|
9 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
|
10 |
|
11 |
|
12 |
-
sh run.sh --stage
|
13 |
--config_file "yaml/config-10.yaml" \
|
14 |
-
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/
|
15 |
-
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/
|
16 |
|
17 |
END
|
18 |
|
|
|
9 |
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech"
|
10 |
|
11 |
|
12 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx2 \
|
13 |
--config_file "yaml/config-10.yaml" \
|
14 |
+
--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \
|
15 |
+
--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2"
|
16 |
|
17 |
END
|
18 |
|
main.py
CHANGED
@@ -177,14 +177,10 @@ def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_
|
|
177 |
infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs)
|
178 |
|
179 |
begin = time.time()
|
180 |
-
|
181 |
time_cost = time.time() - begin
|
182 |
|
183 |
-
noisy_mag_db = generate_spectrogram(noisy_audio, title="noisy")
|
184 |
-
denoise_mag_db = generate_spectrogram(enhanced_audio, title="denoise")
|
185 |
-
|
186 |
fpr = time_cost / audio_duration
|
187 |
-
|
188 |
info = {
|
189 |
"time_cost": round(time_cost, 4),
|
190 |
"audio_duration": round(audio_duration, 4),
|
@@ -192,12 +188,21 @@ def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_
|
|
192 |
}
|
193 |
message = json.dumps(info, ensure_ascii=False, indent=4)
|
194 |
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
except Exception as e:
|
197 |
raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.")
|
198 |
|
199 |
-
|
200 |
-
|
|
|
201 |
|
202 |
|
203 |
def main():
|
@@ -255,21 +260,23 @@ def main():
|
|
255 |
with gr.Column(variant="panel", scale=5):
|
256 |
with gr.Tabs():
|
257 |
with gr.TabItem("audio"):
|
258 |
-
|
|
|
259 |
dn_message = gr.Textbox(lines=1, max_lines=20, label="message")
|
260 |
with gr.TabItem("mag_db"):
|
261 |
dn_noisy_mag_db = gr.Image(label="noisy_mag_db")
|
262 |
dn_denoise_mag_db = gr.Image(label="denoise_mag_db")
|
|
|
263 |
|
264 |
dn_button.click(
|
265 |
when_click_denoise_button,
|
266 |
inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
|
267 |
-
outputs=[
|
268 |
)
|
269 |
gr.Examples(
|
270 |
examples=examples,
|
271 |
inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
|
272 |
-
outputs=[
|
273 |
fn=when_click_denoise_button,
|
274 |
# cache_examples=True,
|
275 |
# cache_mode="lazy",
|
@@ -289,8 +296,8 @@ def main():
|
|
289 |
# http://127.0.0.1:7865/
|
290 |
# http://10.75.27.247:7865/
|
291 |
blocks.queue().launch(
|
292 |
-
share=True,
|
293 |
-
|
294 |
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
|
295 |
server_port=args.server_port
|
296 |
)
|
|
|
177 |
infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs)
|
178 |
|
179 |
begin = time.time()
|
180 |
+
denoise_audio = infer_engine.enhancement_by_ndarray(noisy_audio)
|
181 |
time_cost = time.time() - begin
|
182 |
|
|
|
|
|
|
|
183 |
fpr = time_cost / audio_duration
|
|
|
184 |
info = {
|
185 |
"time_cost": round(time_cost, 4),
|
186 |
"audio_duration": round(audio_duration, 4),
|
|
|
188 |
}
|
189 |
message = json.dumps(info, ensure_ascii=False, indent=4)
|
190 |
|
191 |
+
noise_audio = noisy_audio - denoise_audio
|
192 |
+
|
193 |
+
noisy_mag_db = generate_spectrogram(noisy_audio, title="noisy")
|
194 |
+
denoise_mag_db = generate_spectrogram(denoise_audio, title="denoise")
|
195 |
+
noise_mag_db = generate_spectrogram(noise_audio, title="noise")
|
196 |
+
|
197 |
+
denoise_audio = np.array(denoise_audio * (1 << 15), dtype=np.int16)
|
198 |
+
noise_audio = np.array(noise_audio * (1 << 15), dtype=np.int16)
|
199 |
+
|
200 |
except Exception as e:
|
201 |
raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.")
|
202 |
|
203 |
+
denoise_audio_t = (sample_rate, denoise_audio)
|
204 |
+
noise_audio_t = (sample_rate, noise_audio)
|
205 |
+
return denoise_audio_t, noise_audio_t, message, noisy_mag_db, denoise_mag_db, noise_mag_db
|
206 |
|
207 |
|
208 |
def main():
|
|
|
260 |
with gr.Column(variant="panel", scale=5):
|
261 |
with gr.Tabs():
|
262 |
with gr.TabItem("audio"):
|
263 |
+
dn_denoise_audio = gr.Audio(label="denoise_audio")
|
264 |
+
dn_noise_audio = gr.Audio(label="noise_audio")
|
265 |
dn_message = gr.Textbox(lines=1, max_lines=20, label="message")
|
266 |
with gr.TabItem("mag_db"):
|
267 |
dn_noisy_mag_db = gr.Image(label="noisy_mag_db")
|
268 |
dn_denoise_mag_db = gr.Image(label="denoise_mag_db")
|
269 |
+
dn_noise_mag_db = gr.Image(label="noise_mag_db")
|
270 |
|
271 |
dn_button.click(
|
272 |
when_click_denoise_button,
|
273 |
inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
|
274 |
+
outputs=[dn_denoise_audio, dn_noise_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db, dn_noise_mag_db]
|
275 |
)
|
276 |
gr.Examples(
|
277 |
examples=examples,
|
278 |
inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine],
|
279 |
+
outputs=[dn_denoise_audio, dn_noise_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db, dn_noise_mag_db],
|
280 |
fn=when_click_denoise_button,
|
281 |
# cache_examples=True,
|
282 |
# cache_mode="lazy",
|
|
|
296 |
# http://127.0.0.1:7865/
|
297 |
# http://10.75.27.247:7865/
|
298 |
blocks.queue().launch(
|
299 |
+
# share=True,
|
300 |
+
share=False if platform.system() == "Windows" else False,
|
301 |
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
|
302 |
server_port=args.server_port
|
303 |
)
|
toolbox/torchaudio/models/{nx_clean_unet/transformers → dccrn}/__init__.py
RENAMED
@@ -2,5 +2,5 @@
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
|
5 |
-
if __name__ ==
|
6 |
pass
|
|
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
|
5 |
+
if __name__ == "__main__":
|
6 |
pass
|
toolbox/torchaudio/models/{nx_denoise/stftnet/istftnet.py → dccrn/modeling_dccrn.py}
RENAMED
@@ -1,9 +1,12 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
-
https://arxiv.org/abs/2203.02395
|
5 |
-
"""
|
6 |
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
if __name__ ==
|
9 |
pass
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
|
|
|
|
4 |
|
5 |
+
https://arxiv.org/abs/2008.00264
|
6 |
+
|
7 |
+
https://github.com/huyanxin/DeepComplexCRN
|
8 |
+
|
9 |
+
"""
|
10 |
|
11 |
+
if __name__ == "__main__":
|
12 |
pass
|
toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py
CHANGED
@@ -11,7 +11,6 @@ https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd444
|
|
11 |
"""
|
12 |
import os
|
13 |
import math
|
14 |
-
from collections import defaultdict
|
15 |
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
16 |
|
17 |
import numpy as np
|
@@ -109,7 +108,7 @@ class CausalConv2d(nn.Module):
|
|
109 |
else:
|
110 |
self.activation = nn.Identity()
|
111 |
|
112 |
-
def forward(self, inputs: torch.Tensor, cache:
|
113 |
"""
|
114 |
:param inputs: shape: [b, c, t, f]
|
115 |
:param cache: shape: [b, c, lookback, f];
|
@@ -560,15 +559,14 @@ class Encoder(nn.Module):
|
|
560 |
feat_spec: torch.Tensor,
|
561 |
cache_dict: dict = None,
|
562 |
):
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
cache6 = cache_dict["cache6"]
|
572 |
|
573 |
# feat_erb shape: (b, 1, t, erb_bins)
|
574 |
e0, new_cache0 = self.spec_conv0.forward(feat_erb, cache=cache0)
|
@@ -716,13 +714,12 @@ class ErbDecoder(nn.Module):
|
|
716 |
)
|
717 |
|
718 |
def forward(self, emb, e3, e2, e1, e0, cache_dict: dict = None) -> torch.Tensor:
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
cache4 = cache_dict["cache4"]
|
726 |
|
727 |
# Estimates erb mask
|
728 |
b, _, t, f8 = e3.shape
|
@@ -814,10 +811,9 @@ class DfDecoder(nn.Module):
|
|
814 |
)
|
815 |
|
816 |
def forward(self, emb: torch.Tensor, c0: torch.Tensor, cache_dict: dict = None) -> torch.Tensor:
|
817 |
-
|
818 |
-
|
819 |
-
|
820 |
-
cache1 = cache_dict["cache1"]
|
821 |
|
822 |
# emb shape: [batch_size, time_steps, df_bins // 4 * channels]
|
823 |
b, t, _ = emb.shape
|
@@ -995,10 +991,9 @@ class DeepFiltering(nn.Module):
|
|
995 |
coefs: torch.Tensor,
|
996 |
cache_dict: dict = None,
|
997 |
):
|
998 |
-
|
999 |
-
|
1000 |
-
|
1001 |
-
cache1 = cache_dict["cache1"]
|
1002 |
|
1003 |
# spec shape: [b, 1, t, spec_bins, 2]
|
1004 |
spec_c = torch.view_as_complex(spec.contiguous())
|
@@ -1163,10 +1158,9 @@ class DfNet2(nn.Module):
|
|
1163 |
return spec, feat_erb, feat_spec
|
1164 |
|
1165 |
def feature_norm(self, feat_erb, feat_spec, cache_dict: dict = None):
|
1166 |
-
|
1167 |
-
|
1168 |
-
|
1169 |
-
cache1 = cache_dict["cache1"]
|
1170 |
|
1171 |
feat_erb, new_cache0 = self.erb_ema.norm(feat_erb, state=cache0)
|
1172 |
feat_spec, new_cache1 = self.spec_ema.norm(feat_spec, state=cache1)
|
@@ -1249,6 +1243,65 @@ class DfNet2(nn.Module):
|
|
1249 |
|
1250 |
return est_spec, est_wav, est_mask, lsnr
|
1251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1252 |
def forward_chunk_by_chunk(self,
|
1253 |
noisy: torch.Tensor,
|
1254 |
):
|
@@ -1275,52 +1328,13 @@ class DfNet2(nn.Module):
|
|
1275 |
end = begin + self.win_size
|
1276 |
sub_noisy = noisy[:, :, begin: end]
|
1277 |
|
1278 |
-
|
1279 |
-
|
1280 |
-
|
1281 |
-
|
1282 |
-
|
1283 |
-
|
1284 |
-
|
1285 |
-
e0, e1, e2, e3, emb, c0, lsnr, cache_dict1 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict1)
|
1286 |
-
|
1287 |
-
mask, cache_dict2 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict2)
|
1288 |
-
# mask shape: [b, 1, t, erb_bins]
|
1289 |
-
mask = self.erb_bands.erb_scale_inv(mask)
|
1290 |
-
# mask shape: [b, 1, t, f]
|
1291 |
-
|
1292 |
-
spec_m = self.mask.forward(spec, mask)
|
1293 |
-
# spec_m shape: [b, 1, t, f, 2]
|
1294 |
-
spec_m = spec_m[:, :, :, :self.config.spec_bins, :]
|
1295 |
-
# spec_m shape: [b, 1, t, spec_bins, 2]
|
1296 |
-
|
1297 |
-
# lsnr shape: [b, t, 1]
|
1298 |
-
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
1299 |
-
# lsnr shape: [b, 1, t]
|
1300 |
-
|
1301 |
-
df_coefs, cache_dict3 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict3)
|
1302 |
-
df_coefs = self.df_out_transform(df_coefs)
|
1303 |
-
# df_coefs shape: [b, df_order, t, df_bins, 2]
|
1304 |
-
|
1305 |
-
spec_ = spec[:, :, :, :self.config.spec_bins, :]
|
1306 |
-
# spec shape: [b, 1, t, spec_bins, 2]
|
1307 |
-
spec_f, cache_dict4 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict4)
|
1308 |
-
# spec_f shape: [b, 1, t, df_bins, 2], torch.float32
|
1309 |
-
|
1310 |
-
spec_e, cache_dict5 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict5)
|
1311 |
-
|
1312 |
-
spec_e = torch.squeeze(spec_e, dim=1)
|
1313 |
-
spec_e = spec_e.permute(0, 2, 1, 3)
|
1314 |
-
# spec_e shape: [b, spec_bins, t, 2]
|
1315 |
-
|
1316 |
-
# spec_e shape: [b, spec_bins, t, 2]
|
1317 |
-
est_spec = torch.view_as_complex(spec_e.contiguous())
|
1318 |
-
# est_spec shape: [b, spec_bins, t], torch.complex64
|
1319 |
-
est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
|
1320 |
-
# est_spec shape: [b, f, t], torch.complex64
|
1321 |
-
|
1322 |
-
est_wav, cache_dict6 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict6)
|
1323 |
-
# est_wav shape: [b, 1, hop_size]
|
1324 |
|
1325 |
waveform_list.append(est_wav)
|
1326 |
|
@@ -1335,27 +1349,26 @@ class DfNet2(nn.Module):
|
|
1335 |
:param cache_dict:
|
1336 |
:return:
|
1337 |
"""
|
1338 |
-
|
1339 |
-
|
1340 |
-
cache_spec_m = cache_dict["cache_spec_m"]
|
1341 |
|
1342 |
-
if
|
1343 |
b, c, t, f, _ = spec_m.shape
|
1344 |
-
|
1345 |
# cache0 shape: [b, 1, lookahead, f, 2]
|
1346 |
spec_m_cat = torch.concat(tensors=[
|
1347 |
-
|
1348 |
], dim=2)
|
1349 |
|
1350 |
spec_m = spec_m_cat[:, :, :-self.config.df_lookahead, :, :]
|
1351 |
-
|
1352 |
|
1353 |
spec_e = torch.concat(tensors=[
|
1354 |
spec_f, spec_m[..., self.df_decoder.df_bins:, :]
|
1355 |
], dim=3)
|
1356 |
|
1357 |
new_cache_dict = {
|
1358 |
-
"
|
1359 |
}
|
1360 |
return spec_e, new_cache_dict
|
1361 |
|
|
|
11 |
"""
|
12 |
import os
|
13 |
import math
|
|
|
14 |
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
15 |
|
16 |
import numpy as np
|
|
|
108 |
else:
|
109 |
self.activation = nn.Identity()
|
110 |
|
111 |
+
def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None):
|
112 |
"""
|
113 |
:param inputs: shape: [b, c, t, f]
|
114 |
:param cache: shape: [b, c, lookback, f];
|
|
|
559 |
feat_spec: torch.Tensor,
|
560 |
cache_dict: dict = None,
|
561 |
):
|
562 |
+
cache_dict = cache_dict or dict()
|
563 |
+
cache0 = cache_dict.get("cache0", None)
|
564 |
+
cache1 = cache_dict.get("cache1", None)
|
565 |
+
cache2 = cache_dict.get("cache2", None)
|
566 |
+
cache3 = cache_dict.get("cache3", None)
|
567 |
+
cache4 = cache_dict.get("cache4", None)
|
568 |
+
cache5 = cache_dict.get("cache5", None)
|
569 |
+
cache6 = cache_dict.get("cache6", None)
|
|
|
570 |
|
571 |
# feat_erb shape: (b, 1, t, erb_bins)
|
572 |
e0, new_cache0 = self.spec_conv0.forward(feat_erb, cache=cache0)
|
|
|
714 |
)
|
715 |
|
716 |
def forward(self, emb, e3, e2, e1, e0, cache_dict: dict = None) -> torch.Tensor:
|
717 |
+
cache_dict = cache_dict or dict()
|
718 |
+
cache0 = cache_dict.get("cache0", None)
|
719 |
+
cache1 = cache_dict.get("cache1", None)
|
720 |
+
cache2 = cache_dict.get("cache2", None)
|
721 |
+
cache3 = cache_dict.get("cache3", None)
|
722 |
+
cache4 = cache_dict.get("cache4", None)
|
|
|
723 |
|
724 |
# Estimates erb mask
|
725 |
b, _, t, f8 = e3.shape
|
|
|
811 |
)
|
812 |
|
813 |
def forward(self, emb: torch.Tensor, c0: torch.Tensor, cache_dict: dict = None) -> torch.Tensor:
|
814 |
+
cache_dict = cache_dict or dict()
|
815 |
+
cache0 = cache_dict.get("cache0", None)
|
816 |
+
cache1 = cache_dict.get("cache1", None)
|
|
|
817 |
|
818 |
# emb shape: [batch_size, time_steps, df_bins // 4 * channels]
|
819 |
b, t, _ = emb.shape
|
|
|
991 |
coefs: torch.Tensor,
|
992 |
cache_dict: dict = None,
|
993 |
):
|
994 |
+
cache_dict = cache_dict or dict()
|
995 |
+
cache0 = cache_dict.get("cache0", None)
|
996 |
+
cache1 = cache_dict.get("cache1", None)
|
|
|
997 |
|
998 |
# spec shape: [b, 1, t, spec_bins, 2]
|
999 |
spec_c = torch.view_as_complex(spec.contiguous())
|
|
|
1158 |
return spec, feat_erb, feat_spec
|
1159 |
|
1160 |
def feature_norm(self, feat_erb, feat_spec, cache_dict: dict = None):
|
1161 |
+
cache_dict = cache_dict or dict()
|
1162 |
+
cache0 = cache_dict.get("cache0", None)
|
1163 |
+
cache1 = cache_dict.get("cache1", None)
|
|
|
1164 |
|
1165 |
feat_erb, new_cache0 = self.erb_ema.norm(feat_erb, state=cache0)
|
1166 |
feat_spec, new_cache1 = self.spec_ema.norm(feat_spec, state=cache1)
|
|
|
1243 |
|
1244 |
return est_spec, est_wav, est_mask, lsnr
|
1245 |
|
1246 |
+
def forward_chunk(self,
|
1247 |
+
sub_noisy: torch.Tensor,
|
1248 |
+
cache_dict0: dict = None,
|
1249 |
+
cache_dict1: dict = None,
|
1250 |
+
cache_dict2: dict = None,
|
1251 |
+
cache_dict3: dict = None,
|
1252 |
+
cache_dict4: dict = None,
|
1253 |
+
cache_dict5: dict = None,
|
1254 |
+
cache_dict6: dict = None,
|
1255 |
+
):
|
1256 |
+
|
1257 |
+
spec, feat_erb, feat_spec = self.feature_prepare(sub_noisy)
|
1258 |
+
# spec shape: [b, 1, t, f, 2]
|
1259 |
+
# feat_erb shape: [b, 1, t, erb_bins]
|
1260 |
+
# feat_spec shape: [b, 2, t, df_bins]
|
1261 |
+
if self.config.use_ema_norm:
|
1262 |
+
feat_erb, feat_spec, cache_dict0 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict0)
|
1263 |
+
|
1264 |
+
e0, e1, e2, e3, emb, c0, lsnr, cache_dict1 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict1)
|
1265 |
+
|
1266 |
+
mask, cache_dict2 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict2)
|
1267 |
+
# mask shape: [b, 1, t, erb_bins]
|
1268 |
+
mask = self.erb_bands.erb_scale_inv(mask)
|
1269 |
+
# mask shape: [b, 1, t, f]
|
1270 |
+
|
1271 |
+
spec_m = self.mask.forward(spec, mask)
|
1272 |
+
# spec_m shape: [b, 1, t, f, 2]
|
1273 |
+
spec_m = spec_m[:, :, :, :self.config.spec_bins, :]
|
1274 |
+
# spec_m shape: [b, 1, t, spec_bins, 2]
|
1275 |
+
|
1276 |
+
# lsnr shape: [b, t, 1]
|
1277 |
+
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
1278 |
+
# lsnr shape: [b, 1, t]
|
1279 |
+
|
1280 |
+
df_coefs, cache_dict3 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict3)
|
1281 |
+
df_coefs = self.df_out_transform(df_coefs)
|
1282 |
+
# df_coefs shape: [b, df_order, t, df_bins, 2]
|
1283 |
+
|
1284 |
+
spec_ = spec[:, :, :, :self.config.spec_bins, :]
|
1285 |
+
# spec shape: [b, 1, t, spec_bins, 2]
|
1286 |
+
spec_f, cache_dict4 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict4)
|
1287 |
+
# spec_f shape: [b, 1, t, df_bins, 2], torch.float32
|
1288 |
+
|
1289 |
+
spec_e, cache_dict5 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict5)
|
1290 |
+
|
1291 |
+
spec_e = torch.squeeze(spec_e, dim=1)
|
1292 |
+
spec_e = spec_e.permute(0, 2, 1, 3)
|
1293 |
+
# spec_e shape: [b, spec_bins, t, 2]
|
1294 |
+
|
1295 |
+
# spec_e shape: [b, spec_bins, t, 2]
|
1296 |
+
est_spec = torch.view_as_complex(spec_e.contiguous())
|
1297 |
+
# est_spec shape: [b, spec_bins, t], torch.complex64
|
1298 |
+
est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
|
1299 |
+
# est_spec shape: [b, f, t], torch.complex64
|
1300 |
+
|
1301 |
+
est_wav, cache_dict6 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict6)
|
1302 |
+
# est_wav shape: [b, 1, hop_size]
|
1303 |
+
return est_wav, cache_dict0, cache_dict1, cache_dict2, cache_dict3, cache_dict4, cache_dict5, cache_dict6
|
1304 |
+
|
1305 |
def forward_chunk_by_chunk(self,
|
1306 |
noisy: torch.Tensor,
|
1307 |
):
|
|
|
1328 |
end = begin + self.win_size
|
1329 |
sub_noisy = noisy[:, :, begin: end]
|
1330 |
|
1331 |
+
(est_wav,
|
1332 |
+
cache_dict0, cache_dict1, cache_dict2, cache_dict3,
|
1333 |
+
cache_dict4, cache_dict5, cache_dict6) = self.forward_chunk(
|
1334 |
+
sub_noisy,
|
1335 |
+
cache_dict0, cache_dict1, cache_dict2, cache_dict3,
|
1336 |
+
cache_dict4, cache_dict5, cache_dict6
|
1337 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1338 |
|
1339 |
waveform_list.append(est_wav)
|
1340 |
|
|
|
1349 |
:param cache_dict:
|
1350 |
:return:
|
1351 |
"""
|
1352 |
+
cache_dict = cache_dict or dict()
|
1353 |
+
cache0 = cache_dict.get("cache0", None)
|
|
|
1354 |
|
1355 |
+
if cache0 is None:
|
1356 |
b, c, t, f, _ = spec_m.shape
|
1357 |
+
cache0 = spec_m.new_zeros(size=(b, c, self.config.df_lookahead, f, 2))
|
1358 |
# cache0 shape: [b, 1, lookahead, f, 2]
|
1359 |
spec_m_cat = torch.concat(tensors=[
|
1360 |
+
cache0, spec_m,
|
1361 |
], dim=2)
|
1362 |
|
1363 |
spec_m = spec_m_cat[:, :, :-self.config.df_lookahead, :, :]
|
1364 |
+
new_cache0 = spec_m_cat[:, :, -self.config.df_lookahead:, :, :]
|
1365 |
|
1366 |
spec_e = torch.concat(tensors=[
|
1367 |
spec_f, spec_m[..., self.df_decoder.df_bins:, :]
|
1368 |
], dim=3)
|
1369 |
|
1370 |
new_cache_dict = {
|
1371 |
+
"cache0": new_cache0,
|
1372 |
}
|
1373 |
return spec_e, new_cache_dict
|
1374 |
|
toolbox/torchaudio/models/dtln/modeling_dtln.py
CHANGED
@@ -1,9 +1,17 @@
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
|
|
|
|
4 |
https://github.com/AkenoSyuRi/DTLNPytorch
|
5 |
|
6 |
https://github.com/breizhn/DTLN
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
在 dns3 500个小时的数据上训练, 在 dns3 的测试集上达到了 pesq 3.04 的水平。
|
8 |
|
9 |
"""
|
@@ -245,13 +253,12 @@ class DTLNModel(nn.Module):
|
|
245 |
# print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
|
246 |
|
247 |
t = (num_samples_pad - self.fft_size) // self.hop_size + 1
|
|
|
248 |
|
249 |
denoise_list = list()
|
250 |
out_state1 = None
|
251 |
out_state2 = None
|
252 |
-
overlap_size = self.fft_size - self.hop_size
|
253 |
denoise_cache = torch.zeros(size=(batch_size, overlap_size), dtype=noisy.dtype)
|
254 |
-
# denoise_list.append(torch.clone(denoise_cache))
|
255 |
for i in range(t):
|
256 |
begin = i * self.hop_size
|
257 |
end = begin + self.fft_size
|
|
|
1 |
#!/usr/bin/python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
"""
|
4 |
+
https://www.isca-archive.org/interspeech_2020/westhausen20_interspeech.pdf
|
5 |
+
|
6 |
https://github.com/AkenoSyuRi/DTLNPytorch
|
7 |
|
8 |
https://github.com/breizhn/DTLN
|
9 |
+
|
10 |
+
数据集: DNS3 DNS-Challenge
|
11 |
+
信噪比 -5 到 25 dB
|
12 |
+
5 到 30 dB
|
13 |
+
窗长 32ms, 窗移 8ms
|
14 |
+
|
15 |
在 dns3 500个小时的数据上训练, 在 dns3 的测试集上达到了 pesq 3.04 的水平。
|
16 |
|
17 |
"""
|
|
|
253 |
# print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
|
254 |
|
255 |
t = (num_samples_pad - self.fft_size) // self.hop_size + 1
|
256 |
+
overlap_size = self.fft_size - self.hop_size
|
257 |
|
258 |
denoise_list = list()
|
259 |
out_state1 = None
|
260 |
out_state2 = None
|
|
|
261 |
denoise_cache = torch.zeros(size=(batch_size, overlap_size), dtype=noisy.dtype)
|
|
|
262 |
for i in range(t):
|
263 |
begin = i * self.hop_size
|
264 |
end = begin + self.fft_size
|
toolbox/torchaudio/models/ehnet/modeling_ehnet.py
CHANGED
@@ -71,7 +71,6 @@ class CausalTransConvBlock(nn.Module):
|
|
71 |
return x
|
72 |
|
73 |
|
74 |
-
|
75 |
class CRN(nn.Module):
|
76 |
"""
|
77 |
Input: [batch size, channels=1, T, n_fft]
|
|
|
71 |
return x
|
72 |
|
73 |
|
|
|
74 |
class CRN(nn.Module):
|
75 |
"""
|
76 |
Input: [batch size, channels=1, T, n_fft]
|
toolbox/torchaudio/models/nx_clean_unet/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
|
5 |
-
if __name__ == '__main__':
|
6 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/causal_convolution/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
|
5 |
-
if __name__ == '__main__':
|
6 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/causal_convolution/causal_conv2d.py
DELETED
@@ -1,261 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import math
|
4 |
-
import os
|
5 |
-
from typing import List, Optional, Union, Iterable
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
from torch.nn import functional as F
|
11 |
-
|
12 |
-
|
13 |
-
norm_layer_dict = {
|
14 |
-
"batch_norm_2d": torch.nn.BatchNorm2d
|
15 |
-
}
|
16 |
-
|
17 |
-
|
18 |
-
activation_layer_dict = {
|
19 |
-
"relu": torch.nn.ReLU,
|
20 |
-
"identity": torch.nn.Identity,
|
21 |
-
"sigmoid": torch.nn.Sigmoid,
|
22 |
-
}
|
23 |
-
|
24 |
-
|
25 |
-
class CausalConv2d(nn.Module):
|
26 |
-
def __init__(self,
|
27 |
-
in_channels: int,
|
28 |
-
out_channels: int,
|
29 |
-
kernel_size: Union[int, Iterable[int]],
|
30 |
-
f_stride: int = 1,
|
31 |
-
dilation: int = 1,
|
32 |
-
do_f_pad: bool = True,
|
33 |
-
bias: bool = True,
|
34 |
-
separable: bool = False,
|
35 |
-
norm_layer: str = "batch_norm_2d",
|
36 |
-
activation_layer: str = "relu",
|
37 |
-
lookahead: int = 0
|
38 |
-
):
|
39 |
-
super(CausalConv2d, self).__init__()
|
40 |
-
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
41 |
-
|
42 |
-
if do_f_pad:
|
43 |
-
f_pad = kernel_size[1] // 2 + dilation - 1
|
44 |
-
else:
|
45 |
-
f_pad = 0
|
46 |
-
|
47 |
-
self.causal_left_pad = kernel_size[0] - 1 - lookahead
|
48 |
-
self.causal_right_pad = lookahead
|
49 |
-
self.constant_pad = nn.ConstantPad2d(
|
50 |
-
padding=(0, 0, self.causal_left_pad, self.causal_right_pad),
|
51 |
-
value=0.0
|
52 |
-
)
|
53 |
-
|
54 |
-
groups = math.gcd(in_channels, out_channels) if separable else 1
|
55 |
-
self.conv1 = nn.Conv2d(
|
56 |
-
in_channels,
|
57 |
-
out_channels,
|
58 |
-
kernel_size=kernel_size,
|
59 |
-
padding=(0, f_pad),
|
60 |
-
stride=(1, f_stride),
|
61 |
-
dilation=(1, dilation),
|
62 |
-
groups=groups,
|
63 |
-
bias=bias,
|
64 |
-
)
|
65 |
-
|
66 |
-
self.conv2 = None
|
67 |
-
if not any([groups == 1, max(kernel_size) == 1]):
|
68 |
-
self.conv2 = nn.Conv2d(
|
69 |
-
out_channels,
|
70 |
-
out_channels,
|
71 |
-
kernel_size=1,
|
72 |
-
bias=False,
|
73 |
-
)
|
74 |
-
|
75 |
-
self.norm = None
|
76 |
-
if norm_layer is not None:
|
77 |
-
norm_layer = norm_layer_dict[norm_layer]
|
78 |
-
self.norm = norm_layer(out_channels)
|
79 |
-
|
80 |
-
self.activation = None
|
81 |
-
if activation_layer is not None:
|
82 |
-
activation_layer = activation_layer_dict[activation_layer]
|
83 |
-
self.activation = activation_layer()
|
84 |
-
|
85 |
-
def forward(self,
|
86 |
-
inputs: torch.Tensor,
|
87 |
-
causal_cache: torch.Tensor = None,
|
88 |
-
):
|
89 |
-
|
90 |
-
if causal_cache is None:
|
91 |
-
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
92 |
-
x = self.constant_pad.forward(inputs)
|
93 |
-
else:
|
94 |
-
# inputs shape: [batch_size, 1, time_steps + self.causal_right_pad, hidden_size]
|
95 |
-
# causal_cache shape: [batch_size, 1, self.causal_left_pad, hidden_size]
|
96 |
-
x = torch.concat(tensors=[causal_cache, inputs], dim=2)
|
97 |
-
# x shape: [batch_size, 1, time_steps2, hidden_size]
|
98 |
-
# time_steps2 = time_steps + self.causal_left_pad + self.causal_right_pad
|
99 |
-
|
100 |
-
x = self.conv1.forward(x)
|
101 |
-
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
102 |
-
|
103 |
-
if self.conv2:
|
104 |
-
x = self.conv2.forward(x)
|
105 |
-
|
106 |
-
if self.norm:
|
107 |
-
x = self.norm(x)
|
108 |
-
if self.activation:
|
109 |
-
x = self.activation(x)
|
110 |
-
|
111 |
-
causal_cache = x[:, :, -self.causal_left_pad:, :]
|
112 |
-
|
113 |
-
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
114 |
-
return x, causal_cache
|
115 |
-
|
116 |
-
|
117 |
-
class CausalConv2dEncoder(nn.Module):
|
118 |
-
def __init__(self,
|
119 |
-
in_channels: int,
|
120 |
-
out_channels: int,
|
121 |
-
kernel_size: Union[int, Iterable[int]],
|
122 |
-
f_stride: int = 1,
|
123 |
-
dilation: int = 1,
|
124 |
-
do_f_pad: bool = True,
|
125 |
-
bias: bool = True,
|
126 |
-
separable: bool = False,
|
127 |
-
norm_layer: str = "batch_norm_2d",
|
128 |
-
activation_layer: str = "relu",
|
129 |
-
lookahead: int = 0,
|
130 |
-
num_layers: int = 5,
|
131 |
-
):
|
132 |
-
super(CausalConv2dEncoder, self).__init__()
|
133 |
-
self.num_layers = num_layers
|
134 |
-
|
135 |
-
self.total_causal_left_pad = 0
|
136 |
-
self.total_causal_right_pad = 0
|
137 |
-
|
138 |
-
self.causal_conv_list: List[CausalConv2d] = nn.ModuleList(modules=[])
|
139 |
-
for i_layer in range(num_layers):
|
140 |
-
conv = CausalConv2d(
|
141 |
-
in_channels=in_channels,
|
142 |
-
out_channels=out_channels,
|
143 |
-
kernel_size=kernel_size,
|
144 |
-
f_stride=f_stride,
|
145 |
-
dilation=dilation,
|
146 |
-
do_f_pad=do_f_pad,
|
147 |
-
bias=bias,
|
148 |
-
separable=separable,
|
149 |
-
norm_layer=norm_layer,
|
150 |
-
activation_layer=activation_layer,
|
151 |
-
lookahead=lookahead,
|
152 |
-
)
|
153 |
-
self.causal_conv_list.append(conv)
|
154 |
-
|
155 |
-
self.total_causal_left_pad += conv.causal_left_pad
|
156 |
-
self.total_causal_right_pad += conv.causal_right_pad
|
157 |
-
|
158 |
-
in_channels = out_channels
|
159 |
-
|
160 |
-
def forward(self, inputs: torch.Tensor):
|
161 |
-
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
162 |
-
|
163 |
-
x = inputs
|
164 |
-
for layer in self.causal_conv_list:
|
165 |
-
x, _ = layer.forward(x)
|
166 |
-
return x
|
167 |
-
|
168 |
-
def forward_chunk(self,
|
169 |
-
chunk: torch.Tensor,
|
170 |
-
causal_cache: torch.Tensor = None,
|
171 |
-
):
|
172 |
-
# causal_cache shape: [self.num_layers, 1, causal_left_pad, hidden_size]
|
173 |
-
|
174 |
-
new_causal_cache_list = list()
|
175 |
-
for idx, causal_conv in enumerate(self.causal_conv_list):
|
176 |
-
chunk, new_causal_cache = causal_conv.forward(
|
177 |
-
inputs=chunk, causal_cache=causal_cache[idx: idx+1] if causal_cache is not None else None
|
178 |
-
)
|
179 |
-
new_causal_cache_list.append(new_causal_cache)
|
180 |
-
|
181 |
-
new_causal_cache = torch.cat(new_causal_cache_list, dim=0)
|
182 |
-
return chunk, new_causal_cache
|
183 |
-
|
184 |
-
def forward_chunk_by_chunk(self, inputs: torch.Tensor):
|
185 |
-
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
186 |
-
# batch_size = 1
|
187 |
-
|
188 |
-
batch_size, channels, time_steps, hidden_size = inputs.shape
|
189 |
-
|
190 |
-
causal_cache = None
|
191 |
-
|
192 |
-
outputs = []
|
193 |
-
for idx in range(0, time_steps, 1):
|
194 |
-
begin = idx
|
195 |
-
end = begin + self.total_causal_right_pad + 1
|
196 |
-
chunk_xs = inputs[:, :, begin:end, :]
|
197 |
-
|
198 |
-
ys, attention_cache = self.forward_chunk(
|
199 |
-
chunk=chunk_xs,
|
200 |
-
causal_cache=causal_cache,
|
201 |
-
)
|
202 |
-
# ys shape: [batch_size, channels, self.total_causal_right_pad + 1 , hidden_size]
|
203 |
-
ys = ys[:, :, :1, :]
|
204 |
-
|
205 |
-
# ys shape: [batch_size, chunk_size, hidden_size]
|
206 |
-
outputs.append(ys)
|
207 |
-
|
208 |
-
ys = torch.cat(outputs, 2)
|
209 |
-
return ys
|
210 |
-
|
211 |
-
|
212 |
-
def main2():
|
213 |
-
conv = CausalConv2d(
|
214 |
-
in_channels=1,
|
215 |
-
out_channels=64,
|
216 |
-
kernel_size=3,
|
217 |
-
bias=False,
|
218 |
-
separable=True,
|
219 |
-
f_stride=1,
|
220 |
-
lookahead=0,
|
221 |
-
)
|
222 |
-
|
223 |
-
spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
|
224 |
-
# spec shape: [batch_size, 1, time_steps, hidden_size]
|
225 |
-
cache = torch.randn(size=(1, 1, conv.causal_left_pad, 64), dtype=torch.float32)
|
226 |
-
|
227 |
-
output, _ = conv.forward(spec)
|
228 |
-
print(output.shape)
|
229 |
-
|
230 |
-
output, _ = conv.forward(spec, cache)
|
231 |
-
print(output.shape)
|
232 |
-
|
233 |
-
return
|
234 |
-
|
235 |
-
|
236 |
-
def main():
|
237 |
-
causal = CausalConv2dEncoder(
|
238 |
-
in_channels=1,
|
239 |
-
out_channels=1,
|
240 |
-
kernel_size=3,
|
241 |
-
bias=False,
|
242 |
-
separable=True,
|
243 |
-
f_stride=1,
|
244 |
-
lookahead=0,
|
245 |
-
num_layers=3,
|
246 |
-
)
|
247 |
-
|
248 |
-
spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
|
249 |
-
# spec shape: [batch_size, 1, time_steps, hidden_size]
|
250 |
-
|
251 |
-
output = causal.forward(spec)
|
252 |
-
print(output.shape)
|
253 |
-
|
254 |
-
output = causal.forward_chunk_by_chunk(spec)
|
255 |
-
print(output.shape)
|
256 |
-
|
257 |
-
return
|
258 |
-
|
259 |
-
|
260 |
-
if __name__ == '__main__':
|
261 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py
DELETED
@@ -1,100 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
4 |
-
|
5 |
-
|
6 |
-
class NXCleanUNetConfig(PretrainedConfig):
|
7 |
-
"""
|
8 |
-
https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
|
9 |
-
"""
|
10 |
-
def __init__(self,
|
11 |
-
sample_rate: int = 8000,
|
12 |
-
segment_size: int = 16000,
|
13 |
-
n_fft: int = 512,
|
14 |
-
win_length: int = 200,
|
15 |
-
hop_length: int = 80,
|
16 |
-
|
17 |
-
down_sampling_num_layers: int = 5,
|
18 |
-
down_sampling_in_channels: int = 1,
|
19 |
-
down_sampling_hidden_channels: int = 64,
|
20 |
-
down_sampling_kernel_size: int = 4,
|
21 |
-
down_sampling_stride: int = 2,
|
22 |
-
|
23 |
-
causal_in_channels: int = 64,
|
24 |
-
causal_out_channels: int = 64,
|
25 |
-
causal_kernel_size: int = 3,
|
26 |
-
causal_bias: bool = False,
|
27 |
-
causal_separable: bool = True,
|
28 |
-
causal_f_stride: int = 1,
|
29 |
-
# causal_lookahead: int = 0,
|
30 |
-
causal_num_layers: int = 3,
|
31 |
-
|
32 |
-
tsfm_hidden_size: int = 256,
|
33 |
-
tsfm_attention_heads: int = 4,
|
34 |
-
tsfm_num_blocks: int = 6,
|
35 |
-
tsfm_dropout_rate: float = 0.1,
|
36 |
-
tsfm_max_length: int = 1024,
|
37 |
-
tsfm_chunk_size: int = 4,
|
38 |
-
tsfm_num_left_chunks: int = 128,
|
39 |
-
tsfm_num_right_chunks: int = 2,
|
40 |
-
|
41 |
-
discriminator_dim: int = 16,
|
42 |
-
discriminator_in_channel: int = 2,
|
43 |
-
|
44 |
-
compress_factor: float = 0.3,
|
45 |
-
|
46 |
-
batch_size: int = 4,
|
47 |
-
learning_rate: float = 0.0005,
|
48 |
-
adam_b1: float = 0.8,
|
49 |
-
adam_b2: float = 0.99,
|
50 |
-
lr_decay: float = 0.99,
|
51 |
-
seed: int = 1234,
|
52 |
-
|
53 |
-
**kwargs
|
54 |
-
):
|
55 |
-
super(NXCleanUNetConfig, self).__init__(**kwargs)
|
56 |
-
self.sample_rate = sample_rate
|
57 |
-
self.segment_size = segment_size
|
58 |
-
self.n_fft = n_fft
|
59 |
-
self.win_length = win_length
|
60 |
-
self.hop_length = hop_length
|
61 |
-
|
62 |
-
self.down_sampling_num_layers = down_sampling_num_layers
|
63 |
-
self.down_sampling_in_channels = down_sampling_in_channels
|
64 |
-
self.down_sampling_hidden_channels = down_sampling_hidden_channels
|
65 |
-
self.down_sampling_kernel_size = down_sampling_kernel_size
|
66 |
-
self.down_sampling_stride = down_sampling_stride
|
67 |
-
|
68 |
-
self.causal_in_channels = causal_in_channels
|
69 |
-
self.causal_out_channels = causal_out_channels
|
70 |
-
self.causal_kernel_size = causal_kernel_size
|
71 |
-
self.causal_bias = causal_bias
|
72 |
-
self.causal_separable = causal_separable
|
73 |
-
self.causal_f_stride = causal_f_stride
|
74 |
-
# self.causal_lookahead = causal_lookahead
|
75 |
-
self.causal_num_layers = causal_num_layers
|
76 |
-
|
77 |
-
self.tsfm_hidden_size = tsfm_hidden_size
|
78 |
-
self.tsfm_attention_heads = tsfm_attention_heads
|
79 |
-
self.tsfm_num_blocks = tsfm_num_blocks
|
80 |
-
self.tsfm_dropout_rate = tsfm_dropout_rate
|
81 |
-
self.tsfm_max_length = tsfm_max_length
|
82 |
-
self.tsfm_chunk_size = tsfm_chunk_size
|
83 |
-
self.tsfm_num_left_chunks = tsfm_num_left_chunks
|
84 |
-
self.tsfm_num_right_chunks = tsfm_num_right_chunks
|
85 |
-
|
86 |
-
self.discriminator_dim = discriminator_dim
|
87 |
-
self.discriminator_in_channel = discriminator_in_channel
|
88 |
-
|
89 |
-
self.compress_factor = compress_factor
|
90 |
-
|
91 |
-
self.batch_size = batch_size
|
92 |
-
self.learning_rate = learning_rate
|
93 |
-
self.adam_b1 = adam_b1
|
94 |
-
self.adam_b2 = adam_b2
|
95 |
-
self.lr_decay = lr_decay
|
96 |
-
self.seed = seed
|
97 |
-
|
98 |
-
|
99 |
-
if __name__ == '__main__':
|
100 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/discriminator.py
DELETED
@@ -1,132 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import os
|
4 |
-
from typing import Optional, Union
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import torchaudio
|
9 |
-
|
10 |
-
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
11 |
-
from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
|
12 |
-
from toolbox.torchaudio.models.nx_clean_unet.utils import LearnableSigmoid1d
|
13 |
-
|
14 |
-
|
15 |
-
class MetricDiscriminator(nn.Module):
|
16 |
-
def __init__(self, config: NXCleanUNetConfig):
|
17 |
-
super(MetricDiscriminator, self).__init__()
|
18 |
-
dim = config.discriminator_dim
|
19 |
-
self.in_channel = config.discriminator_in_channel
|
20 |
-
|
21 |
-
self.n_fft = config.n_fft
|
22 |
-
self.win_length = config.win_length
|
23 |
-
self.hop_length = config.hop_length
|
24 |
-
|
25 |
-
self.transform = torchaudio.transforms.Spectrogram(
|
26 |
-
n_fft=self.n_fft,
|
27 |
-
win_length=self.win_length,
|
28 |
-
hop_length=self.hop_length,
|
29 |
-
power=1.0,
|
30 |
-
window_fn=torch.hann_window,
|
31 |
-
# window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
32 |
-
)
|
33 |
-
|
34 |
-
self.layers = nn.Sequential(
|
35 |
-
nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
|
36 |
-
nn.InstanceNorm2d(dim, affine=True),
|
37 |
-
nn.PReLU(dim),
|
38 |
-
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
|
39 |
-
nn.InstanceNorm2d(dim*2, affine=True),
|
40 |
-
nn.PReLU(dim*2),
|
41 |
-
nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
|
42 |
-
nn.InstanceNorm2d(dim*4, affine=True),
|
43 |
-
nn.PReLU(dim*4),
|
44 |
-
nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
|
45 |
-
nn.InstanceNorm2d(dim*8, affine=True),
|
46 |
-
nn.PReLU(dim*8),
|
47 |
-
nn.AdaptiveMaxPool2d(1),
|
48 |
-
nn.Flatten(),
|
49 |
-
nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
|
50 |
-
nn.Dropout(0.3),
|
51 |
-
nn.PReLU(dim*4),
|
52 |
-
nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
|
53 |
-
LearnableSigmoid1d(1)
|
54 |
-
)
|
55 |
-
|
56 |
-
def forward(self, x, y):
|
57 |
-
x = self.transform.forward(x)
|
58 |
-
y = self.transform.forward(y)
|
59 |
-
|
60 |
-
xy = torch.stack((x, y), dim=1)
|
61 |
-
return self.layers(xy)
|
62 |
-
|
63 |
-
|
64 |
-
MODEL_FILE = "discriminator.pt"
|
65 |
-
|
66 |
-
|
67 |
-
class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
|
68 |
-
def __init__(self,
|
69 |
-
config: NXCleanUNetConfig,
|
70 |
-
):
|
71 |
-
super(MetricDiscriminatorPretrainedModel, self).__init__(
|
72 |
-
config=config,
|
73 |
-
)
|
74 |
-
self.config = config
|
75 |
-
|
76 |
-
@classmethod
|
77 |
-
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
78 |
-
config = NXCleanUNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
79 |
-
|
80 |
-
model = cls(config)
|
81 |
-
|
82 |
-
if os.path.isdir(pretrained_model_name_or_path):
|
83 |
-
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
84 |
-
else:
|
85 |
-
ckpt_file = pretrained_model_name_or_path
|
86 |
-
|
87 |
-
with open(ckpt_file, "rb") as f:
|
88 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
89 |
-
model.load_state_dict(state_dict, strict=True)
|
90 |
-
return model
|
91 |
-
|
92 |
-
def save_pretrained(self,
|
93 |
-
save_directory: Union[str, os.PathLike],
|
94 |
-
state_dict: Optional[dict] = None,
|
95 |
-
):
|
96 |
-
|
97 |
-
model = self
|
98 |
-
|
99 |
-
if state_dict is None:
|
100 |
-
state_dict = model.state_dict()
|
101 |
-
|
102 |
-
os.makedirs(save_directory, exist_ok=True)
|
103 |
-
|
104 |
-
# save state dict
|
105 |
-
model_file = os.path.join(save_directory, MODEL_FILE)
|
106 |
-
torch.save(state_dict, model_file)
|
107 |
-
|
108 |
-
# save config
|
109 |
-
config_file = os.path.join(save_directory, CONFIG_FILE)
|
110 |
-
self.config.to_yaml_file(config_file)
|
111 |
-
return save_directory
|
112 |
-
|
113 |
-
|
114 |
-
def main():
|
115 |
-
config = NXCleanUNetConfig()
|
116 |
-
discriminator = MetricDiscriminator(config=config)
|
117 |
-
|
118 |
-
# shape: [batch_size, num_samples]
|
119 |
-
# x = torch.ones([4, int(4.5 * 16000)])
|
120 |
-
# y = torch.ones([4, int(4.5 * 16000)])
|
121 |
-
x = torch.ones([4, 16000])
|
122 |
-
y = torch.ones([4, 16000])
|
123 |
-
|
124 |
-
output = discriminator.forward(x, y)
|
125 |
-
print(output.shape)
|
126 |
-
print(output)
|
127 |
-
|
128 |
-
return
|
129 |
-
|
130 |
-
|
131 |
-
if __name__ == "__main__":
|
132 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav
DELETED
Binary file (63.8 kB)
|
|
toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py
DELETED
@@ -1,96 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import logging
|
4 |
-
from pathlib import Path
|
5 |
-
import shutil
|
6 |
-
import tempfile
|
7 |
-
import zipfile
|
8 |
-
|
9 |
-
import librosa
|
10 |
-
import numpy as np
|
11 |
-
import torch
|
12 |
-
import torchaudio
|
13 |
-
|
14 |
-
from project_settings import project_path
|
15 |
-
from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
|
16 |
-
from toolbox.torchaudio.models.nx_clean_unet.modeling_nx_clean_unet import NXCleanUNetPretrainedModel, MODEL_FILE
|
17 |
-
|
18 |
-
logger = logging.getLogger("toolbox")
|
19 |
-
|
20 |
-
|
21 |
-
class InferenceNXCleanUNet(object):
|
22 |
-
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
23 |
-
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
24 |
-
self.device = torch.device(device)
|
25 |
-
|
26 |
-
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
27 |
-
config, model = self.load_models(self.pretrained_model_path_or_zip_file)
|
28 |
-
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
29 |
-
|
30 |
-
self.config = config
|
31 |
-
self.model = model
|
32 |
-
self.model.to(device)
|
33 |
-
self.model.eval()
|
34 |
-
|
35 |
-
def load_models(self, model_path: str):
|
36 |
-
model_path = Path(model_path)
|
37 |
-
if model_path.name.endswith(".zip"):
|
38 |
-
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
39 |
-
out_root = Path(tempfile.gettempdir()) / "nx_denoise"
|
40 |
-
out_root.mkdir(parents=True, exist_ok=True)
|
41 |
-
f_zip.extractall(path=out_root)
|
42 |
-
model_path = out_root / model_path.stem
|
43 |
-
|
44 |
-
config = NXCleanUNetConfig.from_pretrained(
|
45 |
-
pretrained_model_name_or_path=model_path.as_posix(),
|
46 |
-
)
|
47 |
-
model = NXCleanUNetPretrainedModel.from_pretrained(
|
48 |
-
pretrained_model_name_or_path=model_path.as_posix(),
|
49 |
-
)
|
50 |
-
model.to(self.device)
|
51 |
-
model.eval()
|
52 |
-
|
53 |
-
shutil.rmtree(model_path)
|
54 |
-
return config, model
|
55 |
-
|
56 |
-
def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
|
57 |
-
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
|
58 |
-
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
|
59 |
-
|
60 |
-
# noisy_audio shape: [batch_size, num_samples]
|
61 |
-
noisy_audios = noisy_audio.to(self.device)
|
62 |
-
|
63 |
-
with torch.no_grad():
|
64 |
-
enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios)
|
65 |
-
# enhanced_audios = self.model.forward(noisy_audios)
|
66 |
-
# enhanced_audio shape: [batch_size, n_samples]
|
67 |
-
# enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
68 |
-
|
69 |
-
enhanced_audio = enhanced_audios[0]
|
70 |
-
# enhanced_audio shape: [num_samples,]
|
71 |
-
return enhanced_audio
|
72 |
-
|
73 |
-
def main():
|
74 |
-
model_zip_file = project_path / "trained_models/nx-clean-unet-14-epoch.zip"
|
75 |
-
infer_nx_clean_unet = InferenceNXCleanUNet(model_zip_file)
|
76 |
-
|
77 |
-
sample_rate = 8000
|
78 |
-
noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
|
79 |
-
noisy_audio, _ = librosa.load(
|
80 |
-
noisy_audio_file.as_posix(),
|
81 |
-
sr=sample_rate,
|
82 |
-
)
|
83 |
-
noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
|
84 |
-
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
85 |
-
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
86 |
-
|
87 |
-
enhanced_audio = infer_nx_clean_unet.enhancement_by_tensor(noisy_audio)
|
88 |
-
|
89 |
-
filename = "enhanced_audio.wav"
|
90 |
-
torchaudio.save(filename, enhanced_audio.detach().cpu().unsqueeze(dim=0), sample_rate)
|
91 |
-
|
92 |
-
return
|
93 |
-
|
94 |
-
|
95 |
-
if __name__ == '__main__':
|
96 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/loss.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
|
6 |
-
|
7 |
-
def anti_wrapping_function(x):
|
8 |
-
|
9 |
-
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
|
10 |
-
|
11 |
-
|
12 |
-
def phase_losses(phase_r, phase_g):
|
13 |
-
|
14 |
-
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
|
15 |
-
gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
|
16 |
-
iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
|
17 |
-
|
18 |
-
return ip_loss, gd_loss, iaf_loss
|
19 |
-
|
20 |
-
|
21 |
-
if __name__ == '__main__':
|
22 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/metrics.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
from joblib import Parallel, delayed
|
4 |
-
import numpy as np
|
5 |
-
from pesq import pesq
|
6 |
-
from typing import List
|
7 |
-
|
8 |
-
from pesq import cypesq
|
9 |
-
|
10 |
-
|
11 |
-
def run_pesq(clean_audio: np.ndarray,
|
12 |
-
noisy_audio: np.ndarray,
|
13 |
-
sample_rate: int = 16000,
|
14 |
-
mode: str = "wb",
|
15 |
-
) -> float:
|
16 |
-
if sample_rate == 8000 and mode == "wb":
|
17 |
-
raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
|
18 |
-
try:
|
19 |
-
pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
|
20 |
-
except cypesq.NoUtterancesError as e:
|
21 |
-
pesq_score = -1
|
22 |
-
except Exception as e:
|
23 |
-
print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
|
24 |
-
pesq_score = -1
|
25 |
-
return pesq_score
|
26 |
-
|
27 |
-
|
28 |
-
def run_batch_pesq(clean_audio_list: List[np.ndarray],
|
29 |
-
noisy_audio_list: List[np.ndarray],
|
30 |
-
sample_rate: int = 16000,
|
31 |
-
mode: str = "wb",
|
32 |
-
n_jobs: int = 4,
|
33 |
-
) -> List[float]:
|
34 |
-
parallel = Parallel(n_jobs=n_jobs)
|
35 |
-
|
36 |
-
parallel_tasks = list()
|
37 |
-
for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
|
38 |
-
parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
|
39 |
-
parallel_tasks.append(parallel_task)
|
40 |
-
|
41 |
-
pesq_score_list = parallel.__call__(parallel_tasks)
|
42 |
-
return pesq_score_list
|
43 |
-
|
44 |
-
|
45 |
-
def run_pesq_score(clean_audio_list: List[np.ndarray],
|
46 |
-
noisy_audio_list: List[np.ndarray],
|
47 |
-
sample_rate: int = 16000,
|
48 |
-
mode: str = "wb",
|
49 |
-
n_jobs: int = 4,
|
50 |
-
) -> List[float]:
|
51 |
-
|
52 |
-
pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
|
53 |
-
noisy_audio_list=noisy_audio_list,
|
54 |
-
sample_rate=sample_rate,
|
55 |
-
mode=mode,
|
56 |
-
n_jobs=n_jobs,
|
57 |
-
)
|
58 |
-
|
59 |
-
pesq_score = np.mean(pesq_score_list)
|
60 |
-
return pesq_score
|
61 |
-
|
62 |
-
|
63 |
-
def main():
|
64 |
-
clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
|
65 |
-
noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
|
66 |
-
|
67 |
-
clean_audio_list = list(clean_audio)
|
68 |
-
noisy_audio_list = list(noisy_audio)
|
69 |
-
|
70 |
-
pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
|
71 |
-
print(pesq_score_list)
|
72 |
-
|
73 |
-
pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
|
74 |
-
print(pesq_score)
|
75 |
-
|
76 |
-
return
|
77 |
-
|
78 |
-
|
79 |
-
if __name__ == "__main__":
|
80 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py
DELETED
@@ -1,401 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import os
|
4 |
-
from typing import List, Optional, Union
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
from torch.nn import functional as F
|
10 |
-
|
11 |
-
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
12 |
-
from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig
|
13 |
-
from toolbox.torchaudio.models.nx_clean_unet.transformers.transformers import TransformerEncoder
|
14 |
-
from toolbox.torchaudio.models.nx_clean_unet.causal_convolution.causal_conv2d import CausalConv2dEncoder
|
15 |
-
|
16 |
-
|
17 |
-
class DownSamplingBlock(nn.Module):
|
18 |
-
def __init__(self,
|
19 |
-
in_channels: int,
|
20 |
-
hidden_channels: int,
|
21 |
-
kernel_size: int,
|
22 |
-
stride: int,
|
23 |
-
):
|
24 |
-
super(DownSamplingBlock, self).__init__()
|
25 |
-
self.conv1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, stride)
|
26 |
-
self.relu = nn.ReLU()
|
27 |
-
self.conv2 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1)
|
28 |
-
self.glu = nn.GLU(dim=1)
|
29 |
-
|
30 |
-
def forward(self, x: torch.Tensor):
|
31 |
-
# x shape: [batch_size, 1, num_samples]
|
32 |
-
x = self.conv1.forward(x)
|
33 |
-
# x shape: [batch_size, hidden_channels, new_num_samples]
|
34 |
-
x = self.relu(x)
|
35 |
-
x = self.conv2.forward(x)
|
36 |
-
# x shape: [batch_size, hidden_channels*2, new_num_samples]
|
37 |
-
x = self.glu(x)
|
38 |
-
# x shape: [batch_size, hidden_channels, new_num_samples]
|
39 |
-
# new_num_samples = (num_samples-kernel_size) // stride + 1
|
40 |
-
return x
|
41 |
-
|
42 |
-
|
43 |
-
class DownSampling(nn.Module):
|
44 |
-
def __init__(self,
|
45 |
-
num_layers: int,
|
46 |
-
in_channels: int,
|
47 |
-
hidden_channels: int,
|
48 |
-
kernel_size: int,
|
49 |
-
stride: int,
|
50 |
-
):
|
51 |
-
super(DownSampling, self).__init__()
|
52 |
-
self.num_layers = num_layers
|
53 |
-
|
54 |
-
down_sampling_block_list = list()
|
55 |
-
for idx in range(self.num_layers):
|
56 |
-
down_sampling_block = DownSamplingBlock(
|
57 |
-
in_channels=in_channels,
|
58 |
-
hidden_channels=hidden_channels,
|
59 |
-
kernel_size=kernel_size,
|
60 |
-
stride=stride,
|
61 |
-
)
|
62 |
-
down_sampling_block_list.append(down_sampling_block)
|
63 |
-
in_channels = hidden_channels
|
64 |
-
|
65 |
-
self.down_sampling_block_list = nn.ModuleList(modules=down_sampling_block_list)
|
66 |
-
|
67 |
-
def forward(self, x: torch.Tensor):
|
68 |
-
# x shape: [batch_size, channels, num_samples]
|
69 |
-
skip_connection_list = list()
|
70 |
-
for down_sampling_block in self.down_sampling_block_list:
|
71 |
-
x = down_sampling_block.forward(x)
|
72 |
-
skip_connection_list.append(x)
|
73 |
-
# x shape: [batch_size, hidden_channels, num_samples**]
|
74 |
-
return x, skip_connection_list
|
75 |
-
|
76 |
-
|
77 |
-
class UpSamplingBlock(nn.Module):
|
78 |
-
def __init__(self,
|
79 |
-
out_channels: int,
|
80 |
-
hidden_channels: int,
|
81 |
-
kernel_size: int,
|
82 |
-
stride: int,
|
83 |
-
do_relu: bool = True,
|
84 |
-
):
|
85 |
-
super(UpSamplingBlock, self).__init__()
|
86 |
-
self.do_relu = do_relu
|
87 |
-
|
88 |
-
self.conv1 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1)
|
89 |
-
self.glu = nn.GLU(dim=1)
|
90 |
-
self.convt = nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride)
|
91 |
-
self.relu = nn.ReLU()
|
92 |
-
|
93 |
-
def forward(self, x: torch.Tensor):
|
94 |
-
# x shape: [batch_size, hidden_channels*2, num_samples]
|
95 |
-
x = self.conv1.forward(x)
|
96 |
-
# x shape: [batch_size, hidden_channels, num_samples]
|
97 |
-
x = self.glu(x)
|
98 |
-
# x shape: [batch_size, hidden_channels, num_samples]
|
99 |
-
x = self.convt.forward(x)
|
100 |
-
# x shape: [batch_size, hidden_channels, new_num_samples]
|
101 |
-
# new_num_samples = (num_samples - 1) * stride + kernel_size
|
102 |
-
if self.do_relu:
|
103 |
-
x = self.relu(x)
|
104 |
-
return x
|
105 |
-
|
106 |
-
|
107 |
-
class UpSampling(nn.Module):
|
108 |
-
def __init__(self,
|
109 |
-
num_layers: int,
|
110 |
-
out_channels: int,
|
111 |
-
hidden_channels: int,
|
112 |
-
kernel_size: int,
|
113 |
-
stride: int,
|
114 |
-
):
|
115 |
-
super(UpSampling, self).__init__()
|
116 |
-
self.num_layers = num_layers
|
117 |
-
|
118 |
-
up_sampling_block_list = list()
|
119 |
-
for idx in range(self.num_layers-1):
|
120 |
-
up_sampling_block = UpSamplingBlock(
|
121 |
-
out_channels=hidden_channels,
|
122 |
-
hidden_channels=hidden_channels,
|
123 |
-
kernel_size=kernel_size,
|
124 |
-
stride=stride,
|
125 |
-
do_relu=True,
|
126 |
-
)
|
127 |
-
up_sampling_block_list.append(up_sampling_block)
|
128 |
-
else:
|
129 |
-
up_sampling_block = UpSamplingBlock(
|
130 |
-
out_channels=out_channels,
|
131 |
-
hidden_channels=hidden_channels,
|
132 |
-
kernel_size=kernel_size,
|
133 |
-
stride=stride,
|
134 |
-
do_relu=False,
|
135 |
-
)
|
136 |
-
up_sampling_block_list.append(up_sampling_block)
|
137 |
-
self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
|
138 |
-
|
139 |
-
def forward(self, x: torch.Tensor, skip_connection_list: List[torch.Tensor]):
|
140 |
-
skip_connection_list = skip_connection_list[::-1]
|
141 |
-
|
142 |
-
# x shape: [batch_size, channels, num_samples]
|
143 |
-
for idx, up_sampling_block in enumerate(self.up_sampling_block_list):
|
144 |
-
skip_x = skip_connection_list[idx]
|
145 |
-
x = x + skip_x
|
146 |
-
# x = x + skip_x[:, :, :x.size(2)]
|
147 |
-
x = up_sampling_block.forward(x)
|
148 |
-
return x
|
149 |
-
|
150 |
-
|
151 |
-
def get_padding_length(length, num_layers: int, kernel_size: int, stride: int):
|
152 |
-
for _ in range(num_layers):
|
153 |
-
if length < kernel_size:
|
154 |
-
length = 1
|
155 |
-
else:
|
156 |
-
length = 1 + np.ceil((length - kernel_size) / stride)
|
157 |
-
|
158 |
-
for _ in range(num_layers):
|
159 |
-
length = (length - 1) * stride + kernel_size
|
160 |
-
|
161 |
-
padded_length = int(length)
|
162 |
-
return padded_length
|
163 |
-
|
164 |
-
|
165 |
-
class NXCleanUNet(nn.Module):
|
166 |
-
def __init__(self, config):
|
167 |
-
super().__init__()
|
168 |
-
self.config = config
|
169 |
-
|
170 |
-
self.down_sampling = DownSampling(
|
171 |
-
num_layers=config.down_sampling_num_layers,
|
172 |
-
in_channels=config.down_sampling_in_channels,
|
173 |
-
hidden_channels=config.down_sampling_hidden_channels,
|
174 |
-
kernel_size=config.down_sampling_kernel_size,
|
175 |
-
stride=config.down_sampling_stride,
|
176 |
-
)
|
177 |
-
self.causal_encoder = CausalConv2dEncoder(
|
178 |
-
in_channels=config.causal_in_channels,
|
179 |
-
out_channels=config.causal_out_channels,
|
180 |
-
kernel_size=config.causal_kernel_size,
|
181 |
-
bias=config.causal_bias,
|
182 |
-
separable=config.causal_separable,
|
183 |
-
f_stride=config.causal_f_stride,
|
184 |
-
lookahead=0,
|
185 |
-
num_layers=config.causal_num_layers,
|
186 |
-
)
|
187 |
-
self.transformer = TransformerEncoder(
|
188 |
-
input_size=config.down_sampling_hidden_channels,
|
189 |
-
hidden_size=config.tsfm_hidden_size,
|
190 |
-
attention_heads=config.tsfm_attention_heads,
|
191 |
-
num_blocks=config.tsfm_num_blocks,
|
192 |
-
dropout_rate=config.tsfm_dropout_rate,
|
193 |
-
chunk_size=config.tsfm_chunk_size,
|
194 |
-
num_left_chunks=config.tsfm_num_left_chunks,
|
195 |
-
num_right_chunks=config.tsfm_num_right_chunks,
|
196 |
-
)
|
197 |
-
self.up_sampling = UpSampling(
|
198 |
-
num_layers=config.down_sampling_num_layers,
|
199 |
-
out_channels=config.down_sampling_in_channels,
|
200 |
-
hidden_channels=config.down_sampling_hidden_channels,
|
201 |
-
kernel_size=config.down_sampling_kernel_size,
|
202 |
-
stride=config.down_sampling_stride,
|
203 |
-
)
|
204 |
-
|
205 |
-
def forward(self, noisy_audios: torch.Tensor):
|
206 |
-
# noisy_audios shape: [batch_size, n_samples]
|
207 |
-
noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
|
208 |
-
# noisy_audios shape: [batch_size, 1, n_samples]
|
209 |
-
|
210 |
-
n_samples = noisy_audios.shape[-1]
|
211 |
-
padded_length = get_padding_length(
|
212 |
-
n_samples,
|
213 |
-
num_layers=self.config.down_sampling_num_layers,
|
214 |
-
kernel_size=self.config.down_sampling_kernel_size,
|
215 |
-
stride=self.config.down_sampling_stride,
|
216 |
-
)
|
217 |
-
noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
|
218 |
-
|
219 |
-
bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
|
220 |
-
# bottle_neck shape: [batch_size, channels, time_steps]
|
221 |
-
|
222 |
-
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
223 |
-
# bottle_neck shape: [batch_size, time_steps, input_size]
|
224 |
-
|
225 |
-
bottle_neck = bottle_neck.unsqueeze(dim=1)
|
226 |
-
bottle_neck = self.causal_encoder.forward(bottle_neck)
|
227 |
-
bottle_neck = bottle_neck.squeeze(dim=1)
|
228 |
-
# bottle_neck shape: [batch_size, time_steps, input_size]
|
229 |
-
|
230 |
-
bottle_neck = self.transformer.forward(bottle_neck)
|
231 |
-
# bottle_neck shape: [batch_size, time_steps, input_size]
|
232 |
-
|
233 |
-
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
234 |
-
# bottle_neck shape: [batch_size, channels, time_steps]
|
235 |
-
|
236 |
-
enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
|
237 |
-
|
238 |
-
enhanced_audios = enhanced_audios[:, :, :n_samples]
|
239 |
-
# enhanced_audios shape: [batch_size, 1, n_samples]
|
240 |
-
|
241 |
-
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
242 |
-
# enhanced_audios shape: [batch_size, n_samples]
|
243 |
-
|
244 |
-
return enhanced_audios
|
245 |
-
|
246 |
-
def forward_chunk_by_chunk(self, noisy_audios: torch.Tensor):
|
247 |
-
# noisy_audios shape: [batch_size, n_samples]
|
248 |
-
noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
|
249 |
-
# noisy_audios shape: [batch_size, 1, n_samples]
|
250 |
-
|
251 |
-
n_samples = noisy_audios.shape[-1]
|
252 |
-
padded_length = get_padding_length(
|
253 |
-
n_samples,
|
254 |
-
num_layers=self.config.down_sampling_num_layers,
|
255 |
-
kernel_size=self.config.down_sampling_kernel_size,
|
256 |
-
stride=self.config.down_sampling_stride,
|
257 |
-
)
|
258 |
-
noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
|
259 |
-
|
260 |
-
bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
|
261 |
-
# bottle_neck shape: [batch_size, channels, time_steps]
|
262 |
-
|
263 |
-
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
264 |
-
# bottle_neck shape: [batch_size, time_steps, input_size]
|
265 |
-
|
266 |
-
bottle_neck = bottle_neck.unsqueeze(dim=1)
|
267 |
-
bottle_neck = self.causal_encoder.forward_chunk_by_chunk(bottle_neck)
|
268 |
-
bottle_neck = bottle_neck.squeeze(dim=1)
|
269 |
-
# bottle_neck shape: [batch_size, time_steps, input_size]
|
270 |
-
|
271 |
-
bottle_neck = self.transformer.forward_chunk_by_chunk(bottle_neck)
|
272 |
-
# bottle_neck shape: [batch_size, time_steps, input_size]
|
273 |
-
|
274 |
-
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
275 |
-
# bottle_neck shape: [batch_size, channels, time_steps]
|
276 |
-
|
277 |
-
enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
|
278 |
-
|
279 |
-
enhanced_audios = enhanced_audios[:, :, :n_samples]
|
280 |
-
# enhanced_audios shape: [batch_size, 1, n_samples]
|
281 |
-
|
282 |
-
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
283 |
-
# enhanced_audios shape: [batch_size, n_samples]
|
284 |
-
|
285 |
-
return enhanced_audios
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
MODEL_FILE = "generator.pt"
|
290 |
-
|
291 |
-
|
292 |
-
class NXCleanUNetPretrainedModel(NXCleanUNet):
|
293 |
-
def __init__(self,
|
294 |
-
config: NXCleanUNetConfig,
|
295 |
-
):
|
296 |
-
super(NXCleanUNetPretrainedModel, self).__init__(
|
297 |
-
config=config,
|
298 |
-
)
|
299 |
-
self.config = config
|
300 |
-
|
301 |
-
@classmethod
|
302 |
-
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
303 |
-
config = NXCleanUNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
304 |
-
|
305 |
-
model = cls(config)
|
306 |
-
|
307 |
-
if os.path.isdir(pretrained_model_name_or_path):
|
308 |
-
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
309 |
-
else:
|
310 |
-
ckpt_file = pretrained_model_name_or_path
|
311 |
-
|
312 |
-
with open(ckpt_file, "rb") as f:
|
313 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
314 |
-
model.load_state_dict(state_dict, strict=True)
|
315 |
-
return model
|
316 |
-
|
317 |
-
def save_pretrained(self,
|
318 |
-
save_directory: Union[str, os.PathLike],
|
319 |
-
state_dict: Optional[dict] = None,
|
320 |
-
):
|
321 |
-
|
322 |
-
model = self
|
323 |
-
|
324 |
-
if state_dict is None:
|
325 |
-
state_dict = model.state_dict()
|
326 |
-
|
327 |
-
os.makedirs(save_directory, exist_ok=True)
|
328 |
-
|
329 |
-
# save state dict
|
330 |
-
model_file = os.path.join(save_directory, MODEL_FILE)
|
331 |
-
torch.save(state_dict, model_file)
|
332 |
-
|
333 |
-
# save config
|
334 |
-
config_file = os.path.join(save_directory, CONFIG_FILE)
|
335 |
-
self.config.to_yaml_file(config_file)
|
336 |
-
return save_directory
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
def main2():
|
341 |
-
|
342 |
-
config = NXCleanUNetConfig()
|
343 |
-
down_sampling = DownSampling(
|
344 |
-
num_layers=config.down_sampling_num_layers,
|
345 |
-
in_channels=config.down_sampling_in_channels,
|
346 |
-
hidden_channels=config.down_sampling_hidden_channels,
|
347 |
-
kernel_size=config.down_sampling_kernel_size,
|
348 |
-
stride=config.down_sampling_stride,
|
349 |
-
)
|
350 |
-
up_sampling = UpSampling(
|
351 |
-
num_layers=config.down_sampling_num_layers,
|
352 |
-
out_channels=config.down_sampling_in_channels,
|
353 |
-
hidden_channels=config.down_sampling_hidden_channels,
|
354 |
-
kernel_size=config.down_sampling_kernel_size,
|
355 |
-
stride=config.down_sampling_stride,
|
356 |
-
)
|
357 |
-
|
358 |
-
# shape: [batch_size, channels, num_samples]
|
359 |
-
# min length: 94, stride: 32, 32 == 2**5
|
360 |
-
# x = torch.ones([4, 1, 94])
|
361 |
-
# x = torch.ones([4, 1, 126])
|
362 |
-
# x = torch.ones([4, 1, 158])
|
363 |
-
x = torch.ones([4, 1, 190])
|
364 |
-
|
365 |
-
length = x.shape[-1]
|
366 |
-
padded_length = get_padding_length(
|
367 |
-
length,
|
368 |
-
num_layers=config.down_sampling_num_layers,
|
369 |
-
kernel_size=config.down_sampling_kernel_size,
|
370 |
-
stride=config.down_sampling_stride,
|
371 |
-
)
|
372 |
-
x = F.pad(input=x, pad=(0, padded_length - length), mode="constant", value=0)
|
373 |
-
# print(x)
|
374 |
-
print(x.shape)
|
375 |
-
bottle_neck = down_sampling.forward(x)
|
376 |
-
print("-" * 150)
|
377 |
-
x = up_sampling.forward(bottle_neck)
|
378 |
-
print(x.shape)
|
379 |
-
return
|
380 |
-
|
381 |
-
|
382 |
-
def main():
|
383 |
-
|
384 |
-
config = NXCleanUNetConfig()
|
385 |
-
|
386 |
-
# shape: [batch_size, channels, num_samples]
|
387 |
-
# min length: 94, stride: 32, 32 == 2**5
|
388 |
-
# x = torch.ones([4, 94])
|
389 |
-
# x = torch.ones([4, 126])
|
390 |
-
# x = torch.ones([4, 158])
|
391 |
-
# x = torch.ones([4, 190])
|
392 |
-
x = torch.ones([4, 16000])
|
393 |
-
|
394 |
-
model = NXCleanUNet(config)
|
395 |
-
enhanced_audios = model.forward(x)
|
396 |
-
print(enhanced_audios.shape)
|
397 |
-
return
|
398 |
-
|
399 |
-
|
400 |
-
if __name__ == "__main__":
|
401 |
-
main2()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/transformers/attention.py
DELETED
@@ -1,270 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import math
|
4 |
-
from typing import Tuple
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
|
9 |
-
|
10 |
-
class MultiHeadSelfAttention(nn.Module):
|
11 |
-
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
|
12 |
-
"""
|
13 |
-
:param n_head: int. the number of heads.
|
14 |
-
:param n_feat: int. the number of features.
|
15 |
-
:param dropout_rate: float. dropout rate.
|
16 |
-
"""
|
17 |
-
super().__init__()
|
18 |
-
assert n_feat % n_head == 0
|
19 |
-
# We assume d_v always equals d_k
|
20 |
-
self.d_k = n_feat // n_head
|
21 |
-
self.h = n_head
|
22 |
-
self.linear_q = nn.Linear(n_feat, n_feat)
|
23 |
-
self.linear_k = nn.Linear(n_feat, n_feat)
|
24 |
-
self.linear_v = nn.Linear(n_feat, n_feat)
|
25 |
-
self.linear_out = nn.Linear(n_feat, n_feat)
|
26 |
-
self.dropout = nn.Dropout(p=dropout_rate)
|
27 |
-
|
28 |
-
def forward_qkv(self,
|
29 |
-
query: torch.Tensor,
|
30 |
-
key: torch.Tensor,
|
31 |
-
value: torch.Tensor
|
32 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
33 |
-
"""
|
34 |
-
transform query, key and value.
|
35 |
-
:param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
|
36 |
-
:param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
|
37 |
-
:param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
|
38 |
-
:return:
|
39 |
-
"""
|
40 |
-
n_batch = query.size(0)
|
41 |
-
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
42 |
-
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
43 |
-
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
44 |
-
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
45 |
-
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
46 |
-
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
47 |
-
|
48 |
-
return q, k, v
|
49 |
-
|
50 |
-
def forward_attention(self,
|
51 |
-
value: torch.Tensor,
|
52 |
-
scores: torch.Tensor,
|
53 |
-
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
54 |
-
) -> torch.Tensor:
|
55 |
-
"""
|
56 |
-
compute attention context vector.
|
57 |
-
:param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
|
58 |
-
:param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
|
59 |
-
:param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
|
60 |
-
(batch_size, time1, time2), (0, 0, 0) means fake mask.
|
61 |
-
:return: torch.Tensor. transformed value. (batch_size, time1, d_model).
|
62 |
-
weighted by the attention score (batch_size, time1, time2).
|
63 |
-
"""
|
64 |
-
n_batch = value.size(0)
|
65 |
-
# NOTE: When will `if mask.size(2) > 0` be True?
|
66 |
-
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
67 |
-
# 1st chunk to ease the onnx export.]
|
68 |
-
# 2. pytorch training
|
69 |
-
if mask.size(2) > 0: # time2 > 0
|
70 |
-
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
71 |
-
# For last chunk, time2 might be larger than scores.size(-1)
|
72 |
-
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
73 |
-
scores = scores.masked_fill(mask, -float('inf'))
|
74 |
-
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
75 |
-
|
76 |
-
# NOTE: When will `if mask.size(2) > 0` be False?
|
77 |
-
# 1. onnx(16/-1, -1/-1, 16/0)
|
78 |
-
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
79 |
-
else:
|
80 |
-
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
81 |
-
|
82 |
-
p_attn = self.dropout(attn)
|
83 |
-
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
84 |
-
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
|
85 |
-
|
86 |
-
return self.linear_out(x) # (batch, time1, n_feat)
|
87 |
-
|
88 |
-
def forward(self,
|
89 |
-
x: torch.Tensor,
|
90 |
-
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
91 |
-
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
92 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
93 |
-
|
94 |
-
q, k, v = self.forward_qkv(x, x, x)
|
95 |
-
|
96 |
-
if cache.size(0) > 0:
|
97 |
-
key_cache, value_cache = torch.split(
|
98 |
-
cache, cache.size(-1) // 2, dim=-1)
|
99 |
-
k = torch.cat([key_cache, k], dim=2)
|
100 |
-
v = torch.cat([value_cache, v], dim=2)
|
101 |
-
# NOTE: We do cache slicing in encoder.forward_chunk, since it's
|
102 |
-
# non-trivial to calculate `next_cache_start` here.
|
103 |
-
new_cache = torch.cat((k, v), dim=-1)
|
104 |
-
|
105 |
-
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
106 |
-
return self.forward_attention(v, scores, mask), new_cache
|
107 |
-
|
108 |
-
|
109 |
-
class RelativeMultiHeadSelfAttention(nn.Module):
|
110 |
-
|
111 |
-
def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120):
|
112 |
-
"""
|
113 |
-
:param n_head: int. the number of heads.
|
114 |
-
:param n_feat: int. the number of features.
|
115 |
-
:param dropout_rate: float. dropout rate.
|
116 |
-
:param max_relative_position: int. maximum relative position for relative position encoding.
|
117 |
-
"""
|
118 |
-
super().__init__()
|
119 |
-
assert n_feat % n_head == 0
|
120 |
-
# We assume d_v always equals d_k
|
121 |
-
self.d_k = n_feat // n_head
|
122 |
-
self.h = n_head
|
123 |
-
self.linear_q = nn.Linear(n_feat, n_feat)
|
124 |
-
self.linear_k = nn.Linear(n_feat, n_feat)
|
125 |
-
self.linear_v = nn.Linear(n_feat, n_feat)
|
126 |
-
self.linear_out = nn.Linear(n_feat, n_feat)
|
127 |
-
self.dropout = nn.Dropout(p=dropout_rate)
|
128 |
-
|
129 |
-
# Relative position encoding
|
130 |
-
self.max_relative_position = max_relative_position
|
131 |
-
self.relative_position_k = nn.Parameter(torch.randn(max_relative_position * 2 + 1, self.d_k))
|
132 |
-
|
133 |
-
def forward_qkv(self,
|
134 |
-
query: torch.Tensor,
|
135 |
-
key: torch.Tensor,
|
136 |
-
value: torch.Tensor
|
137 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
138 |
-
"""
|
139 |
-
transform query, key and value.
|
140 |
-
:param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
|
141 |
-
:param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
|
142 |
-
:param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
|
143 |
-
:return:
|
144 |
-
"""
|
145 |
-
n_batch = query.size(0)
|
146 |
-
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
147 |
-
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
148 |
-
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
149 |
-
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
150 |
-
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
151 |
-
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
152 |
-
|
153 |
-
return q, k, v
|
154 |
-
|
155 |
-
def forward_attention(self,
|
156 |
-
value: torch.Tensor,
|
157 |
-
scores: torch.Tensor,
|
158 |
-
mask: torch.Tensor = None
|
159 |
-
) -> torch.Tensor:
|
160 |
-
"""
|
161 |
-
compute attention context vector.
|
162 |
-
:param value: torch.Tensor. transformed value. shape=(batch_size, n_head, key_time_steps, d_k).
|
163 |
-
:param scores: torch.Tensor. attention score. shape=(batch_size, n_head, query_time_steps, key_time_steps).
|
164 |
-
:param mask: torch.Tensor. mask. shape=(batch_size, 1, key_time_steps) or (batch_size, query_time_steps, key_time_steps).
|
165 |
-
:return: torch.Tensor. transformed value. (batch_size, query_time_steps, d_model).
|
166 |
-
weighted by the attention score (batch_size, query_time_steps, key_time_steps).
|
167 |
-
"""
|
168 |
-
n_batch = value.size(0)
|
169 |
-
if mask is not None:
|
170 |
-
mask = mask.unsqueeze(1).eq(0)
|
171 |
-
# mask shape: [batch_size, 1, query_time_steps, key_time_steps]
|
172 |
-
scores = scores.masked_fill(mask, -float('inf'))
|
173 |
-
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
|
174 |
-
else:
|
175 |
-
attn = torch.softmax(scores, dim=-1)
|
176 |
-
# attn shape: [batch_size, n_head, query_time_steps, key_time_steps]
|
177 |
-
|
178 |
-
p_attn = self.dropout(attn)
|
179 |
-
|
180 |
-
x = torch.matmul(p_attn, value)
|
181 |
-
# x shape: [batch_size, n_head, query_time_steps, d_k]
|
182 |
-
x = x.transpose(1, 2)
|
183 |
-
# x shape: [batch_size, query_time_steps, n_head, d_k]
|
184 |
-
|
185 |
-
x = x.contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
|
186 |
-
# x shape: [batch_size, query_time_steps, n_head * d_k]
|
187 |
-
# x shape: [batch_size, query_time_steps, n_feat]
|
188 |
-
|
189 |
-
x = self.linear_out(x)
|
190 |
-
# x shape: [batch_size, query_time_steps, n_feat]
|
191 |
-
return x
|
192 |
-
|
193 |
-
def relative_position_encoding(self, length: int) -> torch.Tensor:
|
194 |
-
"""
|
195 |
-
Generate relative position encoding.
|
196 |
-
:param length: int. length of the sequence.
|
197 |
-
:return: torch.Tensor. relative position encoding. shape=(length, length, d_k).
|
198 |
-
"""
|
199 |
-
range_vec = torch.arange(length)
|
200 |
-
distance_mat = range_vec.unsqueeze(0) - range_vec.unsqueeze(1)
|
201 |
-
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
|
202 |
-
final_mat = distance_mat_clipped + self.max_relative_position
|
203 |
-
return final_mat
|
204 |
-
|
205 |
-
def forward(self,
|
206 |
-
x: torch.Tensor,
|
207 |
-
mask: torch.Tensor = None,
|
208 |
-
cache: torch.Tensor = None
|
209 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
210 |
-
"""
|
211 |
-
|
212 |
-
:param x:
|
213 |
-
:param mask:
|
214 |
-
:param cache: Tensor, shape: [1, n_heads, time_steps, dim]
|
215 |
-
:return:
|
216 |
-
"""
|
217 |
-
# attention! self attention.
|
218 |
-
|
219 |
-
q, k, v = self.forward_qkv(x, x, x)
|
220 |
-
# q k v shape: [batch_size, self.h, query_time_steps, self.d_k]
|
221 |
-
|
222 |
-
if cache is not None:
|
223 |
-
key_cache, value_cache = torch.split(
|
224 |
-
cache, cache.size(-1) // 2, dim=-1)
|
225 |
-
k = torch.cat([key_cache, k], dim=2)
|
226 |
-
v = torch.cat([value_cache, v], dim=2)
|
227 |
-
|
228 |
-
# new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2]
|
229 |
-
new_cache = torch.cat((k, v), dim=-1)
|
230 |
-
|
231 |
-
# native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
|
232 |
-
native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
233 |
-
|
234 |
-
# Compute relative position encoding
|
235 |
-
q_length, k_length = q.size(2), k.size(2)
|
236 |
-
relative_position = self.relative_position_encoding(k_length)
|
237 |
-
|
238 |
-
relative_position = relative_position[-q_length:]
|
239 |
-
|
240 |
-
relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1)
|
241 |
-
|
242 |
-
relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k)
|
243 |
-
relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k)
|
244 |
-
|
245 |
-
relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
|
246 |
-
# relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
|
247 |
-
|
248 |
-
# score
|
249 |
-
scores = native_scores + relative_position_scores
|
250 |
-
|
251 |
-
return self.forward_attention(v, scores, mask), new_cache
|
252 |
-
|
253 |
-
|
254 |
-
def main():
|
255 |
-
rel_attention = RelativeMultiHeadSelfAttention(n_head=4, n_feat=256, dropout_rate=0.1)
|
256 |
-
|
257 |
-
x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
|
258 |
-
xt, new_cache = rel_attention.forward(x, x, x)
|
259 |
-
|
260 |
-
# x = torch.ones(size=(1, 1, 256), dtype=torch.float32)
|
261 |
-
# cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32)
|
262 |
-
# xt, new_cache = rel_attention.forward(x, x, x, cache=cache)
|
263 |
-
|
264 |
-
print(xt.shape)
|
265 |
-
print(new_cache.shape)
|
266 |
-
return
|
267 |
-
|
268 |
-
|
269 |
-
if __name__ == '__main__':
|
270 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/transformers/mask.py
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
def make_pad_mask(lengths: torch.Tensor,
|
7 |
-
max_len: int = 0,
|
8 |
-
) -> torch.Tensor:
|
9 |
-
batch_size = lengths.size(0)
|
10 |
-
max_len = max_len if max_len > 0 else lengths.max().item()
|
11 |
-
seq_range = torch.arange(
|
12 |
-
0,
|
13 |
-
max_len,
|
14 |
-
dtype=torch.int64,
|
15 |
-
device=lengths.device
|
16 |
-
)
|
17 |
-
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
18 |
-
seq_length_expand = lengths.unsqueeze(-1)
|
19 |
-
mask = seq_range_expand >= seq_length_expand
|
20 |
-
return mask
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
def subsequent_chunk_mask(
|
25 |
-
size: int,
|
26 |
-
chunk_size: int,
|
27 |
-
num_left_chunks: int = -1,
|
28 |
-
num_right_chunks: int = 0,
|
29 |
-
device: torch.device = torch.device("cpu"),
|
30 |
-
) -> torch.Tensor:
|
31 |
-
"""
|
32 |
-
Create mask for subsequent steps (size, size) with chunk size,
|
33 |
-
this is for streaming encoder
|
34 |
-
|
35 |
-
Examples:
|
36 |
-
> subsequent_chunk_mask(4, 2)
|
37 |
-
[[1, 1, 0, 0],
|
38 |
-
[1, 1, 0, 0],
|
39 |
-
[1, 1, 1, 1],
|
40 |
-
[1, 1, 1, 1]]
|
41 |
-
|
42 |
-
:param size: int. size of mask.
|
43 |
-
:param chunk_size: int. size of chunk.
|
44 |
-
:param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks.
|
45 |
-
:param num_right_chunks: int. number of right chunks.
|
46 |
-
:param device: torch.device. "cpu" or "cuda" or torch.Tensor.device.
|
47 |
-
:return: torch.Tensor. mask
|
48 |
-
"""
|
49 |
-
|
50 |
-
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
51 |
-
for i in range(size):
|
52 |
-
if num_left_chunks < 0:
|
53 |
-
start = 0
|
54 |
-
else:
|
55 |
-
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
56 |
-
ending = min((i // chunk_size + 1 + num_right_chunks) * chunk_size, size)
|
57 |
-
ret[i, start:ending] = True
|
58 |
-
return ret
|
59 |
-
|
60 |
-
|
61 |
-
def main():
|
62 |
-
chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2)
|
63 |
-
print(chunk_mask)
|
64 |
-
|
65 |
-
chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
|
66 |
-
print(chunk_mask)
|
67 |
-
|
68 |
-
chunk_mask = subsequent_chunk_mask(size=9, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
|
69 |
-
print(chunk_mask)
|
70 |
-
return
|
71 |
-
|
72 |
-
|
73 |
-
if __name__ == '__main__':
|
74 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/transformers/transformers.py
DELETED
@@ -1,266 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
from typing import Dict, Optional, Tuple, List, Union
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
|
8 |
-
from toolbox.torchaudio.models.nx_clean_unet.transformers.mask import subsequent_chunk_mask
|
9 |
-
from toolbox.torchaudio.models.nx_clean_unet.transformers.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
|
10 |
-
|
11 |
-
|
12 |
-
class PositionwiseFeedForward(nn.Module):
|
13 |
-
def __init__(self,
|
14 |
-
input_dim: int,
|
15 |
-
hidden_units: int,
|
16 |
-
dropout_rate: float,
|
17 |
-
activation: torch.nn.Module = torch.nn.ReLU()):
|
18 |
-
"""
|
19 |
-
FeedForward are applied on each position of the sequence.
|
20 |
-
the output dim is same with the input dim.
|
21 |
-
|
22 |
-
:param input_dim: int. input dimension.
|
23 |
-
:param hidden_units: int. the number of hidden units.
|
24 |
-
:param dropout_rate: float. dropout rate.
|
25 |
-
:param activation: torch.nn.Module. activation function.
|
26 |
-
"""
|
27 |
-
super(PositionwiseFeedForward, self).__init__()
|
28 |
-
self.w_1 = torch.nn.Linear(input_dim, hidden_units)
|
29 |
-
self.activation = activation
|
30 |
-
self.dropout = torch.nn.Dropout(dropout_rate)
|
31 |
-
self.w_2 = torch.nn.Linear(hidden_units, input_dim)
|
32 |
-
|
33 |
-
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
34 |
-
"""
|
35 |
-
Forward function.
|
36 |
-
:param xs: torch.Tensor. input tensor. shape=(batch_size, max_length, dim).
|
37 |
-
:return: output tensor. shape=(batch_size, max_length, dim).
|
38 |
-
"""
|
39 |
-
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
40 |
-
|
41 |
-
|
42 |
-
class TransformerBlock(nn.Module):
|
43 |
-
def __init__(self,
|
44 |
-
input_dim: int,
|
45 |
-
dropout_rate: float = 0.1,
|
46 |
-
n_heads: int = 4,
|
47 |
-
max_relative_position: int = 5120
|
48 |
-
):
|
49 |
-
super().__init__()
|
50 |
-
self.norm1 = nn.LayerNorm(input_dim, eps=1e-5)
|
51 |
-
self.attention = RelativeMultiHeadSelfAttention(
|
52 |
-
n_head=n_heads,
|
53 |
-
n_feat=input_dim,
|
54 |
-
dropout_rate=dropout_rate,
|
55 |
-
max_relative_position=max_relative_position,
|
56 |
-
)
|
57 |
-
|
58 |
-
self.dropout1 = nn.Dropout(dropout_rate)
|
59 |
-
self.norm2 = nn.LayerNorm(input_dim, eps=1e-5)
|
60 |
-
self.ffn = PositionwiseFeedForward(
|
61 |
-
input_dim=input_dim,
|
62 |
-
hidden_units=input_dim,
|
63 |
-
dropout_rate=dropout_rate
|
64 |
-
)
|
65 |
-
self.dropout2 = nn.Dropout(dropout_rate)
|
66 |
-
self.norm3 = nn.LayerNorm(input_dim, eps=1e-5)
|
67 |
-
|
68 |
-
def forward(
|
69 |
-
self,
|
70 |
-
x: torch.Tensor,
|
71 |
-
mask: torch.Tensor = None,
|
72 |
-
attention_cache: torch.Tensor = None,
|
73 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
74 |
-
"""
|
75 |
-
|
76 |
-
:param x: torch.Tensor. shape=(batch_size, time, input_dim).
|
77 |
-
:param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time).
|
78 |
-
:param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE
|
79 |
-
shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim.
|
80 |
-
:return:
|
81 |
-
torch.Tensor: Output tensor (batch_size, time, input_dim).
|
82 |
-
torch.Tensor: att_cache tensor, (batch_size=1, head, cache_t1 + time, d_k * 2).
|
83 |
-
"""
|
84 |
-
|
85 |
-
xt = self.norm1(x)
|
86 |
-
|
87 |
-
x_att, new_att_cache = self.attention.forward(
|
88 |
-
xt, mask=mask, cache=attention_cache
|
89 |
-
)
|
90 |
-
x = x + self.dropout1(xt)
|
91 |
-
xt = self.norm2(x)
|
92 |
-
xt = self.ffn.forward(xt)
|
93 |
-
x = x + self.dropout2(xt)
|
94 |
-
|
95 |
-
x = self.norm3(x)
|
96 |
-
|
97 |
-
return x, new_att_cache
|
98 |
-
|
99 |
-
|
100 |
-
class TransformerEncoder(nn.Module):
|
101 |
-
"""
|
102 |
-
https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/encoder.py#L364
|
103 |
-
"""
|
104 |
-
def __init__(self,
|
105 |
-
input_size: int = 64,
|
106 |
-
hidden_size: int = 256,
|
107 |
-
attention_heads: int = 4,
|
108 |
-
num_blocks: int = 6,
|
109 |
-
dropout_rate: float = 0.1,
|
110 |
-
max_relative_position: int = 1024,
|
111 |
-
chunk_size: int = 1,
|
112 |
-
num_left_chunks: int = 128,
|
113 |
-
num_right_chunks: int = 2,
|
114 |
-
):
|
115 |
-
super().__init__()
|
116 |
-
self.input_size = input_size
|
117 |
-
self.hidden_size = hidden_size
|
118 |
-
|
119 |
-
self.max_relative_position = max_relative_position
|
120 |
-
self.chunk_size = chunk_size
|
121 |
-
self.num_left_chunks = num_left_chunks
|
122 |
-
self.num_right_chunks = num_right_chunks
|
123 |
-
|
124 |
-
self.input_linear = nn.Linear(
|
125 |
-
in_features=self.input_size,
|
126 |
-
out_features=self.hidden_size,
|
127 |
-
)
|
128 |
-
|
129 |
-
self.encoder_layer_list = torch.nn.ModuleList([
|
130 |
-
TransformerBlock(
|
131 |
-
input_dim=hidden_size,
|
132 |
-
n_heads=attention_heads,
|
133 |
-
dropout_rate=dropout_rate,
|
134 |
-
max_relative_position=max_relative_position,
|
135 |
-
) for _ in range(num_blocks)
|
136 |
-
])
|
137 |
-
|
138 |
-
self.output_linear = nn.Linear(
|
139 |
-
in_features=self.hidden_size,
|
140 |
-
out_features=self.input_size,
|
141 |
-
)
|
142 |
-
|
143 |
-
def forward(self,
|
144 |
-
xs: torch.Tensor,
|
145 |
-
):
|
146 |
-
"""
|
147 |
-
:param xs: Tensor, shape: [batch_size, time_steps, input_size]
|
148 |
-
:return: Tensor, shape: [batch_size, time_steps, input_size]
|
149 |
-
"""
|
150 |
-
batch_size, time_steps, _ = xs.shape
|
151 |
-
# xs shape: [batch_size, time_steps, input_size]
|
152 |
-
xs = self.input_linear.forward(xs)
|
153 |
-
# xs shape: [batch_size, time_steps, hidden_size]
|
154 |
-
|
155 |
-
chunk_masks = subsequent_chunk_mask(
|
156 |
-
size=time_steps,
|
157 |
-
chunk_size=self.chunk_size,
|
158 |
-
num_left_chunks=self.num_left_chunks,
|
159 |
-
num_right_chunks=self.num_right_chunks,
|
160 |
-
)
|
161 |
-
chunk_masks = chunk_masks.to(xs.device)
|
162 |
-
# chunk_masks shape: [1, time_steps, time_steps]
|
163 |
-
chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps))
|
164 |
-
# chunk_masks shape: [batch_size, time_steps, time_steps]
|
165 |
-
|
166 |
-
for encoder_layer in self.encoder_layer_list:
|
167 |
-
xs, _ = encoder_layer.forward(xs, chunk_masks)
|
168 |
-
|
169 |
-
# xs shape: [batch_size, time_steps, hidden_size]
|
170 |
-
xs = self.output_linear.forward(xs)
|
171 |
-
# xs shape: [batch_size, time_steps, input_size]
|
172 |
-
|
173 |
-
return xs
|
174 |
-
|
175 |
-
def forward_chunk(self,
|
176 |
-
xs: torch.Tensor,
|
177 |
-
max_att_cache_length: int,
|
178 |
-
attention_cache: torch.Tensor = None,
|
179 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
180 |
-
"""
|
181 |
-
Forward just one chunk.
|
182 |
-
:param xs: torch.Tensor. chunk input, with shape (b=1, time, mel-dim),
|
183 |
-
where `time == (chunk_size - 1) * subsample_rate + subsample.right_context + 1`
|
184 |
-
:param max_att_cache_length:
|
185 |
-
:param attention_cache: torch.Tensor.
|
186 |
-
:return:
|
187 |
-
"""
|
188 |
-
# xs shape: [batch_size, time_steps, input_size]
|
189 |
-
xs = self.input_linear.forward(xs)
|
190 |
-
# xs shape: [batch_size, time_steps, hidden_size]
|
191 |
-
|
192 |
-
r_att_cache = []
|
193 |
-
for idx, encoder_layer in enumerate(self.encoder_layer_list):
|
194 |
-
xs, new_att_cache = encoder_layer.forward(
|
195 |
-
x=xs, attention_cache=attention_cache[idx: idx+1] if attention_cache is not None else None,
|
196 |
-
)
|
197 |
-
if new_att_cache.size(2) > max_att_cache_length:
|
198 |
-
begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
|
199 |
-
end = self.num_right_chunks * self.chunk_size
|
200 |
-
new_att_cache = new_att_cache[:, :, -begin:-end, :]
|
201 |
-
r_att_cache.append(new_att_cache)
|
202 |
-
|
203 |
-
r_att_cache = torch.cat(r_att_cache, dim=0)
|
204 |
-
|
205 |
-
return xs, r_att_cache
|
206 |
-
|
207 |
-
def forward_chunk_by_chunk(
|
208 |
-
self,
|
209 |
-
xs: torch.Tensor,
|
210 |
-
) -> torch.Tensor:
|
211 |
-
|
212 |
-
batch_size, time_steps, _ = xs.shape
|
213 |
-
|
214 |
-
# attention_cache shape: [num_blocks, attention_heads, self.num_left_chunks * self.chunk_size, n_heads * d_k * 2]
|
215 |
-
max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
|
216 |
-
attention_cache = None
|
217 |
-
|
218 |
-
outputs = []
|
219 |
-
for idx in range(0, time_steps - self.chunk_size, self.chunk_size):
|
220 |
-
begin = idx
|
221 |
-
end = begin + self.chunk_size * (self.num_right_chunks + 1)
|
222 |
-
chunk_xs = xs[:, begin:end, :]
|
223 |
-
# print(f"begin: {begin}, end: {end}, length: {chunk_xs.size(1)}")
|
224 |
-
|
225 |
-
ys, attention_cache = self.forward_chunk(
|
226 |
-
xs=chunk_xs,
|
227 |
-
max_att_cache_length=max_att_cache_length,
|
228 |
-
attention_cache=attention_cache,
|
229 |
-
)
|
230 |
-
# ys shape: [batch_size, self.chunk_size * (self.num_right_chunks + 1), hidden_size]
|
231 |
-
ys = ys[:, :self.chunk_size, :]
|
232 |
-
|
233 |
-
# ys shape: [batch_size, chunk_size, hidden_size]
|
234 |
-
ys = self.output_linear.forward(ys)
|
235 |
-
# ys shape: [batch_size, chunk_size, input_size]
|
236 |
-
|
237 |
-
outputs.append(ys)
|
238 |
-
|
239 |
-
ys = torch.cat(outputs, 1)
|
240 |
-
return ys
|
241 |
-
|
242 |
-
|
243 |
-
def main():
|
244 |
-
|
245 |
-
encoder = TransformerEncoder(
|
246 |
-
input_size=64,
|
247 |
-
hidden_size=256,
|
248 |
-
attention_heads=4,
|
249 |
-
num_blocks=6,
|
250 |
-
dropout_rate=0.1,
|
251 |
-
)
|
252 |
-
print(encoder)
|
253 |
-
|
254 |
-
x = torch.ones([4, 200, 64])
|
255 |
-
|
256 |
-
y = encoder.forward(xs=x)
|
257 |
-
print(y.shape)
|
258 |
-
|
259 |
-
# y = encoder.forward_chunk_by_chunk(xs=x)
|
260 |
-
# print(y.shape)
|
261 |
-
|
262 |
-
return
|
263 |
-
|
264 |
-
|
265 |
-
if __name__ == '__main__':
|
266 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/utils.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
|
6 |
-
|
7 |
-
class LearnableSigmoid1d(nn.Module):
|
8 |
-
def __init__(self, in_features, beta=1):
|
9 |
-
super().__init__()
|
10 |
-
self.beta = beta
|
11 |
-
self.slope = nn.Parameter(torch.ones(in_features))
|
12 |
-
self.slope.requiresGrad = True
|
13 |
-
|
14 |
-
def forward(self, x):
|
15 |
-
# x shape: [batch_size, time_steps, spec_bins]
|
16 |
-
return self.beta * torch.sigmoid(self.slope * x)
|
17 |
-
|
18 |
-
|
19 |
-
def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
|
20 |
-
|
21 |
-
hann_window = torch.hann_window(win_size).to(y.device)
|
22 |
-
stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
|
23 |
-
center=center, pad_mode='reflect', normalized=False, return_complex=True)
|
24 |
-
stft_spec = torch.view_as_real(stft_spec)
|
25 |
-
mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9)
|
26 |
-
pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5)
|
27 |
-
# Magnitude Compression
|
28 |
-
mag = torch.pow(mag, compress_factor)
|
29 |
-
com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1)
|
30 |
-
|
31 |
-
return mag, pha, com
|
32 |
-
|
33 |
-
|
34 |
-
def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
|
35 |
-
# Magnitude Decompression
|
36 |
-
mag = torch.pow(mag, (1.0/compress_factor))
|
37 |
-
com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha))
|
38 |
-
hann_window = torch.hann_window(win_size).to(com.device)
|
39 |
-
wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
|
40 |
-
|
41 |
-
return wav
|
42 |
-
|
43 |
-
|
44 |
-
if __name__ == '__main__':
|
45 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
model_name: "nx_clean_unet"
|
2 |
-
|
3 |
-
sample_rate: 8000
|
4 |
-
segment_size: 16000
|
5 |
-
n_fft: 512
|
6 |
-
win_size: 200
|
7 |
-
hop_size: 80
|
8 |
-
# 因为 hop_size 取 80,则相当于 stft 的时间步是 10ms 一步,所以降采样也考虑到差不多的分辨率。
|
9 |
-
|
10 |
-
# 2**down_sampling_num_layers,
|
11 |
-
# 例如 2**6=64 就意味着 64 个值在降采样之后是一个时间步,
|
12 |
-
# 则一步是 64/sample_rate = 0.008秒。
|
13 |
-
# 那么 tsfm_chunk_size=2 则为16ms,tsfm_chunk_size=4 则为32ms
|
14 |
-
# 假设每次向左看1秒,向右看30ms,则:
|
15 |
-
# tsfm_chunk_size=1,tsfm_num_left_chunks=128,tsfm_num_right_chunks=4
|
16 |
-
# tsfm_chunk_size=2,tsfm_num_left_chunks=64,tsfm_num_right_chunks=2
|
17 |
-
# tsfm_chunk_size=4,tsfm_num_left_chunks=32,tsfm_num_right_chunks=1
|
18 |
-
down_sampling_num_layers: 6
|
19 |
-
down_sampling_in_channels: 1
|
20 |
-
down_sampling_hidden_channels: 64
|
21 |
-
down_sampling_kernel_size: 4
|
22 |
-
down_sampling_stride: 2
|
23 |
-
|
24 |
-
causal_in_channels: 1
|
25 |
-
causal_out_channels: 1
|
26 |
-
causal_kernel_size: 3
|
27 |
-
causal_bias: false
|
28 |
-
causal_separable: true
|
29 |
-
causal_f_stride: 1
|
30 |
-
causal_num_layers: 3
|
31 |
-
|
32 |
-
tsfm_hidden_size: 256
|
33 |
-
tsfm_attention_heads: 8
|
34 |
-
tsfm_num_blocks: 6
|
35 |
-
tsfm_dropout_rate: 0.1
|
36 |
-
tsfm_max_length: 512
|
37 |
-
tsfm_chunk_size: 1
|
38 |
-
tsfm_num_left_chunks: 128
|
39 |
-
tsfm_num_right_chunks: 4
|
40 |
-
|
41 |
-
discriminator_dim: 32
|
42 |
-
discriminator_in_channel: 2
|
43 |
-
|
44 |
-
compress_factor: 0.3
|
45 |
-
|
46 |
-
batch_size: 4
|
47 |
-
learning_rate: 0.0005
|
48 |
-
adam_b1: 0.8
|
49 |
-
adam_b2: 0.99
|
50 |
-
lr_decay: 0.99
|
51 |
-
seed: 1234
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
|
5 |
-
if __name__ == '__main__':
|
6 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/causal_convolution/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
|
5 |
-
if __name__ == '__main__':
|
6 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/causal_convolution/causal_conv2d.py
DELETED
@@ -1,281 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import math
|
4 |
-
import os
|
5 |
-
from typing import List, Optional, Union, Iterable
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
from torch.nn import functional as F
|
11 |
-
|
12 |
-
|
13 |
-
norm_layer_dict = {
|
14 |
-
"batch_norm_2d": torch.nn.BatchNorm2d
|
15 |
-
}
|
16 |
-
|
17 |
-
|
18 |
-
activation_layer_dict = {
|
19 |
-
"relu": torch.nn.ReLU,
|
20 |
-
"identity": torch.nn.Identity,
|
21 |
-
"sigmoid": torch.nn.Sigmoid,
|
22 |
-
}
|
23 |
-
|
24 |
-
|
25 |
-
class CausalConv2d(nn.Module):
|
26 |
-
def __init__(self,
|
27 |
-
in_channels: int,
|
28 |
-
out_channels: int,
|
29 |
-
kernel_size: Union[int, Iterable[int]],
|
30 |
-
f_stride: int = 1,
|
31 |
-
dilation: int = 1,
|
32 |
-
do_f_pad: bool = True,
|
33 |
-
bias: bool = True,
|
34 |
-
separable: bool = False,
|
35 |
-
norm_layer: str = "batch_norm_2d",
|
36 |
-
activation_layer: str = "relu",
|
37 |
-
lookahead: int = 0
|
38 |
-
):
|
39 |
-
super(CausalConv2d, self).__init__()
|
40 |
-
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
41 |
-
|
42 |
-
if do_f_pad:
|
43 |
-
f_pad = kernel_size[1] // 2 + dilation - 1
|
44 |
-
else:
|
45 |
-
f_pad = 0
|
46 |
-
|
47 |
-
self.causal_left_pad = kernel_size[0] - 1 - lookahead
|
48 |
-
self.causal_right_pad = lookahead
|
49 |
-
self.constant_pad = nn.ConstantPad2d(
|
50 |
-
padding=(0, 0, self.causal_left_pad, self.causal_right_pad),
|
51 |
-
value=0.0
|
52 |
-
)
|
53 |
-
|
54 |
-
groups = math.gcd(in_channels, out_channels) if separable else 1
|
55 |
-
self.conv1 = nn.Conv2d(
|
56 |
-
in_channels,
|
57 |
-
out_channels,
|
58 |
-
kernel_size=kernel_size,
|
59 |
-
padding=(0, f_pad),
|
60 |
-
stride=(1, f_stride),
|
61 |
-
dilation=(1, dilation),
|
62 |
-
groups=groups,
|
63 |
-
bias=bias,
|
64 |
-
)
|
65 |
-
|
66 |
-
self.conv2 = None
|
67 |
-
if not any([groups == 1, max(kernel_size) == 1]):
|
68 |
-
self.conv2 = nn.Conv2d(
|
69 |
-
out_channels,
|
70 |
-
out_channels,
|
71 |
-
kernel_size=1,
|
72 |
-
bias=False,
|
73 |
-
)
|
74 |
-
|
75 |
-
self.norm = None
|
76 |
-
if norm_layer is not None:
|
77 |
-
norm_layer = norm_layer_dict[norm_layer]
|
78 |
-
self.norm = norm_layer(out_channels)
|
79 |
-
|
80 |
-
self.activation = None
|
81 |
-
if activation_layer is not None:
|
82 |
-
activation_layer = activation_layer_dict[activation_layer]
|
83 |
-
self.activation = activation_layer()
|
84 |
-
|
85 |
-
def forward(self,
|
86 |
-
inputs: torch.Tensor,
|
87 |
-
causal_cache: List[torch.Tensor] = None,
|
88 |
-
):
|
89 |
-
|
90 |
-
if causal_cache is None:
|
91 |
-
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
92 |
-
x = self.constant_pad.forward(inputs)
|
93 |
-
else:
|
94 |
-
# inputs shape: [batch_size, 1, time_steps + self.causal_right_pad, hidden_size]
|
95 |
-
# causal_cache shape: [batch_size, 1, self.causal_left_pad, hidden_size]
|
96 |
-
x = torch.concat(tensors=[causal_cache, inputs], dim=2)
|
97 |
-
# x shape: [batch_size, 1, time_steps2, hidden_size]
|
98 |
-
# time_steps2 = time_steps + self.causal_left_pad + self.causal_right_pad
|
99 |
-
|
100 |
-
causal_cache = x[:, :, -self.causal_left_pad:, :]
|
101 |
-
|
102 |
-
x = self.conv1.forward(x)
|
103 |
-
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
104 |
-
|
105 |
-
if self.conv2:
|
106 |
-
x = self.conv2.forward(x)
|
107 |
-
|
108 |
-
if self.norm:
|
109 |
-
x = self.norm(x)
|
110 |
-
if self.activation:
|
111 |
-
x = self.activation(x)
|
112 |
-
|
113 |
-
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
114 |
-
return x, causal_cache
|
115 |
-
|
116 |
-
|
117 |
-
class CausalConv2dEncoder(nn.Module):
|
118 |
-
def __init__(self,
|
119 |
-
in_channels: int,
|
120 |
-
hidden_channels: int,
|
121 |
-
out_channels: int,
|
122 |
-
kernel_size: Union[int, Iterable[int]],
|
123 |
-
f_stride: int = 1,
|
124 |
-
dilation: int = 1,
|
125 |
-
do_f_pad: bool = True,
|
126 |
-
bias: bool = True,
|
127 |
-
separable: bool = False,
|
128 |
-
norm_layer: str = "batch_norm_2d",
|
129 |
-
activation_layer: str = "relu",
|
130 |
-
lookahead: int = 0,
|
131 |
-
num_layers: int = 5,
|
132 |
-
):
|
133 |
-
super(CausalConv2dEncoder, self).__init__()
|
134 |
-
self.num_layers = num_layers
|
135 |
-
|
136 |
-
self.total_causal_left_pad = 0
|
137 |
-
self.total_causal_right_pad = 0
|
138 |
-
|
139 |
-
self.causal_conv_list: List[CausalConv2d] = nn.ModuleList(modules=[])
|
140 |
-
for i_layer in range(num_layers):
|
141 |
-
conv = CausalConv2d(
|
142 |
-
in_channels=in_channels,
|
143 |
-
out_channels=hidden_channels,
|
144 |
-
kernel_size=kernel_size,
|
145 |
-
f_stride=f_stride,
|
146 |
-
dilation=dilation,
|
147 |
-
do_f_pad=do_f_pad,
|
148 |
-
bias=bias,
|
149 |
-
separable=separable,
|
150 |
-
norm_layer=norm_layer,
|
151 |
-
activation_layer=activation_layer,
|
152 |
-
lookahead=lookahead,
|
153 |
-
)
|
154 |
-
self.causal_conv_list.append(conv)
|
155 |
-
|
156 |
-
self.total_causal_left_pad += conv.causal_left_pad
|
157 |
-
self.total_causal_right_pad += conv.causal_right_pad
|
158 |
-
|
159 |
-
in_channels = hidden_channels
|
160 |
-
else:
|
161 |
-
conv = CausalConv2d(
|
162 |
-
in_channels=hidden_channels,
|
163 |
-
out_channels=out_channels,
|
164 |
-
kernel_size=kernel_size,
|
165 |
-
f_stride=f_stride,
|
166 |
-
dilation=dilation,
|
167 |
-
do_f_pad=do_f_pad,
|
168 |
-
bias=bias,
|
169 |
-
separable=separable,
|
170 |
-
norm_layer=norm_layer,
|
171 |
-
activation_layer=activation_layer,
|
172 |
-
lookahead=lookahead,
|
173 |
-
)
|
174 |
-
self.causal_conv_list.append(conv)
|
175 |
-
|
176 |
-
self.total_causal_left_pad += conv.causal_left_pad
|
177 |
-
self.total_causal_right_pad += conv.causal_right_pad
|
178 |
-
|
179 |
-
|
180 |
-
def forward(self, inputs: torch.Tensor):
|
181 |
-
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
182 |
-
|
183 |
-
x = inputs
|
184 |
-
for layer in self.causal_conv_list:
|
185 |
-
x, _ = layer.forward(x)
|
186 |
-
return x
|
187 |
-
|
188 |
-
def forward_chunk(self,
|
189 |
-
chunk: torch.Tensor,
|
190 |
-
causal_cache: List[torch.Tensor] = None,
|
191 |
-
):
|
192 |
-
# causal_cache shape: [self.num_layers, batch_size, 1, causal_left_pad, hidden_size]
|
193 |
-
|
194 |
-
new_causal_cache_list: List[torch.Tensor] = list()
|
195 |
-
for idx, causal_conv in enumerate(self.causal_conv_list):
|
196 |
-
chunk, new_causal_cache = causal_conv.forward(
|
197 |
-
inputs=chunk, causal_cache=causal_cache[idx] if causal_cache is not None else None
|
198 |
-
)
|
199 |
-
# print(f"idx: {idx}, new_causal_cache: {new_causal_cache.shape}")
|
200 |
-
new_causal_cache_list.append(new_causal_cache)
|
201 |
-
|
202 |
-
return chunk, new_causal_cache_list
|
203 |
-
|
204 |
-
def forward_chunk_by_chunk(self, inputs: torch.Tensor):
|
205 |
-
# inputs shape: [batch_size, 1, time_steps, hidden_size]
|
206 |
-
# batch_size = 1
|
207 |
-
|
208 |
-
batch_size, channels, time_steps, hidden_size = inputs.shape
|
209 |
-
|
210 |
-
new_causal_cache_list: List[torch.Tensor] = None
|
211 |
-
|
212 |
-
outputs = []
|
213 |
-
for idx in range(0, time_steps, 1):
|
214 |
-
begin = idx
|
215 |
-
end = begin + self.total_causal_right_pad + 1
|
216 |
-
chunk_xs = inputs[:, :, begin:end, :]
|
217 |
-
|
218 |
-
ys, new_causal_cache_list = self.forward_chunk(
|
219 |
-
chunk=chunk_xs,
|
220 |
-
causal_cache=new_causal_cache_list,
|
221 |
-
)
|
222 |
-
# ys shape: [batch_size, channels, self.total_causal_right_pad + 1 , hidden_size]
|
223 |
-
ys = ys[:, :, :1, :]
|
224 |
-
|
225 |
-
# ys shape: [batch_size, chunk_size, hidden_size]
|
226 |
-
outputs.append(ys)
|
227 |
-
|
228 |
-
ys = torch.cat(outputs, 2)
|
229 |
-
return ys
|
230 |
-
|
231 |
-
|
232 |
-
def main2():
|
233 |
-
conv = CausalConv2d(
|
234 |
-
in_channels=1,
|
235 |
-
out_channels=64,
|
236 |
-
kernel_size=3,
|
237 |
-
bias=False,
|
238 |
-
separable=True,
|
239 |
-
f_stride=1,
|
240 |
-
lookahead=0,
|
241 |
-
)
|
242 |
-
|
243 |
-
spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
|
244 |
-
# spec shape: [batch_size, 1, time_steps, hidden_size]
|
245 |
-
cache = torch.randn(size=(1, 1, conv.causal_left_pad, 64), dtype=torch.float32)
|
246 |
-
|
247 |
-
output, _ = conv.forward(spec)
|
248 |
-
print(output.shape)
|
249 |
-
|
250 |
-
output, _ = conv.forward(spec, cache)
|
251 |
-
print(output.shape)
|
252 |
-
|
253 |
-
return
|
254 |
-
|
255 |
-
|
256 |
-
def main():
|
257 |
-
causal = CausalConv2dEncoder(
|
258 |
-
in_channels=1,
|
259 |
-
out_channels=1,
|
260 |
-
kernel_size=3,
|
261 |
-
bias=False,
|
262 |
-
separable=True,
|
263 |
-
f_stride=1,
|
264 |
-
lookahead=0,
|
265 |
-
num_layers=3,
|
266 |
-
)
|
267 |
-
|
268 |
-
spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32)
|
269 |
-
# spec shape: [batch_size, 1, time_steps, hidden_size]
|
270 |
-
|
271 |
-
output = causal.forward(spec)
|
272 |
-
print(output.shape)
|
273 |
-
|
274 |
-
output = causal.forward_chunk_by_chunk(spec)
|
275 |
-
print(output.shape)
|
276 |
-
|
277 |
-
return
|
278 |
-
|
279 |
-
|
280 |
-
if __name__ == '__main__':
|
281 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/configuration_nx_denoise.py
DELETED
@@ -1,102 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
4 |
-
|
5 |
-
|
6 |
-
class NXDenoiseConfig(PretrainedConfig):
|
7 |
-
"""
|
8 |
-
https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
|
9 |
-
"""
|
10 |
-
def __init__(self,
|
11 |
-
sample_rate: int = 8000,
|
12 |
-
segment_size: int = 16000,
|
13 |
-
n_fft: int = 512,
|
14 |
-
win_length: int = 200,
|
15 |
-
hop_length: int = 80,
|
16 |
-
|
17 |
-
down_sampling_num_layers: int = 5,
|
18 |
-
down_sampling_in_channels: int = 1,
|
19 |
-
down_sampling_hidden_channels: int = 64,
|
20 |
-
down_sampling_kernel_size: int = 4,
|
21 |
-
down_sampling_stride: int = 2,
|
22 |
-
|
23 |
-
causal_in_channels: int = 1,
|
24 |
-
causal_hidden_channels: int = 64,
|
25 |
-
causal_kernel_size: int = 3,
|
26 |
-
causal_bias: bool = False,
|
27 |
-
causal_separable: bool = True,
|
28 |
-
causal_f_stride: int = 1,
|
29 |
-
# causal_lookahead: int = 0,
|
30 |
-
causal_num_layers: int = 3,
|
31 |
-
|
32 |
-
tsfm_hidden_size: int = 256,
|
33 |
-
tsfm_attention_heads: int = 4,
|
34 |
-
tsfm_num_blocks: int = 6,
|
35 |
-
tsfm_dropout_rate: float = 0.1,
|
36 |
-
tsfm_max_time_relative_position: int = 1024,
|
37 |
-
tsfm_max_freq_relative_position: int = 128,
|
38 |
-
tsfm_chunk_size: int = 4,
|
39 |
-
tsfm_num_left_chunks: int = 128,
|
40 |
-
tsfm_num_right_chunks: int = 2,
|
41 |
-
|
42 |
-
discriminator_dim: int = 16,
|
43 |
-
discriminator_in_channel: int = 2,
|
44 |
-
|
45 |
-
compress_factor: float = 0.3,
|
46 |
-
|
47 |
-
batch_size: int = 4,
|
48 |
-
learning_rate: float = 0.0005,
|
49 |
-
adam_b1: float = 0.8,
|
50 |
-
adam_b2: float = 0.99,
|
51 |
-
lr_decay: float = 0.99,
|
52 |
-
seed: int = 1234,
|
53 |
-
|
54 |
-
**kwargs
|
55 |
-
):
|
56 |
-
super(NXDenoiseConfig, self).__init__(**kwargs)
|
57 |
-
self.sample_rate = sample_rate
|
58 |
-
self.segment_size = segment_size
|
59 |
-
self.n_fft = n_fft
|
60 |
-
self.win_length = win_length
|
61 |
-
self.hop_length = hop_length
|
62 |
-
|
63 |
-
self.down_sampling_num_layers = down_sampling_num_layers
|
64 |
-
self.down_sampling_in_channels = down_sampling_in_channels
|
65 |
-
self.down_sampling_hidden_channels = down_sampling_hidden_channels
|
66 |
-
self.down_sampling_kernel_size = down_sampling_kernel_size
|
67 |
-
self.down_sampling_stride = down_sampling_stride
|
68 |
-
|
69 |
-
self.causal_in_channels = causal_in_channels
|
70 |
-
self.causal_hidden_channels = causal_hidden_channels
|
71 |
-
self.causal_kernel_size = causal_kernel_size
|
72 |
-
self.causal_bias = causal_bias
|
73 |
-
self.causal_separable = causal_separable
|
74 |
-
self.causal_f_stride = causal_f_stride
|
75 |
-
# self.causal_lookahead = causal_lookahead
|
76 |
-
self.causal_num_layers = causal_num_layers
|
77 |
-
|
78 |
-
self.tsfm_hidden_size = tsfm_hidden_size
|
79 |
-
self.tsfm_attention_heads = tsfm_attention_heads
|
80 |
-
self.tsfm_num_blocks = tsfm_num_blocks
|
81 |
-
self.tsfm_dropout_rate = tsfm_dropout_rate
|
82 |
-
self.tsfm_max_time_relative_position = tsfm_max_time_relative_position
|
83 |
-
self.tsfm_max_freq_relative_position = tsfm_max_freq_relative_position
|
84 |
-
self.tsfm_chunk_size = tsfm_chunk_size
|
85 |
-
self.tsfm_num_left_chunks = tsfm_num_left_chunks
|
86 |
-
self.tsfm_num_right_chunks = tsfm_num_right_chunks
|
87 |
-
|
88 |
-
self.discriminator_dim = discriminator_dim
|
89 |
-
self.discriminator_in_channel = discriminator_in_channel
|
90 |
-
|
91 |
-
self.compress_factor = compress_factor
|
92 |
-
|
93 |
-
self.batch_size = batch_size
|
94 |
-
self.learning_rate = learning_rate
|
95 |
-
self.adam_b1 = adam_b1
|
96 |
-
self.adam_b2 = adam_b2
|
97 |
-
self.lr_decay = lr_decay
|
98 |
-
self.seed = seed
|
99 |
-
|
100 |
-
|
101 |
-
if __name__ == '__main__':
|
102 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/discriminator.py
DELETED
@@ -1,132 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import os
|
4 |
-
from typing import Optional, Union
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import torchaudio
|
9 |
-
|
10 |
-
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
11 |
-
from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig
|
12 |
-
from toolbox.torchaudio.models.nx_denoise.utils import LearnableSigmoid1d
|
13 |
-
|
14 |
-
|
15 |
-
class MetricDiscriminator(nn.Module):
|
16 |
-
def __init__(self, config: NXDenoiseConfig):
|
17 |
-
super(MetricDiscriminator, self).__init__()
|
18 |
-
dim = config.discriminator_dim
|
19 |
-
self.in_channel = config.discriminator_in_channel
|
20 |
-
|
21 |
-
self.n_fft = config.n_fft
|
22 |
-
self.win_length = config.win_length
|
23 |
-
self.hop_length = config.hop_length
|
24 |
-
|
25 |
-
self.transform = torchaudio.transforms.Spectrogram(
|
26 |
-
n_fft=self.n_fft,
|
27 |
-
win_length=self.win_length,
|
28 |
-
hop_length=self.hop_length,
|
29 |
-
power=1.0,
|
30 |
-
window_fn=torch.hann_window,
|
31 |
-
# window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
|
32 |
-
)
|
33 |
-
|
34 |
-
self.layers = nn.Sequential(
|
35 |
-
nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
|
36 |
-
nn.InstanceNorm2d(dim, affine=True),
|
37 |
-
nn.PReLU(dim),
|
38 |
-
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
|
39 |
-
nn.InstanceNorm2d(dim*2, affine=True),
|
40 |
-
nn.PReLU(dim*2),
|
41 |
-
nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
|
42 |
-
nn.InstanceNorm2d(dim*4, affine=True),
|
43 |
-
nn.PReLU(dim*4),
|
44 |
-
nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
|
45 |
-
nn.InstanceNorm2d(dim*8, affine=True),
|
46 |
-
nn.PReLU(dim*8),
|
47 |
-
nn.AdaptiveMaxPool2d(1),
|
48 |
-
nn.Flatten(),
|
49 |
-
nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
|
50 |
-
nn.Dropout(0.3),
|
51 |
-
nn.PReLU(dim*4),
|
52 |
-
nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
|
53 |
-
LearnableSigmoid1d(1)
|
54 |
-
)
|
55 |
-
|
56 |
-
def forward(self, x, y):
|
57 |
-
x = self.transform.forward(x)
|
58 |
-
y = self.transform.forward(y)
|
59 |
-
|
60 |
-
xy = torch.stack((x, y), dim=1)
|
61 |
-
return self.layers(xy)
|
62 |
-
|
63 |
-
|
64 |
-
MODEL_FILE = "discriminator.pt"
|
65 |
-
|
66 |
-
|
67 |
-
class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
|
68 |
-
def __init__(self,
|
69 |
-
config: NXDenoiseConfig,
|
70 |
-
):
|
71 |
-
super(MetricDiscriminatorPretrainedModel, self).__init__(
|
72 |
-
config=config,
|
73 |
-
)
|
74 |
-
self.config = config
|
75 |
-
|
76 |
-
@classmethod
|
77 |
-
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
78 |
-
config = NXDenoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
79 |
-
|
80 |
-
model = cls(config)
|
81 |
-
|
82 |
-
if os.path.isdir(pretrained_model_name_or_path):
|
83 |
-
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
84 |
-
else:
|
85 |
-
ckpt_file = pretrained_model_name_or_path
|
86 |
-
|
87 |
-
with open(ckpt_file, "rb") as f:
|
88 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
89 |
-
model.load_state_dict(state_dict, strict=True)
|
90 |
-
return model
|
91 |
-
|
92 |
-
def save_pretrained(self,
|
93 |
-
save_directory: Union[str, os.PathLike],
|
94 |
-
state_dict: Optional[dict] = None,
|
95 |
-
):
|
96 |
-
|
97 |
-
model = self
|
98 |
-
|
99 |
-
if state_dict is None:
|
100 |
-
state_dict = model.state_dict()
|
101 |
-
|
102 |
-
os.makedirs(save_directory, exist_ok=True)
|
103 |
-
|
104 |
-
# save state dict
|
105 |
-
model_file = os.path.join(save_directory, MODEL_FILE)
|
106 |
-
torch.save(state_dict, model_file)
|
107 |
-
|
108 |
-
# save config
|
109 |
-
config_file = os.path.join(save_directory, CONFIG_FILE)
|
110 |
-
self.config.to_yaml_file(config_file)
|
111 |
-
return save_directory
|
112 |
-
|
113 |
-
|
114 |
-
def main():
|
115 |
-
config = NXDenoiseConfig()
|
116 |
-
discriminator = MetricDiscriminator(config=config)
|
117 |
-
|
118 |
-
# shape: [batch_size, num_samples]
|
119 |
-
# x = torch.ones([4, int(4.5 * 16000)])
|
120 |
-
# y = torch.ones([4, int(4.5 * 16000)])
|
121 |
-
x = torch.ones([4, 16000])
|
122 |
-
y = torch.ones([4, 16000])
|
123 |
-
|
124 |
-
output = discriminator.forward(x, y)
|
125 |
-
print(output.shape)
|
126 |
-
print(output)
|
127 |
-
|
128 |
-
return
|
129 |
-
|
130 |
-
|
131 |
-
if __name__ == "__main__":
|
132 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/inference_nx_denoise.py
DELETED
@@ -1,97 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import logging
|
4 |
-
from pathlib import Path
|
5 |
-
import shutil
|
6 |
-
import tempfile
|
7 |
-
import zipfile
|
8 |
-
|
9 |
-
import librosa
|
10 |
-
import numpy as np
|
11 |
-
import torch
|
12 |
-
import torchaudio
|
13 |
-
|
14 |
-
from project_settings import project_path
|
15 |
-
from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig
|
16 |
-
from toolbox.torchaudio.models.nx_denoise.modeling_nx_denoise import NXDenoisePretrainedModel, MODEL_FILE
|
17 |
-
|
18 |
-
logger = logging.getLogger("toolbox")
|
19 |
-
|
20 |
-
|
21 |
-
class InferenceNXDenoise(object):
|
22 |
-
def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"):
|
23 |
-
self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file
|
24 |
-
self.device = torch.device(device)
|
25 |
-
|
26 |
-
logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}")
|
27 |
-
config, model = self.load_models(self.pretrained_model_path_or_zip_file)
|
28 |
-
logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}")
|
29 |
-
|
30 |
-
self.config = config
|
31 |
-
self.model = model
|
32 |
-
self.model.to(device)
|
33 |
-
self.model.eval()
|
34 |
-
|
35 |
-
def load_models(self, model_path: str):
|
36 |
-
model_path = Path(model_path)
|
37 |
-
if model_path.name.endswith(".zip"):
|
38 |
-
with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip:
|
39 |
-
out_root = Path(tempfile.gettempdir()) / "nx_denoise"
|
40 |
-
out_root.mkdir(parents=True, exist_ok=True)
|
41 |
-
f_zip.extractall(path=out_root)
|
42 |
-
model_path = out_root / model_path.stem
|
43 |
-
|
44 |
-
config = NXDenoiseConfig.from_pretrained(
|
45 |
-
pretrained_model_name_or_path=model_path.as_posix(),
|
46 |
-
)
|
47 |
-
model = NXDenoisePretrainedModel.from_pretrained(
|
48 |
-
pretrained_model_name_or_path=model_path.as_posix(),
|
49 |
-
)
|
50 |
-
model.to(self.device)
|
51 |
-
model.eval()
|
52 |
-
|
53 |
-
shutil.rmtree(model_path)
|
54 |
-
return config, model
|
55 |
-
|
56 |
-
def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor:
|
57 |
-
if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1:
|
58 |
-
raise AssertionError(f"The value range of audio samples should be between -1 and 1.")
|
59 |
-
|
60 |
-
# noisy_audio shape: [batch_size, num_samples]
|
61 |
-
noisy_audios = noisy_audio.to(self.device)
|
62 |
-
|
63 |
-
with torch.no_grad():
|
64 |
-
# enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios)
|
65 |
-
enhanced_audios = self.model.forward(noisy_audios)
|
66 |
-
# enhanced_audio shape: [batch_size, n_samples]
|
67 |
-
# enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
68 |
-
|
69 |
-
enhanced_audio = enhanced_audios[0]
|
70 |
-
# enhanced_audio shape: [num_samples,]
|
71 |
-
return enhanced_audio
|
72 |
-
|
73 |
-
|
74 |
-
def main():
|
75 |
-
model_zip_file = project_path / "trained_models/nx-denoise.zip"
|
76 |
-
runtime = InferenceNXDenoise(model_zip_file)
|
77 |
-
|
78 |
-
sample_rate = 8000
|
79 |
-
noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav"
|
80 |
-
noisy_audio, _ = librosa.load(
|
81 |
-
noisy_audio_file.as_posix(),
|
82 |
-
sr=sample_rate,
|
83 |
-
)
|
84 |
-
noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)]
|
85 |
-
noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32)
|
86 |
-
noisy_audio = noisy_audio.unsqueeze(dim=0)
|
87 |
-
|
88 |
-
enhanced_audio = runtime.enhancement_by_tensor(noisy_audio)
|
89 |
-
|
90 |
-
filename = "enhanced_audio.wav"
|
91 |
-
torchaudio.save(filename, enhanced_audio.detach().cpu().unsqueeze(dim=0), sample_rate)
|
92 |
-
|
93 |
-
return
|
94 |
-
|
95 |
-
|
96 |
-
if __name__ == '__main__':
|
97 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/loss.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
|
6 |
-
|
7 |
-
def anti_wrapping_function(x):
|
8 |
-
|
9 |
-
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
|
10 |
-
|
11 |
-
|
12 |
-
def phase_losses(phase_r, phase_g):
|
13 |
-
|
14 |
-
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
|
15 |
-
gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
|
16 |
-
iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
|
17 |
-
|
18 |
-
return ip_loss, gd_loss, iaf_loss
|
19 |
-
|
20 |
-
|
21 |
-
if __name__ == '__main__':
|
22 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/metrics.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
from joblib import Parallel, delayed
|
4 |
-
import numpy as np
|
5 |
-
from pesq import pesq
|
6 |
-
from typing import List
|
7 |
-
|
8 |
-
from pesq import cypesq
|
9 |
-
|
10 |
-
|
11 |
-
def run_pesq(clean_audio: np.ndarray,
|
12 |
-
noisy_audio: np.ndarray,
|
13 |
-
sample_rate: int = 16000,
|
14 |
-
mode: str = "wb",
|
15 |
-
) -> float:
|
16 |
-
if sample_rate == 8000 and mode == "wb":
|
17 |
-
raise AssertionError(f"mode should be `nb` when sample_rate is 8000")
|
18 |
-
try:
|
19 |
-
pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode)
|
20 |
-
except cypesq.NoUtterancesError as e:
|
21 |
-
pesq_score = -1
|
22 |
-
except Exception as e:
|
23 |
-
print(f"pesq failed. error type: {type(e)}, error text: {str(e)}")
|
24 |
-
pesq_score = -1
|
25 |
-
return pesq_score
|
26 |
-
|
27 |
-
|
28 |
-
def run_batch_pesq(clean_audio_list: List[np.ndarray],
|
29 |
-
noisy_audio_list: List[np.ndarray],
|
30 |
-
sample_rate: int = 16000,
|
31 |
-
mode: str = "wb",
|
32 |
-
n_jobs: int = 4,
|
33 |
-
) -> List[float]:
|
34 |
-
parallel = Parallel(n_jobs=n_jobs)
|
35 |
-
|
36 |
-
parallel_tasks = list()
|
37 |
-
for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list):
|
38 |
-
parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode)
|
39 |
-
parallel_tasks.append(parallel_task)
|
40 |
-
|
41 |
-
pesq_score_list = parallel.__call__(parallel_tasks)
|
42 |
-
return pesq_score_list
|
43 |
-
|
44 |
-
|
45 |
-
def run_pesq_score(clean_audio_list: List[np.ndarray],
|
46 |
-
noisy_audio_list: List[np.ndarray],
|
47 |
-
sample_rate: int = 16000,
|
48 |
-
mode: str = "wb",
|
49 |
-
n_jobs: int = 4,
|
50 |
-
) -> List[float]:
|
51 |
-
|
52 |
-
pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list,
|
53 |
-
noisy_audio_list=noisy_audio_list,
|
54 |
-
sample_rate=sample_rate,
|
55 |
-
mode=mode,
|
56 |
-
n_jobs=n_jobs,
|
57 |
-
)
|
58 |
-
|
59 |
-
pesq_score = np.mean(pesq_score_list)
|
60 |
-
return pesq_score
|
61 |
-
|
62 |
-
|
63 |
-
def main():
|
64 |
-
clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
|
65 |
-
noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,))
|
66 |
-
|
67 |
-
clean_audio_list = list(clean_audio)
|
68 |
-
noisy_audio_list = list(noisy_audio)
|
69 |
-
|
70 |
-
pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list)
|
71 |
-
print(pesq_score_list)
|
72 |
-
|
73 |
-
pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list)
|
74 |
-
print(pesq_score)
|
75 |
-
|
76 |
-
return
|
77 |
-
|
78 |
-
|
79 |
-
if __name__ == "__main__":
|
80 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/modeling_nx_denoise.py
DELETED
@@ -1,392 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import os
|
4 |
-
from typing import List, Optional, Union
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
from torch.nn import functional as F
|
10 |
-
|
11 |
-
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
12 |
-
from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig
|
13 |
-
from toolbox.torchaudio.models.nx_denoise.causal_convolution.causal_conv2d import CausalConv2dEncoder
|
14 |
-
from toolbox.torchaudio.models.nx_denoise.transformers.transformers import TSTransformerEncoder
|
15 |
-
|
16 |
-
|
17 |
-
class DownSamplingBlock(nn.Module):
|
18 |
-
def __init__(self,
|
19 |
-
in_channels: int,
|
20 |
-
hidden_channels: int,
|
21 |
-
kernel_size: int,
|
22 |
-
stride: int,
|
23 |
-
):
|
24 |
-
super(DownSamplingBlock, self).__init__()
|
25 |
-
self.conv1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, stride)
|
26 |
-
self.relu = nn.ReLU()
|
27 |
-
self.conv2 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1)
|
28 |
-
self.glu = nn.GLU(dim=1)
|
29 |
-
|
30 |
-
def forward(self, x: torch.Tensor):
|
31 |
-
# x shape: [batch_size, 1, num_samples]
|
32 |
-
x = self.conv1.forward(x)
|
33 |
-
# x shape: [batch_size, hidden_channels, new_num_samples]
|
34 |
-
x = self.relu(x)
|
35 |
-
x = self.conv2.forward(x)
|
36 |
-
# x shape: [batch_size, hidden_channels*2, new_num_samples]
|
37 |
-
x = self.glu(x)
|
38 |
-
# x shape: [batch_size, hidden_channels, new_num_samples]
|
39 |
-
# new_num_samples = (num_samples-kernel_size) // stride + 1
|
40 |
-
return x
|
41 |
-
|
42 |
-
|
43 |
-
class DownSampling(nn.Module):
|
44 |
-
def __init__(self,
|
45 |
-
num_layers: int,
|
46 |
-
in_channels: int,
|
47 |
-
hidden_channels: int,
|
48 |
-
kernel_size: int,
|
49 |
-
stride: int,
|
50 |
-
):
|
51 |
-
super(DownSampling, self).__init__()
|
52 |
-
self.num_layers = num_layers
|
53 |
-
|
54 |
-
down_sampling_block_list = list()
|
55 |
-
for idx in range(self.num_layers):
|
56 |
-
down_sampling_block = DownSamplingBlock(
|
57 |
-
in_channels=in_channels,
|
58 |
-
hidden_channels=hidden_channels,
|
59 |
-
kernel_size=kernel_size,
|
60 |
-
stride=stride,
|
61 |
-
)
|
62 |
-
down_sampling_block_list.append(down_sampling_block)
|
63 |
-
in_channels = hidden_channels
|
64 |
-
|
65 |
-
self.down_sampling_block_list = nn.ModuleList(modules=down_sampling_block_list)
|
66 |
-
|
67 |
-
def forward(self, x: torch.Tensor):
|
68 |
-
# x shape: [batch_size, channels, num_samples]
|
69 |
-
skip_connection_list = list()
|
70 |
-
for down_sampling_block in self.down_sampling_block_list:
|
71 |
-
x = down_sampling_block.forward(x)
|
72 |
-
skip_connection_list.append(x)
|
73 |
-
# x shape: [batch_size, hidden_channels, num_samples**]
|
74 |
-
return x, skip_connection_list
|
75 |
-
|
76 |
-
|
77 |
-
class UpSamplingBlock(nn.Module):
|
78 |
-
def __init__(self,
|
79 |
-
out_channels: int,
|
80 |
-
hidden_channels: int,
|
81 |
-
kernel_size: int,
|
82 |
-
stride: int,
|
83 |
-
do_relu: bool = True,
|
84 |
-
):
|
85 |
-
super(UpSamplingBlock, self).__init__()
|
86 |
-
self.do_relu = do_relu
|
87 |
-
|
88 |
-
self.conv1 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1)
|
89 |
-
self.glu = nn.GLU(dim=1)
|
90 |
-
self.convt = nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride)
|
91 |
-
self.relu = nn.ReLU()
|
92 |
-
|
93 |
-
def forward(self, x: torch.Tensor):
|
94 |
-
# x shape: [batch_size, hidden_channels*2, num_samples]
|
95 |
-
x = self.conv1.forward(x)
|
96 |
-
# x shape: [batch_size, hidden_channels, num_samples]
|
97 |
-
x = self.glu(x)
|
98 |
-
# x shape: [batch_size, hidden_channels, num_samples]
|
99 |
-
x = self.convt.forward(x)
|
100 |
-
# x shape: [batch_size, hidden_channels, new_num_samples]
|
101 |
-
# new_num_samples = (num_samples - 1) * stride + kernel_size
|
102 |
-
if self.do_relu:
|
103 |
-
x = self.relu(x)
|
104 |
-
return x
|
105 |
-
|
106 |
-
|
107 |
-
class UpSampling(nn.Module):
|
108 |
-
def __init__(self,
|
109 |
-
num_layers: int,
|
110 |
-
out_channels: int,
|
111 |
-
hidden_channels: int,
|
112 |
-
kernel_size: int,
|
113 |
-
stride: int,
|
114 |
-
):
|
115 |
-
super(UpSampling, self).__init__()
|
116 |
-
self.num_layers = num_layers
|
117 |
-
|
118 |
-
up_sampling_block_list = list()
|
119 |
-
for idx in range(self.num_layers-1):
|
120 |
-
up_sampling_block = UpSamplingBlock(
|
121 |
-
out_channels=hidden_channels,
|
122 |
-
hidden_channels=hidden_channels,
|
123 |
-
kernel_size=kernel_size,
|
124 |
-
stride=stride,
|
125 |
-
do_relu=True,
|
126 |
-
)
|
127 |
-
up_sampling_block_list.append(up_sampling_block)
|
128 |
-
else:
|
129 |
-
up_sampling_block = UpSamplingBlock(
|
130 |
-
out_channels=out_channels,
|
131 |
-
hidden_channels=hidden_channels,
|
132 |
-
kernel_size=kernel_size,
|
133 |
-
stride=stride,
|
134 |
-
do_relu=False,
|
135 |
-
)
|
136 |
-
up_sampling_block_list.append(up_sampling_block)
|
137 |
-
self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list)
|
138 |
-
|
139 |
-
def forward(self, x: torch.Tensor, skip_connection_list: List[torch.Tensor]):
|
140 |
-
skip_connection_list = skip_connection_list[::-1]
|
141 |
-
|
142 |
-
# x shape: [batch_size, channels, num_samples]
|
143 |
-
for idx, up_sampling_block in enumerate(self.up_sampling_block_list):
|
144 |
-
skip_x = skip_connection_list[idx]
|
145 |
-
x = x + skip_x
|
146 |
-
# x = x + skip_x[:, :, :x.size(2)]
|
147 |
-
x = up_sampling_block.forward(x)
|
148 |
-
return x
|
149 |
-
|
150 |
-
|
151 |
-
def get_padding_length(length, num_layers: int, kernel_size: int, stride: int):
|
152 |
-
for _ in range(num_layers):
|
153 |
-
if length < kernel_size:
|
154 |
-
length = 1
|
155 |
-
else:
|
156 |
-
length = 1 + np.ceil((length - kernel_size) / stride)
|
157 |
-
|
158 |
-
for _ in range(num_layers):
|
159 |
-
length = (length - 1) * stride + kernel_size
|
160 |
-
|
161 |
-
padded_length = int(length)
|
162 |
-
return padded_length
|
163 |
-
|
164 |
-
|
165 |
-
class NXDenoise(nn.Module):
|
166 |
-
def __init__(self, config: NXDenoiseConfig):
|
167 |
-
super().__init__()
|
168 |
-
self.config = config
|
169 |
-
|
170 |
-
self.down_sampling = DownSampling(
|
171 |
-
num_layers=config.down_sampling_num_layers,
|
172 |
-
in_channels=config.down_sampling_in_channels,
|
173 |
-
hidden_channels=config.down_sampling_hidden_channels,
|
174 |
-
kernel_size=config.down_sampling_kernel_size,
|
175 |
-
stride=config.down_sampling_stride,
|
176 |
-
)
|
177 |
-
self.causal_conv_in = CausalConv2dEncoder(
|
178 |
-
in_channels=config.causal_in_channels,
|
179 |
-
hidden_channels=config.causal_hidden_channels,
|
180 |
-
out_channels=config.causal_hidden_channels,
|
181 |
-
kernel_size=config.causal_kernel_size,
|
182 |
-
bias=config.causal_bias,
|
183 |
-
separable=config.causal_separable,
|
184 |
-
f_stride=config.causal_f_stride,
|
185 |
-
lookahead=0,
|
186 |
-
num_layers=config.causal_num_layers,
|
187 |
-
)
|
188 |
-
self.ts_transformer = TSTransformerEncoder(
|
189 |
-
input_size=config.down_sampling_hidden_channels,
|
190 |
-
hidden_size=config.tsfm_hidden_size,
|
191 |
-
attention_heads=config.tsfm_attention_heads,
|
192 |
-
num_blocks=config.tsfm_num_blocks,
|
193 |
-
dropout_rate=config.tsfm_dropout_rate,
|
194 |
-
max_time_relative_position=config.tsfm_max_time_relative_position,
|
195 |
-
max_freq_relative_position=config.tsfm_max_freq_relative_position,
|
196 |
-
chunk_size=config.tsfm_chunk_size,
|
197 |
-
num_left_chunks=config.tsfm_num_left_chunks,
|
198 |
-
num_right_chunks=config.tsfm_num_right_chunks,
|
199 |
-
)
|
200 |
-
self.causal_conv_out = CausalConv2dEncoder(
|
201 |
-
in_channels=config.causal_hidden_channels,
|
202 |
-
hidden_channels=config.causal_hidden_channels,
|
203 |
-
out_channels=config.causal_in_channels,
|
204 |
-
kernel_size=config.causal_kernel_size,
|
205 |
-
bias=config.causal_bias,
|
206 |
-
separable=config.causal_separable,
|
207 |
-
f_stride=config.causal_f_stride,
|
208 |
-
lookahead=0,
|
209 |
-
num_layers=config.causal_num_layers,
|
210 |
-
)
|
211 |
-
self.up_sampling = UpSampling(
|
212 |
-
num_layers=config.down_sampling_num_layers,
|
213 |
-
out_channels=config.down_sampling_in_channels,
|
214 |
-
hidden_channels=config.down_sampling_hidden_channels,
|
215 |
-
kernel_size=config.down_sampling_kernel_size,
|
216 |
-
stride=config.down_sampling_stride,
|
217 |
-
)
|
218 |
-
|
219 |
-
def forward(self, noisy_audios: torch.Tensor):
|
220 |
-
# noisy_audios shape: [batch_size, n_samples]
|
221 |
-
noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
|
222 |
-
# noisy_audios shape: [batch_size, 1, n_samples]
|
223 |
-
|
224 |
-
n_samples = noisy_audios.shape[-1]
|
225 |
-
padded_length = get_padding_length(
|
226 |
-
n_samples,
|
227 |
-
num_layers=self.config.down_sampling_num_layers,
|
228 |
-
kernel_size=self.config.down_sampling_kernel_size,
|
229 |
-
stride=self.config.down_sampling_stride,
|
230 |
-
)
|
231 |
-
noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
|
232 |
-
|
233 |
-
# down sampling
|
234 |
-
bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
|
235 |
-
# bottle_neck shape: [batch_size, channels, time_steps]
|
236 |
-
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
237 |
-
# bottle_neck shape: [batch_size, time_steps, channels]
|
238 |
-
bottle_neck = torch.unsqueeze(bottle_neck, dim=1)
|
239 |
-
# bottle_neck shape: [batch_size, 1, time_steps, freq_dim]
|
240 |
-
|
241 |
-
# causal conv in
|
242 |
-
bottle_neck = self.causal_conv_in.forward(bottle_neck)
|
243 |
-
# bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
|
244 |
-
|
245 |
-
# ts transformer
|
246 |
-
# bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
|
247 |
-
bottle_neck = self.ts_transformer.forward(bottle_neck)
|
248 |
-
# bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
|
249 |
-
|
250 |
-
# causal conv out
|
251 |
-
bottle_neck = self.causal_conv_out.forward(bottle_neck)
|
252 |
-
# bottle_neck shape: [batch_size, 1, time_steps, freq_dim]
|
253 |
-
|
254 |
-
# up sampling
|
255 |
-
bottle_neck = torch.squeeze(bottle_neck, dim=1)
|
256 |
-
# bottle_neck shape: [batch_size, time_steps, channels]
|
257 |
-
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
258 |
-
# bottle_neck shape: [batch_size, channels, time_steps]
|
259 |
-
|
260 |
-
enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
|
261 |
-
|
262 |
-
enhanced_audios = enhanced_audios[:, :, :n_samples]
|
263 |
-
# enhanced_audios shape: [batch_size, 1, n_samples]
|
264 |
-
|
265 |
-
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
266 |
-
# enhanced_audios shape: [batch_size, n_samples]
|
267 |
-
|
268 |
-
return enhanced_audios
|
269 |
-
|
270 |
-
|
271 |
-
def forward_chunk_by_chunk(self, noisy_audios: torch.Tensor):
|
272 |
-
# noisy_audios shape: [batch_size, n_samples]
|
273 |
-
noisy_audios = torch.unsqueeze(noisy_audios, dim=1)
|
274 |
-
# noisy_audios shape: [batch_size, 1, n_samples]
|
275 |
-
|
276 |
-
n_samples = noisy_audios.shape[-1]
|
277 |
-
padded_length = get_padding_length(
|
278 |
-
n_samples,
|
279 |
-
num_layers=self.config.down_sampling_num_layers,
|
280 |
-
kernel_size=self.config.down_sampling_kernel_size,
|
281 |
-
stride=self.config.down_sampling_stride,
|
282 |
-
)
|
283 |
-
noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0)
|
284 |
-
|
285 |
-
# down sampling
|
286 |
-
bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded)
|
287 |
-
# bottle_neck shape: [batch_size, channels, time_steps]
|
288 |
-
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
289 |
-
# bottle_neck shape: [batch_size, time_steps, channels]
|
290 |
-
bottle_neck = torch.unsqueeze(bottle_neck, dim=1)
|
291 |
-
# bottle_neck shape: [batch_size, 1, time_steps, freq_dim]
|
292 |
-
|
293 |
-
# causal conv in
|
294 |
-
bottle_neck = self.causal_conv_in.forward_chunk_by_chunk(bottle_neck)
|
295 |
-
# bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
|
296 |
-
|
297 |
-
# ts transformer
|
298 |
-
# bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
|
299 |
-
bottle_neck = self.ts_transformer.forward_chunk_by_chunk(bottle_neck)
|
300 |
-
# bottle_neck shape: [batch_size, channels, time_steps, freq_dim]
|
301 |
-
|
302 |
-
# causal conv out
|
303 |
-
bottle_neck = self.causal_conv_out.forward_chunk_by_chunk(bottle_neck)
|
304 |
-
# bottle_neck shape: [batch_size, 1, time_steps, freq_dim]
|
305 |
-
|
306 |
-
# up sampling
|
307 |
-
bottle_neck = torch.squeeze(bottle_neck, dim=1)
|
308 |
-
# bottle_neck shape: [batch_size, time_steps, channels]
|
309 |
-
bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1)
|
310 |
-
# bottle_neck shape: [batch_size, channels, time_steps]
|
311 |
-
|
312 |
-
enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list)
|
313 |
-
|
314 |
-
enhanced_audios = enhanced_audios[:, :, :n_samples]
|
315 |
-
# enhanced_audios shape: [batch_size, 1, n_samples]
|
316 |
-
|
317 |
-
enhanced_audios = torch.squeeze(enhanced_audios, dim=1)
|
318 |
-
# enhanced_audios shape: [batch_size, n_samples]
|
319 |
-
|
320 |
-
return enhanced_audios
|
321 |
-
|
322 |
-
|
323 |
-
MODEL_FILE = "generator.pt"
|
324 |
-
|
325 |
-
|
326 |
-
class NXDenoisePretrainedModel(NXDenoise):
|
327 |
-
def __init__(self,
|
328 |
-
config: NXDenoiseConfig,
|
329 |
-
):
|
330 |
-
super(NXDenoisePretrainedModel, self).__init__(
|
331 |
-
config=config,
|
332 |
-
)
|
333 |
-
self.config = config
|
334 |
-
|
335 |
-
@classmethod
|
336 |
-
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
337 |
-
config = NXDenoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
338 |
-
|
339 |
-
model = cls(config)
|
340 |
-
|
341 |
-
if os.path.isdir(pretrained_model_name_or_path):
|
342 |
-
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
343 |
-
else:
|
344 |
-
ckpt_file = pretrained_model_name_or_path
|
345 |
-
|
346 |
-
with open(ckpt_file, "rb") as f:
|
347 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
348 |
-
model.load_state_dict(state_dict, strict=True)
|
349 |
-
return model
|
350 |
-
|
351 |
-
def save_pretrained(self,
|
352 |
-
save_directory: Union[str, os.PathLike],
|
353 |
-
state_dict: Optional[dict] = None,
|
354 |
-
):
|
355 |
-
|
356 |
-
model = self
|
357 |
-
|
358 |
-
if state_dict is None:
|
359 |
-
state_dict = model.state_dict()
|
360 |
-
|
361 |
-
os.makedirs(save_directory, exist_ok=True)
|
362 |
-
|
363 |
-
# save state dict
|
364 |
-
model_file = os.path.join(save_directory, MODEL_FILE)
|
365 |
-
torch.save(state_dict, model_file)
|
366 |
-
|
367 |
-
# save config
|
368 |
-
config_file = os.path.join(save_directory, CONFIG_FILE)
|
369 |
-
self.config.to_yaml_file(config_file)
|
370 |
-
return save_directory
|
371 |
-
|
372 |
-
|
373 |
-
def main():
|
374 |
-
|
375 |
-
config = NXDenoiseConfig()
|
376 |
-
|
377 |
-
# shape: [batch_size, channels, num_samples]
|
378 |
-
# min length: 94, stride: 32, 32 == 2**5
|
379 |
-
# x = torch.ones([4, 94])
|
380 |
-
# x = torch.ones([4, 126])
|
381 |
-
# x = torch.ones([4, 158])
|
382 |
-
# x = torch.ones([4, 190])
|
383 |
-
x = torch.ones([4, 16000])
|
384 |
-
|
385 |
-
model = NXDenoise(config)
|
386 |
-
enhanced_audios = model.forward(x)
|
387 |
-
print(enhanced_audios.shape)
|
388 |
-
return
|
389 |
-
|
390 |
-
|
391 |
-
if __name__ == "__main__":
|
392 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/stftnet/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
|
5 |
-
if __name__ == '__main__':
|
6 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/stftnet/stfnets.py
DELETED
@@ -1,9 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
"""
|
4 |
-
https://arxiv.org/abs/1902.07849
|
5 |
-
"""
|
6 |
-
|
7 |
-
|
8 |
-
if __name__ == '__main__':
|
9 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/transformers/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
|
5 |
-
if __name__ == '__main__':
|
6 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/transformers/attention.py
DELETED
@@ -1,263 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import math
|
4 |
-
from typing import Tuple
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
|
9 |
-
|
10 |
-
class MultiHeadSelfAttention(nn.Module):
|
11 |
-
def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
|
12 |
-
"""
|
13 |
-
:param n_head: int. the number of heads.
|
14 |
-
:param n_feat: int. the number of features.
|
15 |
-
:param dropout_rate: float. dropout rate.
|
16 |
-
"""
|
17 |
-
super().__init__()
|
18 |
-
assert n_feat % n_head == 0
|
19 |
-
# We assume d_v always equals d_k
|
20 |
-
self.d_k = n_feat // n_head
|
21 |
-
self.h = n_head
|
22 |
-
self.linear_q = nn.Linear(n_feat, n_feat)
|
23 |
-
self.linear_k = nn.Linear(n_feat, n_feat)
|
24 |
-
self.linear_v = nn.Linear(n_feat, n_feat)
|
25 |
-
self.linear_out = nn.Linear(n_feat, n_feat)
|
26 |
-
self.dropout = nn.Dropout(p=dropout_rate)
|
27 |
-
|
28 |
-
def forward_qkv(self,
|
29 |
-
query: torch.Tensor,
|
30 |
-
key: torch.Tensor,
|
31 |
-
value: torch.Tensor
|
32 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
33 |
-
"""
|
34 |
-
transform query, key and value.
|
35 |
-
:param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
|
36 |
-
:param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
|
37 |
-
:param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
|
38 |
-
:return:
|
39 |
-
"""
|
40 |
-
n_batch = query.size(0)
|
41 |
-
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
42 |
-
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
43 |
-
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
44 |
-
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
45 |
-
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
46 |
-
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
47 |
-
|
48 |
-
return q, k, v
|
49 |
-
|
50 |
-
def forward_attention(self,
|
51 |
-
value: torch.Tensor,
|
52 |
-
scores: torch.Tensor,
|
53 |
-
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
54 |
-
) -> torch.Tensor:
|
55 |
-
"""
|
56 |
-
compute attention context vector.
|
57 |
-
:param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k).
|
58 |
-
:param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2).
|
59 |
-
:param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or
|
60 |
-
(batch_size, time1, time2), (0, 0, 0) means fake mask.
|
61 |
-
:return: torch.Tensor. transformed value. (batch_size, time1, d_model).
|
62 |
-
weighted by the attention score (batch_size, time1, time2).
|
63 |
-
"""
|
64 |
-
n_batch = value.size(0)
|
65 |
-
# NOTE: When will `if mask.size(2) > 0` be True?
|
66 |
-
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
67 |
-
# 1st chunk to ease the onnx export.]
|
68 |
-
# 2. pytorch training
|
69 |
-
if mask.size(2) > 0: # time2 > 0
|
70 |
-
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
71 |
-
# For last chunk, time2 might be larger than scores.size(-1)
|
72 |
-
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
73 |
-
scores = scores.masked_fill(mask, -float('inf'))
|
74 |
-
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2)
|
75 |
-
|
76 |
-
# NOTE: When will `if mask.size(2) > 0` be False?
|
77 |
-
# 1. onnx(16/-1, -1/-1, 16/0)
|
78 |
-
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
79 |
-
else:
|
80 |
-
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
81 |
-
|
82 |
-
p_attn = self.dropout(attn)
|
83 |
-
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
84 |
-
x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
|
85 |
-
|
86 |
-
return self.linear_out(x) # (batch, time1, n_feat)
|
87 |
-
|
88 |
-
def forward(self,
|
89 |
-
x: torch.Tensor,
|
90 |
-
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
91 |
-
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
92 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
93 |
-
|
94 |
-
q, k, v = self.forward_qkv(x, x, x)
|
95 |
-
|
96 |
-
if cache.size(0) > 0:
|
97 |
-
key_cache, value_cache = torch.split(
|
98 |
-
cache, cache.size(-1) // 2, dim=-1)
|
99 |
-
k = torch.cat([key_cache, k], dim=2)
|
100 |
-
v = torch.cat([value_cache, v], dim=2)
|
101 |
-
# NOTE: We do cache slicing in encoder.forward_chunk, since it's
|
102 |
-
# non-trivial to calculate `next_cache_start` here.
|
103 |
-
new_cache = torch.cat((k, v), dim=-1)
|
104 |
-
|
105 |
-
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
106 |
-
return self.forward_attention(v, scores, mask), new_cache
|
107 |
-
|
108 |
-
|
109 |
-
class RelativeMultiHeadSelfAttention(nn.Module):
|
110 |
-
|
111 |
-
def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120):
|
112 |
-
"""
|
113 |
-
:param n_head: int. the number of heads.
|
114 |
-
:param n_feat: int. the number of features.
|
115 |
-
:param dropout_rate: float. dropout rate.
|
116 |
-
:param max_relative_position: int. maximum relative position for relative position encoding.
|
117 |
-
"""
|
118 |
-
super().__init__()
|
119 |
-
assert n_feat % n_head == 0
|
120 |
-
# We assume d_v always equals d_k
|
121 |
-
self.d_k = n_feat // n_head
|
122 |
-
self.h = n_head
|
123 |
-
self.linear_q = nn.Linear(n_feat, n_feat)
|
124 |
-
self.linear_k = nn.Linear(n_feat, n_feat)
|
125 |
-
self.linear_v = nn.Linear(n_feat, n_feat)
|
126 |
-
self.linear_out = nn.Linear(n_feat, n_feat)
|
127 |
-
self.dropout = nn.Dropout(p=dropout_rate)
|
128 |
-
|
129 |
-
# Relative position encoding
|
130 |
-
self.max_relative_position = max_relative_position
|
131 |
-
self.relative_position_k = nn.Parameter(torch.randn(max_relative_position * 2 + 1, self.d_k))
|
132 |
-
|
133 |
-
def forward_qkv(self,
|
134 |
-
query: torch.Tensor,
|
135 |
-
key: torch.Tensor,
|
136 |
-
value: torch.Tensor
|
137 |
-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
138 |
-
"""
|
139 |
-
transform query, key and value.
|
140 |
-
:param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat).
|
141 |
-
:param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat).
|
142 |
-
:param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat).
|
143 |
-
:return:
|
144 |
-
"""
|
145 |
-
n_batch = query.size(0)
|
146 |
-
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
147 |
-
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
148 |
-
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
149 |
-
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
150 |
-
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
151 |
-
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
152 |
-
|
153 |
-
return q, k, v
|
154 |
-
|
155 |
-
def forward_attention(self,
|
156 |
-
value: torch.Tensor,
|
157 |
-
scores: torch.Tensor,
|
158 |
-
mask: torch.Tensor = None
|
159 |
-
) -> torch.Tensor:
|
160 |
-
"""
|
161 |
-
compute attention context vector.
|
162 |
-
:param value: torch.Tensor. transformed value. shape=(batch_size, n_head, key_time_steps, d_k).
|
163 |
-
:param scores: torch.Tensor. attention score. shape=(batch_size, n_head, query_time_steps, key_time_steps).
|
164 |
-
:param mask: torch.Tensor. mask. shape=(batch_size, 1, key_time_steps) or (batch_size, query_time_steps, key_time_steps).
|
165 |
-
:return: torch.Tensor. transformed value. (batch_size, query_time_steps, d_model).
|
166 |
-
weighted by the attention score (batch_size, query_time_steps, key_time_steps).
|
167 |
-
"""
|
168 |
-
n_batch = value.size(0)
|
169 |
-
if mask is not None:
|
170 |
-
mask = mask.unsqueeze(1).eq(0)
|
171 |
-
# mask shape: [batch_size, 1, query_time_steps, key_time_steps]
|
172 |
-
scores = scores.masked_fill(mask, -float('inf'))
|
173 |
-
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
|
174 |
-
else:
|
175 |
-
attn = torch.softmax(scores, dim=-1)
|
176 |
-
# attn shape: [batch_size, n_head, query_time_steps, key_time_steps]
|
177 |
-
|
178 |
-
p_attn = self.dropout(attn)
|
179 |
-
|
180 |
-
x = torch.matmul(p_attn, value)
|
181 |
-
# x shape: [batch_size, n_head, query_time_steps, d_k]
|
182 |
-
x = x.transpose(1, 2)
|
183 |
-
# x shape: [batch_size, query_time_steps, n_head, d_k]
|
184 |
-
|
185 |
-
x = x.contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat)
|
186 |
-
# x shape: [batch_size, query_time_steps, n_head * d_k]
|
187 |
-
# x shape: [batch_size, query_time_steps, n_feat]
|
188 |
-
|
189 |
-
x = self.linear_out(x)
|
190 |
-
# x shape: [batch_size, query_time_steps, n_feat]
|
191 |
-
return x
|
192 |
-
|
193 |
-
def relative_position_encoding(self, length: int) -> torch.Tensor:
|
194 |
-
"""
|
195 |
-
Generate relative position encoding.
|
196 |
-
:param length: int. length of the sequence.
|
197 |
-
:return: torch.Tensor. relative position encoding. shape=(length, length, d_k).
|
198 |
-
"""
|
199 |
-
range_vec = torch.arange(length)
|
200 |
-
distance_mat = range_vec.unsqueeze(0) - range_vec.unsqueeze(1)
|
201 |
-
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
|
202 |
-
final_mat = distance_mat_clipped + self.max_relative_position
|
203 |
-
return final_mat
|
204 |
-
|
205 |
-
def forward(self,
|
206 |
-
x: torch.Tensor,
|
207 |
-
mask: torch.Tensor = None,
|
208 |
-
cache: torch.Tensor = None
|
209 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
210 |
-
# attention! self attention.
|
211 |
-
|
212 |
-
q, k, v = self.forward_qkv(x, x, x)
|
213 |
-
# q k v shape: [batch_size, self.h, query_time_steps, self.d_k]
|
214 |
-
|
215 |
-
if cache is not None:
|
216 |
-
key_cache, value_cache = torch.split(
|
217 |
-
cache, cache.size(-1) // 2, dim=-1)
|
218 |
-
k = torch.cat([key_cache, k], dim=2)
|
219 |
-
v = torch.cat([value_cache, v], dim=2)
|
220 |
-
|
221 |
-
# new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2]
|
222 |
-
new_cache = torch.cat((k, v), dim=-1)
|
223 |
-
|
224 |
-
# native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
|
225 |
-
native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
226 |
-
|
227 |
-
# Compute relative position encoding
|
228 |
-
q_length, k_length = q.size(2), k.size(2)
|
229 |
-
relative_position = self.relative_position_encoding(k_length)
|
230 |
-
|
231 |
-
relative_position = relative_position[-q_length:]
|
232 |
-
|
233 |
-
relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1)
|
234 |
-
|
235 |
-
relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k)
|
236 |
-
relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k)
|
237 |
-
|
238 |
-
relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k)
|
239 |
-
# relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps]
|
240 |
-
|
241 |
-
# score
|
242 |
-
scores = native_scores + relative_position_scores
|
243 |
-
|
244 |
-
return self.forward_attention(v, scores, mask), new_cache
|
245 |
-
|
246 |
-
|
247 |
-
def main():
|
248 |
-
rel_attention = RelativeMultiHeadSelfAttention(n_head=4, n_feat=256, dropout_rate=0.1)
|
249 |
-
|
250 |
-
x = torch.ones(size=(1, 200, 256), dtype=torch.float32)
|
251 |
-
xt, new_cache = rel_attention.forward(x, x, x)
|
252 |
-
|
253 |
-
# x = torch.ones(size=(1, 1, 256), dtype=torch.float32)
|
254 |
-
# cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32)
|
255 |
-
# xt, new_cache = rel_attention.forward(x, x, x, cache=cache)
|
256 |
-
|
257 |
-
print(xt.shape)
|
258 |
-
print(new_cache.shape)
|
259 |
-
return
|
260 |
-
|
261 |
-
|
262 |
-
if __name__ == '__main__':
|
263 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/transformers/mask.py
DELETED
@@ -1,74 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
def make_pad_mask(lengths: torch.Tensor,
|
7 |
-
max_len: int = 0,
|
8 |
-
) -> torch.Tensor:
|
9 |
-
batch_size = lengths.size(0)
|
10 |
-
max_len = max_len if max_len > 0 else lengths.max().item()
|
11 |
-
seq_range = torch.arange(
|
12 |
-
0,
|
13 |
-
max_len,
|
14 |
-
dtype=torch.int64,
|
15 |
-
device=lengths.device
|
16 |
-
)
|
17 |
-
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
18 |
-
seq_length_expand = lengths.unsqueeze(-1)
|
19 |
-
mask = seq_range_expand >= seq_length_expand
|
20 |
-
return mask
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
def subsequent_chunk_mask(
|
25 |
-
size: int,
|
26 |
-
chunk_size: int,
|
27 |
-
num_left_chunks: int = -1,
|
28 |
-
num_right_chunks: int = 0,
|
29 |
-
device: torch.device = torch.device("cpu"),
|
30 |
-
) -> torch.Tensor:
|
31 |
-
"""
|
32 |
-
Create mask for subsequent steps (size, size) with chunk size,
|
33 |
-
this is for streaming encoder
|
34 |
-
|
35 |
-
Examples:
|
36 |
-
> subsequent_chunk_mask(4, 2)
|
37 |
-
[[1, 1, 0, 0],
|
38 |
-
[1, 1, 0, 0],
|
39 |
-
[1, 1, 1, 1],
|
40 |
-
[1, 1, 1, 1]]
|
41 |
-
|
42 |
-
:param size: int. size of mask.
|
43 |
-
:param chunk_size: int. size of chunk.
|
44 |
-
:param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks.
|
45 |
-
:param num_right_chunks: int. number of right chunks.
|
46 |
-
:param device: torch.device. "cpu" or "cuda" or torch.Tensor.device.
|
47 |
-
:return: torch.Tensor. mask
|
48 |
-
"""
|
49 |
-
|
50 |
-
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
51 |
-
for i in range(size):
|
52 |
-
if num_left_chunks < 0:
|
53 |
-
start = 0
|
54 |
-
else:
|
55 |
-
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
56 |
-
ending = min((i // chunk_size + 1 + num_right_chunks) * chunk_size, size)
|
57 |
-
ret[i, start:ending] = True
|
58 |
-
return ret
|
59 |
-
|
60 |
-
|
61 |
-
def main():
|
62 |
-
chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2)
|
63 |
-
print(chunk_mask)
|
64 |
-
|
65 |
-
chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
|
66 |
-
print(chunk_mask)
|
67 |
-
|
68 |
-
chunk_mask = subsequent_chunk_mask(size=9, chunk_size=2, num_left_chunks=2, num_right_chunks=1)
|
69 |
-
print(chunk_mask)
|
70 |
-
return
|
71 |
-
|
72 |
-
|
73 |
-
if __name__ == '__main__':
|
74 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/transformers/transformers.py
DELETED
@@ -1,479 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
from typing import Dict, Optional, Tuple, List, Union
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
|
8 |
-
from toolbox.torchaudio.models.nx_clean_unet.transformers.mask import subsequent_chunk_mask
|
9 |
-
from toolbox.torchaudio.models.nx_clean_unet.transformers.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention
|
10 |
-
|
11 |
-
|
12 |
-
class PositionwiseFeedForward(nn.Module):
|
13 |
-
def __init__(self,
|
14 |
-
input_dim: int,
|
15 |
-
hidden_units: int,
|
16 |
-
dropout_rate: float,
|
17 |
-
activation: torch.nn.Module = torch.nn.ReLU()):
|
18 |
-
"""
|
19 |
-
FeedForward are applied on each position of the sequence.
|
20 |
-
the output dim is same with the input dim.
|
21 |
-
|
22 |
-
:param input_dim: int. input dimension.
|
23 |
-
:param hidden_units: int. the number of hidden units.
|
24 |
-
:param dropout_rate: float. dropout rate.
|
25 |
-
:param activation: torch.nn.Module. activation function.
|
26 |
-
"""
|
27 |
-
super(PositionwiseFeedForward, self).__init__()
|
28 |
-
self.w_1 = torch.nn.Linear(input_dim, hidden_units)
|
29 |
-
self.activation = activation
|
30 |
-
self.dropout = torch.nn.Dropout(dropout_rate)
|
31 |
-
self.w_2 = torch.nn.Linear(hidden_units, input_dim)
|
32 |
-
|
33 |
-
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
34 |
-
"""
|
35 |
-
Forward function.
|
36 |
-
:param xs: torch.Tensor. input tensor. shape=(batch_size, max_length, dim).
|
37 |
-
:return: output tensor. shape=(batch_size, max_length, dim).
|
38 |
-
"""
|
39 |
-
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
40 |
-
|
41 |
-
|
42 |
-
class TransformerBlock(nn.Module):
|
43 |
-
def __init__(self,
|
44 |
-
input_dim: int,
|
45 |
-
dropout_rate: float = 0.1,
|
46 |
-
n_heads: int = 4,
|
47 |
-
max_relative_position: int = 5120
|
48 |
-
):
|
49 |
-
super().__init__()
|
50 |
-
self.norm1 = nn.LayerNorm(input_dim, eps=1e-5)
|
51 |
-
self.attention = RelativeMultiHeadSelfAttention(
|
52 |
-
n_head=n_heads,
|
53 |
-
n_feat=input_dim,
|
54 |
-
dropout_rate=dropout_rate,
|
55 |
-
max_relative_position=max_relative_position,
|
56 |
-
)
|
57 |
-
|
58 |
-
self.dropout1 = nn.Dropout(dropout_rate)
|
59 |
-
self.norm2 = nn.LayerNorm(input_dim, eps=1e-5)
|
60 |
-
self.ffn = PositionwiseFeedForward(
|
61 |
-
input_dim=input_dim,
|
62 |
-
hidden_units=input_dim,
|
63 |
-
dropout_rate=dropout_rate
|
64 |
-
)
|
65 |
-
self.dropout2 = nn.Dropout(dropout_rate)
|
66 |
-
self.norm3 = nn.LayerNorm(input_dim, eps=1e-5)
|
67 |
-
|
68 |
-
def forward(
|
69 |
-
self,
|
70 |
-
x: torch.Tensor,
|
71 |
-
mask: torch.Tensor = None,
|
72 |
-
attention_cache: torch.Tensor = None,
|
73 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
74 |
-
"""
|
75 |
-
|
76 |
-
:param x: torch.Tensor. shape=(batch_size, time, input_dim).
|
77 |
-
:param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time).
|
78 |
-
:param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE
|
79 |
-
shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim.
|
80 |
-
:return:
|
81 |
-
torch.Tensor: Output tensor (batch_size, time, input_dim).
|
82 |
-
torch.Tensor: att_cache tensor, (batch_size=1, head, cache_t1 + time, d_k * 2).
|
83 |
-
"""
|
84 |
-
xt = self.norm1(x)
|
85 |
-
|
86 |
-
x_att, new_att_cache = self.attention.forward(
|
87 |
-
xt, mask=mask, cache=attention_cache
|
88 |
-
)
|
89 |
-
x = x + self.dropout1(xt)
|
90 |
-
xt = self.norm2(x)
|
91 |
-
xt = self.ffn.forward(xt)
|
92 |
-
x = x + self.dropout2(xt)
|
93 |
-
|
94 |
-
x = self.norm3(x)
|
95 |
-
|
96 |
-
return x, new_att_cache
|
97 |
-
|
98 |
-
|
99 |
-
class TransformerEncoder(nn.Module):
|
100 |
-
"""
|
101 |
-
https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/encoder.py#L364
|
102 |
-
"""
|
103 |
-
def __init__(self,
|
104 |
-
input_size: int = 64,
|
105 |
-
hidden_size: int = 256,
|
106 |
-
attention_heads: int = 4,
|
107 |
-
num_blocks: int = 6,
|
108 |
-
dropout_rate: float = 0.1,
|
109 |
-
max_relative_position: int = 1024,
|
110 |
-
chunk_size: int = 1,
|
111 |
-
num_left_chunks: int = 128,
|
112 |
-
num_right_chunks: int = 2,
|
113 |
-
):
|
114 |
-
super().__init__()
|
115 |
-
self.input_size = input_size
|
116 |
-
self.hidden_size = hidden_size
|
117 |
-
|
118 |
-
self.max_relative_position = max_relative_position
|
119 |
-
self.chunk_size = chunk_size
|
120 |
-
self.num_left_chunks = num_left_chunks
|
121 |
-
self.num_right_chunks = num_right_chunks
|
122 |
-
|
123 |
-
self.input_linear = nn.Linear(
|
124 |
-
in_features=self.input_size,
|
125 |
-
out_features=self.hidden_size,
|
126 |
-
)
|
127 |
-
|
128 |
-
self.encoder_layer_list = torch.nn.ModuleList([
|
129 |
-
TransformerBlock(
|
130 |
-
input_dim=hidden_size,
|
131 |
-
n_heads=attention_heads,
|
132 |
-
dropout_rate=dropout_rate,
|
133 |
-
max_relative_position=max_relative_position,
|
134 |
-
) for _ in range(num_blocks)
|
135 |
-
])
|
136 |
-
|
137 |
-
self.output_linear = nn.Linear(
|
138 |
-
in_features=self.hidden_size,
|
139 |
-
out_features=self.input_size,
|
140 |
-
)
|
141 |
-
|
142 |
-
def forward(self,
|
143 |
-
xs: torch.Tensor,
|
144 |
-
):
|
145 |
-
"""
|
146 |
-
:param xs: Tensor, shape: [batch_size, time_steps, input_size]
|
147 |
-
:return: Tensor, shape: [batch_size, time_steps, input_size]
|
148 |
-
"""
|
149 |
-
batch_size, time_steps, _ = xs.shape
|
150 |
-
# xs shape: [batch_size, time_steps, input_size]
|
151 |
-
xs = self.input_linear.forward(xs)
|
152 |
-
# xs shape: [batch_size, time_steps, hidden_size]
|
153 |
-
|
154 |
-
chunk_masks = subsequent_chunk_mask(
|
155 |
-
size=time_steps,
|
156 |
-
chunk_size=self.chunk_size,
|
157 |
-
num_left_chunks=self.num_left_chunks,
|
158 |
-
num_right_chunks=self.num_right_chunks,
|
159 |
-
)
|
160 |
-
chunk_masks = chunk_masks.to(xs.device)
|
161 |
-
# chunk_masks shape: [time_steps, time_steps]
|
162 |
-
chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps))
|
163 |
-
# chunk_masks shape: [batch_size, time_steps, time_steps]
|
164 |
-
|
165 |
-
for encoder_layer in self.encoder_layer_list:
|
166 |
-
xs, _ = encoder_layer.forward(xs, chunk_masks)
|
167 |
-
|
168 |
-
# xs shape: [batch_size, time_steps, hidden_size]
|
169 |
-
xs = self.output_linear.forward(xs)
|
170 |
-
# xs shape: [batch_size, time_steps, input_size]
|
171 |
-
|
172 |
-
return xs
|
173 |
-
|
174 |
-
def forward_chunk(self,
|
175 |
-
xs: torch.Tensor,
|
176 |
-
max_att_cache_length: int,
|
177 |
-
attention_cache: torch.Tensor = None,
|
178 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
179 |
-
"""
|
180 |
-
|
181 |
-
:param xs:
|
182 |
-
:param max_att_cache_length:
|
183 |
-
:param attention_cache: Tensor, [num_layers, ...]
|
184 |
-
:return:
|
185 |
-
"""
|
186 |
-
# xs shape: [batch_size, time_steps, input_size]
|
187 |
-
xs = self.input_linear.forward(xs)
|
188 |
-
# xs shape: [batch_size, time_steps, hidden_size]
|
189 |
-
|
190 |
-
r_att_cache = []
|
191 |
-
for idx, encoder_layer in enumerate(self.encoder_layer_list):
|
192 |
-
xs, new_att_cache = encoder_layer.forward(
|
193 |
-
x=xs, attention_cache=attention_cache[idx] if attention_cache is not None else None,
|
194 |
-
)
|
195 |
-
# new_att_cache shape: [batch_size, n_heads, time_steps, dim]
|
196 |
-
if new_att_cache.size(2) > max_att_cache_length:
|
197 |
-
begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
|
198 |
-
end = self.num_right_chunks * self.chunk_size
|
199 |
-
new_att_cache = new_att_cache[:, :, -begin:-end, :]
|
200 |
-
r_att_cache.append(new_att_cache)
|
201 |
-
|
202 |
-
r_att_cache = torch.stack(r_att_cache, dim=0)
|
203 |
-
|
204 |
-
# xs shape: [batch_size, time_steps, hidden_size]
|
205 |
-
xs = self.output_linear.forward(xs)
|
206 |
-
# xs shape: [batch_size, time_steps, input_size]
|
207 |
-
|
208 |
-
return xs, r_att_cache
|
209 |
-
|
210 |
-
def forward_chunk_by_chunk(
|
211 |
-
self,
|
212 |
-
xs: torch.Tensor,
|
213 |
-
) -> torch.Tensor:
|
214 |
-
|
215 |
-
batch_size, time_steps, _ = xs.shape
|
216 |
-
|
217 |
-
# attention_cache shape: [num_blocks, attention_heads, self.num_left_chunks * self.chunk_size, n_heads * d_k * 2]
|
218 |
-
max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
|
219 |
-
attention_cache = None
|
220 |
-
|
221 |
-
outputs = []
|
222 |
-
for idx in range(0, time_steps, self.chunk_size):
|
223 |
-
begin = idx
|
224 |
-
end = begin + self.chunk_size * (self.num_right_chunks + 1)
|
225 |
-
chunk_xs = xs[:, begin:end, :]
|
226 |
-
# print(f"begin: {begin}, end: {end}, length: {chunk_xs.size(1)}")
|
227 |
-
|
228 |
-
ys, attention_cache = self.forward_chunk(
|
229 |
-
xs=chunk_xs,
|
230 |
-
max_att_cache_length=max_att_cache_length,
|
231 |
-
attention_cache=attention_cache,
|
232 |
-
)
|
233 |
-
|
234 |
-
# ys shape: [batch_size, self.chunk_size * (self.num_right_chunks + 1), input_size]
|
235 |
-
ys = ys[:, :self.chunk_size, :]
|
236 |
-
|
237 |
-
outputs.append(ys)
|
238 |
-
|
239 |
-
ys = torch.cat(outputs, 1)
|
240 |
-
return ys
|
241 |
-
|
242 |
-
|
243 |
-
class TSTransformerBlock(nn.Module):
|
244 |
-
def __init__(self,
|
245 |
-
input_dim: int,
|
246 |
-
dropout_rate: float = 0.1,
|
247 |
-
n_heads: int = 4,
|
248 |
-
max_time_relative_position: int = 1024,
|
249 |
-
max_freq_relative_position: int = 128,
|
250 |
-
):
|
251 |
-
super(TSTransformerBlock, self).__init__()
|
252 |
-
self.time_transformer = TransformerBlock(input_dim, dropout_rate, n_heads, max_time_relative_position)
|
253 |
-
self.freq_transformer = TransformerBlock(input_dim, dropout_rate, n_heads, max_freq_relative_position)
|
254 |
-
|
255 |
-
def forward(self,
|
256 |
-
x: torch.Tensor,
|
257 |
-
mask: torch.Tensor = None,
|
258 |
-
attention_cache: torch.Tensor = None,
|
259 |
-
):
|
260 |
-
"""
|
261 |
-
|
262 |
-
:param x: Tensor. shape: [batch_size, hidden_size, time_steps, input_size]
|
263 |
-
:param mask: Tensor. shape: [time_steps, time_steps]
|
264 |
-
:param attention_cache:
|
265 |
-
:return:
|
266 |
-
"""
|
267 |
-
b, c, t, f = x.size()
|
268 |
-
|
269 |
-
mask = None if mask is None else torch.broadcast_to(mask, size=(b*f, t, t))
|
270 |
-
|
271 |
-
x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
|
272 |
-
x_, new_att_cache = self.time_transformer.forward(x, mask, attention_cache)
|
273 |
-
x = x_ + x
|
274 |
-
x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
|
275 |
-
x_, _ = self.freq_transformer.forward(x)
|
276 |
-
x = x_ + x
|
277 |
-
x = x.view(b, t, f, c).permute(0, 3, 1, 2)
|
278 |
-
return x, new_att_cache
|
279 |
-
|
280 |
-
|
281 |
-
class TSTransformerEncoder(nn.Module):
|
282 |
-
def __init__(self,
|
283 |
-
input_size: int = 64,
|
284 |
-
hidden_size: int = 256,
|
285 |
-
attention_heads: int = 4,
|
286 |
-
num_blocks: int = 6,
|
287 |
-
dropout_rate: float = 0.1,
|
288 |
-
max_time_relative_position: int = 1024,
|
289 |
-
max_freq_relative_position: int = 128,
|
290 |
-
chunk_size: int = 1,
|
291 |
-
num_left_chunks: int = 128,
|
292 |
-
num_right_chunks: int = 2,
|
293 |
-
):
|
294 |
-
super().__init__()
|
295 |
-
self.input_size = input_size
|
296 |
-
self.hidden_size = hidden_size
|
297 |
-
|
298 |
-
self.max_time_relative_position = max_time_relative_position
|
299 |
-
self.max_freq_relative_position = max_freq_relative_position
|
300 |
-
self.chunk_size = chunk_size
|
301 |
-
self.num_left_chunks = num_left_chunks
|
302 |
-
self.num_right_chunks = num_right_chunks
|
303 |
-
|
304 |
-
self.input_linear = nn.Linear(
|
305 |
-
in_features=self.input_size,
|
306 |
-
out_features=self.hidden_size,
|
307 |
-
)
|
308 |
-
|
309 |
-
self.encoder_layer_list = torch.nn.ModuleList([
|
310 |
-
TSTransformerBlock(
|
311 |
-
input_dim=hidden_size,
|
312 |
-
n_heads=attention_heads,
|
313 |
-
dropout_rate=dropout_rate,
|
314 |
-
max_time_relative_position=max_time_relative_position,
|
315 |
-
max_freq_relative_position=max_freq_relative_position,
|
316 |
-
) for _ in range(num_blocks)
|
317 |
-
])
|
318 |
-
|
319 |
-
self.output_linear = nn.Linear(
|
320 |
-
in_features=self.hidden_size,
|
321 |
-
out_features=self.input_size,
|
322 |
-
)
|
323 |
-
|
324 |
-
def forward(self,
|
325 |
-
xs: torch.Tensor,
|
326 |
-
):
|
327 |
-
"""
|
328 |
-
:param xs: Tensor, shape: [batch_size, channels, time_steps, input_size]
|
329 |
-
:return: Tensor, shape: [batch_size, channels, time_steps, input_size]
|
330 |
-
"""
|
331 |
-
batch_size, channels, time_steps, _ = xs.shape
|
332 |
-
# xs shape: [batch_size, channels, time_steps, input_size]
|
333 |
-
xs = xs.permute(0, 3, 2, 1)
|
334 |
-
# xs shape: [batch_size, input_size, time_steps, channels]
|
335 |
-
xs = self.input_linear.forward(xs)
|
336 |
-
# xs shape: [batch_size, input_size, time_steps, hidden_size]
|
337 |
-
xs = xs.permute(0, 3, 2, 1)
|
338 |
-
# xs shape: [batch_size, hidden_size, time_steps, input_size]
|
339 |
-
|
340 |
-
chunk_masks = subsequent_chunk_mask(
|
341 |
-
size=time_steps,
|
342 |
-
chunk_size=self.chunk_size,
|
343 |
-
num_left_chunks=self.num_left_chunks,
|
344 |
-
num_right_chunks=self.num_right_chunks,
|
345 |
-
)
|
346 |
-
chunk_masks = chunk_masks.to(xs.device)
|
347 |
-
# chunk_masks shape: [time_steps, time_steps]
|
348 |
-
|
349 |
-
for encoder_layer in self.encoder_layer_list:
|
350 |
-
xs, _ = encoder_layer.forward(xs, chunk_masks)
|
351 |
-
# xs shape: [batch_size, hidden_size, time_steps, input_size]
|
352 |
-
xs = xs.permute(0, 3, 2, 1)
|
353 |
-
# xs shape: [batch_size, input_size, time_steps, hidden_size]
|
354 |
-
xs = self.output_linear.forward(xs)
|
355 |
-
# xs shape: [batch_size, input_size, time_steps, channels]
|
356 |
-
xs = xs.permute(0, 3, 2, 1)
|
357 |
-
# xs shape: [batch_size, channels, time_steps, input_size]
|
358 |
-
|
359 |
-
return xs
|
360 |
-
|
361 |
-
def forward_chunk(self,
|
362 |
-
xs: torch.Tensor,
|
363 |
-
max_att_cache_length: int,
|
364 |
-
attention_cache: torch.Tensor = None,
|
365 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
366 |
-
"""
|
367 |
-
|
368 |
-
:param xs:
|
369 |
-
:param max_att_cache_length:
|
370 |
-
:param attention_cache: Tensor, shape: [num_layers, ...]
|
371 |
-
:return:
|
372 |
-
"""
|
373 |
-
# xs shape: [batch_size, channels, time_steps, input_size]
|
374 |
-
xs = xs.permute(0, 3, 2, 1)
|
375 |
-
xs = self.input_linear.forward(xs)
|
376 |
-
xs = xs.permute(0, 3, 2, 1)
|
377 |
-
# xs shape: [batch_size, hidden_size, time_steps, input_size]
|
378 |
-
|
379 |
-
r_att_cache = []
|
380 |
-
for idx, encoder_layer in enumerate(self.encoder_layer_list):
|
381 |
-
xs, new_att_cache = encoder_layer.forward(
|
382 |
-
x=xs, attention_cache=attention_cache[idx] if attention_cache is not None else None,
|
383 |
-
)
|
384 |
-
# new_att_cache shape: [b*f, n_heads, time_steps, dim]
|
385 |
-
if new_att_cache.size(2) > max_att_cache_length:
|
386 |
-
begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
|
387 |
-
end = self.num_right_chunks * self.chunk_size
|
388 |
-
new_att_cache = new_att_cache[:, :, -begin:-end, :]
|
389 |
-
r_att_cache.append(new_att_cache)
|
390 |
-
|
391 |
-
r_att_cache = torch.stack(r_att_cache, dim=0)
|
392 |
-
|
393 |
-
# xs shape: [batch_size, hidden_size, time_steps, input_size]
|
394 |
-
xs = xs.permute(0, 3, 2, 1)
|
395 |
-
xs = self.output_linear.forward(xs)
|
396 |
-
xs = xs.permute(0, 3, 2, 1)
|
397 |
-
# xs shape: [batch_size, channels, time_steps, input_size]
|
398 |
-
|
399 |
-
return xs, r_att_cache
|
400 |
-
|
401 |
-
def forward_chunk_by_chunk(
|
402 |
-
self,
|
403 |
-
xs: torch.Tensor,
|
404 |
-
) -> torch.Tensor:
|
405 |
-
|
406 |
-
batch_size, channels, time_steps, _ = xs.shape
|
407 |
-
|
408 |
-
max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size
|
409 |
-
attention_cache = None
|
410 |
-
|
411 |
-
outputs = []
|
412 |
-
for idx in range(0, time_steps, self.chunk_size):
|
413 |
-
begin = idx
|
414 |
-
end = begin + self.chunk_size * (self.num_right_chunks + 1)
|
415 |
-
chunk_xs = xs[:, :, begin:end, :]
|
416 |
-
# chunk_xs shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
|
417 |
-
|
418 |
-
ys, attention_cache = self.forward_chunk(
|
419 |
-
xs=chunk_xs,
|
420 |
-
max_att_cache_length=max_att_cache_length,
|
421 |
-
attention_cache=attention_cache,
|
422 |
-
)
|
423 |
-
# ys shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size]
|
424 |
-
ys = ys[:, :, :self.chunk_size, :]
|
425 |
-
|
426 |
-
outputs.append(ys)
|
427 |
-
|
428 |
-
ys = torch.cat(outputs, dim=2)
|
429 |
-
return ys
|
430 |
-
|
431 |
-
|
432 |
-
def main2():
|
433 |
-
|
434 |
-
encoder = TransformerEncoder(
|
435 |
-
input_size=64,
|
436 |
-
hidden_size=256,
|
437 |
-
attention_heads=4,
|
438 |
-
num_blocks=6,
|
439 |
-
dropout_rate=0.1,
|
440 |
-
)
|
441 |
-
print(encoder)
|
442 |
-
|
443 |
-
x = torch.ones([4, 200, 64])
|
444 |
-
|
445 |
-
x = torch.ones([4, 200, 64])
|
446 |
-
y = encoder.forward(xs=x)
|
447 |
-
print(y.shape)
|
448 |
-
|
449 |
-
x = torch.ones([4, 200, 64])
|
450 |
-
y = encoder.forward_chunk_by_chunk(xs=x)
|
451 |
-
print(y.shape)
|
452 |
-
|
453 |
-
return
|
454 |
-
|
455 |
-
|
456 |
-
def main():
|
457 |
-
|
458 |
-
encoder = TSTransformerEncoder(
|
459 |
-
input_size=8,
|
460 |
-
hidden_size=16,
|
461 |
-
attention_heads=2,
|
462 |
-
num_blocks=2,
|
463 |
-
dropout_rate=0.1,
|
464 |
-
)
|
465 |
-
# print(encoder)
|
466 |
-
|
467 |
-
x = torch.ones([4, 8, 200, 8])
|
468 |
-
y = encoder.forward(xs=x)
|
469 |
-
print(y.shape)
|
470 |
-
|
471 |
-
x = torch.ones([4, 8, 200, 8])
|
472 |
-
y = encoder.forward_chunk_by_chunk(xs=x)
|
473 |
-
print(y.shape)
|
474 |
-
|
475 |
-
return
|
476 |
-
|
477 |
-
|
478 |
-
if __name__ == '__main__':
|
479 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/utils.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
|
6 |
-
|
7 |
-
class LearnableSigmoid1d(nn.Module):
|
8 |
-
def __init__(self, in_features, beta=1):
|
9 |
-
super().__init__()
|
10 |
-
self.beta = beta
|
11 |
-
self.slope = nn.Parameter(torch.ones(in_features))
|
12 |
-
self.slope.requiresGrad = True
|
13 |
-
|
14 |
-
def forward(self, x):
|
15 |
-
# x shape: [batch_size, time_steps, spec_bins]
|
16 |
-
return self.beta * torch.sigmoid(self.slope * x)
|
17 |
-
|
18 |
-
|
19 |
-
def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
|
20 |
-
|
21 |
-
hann_window = torch.hann_window(win_size).to(y.device)
|
22 |
-
stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window,
|
23 |
-
center=center, pad_mode='reflect', normalized=False, return_complex=True)
|
24 |
-
stft_spec = torch.view_as_real(stft_spec)
|
25 |
-
mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9)
|
26 |
-
pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5)
|
27 |
-
# Magnitude Compression
|
28 |
-
mag = torch.pow(mag, compress_factor)
|
29 |
-
com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1)
|
30 |
-
|
31 |
-
return mag, pha, com
|
32 |
-
|
33 |
-
|
34 |
-
def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True):
|
35 |
-
# Magnitude Decompression
|
36 |
-
mag = torch.pow(mag, (1.0/compress_factor))
|
37 |
-
com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha))
|
38 |
-
hann_window = torch.hann_window(win_size).to(com.device)
|
39 |
-
wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center)
|
40 |
-
|
41 |
-
return wav
|
42 |
-
|
43 |
-
|
44 |
-
if __name__ == '__main__':
|
45 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_denoise/yaml/config.yaml
DELETED
@@ -1,51 +0,0 @@
|
|
1 |
-
model_name: "nx_denoise"
|
2 |
-
|
3 |
-
sample_rate: 8000
|
4 |
-
segment_size: 16000
|
5 |
-
n_fft: 512
|
6 |
-
win_size: 200
|
7 |
-
hop_size: 80
|
8 |
-
# 因为 hop_size 取 80,则相当于 stft 的时间步是 10ms 一步,所以降采样也考虑到差不多的分辨率。
|
9 |
-
|
10 |
-
# 2**down_sampling_num_layers,
|
11 |
-
# 例如 2**6=64 就意味着 64 个值在降采样之后是一个时间步,
|
12 |
-
# 则一步是 64/sample_rate = 0.008秒。
|
13 |
-
# 那么 tsfm_chunk_size=2 则为16ms,tsfm_chunk_size=4 则为32ms
|
14 |
-
# 假设每次向左看1秒,向右看30ms,则:
|
15 |
-
# tsfm_chunk_size=1,tsfm_num_left_chunks=128,tsfm_num_right_chunks=4
|
16 |
-
# tsfm_chunk_size=2,tsfm_num_left_chunks=64,tsfm_num_right_chunks=2
|
17 |
-
# tsfm_chunk_size=4,tsfm_num_left_chunks=32,tsfm_num_right_chunks=1
|
18 |
-
down_sampling_num_layers: 6
|
19 |
-
down_sampling_in_channels: 1
|
20 |
-
down_sampling_hidden_channels: 64
|
21 |
-
down_sampling_kernel_size: 4
|
22 |
-
down_sampling_stride: 2
|
23 |
-
|
24 |
-
causal_in_channels: 1
|
25 |
-
causal_out_channels: 64
|
26 |
-
causal_kernel_size: 3
|
27 |
-
causal_bias: false
|
28 |
-
causal_separable: true
|
29 |
-
causal_f_stride: 1
|
30 |
-
causal_num_layers: 3
|
31 |
-
|
32 |
-
tsfm_hidden_size: 256
|
33 |
-
tsfm_attention_heads: 8
|
34 |
-
tsfm_num_blocks: 6
|
35 |
-
tsfm_dropout_rate: 0.1
|
36 |
-
tsfm_max_length: 512
|
37 |
-
tsfm_chunk_size: 1
|
38 |
-
tsfm_num_left_chunks: 128
|
39 |
-
tsfm_num_right_chunks: 4
|
40 |
-
|
41 |
-
discriminator_dim: 32
|
42 |
-
discriminator_in_channel: 2
|
43 |
-
|
44 |
-
compress_factor: 0.3
|
45 |
-
|
46 |
-
batch_size: 4
|
47 |
-
learning_rate: 0.0005
|
48 |
-
adam_b1: 0.8
|
49 |
-
adam_b2: 0.99
|
50 |
-
lr_decay: 0.99
|
51 |
-
seed: 1234
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_dfnet/configuration_nx_dfnet.py
DELETED
@@ -1,102 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
from typing import Tuple
|
4 |
-
|
5 |
-
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
6 |
-
|
7 |
-
|
8 |
-
class NXDfNetConfig(PretrainedConfig):
|
9 |
-
def __init__(self,
|
10 |
-
sample_rate: int = 8000,
|
11 |
-
freq_bins: int = 256,
|
12 |
-
win_size: int = 200,
|
13 |
-
hop_size: int = 100,
|
14 |
-
|
15 |
-
conv_channels: int = 64,
|
16 |
-
conv_kernel_size_input: Tuple[int, int] = (3, 3),
|
17 |
-
conv_kernel_size_inner: Tuple[int, int] = (1, 3),
|
18 |
-
conv_lookahead: int = 0,
|
19 |
-
|
20 |
-
convt_kernel_size_inner: Tuple[int, int] = (1, 3),
|
21 |
-
|
22 |
-
embedding_hidden_size: int = 256,
|
23 |
-
encoder_combine_op: str = "concat",
|
24 |
-
|
25 |
-
encoder_emb_skip_op: str = "none",
|
26 |
-
encoder_emb_linear_groups: int = 16,
|
27 |
-
encoder_emb_hidden_size: int = 256,
|
28 |
-
|
29 |
-
encoder_linear_groups: int = 32,
|
30 |
-
|
31 |
-
lsnr_max: int = 30,
|
32 |
-
lsnr_min: int = -15,
|
33 |
-
norm_tau: float = 1.,
|
34 |
-
|
35 |
-
decoder_emb_num_layers: int = 3,
|
36 |
-
decoder_emb_skip_op: str = "none",
|
37 |
-
decoder_emb_linear_groups: int = 16,
|
38 |
-
decoder_emb_hidden_size: int = 256,
|
39 |
-
|
40 |
-
df_decoder_hidden_size: int = 256,
|
41 |
-
df_num_layers: int = 2,
|
42 |
-
df_order: int = 5,
|
43 |
-
df_bins: int = 96,
|
44 |
-
df_gru_skip: str = "grouped_linear",
|
45 |
-
df_decoder_linear_groups: int = 16,
|
46 |
-
df_pathway_kernel_size_t: int = 5,
|
47 |
-
df_lookahead: int = 2,
|
48 |
-
|
49 |
-
use_post_filter: bool = False,
|
50 |
-
**kwargs
|
51 |
-
):
|
52 |
-
super(NXDfNetConfig, self).__init__(**kwargs)
|
53 |
-
# transform
|
54 |
-
self.sample_rate = sample_rate
|
55 |
-
self.freq_bins = freq_bins
|
56 |
-
self.win_size = win_size
|
57 |
-
self.hop_size = hop_size
|
58 |
-
|
59 |
-
# conv
|
60 |
-
self.conv_channels = conv_channels
|
61 |
-
self.conv_kernel_size_input = conv_kernel_size_input
|
62 |
-
self.conv_kernel_size_inner = conv_kernel_size_inner
|
63 |
-
self.conv_lookahead = conv_lookahead
|
64 |
-
|
65 |
-
self.convt_kernel_size_inner = convt_kernel_size_inner
|
66 |
-
|
67 |
-
self.embedding_hidden_size = embedding_hidden_size
|
68 |
-
|
69 |
-
# encoder
|
70 |
-
self.encoder_emb_skip_op = encoder_emb_skip_op
|
71 |
-
self.encoder_emb_linear_groups = encoder_emb_linear_groups
|
72 |
-
self.encoder_emb_hidden_size = encoder_emb_hidden_size
|
73 |
-
|
74 |
-
self.encoder_linear_groups = encoder_linear_groups
|
75 |
-
self.encoder_combine_op = encoder_combine_op
|
76 |
-
|
77 |
-
self.lsnr_max = lsnr_max
|
78 |
-
self.lsnr_min = lsnr_min
|
79 |
-
self.norm_tau = norm_tau
|
80 |
-
|
81 |
-
# decoder
|
82 |
-
self.decoder_emb_num_layers = decoder_emb_num_layers
|
83 |
-
self.decoder_emb_skip_op = decoder_emb_skip_op
|
84 |
-
self.decoder_emb_linear_groups = decoder_emb_linear_groups
|
85 |
-
self.decoder_emb_hidden_size = decoder_emb_hidden_size
|
86 |
-
|
87 |
-
# df decoder
|
88 |
-
self.df_decoder_hidden_size = df_decoder_hidden_size
|
89 |
-
self.df_num_layers = df_num_layers
|
90 |
-
self.df_order = df_order
|
91 |
-
self.df_bins = df_bins
|
92 |
-
self.df_gru_skip = df_gru_skip
|
93 |
-
self.df_decoder_linear_groups = df_decoder_linear_groups
|
94 |
-
self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
|
95 |
-
self.df_lookahead = df_lookahead
|
96 |
-
|
97 |
-
# runtime
|
98 |
-
self.use_post_filter = use_post_filter
|
99 |
-
|
100 |
-
|
101 |
-
if __name__ == "__main__":
|
102 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_dfnet/modeling_nx_dfnet.py
DELETED
@@ -1,989 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import os
|
4 |
-
import math
|
5 |
-
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import torch
|
9 |
-
import torch.nn as nn
|
10 |
-
from torch.nn import functional as F
|
11 |
-
import torchaudio
|
12 |
-
|
13 |
-
from toolbox.torchaudio.models.nx_dfnet.utils import overlap_and_add
|
14 |
-
from toolbox.torchaudio.models.nx_dfnet.configuration_nx_dfnet import NXDfNetConfig
|
15 |
-
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
16 |
-
|
17 |
-
|
18 |
-
MODEL_FILE = "model.pt"
|
19 |
-
|
20 |
-
|
21 |
-
norm_layer_dict = {
|
22 |
-
"batch_norm_2d": torch.nn.BatchNorm2d
|
23 |
-
}
|
24 |
-
|
25 |
-
|
26 |
-
activation_layer_dict = {
|
27 |
-
"relu": torch.nn.ReLU,
|
28 |
-
"identity": torch.nn.Identity,
|
29 |
-
"sigmoid": torch.nn.Sigmoid,
|
30 |
-
}
|
31 |
-
|
32 |
-
|
33 |
-
class CausalConv2d(nn.Sequential):
|
34 |
-
def __init__(self,
|
35 |
-
in_channels: int,
|
36 |
-
out_channels: int,
|
37 |
-
kernel_size: Union[int, Iterable[int]],
|
38 |
-
fstride: int = 1,
|
39 |
-
dilation: int = 1,
|
40 |
-
fpad: bool = True,
|
41 |
-
bias: bool = True,
|
42 |
-
separable: bool = False,
|
43 |
-
norm_layer: str = "batch_norm_2d",
|
44 |
-
activation_layer: str = "relu",
|
45 |
-
lookahead: int = 0
|
46 |
-
):
|
47 |
-
"""
|
48 |
-
Causal Conv2d by delaying the signal for any lookahead.
|
49 |
-
|
50 |
-
Expected input format: [batch_size, channels, time_steps, spec_dim]
|
51 |
-
|
52 |
-
:param in_channels:
|
53 |
-
:param out_channels:
|
54 |
-
:param kernel_size:
|
55 |
-
:param fstride:
|
56 |
-
:param dilation:
|
57 |
-
:param fpad:
|
58 |
-
"""
|
59 |
-
super(CausalConv2d, self).__init__()
|
60 |
-
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
61 |
-
|
62 |
-
if fpad:
|
63 |
-
fpad_ = kernel_size[1] // 2 + dilation - 1
|
64 |
-
else:
|
65 |
-
fpad_ = 0
|
66 |
-
|
67 |
-
# for last 2 dim, pad (left, right, top, bottom).
|
68 |
-
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
|
69 |
-
|
70 |
-
layers = list()
|
71 |
-
if any(x > 0 for x in pad):
|
72 |
-
layers.append(nn.ConstantPad2d(pad, 0.0))
|
73 |
-
|
74 |
-
groups = math.gcd(in_channels, out_channels) if separable else 1
|
75 |
-
if groups == 1:
|
76 |
-
separable = False
|
77 |
-
if max(kernel_size) == 1:
|
78 |
-
separable = False
|
79 |
-
|
80 |
-
layers.append(
|
81 |
-
nn.Conv2d(
|
82 |
-
in_channels,
|
83 |
-
out_channels,
|
84 |
-
kernel_size=kernel_size,
|
85 |
-
padding=(0, fpad_),
|
86 |
-
stride=(1, fstride), # stride over time is always 1
|
87 |
-
dilation=(1, dilation), # dilation over time is always 1
|
88 |
-
groups=groups,
|
89 |
-
bias=bias,
|
90 |
-
)
|
91 |
-
)
|
92 |
-
|
93 |
-
if separable:
|
94 |
-
layers.append(
|
95 |
-
nn.Conv2d(
|
96 |
-
out_channels,
|
97 |
-
out_channels,
|
98 |
-
kernel_size=1,
|
99 |
-
bias=False,
|
100 |
-
)
|
101 |
-
)
|
102 |
-
|
103 |
-
if norm_layer is not None:
|
104 |
-
norm_layer = norm_layer_dict[norm_layer]
|
105 |
-
layers.append(norm_layer(out_channels))
|
106 |
-
|
107 |
-
if activation_layer is not None:
|
108 |
-
activation_layer = activation_layer_dict[activation_layer]
|
109 |
-
layers.append(activation_layer())
|
110 |
-
|
111 |
-
super().__init__(*layers)
|
112 |
-
|
113 |
-
def forward(self, inputs):
|
114 |
-
for module in self:
|
115 |
-
inputs = module(inputs)
|
116 |
-
return inputs
|
117 |
-
|
118 |
-
|
119 |
-
class CausalConvTranspose2d(nn.Sequential):
|
120 |
-
def __init__(self,
|
121 |
-
in_channels: int,
|
122 |
-
out_channels: int,
|
123 |
-
kernel_size: Union[int, Iterable[int]],
|
124 |
-
fstride: int = 1,
|
125 |
-
dilation: int = 1,
|
126 |
-
fpad: bool = True,
|
127 |
-
bias: bool = True,
|
128 |
-
separable: bool = False,
|
129 |
-
norm_layer: str = "batch_norm_2d",
|
130 |
-
activation_layer: str = "relu",
|
131 |
-
lookahead: int = 0
|
132 |
-
):
|
133 |
-
"""
|
134 |
-
Causal ConvTranspose2d.
|
135 |
-
|
136 |
-
Expected input format: [batch_size, channels, time_steps, spec_dim]
|
137 |
-
"""
|
138 |
-
super(CausalConvTranspose2d, self).__init__()
|
139 |
-
|
140 |
-
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
141 |
-
|
142 |
-
if fpad:
|
143 |
-
fpad_ = kernel_size[1] // 2
|
144 |
-
else:
|
145 |
-
fpad_ = 0
|
146 |
-
|
147 |
-
# for last 2 dim, pad (left, right, top, bottom).
|
148 |
-
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
|
149 |
-
|
150 |
-
layers = []
|
151 |
-
if any(x > 0 for x in pad):
|
152 |
-
layers.append(nn.ConstantPad2d(pad, 0.0))
|
153 |
-
|
154 |
-
groups = math.gcd(in_channels, out_channels) if separable else 1
|
155 |
-
if groups == 1:
|
156 |
-
separable = False
|
157 |
-
|
158 |
-
layers.append(
|
159 |
-
nn.ConvTranspose2d(
|
160 |
-
in_channels,
|
161 |
-
out_channels,
|
162 |
-
kernel_size=kernel_size,
|
163 |
-
padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
|
164 |
-
output_padding=(0, fpad_),
|
165 |
-
stride=(1, fstride), # stride over time is always 1
|
166 |
-
dilation=(1, dilation), # dilation over time is always 1
|
167 |
-
groups=groups,
|
168 |
-
bias=bias,
|
169 |
-
)
|
170 |
-
)
|
171 |
-
|
172 |
-
if separable:
|
173 |
-
layers.append(
|
174 |
-
nn.Conv2d(
|
175 |
-
out_channels,
|
176 |
-
out_channels,
|
177 |
-
kernel_size=1,
|
178 |
-
bias=False,
|
179 |
-
)
|
180 |
-
)
|
181 |
-
|
182 |
-
if norm_layer is not None:
|
183 |
-
norm_layer = norm_layer_dict[norm_layer]
|
184 |
-
layers.append(norm_layer(out_channels))
|
185 |
-
|
186 |
-
if activation_layer is not None:
|
187 |
-
activation_layer = activation_layer_dict[activation_layer]
|
188 |
-
layers.append(activation_layer())
|
189 |
-
|
190 |
-
super().__init__(*layers)
|
191 |
-
|
192 |
-
|
193 |
-
class GroupedLinear(nn.Module):
|
194 |
-
|
195 |
-
def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
|
196 |
-
super().__init__()
|
197 |
-
# self.weight: Tensor
|
198 |
-
self.input_size = input_size
|
199 |
-
self.hidden_size = hidden_size
|
200 |
-
self.groups = groups
|
201 |
-
assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
|
202 |
-
assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
|
203 |
-
self.ws = input_size // groups
|
204 |
-
self.register_parameter(
|
205 |
-
"weight",
|
206 |
-
torch.nn.Parameter(
|
207 |
-
torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
|
208 |
-
),
|
209 |
-
)
|
210 |
-
self.reset_parameters()
|
211 |
-
|
212 |
-
def reset_parameters(self):
|
213 |
-
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
|
214 |
-
|
215 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
216 |
-
# x: [..., I]
|
217 |
-
b, t, _ = x.shape
|
218 |
-
# new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
|
219 |
-
new_shape = (b, t, self.groups, self.ws)
|
220 |
-
x = x.view(new_shape)
|
221 |
-
# The better way, but not supported by torchscript
|
222 |
-
# x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
|
223 |
-
x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
|
224 |
-
x = x.flatten(2, 3) # [B, T, H]
|
225 |
-
return x
|
226 |
-
|
227 |
-
def __repr__(self):
|
228 |
-
cls = self.__class__.__name__
|
229 |
-
return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"
|
230 |
-
|
231 |
-
|
232 |
-
class SqueezedGRU_S(nn.Module):
|
233 |
-
"""
|
234 |
-
SGE net: Video object detection with squeezed GRU and information entropy map
|
235 |
-
https://arxiv.org/abs/2106.07224
|
236 |
-
"""
|
237 |
-
|
238 |
-
def __init__(
|
239 |
-
self,
|
240 |
-
input_size: int,
|
241 |
-
hidden_size: int,
|
242 |
-
output_size: Optional[int] = None,
|
243 |
-
num_layers: int = 1,
|
244 |
-
linear_groups: int = 8,
|
245 |
-
batch_first: bool = True,
|
246 |
-
skip_op: str = "none",
|
247 |
-
activation_layer: str = "identity",
|
248 |
-
):
|
249 |
-
super().__init__()
|
250 |
-
self.input_size = input_size
|
251 |
-
self.hidden_size = hidden_size
|
252 |
-
|
253 |
-
self.linear_in = nn.Sequential(
|
254 |
-
GroupedLinear(
|
255 |
-
input_size=input_size,
|
256 |
-
hidden_size=hidden_size,
|
257 |
-
groups=linear_groups,
|
258 |
-
),
|
259 |
-
activation_layer_dict[activation_layer](),
|
260 |
-
)
|
261 |
-
|
262 |
-
# gru skip operator
|
263 |
-
self.gru_skip_op = None
|
264 |
-
|
265 |
-
if skip_op == "none":
|
266 |
-
self.gru_skip_op = None
|
267 |
-
elif skip_op == "identity":
|
268 |
-
if not input_size != output_size:
|
269 |
-
raise AssertionError("Dimensions do not match")
|
270 |
-
self.gru_skip_op = nn.Identity()
|
271 |
-
elif skip_op == "grouped_linear":
|
272 |
-
self.gru_skip_op = GroupedLinear(
|
273 |
-
input_size=hidden_size,
|
274 |
-
hidden_size=hidden_size,
|
275 |
-
groups=linear_groups,
|
276 |
-
)
|
277 |
-
else:
|
278 |
-
raise NotImplementedError()
|
279 |
-
|
280 |
-
self.gru = nn.GRU(
|
281 |
-
input_size=hidden_size,
|
282 |
-
hidden_size=hidden_size,
|
283 |
-
num_layers=num_layers,
|
284 |
-
batch_first=batch_first,
|
285 |
-
bidirectional=False,
|
286 |
-
)
|
287 |
-
|
288 |
-
if output_size is not None:
|
289 |
-
self.linear_out = nn.Sequential(
|
290 |
-
GroupedLinear(
|
291 |
-
input_size=hidden_size,
|
292 |
-
hidden_size=output_size,
|
293 |
-
groups=linear_groups,
|
294 |
-
),
|
295 |
-
activation_layer_dict[activation_layer](),
|
296 |
-
)
|
297 |
-
else:
|
298 |
-
self.linear_out = nn.Identity()
|
299 |
-
|
300 |
-
def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
|
301 |
-
x = self.linear_in(inputs)
|
302 |
-
|
303 |
-
x, h = self.gru.forward(x, h)
|
304 |
-
|
305 |
-
x = self.linear_out(x)
|
306 |
-
|
307 |
-
if self.gru_skip_op is not None:
|
308 |
-
x = x + self.gru_skip_op(inputs)
|
309 |
-
|
310 |
-
return x, h
|
311 |
-
|
312 |
-
|
313 |
-
class Add(nn.Module):
|
314 |
-
def forward(self, a, b):
|
315 |
-
return a + b
|
316 |
-
|
317 |
-
|
318 |
-
class Concat(nn.Module):
|
319 |
-
def forward(self, a, b):
|
320 |
-
return torch.cat((a, b), dim=-1)
|
321 |
-
|
322 |
-
|
323 |
-
class DeepSTFT(nn.Module):
|
324 |
-
def __init__(self, win_size: int, freq_bins: int):
|
325 |
-
super(DeepSTFT, self).__init__()
|
326 |
-
self.win_size = win_size
|
327 |
-
self.freq_bins = freq_bins
|
328 |
-
|
329 |
-
self.conv1d_U = nn.Conv1d(
|
330 |
-
in_channels=1,
|
331 |
-
out_channels=freq_bins * 2,
|
332 |
-
kernel_size=win_size,
|
333 |
-
stride=win_size // 2,
|
334 |
-
bias=False
|
335 |
-
)
|
336 |
-
|
337 |
-
def forward(self, signal: torch.Tensor):
|
338 |
-
"""
|
339 |
-
:param signal: Tensor, shape: [batch_size, num_samples]
|
340 |
-
:return: v, Tensor, shape: [batch_size, freq_bins, time_steps, 2],
|
341 |
-
where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1
|
342 |
-
"""
|
343 |
-
signal = torch.unsqueeze(signal, 1)
|
344 |
-
# signal shape: [batch_size, 1, num_samples]
|
345 |
-
spec = F.relu(self.conv1d_U(signal))
|
346 |
-
# spec shape: [batch_size, freq_bins * 2, time_steps]
|
347 |
-
b, f2, t = spec.shape
|
348 |
-
spec = spec.view(b, f2//2, 2, t).permute(0, 1, 3, 2)
|
349 |
-
# spec shape: [batch_size, freq_bins, time_steps, 2]
|
350 |
-
return spec
|
351 |
-
|
352 |
-
|
353 |
-
class DeepISTFT(nn.Module):
|
354 |
-
def __init__(self, win_size: int, freq_bins: int):
|
355 |
-
super(DeepISTFT, self).__init__()
|
356 |
-
self.win_size = win_size
|
357 |
-
self.freq_bins = freq_bins
|
358 |
-
|
359 |
-
self.basis_signals = nn.Linear(
|
360 |
-
in_features=freq_bins * 2,
|
361 |
-
out_features=win_size,
|
362 |
-
bias=False
|
363 |
-
)
|
364 |
-
|
365 |
-
def forward(self,
|
366 |
-
spec: torch.Tensor,
|
367 |
-
):
|
368 |
-
"""
|
369 |
-
:param spec: Tensor, shape: [batch_size, freq_bins, time_steps, 2],
|
370 |
-
where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1
|
371 |
-
:return: Tensor, shape: [batch_size, c, num_samples],
|
372 |
-
"""
|
373 |
-
b, f, t, _ = spec.shape
|
374 |
-
# spec shape: [b, f, t, 2]
|
375 |
-
spec = spec.permute(0, 2, 1, 3)
|
376 |
-
# spec shape: [b, t, f, 2]
|
377 |
-
spec = spec.view(b, 1, t, -1)
|
378 |
-
# spec shape: [b, 1, t, f2]
|
379 |
-
signal = self.basis_signals(spec)
|
380 |
-
# signal shape: [b, 1, t, win_size]
|
381 |
-
signal = overlap_and_add(signal, self.win_size//2)
|
382 |
-
# signal shape: [b, 1, num_samples]
|
383 |
-
return signal
|
384 |
-
|
385 |
-
|
386 |
-
class Encoder(nn.Module):
|
387 |
-
def __init__(self, config: NXDfNetConfig):
|
388 |
-
super(Encoder, self).__init__()
|
389 |
-
self.embedding_input_size = config.conv_channels * config.freq_bins // 4
|
390 |
-
self.embedding_output_size = config.conv_channels * config.freq_bins // 4
|
391 |
-
self.embedding_hidden_size = config.embedding_hidden_size
|
392 |
-
|
393 |
-
self.spec_conv0 = CausalConv2d(
|
394 |
-
in_channels=1,
|
395 |
-
out_channels=config.conv_channels,
|
396 |
-
kernel_size=config.conv_kernel_size_input,
|
397 |
-
bias=False,
|
398 |
-
separable=True,
|
399 |
-
fstride=1,
|
400 |
-
lookahead=config.conv_lookahead,
|
401 |
-
)
|
402 |
-
self.spec_conv1 = CausalConv2d(
|
403 |
-
in_channels=config.conv_channels,
|
404 |
-
out_channels=config.conv_channels,
|
405 |
-
kernel_size=config.conv_kernel_size_inner,
|
406 |
-
bias=False,
|
407 |
-
separable=True,
|
408 |
-
fstride=2,
|
409 |
-
lookahead=config.conv_lookahead,
|
410 |
-
)
|
411 |
-
self.spec_conv2 = CausalConv2d(
|
412 |
-
in_channels=config.conv_channels,
|
413 |
-
out_channels=config.conv_channels,
|
414 |
-
kernel_size=config.conv_kernel_size_inner,
|
415 |
-
bias=False,
|
416 |
-
separable=True,
|
417 |
-
fstride=2,
|
418 |
-
lookahead=config.conv_lookahead,
|
419 |
-
)
|
420 |
-
self.spec_conv3 = CausalConv2d(
|
421 |
-
in_channels=config.conv_channels,
|
422 |
-
out_channels=config.conv_channels,
|
423 |
-
kernel_size=config.conv_kernel_size_inner,
|
424 |
-
bias=False,
|
425 |
-
separable=True,
|
426 |
-
fstride=1,
|
427 |
-
lookahead=config.conv_lookahead,
|
428 |
-
)
|
429 |
-
|
430 |
-
self.df_conv0 = CausalConv2d(
|
431 |
-
in_channels=2,
|
432 |
-
out_channels=config.conv_channels,
|
433 |
-
kernel_size=config.conv_kernel_size_input,
|
434 |
-
bias=False,
|
435 |
-
separable=True,
|
436 |
-
fstride=1,
|
437 |
-
)
|
438 |
-
self.df_conv1 = CausalConv2d(
|
439 |
-
in_channels=config.conv_channels,
|
440 |
-
out_channels=config.conv_channels,
|
441 |
-
kernel_size=config.conv_kernel_size_inner,
|
442 |
-
bias=False,
|
443 |
-
separable=True,
|
444 |
-
fstride=2,
|
445 |
-
)
|
446 |
-
self.df_fc_emb = nn.Sequential(
|
447 |
-
GroupedLinear(
|
448 |
-
config.conv_channels * config.df_bins // 2,
|
449 |
-
self.embedding_input_size,
|
450 |
-
groups=config.encoder_linear_groups
|
451 |
-
),
|
452 |
-
nn.ReLU(inplace=True)
|
453 |
-
)
|
454 |
-
|
455 |
-
if config.encoder_combine_op == "concat":
|
456 |
-
self.embedding_input_size *= 2
|
457 |
-
self.combine = Concat()
|
458 |
-
else:
|
459 |
-
self.combine = Add()
|
460 |
-
|
461 |
-
# emb_gru
|
462 |
-
if config.freq_bins % 8 != 0:
|
463 |
-
raise AssertionError("freq_bins should be divisible by 8")
|
464 |
-
|
465 |
-
self.emb_gru = SqueezedGRU_S(
|
466 |
-
self.embedding_input_size,
|
467 |
-
self.embedding_hidden_size,
|
468 |
-
output_size=self.embedding_output_size,
|
469 |
-
num_layers=1,
|
470 |
-
batch_first=True,
|
471 |
-
skip_op=config.encoder_emb_skip_op,
|
472 |
-
linear_groups=config.encoder_emb_linear_groups,
|
473 |
-
activation_layer="relu",
|
474 |
-
)
|
475 |
-
|
476 |
-
# lsnr
|
477 |
-
self.lsnr_fc = nn.Sequential(
|
478 |
-
nn.Linear(self.embedding_output_size, 1),
|
479 |
-
nn.Sigmoid()
|
480 |
-
)
|
481 |
-
self.lsnr_scale = config.lsnr_max - config.lsnr_min
|
482 |
-
self.lsnr_offset = config.lsnr_min
|
483 |
-
|
484 |
-
def forward(self,
|
485 |
-
power_spec: torch.Tensor,
|
486 |
-
df_spec: torch.Tensor,
|
487 |
-
hidden_state: torch.Tensor = None,
|
488 |
-
):
|
489 |
-
# power_spec shape: (batch_size, 1, time_steps, spec_dim)
|
490 |
-
e0 = self.spec_conv0.forward(power_spec)
|
491 |
-
e1 = self.spec_conv1.forward(e0)
|
492 |
-
e2 = self.spec_conv2.forward(e1)
|
493 |
-
e3 = self.spec_conv3.forward(e2)
|
494 |
-
# e0 shape: [batch_size, channels, time_steps, spec_dim]
|
495 |
-
# e1 shape: [batch_size, channels, time_steps, spec_dim // 2]
|
496 |
-
# e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
497 |
-
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
498 |
-
|
499 |
-
# df_spec, shape: (batch_size, 2, time_steps, df_bins)
|
500 |
-
c0 = self.df_conv0(df_spec)
|
501 |
-
c1 = self.df_conv1(c0)
|
502 |
-
# c0 shape: [batch_size, channels, time_steps, df_bins]
|
503 |
-
# c1 shape: [batch_size, channels, time_steps, df_bins // 2]
|
504 |
-
|
505 |
-
cemb = c1.permute(0, 2, 3, 1)
|
506 |
-
# cemb shape: [batch_size, time_steps, df_bins // 2, channels]
|
507 |
-
cemb = cemb.flatten(2)
|
508 |
-
# cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
|
509 |
-
cemb = self.df_fc_emb(cemb)
|
510 |
-
# cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
511 |
-
|
512 |
-
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
|
513 |
-
emb = e3.permute(0, 2, 3, 1)
|
514 |
-
# emb shape: [batch_size, time_steps, spec_dim // 4, channels]
|
515 |
-
emb = emb.flatten(2)
|
516 |
-
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
517 |
-
|
518 |
-
emb = self.combine(emb, cemb)
|
519 |
-
# if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
|
520 |
-
# if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
521 |
-
|
522 |
-
emb, h = self.emb_gru.forward(emb, hidden_state)
|
523 |
-
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
|
524 |
-
# h shape: [batch_size, 1, spec_dim]
|
525 |
-
|
526 |
-
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
|
527 |
-
# lsnr shape: [batch_size, time_steps, 1]
|
528 |
-
|
529 |
-
return e0, e1, e2, e3, emb, c0, lsnr, h
|
530 |
-
|
531 |
-
|
532 |
-
class Decoder(nn.Module):
|
533 |
-
def __init__(self, config: NXDfNetConfig):
|
534 |
-
super(Decoder, self).__init__()
|
535 |
-
|
536 |
-
if config.freq_bins % 8 != 0:
|
537 |
-
raise AssertionError("freq_bins should be divisible by 8")
|
538 |
-
|
539 |
-
self.emb_in_dim = config.conv_channels * config.freq_bins // 4
|
540 |
-
self.emb_out_dim = config.conv_channels * config.freq_bins // 4
|
541 |
-
self.emb_hidden_dim = config.decoder_emb_hidden_size
|
542 |
-
|
543 |
-
self.emb_gru = SqueezedGRU_S(
|
544 |
-
self.emb_in_dim,
|
545 |
-
self.emb_hidden_dim,
|
546 |
-
output_size=self.emb_out_dim,
|
547 |
-
num_layers=config.decoder_emb_num_layers - 1,
|
548 |
-
batch_first=True,
|
549 |
-
skip_op=config.decoder_emb_skip_op,
|
550 |
-
linear_groups=config.decoder_emb_linear_groups,
|
551 |
-
activation_layer="relu",
|
552 |
-
)
|
553 |
-
self.conv3p = CausalConv2d(
|
554 |
-
in_channels=config.conv_channels,
|
555 |
-
out_channels=config.conv_channels,
|
556 |
-
kernel_size=1,
|
557 |
-
bias=False,
|
558 |
-
separable=True,
|
559 |
-
fstride=1,
|
560 |
-
lookahead=config.conv_lookahead,
|
561 |
-
)
|
562 |
-
self.convt3 = CausalConv2d(
|
563 |
-
in_channels=config.conv_channels,
|
564 |
-
out_channels=config.conv_channels,
|
565 |
-
kernel_size=config.conv_kernel_size_inner,
|
566 |
-
bias=False,
|
567 |
-
separable=True,
|
568 |
-
fstride=1,
|
569 |
-
lookahead=config.conv_lookahead,
|
570 |
-
)
|
571 |
-
self.conv2p = CausalConv2d(
|
572 |
-
in_channels=config.conv_channels,
|
573 |
-
out_channels=config.conv_channels,
|
574 |
-
kernel_size=1,
|
575 |
-
bias=False,
|
576 |
-
separable=True,
|
577 |
-
fstride=1,
|
578 |
-
lookahead=config.conv_lookahead,
|
579 |
-
)
|
580 |
-
self.convt2 = CausalConvTranspose2d(
|
581 |
-
in_channels=config.conv_channels,
|
582 |
-
out_channels=config.conv_channels,
|
583 |
-
kernel_size=config.convt_kernel_size_inner,
|
584 |
-
bias=False,
|
585 |
-
separable=True,
|
586 |
-
fstride=2,
|
587 |
-
lookahead=config.conv_lookahead,
|
588 |
-
)
|
589 |
-
self.conv1p = CausalConv2d(
|
590 |
-
in_channels=config.conv_channels,
|
591 |
-
out_channels=config.conv_channels,
|
592 |
-
kernel_size=1,
|
593 |
-
bias=False,
|
594 |
-
separable=True,
|
595 |
-
fstride=1,
|
596 |
-
lookahead=config.conv_lookahead,
|
597 |
-
)
|
598 |
-
self.convt1 = CausalConvTranspose2d(
|
599 |
-
in_channels=config.conv_channels,
|
600 |
-
out_channels=config.conv_channels,
|
601 |
-
kernel_size=config.convt_kernel_size_inner,
|
602 |
-
bias=False,
|
603 |
-
separable=True,
|
604 |
-
fstride=2,
|
605 |
-
lookahead=config.conv_lookahead,
|
606 |
-
)
|
607 |
-
self.conv0p = CausalConv2d(
|
608 |
-
in_channels=config.conv_channels,
|
609 |
-
out_channels=config.conv_channels,
|
610 |
-
kernel_size=1,
|
611 |
-
bias=False,
|
612 |
-
separable=True,
|
613 |
-
fstride=1,
|
614 |
-
lookahead=config.conv_lookahead,
|
615 |
-
)
|
616 |
-
self.conv0_out = CausalConv2d(
|
617 |
-
in_channels=config.conv_channels,
|
618 |
-
out_channels=1,
|
619 |
-
kernel_size=config.conv_kernel_size_inner,
|
620 |
-
activation_layer="sigmoid",
|
621 |
-
bias=False,
|
622 |
-
separable=True,
|
623 |
-
fstride=1,
|
624 |
-
lookahead=config.conv_lookahead,
|
625 |
-
)
|
626 |
-
|
627 |
-
def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor:
|
628 |
-
# Estimates erb mask
|
629 |
-
b, _, t, f8 = e3.shape
|
630 |
-
|
631 |
-
# emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
|
632 |
-
emb, _ = self.emb_gru(emb)
|
633 |
-
# emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
|
634 |
-
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
|
635 |
-
e3 = self.convt3(self.conv3p(e3) + emb)
|
636 |
-
# e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
|
637 |
-
e2 = self.convt2(self.conv2p(e2) + e3)
|
638 |
-
# e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2]
|
639 |
-
e1 = self.convt1(self.conv1p(e1) + e2)
|
640 |
-
# e1 shape: [batch_size, conv_channels, time_steps, freq_dim]
|
641 |
-
mask = self.conv0_out(self.conv0p(e0) + e1)
|
642 |
-
# mask shape: [batch_size, 1, time_steps, freq_dim]
|
643 |
-
return mask
|
644 |
-
|
645 |
-
|
646 |
-
class DfDecoder(nn.Module):
|
647 |
-
def __init__(self, config: NXDfNetConfig):
|
648 |
-
super(DfDecoder, self).__init__()
|
649 |
-
|
650 |
-
self.embedding_input_size = config.conv_channels * config.freq_bins // 4
|
651 |
-
self.df_decoder_hidden_size = config.df_decoder_hidden_size
|
652 |
-
self.df_num_layers = config.df_num_layers
|
653 |
-
|
654 |
-
self.df_order = config.df_order
|
655 |
-
|
656 |
-
self.df_bins = config.df_bins
|
657 |
-
self.df_out_ch = config.df_order * 2
|
658 |
-
|
659 |
-
self.df_convp = CausalConv2d(
|
660 |
-
config.conv_channels,
|
661 |
-
self.df_out_ch,
|
662 |
-
fstride=1,
|
663 |
-
kernel_size=(config.df_pathway_kernel_size_t, 1),
|
664 |
-
separable=True,
|
665 |
-
bias=False,
|
666 |
-
)
|
667 |
-
self.df_gru = SqueezedGRU_S(
|
668 |
-
self.embedding_input_size,
|
669 |
-
self.df_decoder_hidden_size,
|
670 |
-
num_layers=self.df_num_layers,
|
671 |
-
batch_first=True,
|
672 |
-
skip_op="none",
|
673 |
-
activation_layer="relu",
|
674 |
-
)
|
675 |
-
|
676 |
-
if config.df_gru_skip == "none":
|
677 |
-
self.df_skip = None
|
678 |
-
elif config.df_gru_skip == "identity":
|
679 |
-
if config.embedding_hidden_size != config.df_decoder_hidden_size:
|
680 |
-
raise AssertionError("Dimensions do not match")
|
681 |
-
self.df_skip = nn.Identity()
|
682 |
-
elif config.df_gru_skip == "grouped_linear":
|
683 |
-
self.df_skip = GroupedLinear(
|
684 |
-
self.embedding_input_size,
|
685 |
-
self.df_decoder_hidden_size,
|
686 |
-
groups=config.df_decoder_linear_groups
|
687 |
-
)
|
688 |
-
else:
|
689 |
-
raise NotImplementedError()
|
690 |
-
|
691 |
-
self.df_out: nn.Module
|
692 |
-
out_dim = self.df_bins * self.df_out_ch
|
693 |
-
|
694 |
-
self.df_out = nn.Sequential(
|
695 |
-
GroupedLinear(
|
696 |
-
input_size=self.df_decoder_hidden_size,
|
697 |
-
hidden_size=out_dim,
|
698 |
-
groups=config.df_decoder_linear_groups
|
699 |
-
),
|
700 |
-
nn.Tanh()
|
701 |
-
)
|
702 |
-
self.df_fc_a = nn.Sequential(
|
703 |
-
nn.Linear(self.df_decoder_hidden_size, 1),
|
704 |
-
nn.Sigmoid()
|
705 |
-
)
|
706 |
-
|
707 |
-
def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor:
|
708 |
-
# emb shape: [batch_size, time_steps, df_bins // 4 * channels]
|
709 |
-
b, t, _ = emb.shape
|
710 |
-
df_coefs, _ = self.df_gru(emb)
|
711 |
-
if self.df_skip is not None:
|
712 |
-
df_coefs = df_coefs + self.df_skip(emb)
|
713 |
-
# df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size]
|
714 |
-
|
715 |
-
# c0 shape: [batch_size, channels, time_steps, df_bins]
|
716 |
-
c0 = self.df_convp(c0)
|
717 |
-
# c0 shape: [batch_size, df_order * 2, time_steps, df_bins]
|
718 |
-
c0 = c0.permute(0, 2, 3, 1)
|
719 |
-
# c0 shape: [batch_size, time_steps, df_bins, df_order * 2]
|
720 |
-
|
721 |
-
df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order
|
722 |
-
# df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2]
|
723 |
-
df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch)
|
724 |
-
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
|
725 |
-
df_coefs = df_coefs + c0
|
726 |
-
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
|
727 |
-
return df_coefs
|
728 |
-
|
729 |
-
|
730 |
-
class DfOutputReshapeMF(nn.Module):
|
731 |
-
"""Coefficients output reshape for multiframe/MultiFrameModule
|
732 |
-
|
733 |
-
Requires input of shape B, C, T, F, 2.
|
734 |
-
"""
|
735 |
-
|
736 |
-
def __init__(self, df_order: int, df_bins: int):
|
737 |
-
super().__init__()
|
738 |
-
self.df_order = df_order
|
739 |
-
self.df_bins = df_bins
|
740 |
-
|
741 |
-
def forward(self, coefs: torch.Tensor) -> torch.Tensor:
|
742 |
-
# [B, T, F, O*2] -> [B, O, T, F, 2]
|
743 |
-
new_shape = list(coefs.shape)
|
744 |
-
new_shape[-1] = -1
|
745 |
-
new_shape.append(2)
|
746 |
-
coefs = coefs.view(new_shape)
|
747 |
-
coefs = coefs.permute(0, 3, 1, 2, 4)
|
748 |
-
return coefs
|
749 |
-
|
750 |
-
|
751 |
-
class Mask(nn.Module):
|
752 |
-
def __init__(self, use_post_filter: bool = False, eps: float = 1e-12):
|
753 |
-
super().__init__()
|
754 |
-
self.use_post_filter = use_post_filter
|
755 |
-
self.eps = eps
|
756 |
-
|
757 |
-
def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
|
758 |
-
"""
|
759 |
-
Post-Filter
|
760 |
-
|
761 |
-
A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
|
762 |
-
https://arxiv.org/abs/2008.04259
|
763 |
-
|
764 |
-
:param mask: Real valued mask, typically of shape [B, C, T, F].
|
765 |
-
:param beta: Global gain factor.
|
766 |
-
:return:
|
767 |
-
"""
|
768 |
-
mask_sin = mask * torch.sin(np.pi * mask / 2)
|
769 |
-
mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
|
770 |
-
return mask_pf
|
771 |
-
|
772 |
-
def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
773 |
-
# spec shape: [batch_size, 1, time_steps, freq_bins, 2]
|
774 |
-
|
775 |
-
if not self.training and self.use_post_filter:
|
776 |
-
mask = self.post_filter(mask)
|
777 |
-
|
778 |
-
# mask shape: [batch_size, 1, time_steps, freq_bins]
|
779 |
-
mask = mask.unsqueeze(4)
|
780 |
-
# mask shape: [batch_size, 1, time_steps, freq_bins, 1]
|
781 |
-
return spec * mask
|
782 |
-
|
783 |
-
|
784 |
-
class DeepFiltering(nn.Module):
|
785 |
-
def __init__(self,
|
786 |
-
df_bins: int,
|
787 |
-
df_order: int,
|
788 |
-
lookahead: int = 0,
|
789 |
-
):
|
790 |
-
super(DeepFiltering, self).__init__()
|
791 |
-
self.df_bins = df_bins
|
792 |
-
self.df_order = df_order
|
793 |
-
self.need_unfold = df_order > 1
|
794 |
-
self.lookahead = lookahead
|
795 |
-
|
796 |
-
self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0)
|
797 |
-
|
798 |
-
def spec_unfold(self, spec: torch.Tensor):
|
799 |
-
"""
|
800 |
-
Pads and unfolds the spectrogram according to frame_size.
|
801 |
-
:param spec: complex Tensor, Spectrogram of shape [B, C, T, F].
|
802 |
-
:return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
|
803 |
-
"""
|
804 |
-
if self.need_unfold:
|
805 |
-
# spec shape: [batch_size, freq_bins, time_steps]
|
806 |
-
spec_pad = self.pad(spec)
|
807 |
-
# spec_pad shape: [batch_size, 1, time_steps_pad, freq_bins]
|
808 |
-
spec_unfold = spec_pad.unfold(2, self.df_order, 1)
|
809 |
-
# spec_unfold shape: [batch_size, 1, time_steps, freq_bins, df_order]
|
810 |
-
return spec_unfold
|
811 |
-
else:
|
812 |
-
return spec.unsqueeze(-1)
|
813 |
-
|
814 |
-
def forward(self,
|
815 |
-
spec: torch.Tensor,
|
816 |
-
coefs: torch.Tensor,
|
817 |
-
):
|
818 |
-
# spec shape: [batch_size, 1, time_steps, freq_bins, 2]
|
819 |
-
spec = spec.contiguous()
|
820 |
-
spec_u = self.spec_unfold(torch.view_as_complex(spec))
|
821 |
-
# spec_u shape: [batch_size, 1, time_steps, freq_bins, df_order]
|
822 |
-
|
823 |
-
# coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
|
824 |
-
coefs = torch.view_as_complex(coefs)
|
825 |
-
# coefs shape: [batch_size, df_order, time_steps, df_bins]
|
826 |
-
spec_f = spec_u.narrow(-2, 0, self.df_bins)
|
827 |
-
# spec_f shape: [batch_size, 1, time_steps, df_bins, df_order]
|
828 |
-
|
829 |
-
coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:])
|
830 |
-
# coefs shape: [batch_size, 1, df_order, time_steps, df_bins]
|
831 |
-
|
832 |
-
spec_f = self.df(spec_f, coefs)
|
833 |
-
# spec_f shape: [batch_size, 1, time_steps, df_bins]
|
834 |
-
|
835 |
-
if self.training:
|
836 |
-
spec = spec.clone()
|
837 |
-
spec[..., :self.df_bins, :] = torch.view_as_real(spec_f)
|
838 |
-
# spec shape: [batch_size, 1, time_steps, freq_bins, 2]
|
839 |
-
return spec
|
840 |
-
|
841 |
-
@staticmethod
|
842 |
-
def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
|
843 |
-
"""
|
844 |
-
Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
|
845 |
-
:param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N].
|
846 |
-
:param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F].
|
847 |
-
:return: (complex Tensor). Spectrogram of shape [B, C, T, F].
|
848 |
-
"""
|
849 |
-
return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
|
850 |
-
|
851 |
-
|
852 |
-
class NXDfNet(nn.Module):
|
853 |
-
def __init__(self, config: NXDfNetConfig):
|
854 |
-
super(NXDfNet, self).__init__()
|
855 |
-
self.config = config
|
856 |
-
|
857 |
-
self.stft = DeepSTFT(win_size=config.win_size, freq_bins=config.freq_bins)
|
858 |
-
self.istft = DeepISTFT(win_size=config.win_size, freq_bins=config.freq_bins)
|
859 |
-
|
860 |
-
self.encoder = Encoder(config)
|
861 |
-
self.decoder = Decoder(config)
|
862 |
-
|
863 |
-
self.df_decoder = DfDecoder(config)
|
864 |
-
self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins)
|
865 |
-
self.df_op = DeepFiltering(
|
866 |
-
df_bins=config.df_bins,
|
867 |
-
df_order=config.df_order,
|
868 |
-
lookahead=config.df_lookahead,
|
869 |
-
)
|
870 |
-
|
871 |
-
self.mask = Mask(use_post_filter=config.use_post_filter)
|
872 |
-
|
873 |
-
def forward(self,
|
874 |
-
noisy: torch.Tensor,
|
875 |
-
):
|
876 |
-
"""
|
877 |
-
:param noisy: Tensor, shape: [batch_size, num_samples]
|
878 |
-
:return:
|
879 |
-
"""
|
880 |
-
spec = self.stft.forward(noisy)
|
881 |
-
# spec shape: [batch_size, freq_bins, time_steps, 2]
|
882 |
-
power_spec = torch.sum(torch.square(spec), dim=-1)
|
883 |
-
power_spec = power_spec.unsqueeze(1).permute(0, 1, 3, 2)
|
884 |
-
# power_spec shape: [batch_size, freq_bins, time_steps]
|
885 |
-
# power_spec shape: [batch_size, 1, freq_bins, time_steps]
|
886 |
-
# power_spec shape: [batch_size, 1, time_steps, freq_bins]
|
887 |
-
|
888 |
-
df_spec = spec.permute(0, 3, 2, 1)
|
889 |
-
# df_spec shape: [batch_size, 2, time_steps, freq_bins]
|
890 |
-
df_spec = df_spec[..., :self.df_decoder.df_bins]
|
891 |
-
# df_spec shape: [batch_size, 2, time_steps, df_bins]
|
892 |
-
|
893 |
-
# spec shape: [batch_size, freq_bins, time_steps, 2]
|
894 |
-
spec = torch.transpose(spec, dim0=1, dim1=2)
|
895 |
-
# spec shape: [batch_size, time_steps, freq_bins, 2]
|
896 |
-
spec = torch.unsqueeze(spec, dim=1)
|
897 |
-
# spec shape: [batch_size, 1, time_steps, freq_bins, 2]
|
898 |
-
|
899 |
-
e0, e1, e2, e3, emb, c0, _, h = self.encoder.forward(power_spec, df_spec)
|
900 |
-
|
901 |
-
mask = self.decoder.forward(emb, e3, e2, e1, e0)
|
902 |
-
# mask shape: [batch_size, 1, time_steps, freq_bins]
|
903 |
-
if torch.any(mask > 1) or torch.any(mask < 0):
|
904 |
-
raise AssertionError
|
905 |
-
|
906 |
-
spec_m = self.mask.forward(spec, mask)
|
907 |
-
|
908 |
-
# lsnr shape: [batch_size, time_steps, 1]
|
909 |
-
# lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
|
910 |
-
# lsnr shape: [batch_size, 1, time_steps]
|
911 |
-
|
912 |
-
df_coefs = self.df_decoder.forward(emb, c0)
|
913 |
-
df_coefs = self.df_out_transform(df_coefs)
|
914 |
-
# df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
|
915 |
-
|
916 |
-
spec_e = self.df_op.forward(spec.clone(), df_coefs)
|
917 |
-
# spec_e shape: [batch_size, 1, time_steps, freq_bins, 2]
|
918 |
-
|
919 |
-
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
|
920 |
-
|
921 |
-
spec_e = torch.squeeze(spec_e, dim=1)
|
922 |
-
spec_e = spec_e.permute(0, 2, 1, 3)
|
923 |
-
# spec_e shape: [batch_size, freq_bins, time_steps, 2]
|
924 |
-
|
925 |
-
denoise = self.istft.forward(spec_e)
|
926 |
-
# spec_e shape: [batch_size, freq_bins, time_steps, 2]
|
927 |
-
return denoise
|
928 |
-
|
929 |
-
|
930 |
-
class NXDfNetPretrainedModel(NXDfNet):
|
931 |
-
def __init__(self,
|
932 |
-
config: NXDfNetConfig,
|
933 |
-
):
|
934 |
-
super(NXDfNetPretrainedModel, self).__init__(
|
935 |
-
config=config,
|
936 |
-
)
|
937 |
-
|
938 |
-
@classmethod
|
939 |
-
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
940 |
-
config = NXDfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
941 |
-
|
942 |
-
model = cls(config)
|
943 |
-
|
944 |
-
if os.path.isdir(pretrained_model_name_or_path):
|
945 |
-
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
946 |
-
else:
|
947 |
-
ckpt_file = pretrained_model_name_or_path
|
948 |
-
|
949 |
-
with open(ckpt_file, "rb") as f:
|
950 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
951 |
-
model.load_state_dict(state_dict, strict=True)
|
952 |
-
return model
|
953 |
-
|
954 |
-
def save_pretrained(self,
|
955 |
-
save_directory: Union[str, os.PathLike],
|
956 |
-
state_dict: Optional[dict] = None,
|
957 |
-
):
|
958 |
-
|
959 |
-
model = self
|
960 |
-
|
961 |
-
if state_dict is None:
|
962 |
-
state_dict = model.state_dict()
|
963 |
-
|
964 |
-
os.makedirs(save_directory, exist_ok=True)
|
965 |
-
|
966 |
-
# save state dict
|
967 |
-
model_file = os.path.join(save_directory, MODEL_FILE)
|
968 |
-
torch.save(state_dict, model_file)
|
969 |
-
|
970 |
-
# save config
|
971 |
-
config_file = os.path.join(save_directory, CONFIG_FILE)
|
972 |
-
self.config.to_yaml_file(config_file)
|
973 |
-
return save_directory
|
974 |
-
|
975 |
-
|
976 |
-
def main():
|
977 |
-
|
978 |
-
config = NXDfNetConfig()
|
979 |
-
model = NXDfNet(config=config)
|
980 |
-
|
981 |
-
inputs = torch.randn(size=(1, 16000), dtype=torch.float32)
|
982 |
-
|
983 |
-
denoise = model.forward(inputs)
|
984 |
-
print(denoise.shape)
|
985 |
-
return
|
986 |
-
|
987 |
-
|
988 |
-
if __name__ == "__main__":
|
989 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_dfnet/utils.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
"""
|
4 |
-
https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py
|
5 |
-
"""
|
6 |
-
import math
|
7 |
-
import torch
|
8 |
-
|
9 |
-
|
10 |
-
def overlap_and_add(signal: torch.Tensor, frame_step: int):
|
11 |
-
"""
|
12 |
-
Reconstructs a signal from a framed representation.
|
13 |
-
|
14 |
-
Adds potentially overlapping frames of a signal with shape
|
15 |
-
`[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
|
16 |
-
The resulting tensor has shape `[..., output_size]` where
|
17 |
-
|
18 |
-
output_size = (frames - 1) * frame_step + frame_length
|
19 |
-
|
20 |
-
Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
|
21 |
-
|
22 |
-
:param signal: Tensor, shape: [..., frames, frame_length]. All dimensions may be unknown, and rank must be at least 2.
|
23 |
-
:param frame_step: int, overlap offsets. Must be less than or equal to frame_length.
|
24 |
-
:return: Tensor, shape: [..., output_size].
|
25 |
-
containing the overlap-added frames of signal's inner-most two dimensions.
|
26 |
-
output_size = (frames - 1) * frame_step + frame_length
|
27 |
-
"""
|
28 |
-
outer_dimensions = signal.size()[:-2]
|
29 |
-
frames, frame_length = signal.size()[-2:]
|
30 |
-
|
31 |
-
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
|
32 |
-
subframe_step = frame_step // subframe_length
|
33 |
-
subframes_per_frame = frame_length // subframe_length
|
34 |
-
|
35 |
-
output_size = frame_step * (frames - 1) + frame_length
|
36 |
-
output_subframes = output_size // subframe_length
|
37 |
-
|
38 |
-
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
|
39 |
-
|
40 |
-
frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
|
41 |
-
|
42 |
-
frame = frame.clone().detach()
|
43 |
-
frame = frame.to(signal.device)
|
44 |
-
frame = frame.long()
|
45 |
-
|
46 |
-
frame = frame.contiguous().view(-1)
|
47 |
-
|
48 |
-
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
|
49 |
-
result.index_add_(-2, frame, subframe_signal)
|
50 |
-
result = result.view(*outer_dimensions, -1)
|
51 |
-
return result
|
52 |
-
|
53 |
-
|
54 |
-
if __name__ == "__main__":
|
55 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_mpnet/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
|
5 |
-
if __name__ == '__main__':
|
6 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_mpnet/causal_convolution/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
|
5 |
-
if __name__ == '__main__':
|
6 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_mpnet/causal_convolution/causal_conv2d.py
DELETED
@@ -1,445 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
from typing import List, Tuple, Union
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
|
8 |
-
from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid2d
|
9 |
-
|
10 |
-
|
11 |
-
class SPConvTranspose2d(nn.Module):
|
12 |
-
def __init__(self,
|
13 |
-
in_channels: int,
|
14 |
-
out_channels: int,
|
15 |
-
kernel_size: Union[int, Tuple[int]],
|
16 |
-
r=1
|
17 |
-
):
|
18 |
-
super(SPConvTranspose2d, self).__init__()
|
19 |
-
self.pad_freq = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
|
20 |
-
self.out_channels = out_channels
|
21 |
-
self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1))
|
22 |
-
self.r = r
|
23 |
-
|
24 |
-
def forward(self, x: torch.Tensor):
|
25 |
-
x = self.pad_freq(x)
|
26 |
-
out = self.conv(x)
|
27 |
-
|
28 |
-
b, c, t, f = out.shape
|
29 |
-
|
30 |
-
out = out.view((b, self.r, c // self.r, t, f))
|
31 |
-
out = out.permute(0, 2, 3, 4, 1)
|
32 |
-
out = out.contiguous().view((b, c // self.r, t, -1))
|
33 |
-
return out
|
34 |
-
|
35 |
-
|
36 |
-
class CausalConv2dBlock(nn.Module):
|
37 |
-
def __init__(self,
|
38 |
-
in_channels: int,
|
39 |
-
out_channels: int,
|
40 |
-
dilation: int,
|
41 |
-
kernel_size: Tuple[int, int] = (2, 3),
|
42 |
-
):
|
43 |
-
super(CausalConv2dBlock, self).__init__()
|
44 |
-
self.pad_length = dilation
|
45 |
-
|
46 |
-
self.pad_time = nn.ConstantPad2d((0, 0, self.pad_length, 0), value=0.)
|
47 |
-
self.pad_freq = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
|
48 |
-
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, dilation=(dilation, 1))
|
49 |
-
self.norm = nn.InstanceNorm2d(out_channels, affine=True)
|
50 |
-
self.activation = nn.PReLU(out_channels)
|
51 |
-
|
52 |
-
def forward(self,
|
53 |
-
x: torch.Tensor,
|
54 |
-
cache_pad: torch.Tensor = None,
|
55 |
-
):
|
56 |
-
"""
|
57 |
-
|
58 |
-
:param x: Tensor, shape: [batch_size, channels, time_steps, dim]
|
59 |
-
:param cache_pad:
|
60 |
-
:return:
|
61 |
-
"""
|
62 |
-
if cache_pad is None:
|
63 |
-
x = self.pad_time(x)
|
64 |
-
else:
|
65 |
-
x = torch.concat(tensors=[cache_pad, x], dim=2)
|
66 |
-
new_cache_pad = x[:, :, -self.pad_length:, :]
|
67 |
-
|
68 |
-
x = self.pad_freq(x)
|
69 |
-
|
70 |
-
x = self.conv(x)
|
71 |
-
x = self.norm(x)
|
72 |
-
x = self.activation(x)
|
73 |
-
return x, new_cache_pad
|
74 |
-
|
75 |
-
|
76 |
-
class CausalConv2dEncoder(nn.Module):
|
77 |
-
def __init__(self,
|
78 |
-
num_blocks: int,
|
79 |
-
hidden_size: int,
|
80 |
-
):
|
81 |
-
super(CausalConv2dEncoder, self).__init__()
|
82 |
-
self.num_blocks = num_blocks
|
83 |
-
|
84 |
-
self.blocks: List[CausalConv2dBlock] = nn.ModuleList([])
|
85 |
-
for idx in range(num_blocks):
|
86 |
-
in_channels = hidden_size * (idx+1)
|
87 |
-
dilation = 2 ** idx
|
88 |
-
block = CausalConv2dBlock(
|
89 |
-
in_channels=in_channels,
|
90 |
-
out_channels=hidden_size,
|
91 |
-
dilation=dilation,
|
92 |
-
kernel_size=(2, 3),
|
93 |
-
)
|
94 |
-
self.blocks.append(block)
|
95 |
-
|
96 |
-
def forward(self,
|
97 |
-
x: torch.Tensor,
|
98 |
-
cache_pad_list: List[torch.Tensor] = None,
|
99 |
-
):
|
100 |
-
"""
|
101 |
-
:param x: Tensor, shape: [batch_size, channels, time_steps, dim].
|
102 |
-
:param cache_pad_list: List[Tensor]
|
103 |
-
:return:
|
104 |
-
"""
|
105 |
-
new_cache_pad_list = list()
|
106 |
-
|
107 |
-
skip = x
|
108 |
-
for idx, block in enumerate(self.blocks):
|
109 |
-
x, new_cache_pad = block.forward(
|
110 |
-
skip,
|
111 |
-
cache_pad=None if cache_pad_list is None else cache_pad_list[idx]
|
112 |
-
)
|
113 |
-
new_cache_pad_list.append(new_cache_pad)
|
114 |
-
skip = torch.cat([x, skip], dim=1)
|
115 |
-
# x shape: [batch_size, channels, time_steps, dim].
|
116 |
-
return x, new_cache_pad_list
|
117 |
-
|
118 |
-
def forward_chunk(self,
|
119 |
-
chunk: torch.Tensor,
|
120 |
-
cache_pad_list: List[torch.Tensor] = None,
|
121 |
-
):
|
122 |
-
return self.forward(chunk, cache_pad_list)
|
123 |
-
|
124 |
-
def forward_chunk_by_chunk(self,
|
125 |
-
x: torch.Tensor,
|
126 |
-
):
|
127 |
-
"""
|
128 |
-
:param x: Tensor, shape: [batch_size, channels, time_steps, dim].
|
129 |
-
:return:
|
130 |
-
"""
|
131 |
-
batch_size, channels, time_steps, _ = x.shape
|
132 |
-
|
133 |
-
cache_pad_list = None
|
134 |
-
|
135 |
-
outputs = list()
|
136 |
-
for idx in range(time_steps):
|
137 |
-
chunk = x[:, :, idx:idx+1, :]
|
138 |
-
|
139 |
-
y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list)
|
140 |
-
outputs.append(y)
|
141 |
-
|
142 |
-
outputs = torch.concat(outputs, dim=2)
|
143 |
-
return outputs
|
144 |
-
|
145 |
-
|
146 |
-
class DenseEncoder(nn.Module):
|
147 |
-
def __init__(self,
|
148 |
-
num_blocks: int,
|
149 |
-
in_channels: int,
|
150 |
-
out_channels: int,
|
151 |
-
):
|
152 |
-
super(DenseEncoder, self).__init__()
|
153 |
-
self.dense_conv_1 = nn.Sequential(
|
154 |
-
nn.Conv2d(in_channels, out_channels, (1, 1)),
|
155 |
-
nn.InstanceNorm2d(out_channels, affine=True),
|
156 |
-
nn.PReLU(out_channels)
|
157 |
-
)
|
158 |
-
self.dense_block = CausalConv2dEncoder(
|
159 |
-
num_blocks=num_blocks, hidden_size=out_channels,
|
160 |
-
)
|
161 |
-
self.dense_conv_2 = nn.Sequential(
|
162 |
-
nn.Conv2d(out_channels, out_channels, (1, 3), (1, 2), padding=(0, 1)),
|
163 |
-
nn.InstanceNorm2d(out_channels, affine=True),
|
164 |
-
nn.PReLU(out_channels)
|
165 |
-
)
|
166 |
-
|
167 |
-
def forward(self,
|
168 |
-
x: torch.Tensor,
|
169 |
-
):
|
170 |
-
"""
|
171 |
-
:param x: Tensor, shape: [batch_size, channels, time_steps, dim]
|
172 |
-
:return:
|
173 |
-
"""
|
174 |
-
x = self.dense_conv_1(x)
|
175 |
-
x, _ = self.dense_block.forward(x)
|
176 |
-
x = self.dense_conv_2(x)
|
177 |
-
# x shape: [b, c, t, f//2]
|
178 |
-
return x
|
179 |
-
|
180 |
-
def forward_chunk(self,
|
181 |
-
x: torch.Tensor,
|
182 |
-
cache_pad_list: List[torch.Tensor] = None,
|
183 |
-
):
|
184 |
-
x = self.dense_conv_1(x)
|
185 |
-
x, new_cache_pad_list = self.dense_block.forward(x, cache_pad_list)
|
186 |
-
x = self.dense_conv_2(x)
|
187 |
-
# x shape: [b, c, t, f//2]
|
188 |
-
return x, new_cache_pad_list
|
189 |
-
|
190 |
-
def forward_chunk_by_chunk(self,
|
191 |
-
x: torch.Tensor,
|
192 |
-
):
|
193 |
-
"""
|
194 |
-
:param x: Tensor, shape: [batch_size, channels, time_steps, dim].
|
195 |
-
:return:
|
196 |
-
"""
|
197 |
-
batch_size, channels, time_steps, _ = x.shape
|
198 |
-
|
199 |
-
cache_pad_list = None
|
200 |
-
|
201 |
-
outputs = list()
|
202 |
-
for idx in range(time_steps):
|
203 |
-
chunk = x[:, :, idx:idx+1, :]
|
204 |
-
|
205 |
-
y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list)
|
206 |
-
outputs.append(y)
|
207 |
-
|
208 |
-
outputs = torch.concat(outputs, dim=2)
|
209 |
-
return outputs
|
210 |
-
|
211 |
-
|
212 |
-
class MaskDecoder(nn.Module):
|
213 |
-
def __init__(self,
|
214 |
-
num_blocks: int,
|
215 |
-
hidden_size: int,
|
216 |
-
out_channels: int = 1,
|
217 |
-
beta: float = 2.0,
|
218 |
-
n_fft: int = 512,
|
219 |
-
):
|
220 |
-
super(MaskDecoder, self).__init__()
|
221 |
-
self.dense_block = CausalConv2dEncoder(
|
222 |
-
num_blocks=num_blocks, hidden_size=hidden_size,
|
223 |
-
)
|
224 |
-
self.mask_conv = nn.Sequential(
|
225 |
-
SPConvTranspose2d(hidden_size, hidden_size, (1, 3), 2),
|
226 |
-
nn.InstanceNorm2d(hidden_size, affine=True),
|
227 |
-
nn.PReLU(hidden_size),
|
228 |
-
nn.Conv2d(hidden_size, out_channels, (1, 2))
|
229 |
-
)
|
230 |
-
self.lsigmoid = LearnableSigmoid2d(n_fft//2+1, beta=beta)
|
231 |
-
|
232 |
-
def forward(self,
|
233 |
-
x: torch.Tensor,
|
234 |
-
):
|
235 |
-
"""
|
236 |
-
|
237 |
-
:param x: Tensor, shape: [batch_size, channels, time_steps, dim]
|
238 |
-
:return:
|
239 |
-
"""
|
240 |
-
x, _ = self.dense_block(x)
|
241 |
-
x = self.mask_conv(x)
|
242 |
-
# x shape: [batch_size, 1, time_steps, dim*2-1]
|
243 |
-
x = x.permute(0, 3, 2, 1).squeeze(-1)
|
244 |
-
# x shape: [b, f, t]
|
245 |
-
x = self.lsigmoid(x)
|
246 |
-
return x
|
247 |
-
|
248 |
-
def forward_chunk(self,
|
249 |
-
x: torch.Tensor,
|
250 |
-
cache_pad_list: List[torch.Tensor] = None,
|
251 |
-
):
|
252 |
-
x, new_cache_pad_list = self.dense_block(x, cache_pad_list)
|
253 |
-
x = self.mask_conv(x)
|
254 |
-
# x shape: [batch_size, 1, time_steps, dim*2-1]
|
255 |
-
x = x.permute(0, 3, 2, 1).squeeze(-1)
|
256 |
-
# x shape: [b, f, t]
|
257 |
-
x = self.lsigmoid(x)
|
258 |
-
return x, new_cache_pad_list
|
259 |
-
|
260 |
-
def forward_chunk_by_chunk(self,
|
261 |
-
x: torch.Tensor,
|
262 |
-
):
|
263 |
-
"""
|
264 |
-
:param x: Tensor, shape: [batch_size, channels, time_steps, dim].
|
265 |
-
:return:
|
266 |
-
"""
|
267 |
-
batch_size, channels, time_steps, _ = x.shape
|
268 |
-
|
269 |
-
cache_pad_list = None
|
270 |
-
|
271 |
-
outputs = list()
|
272 |
-
for idx in range(time_steps):
|
273 |
-
chunk = x[:, :, idx:idx+1, :]
|
274 |
-
|
275 |
-
y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list)
|
276 |
-
outputs.append(y)
|
277 |
-
|
278 |
-
outputs = torch.concat(outputs, dim=2)
|
279 |
-
return outputs
|
280 |
-
|
281 |
-
|
282 |
-
class PhaseDecoder(nn.Module):
|
283 |
-
def __init__(self,
|
284 |
-
num_blocks: int,
|
285 |
-
hidden_size: int,
|
286 |
-
out_channels: int = 1,
|
287 |
-
):
|
288 |
-
super(PhaseDecoder, self).__init__()
|
289 |
-
self.dense_block = CausalConv2dEncoder(
|
290 |
-
num_blocks=num_blocks, hidden_size=hidden_size,
|
291 |
-
)
|
292 |
-
|
293 |
-
self.phase_conv = nn.Sequential(
|
294 |
-
SPConvTranspose2d(hidden_size, hidden_size, (1, 3), 2),
|
295 |
-
nn.InstanceNorm2d(hidden_size, affine=True),
|
296 |
-
nn.PReLU(hidden_size)
|
297 |
-
)
|
298 |
-
self.phase_conv_r = nn.Conv2d(hidden_size, out_channels, (1, 2))
|
299 |
-
self.phase_conv_i = nn.Conv2d(hidden_size, out_channels, (1, 2))
|
300 |
-
|
301 |
-
def forward(self,
|
302 |
-
x: torch.Tensor,
|
303 |
-
):
|
304 |
-
"""
|
305 |
-
:param x: Tensor, shape: [batch_size, channels, time_steps, dim]
|
306 |
-
:return:
|
307 |
-
"""
|
308 |
-
x, _ = self.dense_block(x)
|
309 |
-
|
310 |
-
x = self.phase_conv(x)
|
311 |
-
x_r = self.phase_conv_r(x)
|
312 |
-
x_i = self.phase_conv_i(x)
|
313 |
-
x = torch.atan2(x_i, x_r)
|
314 |
-
x = x.permute(0, 3, 2, 1).squeeze(-1)
|
315 |
-
# x shape: [b, f, t]
|
316 |
-
return x
|
317 |
-
|
318 |
-
def forward_chunk(self,
|
319 |
-
x: torch.Tensor,
|
320 |
-
cache_pad_list: List[torch.Tensor] = None,
|
321 |
-
):
|
322 |
-
x, new_cache_pad_list = self.dense_block(x, cache_pad_list)
|
323 |
-
|
324 |
-
x = self.phase_conv(x)
|
325 |
-
x_r = self.phase_conv_r(x)
|
326 |
-
x_i = self.phase_conv_i(x)
|
327 |
-
x = torch.atan2(x_i, x_r)
|
328 |
-
x = x.permute(0, 3, 2, 1).squeeze(-1)
|
329 |
-
# x shape: [b, f, t]
|
330 |
-
return x, new_cache_pad_list
|
331 |
-
|
332 |
-
def forward_chunk_by_chunk(self,
|
333 |
-
x: torch.Tensor,
|
334 |
-
):
|
335 |
-
"""
|
336 |
-
:param x: Tensor, shape: [batch_size, channels, time_steps, dim].
|
337 |
-
:return:
|
338 |
-
"""
|
339 |
-
batch_size, channels, time_steps, _ = x.shape
|
340 |
-
|
341 |
-
cache_pad_list = None
|
342 |
-
|
343 |
-
outputs = list()
|
344 |
-
for idx in range(time_steps):
|
345 |
-
chunk = x[:, :, idx:idx+1, :]
|
346 |
-
|
347 |
-
y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list)
|
348 |
-
outputs.append(y)
|
349 |
-
|
350 |
-
outputs = torch.concat(outputs, dim=2)
|
351 |
-
return outputs
|
352 |
-
|
353 |
-
|
354 |
-
def main1():
|
355 |
-
|
356 |
-
encoder = CausalConv2dEncoder(
|
357 |
-
num_blocks=3, hidden_size=8,
|
358 |
-
)
|
359 |
-
|
360 |
-
# x shape: [batch_size, channels, time_steps, dim]
|
361 |
-
x = torch.rand(size=(1, 8, 200, 32))
|
362 |
-
x, new_cache_pad_list = encoder.forward(x)
|
363 |
-
print(x.shape)
|
364 |
-
for new_cache_pad in new_cache_pad_list:
|
365 |
-
print(new_cache_pad.shape)
|
366 |
-
|
367 |
-
x = torch.rand(size=(1, 8, 200, 32))
|
368 |
-
x = encoder.forward_chunk_by_chunk(x)
|
369 |
-
print(x.shape)
|
370 |
-
|
371 |
-
return
|
372 |
-
|
373 |
-
|
374 |
-
def main2():
|
375 |
-
|
376 |
-
encoder = DenseEncoder(
|
377 |
-
num_blocks=3, in_channels=8, out_channels=8
|
378 |
-
)
|
379 |
-
|
380 |
-
# x shape: [batch_size, channels, time_steps, dim]
|
381 |
-
x = torch.rand(size=(1, 8, 200, 32))
|
382 |
-
x, new_cache_pad_list = encoder.forward(x)
|
383 |
-
print(x.shape)
|
384 |
-
for new_cache_pad in new_cache_pad_list:
|
385 |
-
print(new_cache_pad.shape)
|
386 |
-
|
387 |
-
x = torch.rand(size=(1, 8, 200, 32))
|
388 |
-
x = encoder.forward_chunk_by_chunk(x)
|
389 |
-
print(x.shape)
|
390 |
-
|
391 |
-
return
|
392 |
-
|
393 |
-
|
394 |
-
def main3():
|
395 |
-
|
396 |
-
encoder = MaskDecoder(
|
397 |
-
num_blocks=3, hidden_size=64, out_channels=1,
|
398 |
-
n_fft=512,
|
399 |
-
)
|
400 |
-
|
401 |
-
# 512 // 2 + 1 = 257
|
402 |
-
# 129 * 2 - 1 = 257
|
403 |
-
# 257 // 2 + 1 = 129
|
404 |
-
|
405 |
-
# x shape: [batch_size, channels, time_steps, dim]
|
406 |
-
x = torch.rand(size=(1, 64, 201, 129))
|
407 |
-
x, new_cache_pad_list = encoder.forward(x)
|
408 |
-
print(x.shape)
|
409 |
-
for new_cache_pad in new_cache_pad_list:
|
410 |
-
print(new_cache_pad.shape)
|
411 |
-
|
412 |
-
x = torch.rand(size=(1, 64, 201, 129))
|
413 |
-
x = encoder.forward_chunk_by_chunk(x)
|
414 |
-
print(x.shape)
|
415 |
-
|
416 |
-
return
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
def main():
|
421 |
-
|
422 |
-
encoder = PhaseDecoder(
|
423 |
-
num_blocks=3, hidden_size=64, out_channels=1,
|
424 |
-
)
|
425 |
-
|
426 |
-
# 512 // 2 + 1 = 257
|
427 |
-
# 129 * 2 - 1 = 257
|
428 |
-
# 257 // 2 + 1 = 129
|
429 |
-
|
430 |
-
# x shape: [batch_size, channels, time_steps, dim]
|
431 |
-
x = torch.rand(size=(1, 64, 201, 129))
|
432 |
-
x, new_cache_pad_list = encoder.forward(x)
|
433 |
-
print(x.shape)
|
434 |
-
for new_cache_pad in new_cache_pad_list:
|
435 |
-
print(new_cache_pad.shape)
|
436 |
-
|
437 |
-
x = torch.rand(size=(1, 64, 201, 129))
|
438 |
-
x = encoder.forward_chunk_by_chunk(x)
|
439 |
-
print(x.shape)
|
440 |
-
|
441 |
-
return
|
442 |
-
|
443 |
-
|
444 |
-
if __name__ == "__main__":
|
445 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_mpnet/configuration_nx_mpnet.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
from toolbox.torchaudio.configuration_utils import PretrainedConfig
|
4 |
-
|
5 |
-
|
6 |
-
class NXMPNetConfig(PretrainedConfig):
|
7 |
-
"""
|
8 |
-
https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
|
9 |
-
"""
|
10 |
-
def __init__(self,
|
11 |
-
sample_rate: int = 8000,
|
12 |
-
segment_size: int = 16000,
|
13 |
-
n_fft: int = 512,
|
14 |
-
win_size: int = 200,
|
15 |
-
hop_size: int = 80,
|
16 |
-
|
17 |
-
dense_num_blocks: int = 4,
|
18 |
-
dense_hidden_size: int = 64,
|
19 |
-
|
20 |
-
mask_num_blocks: int = 4,
|
21 |
-
mask_hidden_size: int = 64,
|
22 |
-
|
23 |
-
phase_num_blocks: int = 4,
|
24 |
-
phase_hidden_size: int = 64,
|
25 |
-
|
26 |
-
tsfm_hidden_size: int = 64,
|
27 |
-
tsfm_attention_heads: int = 4,
|
28 |
-
tsfm_num_blocks: int = 4,
|
29 |
-
tsfm_dropout_rate: float = 0.0,
|
30 |
-
tsfm_max_time_relative_position: int = 2048,
|
31 |
-
tsfm_max_freq_relative_position: int = 256,
|
32 |
-
tsfm_chunk_size: int = 1,
|
33 |
-
tsfm_num_left_chunks: int = 64,
|
34 |
-
tsfm_num_right_chunks: int = 2,
|
35 |
-
|
36 |
-
discriminator_dim: int = 32,
|
37 |
-
discriminator_in_channel: int = 2,
|
38 |
-
|
39 |
-
compress_factor: float = 0.3,
|
40 |
-
|
41 |
-
batch_size: int = 4,
|
42 |
-
learning_rate: float = 0.0005,
|
43 |
-
adam_b1: float = 0.8,
|
44 |
-
adam_b2: float = 0.99,
|
45 |
-
lr_decay: float = 0.99,
|
46 |
-
seed: int = 1234,
|
47 |
-
|
48 |
-
**kwargs
|
49 |
-
):
|
50 |
-
super(NXMPNetConfig, self).__init__(**kwargs)
|
51 |
-
self.sample_rate = sample_rate
|
52 |
-
self.segment_size = segment_size
|
53 |
-
self.n_fft = n_fft
|
54 |
-
self.win_size = win_size
|
55 |
-
self.hop_size = hop_size
|
56 |
-
|
57 |
-
self.dense_num_blocks = dense_num_blocks
|
58 |
-
self.dense_hidden_size = dense_hidden_size
|
59 |
-
|
60 |
-
self.mask_num_blocks = mask_num_blocks
|
61 |
-
self.mask_hidden_size = mask_hidden_size
|
62 |
-
|
63 |
-
self.phase_num_blocks = phase_num_blocks
|
64 |
-
self.phase_hidden_size = phase_hidden_size
|
65 |
-
|
66 |
-
self.tsfm_hidden_size = tsfm_hidden_size
|
67 |
-
self.tsfm_attention_heads = tsfm_attention_heads
|
68 |
-
self.tsfm_num_blocks = tsfm_num_blocks
|
69 |
-
self.tsfm_dropout_rate = tsfm_dropout_rate
|
70 |
-
self.tsfm_max_time_relative_position = tsfm_max_time_relative_position
|
71 |
-
self.tsfm_max_freq_relative_position = tsfm_max_freq_relative_position
|
72 |
-
self.tsfm_chunk_size = tsfm_chunk_size
|
73 |
-
self.tsfm_num_left_chunks = tsfm_num_left_chunks
|
74 |
-
self.tsfm_num_right_chunks = tsfm_num_right_chunks
|
75 |
-
|
76 |
-
self.discriminator_dim = discriminator_dim
|
77 |
-
self.discriminator_in_channel = discriminator_in_channel
|
78 |
-
|
79 |
-
self.compress_factor = compress_factor
|
80 |
-
|
81 |
-
self.batch_size = batch_size
|
82 |
-
self.learning_rate = learning_rate
|
83 |
-
self.adam_b1 = adam_b1
|
84 |
-
self.adam_b2 = adam_b2
|
85 |
-
self.lr_decay = lr_decay
|
86 |
-
self.seed = seed
|
87 |
-
|
88 |
-
|
89 |
-
if __name__ == '__main__':
|
90 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
toolbox/torchaudio/models/nx_mpnet/discriminator.py
DELETED
@@ -1,102 +0,0 @@
|
|
1 |
-
#!/usr/bin/python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
import os
|
4 |
-
from typing import Optional, Union
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import numpy as np
|
9 |
-
import torch.nn.functional as F
|
10 |
-
from pesq import pesq
|
11 |
-
from joblib import Parallel, delayed
|
12 |
-
|
13 |
-
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
|
14 |
-
from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
|
15 |
-
from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid1d
|
16 |
-
|
17 |
-
|
18 |
-
class MetricDiscriminator(nn.Module):
|
19 |
-
def __init__(self, config: NXMPNetConfig):
|
20 |
-
super(MetricDiscriminator, self).__init__()
|
21 |
-
dim = config.discriminator_dim
|
22 |
-
in_channel = config.discriminator_in_channel
|
23 |
-
|
24 |
-
self.layers = nn.Sequential(
|
25 |
-
nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
|
26 |
-
nn.InstanceNorm2d(dim, affine=True),
|
27 |
-
nn.PReLU(dim),
|
28 |
-
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
|
29 |
-
nn.InstanceNorm2d(dim*2, affine=True),
|
30 |
-
nn.PReLU(dim*2),
|
31 |
-
nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
|
32 |
-
nn.InstanceNorm2d(dim*4, affine=True),
|
33 |
-
nn.PReLU(dim*4),
|
34 |
-
nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
|
35 |
-
nn.InstanceNorm2d(dim*8, affine=True),
|
36 |
-
nn.PReLU(dim*8),
|
37 |
-
nn.AdaptiveMaxPool2d(1),
|
38 |
-
nn.Flatten(),
|
39 |
-
nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
|
40 |
-
nn.Dropout(0.3),
|
41 |
-
nn.PReLU(dim*4),
|
42 |
-
nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
|
43 |
-
LearnableSigmoid1d(1)
|
44 |
-
)
|
45 |
-
|
46 |
-
def forward(self, x, y):
|
47 |
-
xy = torch.stack((x, y), dim=1)
|
48 |
-
return self.layers(xy)
|
49 |
-
|
50 |
-
|
51 |
-
MODEL_FILE = "discriminator.pt"
|
52 |
-
|
53 |
-
|
54 |
-
class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
|
55 |
-
def __init__(self,
|
56 |
-
config: NXMPNetConfig,
|
57 |
-
):
|
58 |
-
super(MetricDiscriminatorPretrainedModel, self).__init__(
|
59 |
-
config=config,
|
60 |
-
)
|
61 |
-
self.config = config
|
62 |
-
|
63 |
-
@classmethod
|
64 |
-
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
65 |
-
config = NXMPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
66 |
-
|
67 |
-
model = cls(config)
|
68 |
-
|
69 |
-
if os.path.isdir(pretrained_model_name_or_path):
|
70 |
-
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
|
71 |
-
else:
|
72 |
-
ckpt_file = pretrained_model_name_or_path
|
73 |
-
|
74 |
-
with open(ckpt_file, "rb") as f:
|
75 |
-
state_dict = torch.load(f, map_location="cpu", weights_only=True)
|
76 |
-
model.load_state_dict(state_dict, strict=True)
|
77 |
-
return model
|
78 |
-
|
79 |
-
def save_pretrained(self,
|
80 |
-
save_directory: Union[str, os.PathLike],
|
81 |
-
state_dict: Optional[dict] = None,
|
82 |
-
):
|
83 |
-
|
84 |
-
model = self
|
85 |
-
|
86 |
-
if state_dict is None:
|
87 |
-
state_dict = model.state_dict()
|
88 |
-
|
89 |
-
os.makedirs(save_directory, exist_ok=True)
|
90 |
-
|
91 |
-
# save state dict
|
92 |
-
model_file = os.path.join(save_directory, MODEL_FILE)
|
93 |
-
torch.save(state_dict, model_file)
|
94 |
-
|
95 |
-
# save config
|
96 |
-
config_file = os.path.join(save_directory, CONFIG_FILE)
|
97 |
-
self.config.to_yaml_file(config_file)
|
98 |
-
return save_directory
|
99 |
-
|
100 |
-
|
101 |
-
if __name__ == '__main__':
|
102 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|