diff --git a/examples/data_preprocess/nx_speech_denoise/nx_speech_denoise.py b/examples/data_preprocess/nx_speech_denoise/nx_speech_denoise.py deleted file mode 100644 index 8a5d6d5bfc88072d2c034c748baa3444239cfd08..0000000000000000000000000000000000000000 --- a/examples/data_preprocess/nx_speech_denoise/nx_speech_denoise.py +++ /dev/null @@ -1,83 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import argparse -import os -from pathlib import Path -import sys - -from gradio_client import Client, handle_file -import numpy as np -from tqdm import tqdm -import shutil - -pwd = os.path.abspath(os.path.dirname(__file__)) -sys.path.append(os.path.join(pwd, "../../")) - -import librosa -from scipy.io import wavfile - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--src_dir", - default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\speech\en-PH", - # default=r"/data/tianxing/HuggingDatasets/nx_noise/data/speech/en-PH", - type=str - ) - parser.add_argument( - "--tgt_dir", - default=r"E:\Users\tianx\HuggingDatasets\nx_noise\data\speech-denoise\en-PH", - # default=r"/data/tianxing/HuggingDatasets/nx_noise/data/speech-denoise/en-PH", - type=str - ) - args = parser.parse_args() - return args - - -def main(): - args = get_args() - - # client = Client(src="http://10.75.27.247:7865/") - client = Client(src="http://127.0.0.1:7865/") - - src_dir = Path(args.src_dir) - tgt_dir = Path(args.tgt_dir) - tgt_dir.mkdir(parents=True, exist_ok=True) - - tgt_date_list = list(sorted([date.name for date in src_dir.glob("*") if not date.name.endswith(".zip")])) - finished_date_set = set(tgt_date_list[:-1]) - current_date = tgt_date_list[-1] - - print(f"finished_date_set: {finished_date_set}") - print(f"current_date: {current_date}") - - finished_set = set() - for filename in (tgt_dir / current_date).glob("*.wav"): - name = filename.name - finished_set.add(name) - - src_date_list = list(sorted([date.name for date in src_dir.glob("*")])) - for date in src_date_list: - if date in finished_date_set: - continue - for filename in (src_dir / current_date).glob("**/*.wav"): - result = client.predict( - noisy_audio_file_t=handle_file(filename.as_posix()), - noisy_audio_microphone_t=None, - engine="frcrn-dns3", - api_name="/when_click_denoise_button" - ) - denoise_file = result[0] - tgt_file = tgt_dir / current_date / f"{filename.name}" - tgt_file.parent.mkdir(parents=True, exist_ok=True) - - shutil.move(denoise_file, tgt_file) - print(denoise_file) - exit(0) - - return - - -if __name__ == "__main__": - main() diff --git a/examples/dfnet2/run.sh b/examples/dfnet2/run.sh index 302c9d65ce356806a456182e6f0d878c19b89395..27e663100bcd06d2c4beea356d3946a3829bdd42 100644 --- a/examples/dfnet2/run.sh +++ b/examples/dfnet2/run.sh @@ -10,9 +10,9 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" -sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx-devoice \ ---noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech" \ ---speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name dfnet2-nx2 \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2" END diff --git a/examples/dtln/run.sh b/examples/dtln/run.sh index 232d2fa1dd5c2b7a8571ad650c085e53d0c8b507..cf045a8bafb0fdc1d7cdf42f008e5bf4882f7749 100644 --- a/examples/dtln/run.sh +++ b/examples/dtln/run.sh @@ -7,16 +7,23 @@ sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name fi --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" + sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-512 --final_model_name dtln-512-nx-dns3 \ --config_file "yaml/config-512.yaml" \ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise" \ --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/dns3-speech" -sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir-1024 --final_model_name dtln-1024-nx \ +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name dtnl-1024-nx2 --final_model_name dtln-1024-nx2 \ --config_file "yaml/config-1024.yaml" \ --noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \ ---speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech" +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2" + + +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir-1024 --final_model_name dtln-1024-nx-devoice \ +--config_file "yaml/config-1024.yaml" \ +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" END diff --git a/examples/frcrn/run.sh b/examples/frcrn/run.sh index bb1fbdb1c9abbe8d8e0adcdd618163c5fab53962..0cdd5a10c69a4827bfdb443d594a42886c35ea1f 100644 --- a/examples/frcrn/run.sh +++ b/examples/frcrn/run.sh @@ -9,10 +9,10 @@ sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name fi --speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech" -sh run.sh --stage 1 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx-devoice \ +sh run.sh --stage 2 --stop_stage 2 --system_version centos --file_folder_name file_dir --final_model_name frcrn-10-nx2 \ --config_file "yaml/config-10.yaml" \ ---noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech" \ ---speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" +--noise_dir "/data/tianxing/HuggingDatasets/nx_noise/data/noise/nx-noise" \ +--speech_dir "/data/tianxing/HuggingDatasets/nx_noise/data/speech/nx-speech2" END diff --git a/main.py b/main.py index b3a0a079719255a04c37776c9caa7812f1e46ae9..6c472cdbc2f22a9758cbe056963d870ec848410e 100644 --- a/main.py +++ b/main.py @@ -177,14 +177,10 @@ def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_ infer_engine = load_denoise_model(infer_cls=infer_cls, **kwargs) begin = time.time() - enhanced_audio = infer_engine.enhancement_by_ndarray(noisy_audio) + denoise_audio = infer_engine.enhancement_by_ndarray(noisy_audio) time_cost = time.time() - begin - noisy_mag_db = generate_spectrogram(noisy_audio, title="noisy") - denoise_mag_db = generate_spectrogram(enhanced_audio, title="denoise") - fpr = time_cost / audio_duration - info = { "time_cost": round(time_cost, 4), "audio_duration": round(audio_duration, 4), @@ -192,12 +188,21 @@ def when_click_denoise_button(noisy_audio_file_t = None, noisy_audio_microphone_ } message = json.dumps(info, ensure_ascii=False, indent=4) - enhanced_audio = np.array(enhanced_audio * (1 << 15), dtype=np.int16) + noise_audio = noisy_audio - denoise_audio + + noisy_mag_db = generate_spectrogram(noisy_audio, title="noisy") + denoise_mag_db = generate_spectrogram(denoise_audio, title="denoise") + noise_mag_db = generate_spectrogram(noise_audio, title="noise") + + denoise_audio = np.array(denoise_audio * (1 << 15), dtype=np.int16) + noise_audio = np.array(noise_audio * (1 << 15), dtype=np.int16) + except Exception as e: raise gr.Error(f"enhancement failed, error type: {type(e)}, error text: {str(e)}.") - enhanced_audio_t = (sample_rate, enhanced_audio) - return enhanced_audio_t, message, noisy_mag_db, denoise_mag_db + denoise_audio_t = (sample_rate, denoise_audio) + noise_audio_t = (sample_rate, noise_audio) + return denoise_audio_t, noise_audio_t, message, noisy_mag_db, denoise_mag_db, noise_mag_db def main(): @@ -255,21 +260,23 @@ def main(): with gr.Column(variant="panel", scale=5): with gr.Tabs(): with gr.TabItem("audio"): - dn_enhanced_audio = gr.Audio(label="enhanced_audio") + dn_denoise_audio = gr.Audio(label="denoise_audio") + dn_noise_audio = gr.Audio(label="noise_audio") dn_message = gr.Textbox(lines=1, max_lines=20, label="message") with gr.TabItem("mag_db"): dn_noisy_mag_db = gr.Image(label="noisy_mag_db") dn_denoise_mag_db = gr.Image(label="denoise_mag_db") + dn_noise_mag_db = gr.Image(label="noise_mag_db") dn_button.click( when_click_denoise_button, inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine], - outputs=[dn_enhanced_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db] + outputs=[dn_denoise_audio, dn_noise_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db, dn_noise_mag_db] ) gr.Examples( examples=examples, inputs=[dn_noisy_audio_file, dn_noisy_audio_microphone, dn_engine], - outputs=[dn_enhanced_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db], + outputs=[dn_denoise_audio, dn_noise_audio, dn_message, dn_noisy_mag_db, dn_denoise_mag_db, dn_noise_mag_db], fn=when_click_denoise_button, # cache_examples=True, # cache_mode="lazy", @@ -289,8 +296,8 @@ def main(): # http://127.0.0.1:7865/ # http://10.75.27.247:7865/ blocks.queue().launch( - share=True, - # share=False if platform.system() == "Windows" else False, + # share=True, + share=False if platform.system() == "Windows" else False, server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", server_port=args.server_port ) diff --git a/toolbox/torchaudio/models/nx_clean_unet/transformers/__init__.py b/toolbox/torchaudio/models/dccrn/__init__.py similarity index 66% rename from toolbox/torchaudio/models/nx_clean_unet/transformers/__init__.py rename to toolbox/torchaudio/models/dccrn/__init__.py index 8bc5155c67cae42f80e8126d1727b0edc1e02398..81a66fc40cec5e1bad20c94ebc03002f9772eb07 100644 --- a/toolbox/torchaudio/models/nx_clean_unet/transformers/__init__.py +++ b/toolbox/torchaudio/models/dccrn/__init__.py @@ -2,5 +2,5 @@ # -*- coding: utf-8 -*- -if __name__ == '__main__': +if __name__ == "__main__": pass diff --git a/toolbox/torchaudio/models/dccrn/modeling_dccrn.py b/toolbox/torchaudio/models/dccrn/modeling_dccrn.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6306616753aca2fc7379a8f2406b819adb830a --- /dev/null +++ b/toolbox/torchaudio/models/dccrn/modeling_dccrn.py @@ -0,0 +1,12 @@ +#!/usr/bin/python3 +# -*- coding: utf-8 -*- +""" + +https://arxiv.org/abs/2008.00264 + +https://github.com/huyanxin/DeepComplexCRN + +""" + +if __name__ == "__main__": + pass diff --git a/toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py b/toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py index c65bc4f04b6fb14641dfdc5fbe075dbccc29df31..9054e4704b4cd0076f7b181b4b28a74adbd5c2e8 100644 --- a/toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py +++ b/toolbox/torchaudio/models/dfnet2/modeling_dfnet2.py @@ -11,7 +11,6 @@ https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd444 """ import os import math -from collections import defaultdict from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np @@ -109,7 +108,7 @@ class CausalConv2d(nn.Module): else: self.activation = nn.Identity() - def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None): + def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None): """ :param inputs: shape: [b, c, t, f] :param cache: shape: [b, c, lookback, f]; @@ -560,15 +559,14 @@ class Encoder(nn.Module): feat_spec: torch.Tensor, cache_dict: dict = None, ): - if cache_dict is None: - cache_dict = defaultdict(lambda: None) - cache0 = cache_dict["cache0"] - cache1 = cache_dict["cache1"] - cache2 = cache_dict["cache2"] - cache3 = cache_dict["cache3"] - cache4 = cache_dict["cache4"] - cache5 = cache_dict["cache5"] - cache6 = cache_dict["cache6"] + cache_dict = cache_dict or dict() + cache0 = cache_dict.get("cache0", None) + cache1 = cache_dict.get("cache1", None) + cache2 = cache_dict.get("cache2", None) + cache3 = cache_dict.get("cache3", None) + cache4 = cache_dict.get("cache4", None) + cache5 = cache_dict.get("cache5", None) + cache6 = cache_dict.get("cache6", None) # feat_erb shape: (b, 1, t, erb_bins) e0, new_cache0 = self.spec_conv0.forward(feat_erb, cache=cache0) @@ -716,13 +714,12 @@ class ErbDecoder(nn.Module): ) def forward(self, emb, e3, e2, e1, e0, cache_dict: dict = None) -> torch.Tensor: - if cache_dict is None: - cache_dict = defaultdict(lambda: None) - cache0 = cache_dict["cache0"] - cache1 = cache_dict["cache1"] - cache2 = cache_dict["cache2"] - cache3 = cache_dict["cache3"] - cache4 = cache_dict["cache4"] + cache_dict = cache_dict or dict() + cache0 = cache_dict.get("cache0", None) + cache1 = cache_dict.get("cache1", None) + cache2 = cache_dict.get("cache2", None) + cache3 = cache_dict.get("cache3", None) + cache4 = cache_dict.get("cache4", None) # Estimates erb mask b, _, t, f8 = e3.shape @@ -814,10 +811,9 @@ class DfDecoder(nn.Module): ) def forward(self, emb: torch.Tensor, c0: torch.Tensor, cache_dict: dict = None) -> torch.Tensor: - if cache_dict is None: - cache_dict = defaultdict(lambda: None) - cache0 = cache_dict["cache0"] - cache1 = cache_dict["cache1"] + cache_dict = cache_dict or dict() + cache0 = cache_dict.get("cache0", None) + cache1 = cache_dict.get("cache1", None) # emb shape: [batch_size, time_steps, df_bins // 4 * channels] b, t, _ = emb.shape @@ -995,10 +991,9 @@ class DeepFiltering(nn.Module): coefs: torch.Tensor, cache_dict: dict = None, ): - if cache_dict is None: - cache_dict = defaultdict(lambda: None) - cache0 = cache_dict["cache0"] - cache1 = cache_dict["cache1"] + cache_dict = cache_dict or dict() + cache0 = cache_dict.get("cache0", None) + cache1 = cache_dict.get("cache1", None) # spec shape: [b, 1, t, spec_bins, 2] spec_c = torch.view_as_complex(spec.contiguous()) @@ -1163,10 +1158,9 @@ class DfNet2(nn.Module): return spec, feat_erb, feat_spec def feature_norm(self, feat_erb, feat_spec, cache_dict: dict = None): - if cache_dict is None: - cache_dict = defaultdict(lambda: None) - cache0 = cache_dict["cache0"] - cache1 = cache_dict["cache1"] + cache_dict = cache_dict or dict() + cache0 = cache_dict.get("cache0", None) + cache1 = cache_dict.get("cache1", None) feat_erb, new_cache0 = self.erb_ema.norm(feat_erb, state=cache0) feat_spec, new_cache1 = self.spec_ema.norm(feat_spec, state=cache1) @@ -1249,6 +1243,65 @@ class DfNet2(nn.Module): return est_spec, est_wav, est_mask, lsnr + def forward_chunk(self, + sub_noisy: torch.Tensor, + cache_dict0: dict = None, + cache_dict1: dict = None, + cache_dict2: dict = None, + cache_dict3: dict = None, + cache_dict4: dict = None, + cache_dict5: dict = None, + cache_dict6: dict = None, + ): + + spec, feat_erb, feat_spec = self.feature_prepare(sub_noisy) + # spec shape: [b, 1, t, f, 2] + # feat_erb shape: [b, 1, t, erb_bins] + # feat_spec shape: [b, 2, t, df_bins] + if self.config.use_ema_norm: + feat_erb, feat_spec, cache_dict0 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict0) + + e0, e1, e2, e3, emb, c0, lsnr, cache_dict1 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict1) + + mask, cache_dict2 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict2) + # mask shape: [b, 1, t, erb_bins] + mask = self.erb_bands.erb_scale_inv(mask) + # mask shape: [b, 1, t, f] + + spec_m = self.mask.forward(spec, mask) + # spec_m shape: [b, 1, t, f, 2] + spec_m = spec_m[:, :, :, :self.config.spec_bins, :] + # spec_m shape: [b, 1, t, spec_bins, 2] + + # lsnr shape: [b, t, 1] + lsnr = torch.transpose(lsnr, dim0=2, dim1=1) + # lsnr shape: [b, 1, t] + + df_coefs, cache_dict3 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict3) + df_coefs = self.df_out_transform(df_coefs) + # df_coefs shape: [b, df_order, t, df_bins, 2] + + spec_ = spec[:, :, :, :self.config.spec_bins, :] + # spec shape: [b, 1, t, spec_bins, 2] + spec_f, cache_dict4 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict4) + # spec_f shape: [b, 1, t, df_bins, 2], torch.float32 + + spec_e, cache_dict5 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict5) + + spec_e = torch.squeeze(spec_e, dim=1) + spec_e = spec_e.permute(0, 2, 1, 3) + # spec_e shape: [b, spec_bins, t, 2] + + # spec_e shape: [b, spec_bins, t, 2] + est_spec = torch.view_as_complex(spec_e.contiguous()) + # est_spec shape: [b, spec_bins, t], torch.complex64 + est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1) + # est_spec shape: [b, f, t], torch.complex64 + + est_wav, cache_dict6 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict6) + # est_wav shape: [b, 1, hop_size] + return est_wav, cache_dict0, cache_dict1, cache_dict2, cache_dict3, cache_dict4, cache_dict5, cache_dict6 + def forward_chunk_by_chunk(self, noisy: torch.Tensor, ): @@ -1275,52 +1328,13 @@ class DfNet2(nn.Module): end = begin + self.win_size sub_noisy = noisy[:, :, begin: end] - spec, feat_erb, feat_spec = self.feature_prepare(sub_noisy) - # spec shape: [b, 1, t, f, 2] - # feat_erb shape: [b, 1, t, erb_bins] - # feat_spec shape: [b, 2, t, df_bins] - if self.config.use_ema_norm: - feat_erb, feat_spec, cache_dict0 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict0) - - e0, e1, e2, e3, emb, c0, lsnr, cache_dict1 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict1) - - mask, cache_dict2 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict2) - # mask shape: [b, 1, t, erb_bins] - mask = self.erb_bands.erb_scale_inv(mask) - # mask shape: [b, 1, t, f] - - spec_m = self.mask.forward(spec, mask) - # spec_m shape: [b, 1, t, f, 2] - spec_m = spec_m[:, :, :, :self.config.spec_bins, :] - # spec_m shape: [b, 1, t, spec_bins, 2] - - # lsnr shape: [b, t, 1] - lsnr = torch.transpose(lsnr, dim0=2, dim1=1) - # lsnr shape: [b, 1, t] - - df_coefs, cache_dict3 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict3) - df_coefs = self.df_out_transform(df_coefs) - # df_coefs shape: [b, df_order, t, df_bins, 2] - - spec_ = spec[:, :, :, :self.config.spec_bins, :] - # spec shape: [b, 1, t, spec_bins, 2] - spec_f, cache_dict4 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict4) - # spec_f shape: [b, 1, t, df_bins, 2], torch.float32 - - spec_e, cache_dict5 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict5) - - spec_e = torch.squeeze(spec_e, dim=1) - spec_e = spec_e.permute(0, 2, 1, 3) - # spec_e shape: [b, spec_bins, t, 2] - - # spec_e shape: [b, spec_bins, t, 2] - est_spec = torch.view_as_complex(spec_e.contiguous()) - # est_spec shape: [b, spec_bins, t], torch.complex64 - est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1) - # est_spec shape: [b, f, t], torch.complex64 - - est_wav, cache_dict6 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict6) - # est_wav shape: [b, 1, hop_size] + (est_wav, + cache_dict0, cache_dict1, cache_dict2, cache_dict3, + cache_dict4, cache_dict5, cache_dict6) = self.forward_chunk( + sub_noisy, + cache_dict0, cache_dict1, cache_dict2, cache_dict3, + cache_dict4, cache_dict5, cache_dict6 + ) waveform_list.append(est_wav) @@ -1335,27 +1349,26 @@ class DfNet2(nn.Module): :param cache_dict: :return: """ - if cache_dict is None: - cache_dict = defaultdict(lambda: None) - cache_spec_m = cache_dict["cache_spec_m"] + cache_dict = cache_dict or dict() + cache0 = cache_dict.get("cache0", None) - if cache_spec_m is None: + if cache0 is None: b, c, t, f, _ = spec_m.shape - cache_spec_m = spec_m.new_zeros(size=(b, c, self.config.df_lookahead, f, 2)) + cache0 = spec_m.new_zeros(size=(b, c, self.config.df_lookahead, f, 2)) # cache0 shape: [b, 1, lookahead, f, 2] spec_m_cat = torch.concat(tensors=[ - cache_spec_m, spec_m, + cache0, spec_m, ], dim=2) spec_m = spec_m_cat[:, :, :-self.config.df_lookahead, :, :] - new_cache_spec_m = spec_m_cat[:, :, -self.config.df_lookahead:, :, :] + new_cache0 = spec_m_cat[:, :, -self.config.df_lookahead:, :, :] spec_e = torch.concat(tensors=[ spec_f, spec_m[..., self.df_decoder.df_bins:, :] ], dim=3) new_cache_dict = { - "cache_spec_m": new_cache_spec_m, + "cache0": new_cache0, } return spec_e, new_cache_dict diff --git a/toolbox/torchaudio/models/dtln/modeling_dtln.py b/toolbox/torchaudio/models/dtln/modeling_dtln.py index 2c9877fcfc52b0096a997af4c89b8805258cc403..4a8580ece11edbfb72a5511ebbcf9eb2ace114ed 100644 --- a/toolbox/torchaudio/models/dtln/modeling_dtln.py +++ b/toolbox/torchaudio/models/dtln/modeling_dtln.py @@ -1,9 +1,17 @@ #!/usr/bin/python3 # -*- coding: utf-8 -*- """ +https://www.isca-archive.org/interspeech_2020/westhausen20_interspeech.pdf + https://github.com/AkenoSyuRi/DTLNPytorch https://github.com/breizhn/DTLN + +数据集: DNS3 DNS-Challenge +信噪比 -5 到 25 dB +5 到 30 dB +窗长 32ms, 窗移 8ms + 在 dns3 500个小时的数据上训练, 在 dns3 的测试集上达到了 pesq 3.04 的水平。 """ @@ -245,13 +253,12 @@ class DTLNModel(nn.Module): # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}") t = (num_samples_pad - self.fft_size) // self.hop_size + 1 + overlap_size = self.fft_size - self.hop_size denoise_list = list() out_state1 = None out_state2 = None - overlap_size = self.fft_size - self.hop_size denoise_cache = torch.zeros(size=(batch_size, overlap_size), dtype=noisy.dtype) - # denoise_list.append(torch.clone(denoise_cache)) for i in range(t): begin = i * self.hop_size end = begin + self.fft_size diff --git a/toolbox/torchaudio/models/ehnet/modeling_ehnet.py b/toolbox/torchaudio/models/ehnet/modeling_ehnet.py index 0acf083dc5fb6f61c5da78d866081310f3768d1c..5afecbfcdff63bddb1df740d6ee324a986622141 100644 --- a/toolbox/torchaudio/models/ehnet/modeling_ehnet.py +++ b/toolbox/torchaudio/models/ehnet/modeling_ehnet.py @@ -71,7 +71,6 @@ class CausalTransConvBlock(nn.Module): return x - class CRN(nn.Module): """ Input: [batch size, channels=1, T, n_fft] diff --git a/toolbox/torchaudio/models/nx_clean_unet/__init__.py b/toolbox/torchaudio/models/nx_clean_unet/__init__.py deleted file mode 100644 index 8bc5155c67cae42f80e8126d1727b0edc1e02398..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_clean_unet/causal_convolution/__init__.py b/toolbox/torchaudio/models/nx_clean_unet/causal_convolution/__init__.py deleted file mode 100644 index 8bc5155c67cae42f80e8126d1727b0edc1e02398..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/causal_convolution/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_clean_unet/causal_convolution/causal_conv2d.py b/toolbox/torchaudio/models/nx_clean_unet/causal_convolution/causal_conv2d.py deleted file mode 100644 index 2a479f417a5968ef05ebce42e0b102620c31b5ef..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/causal_convolution/causal_conv2d.py +++ /dev/null @@ -1,261 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import math -import os -from typing import List, Optional, Union, Iterable - -import numpy as np -import torch -import torch.nn as nn -from torch.nn import functional as F - - -norm_layer_dict = { - "batch_norm_2d": torch.nn.BatchNorm2d -} - - -activation_layer_dict = { - "relu": torch.nn.ReLU, - "identity": torch.nn.Identity, - "sigmoid": torch.nn.Sigmoid, -} - - -class CausalConv2d(nn.Module): - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Iterable[int]], - f_stride: int = 1, - dilation: int = 1, - do_f_pad: bool = True, - bias: bool = True, - separable: bool = False, - norm_layer: str = "batch_norm_2d", - activation_layer: str = "relu", - lookahead: int = 0 - ): - super(CausalConv2d, self).__init__() - kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) - - if do_f_pad: - f_pad = kernel_size[1] // 2 + dilation - 1 - else: - f_pad = 0 - - self.causal_left_pad = kernel_size[0] - 1 - lookahead - self.causal_right_pad = lookahead - self.constant_pad = nn.ConstantPad2d( - padding=(0, 0, self.causal_left_pad, self.causal_right_pad), - value=0.0 - ) - - groups = math.gcd(in_channels, out_channels) if separable else 1 - self.conv1 = nn.Conv2d( - in_channels, - out_channels, - kernel_size=kernel_size, - padding=(0, f_pad), - stride=(1, f_stride), - dilation=(1, dilation), - groups=groups, - bias=bias, - ) - - self.conv2 = None - if not any([groups == 1, max(kernel_size) == 1]): - self.conv2 = nn.Conv2d( - out_channels, - out_channels, - kernel_size=1, - bias=False, - ) - - self.norm = None - if norm_layer is not None: - norm_layer = norm_layer_dict[norm_layer] - self.norm = norm_layer(out_channels) - - self.activation = None - if activation_layer is not None: - activation_layer = activation_layer_dict[activation_layer] - self.activation = activation_layer() - - def forward(self, - inputs: torch.Tensor, - causal_cache: torch.Tensor = None, - ): - - if causal_cache is None: - # inputs shape: [batch_size, 1, time_steps, hidden_size] - x = self.constant_pad.forward(inputs) - else: - # inputs shape: [batch_size, 1, time_steps + self.causal_right_pad, hidden_size] - # causal_cache shape: [batch_size, 1, self.causal_left_pad, hidden_size] - x = torch.concat(tensors=[causal_cache, inputs], dim=2) - # x shape: [batch_size, 1, time_steps2, hidden_size] - # time_steps2 = time_steps + self.causal_left_pad + self.causal_right_pad - - x = self.conv1.forward(x) - # inputs shape: [batch_size, 1, time_steps, hidden_size] - - if self.conv2: - x = self.conv2.forward(x) - - if self.norm: - x = self.norm(x) - if self.activation: - x = self.activation(x) - - causal_cache = x[:, :, -self.causal_left_pad:, :] - - # inputs shape: [batch_size, 1, time_steps, hidden_size] - return x, causal_cache - - -class CausalConv2dEncoder(nn.Module): - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Iterable[int]], - f_stride: int = 1, - dilation: int = 1, - do_f_pad: bool = True, - bias: bool = True, - separable: bool = False, - norm_layer: str = "batch_norm_2d", - activation_layer: str = "relu", - lookahead: int = 0, - num_layers: int = 5, - ): - super(CausalConv2dEncoder, self).__init__() - self.num_layers = num_layers - - self.total_causal_left_pad = 0 - self.total_causal_right_pad = 0 - - self.causal_conv_list: List[CausalConv2d] = nn.ModuleList(modules=[]) - for i_layer in range(num_layers): - conv = CausalConv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - f_stride=f_stride, - dilation=dilation, - do_f_pad=do_f_pad, - bias=bias, - separable=separable, - norm_layer=norm_layer, - activation_layer=activation_layer, - lookahead=lookahead, - ) - self.causal_conv_list.append(conv) - - self.total_causal_left_pad += conv.causal_left_pad - self.total_causal_right_pad += conv.causal_right_pad - - in_channels = out_channels - - def forward(self, inputs: torch.Tensor): - # inputs shape: [batch_size, 1, time_steps, hidden_size] - - x = inputs - for layer in self.causal_conv_list: - x, _ = layer.forward(x) - return x - - def forward_chunk(self, - chunk: torch.Tensor, - causal_cache: torch.Tensor = None, - ): - # causal_cache shape: [self.num_layers, 1, causal_left_pad, hidden_size] - - new_causal_cache_list = list() - for idx, causal_conv in enumerate(self.causal_conv_list): - chunk, new_causal_cache = causal_conv.forward( - inputs=chunk, causal_cache=causal_cache[idx: idx+1] if causal_cache is not None else None - ) - new_causal_cache_list.append(new_causal_cache) - - new_causal_cache = torch.cat(new_causal_cache_list, dim=0) - return chunk, new_causal_cache - - def forward_chunk_by_chunk(self, inputs: torch.Tensor): - # inputs shape: [batch_size, 1, time_steps, hidden_size] - # batch_size = 1 - - batch_size, channels, time_steps, hidden_size = inputs.shape - - causal_cache = None - - outputs = [] - for idx in range(0, time_steps, 1): - begin = idx - end = begin + self.total_causal_right_pad + 1 - chunk_xs = inputs[:, :, begin:end, :] - - ys, attention_cache = self.forward_chunk( - chunk=chunk_xs, - causal_cache=causal_cache, - ) - # ys shape: [batch_size, channels, self.total_causal_right_pad + 1 , hidden_size] - ys = ys[:, :, :1, :] - - # ys shape: [batch_size, chunk_size, hidden_size] - outputs.append(ys) - - ys = torch.cat(outputs, 2) - return ys - - -def main2(): - conv = CausalConv2d( - in_channels=1, - out_channels=64, - kernel_size=3, - bias=False, - separable=True, - f_stride=1, - lookahead=0, - ) - - spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32) - # spec shape: [batch_size, 1, time_steps, hidden_size] - cache = torch.randn(size=(1, 1, conv.causal_left_pad, 64), dtype=torch.float32) - - output, _ = conv.forward(spec) - print(output.shape) - - output, _ = conv.forward(spec, cache) - print(output.shape) - - return - - -def main(): - causal = CausalConv2dEncoder( - in_channels=1, - out_channels=1, - kernel_size=3, - bias=False, - separable=True, - f_stride=1, - lookahead=0, - num_layers=3, - ) - - spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32) - # spec shape: [batch_size, 1, time_steps, hidden_size] - - output = causal.forward(spec) - print(output.shape) - - output = causal.forward_chunk_by_chunk(spec) - print(output.shape) - - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py b/toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py deleted file mode 100644 index c02c4c8b4e4fcbf8cc09e14523eeeabf0c0a1e0d..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/configuration_nx_clean_unet.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -from toolbox.torchaudio.configuration_utils import PretrainedConfig - - -class NXCleanUNetConfig(PretrainedConfig): - """ - https://github.com/yxlu-0102/MP-SENet/blob/main/config.json - """ - def __init__(self, - sample_rate: int = 8000, - segment_size: int = 16000, - n_fft: int = 512, - win_length: int = 200, - hop_length: int = 80, - - down_sampling_num_layers: int = 5, - down_sampling_in_channels: int = 1, - down_sampling_hidden_channels: int = 64, - down_sampling_kernel_size: int = 4, - down_sampling_stride: int = 2, - - causal_in_channels: int = 64, - causal_out_channels: int = 64, - causal_kernel_size: int = 3, - causal_bias: bool = False, - causal_separable: bool = True, - causal_f_stride: int = 1, - # causal_lookahead: int = 0, - causal_num_layers: int = 3, - - tsfm_hidden_size: int = 256, - tsfm_attention_heads: int = 4, - tsfm_num_blocks: int = 6, - tsfm_dropout_rate: float = 0.1, - tsfm_max_length: int = 1024, - tsfm_chunk_size: int = 4, - tsfm_num_left_chunks: int = 128, - tsfm_num_right_chunks: int = 2, - - discriminator_dim: int = 16, - discriminator_in_channel: int = 2, - - compress_factor: float = 0.3, - - batch_size: int = 4, - learning_rate: float = 0.0005, - adam_b1: float = 0.8, - adam_b2: float = 0.99, - lr_decay: float = 0.99, - seed: int = 1234, - - **kwargs - ): - super(NXCleanUNetConfig, self).__init__(**kwargs) - self.sample_rate = sample_rate - self.segment_size = segment_size - self.n_fft = n_fft - self.win_length = win_length - self.hop_length = hop_length - - self.down_sampling_num_layers = down_sampling_num_layers - self.down_sampling_in_channels = down_sampling_in_channels - self.down_sampling_hidden_channels = down_sampling_hidden_channels - self.down_sampling_kernel_size = down_sampling_kernel_size - self.down_sampling_stride = down_sampling_stride - - self.causal_in_channels = causal_in_channels - self.causal_out_channels = causal_out_channels - self.causal_kernel_size = causal_kernel_size - self.causal_bias = causal_bias - self.causal_separable = causal_separable - self.causal_f_stride = causal_f_stride - # self.causal_lookahead = causal_lookahead - self.causal_num_layers = causal_num_layers - - self.tsfm_hidden_size = tsfm_hidden_size - self.tsfm_attention_heads = tsfm_attention_heads - self.tsfm_num_blocks = tsfm_num_blocks - self.tsfm_dropout_rate = tsfm_dropout_rate - self.tsfm_max_length = tsfm_max_length - self.tsfm_chunk_size = tsfm_chunk_size - self.tsfm_num_left_chunks = tsfm_num_left_chunks - self.tsfm_num_right_chunks = tsfm_num_right_chunks - - self.discriminator_dim = discriminator_dim - self.discriminator_in_channel = discriminator_in_channel - - self.compress_factor = compress_factor - - self.batch_size = batch_size - self.learning_rate = learning_rate - self.adam_b1 = adam_b1 - self.adam_b2 = adam_b2 - self.lr_decay = lr_decay - self.seed = seed - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_clean_unet/discriminator.py b/toolbox/torchaudio/models/nx_clean_unet/discriminator.py deleted file mode 100644 index 54c0c70a979bcd4e673696ea67c615b6d3557204..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/discriminator.py +++ /dev/null @@ -1,132 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import os -from typing import Optional, Union - -import torch -import torch.nn as nn -import torchaudio - -from toolbox.torchaudio.configuration_utils import CONFIG_FILE -from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig -from toolbox.torchaudio.models.nx_clean_unet.utils import LearnableSigmoid1d - - -class MetricDiscriminator(nn.Module): - def __init__(self, config: NXCleanUNetConfig): - super(MetricDiscriminator, self).__init__() - dim = config.discriminator_dim - self.in_channel = config.discriminator_in_channel - - self.n_fft = config.n_fft - self.win_length = config.win_length - self.hop_length = config.hop_length - - self.transform = torchaudio.transforms.Spectrogram( - n_fft=self.n_fft, - win_length=self.win_length, - hop_length=self.hop_length, - power=1.0, - window_fn=torch.hann_window, - # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, - ) - - self.layers = nn.Sequential( - nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)), - nn.InstanceNorm2d(dim, affine=True), - nn.PReLU(dim), - nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), - nn.InstanceNorm2d(dim*2, affine=True), - nn.PReLU(dim*2), - nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), - nn.InstanceNorm2d(dim*4, affine=True), - nn.PReLU(dim*4), - nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), - nn.InstanceNorm2d(dim*8, affine=True), - nn.PReLU(dim*8), - nn.AdaptiveMaxPool2d(1), - nn.Flatten(), - nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), - nn.Dropout(0.3), - nn.PReLU(dim*4), - nn.utils.spectral_norm(nn.Linear(dim*4, 1)), - LearnableSigmoid1d(1) - ) - - def forward(self, x, y): - x = self.transform.forward(x) - y = self.transform.forward(y) - - xy = torch.stack((x, y), dim=1) - return self.layers(xy) - - -MODEL_FILE = "discriminator.pt" - - -class MetricDiscriminatorPretrainedModel(MetricDiscriminator): - def __init__(self, - config: NXCleanUNetConfig, - ): - super(MetricDiscriminatorPretrainedModel, self).__init__( - config=config, - ) - self.config = config - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - config = NXCleanUNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - - model = cls(config) - - if os.path.isdir(pretrained_model_name_or_path): - ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) - else: - ckpt_file = pretrained_model_name_or_path - - with open(ckpt_file, "rb") as f: - state_dict = torch.load(f, map_location="cpu", weights_only=True) - model.load_state_dict(state_dict, strict=True) - return model - - def save_pretrained(self, - save_directory: Union[str, os.PathLike], - state_dict: Optional[dict] = None, - ): - - model = self - - if state_dict is None: - state_dict = model.state_dict() - - os.makedirs(save_directory, exist_ok=True) - - # save state dict - model_file = os.path.join(save_directory, MODEL_FILE) - torch.save(state_dict, model_file) - - # save config - config_file = os.path.join(save_directory, CONFIG_FILE) - self.config.to_yaml_file(config_file) - return save_directory - - -def main(): - config = NXCleanUNetConfig() - discriminator = MetricDiscriminator(config=config) - - # shape: [batch_size, num_samples] - # x = torch.ones([4, int(4.5 * 16000)]) - # y = torch.ones([4, int(4.5 * 16000)]) - x = torch.ones([4, 16000]) - y = torch.ones([4, 16000]) - - output = discriminator.forward(x, y) - print(output.shape) - print(output) - - return - - -if __name__ == "__main__": - main() diff --git a/toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav b/toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav deleted file mode 100644 index 466d0d39b1e4b04015b894c33a5b8026ef037d63..0000000000000000000000000000000000000000 Binary files a/toolbox/torchaudio/models/nx_clean_unet/enhanced_audio.wav and /dev/null differ diff --git a/toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py b/toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py deleted file mode 100644 index 7dcf854b0baa41c711ce595cd0350238d2d9f7c3..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/inference_nx_clean_unet.py +++ /dev/null @@ -1,96 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import logging -from pathlib import Path -import shutil -import tempfile -import zipfile - -import librosa -import numpy as np -import torch -import torchaudio - -from project_settings import project_path -from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig -from toolbox.torchaudio.models.nx_clean_unet.modeling_nx_clean_unet import NXCleanUNetPretrainedModel, MODEL_FILE - -logger = logging.getLogger("toolbox") - - -class InferenceNXCleanUNet(object): - def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"): - self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file - self.device = torch.device(device) - - logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}") - config, model = self.load_models(self.pretrained_model_path_or_zip_file) - logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}") - - self.config = config - self.model = model - self.model.to(device) - self.model.eval() - - def load_models(self, model_path: str): - model_path = Path(model_path) - if model_path.name.endswith(".zip"): - with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip: - out_root = Path(tempfile.gettempdir()) / "nx_denoise" - out_root.mkdir(parents=True, exist_ok=True) - f_zip.extractall(path=out_root) - model_path = out_root / model_path.stem - - config = NXCleanUNetConfig.from_pretrained( - pretrained_model_name_or_path=model_path.as_posix(), - ) - model = NXCleanUNetPretrainedModel.from_pretrained( - pretrained_model_name_or_path=model_path.as_posix(), - ) - model.to(self.device) - model.eval() - - shutil.rmtree(model_path) - return config, model - - def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor: - if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: - raise AssertionError(f"The value range of audio samples should be between -1 and 1.") - - # noisy_audio shape: [batch_size, num_samples] - noisy_audios = noisy_audio.to(self.device) - - with torch.no_grad(): - enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios) - # enhanced_audios = self.model.forward(noisy_audios) - # enhanced_audio shape: [batch_size, n_samples] - # enhanced_audios = torch.squeeze(enhanced_audios, dim=1) - - enhanced_audio = enhanced_audios[0] - # enhanced_audio shape: [num_samples,] - return enhanced_audio - -def main(): - model_zip_file = project_path / "trained_models/nx-clean-unet-14-epoch.zip" - infer_nx_clean_unet = InferenceNXCleanUNet(model_zip_file) - - sample_rate = 8000 - noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav" - noisy_audio, _ = librosa.load( - noisy_audio_file.as_posix(), - sr=sample_rate, - ) - noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)] - noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) - noisy_audio = noisy_audio.unsqueeze(dim=0) - - enhanced_audio = infer_nx_clean_unet.enhancement_by_tensor(noisy_audio) - - filename = "enhanced_audio.wav" - torchaudio.save(filename, enhanced_audio.detach().cpu().unsqueeze(dim=0), sample_rate) - - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_clean_unet/loss.py b/toolbox/torchaudio/models/nx_clean_unet/loss.py deleted file mode 100644 index 475535006ee63213332fdc19ae91da1d81fe9cfc..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/loss.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import numpy as np -import torch - - -def anti_wrapping_function(x): - - return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) - - -def phase_losses(phase_r, phase_g): - - ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) - gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1))) - iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2))) - - return ip_loss, gd_loss, iaf_loss - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_clean_unet/metrics.py b/toolbox/torchaudio/models/nx_clean_unet/metrics.py deleted file mode 100644 index 78468894a56d4488021e83ea47e07c785a385269..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/metrics.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -from joblib import Parallel, delayed -import numpy as np -from pesq import pesq -from typing import List - -from pesq import cypesq - - -def run_pesq(clean_audio: np.ndarray, - noisy_audio: np.ndarray, - sample_rate: int = 16000, - mode: str = "wb", - ) -> float: - if sample_rate == 8000 and mode == "wb": - raise AssertionError(f"mode should be `nb` when sample_rate is 8000") - try: - pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode) - except cypesq.NoUtterancesError as e: - pesq_score = -1 - except Exception as e: - print(f"pesq failed. error type: {type(e)}, error text: {str(e)}") - pesq_score = -1 - return pesq_score - - -def run_batch_pesq(clean_audio_list: List[np.ndarray], - noisy_audio_list: List[np.ndarray], - sample_rate: int = 16000, - mode: str = "wb", - n_jobs: int = 4, - ) -> List[float]: - parallel = Parallel(n_jobs=n_jobs) - - parallel_tasks = list() - for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list): - parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode) - parallel_tasks.append(parallel_task) - - pesq_score_list = parallel.__call__(parallel_tasks) - return pesq_score_list - - -def run_pesq_score(clean_audio_list: List[np.ndarray], - noisy_audio_list: List[np.ndarray], - sample_rate: int = 16000, - mode: str = "wb", - n_jobs: int = 4, - ) -> List[float]: - - pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list, - noisy_audio_list=noisy_audio_list, - sample_rate=sample_rate, - mode=mode, - n_jobs=n_jobs, - ) - - pesq_score = np.mean(pesq_score_list) - return pesq_score - - -def main(): - clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) - noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) - - clean_audio_list = list(clean_audio) - noisy_audio_list = list(noisy_audio) - - pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list) - print(pesq_score_list) - - pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list) - print(pesq_score) - - return - - -if __name__ == "__main__": - main() diff --git a/toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py b/toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py deleted file mode 100644 index b03dd1b836fbb2e1caae4a6d44dfac7f7adc83f9..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/modeling_nx_clean_unet.py +++ /dev/null @@ -1,401 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import os -from typing import List, Optional, Union - -import numpy as np -import torch -import torch.nn as nn -from torch.nn import functional as F - -from toolbox.torchaudio.configuration_utils import CONFIG_FILE -from toolbox.torchaudio.models.nx_clean_unet.configuration_nx_clean_unet import NXCleanUNetConfig -from toolbox.torchaudio.models.nx_clean_unet.transformers.transformers import TransformerEncoder -from toolbox.torchaudio.models.nx_clean_unet.causal_convolution.causal_conv2d import CausalConv2dEncoder - - -class DownSamplingBlock(nn.Module): - def __init__(self, - in_channels: int, - hidden_channels: int, - kernel_size: int, - stride: int, - ): - super(DownSamplingBlock, self).__init__() - self.conv1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, stride) - self.relu = nn.ReLU() - self.conv2 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1) - self.glu = nn.GLU(dim=1) - - def forward(self, x: torch.Tensor): - # x shape: [batch_size, 1, num_samples] - x = self.conv1.forward(x) - # x shape: [batch_size, hidden_channels, new_num_samples] - x = self.relu(x) - x = self.conv2.forward(x) - # x shape: [batch_size, hidden_channels*2, new_num_samples] - x = self.glu(x) - # x shape: [batch_size, hidden_channels, new_num_samples] - # new_num_samples = (num_samples-kernel_size) // stride + 1 - return x - - -class DownSampling(nn.Module): - def __init__(self, - num_layers: int, - in_channels: int, - hidden_channels: int, - kernel_size: int, - stride: int, - ): - super(DownSampling, self).__init__() - self.num_layers = num_layers - - down_sampling_block_list = list() - for idx in range(self.num_layers): - down_sampling_block = DownSamplingBlock( - in_channels=in_channels, - hidden_channels=hidden_channels, - kernel_size=kernel_size, - stride=stride, - ) - down_sampling_block_list.append(down_sampling_block) - in_channels = hidden_channels - - self.down_sampling_block_list = nn.ModuleList(modules=down_sampling_block_list) - - def forward(self, x: torch.Tensor): - # x shape: [batch_size, channels, num_samples] - skip_connection_list = list() - for down_sampling_block in self.down_sampling_block_list: - x = down_sampling_block.forward(x) - skip_connection_list.append(x) - # x shape: [batch_size, hidden_channels, num_samples**] - return x, skip_connection_list - - -class UpSamplingBlock(nn.Module): - def __init__(self, - out_channels: int, - hidden_channels: int, - kernel_size: int, - stride: int, - do_relu: bool = True, - ): - super(UpSamplingBlock, self).__init__() - self.do_relu = do_relu - - self.conv1 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1) - self.glu = nn.GLU(dim=1) - self.convt = nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride) - self.relu = nn.ReLU() - - def forward(self, x: torch.Tensor): - # x shape: [batch_size, hidden_channels*2, num_samples] - x = self.conv1.forward(x) - # x shape: [batch_size, hidden_channels, num_samples] - x = self.glu(x) - # x shape: [batch_size, hidden_channels, num_samples] - x = self.convt.forward(x) - # x shape: [batch_size, hidden_channels, new_num_samples] - # new_num_samples = (num_samples - 1) * stride + kernel_size - if self.do_relu: - x = self.relu(x) - return x - - -class UpSampling(nn.Module): - def __init__(self, - num_layers: int, - out_channels: int, - hidden_channels: int, - kernel_size: int, - stride: int, - ): - super(UpSampling, self).__init__() - self.num_layers = num_layers - - up_sampling_block_list = list() - for idx in range(self.num_layers-1): - up_sampling_block = UpSamplingBlock( - out_channels=hidden_channels, - hidden_channels=hidden_channels, - kernel_size=kernel_size, - stride=stride, - do_relu=True, - ) - up_sampling_block_list.append(up_sampling_block) - else: - up_sampling_block = UpSamplingBlock( - out_channels=out_channels, - hidden_channels=hidden_channels, - kernel_size=kernel_size, - stride=stride, - do_relu=False, - ) - up_sampling_block_list.append(up_sampling_block) - self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list) - - def forward(self, x: torch.Tensor, skip_connection_list: List[torch.Tensor]): - skip_connection_list = skip_connection_list[::-1] - - # x shape: [batch_size, channels, num_samples] - for idx, up_sampling_block in enumerate(self.up_sampling_block_list): - skip_x = skip_connection_list[idx] - x = x + skip_x - # x = x + skip_x[:, :, :x.size(2)] - x = up_sampling_block.forward(x) - return x - - -def get_padding_length(length, num_layers: int, kernel_size: int, stride: int): - for _ in range(num_layers): - if length < kernel_size: - length = 1 - else: - length = 1 + np.ceil((length - kernel_size) / stride) - - for _ in range(num_layers): - length = (length - 1) * stride + kernel_size - - padded_length = int(length) - return padded_length - - -class NXCleanUNet(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - - self.down_sampling = DownSampling( - num_layers=config.down_sampling_num_layers, - in_channels=config.down_sampling_in_channels, - hidden_channels=config.down_sampling_hidden_channels, - kernel_size=config.down_sampling_kernel_size, - stride=config.down_sampling_stride, - ) - self.causal_encoder = CausalConv2dEncoder( - in_channels=config.causal_in_channels, - out_channels=config.causal_out_channels, - kernel_size=config.causal_kernel_size, - bias=config.causal_bias, - separable=config.causal_separable, - f_stride=config.causal_f_stride, - lookahead=0, - num_layers=config.causal_num_layers, - ) - self.transformer = TransformerEncoder( - input_size=config.down_sampling_hidden_channels, - hidden_size=config.tsfm_hidden_size, - attention_heads=config.tsfm_attention_heads, - num_blocks=config.tsfm_num_blocks, - dropout_rate=config.tsfm_dropout_rate, - chunk_size=config.tsfm_chunk_size, - num_left_chunks=config.tsfm_num_left_chunks, - num_right_chunks=config.tsfm_num_right_chunks, - ) - self.up_sampling = UpSampling( - num_layers=config.down_sampling_num_layers, - out_channels=config.down_sampling_in_channels, - hidden_channels=config.down_sampling_hidden_channels, - kernel_size=config.down_sampling_kernel_size, - stride=config.down_sampling_stride, - ) - - def forward(self, noisy_audios: torch.Tensor): - # noisy_audios shape: [batch_size, n_samples] - noisy_audios = torch.unsqueeze(noisy_audios, dim=1) - # noisy_audios shape: [batch_size, 1, n_samples] - - n_samples = noisy_audios.shape[-1] - padded_length = get_padding_length( - n_samples, - num_layers=self.config.down_sampling_num_layers, - kernel_size=self.config.down_sampling_kernel_size, - stride=self.config.down_sampling_stride, - ) - noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0) - - bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded) - # bottle_neck shape: [batch_size, channels, time_steps] - - bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1) - # bottle_neck shape: [batch_size, time_steps, input_size] - - bottle_neck = bottle_neck.unsqueeze(dim=1) - bottle_neck = self.causal_encoder.forward(bottle_neck) - bottle_neck = bottle_neck.squeeze(dim=1) - # bottle_neck shape: [batch_size, time_steps, input_size] - - bottle_neck = self.transformer.forward(bottle_neck) - # bottle_neck shape: [batch_size, time_steps, input_size] - - bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1) - # bottle_neck shape: [batch_size, channels, time_steps] - - enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list) - - enhanced_audios = enhanced_audios[:, :, :n_samples] - # enhanced_audios shape: [batch_size, 1, n_samples] - - enhanced_audios = torch.squeeze(enhanced_audios, dim=1) - # enhanced_audios shape: [batch_size, n_samples] - - return enhanced_audios - - def forward_chunk_by_chunk(self, noisy_audios: torch.Tensor): - # noisy_audios shape: [batch_size, n_samples] - noisy_audios = torch.unsqueeze(noisy_audios, dim=1) - # noisy_audios shape: [batch_size, 1, n_samples] - - n_samples = noisy_audios.shape[-1] - padded_length = get_padding_length( - n_samples, - num_layers=self.config.down_sampling_num_layers, - kernel_size=self.config.down_sampling_kernel_size, - stride=self.config.down_sampling_stride, - ) - noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0) - - bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded) - # bottle_neck shape: [batch_size, channels, time_steps] - - bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1) - # bottle_neck shape: [batch_size, time_steps, input_size] - - bottle_neck = bottle_neck.unsqueeze(dim=1) - bottle_neck = self.causal_encoder.forward_chunk_by_chunk(bottle_neck) - bottle_neck = bottle_neck.squeeze(dim=1) - # bottle_neck shape: [batch_size, time_steps, input_size] - - bottle_neck = self.transformer.forward_chunk_by_chunk(bottle_neck) - # bottle_neck shape: [batch_size, time_steps, input_size] - - bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1) - # bottle_neck shape: [batch_size, channels, time_steps] - - enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list) - - enhanced_audios = enhanced_audios[:, :, :n_samples] - # enhanced_audios shape: [batch_size, 1, n_samples] - - enhanced_audios = torch.squeeze(enhanced_audios, dim=1) - # enhanced_audios shape: [batch_size, n_samples] - - return enhanced_audios - - - -MODEL_FILE = "generator.pt" - - -class NXCleanUNetPretrainedModel(NXCleanUNet): - def __init__(self, - config: NXCleanUNetConfig, - ): - super(NXCleanUNetPretrainedModel, self).__init__( - config=config, - ) - self.config = config - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - config = NXCleanUNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - - model = cls(config) - - if os.path.isdir(pretrained_model_name_or_path): - ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) - else: - ckpt_file = pretrained_model_name_or_path - - with open(ckpt_file, "rb") as f: - state_dict = torch.load(f, map_location="cpu", weights_only=True) - model.load_state_dict(state_dict, strict=True) - return model - - def save_pretrained(self, - save_directory: Union[str, os.PathLike], - state_dict: Optional[dict] = None, - ): - - model = self - - if state_dict is None: - state_dict = model.state_dict() - - os.makedirs(save_directory, exist_ok=True) - - # save state dict - model_file = os.path.join(save_directory, MODEL_FILE) - torch.save(state_dict, model_file) - - # save config - config_file = os.path.join(save_directory, CONFIG_FILE) - self.config.to_yaml_file(config_file) - return save_directory - - - -def main2(): - - config = NXCleanUNetConfig() - down_sampling = DownSampling( - num_layers=config.down_sampling_num_layers, - in_channels=config.down_sampling_in_channels, - hidden_channels=config.down_sampling_hidden_channels, - kernel_size=config.down_sampling_kernel_size, - stride=config.down_sampling_stride, - ) - up_sampling = UpSampling( - num_layers=config.down_sampling_num_layers, - out_channels=config.down_sampling_in_channels, - hidden_channels=config.down_sampling_hidden_channels, - kernel_size=config.down_sampling_kernel_size, - stride=config.down_sampling_stride, - ) - - # shape: [batch_size, channels, num_samples] - # min length: 94, stride: 32, 32 == 2**5 - # x = torch.ones([4, 1, 94]) - # x = torch.ones([4, 1, 126]) - # x = torch.ones([4, 1, 158]) - x = torch.ones([4, 1, 190]) - - length = x.shape[-1] - padded_length = get_padding_length( - length, - num_layers=config.down_sampling_num_layers, - kernel_size=config.down_sampling_kernel_size, - stride=config.down_sampling_stride, - ) - x = F.pad(input=x, pad=(0, padded_length - length), mode="constant", value=0) - # print(x) - print(x.shape) - bottle_neck = down_sampling.forward(x) - print("-" * 150) - x = up_sampling.forward(bottle_neck) - print(x.shape) - return - - -def main(): - - config = NXCleanUNetConfig() - - # shape: [batch_size, channels, num_samples] - # min length: 94, stride: 32, 32 == 2**5 - # x = torch.ones([4, 94]) - # x = torch.ones([4, 126]) - # x = torch.ones([4, 158]) - # x = torch.ones([4, 190]) - x = torch.ones([4, 16000]) - - model = NXCleanUNet(config) - enhanced_audios = model.forward(x) - print(enhanced_audios.shape) - return - - -if __name__ == "__main__": - main2() diff --git a/toolbox/torchaudio/models/nx_clean_unet/transformers/attention.py b/toolbox/torchaudio/models/nx_clean_unet/transformers/attention.py deleted file mode 100644 index c22d0d6bab79e7ad88c4606db651320fc3e15cd9..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/transformers/attention.py +++ /dev/null @@ -1,270 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import math -from typing import Tuple - -import torch -import torch.nn as nn - - -class MultiHeadSelfAttention(nn.Module): - def __init__(self, n_head: int, n_feat: int, dropout_rate: float): - """ - :param n_head: int. the number of heads. - :param n_feat: int. the number of features. - :param dropout_rate: float. dropout rate. - """ - super().__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.dropout = nn.Dropout(p=dropout_rate) - - def forward_qkv(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - transform query, key and value. - :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat). - :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat). - :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat). - :return: - """ - n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - q = q.transpose(1, 2) # (batch, head, time1, d_k) - k = k.transpose(1, 2) # (batch, head, time2, d_k) - v = v.transpose(1, 2) # (batch, head, time2, d_k) - - return q, k, v - - def forward_attention(self, - value: torch.Tensor, - scores: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) - ) -> torch.Tensor: - """ - compute attention context vector. - :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k). - :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2). - :param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or - (batch_size, time1, time2), (0, 0, 0) means fake mask. - :return: torch.Tensor. transformed value. (batch_size, time1, d_model). - weighted by the attention score (batch_size, time1, time2). - """ - n_batch = value.size(0) - # NOTE: When will `if mask.size(2) > 0` be True? - # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the - # 1st chunk to ease the onnx export.] - # 2. pytorch training - if mask.size(2) > 0: # time2 > 0 - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - # For last chunk, time2 might be larger than scores.size(-1) - mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) - scores = scores.masked_fill(mask, -float('inf')) - attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) - - # NOTE: When will `if mask.size(2) > 0` be False? - # 1. onnx(16/-1, -1/-1, 16/0) - # 2. jit (16/-1, -1/-1, 16/0, 16/4) - else: - attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat) - - return self.linear_out(x) # (batch, time1, n_feat) - - def forward(self, - x: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: - - q, k, v = self.forward_qkv(x, x, x) - - if cache.size(0) > 0: - key_cache, value_cache = torch.split( - cache, cache.size(-1) // 2, dim=-1) - k = torch.cat([key_cache, k], dim=2) - v = torch.cat([value_cache, v], dim=2) - # NOTE: We do cache slicing in encoder.forward_chunk, since it's - # non-trivial to calculate `next_cache_start` here. - new_cache = torch.cat((k, v), dim=-1) - - scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) - return self.forward_attention(v, scores, mask), new_cache - - -class RelativeMultiHeadSelfAttention(nn.Module): - - def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120): - """ - :param n_head: int. the number of heads. - :param n_feat: int. the number of features. - :param dropout_rate: float. dropout rate. - :param max_relative_position: int. maximum relative position for relative position encoding. - """ - super().__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.dropout = nn.Dropout(p=dropout_rate) - - # Relative position encoding - self.max_relative_position = max_relative_position - self.relative_position_k = nn.Parameter(torch.randn(max_relative_position * 2 + 1, self.d_k)) - - def forward_qkv(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - transform query, key and value. - :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat). - :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat). - :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat). - :return: - """ - n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - q = q.transpose(1, 2) # (batch, head, time1, d_k) - k = k.transpose(1, 2) # (batch, head, time2, d_k) - v = v.transpose(1, 2) # (batch, head, time2, d_k) - - return q, k, v - - def forward_attention(self, - value: torch.Tensor, - scores: torch.Tensor, - mask: torch.Tensor = None - ) -> torch.Tensor: - """ - compute attention context vector. - :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, key_time_steps, d_k). - :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, query_time_steps, key_time_steps). - :param mask: torch.Tensor. mask. shape=(batch_size, 1, key_time_steps) or (batch_size, query_time_steps, key_time_steps). - :return: torch.Tensor. transformed value. (batch_size, query_time_steps, d_model). - weighted by the attention score (batch_size, query_time_steps, key_time_steps). - """ - n_batch = value.size(0) - if mask is not None: - mask = mask.unsqueeze(1).eq(0) - # mask shape: [batch_size, 1, query_time_steps, key_time_steps] - scores = scores.masked_fill(mask, -float('inf')) - attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) - else: - attn = torch.softmax(scores, dim=-1) - # attn shape: [batch_size, n_head, query_time_steps, key_time_steps] - - p_attn = self.dropout(attn) - - x = torch.matmul(p_attn, value) - # x shape: [batch_size, n_head, query_time_steps, d_k] - x = x.transpose(1, 2) - # x shape: [batch_size, query_time_steps, n_head, d_k] - - x = x.contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat) - # x shape: [batch_size, query_time_steps, n_head * d_k] - # x shape: [batch_size, query_time_steps, n_feat] - - x = self.linear_out(x) - # x shape: [batch_size, query_time_steps, n_feat] - return x - - def relative_position_encoding(self, length: int) -> torch.Tensor: - """ - Generate relative position encoding. - :param length: int. length of the sequence. - :return: torch.Tensor. relative position encoding. shape=(length, length, d_k). - """ - range_vec = torch.arange(length) - distance_mat = range_vec.unsqueeze(0) - range_vec.unsqueeze(1) - distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) - final_mat = distance_mat_clipped + self.max_relative_position - return final_mat - - def forward(self, - x: torch.Tensor, - mask: torch.Tensor = None, - cache: torch.Tensor = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - - :param x: - :param mask: - :param cache: Tensor, shape: [1, n_heads, time_steps, dim] - :return: - """ - # attention! self attention. - - q, k, v = self.forward_qkv(x, x, x) - # q k v shape: [batch_size, self.h, query_time_steps, self.d_k] - - if cache is not None: - key_cache, value_cache = torch.split( - cache, cache.size(-1) // 2, dim=-1) - k = torch.cat([key_cache, k], dim=2) - v = torch.cat([value_cache, v], dim=2) - - # new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2] - new_cache = torch.cat((k, v), dim=-1) - - # native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps] - native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) - - # Compute relative position encoding - q_length, k_length = q.size(2), k.size(2) - relative_position = self.relative_position_encoding(k_length) - - relative_position = relative_position[-q_length:] - - relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1) - - relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k) - relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k) - - relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k) - # relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps] - - # score - scores = native_scores + relative_position_scores - - return self.forward_attention(v, scores, mask), new_cache - - -def main(): - rel_attention = RelativeMultiHeadSelfAttention(n_head=4, n_feat=256, dropout_rate=0.1) - - x = torch.ones(size=(1, 200, 256), dtype=torch.float32) - xt, new_cache = rel_attention.forward(x, x, x) - - # x = torch.ones(size=(1, 1, 256), dtype=torch.float32) - # cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32) - # xt, new_cache = rel_attention.forward(x, x, x, cache=cache) - - print(xt.shape) - print(new_cache.shape) - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_clean_unet/transformers/mask.py b/toolbox/torchaudio/models/nx_clean_unet/transformers/mask.py deleted file mode 100644 index 087be346c5619573cf5350290dfd3a70a4b685a5..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/transformers/mask.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import torch - - -def make_pad_mask(lengths: torch.Tensor, - max_len: int = 0, - ) -> torch.Tensor: - batch_size = lengths.size(0) - max_len = max_len if max_len > 0 else lengths.max().item() - seq_range = torch.arange( - 0, - max_len, - dtype=torch.int64, - device=lengths.device - ) - seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) - seq_length_expand = lengths.unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - return mask - - - -def subsequent_chunk_mask( - size: int, - chunk_size: int, - num_left_chunks: int = -1, - num_right_chunks: int = 0, - device: torch.device = torch.device("cpu"), -) -> torch.Tensor: - """ - Create mask for subsequent steps (size, size) with chunk size, - this is for streaming encoder - - Examples: - > subsequent_chunk_mask(4, 2) - [[1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 1, 1], - [1, 1, 1, 1]] - - :param size: int. size of mask. - :param chunk_size: int. size of chunk. - :param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks. - :param num_right_chunks: int. number of right chunks. - :param device: torch.device. "cpu" or "cuda" or torch.Tensor.device. - :return: torch.Tensor. mask - """ - - ret = torch.zeros(size, size, device=device, dtype=torch.bool) - for i in range(size): - if num_left_chunks < 0: - start = 0 - else: - start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) - ending = min((i // chunk_size + 1 + num_right_chunks) * chunk_size, size) - ret[i, start:ending] = True - return ret - - -def main(): - chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2) - print(chunk_mask) - - chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2, num_right_chunks=1) - print(chunk_mask) - - chunk_mask = subsequent_chunk_mask(size=9, chunk_size=2, num_left_chunks=2, num_right_chunks=1) - print(chunk_mask) - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_clean_unet/transformers/transformers.py b/toolbox/torchaudio/models/nx_clean_unet/transformers/transformers.py deleted file mode 100644 index 43a5f499a6258acb3c0d8a8503d8d5e5afec3bad..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/transformers/transformers.py +++ /dev/null @@ -1,266 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -from typing import Dict, Optional, Tuple, List, Union - -import torch -import torch.nn as nn - -from toolbox.torchaudio.models.nx_clean_unet.transformers.mask import subsequent_chunk_mask -from toolbox.torchaudio.models.nx_clean_unet.transformers.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention - - -class PositionwiseFeedForward(nn.Module): - def __init__(self, - input_dim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.ReLU()): - """ - FeedForward are applied on each position of the sequence. - the output dim is same with the input dim. - - :param input_dim: int. input dimension. - :param hidden_units: int. the number of hidden units. - :param dropout_rate: float. dropout rate. - :param activation: torch.nn.Module. activation function. - """ - super(PositionwiseFeedForward, self).__init__() - self.w_1 = torch.nn.Linear(input_dim, hidden_units) - self.activation = activation - self.dropout = torch.nn.Dropout(dropout_rate) - self.w_2 = torch.nn.Linear(hidden_units, input_dim) - - def forward(self, xs: torch.Tensor) -> torch.Tensor: - """ - Forward function. - :param xs: torch.Tensor. input tensor. shape=(batch_size, max_length, dim). - :return: output tensor. shape=(batch_size, max_length, dim). - """ - return self.w_2(self.dropout(self.activation(self.w_1(xs)))) - - -class TransformerBlock(nn.Module): - def __init__(self, - input_dim: int, - dropout_rate: float = 0.1, - n_heads: int = 4, - max_relative_position: int = 5120 - ): - super().__init__() - self.norm1 = nn.LayerNorm(input_dim, eps=1e-5) - self.attention = RelativeMultiHeadSelfAttention( - n_head=n_heads, - n_feat=input_dim, - dropout_rate=dropout_rate, - max_relative_position=max_relative_position, - ) - - self.dropout1 = nn.Dropout(dropout_rate) - self.norm2 = nn.LayerNorm(input_dim, eps=1e-5) - self.ffn = PositionwiseFeedForward( - input_dim=input_dim, - hidden_units=input_dim, - dropout_rate=dropout_rate - ) - self.dropout2 = nn.Dropout(dropout_rate) - self.norm3 = nn.LayerNorm(input_dim, eps=1e-5) - - def forward( - self, - x: torch.Tensor, - mask: torch.Tensor = None, - attention_cache: torch.Tensor = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - - :param x: torch.Tensor. shape=(batch_size, time, input_dim). - :param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time). - :param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE - shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim. - :return: - torch.Tensor: Output tensor (batch_size, time, input_dim). - torch.Tensor: att_cache tensor, (batch_size=1, head, cache_t1 + time, d_k * 2). - """ - - xt = self.norm1(x) - - x_att, new_att_cache = self.attention.forward( - xt, mask=mask, cache=attention_cache - ) - x = x + self.dropout1(xt) - xt = self.norm2(x) - xt = self.ffn.forward(xt) - x = x + self.dropout2(xt) - - x = self.norm3(x) - - return x, new_att_cache - - -class TransformerEncoder(nn.Module): - """ - https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/encoder.py#L364 - """ - def __init__(self, - input_size: int = 64, - hidden_size: int = 256, - attention_heads: int = 4, - num_blocks: int = 6, - dropout_rate: float = 0.1, - max_relative_position: int = 1024, - chunk_size: int = 1, - num_left_chunks: int = 128, - num_right_chunks: int = 2, - ): - super().__init__() - self.input_size = input_size - self.hidden_size = hidden_size - - self.max_relative_position = max_relative_position - self.chunk_size = chunk_size - self.num_left_chunks = num_left_chunks - self.num_right_chunks = num_right_chunks - - self.input_linear = nn.Linear( - in_features=self.input_size, - out_features=self.hidden_size, - ) - - self.encoder_layer_list = torch.nn.ModuleList([ - TransformerBlock( - input_dim=hidden_size, - n_heads=attention_heads, - dropout_rate=dropout_rate, - max_relative_position=max_relative_position, - ) for _ in range(num_blocks) - ]) - - self.output_linear = nn.Linear( - in_features=self.hidden_size, - out_features=self.input_size, - ) - - def forward(self, - xs: torch.Tensor, - ): - """ - :param xs: Tensor, shape: [batch_size, time_steps, input_size] - :return: Tensor, shape: [batch_size, time_steps, input_size] - """ - batch_size, time_steps, _ = xs.shape - # xs shape: [batch_size, time_steps, input_size] - xs = self.input_linear.forward(xs) - # xs shape: [batch_size, time_steps, hidden_size] - - chunk_masks = subsequent_chunk_mask( - size=time_steps, - chunk_size=self.chunk_size, - num_left_chunks=self.num_left_chunks, - num_right_chunks=self.num_right_chunks, - ) - chunk_masks = chunk_masks.to(xs.device) - # chunk_masks shape: [1, time_steps, time_steps] - chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps)) - # chunk_masks shape: [batch_size, time_steps, time_steps] - - for encoder_layer in self.encoder_layer_list: - xs, _ = encoder_layer.forward(xs, chunk_masks) - - # xs shape: [batch_size, time_steps, hidden_size] - xs = self.output_linear.forward(xs) - # xs shape: [batch_size, time_steps, input_size] - - return xs - - def forward_chunk(self, - xs: torch.Tensor, - max_att_cache_length: int, - attention_cache: torch.Tensor = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Forward just one chunk. - :param xs: torch.Tensor. chunk input, with shape (b=1, time, mel-dim), - where `time == (chunk_size - 1) * subsample_rate + subsample.right_context + 1` - :param max_att_cache_length: - :param attention_cache: torch.Tensor. - :return: - """ - # xs shape: [batch_size, time_steps, input_size] - xs = self.input_linear.forward(xs) - # xs shape: [batch_size, time_steps, hidden_size] - - r_att_cache = [] - for idx, encoder_layer in enumerate(self.encoder_layer_list): - xs, new_att_cache = encoder_layer.forward( - x=xs, attention_cache=attention_cache[idx: idx+1] if attention_cache is not None else None, - ) - if new_att_cache.size(2) > max_att_cache_length: - begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size - end = self.num_right_chunks * self.chunk_size - new_att_cache = new_att_cache[:, :, -begin:-end, :] - r_att_cache.append(new_att_cache) - - r_att_cache = torch.cat(r_att_cache, dim=0) - - return xs, r_att_cache - - def forward_chunk_by_chunk( - self, - xs: torch.Tensor, - ) -> torch.Tensor: - - batch_size, time_steps, _ = xs.shape - - # attention_cache shape: [num_blocks, attention_heads, self.num_left_chunks * self.chunk_size, n_heads * d_k * 2] - max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size - attention_cache = None - - outputs = [] - for idx in range(0, time_steps - self.chunk_size, self.chunk_size): - begin = idx - end = begin + self.chunk_size * (self.num_right_chunks + 1) - chunk_xs = xs[:, begin:end, :] - # print(f"begin: {begin}, end: {end}, length: {chunk_xs.size(1)}") - - ys, attention_cache = self.forward_chunk( - xs=chunk_xs, - max_att_cache_length=max_att_cache_length, - attention_cache=attention_cache, - ) - # ys shape: [batch_size, self.chunk_size * (self.num_right_chunks + 1), hidden_size] - ys = ys[:, :self.chunk_size, :] - - # ys shape: [batch_size, chunk_size, hidden_size] - ys = self.output_linear.forward(ys) - # ys shape: [batch_size, chunk_size, input_size] - - outputs.append(ys) - - ys = torch.cat(outputs, 1) - return ys - - -def main(): - - encoder = TransformerEncoder( - input_size=64, - hidden_size=256, - attention_heads=4, - num_blocks=6, - dropout_rate=0.1, - ) - print(encoder) - - x = torch.ones([4, 200, 64]) - - y = encoder.forward(xs=x) - print(y.shape) - - # y = encoder.forward_chunk_by_chunk(xs=x) - # print(y.shape) - - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_clean_unet/utils.py b/toolbox/torchaudio/models/nx_clean_unet/utils.py deleted file mode 100644 index 84a6918b6a8f945e196b6ef909d7ba3575ce686e..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/utils.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import torch -import torch.nn as nn - - -class LearnableSigmoid1d(nn.Module): - def __init__(self, in_features, beta=1): - super().__init__() - self.beta = beta - self.slope = nn.Parameter(torch.ones(in_features)) - self.slope.requiresGrad = True - - def forward(self, x): - # x shape: [batch_size, time_steps, spec_bins] - return self.beta * torch.sigmoid(self.slope * x) - - -def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True): - - hann_window = torch.hann_window(win_size).to(y.device) - stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, - center=center, pad_mode='reflect', normalized=False, return_complex=True) - stft_spec = torch.view_as_real(stft_spec) - mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9) - pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5) - # Magnitude Compression - mag = torch.pow(mag, compress_factor) - com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1) - - return mag, pha, com - - -def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True): - # Magnitude Decompression - mag = torch.pow(mag, (1.0/compress_factor)) - com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha)) - hann_window = torch.hann_window(win_size).to(com.device) - wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center) - - return wav - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml b/toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml deleted file mode 100644 index b7cf99d020a417744c1be7f42ce6164d9125dde0..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_clean_unet/yaml/config.yaml +++ /dev/null @@ -1,51 +0,0 @@ -model_name: "nx_clean_unet" - -sample_rate: 8000 -segment_size: 16000 -n_fft: 512 -win_size: 200 -hop_size: 80 -# 因为 hop_size 取 80,则相当于 stft 的时间步是 10ms 一步,所以降采样也考虑到差不多的分辨率。 - -# 2**down_sampling_num_layers, -# 例如 2**6=64 就意味着 64 个值在降采样之后是一个时间步, -# 则一步是 64/sample_rate = 0.008秒。 -# 那么 tsfm_chunk_size=2 则为16ms,tsfm_chunk_size=4 则为32ms -# 假设每次向左看1秒,向右看30ms,则: -# tsfm_chunk_size=1,tsfm_num_left_chunks=128,tsfm_num_right_chunks=4 -# tsfm_chunk_size=2,tsfm_num_left_chunks=64,tsfm_num_right_chunks=2 -# tsfm_chunk_size=4,tsfm_num_left_chunks=32,tsfm_num_right_chunks=1 -down_sampling_num_layers: 6 -down_sampling_in_channels: 1 -down_sampling_hidden_channels: 64 -down_sampling_kernel_size: 4 -down_sampling_stride: 2 - -causal_in_channels: 1 -causal_out_channels: 1 -causal_kernel_size: 3 -causal_bias: false -causal_separable: true -causal_f_stride: 1 -causal_num_layers: 3 - -tsfm_hidden_size: 256 -tsfm_attention_heads: 8 -tsfm_num_blocks: 6 -tsfm_dropout_rate: 0.1 -tsfm_max_length: 512 -tsfm_chunk_size: 1 -tsfm_num_left_chunks: 128 -tsfm_num_right_chunks: 4 - -discriminator_dim: 32 -discriminator_in_channel: 2 - -compress_factor: 0.3 - -batch_size: 4 -learning_rate: 0.0005 -adam_b1: 0.8 -adam_b2: 0.99 -lr_decay: 0.99 -seed: 1234 diff --git a/toolbox/torchaudio/models/nx_denoise/__init__.py b/toolbox/torchaudio/models/nx_denoise/__init__.py deleted file mode 100644 index 8bc5155c67cae42f80e8126d1727b0edc1e02398..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_denoise/causal_convolution/__init__.py b/toolbox/torchaudio/models/nx_denoise/causal_convolution/__init__.py deleted file mode 100644 index 8bc5155c67cae42f80e8126d1727b0edc1e02398..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/causal_convolution/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_denoise/causal_convolution/causal_conv2d.py b/toolbox/torchaudio/models/nx_denoise/causal_convolution/causal_conv2d.py deleted file mode 100644 index 101b739df7a96c16e713c64edcd68ac1a8279989..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/causal_convolution/causal_conv2d.py +++ /dev/null @@ -1,281 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import math -import os -from typing import List, Optional, Union, Iterable - -import numpy as np -import torch -import torch.nn as nn -from torch.nn import functional as F - - -norm_layer_dict = { - "batch_norm_2d": torch.nn.BatchNorm2d -} - - -activation_layer_dict = { - "relu": torch.nn.ReLU, - "identity": torch.nn.Identity, - "sigmoid": torch.nn.Sigmoid, -} - - -class CausalConv2d(nn.Module): - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Iterable[int]], - f_stride: int = 1, - dilation: int = 1, - do_f_pad: bool = True, - bias: bool = True, - separable: bool = False, - norm_layer: str = "batch_norm_2d", - activation_layer: str = "relu", - lookahead: int = 0 - ): - super(CausalConv2d, self).__init__() - kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) - - if do_f_pad: - f_pad = kernel_size[1] // 2 + dilation - 1 - else: - f_pad = 0 - - self.causal_left_pad = kernel_size[0] - 1 - lookahead - self.causal_right_pad = lookahead - self.constant_pad = nn.ConstantPad2d( - padding=(0, 0, self.causal_left_pad, self.causal_right_pad), - value=0.0 - ) - - groups = math.gcd(in_channels, out_channels) if separable else 1 - self.conv1 = nn.Conv2d( - in_channels, - out_channels, - kernel_size=kernel_size, - padding=(0, f_pad), - stride=(1, f_stride), - dilation=(1, dilation), - groups=groups, - bias=bias, - ) - - self.conv2 = None - if not any([groups == 1, max(kernel_size) == 1]): - self.conv2 = nn.Conv2d( - out_channels, - out_channels, - kernel_size=1, - bias=False, - ) - - self.norm = None - if norm_layer is not None: - norm_layer = norm_layer_dict[norm_layer] - self.norm = norm_layer(out_channels) - - self.activation = None - if activation_layer is not None: - activation_layer = activation_layer_dict[activation_layer] - self.activation = activation_layer() - - def forward(self, - inputs: torch.Tensor, - causal_cache: List[torch.Tensor] = None, - ): - - if causal_cache is None: - # inputs shape: [batch_size, 1, time_steps, hidden_size] - x = self.constant_pad.forward(inputs) - else: - # inputs shape: [batch_size, 1, time_steps + self.causal_right_pad, hidden_size] - # causal_cache shape: [batch_size, 1, self.causal_left_pad, hidden_size] - x = torch.concat(tensors=[causal_cache, inputs], dim=2) - # x shape: [batch_size, 1, time_steps2, hidden_size] - # time_steps2 = time_steps + self.causal_left_pad + self.causal_right_pad - - causal_cache = x[:, :, -self.causal_left_pad:, :] - - x = self.conv1.forward(x) - # inputs shape: [batch_size, 1, time_steps, hidden_size] - - if self.conv2: - x = self.conv2.forward(x) - - if self.norm: - x = self.norm(x) - if self.activation: - x = self.activation(x) - - # inputs shape: [batch_size, 1, time_steps, hidden_size] - return x, causal_cache - - -class CausalConv2dEncoder(nn.Module): - def __init__(self, - in_channels: int, - hidden_channels: int, - out_channels: int, - kernel_size: Union[int, Iterable[int]], - f_stride: int = 1, - dilation: int = 1, - do_f_pad: bool = True, - bias: bool = True, - separable: bool = False, - norm_layer: str = "batch_norm_2d", - activation_layer: str = "relu", - lookahead: int = 0, - num_layers: int = 5, - ): - super(CausalConv2dEncoder, self).__init__() - self.num_layers = num_layers - - self.total_causal_left_pad = 0 - self.total_causal_right_pad = 0 - - self.causal_conv_list: List[CausalConv2d] = nn.ModuleList(modules=[]) - for i_layer in range(num_layers): - conv = CausalConv2d( - in_channels=in_channels, - out_channels=hidden_channels, - kernel_size=kernel_size, - f_stride=f_stride, - dilation=dilation, - do_f_pad=do_f_pad, - bias=bias, - separable=separable, - norm_layer=norm_layer, - activation_layer=activation_layer, - lookahead=lookahead, - ) - self.causal_conv_list.append(conv) - - self.total_causal_left_pad += conv.causal_left_pad - self.total_causal_right_pad += conv.causal_right_pad - - in_channels = hidden_channels - else: - conv = CausalConv2d( - in_channels=hidden_channels, - out_channels=out_channels, - kernel_size=kernel_size, - f_stride=f_stride, - dilation=dilation, - do_f_pad=do_f_pad, - bias=bias, - separable=separable, - norm_layer=norm_layer, - activation_layer=activation_layer, - lookahead=lookahead, - ) - self.causal_conv_list.append(conv) - - self.total_causal_left_pad += conv.causal_left_pad - self.total_causal_right_pad += conv.causal_right_pad - - - def forward(self, inputs: torch.Tensor): - # inputs shape: [batch_size, 1, time_steps, hidden_size] - - x = inputs - for layer in self.causal_conv_list: - x, _ = layer.forward(x) - return x - - def forward_chunk(self, - chunk: torch.Tensor, - causal_cache: List[torch.Tensor] = None, - ): - # causal_cache shape: [self.num_layers, batch_size, 1, causal_left_pad, hidden_size] - - new_causal_cache_list: List[torch.Tensor] = list() - for idx, causal_conv in enumerate(self.causal_conv_list): - chunk, new_causal_cache = causal_conv.forward( - inputs=chunk, causal_cache=causal_cache[idx] if causal_cache is not None else None - ) - # print(f"idx: {idx}, new_causal_cache: {new_causal_cache.shape}") - new_causal_cache_list.append(new_causal_cache) - - return chunk, new_causal_cache_list - - def forward_chunk_by_chunk(self, inputs: torch.Tensor): - # inputs shape: [batch_size, 1, time_steps, hidden_size] - # batch_size = 1 - - batch_size, channels, time_steps, hidden_size = inputs.shape - - new_causal_cache_list: List[torch.Tensor] = None - - outputs = [] - for idx in range(0, time_steps, 1): - begin = idx - end = begin + self.total_causal_right_pad + 1 - chunk_xs = inputs[:, :, begin:end, :] - - ys, new_causal_cache_list = self.forward_chunk( - chunk=chunk_xs, - causal_cache=new_causal_cache_list, - ) - # ys shape: [batch_size, channels, self.total_causal_right_pad + 1 , hidden_size] - ys = ys[:, :, :1, :] - - # ys shape: [batch_size, chunk_size, hidden_size] - outputs.append(ys) - - ys = torch.cat(outputs, 2) - return ys - - -def main2(): - conv = CausalConv2d( - in_channels=1, - out_channels=64, - kernel_size=3, - bias=False, - separable=True, - f_stride=1, - lookahead=0, - ) - - spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32) - # spec shape: [batch_size, 1, time_steps, hidden_size] - cache = torch.randn(size=(1, 1, conv.causal_left_pad, 64), dtype=torch.float32) - - output, _ = conv.forward(spec) - print(output.shape) - - output, _ = conv.forward(spec, cache) - print(output.shape) - - return - - -def main(): - causal = CausalConv2dEncoder( - in_channels=1, - out_channels=1, - kernel_size=3, - bias=False, - separable=True, - f_stride=1, - lookahead=0, - num_layers=3, - ) - - spec = torch.randn(size=(1, 1, 200, 64), dtype=torch.float32) - # spec shape: [batch_size, 1, time_steps, hidden_size] - - output = causal.forward(spec) - print(output.shape) - - output = causal.forward_chunk_by_chunk(spec) - print(output.shape) - - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_denoise/configuration_nx_denoise.py b/toolbox/torchaudio/models/nx_denoise/configuration_nx_denoise.py deleted file mode 100644 index cf4bb1835dabce5b96c8047d1064111692ab4840..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/configuration_nx_denoise.py +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -from toolbox.torchaudio.configuration_utils import PretrainedConfig - - -class NXDenoiseConfig(PretrainedConfig): - """ - https://github.com/yxlu-0102/MP-SENet/blob/main/config.json - """ - def __init__(self, - sample_rate: int = 8000, - segment_size: int = 16000, - n_fft: int = 512, - win_length: int = 200, - hop_length: int = 80, - - down_sampling_num_layers: int = 5, - down_sampling_in_channels: int = 1, - down_sampling_hidden_channels: int = 64, - down_sampling_kernel_size: int = 4, - down_sampling_stride: int = 2, - - causal_in_channels: int = 1, - causal_hidden_channels: int = 64, - causal_kernel_size: int = 3, - causal_bias: bool = False, - causal_separable: bool = True, - causal_f_stride: int = 1, - # causal_lookahead: int = 0, - causal_num_layers: int = 3, - - tsfm_hidden_size: int = 256, - tsfm_attention_heads: int = 4, - tsfm_num_blocks: int = 6, - tsfm_dropout_rate: float = 0.1, - tsfm_max_time_relative_position: int = 1024, - tsfm_max_freq_relative_position: int = 128, - tsfm_chunk_size: int = 4, - tsfm_num_left_chunks: int = 128, - tsfm_num_right_chunks: int = 2, - - discriminator_dim: int = 16, - discriminator_in_channel: int = 2, - - compress_factor: float = 0.3, - - batch_size: int = 4, - learning_rate: float = 0.0005, - adam_b1: float = 0.8, - adam_b2: float = 0.99, - lr_decay: float = 0.99, - seed: int = 1234, - - **kwargs - ): - super(NXDenoiseConfig, self).__init__(**kwargs) - self.sample_rate = sample_rate - self.segment_size = segment_size - self.n_fft = n_fft - self.win_length = win_length - self.hop_length = hop_length - - self.down_sampling_num_layers = down_sampling_num_layers - self.down_sampling_in_channels = down_sampling_in_channels - self.down_sampling_hidden_channels = down_sampling_hidden_channels - self.down_sampling_kernel_size = down_sampling_kernel_size - self.down_sampling_stride = down_sampling_stride - - self.causal_in_channels = causal_in_channels - self.causal_hidden_channels = causal_hidden_channels - self.causal_kernel_size = causal_kernel_size - self.causal_bias = causal_bias - self.causal_separable = causal_separable - self.causal_f_stride = causal_f_stride - # self.causal_lookahead = causal_lookahead - self.causal_num_layers = causal_num_layers - - self.tsfm_hidden_size = tsfm_hidden_size - self.tsfm_attention_heads = tsfm_attention_heads - self.tsfm_num_blocks = tsfm_num_blocks - self.tsfm_dropout_rate = tsfm_dropout_rate - self.tsfm_max_time_relative_position = tsfm_max_time_relative_position - self.tsfm_max_freq_relative_position = tsfm_max_freq_relative_position - self.tsfm_chunk_size = tsfm_chunk_size - self.tsfm_num_left_chunks = tsfm_num_left_chunks - self.tsfm_num_right_chunks = tsfm_num_right_chunks - - self.discriminator_dim = discriminator_dim - self.discriminator_in_channel = discriminator_in_channel - - self.compress_factor = compress_factor - - self.batch_size = batch_size - self.learning_rate = learning_rate - self.adam_b1 = adam_b1 - self.adam_b2 = adam_b2 - self.lr_decay = lr_decay - self.seed = seed - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_denoise/discriminator.py b/toolbox/torchaudio/models/nx_denoise/discriminator.py deleted file mode 100644 index 8be8e6930b84f8f0b78dafa05349bef8e4687565..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/discriminator.py +++ /dev/null @@ -1,132 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import os -from typing import Optional, Union - -import torch -import torch.nn as nn -import torchaudio - -from toolbox.torchaudio.configuration_utils import CONFIG_FILE -from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig -from toolbox.torchaudio.models.nx_denoise.utils import LearnableSigmoid1d - - -class MetricDiscriminator(nn.Module): - def __init__(self, config: NXDenoiseConfig): - super(MetricDiscriminator, self).__init__() - dim = config.discriminator_dim - self.in_channel = config.discriminator_in_channel - - self.n_fft = config.n_fft - self.win_length = config.win_length - self.hop_length = config.hop_length - - self.transform = torchaudio.transforms.Spectrogram( - n_fft=self.n_fft, - win_length=self.win_length, - hop_length=self.hop_length, - power=1.0, - window_fn=torch.hann_window, - # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, - ) - - self.layers = nn.Sequential( - nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)), - nn.InstanceNorm2d(dim, affine=True), - nn.PReLU(dim), - nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), - nn.InstanceNorm2d(dim*2, affine=True), - nn.PReLU(dim*2), - nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), - nn.InstanceNorm2d(dim*4, affine=True), - nn.PReLU(dim*4), - nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), - nn.InstanceNorm2d(dim*8, affine=True), - nn.PReLU(dim*8), - nn.AdaptiveMaxPool2d(1), - nn.Flatten(), - nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), - nn.Dropout(0.3), - nn.PReLU(dim*4), - nn.utils.spectral_norm(nn.Linear(dim*4, 1)), - LearnableSigmoid1d(1) - ) - - def forward(self, x, y): - x = self.transform.forward(x) - y = self.transform.forward(y) - - xy = torch.stack((x, y), dim=1) - return self.layers(xy) - - -MODEL_FILE = "discriminator.pt" - - -class MetricDiscriminatorPretrainedModel(MetricDiscriminator): - def __init__(self, - config: NXDenoiseConfig, - ): - super(MetricDiscriminatorPretrainedModel, self).__init__( - config=config, - ) - self.config = config - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - config = NXDenoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - - model = cls(config) - - if os.path.isdir(pretrained_model_name_or_path): - ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) - else: - ckpt_file = pretrained_model_name_or_path - - with open(ckpt_file, "rb") as f: - state_dict = torch.load(f, map_location="cpu", weights_only=True) - model.load_state_dict(state_dict, strict=True) - return model - - def save_pretrained(self, - save_directory: Union[str, os.PathLike], - state_dict: Optional[dict] = None, - ): - - model = self - - if state_dict is None: - state_dict = model.state_dict() - - os.makedirs(save_directory, exist_ok=True) - - # save state dict - model_file = os.path.join(save_directory, MODEL_FILE) - torch.save(state_dict, model_file) - - # save config - config_file = os.path.join(save_directory, CONFIG_FILE) - self.config.to_yaml_file(config_file) - return save_directory - - -def main(): - config = NXDenoiseConfig() - discriminator = MetricDiscriminator(config=config) - - # shape: [batch_size, num_samples] - # x = torch.ones([4, int(4.5 * 16000)]) - # y = torch.ones([4, int(4.5 * 16000)]) - x = torch.ones([4, 16000]) - y = torch.ones([4, 16000]) - - output = discriminator.forward(x, y) - print(output.shape) - print(output) - - return - - -if __name__ == "__main__": - main() diff --git a/toolbox/torchaudio/models/nx_denoise/inference_nx_denoise.py b/toolbox/torchaudio/models/nx_denoise/inference_nx_denoise.py deleted file mode 100644 index 5f5dce5ed9727ab8460a1ad5f1d61c93db1e4fe6..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/inference_nx_denoise.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import logging -from pathlib import Path -import shutil -import tempfile -import zipfile - -import librosa -import numpy as np -import torch -import torchaudio - -from project_settings import project_path -from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig -from toolbox.torchaudio.models.nx_denoise.modeling_nx_denoise import NXDenoisePretrainedModel, MODEL_FILE - -logger = logging.getLogger("toolbox") - - -class InferenceNXDenoise(object): - def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"): - self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file - self.device = torch.device(device) - - logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}") - config, model = self.load_models(self.pretrained_model_path_or_zip_file) - logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}") - - self.config = config - self.model = model - self.model.to(device) - self.model.eval() - - def load_models(self, model_path: str): - model_path = Path(model_path) - if model_path.name.endswith(".zip"): - with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip: - out_root = Path(tempfile.gettempdir()) / "nx_denoise" - out_root.mkdir(parents=True, exist_ok=True) - f_zip.extractall(path=out_root) - model_path = out_root / model_path.stem - - config = NXDenoiseConfig.from_pretrained( - pretrained_model_name_or_path=model_path.as_posix(), - ) - model = NXDenoisePretrainedModel.from_pretrained( - pretrained_model_name_or_path=model_path.as_posix(), - ) - model.to(self.device) - model.eval() - - shutil.rmtree(model_path) - return config, model - - def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor: - if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: - raise AssertionError(f"The value range of audio samples should be between -1 and 1.") - - # noisy_audio shape: [batch_size, num_samples] - noisy_audios = noisy_audio.to(self.device) - - with torch.no_grad(): - # enhanced_audios = self.model.forward_chunk_by_chunk(noisy_audios) - enhanced_audios = self.model.forward(noisy_audios) - # enhanced_audio shape: [batch_size, n_samples] - # enhanced_audios = torch.squeeze(enhanced_audios, dim=1) - - enhanced_audio = enhanced_audios[0] - # enhanced_audio shape: [num_samples,] - return enhanced_audio - - -def main(): - model_zip_file = project_path / "trained_models/nx-denoise.zip" - runtime = InferenceNXDenoise(model_zip_file) - - sample_rate = 8000 - noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav" - noisy_audio, _ = librosa.load( - noisy_audio_file.as_posix(), - sr=sample_rate, - ) - noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)] - noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) - noisy_audio = noisy_audio.unsqueeze(dim=0) - - enhanced_audio = runtime.enhancement_by_tensor(noisy_audio) - - filename = "enhanced_audio.wav" - torchaudio.save(filename, enhanced_audio.detach().cpu().unsqueeze(dim=0), sample_rate) - - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_denoise/loss.py b/toolbox/torchaudio/models/nx_denoise/loss.py deleted file mode 100644 index 475535006ee63213332fdc19ae91da1d81fe9cfc..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/loss.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import numpy as np -import torch - - -def anti_wrapping_function(x): - - return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) - - -def phase_losses(phase_r, phase_g): - - ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) - gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1))) - iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2))) - - return ip_loss, gd_loss, iaf_loss - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_denoise/metrics.py b/toolbox/torchaudio/models/nx_denoise/metrics.py deleted file mode 100644 index 78468894a56d4488021e83ea47e07c785a385269..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/metrics.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -from joblib import Parallel, delayed -import numpy as np -from pesq import pesq -from typing import List - -from pesq import cypesq - - -def run_pesq(clean_audio: np.ndarray, - noisy_audio: np.ndarray, - sample_rate: int = 16000, - mode: str = "wb", - ) -> float: - if sample_rate == 8000 and mode == "wb": - raise AssertionError(f"mode should be `nb` when sample_rate is 8000") - try: - pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode) - except cypesq.NoUtterancesError as e: - pesq_score = -1 - except Exception as e: - print(f"pesq failed. error type: {type(e)}, error text: {str(e)}") - pesq_score = -1 - return pesq_score - - -def run_batch_pesq(clean_audio_list: List[np.ndarray], - noisy_audio_list: List[np.ndarray], - sample_rate: int = 16000, - mode: str = "wb", - n_jobs: int = 4, - ) -> List[float]: - parallel = Parallel(n_jobs=n_jobs) - - parallel_tasks = list() - for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list): - parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode) - parallel_tasks.append(parallel_task) - - pesq_score_list = parallel.__call__(parallel_tasks) - return pesq_score_list - - -def run_pesq_score(clean_audio_list: List[np.ndarray], - noisy_audio_list: List[np.ndarray], - sample_rate: int = 16000, - mode: str = "wb", - n_jobs: int = 4, - ) -> List[float]: - - pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list, - noisy_audio_list=noisy_audio_list, - sample_rate=sample_rate, - mode=mode, - n_jobs=n_jobs, - ) - - pesq_score = np.mean(pesq_score_list) - return pesq_score - - -def main(): - clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) - noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) - - clean_audio_list = list(clean_audio) - noisy_audio_list = list(noisy_audio) - - pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list) - print(pesq_score_list) - - pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list) - print(pesq_score) - - return - - -if __name__ == "__main__": - main() diff --git a/toolbox/torchaudio/models/nx_denoise/modeling_nx_denoise.py b/toolbox/torchaudio/models/nx_denoise/modeling_nx_denoise.py deleted file mode 100644 index fbccf9c2b97abe81de857db68e284bb19fbc0c75..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/modeling_nx_denoise.py +++ /dev/null @@ -1,392 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import os -from typing import List, Optional, Union - -import numpy as np -import torch -import torch.nn as nn -from torch.nn import functional as F - -from toolbox.torchaudio.configuration_utils import CONFIG_FILE -from toolbox.torchaudio.models.nx_denoise.configuration_nx_denoise import NXDenoiseConfig -from toolbox.torchaudio.models.nx_denoise.causal_convolution.causal_conv2d import CausalConv2dEncoder -from toolbox.torchaudio.models.nx_denoise.transformers.transformers import TSTransformerEncoder - - -class DownSamplingBlock(nn.Module): - def __init__(self, - in_channels: int, - hidden_channels: int, - kernel_size: int, - stride: int, - ): - super(DownSamplingBlock, self).__init__() - self.conv1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, stride) - self.relu = nn.ReLU() - self.conv2 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1) - self.glu = nn.GLU(dim=1) - - def forward(self, x: torch.Tensor): - # x shape: [batch_size, 1, num_samples] - x = self.conv1.forward(x) - # x shape: [batch_size, hidden_channels, new_num_samples] - x = self.relu(x) - x = self.conv2.forward(x) - # x shape: [batch_size, hidden_channels*2, new_num_samples] - x = self.glu(x) - # x shape: [batch_size, hidden_channels, new_num_samples] - # new_num_samples = (num_samples-kernel_size) // stride + 1 - return x - - -class DownSampling(nn.Module): - def __init__(self, - num_layers: int, - in_channels: int, - hidden_channels: int, - kernel_size: int, - stride: int, - ): - super(DownSampling, self).__init__() - self.num_layers = num_layers - - down_sampling_block_list = list() - for idx in range(self.num_layers): - down_sampling_block = DownSamplingBlock( - in_channels=in_channels, - hidden_channels=hidden_channels, - kernel_size=kernel_size, - stride=stride, - ) - down_sampling_block_list.append(down_sampling_block) - in_channels = hidden_channels - - self.down_sampling_block_list = nn.ModuleList(modules=down_sampling_block_list) - - def forward(self, x: torch.Tensor): - # x shape: [batch_size, channels, num_samples] - skip_connection_list = list() - for down_sampling_block in self.down_sampling_block_list: - x = down_sampling_block.forward(x) - skip_connection_list.append(x) - # x shape: [batch_size, hidden_channels, num_samples**] - return x, skip_connection_list - - -class UpSamplingBlock(nn.Module): - def __init__(self, - out_channels: int, - hidden_channels: int, - kernel_size: int, - stride: int, - do_relu: bool = True, - ): - super(UpSamplingBlock, self).__init__() - self.do_relu = do_relu - - self.conv1 = nn.Conv1d(hidden_channels, hidden_channels * 2, 1) - self.glu = nn.GLU(dim=1) - self.convt = nn.ConvTranspose1d(hidden_channels, out_channels, kernel_size, stride) - self.relu = nn.ReLU() - - def forward(self, x: torch.Tensor): - # x shape: [batch_size, hidden_channels*2, num_samples] - x = self.conv1.forward(x) - # x shape: [batch_size, hidden_channels, num_samples] - x = self.glu(x) - # x shape: [batch_size, hidden_channels, num_samples] - x = self.convt.forward(x) - # x shape: [batch_size, hidden_channels, new_num_samples] - # new_num_samples = (num_samples - 1) * stride + kernel_size - if self.do_relu: - x = self.relu(x) - return x - - -class UpSampling(nn.Module): - def __init__(self, - num_layers: int, - out_channels: int, - hidden_channels: int, - kernel_size: int, - stride: int, - ): - super(UpSampling, self).__init__() - self.num_layers = num_layers - - up_sampling_block_list = list() - for idx in range(self.num_layers-1): - up_sampling_block = UpSamplingBlock( - out_channels=hidden_channels, - hidden_channels=hidden_channels, - kernel_size=kernel_size, - stride=stride, - do_relu=True, - ) - up_sampling_block_list.append(up_sampling_block) - else: - up_sampling_block = UpSamplingBlock( - out_channels=out_channels, - hidden_channels=hidden_channels, - kernel_size=kernel_size, - stride=stride, - do_relu=False, - ) - up_sampling_block_list.append(up_sampling_block) - self.up_sampling_block_list = nn.ModuleList(modules=up_sampling_block_list) - - def forward(self, x: torch.Tensor, skip_connection_list: List[torch.Tensor]): - skip_connection_list = skip_connection_list[::-1] - - # x shape: [batch_size, channels, num_samples] - for idx, up_sampling_block in enumerate(self.up_sampling_block_list): - skip_x = skip_connection_list[idx] - x = x + skip_x - # x = x + skip_x[:, :, :x.size(2)] - x = up_sampling_block.forward(x) - return x - - -def get_padding_length(length, num_layers: int, kernel_size: int, stride: int): - for _ in range(num_layers): - if length < kernel_size: - length = 1 - else: - length = 1 + np.ceil((length - kernel_size) / stride) - - for _ in range(num_layers): - length = (length - 1) * stride + kernel_size - - padded_length = int(length) - return padded_length - - -class NXDenoise(nn.Module): - def __init__(self, config: NXDenoiseConfig): - super().__init__() - self.config = config - - self.down_sampling = DownSampling( - num_layers=config.down_sampling_num_layers, - in_channels=config.down_sampling_in_channels, - hidden_channels=config.down_sampling_hidden_channels, - kernel_size=config.down_sampling_kernel_size, - stride=config.down_sampling_stride, - ) - self.causal_conv_in = CausalConv2dEncoder( - in_channels=config.causal_in_channels, - hidden_channels=config.causal_hidden_channels, - out_channels=config.causal_hidden_channels, - kernel_size=config.causal_kernel_size, - bias=config.causal_bias, - separable=config.causal_separable, - f_stride=config.causal_f_stride, - lookahead=0, - num_layers=config.causal_num_layers, - ) - self.ts_transformer = TSTransformerEncoder( - input_size=config.down_sampling_hidden_channels, - hidden_size=config.tsfm_hidden_size, - attention_heads=config.tsfm_attention_heads, - num_blocks=config.tsfm_num_blocks, - dropout_rate=config.tsfm_dropout_rate, - max_time_relative_position=config.tsfm_max_time_relative_position, - max_freq_relative_position=config.tsfm_max_freq_relative_position, - chunk_size=config.tsfm_chunk_size, - num_left_chunks=config.tsfm_num_left_chunks, - num_right_chunks=config.tsfm_num_right_chunks, - ) - self.causal_conv_out = CausalConv2dEncoder( - in_channels=config.causal_hidden_channels, - hidden_channels=config.causal_hidden_channels, - out_channels=config.causal_in_channels, - kernel_size=config.causal_kernel_size, - bias=config.causal_bias, - separable=config.causal_separable, - f_stride=config.causal_f_stride, - lookahead=0, - num_layers=config.causal_num_layers, - ) - self.up_sampling = UpSampling( - num_layers=config.down_sampling_num_layers, - out_channels=config.down_sampling_in_channels, - hidden_channels=config.down_sampling_hidden_channels, - kernel_size=config.down_sampling_kernel_size, - stride=config.down_sampling_stride, - ) - - def forward(self, noisy_audios: torch.Tensor): - # noisy_audios shape: [batch_size, n_samples] - noisy_audios = torch.unsqueeze(noisy_audios, dim=1) - # noisy_audios shape: [batch_size, 1, n_samples] - - n_samples = noisy_audios.shape[-1] - padded_length = get_padding_length( - n_samples, - num_layers=self.config.down_sampling_num_layers, - kernel_size=self.config.down_sampling_kernel_size, - stride=self.config.down_sampling_stride, - ) - noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0) - - # down sampling - bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded) - # bottle_neck shape: [batch_size, channels, time_steps] - bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1) - # bottle_neck shape: [batch_size, time_steps, channels] - bottle_neck = torch.unsqueeze(bottle_neck, dim=1) - # bottle_neck shape: [batch_size, 1, time_steps, freq_dim] - - # causal conv in - bottle_neck = self.causal_conv_in.forward(bottle_neck) - # bottle_neck shape: [batch_size, channels, time_steps, freq_dim] - - # ts transformer - # bottle_neck shape: [batch_size, channels, time_steps, freq_dim] - bottle_neck = self.ts_transformer.forward(bottle_neck) - # bottle_neck shape: [batch_size, channels, time_steps, freq_dim] - - # causal conv out - bottle_neck = self.causal_conv_out.forward(bottle_neck) - # bottle_neck shape: [batch_size, 1, time_steps, freq_dim] - - # up sampling - bottle_neck = torch.squeeze(bottle_neck, dim=1) - # bottle_neck shape: [batch_size, time_steps, channels] - bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1) - # bottle_neck shape: [batch_size, channels, time_steps] - - enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list) - - enhanced_audios = enhanced_audios[:, :, :n_samples] - # enhanced_audios shape: [batch_size, 1, n_samples] - - enhanced_audios = torch.squeeze(enhanced_audios, dim=1) - # enhanced_audios shape: [batch_size, n_samples] - - return enhanced_audios - - - def forward_chunk_by_chunk(self, noisy_audios: torch.Tensor): - # noisy_audios shape: [batch_size, n_samples] - noisy_audios = torch.unsqueeze(noisy_audios, dim=1) - # noisy_audios shape: [batch_size, 1, n_samples] - - n_samples = noisy_audios.shape[-1] - padded_length = get_padding_length( - n_samples, - num_layers=self.config.down_sampling_num_layers, - kernel_size=self.config.down_sampling_kernel_size, - stride=self.config.down_sampling_stride, - ) - noisy_audios_padded = F.pad(input=noisy_audios, pad=(0, padded_length - n_samples), mode="constant", value=0) - - # down sampling - bottle_neck, skip_connection_list = self.down_sampling.forward(noisy_audios_padded) - # bottle_neck shape: [batch_size, channels, time_steps] - bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1) - # bottle_neck shape: [batch_size, time_steps, channels] - bottle_neck = torch.unsqueeze(bottle_neck, dim=1) - # bottle_neck shape: [batch_size, 1, time_steps, freq_dim] - - # causal conv in - bottle_neck = self.causal_conv_in.forward_chunk_by_chunk(bottle_neck) - # bottle_neck shape: [batch_size, channels, time_steps, freq_dim] - - # ts transformer - # bottle_neck shape: [batch_size, channels, time_steps, freq_dim] - bottle_neck = self.ts_transformer.forward_chunk_by_chunk(bottle_neck) - # bottle_neck shape: [batch_size, channels, time_steps, freq_dim] - - # causal conv out - bottle_neck = self.causal_conv_out.forward_chunk_by_chunk(bottle_neck) - # bottle_neck shape: [batch_size, 1, time_steps, freq_dim] - - # up sampling - bottle_neck = torch.squeeze(bottle_neck, dim=1) - # bottle_neck shape: [batch_size, time_steps, channels] - bottle_neck = torch.transpose(bottle_neck, dim0=-2, dim1=-1) - # bottle_neck shape: [batch_size, channels, time_steps] - - enhanced_audios = self.up_sampling.forward(bottle_neck, skip_connection_list) - - enhanced_audios = enhanced_audios[:, :, :n_samples] - # enhanced_audios shape: [batch_size, 1, n_samples] - - enhanced_audios = torch.squeeze(enhanced_audios, dim=1) - # enhanced_audios shape: [batch_size, n_samples] - - return enhanced_audios - - -MODEL_FILE = "generator.pt" - - -class NXDenoisePretrainedModel(NXDenoise): - def __init__(self, - config: NXDenoiseConfig, - ): - super(NXDenoisePretrainedModel, self).__init__( - config=config, - ) - self.config = config - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - config = NXDenoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - - model = cls(config) - - if os.path.isdir(pretrained_model_name_or_path): - ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) - else: - ckpt_file = pretrained_model_name_or_path - - with open(ckpt_file, "rb") as f: - state_dict = torch.load(f, map_location="cpu", weights_only=True) - model.load_state_dict(state_dict, strict=True) - return model - - def save_pretrained(self, - save_directory: Union[str, os.PathLike], - state_dict: Optional[dict] = None, - ): - - model = self - - if state_dict is None: - state_dict = model.state_dict() - - os.makedirs(save_directory, exist_ok=True) - - # save state dict - model_file = os.path.join(save_directory, MODEL_FILE) - torch.save(state_dict, model_file) - - # save config - config_file = os.path.join(save_directory, CONFIG_FILE) - self.config.to_yaml_file(config_file) - return save_directory - - -def main(): - - config = NXDenoiseConfig() - - # shape: [batch_size, channels, num_samples] - # min length: 94, stride: 32, 32 == 2**5 - # x = torch.ones([4, 94]) - # x = torch.ones([4, 126]) - # x = torch.ones([4, 158]) - # x = torch.ones([4, 190]) - x = torch.ones([4, 16000]) - - model = NXDenoise(config) - enhanced_audios = model.forward(x) - print(enhanced_audios.shape) - return - - -if __name__ == "__main__": - main() diff --git a/toolbox/torchaudio/models/nx_denoise/stftnet/__init__.py b/toolbox/torchaudio/models/nx_denoise/stftnet/__init__.py deleted file mode 100644 index 8bc5155c67cae42f80e8126d1727b0edc1e02398..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/stftnet/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_denoise/stftnet/istftnet.py b/toolbox/torchaudio/models/nx_denoise/stftnet/istftnet.py deleted file mode 100644 index 02f80808d91971c2a76fd2606703ec5b825f26de..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/stftnet/istftnet.py +++ /dev/null @@ -1,9 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -""" -https://arxiv.org/abs/2203.02395 -""" - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_denoise/stftnet/stfnets.py b/toolbox/torchaudio/models/nx_denoise/stftnet/stfnets.py deleted file mode 100644 index b51732e16ae40f62b9df7f0c0668b323feec6103..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/stftnet/stfnets.py +++ /dev/null @@ -1,9 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -""" -https://arxiv.org/abs/1902.07849 -""" - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_denoise/transformers/__init__.py b/toolbox/torchaudio/models/nx_denoise/transformers/__init__.py deleted file mode 100644 index 8bc5155c67cae42f80e8126d1727b0edc1e02398..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/transformers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_denoise/transformers/attention.py b/toolbox/torchaudio/models/nx_denoise/transformers/attention.py deleted file mode 100644 index 9492d0498e8dcfd2afc02e853c491100a6ba18f7..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/transformers/attention.py +++ /dev/null @@ -1,263 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import math -from typing import Tuple - -import torch -import torch.nn as nn - - -class MultiHeadSelfAttention(nn.Module): - def __init__(self, n_head: int, n_feat: int, dropout_rate: float): - """ - :param n_head: int. the number of heads. - :param n_feat: int. the number of features. - :param dropout_rate: float. dropout rate. - """ - super().__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.dropout = nn.Dropout(p=dropout_rate) - - def forward_qkv(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - transform query, key and value. - :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat). - :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat). - :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat). - :return: - """ - n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - q = q.transpose(1, 2) # (batch, head, time1, d_k) - k = k.transpose(1, 2) # (batch, head, time2, d_k) - v = v.transpose(1, 2) # (batch, head, time2, d_k) - - return q, k, v - - def forward_attention(self, - value: torch.Tensor, - scores: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) - ) -> torch.Tensor: - """ - compute attention context vector. - :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k). - :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2). - :param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or - (batch_size, time1, time2), (0, 0, 0) means fake mask. - :return: torch.Tensor. transformed value. (batch_size, time1, d_model). - weighted by the attention score (batch_size, time1, time2). - """ - n_batch = value.size(0) - # NOTE: When will `if mask.size(2) > 0` be True? - # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the - # 1st chunk to ease the onnx export.] - # 2. pytorch training - if mask.size(2) > 0: # time2 > 0 - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - # For last chunk, time2 might be larger than scores.size(-1) - mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) - scores = scores.masked_fill(mask, -float('inf')) - attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) - - # NOTE: When will `if mask.size(2) > 0` be False? - # 1. onnx(16/-1, -1/-1, 16/0) - # 2. jit (16/-1, -1/-1, 16/0, 16/4) - else: - attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat) - - return self.linear_out(x) # (batch, time1, n_feat) - - def forward(self, - x: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: - - q, k, v = self.forward_qkv(x, x, x) - - if cache.size(0) > 0: - key_cache, value_cache = torch.split( - cache, cache.size(-1) // 2, dim=-1) - k = torch.cat([key_cache, k], dim=2) - v = torch.cat([value_cache, v], dim=2) - # NOTE: We do cache slicing in encoder.forward_chunk, since it's - # non-trivial to calculate `next_cache_start` here. - new_cache = torch.cat((k, v), dim=-1) - - scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) - return self.forward_attention(v, scores, mask), new_cache - - -class RelativeMultiHeadSelfAttention(nn.Module): - - def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120): - """ - :param n_head: int. the number of heads. - :param n_feat: int. the number of features. - :param dropout_rate: float. dropout rate. - :param max_relative_position: int. maximum relative position for relative position encoding. - """ - super().__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.dropout = nn.Dropout(p=dropout_rate) - - # Relative position encoding - self.max_relative_position = max_relative_position - self.relative_position_k = nn.Parameter(torch.randn(max_relative_position * 2 + 1, self.d_k)) - - def forward_qkv(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - transform query, key and value. - :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat). - :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat). - :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat). - :return: - """ - n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - q = q.transpose(1, 2) # (batch, head, time1, d_k) - k = k.transpose(1, 2) # (batch, head, time2, d_k) - v = v.transpose(1, 2) # (batch, head, time2, d_k) - - return q, k, v - - def forward_attention(self, - value: torch.Tensor, - scores: torch.Tensor, - mask: torch.Tensor = None - ) -> torch.Tensor: - """ - compute attention context vector. - :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, key_time_steps, d_k). - :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, query_time_steps, key_time_steps). - :param mask: torch.Tensor. mask. shape=(batch_size, 1, key_time_steps) or (batch_size, query_time_steps, key_time_steps). - :return: torch.Tensor. transformed value. (batch_size, query_time_steps, d_model). - weighted by the attention score (batch_size, query_time_steps, key_time_steps). - """ - n_batch = value.size(0) - if mask is not None: - mask = mask.unsqueeze(1).eq(0) - # mask shape: [batch_size, 1, query_time_steps, key_time_steps] - scores = scores.masked_fill(mask, -float('inf')) - attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) - else: - attn = torch.softmax(scores, dim=-1) - # attn shape: [batch_size, n_head, query_time_steps, key_time_steps] - - p_attn = self.dropout(attn) - - x = torch.matmul(p_attn, value) - # x shape: [batch_size, n_head, query_time_steps, d_k] - x = x.transpose(1, 2) - # x shape: [batch_size, query_time_steps, n_head, d_k] - - x = x.contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat) - # x shape: [batch_size, query_time_steps, n_head * d_k] - # x shape: [batch_size, query_time_steps, n_feat] - - x = self.linear_out(x) - # x shape: [batch_size, query_time_steps, n_feat] - return x - - def relative_position_encoding(self, length: int) -> torch.Tensor: - """ - Generate relative position encoding. - :param length: int. length of the sequence. - :return: torch.Tensor. relative position encoding. shape=(length, length, d_k). - """ - range_vec = torch.arange(length) - distance_mat = range_vec.unsqueeze(0) - range_vec.unsqueeze(1) - distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) - final_mat = distance_mat_clipped + self.max_relative_position - return final_mat - - def forward(self, - x: torch.Tensor, - mask: torch.Tensor = None, - cache: torch.Tensor = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - # attention! self attention. - - q, k, v = self.forward_qkv(x, x, x) - # q k v shape: [batch_size, self.h, query_time_steps, self.d_k] - - if cache is not None: - key_cache, value_cache = torch.split( - cache, cache.size(-1) // 2, dim=-1) - k = torch.cat([key_cache, k], dim=2) - v = torch.cat([value_cache, v], dim=2) - - # new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2] - new_cache = torch.cat((k, v), dim=-1) - - # native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps] - native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) - - # Compute relative position encoding - q_length, k_length = q.size(2), k.size(2) - relative_position = self.relative_position_encoding(k_length) - - relative_position = relative_position[-q_length:] - - relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1) - - relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k) - relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k) - - relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k) - # relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps] - - # score - scores = native_scores + relative_position_scores - - return self.forward_attention(v, scores, mask), new_cache - - -def main(): - rel_attention = RelativeMultiHeadSelfAttention(n_head=4, n_feat=256, dropout_rate=0.1) - - x = torch.ones(size=(1, 200, 256), dtype=torch.float32) - xt, new_cache = rel_attention.forward(x, x, x) - - # x = torch.ones(size=(1, 1, 256), dtype=torch.float32) - # cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32) - # xt, new_cache = rel_attention.forward(x, x, x, cache=cache) - - print(xt.shape) - print(new_cache.shape) - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_denoise/transformers/mask.py b/toolbox/torchaudio/models/nx_denoise/transformers/mask.py deleted file mode 100644 index 087be346c5619573cf5350290dfd3a70a4b685a5..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/transformers/mask.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import torch - - -def make_pad_mask(lengths: torch.Tensor, - max_len: int = 0, - ) -> torch.Tensor: - batch_size = lengths.size(0) - max_len = max_len if max_len > 0 else lengths.max().item() - seq_range = torch.arange( - 0, - max_len, - dtype=torch.int64, - device=lengths.device - ) - seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) - seq_length_expand = lengths.unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - return mask - - - -def subsequent_chunk_mask( - size: int, - chunk_size: int, - num_left_chunks: int = -1, - num_right_chunks: int = 0, - device: torch.device = torch.device("cpu"), -) -> torch.Tensor: - """ - Create mask for subsequent steps (size, size) with chunk size, - this is for streaming encoder - - Examples: - > subsequent_chunk_mask(4, 2) - [[1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 1, 1], - [1, 1, 1, 1]] - - :param size: int. size of mask. - :param chunk_size: int. size of chunk. - :param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks. - :param num_right_chunks: int. number of right chunks. - :param device: torch.device. "cpu" or "cuda" or torch.Tensor.device. - :return: torch.Tensor. mask - """ - - ret = torch.zeros(size, size, device=device, dtype=torch.bool) - for i in range(size): - if num_left_chunks < 0: - start = 0 - else: - start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) - ending = min((i // chunk_size + 1 + num_right_chunks) * chunk_size, size) - ret[i, start:ending] = True - return ret - - -def main(): - chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2) - print(chunk_mask) - - chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2, num_right_chunks=1) - print(chunk_mask) - - chunk_mask = subsequent_chunk_mask(size=9, chunk_size=2, num_left_chunks=2, num_right_chunks=1) - print(chunk_mask) - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_denoise/transformers/transformers.py b/toolbox/torchaudio/models/nx_denoise/transformers/transformers.py deleted file mode 100644 index 97c0583474a8f7254dc27db50ff7191d1645ad17..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/transformers/transformers.py +++ /dev/null @@ -1,479 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -from typing import Dict, Optional, Tuple, List, Union - -import torch -import torch.nn as nn - -from toolbox.torchaudio.models.nx_clean_unet.transformers.mask import subsequent_chunk_mask -from toolbox.torchaudio.models.nx_clean_unet.transformers.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention - - -class PositionwiseFeedForward(nn.Module): - def __init__(self, - input_dim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.ReLU()): - """ - FeedForward are applied on each position of the sequence. - the output dim is same with the input dim. - - :param input_dim: int. input dimension. - :param hidden_units: int. the number of hidden units. - :param dropout_rate: float. dropout rate. - :param activation: torch.nn.Module. activation function. - """ - super(PositionwiseFeedForward, self).__init__() - self.w_1 = torch.nn.Linear(input_dim, hidden_units) - self.activation = activation - self.dropout = torch.nn.Dropout(dropout_rate) - self.w_2 = torch.nn.Linear(hidden_units, input_dim) - - def forward(self, xs: torch.Tensor) -> torch.Tensor: - """ - Forward function. - :param xs: torch.Tensor. input tensor. shape=(batch_size, max_length, dim). - :return: output tensor. shape=(batch_size, max_length, dim). - """ - return self.w_2(self.dropout(self.activation(self.w_1(xs)))) - - -class TransformerBlock(nn.Module): - def __init__(self, - input_dim: int, - dropout_rate: float = 0.1, - n_heads: int = 4, - max_relative_position: int = 5120 - ): - super().__init__() - self.norm1 = nn.LayerNorm(input_dim, eps=1e-5) - self.attention = RelativeMultiHeadSelfAttention( - n_head=n_heads, - n_feat=input_dim, - dropout_rate=dropout_rate, - max_relative_position=max_relative_position, - ) - - self.dropout1 = nn.Dropout(dropout_rate) - self.norm2 = nn.LayerNorm(input_dim, eps=1e-5) - self.ffn = PositionwiseFeedForward( - input_dim=input_dim, - hidden_units=input_dim, - dropout_rate=dropout_rate - ) - self.dropout2 = nn.Dropout(dropout_rate) - self.norm3 = nn.LayerNorm(input_dim, eps=1e-5) - - def forward( - self, - x: torch.Tensor, - mask: torch.Tensor = None, - attention_cache: torch.Tensor = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - - :param x: torch.Tensor. shape=(batch_size, time, input_dim). - :param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time). - :param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE - shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim. - :return: - torch.Tensor: Output tensor (batch_size, time, input_dim). - torch.Tensor: att_cache tensor, (batch_size=1, head, cache_t1 + time, d_k * 2). - """ - xt = self.norm1(x) - - x_att, new_att_cache = self.attention.forward( - xt, mask=mask, cache=attention_cache - ) - x = x + self.dropout1(xt) - xt = self.norm2(x) - xt = self.ffn.forward(xt) - x = x + self.dropout2(xt) - - x = self.norm3(x) - - return x, new_att_cache - - -class TransformerEncoder(nn.Module): - """ - https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/encoder.py#L364 - """ - def __init__(self, - input_size: int = 64, - hidden_size: int = 256, - attention_heads: int = 4, - num_blocks: int = 6, - dropout_rate: float = 0.1, - max_relative_position: int = 1024, - chunk_size: int = 1, - num_left_chunks: int = 128, - num_right_chunks: int = 2, - ): - super().__init__() - self.input_size = input_size - self.hidden_size = hidden_size - - self.max_relative_position = max_relative_position - self.chunk_size = chunk_size - self.num_left_chunks = num_left_chunks - self.num_right_chunks = num_right_chunks - - self.input_linear = nn.Linear( - in_features=self.input_size, - out_features=self.hidden_size, - ) - - self.encoder_layer_list = torch.nn.ModuleList([ - TransformerBlock( - input_dim=hidden_size, - n_heads=attention_heads, - dropout_rate=dropout_rate, - max_relative_position=max_relative_position, - ) for _ in range(num_blocks) - ]) - - self.output_linear = nn.Linear( - in_features=self.hidden_size, - out_features=self.input_size, - ) - - def forward(self, - xs: torch.Tensor, - ): - """ - :param xs: Tensor, shape: [batch_size, time_steps, input_size] - :return: Tensor, shape: [batch_size, time_steps, input_size] - """ - batch_size, time_steps, _ = xs.shape - # xs shape: [batch_size, time_steps, input_size] - xs = self.input_linear.forward(xs) - # xs shape: [batch_size, time_steps, hidden_size] - - chunk_masks = subsequent_chunk_mask( - size=time_steps, - chunk_size=self.chunk_size, - num_left_chunks=self.num_left_chunks, - num_right_chunks=self.num_right_chunks, - ) - chunk_masks = chunk_masks.to(xs.device) - # chunk_masks shape: [time_steps, time_steps] - chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps)) - # chunk_masks shape: [batch_size, time_steps, time_steps] - - for encoder_layer in self.encoder_layer_list: - xs, _ = encoder_layer.forward(xs, chunk_masks) - - # xs shape: [batch_size, time_steps, hidden_size] - xs = self.output_linear.forward(xs) - # xs shape: [batch_size, time_steps, input_size] - - return xs - - def forward_chunk(self, - xs: torch.Tensor, - max_att_cache_length: int, - attention_cache: torch.Tensor = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - - :param xs: - :param max_att_cache_length: - :param attention_cache: Tensor, [num_layers, ...] - :return: - """ - # xs shape: [batch_size, time_steps, input_size] - xs = self.input_linear.forward(xs) - # xs shape: [batch_size, time_steps, hidden_size] - - r_att_cache = [] - for idx, encoder_layer in enumerate(self.encoder_layer_list): - xs, new_att_cache = encoder_layer.forward( - x=xs, attention_cache=attention_cache[idx] if attention_cache is not None else None, - ) - # new_att_cache shape: [batch_size, n_heads, time_steps, dim] - if new_att_cache.size(2) > max_att_cache_length: - begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size - end = self.num_right_chunks * self.chunk_size - new_att_cache = new_att_cache[:, :, -begin:-end, :] - r_att_cache.append(new_att_cache) - - r_att_cache = torch.stack(r_att_cache, dim=0) - - # xs shape: [batch_size, time_steps, hidden_size] - xs = self.output_linear.forward(xs) - # xs shape: [batch_size, time_steps, input_size] - - return xs, r_att_cache - - def forward_chunk_by_chunk( - self, - xs: torch.Tensor, - ) -> torch.Tensor: - - batch_size, time_steps, _ = xs.shape - - # attention_cache shape: [num_blocks, attention_heads, self.num_left_chunks * self.chunk_size, n_heads * d_k * 2] - max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size - attention_cache = None - - outputs = [] - for idx in range(0, time_steps, self.chunk_size): - begin = idx - end = begin + self.chunk_size * (self.num_right_chunks + 1) - chunk_xs = xs[:, begin:end, :] - # print(f"begin: {begin}, end: {end}, length: {chunk_xs.size(1)}") - - ys, attention_cache = self.forward_chunk( - xs=chunk_xs, - max_att_cache_length=max_att_cache_length, - attention_cache=attention_cache, - ) - - # ys shape: [batch_size, self.chunk_size * (self.num_right_chunks + 1), input_size] - ys = ys[:, :self.chunk_size, :] - - outputs.append(ys) - - ys = torch.cat(outputs, 1) - return ys - - -class TSTransformerBlock(nn.Module): - def __init__(self, - input_dim: int, - dropout_rate: float = 0.1, - n_heads: int = 4, - max_time_relative_position: int = 1024, - max_freq_relative_position: int = 128, - ): - super(TSTransformerBlock, self).__init__() - self.time_transformer = TransformerBlock(input_dim, dropout_rate, n_heads, max_time_relative_position) - self.freq_transformer = TransformerBlock(input_dim, dropout_rate, n_heads, max_freq_relative_position) - - def forward(self, - x: torch.Tensor, - mask: torch.Tensor = None, - attention_cache: torch.Tensor = None, - ): - """ - - :param x: Tensor. shape: [batch_size, hidden_size, time_steps, input_size] - :param mask: Tensor. shape: [time_steps, time_steps] - :param attention_cache: - :return: - """ - b, c, t, f = x.size() - - mask = None if mask is None else torch.broadcast_to(mask, size=(b*f, t, t)) - - x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c) - x_, new_att_cache = self.time_transformer.forward(x, mask, attention_cache) - x = x_ + x - x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c) - x_, _ = self.freq_transformer.forward(x) - x = x_ + x - x = x.view(b, t, f, c).permute(0, 3, 1, 2) - return x, new_att_cache - - -class TSTransformerEncoder(nn.Module): - def __init__(self, - input_size: int = 64, - hidden_size: int = 256, - attention_heads: int = 4, - num_blocks: int = 6, - dropout_rate: float = 0.1, - max_time_relative_position: int = 1024, - max_freq_relative_position: int = 128, - chunk_size: int = 1, - num_left_chunks: int = 128, - num_right_chunks: int = 2, - ): - super().__init__() - self.input_size = input_size - self.hidden_size = hidden_size - - self.max_time_relative_position = max_time_relative_position - self.max_freq_relative_position = max_freq_relative_position - self.chunk_size = chunk_size - self.num_left_chunks = num_left_chunks - self.num_right_chunks = num_right_chunks - - self.input_linear = nn.Linear( - in_features=self.input_size, - out_features=self.hidden_size, - ) - - self.encoder_layer_list = torch.nn.ModuleList([ - TSTransformerBlock( - input_dim=hidden_size, - n_heads=attention_heads, - dropout_rate=dropout_rate, - max_time_relative_position=max_time_relative_position, - max_freq_relative_position=max_freq_relative_position, - ) for _ in range(num_blocks) - ]) - - self.output_linear = nn.Linear( - in_features=self.hidden_size, - out_features=self.input_size, - ) - - def forward(self, - xs: torch.Tensor, - ): - """ - :param xs: Tensor, shape: [batch_size, channels, time_steps, input_size] - :return: Tensor, shape: [batch_size, channels, time_steps, input_size] - """ - batch_size, channels, time_steps, _ = xs.shape - # xs shape: [batch_size, channels, time_steps, input_size] - xs = xs.permute(0, 3, 2, 1) - # xs shape: [batch_size, input_size, time_steps, channels] - xs = self.input_linear.forward(xs) - # xs shape: [batch_size, input_size, time_steps, hidden_size] - xs = xs.permute(0, 3, 2, 1) - # xs shape: [batch_size, hidden_size, time_steps, input_size] - - chunk_masks = subsequent_chunk_mask( - size=time_steps, - chunk_size=self.chunk_size, - num_left_chunks=self.num_left_chunks, - num_right_chunks=self.num_right_chunks, - ) - chunk_masks = chunk_masks.to(xs.device) - # chunk_masks shape: [time_steps, time_steps] - - for encoder_layer in self.encoder_layer_list: - xs, _ = encoder_layer.forward(xs, chunk_masks) - # xs shape: [batch_size, hidden_size, time_steps, input_size] - xs = xs.permute(0, 3, 2, 1) - # xs shape: [batch_size, input_size, time_steps, hidden_size] - xs = self.output_linear.forward(xs) - # xs shape: [batch_size, input_size, time_steps, channels] - xs = xs.permute(0, 3, 2, 1) - # xs shape: [batch_size, channels, time_steps, input_size] - - return xs - - def forward_chunk(self, - xs: torch.Tensor, - max_att_cache_length: int, - attention_cache: torch.Tensor = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - - :param xs: - :param max_att_cache_length: - :param attention_cache: Tensor, shape: [num_layers, ...] - :return: - """ - # xs shape: [batch_size, channels, time_steps, input_size] - xs = xs.permute(0, 3, 2, 1) - xs = self.input_linear.forward(xs) - xs = xs.permute(0, 3, 2, 1) - # xs shape: [batch_size, hidden_size, time_steps, input_size] - - r_att_cache = [] - for idx, encoder_layer in enumerate(self.encoder_layer_list): - xs, new_att_cache = encoder_layer.forward( - x=xs, attention_cache=attention_cache[idx] if attention_cache is not None else None, - ) - # new_att_cache shape: [b*f, n_heads, time_steps, dim] - if new_att_cache.size(2) > max_att_cache_length: - begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size - end = self.num_right_chunks * self.chunk_size - new_att_cache = new_att_cache[:, :, -begin:-end, :] - r_att_cache.append(new_att_cache) - - r_att_cache = torch.stack(r_att_cache, dim=0) - - # xs shape: [batch_size, hidden_size, time_steps, input_size] - xs = xs.permute(0, 3, 2, 1) - xs = self.output_linear.forward(xs) - xs = xs.permute(0, 3, 2, 1) - # xs shape: [batch_size, channels, time_steps, input_size] - - return xs, r_att_cache - - def forward_chunk_by_chunk( - self, - xs: torch.Tensor, - ) -> torch.Tensor: - - batch_size, channels, time_steps, _ = xs.shape - - max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size - attention_cache = None - - outputs = [] - for idx in range(0, time_steps, self.chunk_size): - begin = idx - end = begin + self.chunk_size * (self.num_right_chunks + 1) - chunk_xs = xs[:, :, begin:end, :] - # chunk_xs shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size] - - ys, attention_cache = self.forward_chunk( - xs=chunk_xs, - max_att_cache_length=max_att_cache_length, - attention_cache=attention_cache, - ) - # ys shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size] - ys = ys[:, :, :self.chunk_size, :] - - outputs.append(ys) - - ys = torch.cat(outputs, dim=2) - return ys - - -def main2(): - - encoder = TransformerEncoder( - input_size=64, - hidden_size=256, - attention_heads=4, - num_blocks=6, - dropout_rate=0.1, - ) - print(encoder) - - x = torch.ones([4, 200, 64]) - - x = torch.ones([4, 200, 64]) - y = encoder.forward(xs=x) - print(y.shape) - - x = torch.ones([4, 200, 64]) - y = encoder.forward_chunk_by_chunk(xs=x) - print(y.shape) - - return - - -def main(): - - encoder = TSTransformerEncoder( - input_size=8, - hidden_size=16, - attention_heads=2, - num_blocks=2, - dropout_rate=0.1, - ) - # print(encoder) - - x = torch.ones([4, 8, 200, 8]) - y = encoder.forward(xs=x) - print(y.shape) - - x = torch.ones([4, 8, 200, 8]) - y = encoder.forward_chunk_by_chunk(xs=x) - print(y.shape) - - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_denoise/utils.py b/toolbox/torchaudio/models/nx_denoise/utils.py deleted file mode 100644 index 84a6918b6a8f945e196b6ef909d7ba3575ce686e..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/utils.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import torch -import torch.nn as nn - - -class LearnableSigmoid1d(nn.Module): - def __init__(self, in_features, beta=1): - super().__init__() - self.beta = beta - self.slope = nn.Parameter(torch.ones(in_features)) - self.slope.requiresGrad = True - - def forward(self, x): - # x shape: [batch_size, time_steps, spec_bins] - return self.beta * torch.sigmoid(self.slope * x) - - -def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True): - - hann_window = torch.hann_window(win_size).to(y.device) - stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, - center=center, pad_mode='reflect', normalized=False, return_complex=True) - stft_spec = torch.view_as_real(stft_spec) - mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9) - pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5) - # Magnitude Compression - mag = torch.pow(mag, compress_factor) - com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1) - - return mag, pha, com - - -def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True): - # Magnitude Decompression - mag = torch.pow(mag, (1.0/compress_factor)) - com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha)) - hann_window = torch.hann_window(win_size).to(com.device) - wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center) - - return wav - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_denoise/yaml/config.yaml b/toolbox/torchaudio/models/nx_denoise/yaml/config.yaml deleted file mode 100644 index a0b33db42a1f57579e3b4f973ce724392acfa582..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_denoise/yaml/config.yaml +++ /dev/null @@ -1,51 +0,0 @@ -model_name: "nx_denoise" - -sample_rate: 8000 -segment_size: 16000 -n_fft: 512 -win_size: 200 -hop_size: 80 -# 因为 hop_size 取 80,则相当于 stft 的时间步是 10ms 一步,所以降采样也考虑到差不多的分辨率。 - -# 2**down_sampling_num_layers, -# 例如 2**6=64 就意味着 64 个值在降采样之后是一个时间步, -# 则一步是 64/sample_rate = 0.008秒。 -# 那么 tsfm_chunk_size=2 则为16ms,tsfm_chunk_size=4 则为32ms -# 假设每次向左看1秒,向右看30ms,则: -# tsfm_chunk_size=1,tsfm_num_left_chunks=128,tsfm_num_right_chunks=4 -# tsfm_chunk_size=2,tsfm_num_left_chunks=64,tsfm_num_right_chunks=2 -# tsfm_chunk_size=4,tsfm_num_left_chunks=32,tsfm_num_right_chunks=1 -down_sampling_num_layers: 6 -down_sampling_in_channels: 1 -down_sampling_hidden_channels: 64 -down_sampling_kernel_size: 4 -down_sampling_stride: 2 - -causal_in_channels: 1 -causal_out_channels: 64 -causal_kernel_size: 3 -causal_bias: false -causal_separable: true -causal_f_stride: 1 -causal_num_layers: 3 - -tsfm_hidden_size: 256 -tsfm_attention_heads: 8 -tsfm_num_blocks: 6 -tsfm_dropout_rate: 0.1 -tsfm_max_length: 512 -tsfm_chunk_size: 1 -tsfm_num_left_chunks: 128 -tsfm_num_right_chunks: 4 - -discriminator_dim: 32 -discriminator_in_channel: 2 - -compress_factor: 0.3 - -batch_size: 4 -learning_rate: 0.0005 -adam_b1: 0.8 -adam_b2: 0.99 -lr_decay: 0.99 -seed: 1234 diff --git a/toolbox/torchaudio/models/nx_dfnet/configuration_nx_dfnet.py b/toolbox/torchaudio/models/nx_dfnet/configuration_nx_dfnet.py deleted file mode 100644 index 1e82d9c3bd5675feea8ae945a212359c3438f66a..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_dfnet/configuration_nx_dfnet.py +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -from typing import Tuple - -from toolbox.torchaudio.configuration_utils import PretrainedConfig - - -class NXDfNetConfig(PretrainedConfig): - def __init__(self, - sample_rate: int = 8000, - freq_bins: int = 256, - win_size: int = 200, - hop_size: int = 100, - - conv_channels: int = 64, - conv_kernel_size_input: Tuple[int, int] = (3, 3), - conv_kernel_size_inner: Tuple[int, int] = (1, 3), - conv_lookahead: int = 0, - - convt_kernel_size_inner: Tuple[int, int] = (1, 3), - - embedding_hidden_size: int = 256, - encoder_combine_op: str = "concat", - - encoder_emb_skip_op: str = "none", - encoder_emb_linear_groups: int = 16, - encoder_emb_hidden_size: int = 256, - - encoder_linear_groups: int = 32, - - lsnr_max: int = 30, - lsnr_min: int = -15, - norm_tau: float = 1., - - decoder_emb_num_layers: int = 3, - decoder_emb_skip_op: str = "none", - decoder_emb_linear_groups: int = 16, - decoder_emb_hidden_size: int = 256, - - df_decoder_hidden_size: int = 256, - df_num_layers: int = 2, - df_order: int = 5, - df_bins: int = 96, - df_gru_skip: str = "grouped_linear", - df_decoder_linear_groups: int = 16, - df_pathway_kernel_size_t: int = 5, - df_lookahead: int = 2, - - use_post_filter: bool = False, - **kwargs - ): - super(NXDfNetConfig, self).__init__(**kwargs) - # transform - self.sample_rate = sample_rate - self.freq_bins = freq_bins - self.win_size = win_size - self.hop_size = hop_size - - # conv - self.conv_channels = conv_channels - self.conv_kernel_size_input = conv_kernel_size_input - self.conv_kernel_size_inner = conv_kernel_size_inner - self.conv_lookahead = conv_lookahead - - self.convt_kernel_size_inner = convt_kernel_size_inner - - self.embedding_hidden_size = embedding_hidden_size - - # encoder - self.encoder_emb_skip_op = encoder_emb_skip_op - self.encoder_emb_linear_groups = encoder_emb_linear_groups - self.encoder_emb_hidden_size = encoder_emb_hidden_size - - self.encoder_linear_groups = encoder_linear_groups - self.encoder_combine_op = encoder_combine_op - - self.lsnr_max = lsnr_max - self.lsnr_min = lsnr_min - self.norm_tau = norm_tau - - # decoder - self.decoder_emb_num_layers = decoder_emb_num_layers - self.decoder_emb_skip_op = decoder_emb_skip_op - self.decoder_emb_linear_groups = decoder_emb_linear_groups - self.decoder_emb_hidden_size = decoder_emb_hidden_size - - # df decoder - self.df_decoder_hidden_size = df_decoder_hidden_size - self.df_num_layers = df_num_layers - self.df_order = df_order - self.df_bins = df_bins - self.df_gru_skip = df_gru_skip - self.df_decoder_linear_groups = df_decoder_linear_groups - self.df_pathway_kernel_size_t = df_pathway_kernel_size_t - self.df_lookahead = df_lookahead - - # runtime - self.use_post_filter = use_post_filter - - -if __name__ == "__main__": - pass diff --git a/toolbox/torchaudio/models/nx_dfnet/modeling_nx_dfnet.py b/toolbox/torchaudio/models/nx_dfnet/modeling_nx_dfnet.py deleted file mode 100644 index 924fa0023d78da8d9432482338b75f59007aead1..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_dfnet/modeling_nx_dfnet.py +++ /dev/null @@ -1,989 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import os -import math -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn -from torch.nn import functional as F -import torchaudio - -from toolbox.torchaudio.models.nx_dfnet.utils import overlap_and_add -from toolbox.torchaudio.models.nx_dfnet.configuration_nx_dfnet import NXDfNetConfig -from toolbox.torchaudio.configuration_utils import CONFIG_FILE - - -MODEL_FILE = "model.pt" - - -norm_layer_dict = { - "batch_norm_2d": torch.nn.BatchNorm2d -} - - -activation_layer_dict = { - "relu": torch.nn.ReLU, - "identity": torch.nn.Identity, - "sigmoid": torch.nn.Sigmoid, -} - - -class CausalConv2d(nn.Sequential): - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Iterable[int]], - fstride: int = 1, - dilation: int = 1, - fpad: bool = True, - bias: bool = True, - separable: bool = False, - norm_layer: str = "batch_norm_2d", - activation_layer: str = "relu", - lookahead: int = 0 - ): - """ - Causal Conv2d by delaying the signal for any lookahead. - - Expected input format: [batch_size, channels, time_steps, spec_dim] - - :param in_channels: - :param out_channels: - :param kernel_size: - :param fstride: - :param dilation: - :param fpad: - """ - super(CausalConv2d, self).__init__() - kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) - - if fpad: - fpad_ = kernel_size[1] // 2 + dilation - 1 - else: - fpad_ = 0 - - # for last 2 dim, pad (left, right, top, bottom). - pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) - - layers = list() - if any(x > 0 for x in pad): - layers.append(nn.ConstantPad2d(pad, 0.0)) - - groups = math.gcd(in_channels, out_channels) if separable else 1 - if groups == 1: - separable = False - if max(kernel_size) == 1: - separable = False - - layers.append( - nn.Conv2d( - in_channels, - out_channels, - kernel_size=kernel_size, - padding=(0, fpad_), - stride=(1, fstride), # stride over time is always 1 - dilation=(1, dilation), # dilation over time is always 1 - groups=groups, - bias=bias, - ) - ) - - if separable: - layers.append( - nn.Conv2d( - out_channels, - out_channels, - kernel_size=1, - bias=False, - ) - ) - - if norm_layer is not None: - norm_layer = norm_layer_dict[norm_layer] - layers.append(norm_layer(out_channels)) - - if activation_layer is not None: - activation_layer = activation_layer_dict[activation_layer] - layers.append(activation_layer()) - - super().__init__(*layers) - - def forward(self, inputs): - for module in self: - inputs = module(inputs) - return inputs - - -class CausalConvTranspose2d(nn.Sequential): - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Iterable[int]], - fstride: int = 1, - dilation: int = 1, - fpad: bool = True, - bias: bool = True, - separable: bool = False, - norm_layer: str = "batch_norm_2d", - activation_layer: str = "relu", - lookahead: int = 0 - ): - """ - Causal ConvTranspose2d. - - Expected input format: [batch_size, channels, time_steps, spec_dim] - """ - super(CausalConvTranspose2d, self).__init__() - - kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size - - if fpad: - fpad_ = kernel_size[1] // 2 - else: - fpad_ = 0 - - # for last 2 dim, pad (left, right, top, bottom). - pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) - - layers = [] - if any(x > 0 for x in pad): - layers.append(nn.ConstantPad2d(pad, 0.0)) - - groups = math.gcd(in_channels, out_channels) if separable else 1 - if groups == 1: - separable = False - - layers.append( - nn.ConvTranspose2d( - in_channels, - out_channels, - kernel_size=kernel_size, - padding=(kernel_size[0] - 1, fpad_ + dilation - 1), - output_padding=(0, fpad_), - stride=(1, fstride), # stride over time is always 1 - dilation=(1, dilation), # dilation over time is always 1 - groups=groups, - bias=bias, - ) - ) - - if separable: - layers.append( - nn.Conv2d( - out_channels, - out_channels, - kernel_size=1, - bias=False, - ) - ) - - if norm_layer is not None: - norm_layer = norm_layer_dict[norm_layer] - layers.append(norm_layer(out_channels)) - - if activation_layer is not None: - activation_layer = activation_layer_dict[activation_layer] - layers.append(activation_layer()) - - super().__init__(*layers) - - -class GroupedLinear(nn.Module): - - def __init__(self, input_size: int, hidden_size: int, groups: int = 1): - super().__init__() - # self.weight: Tensor - self.input_size = input_size - self.hidden_size = hidden_size - self.groups = groups - assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" - assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" - self.ws = input_size // groups - self.register_parameter( - "weight", - torch.nn.Parameter( - torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True - ), - ) - self.reset_parameters() - - def reset_parameters(self): - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # x: [..., I] - b, t, _ = x.shape - # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] - new_shape = (b, t, self.groups, self.ws) - x = x.view(new_shape) - # The better way, but not supported by torchscript - # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] - x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] - x = x.flatten(2, 3) # [B, T, H] - return x - - def __repr__(self): - cls = self.__class__.__name__ - return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" - - -class SqueezedGRU_S(nn.Module): - """ - SGE net: Video object detection with squeezed GRU and information entropy map - https://arxiv.org/abs/2106.07224 - """ - - def __init__( - self, - input_size: int, - hidden_size: int, - output_size: Optional[int] = None, - num_layers: int = 1, - linear_groups: int = 8, - batch_first: bool = True, - skip_op: str = "none", - activation_layer: str = "identity", - ): - super().__init__() - self.input_size = input_size - self.hidden_size = hidden_size - - self.linear_in = nn.Sequential( - GroupedLinear( - input_size=input_size, - hidden_size=hidden_size, - groups=linear_groups, - ), - activation_layer_dict[activation_layer](), - ) - - # gru skip operator - self.gru_skip_op = None - - if skip_op == "none": - self.gru_skip_op = None - elif skip_op == "identity": - if not input_size != output_size: - raise AssertionError("Dimensions do not match") - self.gru_skip_op = nn.Identity() - elif skip_op == "grouped_linear": - self.gru_skip_op = GroupedLinear( - input_size=hidden_size, - hidden_size=hidden_size, - groups=linear_groups, - ) - else: - raise NotImplementedError() - - self.gru = nn.GRU( - input_size=hidden_size, - hidden_size=hidden_size, - num_layers=num_layers, - batch_first=batch_first, - bidirectional=False, - ) - - if output_size is not None: - self.linear_out = nn.Sequential( - GroupedLinear( - input_size=hidden_size, - hidden_size=output_size, - groups=linear_groups, - ), - activation_layer_dict[activation_layer](), - ) - else: - self.linear_out = nn.Identity() - - def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: - x = self.linear_in(inputs) - - x, h = self.gru.forward(x, h) - - x = self.linear_out(x) - - if self.gru_skip_op is not None: - x = x + self.gru_skip_op(inputs) - - return x, h - - -class Add(nn.Module): - def forward(self, a, b): - return a + b - - -class Concat(nn.Module): - def forward(self, a, b): - return torch.cat((a, b), dim=-1) - - -class DeepSTFT(nn.Module): - def __init__(self, win_size: int, freq_bins: int): - super(DeepSTFT, self).__init__() - self.win_size = win_size - self.freq_bins = freq_bins - - self.conv1d_U = nn.Conv1d( - in_channels=1, - out_channels=freq_bins * 2, - kernel_size=win_size, - stride=win_size // 2, - bias=False - ) - - def forward(self, signal: torch.Tensor): - """ - :param signal: Tensor, shape: [batch_size, num_samples] - :return: v, Tensor, shape: [batch_size, freq_bins, time_steps, 2], - where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1 - """ - signal = torch.unsqueeze(signal, 1) - # signal shape: [batch_size, 1, num_samples] - spec = F.relu(self.conv1d_U(signal)) - # spec shape: [batch_size, freq_bins * 2, time_steps] - b, f2, t = spec.shape - spec = spec.view(b, f2//2, 2, t).permute(0, 1, 3, 2) - # spec shape: [batch_size, freq_bins, time_steps, 2] - return spec - - -class DeepISTFT(nn.Module): - def __init__(self, win_size: int, freq_bins: int): - super(DeepISTFT, self).__init__() - self.win_size = win_size - self.freq_bins = freq_bins - - self.basis_signals = nn.Linear( - in_features=freq_bins * 2, - out_features=win_size, - bias=False - ) - - def forward(self, - spec: torch.Tensor, - ): - """ - :param spec: Tensor, shape: [batch_size, freq_bins, time_steps, 2], - where time_steps = (num_samples-win_size) / (win_size/2) + 1 = 2num_samples/win_size-1 - :return: Tensor, shape: [batch_size, c, num_samples], - """ - b, f, t, _ = spec.shape - # spec shape: [b, f, t, 2] - spec = spec.permute(0, 2, 1, 3) - # spec shape: [b, t, f, 2] - spec = spec.view(b, 1, t, -1) - # spec shape: [b, 1, t, f2] - signal = self.basis_signals(spec) - # signal shape: [b, 1, t, win_size] - signal = overlap_and_add(signal, self.win_size//2) - # signal shape: [b, 1, num_samples] - return signal - - -class Encoder(nn.Module): - def __init__(self, config: NXDfNetConfig): - super(Encoder, self).__init__() - self.embedding_input_size = config.conv_channels * config.freq_bins // 4 - self.embedding_output_size = config.conv_channels * config.freq_bins // 4 - self.embedding_hidden_size = config.embedding_hidden_size - - self.spec_conv0 = CausalConv2d( - in_channels=1, - out_channels=config.conv_channels, - kernel_size=config.conv_kernel_size_input, - bias=False, - separable=True, - fstride=1, - lookahead=config.conv_lookahead, - ) - self.spec_conv1 = CausalConv2d( - in_channels=config.conv_channels, - out_channels=config.conv_channels, - kernel_size=config.conv_kernel_size_inner, - bias=False, - separable=True, - fstride=2, - lookahead=config.conv_lookahead, - ) - self.spec_conv2 = CausalConv2d( - in_channels=config.conv_channels, - out_channels=config.conv_channels, - kernel_size=config.conv_kernel_size_inner, - bias=False, - separable=True, - fstride=2, - lookahead=config.conv_lookahead, - ) - self.spec_conv3 = CausalConv2d( - in_channels=config.conv_channels, - out_channels=config.conv_channels, - kernel_size=config.conv_kernel_size_inner, - bias=False, - separable=True, - fstride=1, - lookahead=config.conv_lookahead, - ) - - self.df_conv0 = CausalConv2d( - in_channels=2, - out_channels=config.conv_channels, - kernel_size=config.conv_kernel_size_input, - bias=False, - separable=True, - fstride=1, - ) - self.df_conv1 = CausalConv2d( - in_channels=config.conv_channels, - out_channels=config.conv_channels, - kernel_size=config.conv_kernel_size_inner, - bias=False, - separable=True, - fstride=2, - ) - self.df_fc_emb = nn.Sequential( - GroupedLinear( - config.conv_channels * config.df_bins // 2, - self.embedding_input_size, - groups=config.encoder_linear_groups - ), - nn.ReLU(inplace=True) - ) - - if config.encoder_combine_op == "concat": - self.embedding_input_size *= 2 - self.combine = Concat() - else: - self.combine = Add() - - # emb_gru - if config.freq_bins % 8 != 0: - raise AssertionError("freq_bins should be divisible by 8") - - self.emb_gru = SqueezedGRU_S( - self.embedding_input_size, - self.embedding_hidden_size, - output_size=self.embedding_output_size, - num_layers=1, - batch_first=True, - skip_op=config.encoder_emb_skip_op, - linear_groups=config.encoder_emb_linear_groups, - activation_layer="relu", - ) - - # lsnr - self.lsnr_fc = nn.Sequential( - nn.Linear(self.embedding_output_size, 1), - nn.Sigmoid() - ) - self.lsnr_scale = config.lsnr_max - config.lsnr_min - self.lsnr_offset = config.lsnr_min - - def forward(self, - power_spec: torch.Tensor, - df_spec: torch.Tensor, - hidden_state: torch.Tensor = None, - ): - # power_spec shape: (batch_size, 1, time_steps, spec_dim) - e0 = self.spec_conv0.forward(power_spec) - e1 = self.spec_conv1.forward(e0) - e2 = self.spec_conv2.forward(e1) - e3 = self.spec_conv3.forward(e2) - # e0 shape: [batch_size, channels, time_steps, spec_dim] - # e1 shape: [batch_size, channels, time_steps, spec_dim // 2] - # e2 shape: [batch_size, channels, time_steps, spec_dim // 4] - # e3 shape: [batch_size, channels, time_steps, spec_dim // 4] - - # df_spec, shape: (batch_size, 2, time_steps, df_bins) - c0 = self.df_conv0(df_spec) - c1 = self.df_conv1(c0) - # c0 shape: [batch_size, channels, time_steps, df_bins] - # c1 shape: [batch_size, channels, time_steps, df_bins // 2] - - cemb = c1.permute(0, 2, 3, 1) - # cemb shape: [batch_size, time_steps, df_bins // 2, channels] - cemb = cemb.flatten(2) - # cemb shape: [batch_size, time_steps, df_bins // 2 * channels] - cemb = self.df_fc_emb(cemb) - # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels] - - # e3 shape: [batch_size, channels, time_steps, spec_dim // 4] - emb = e3.permute(0, 2, 3, 1) - # emb shape: [batch_size, time_steps, spec_dim // 4, channels] - emb = emb.flatten(2) - # emb shape: [batch_size, time_steps, spec_dim // 4 * channels] - - emb = self.combine(emb, cemb) - # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2] - # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels] - - emb, h = self.emb_gru.forward(emb, hidden_state) - # emb shape: [batch_size, time_steps, spec_dim // 4 * channels] - # h shape: [batch_size, 1, spec_dim] - - lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset - # lsnr shape: [batch_size, time_steps, 1] - - return e0, e1, e2, e3, emb, c0, lsnr, h - - -class Decoder(nn.Module): - def __init__(self, config: NXDfNetConfig): - super(Decoder, self).__init__() - - if config.freq_bins % 8 != 0: - raise AssertionError("freq_bins should be divisible by 8") - - self.emb_in_dim = config.conv_channels * config.freq_bins // 4 - self.emb_out_dim = config.conv_channels * config.freq_bins // 4 - self.emb_hidden_dim = config.decoder_emb_hidden_size - - self.emb_gru = SqueezedGRU_S( - self.emb_in_dim, - self.emb_hidden_dim, - output_size=self.emb_out_dim, - num_layers=config.decoder_emb_num_layers - 1, - batch_first=True, - skip_op=config.decoder_emb_skip_op, - linear_groups=config.decoder_emb_linear_groups, - activation_layer="relu", - ) - self.conv3p = CausalConv2d( - in_channels=config.conv_channels, - out_channels=config.conv_channels, - kernel_size=1, - bias=False, - separable=True, - fstride=1, - lookahead=config.conv_lookahead, - ) - self.convt3 = CausalConv2d( - in_channels=config.conv_channels, - out_channels=config.conv_channels, - kernel_size=config.conv_kernel_size_inner, - bias=False, - separable=True, - fstride=1, - lookahead=config.conv_lookahead, - ) - self.conv2p = CausalConv2d( - in_channels=config.conv_channels, - out_channels=config.conv_channels, - kernel_size=1, - bias=False, - separable=True, - fstride=1, - lookahead=config.conv_lookahead, - ) - self.convt2 = CausalConvTranspose2d( - in_channels=config.conv_channels, - out_channels=config.conv_channels, - kernel_size=config.convt_kernel_size_inner, - bias=False, - separable=True, - fstride=2, - lookahead=config.conv_lookahead, - ) - self.conv1p = CausalConv2d( - in_channels=config.conv_channels, - out_channels=config.conv_channels, - kernel_size=1, - bias=False, - separable=True, - fstride=1, - lookahead=config.conv_lookahead, - ) - self.convt1 = CausalConvTranspose2d( - in_channels=config.conv_channels, - out_channels=config.conv_channels, - kernel_size=config.convt_kernel_size_inner, - bias=False, - separable=True, - fstride=2, - lookahead=config.conv_lookahead, - ) - self.conv0p = CausalConv2d( - in_channels=config.conv_channels, - out_channels=config.conv_channels, - kernel_size=1, - bias=False, - separable=True, - fstride=1, - lookahead=config.conv_lookahead, - ) - self.conv0_out = CausalConv2d( - in_channels=config.conv_channels, - out_channels=1, - kernel_size=config.conv_kernel_size_inner, - activation_layer="sigmoid", - bias=False, - separable=True, - fstride=1, - lookahead=config.conv_lookahead, - ) - - def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor: - # Estimates erb mask - b, _, t, f8 = e3.shape - - # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] - emb, _ = self.emb_gru(emb) - # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4] - emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) - e3 = self.convt3(self.conv3p(e3) + emb) - # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] - e2 = self.convt2(self.conv2p(e2) + e3) - # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] - e1 = self.convt1(self.conv1p(e1) + e2) - # e1 shape: [batch_size, conv_channels, time_steps, freq_dim] - mask = self.conv0_out(self.conv0p(e0) + e1) - # mask shape: [batch_size, 1, time_steps, freq_dim] - return mask - - -class DfDecoder(nn.Module): - def __init__(self, config: NXDfNetConfig): - super(DfDecoder, self).__init__() - - self.embedding_input_size = config.conv_channels * config.freq_bins // 4 - self.df_decoder_hidden_size = config.df_decoder_hidden_size - self.df_num_layers = config.df_num_layers - - self.df_order = config.df_order - - self.df_bins = config.df_bins - self.df_out_ch = config.df_order * 2 - - self.df_convp = CausalConv2d( - config.conv_channels, - self.df_out_ch, - fstride=1, - kernel_size=(config.df_pathway_kernel_size_t, 1), - separable=True, - bias=False, - ) - self.df_gru = SqueezedGRU_S( - self.embedding_input_size, - self.df_decoder_hidden_size, - num_layers=self.df_num_layers, - batch_first=True, - skip_op="none", - activation_layer="relu", - ) - - if config.df_gru_skip == "none": - self.df_skip = None - elif config.df_gru_skip == "identity": - if config.embedding_hidden_size != config.df_decoder_hidden_size: - raise AssertionError("Dimensions do not match") - self.df_skip = nn.Identity() - elif config.df_gru_skip == "grouped_linear": - self.df_skip = GroupedLinear( - self.embedding_input_size, - self.df_decoder_hidden_size, - groups=config.df_decoder_linear_groups - ) - else: - raise NotImplementedError() - - self.df_out: nn.Module - out_dim = self.df_bins * self.df_out_ch - - self.df_out = nn.Sequential( - GroupedLinear( - input_size=self.df_decoder_hidden_size, - hidden_size=out_dim, - groups=config.df_decoder_linear_groups - ), - nn.Tanh() - ) - self.df_fc_a = nn.Sequential( - nn.Linear(self.df_decoder_hidden_size, 1), - nn.Sigmoid() - ) - - def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor: - # emb shape: [batch_size, time_steps, df_bins // 4 * channels] - b, t, _ = emb.shape - df_coefs, _ = self.df_gru(emb) - if self.df_skip is not None: - df_coefs = df_coefs + self.df_skip(emb) - # df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size] - - # c0 shape: [batch_size, channels, time_steps, df_bins] - c0 = self.df_convp(c0) - # c0 shape: [batch_size, df_order * 2, time_steps, df_bins] - c0 = c0.permute(0, 2, 3, 1) - # c0 shape: [batch_size, time_steps, df_bins, df_order * 2] - - df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order - # df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2] - df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch) - # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] - df_coefs = df_coefs + c0 - # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] - return df_coefs - - -class DfOutputReshapeMF(nn.Module): - """Coefficients output reshape for multiframe/MultiFrameModule - - Requires input of shape B, C, T, F, 2. - """ - - def __init__(self, df_order: int, df_bins: int): - super().__init__() - self.df_order = df_order - self.df_bins = df_bins - - def forward(self, coefs: torch.Tensor) -> torch.Tensor: - # [B, T, F, O*2] -> [B, O, T, F, 2] - new_shape = list(coefs.shape) - new_shape[-1] = -1 - new_shape.append(2) - coefs = coefs.view(new_shape) - coefs = coefs.permute(0, 3, 1, 2, 4) - return coefs - - -class Mask(nn.Module): - def __init__(self, use_post_filter: bool = False, eps: float = 1e-12): - super().__init__() - self.use_post_filter = use_post_filter - self.eps = eps - - def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: - """ - Post-Filter - - A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. - https://arxiv.org/abs/2008.04259 - - :param mask: Real valued mask, typically of shape [B, C, T, F]. - :param beta: Global gain factor. - :return: - """ - mask_sin = mask * torch.sin(np.pi * mask / 2) - mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) - return mask_pf - - def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - # spec shape: [batch_size, 1, time_steps, freq_bins, 2] - - if not self.training and self.use_post_filter: - mask = self.post_filter(mask) - - # mask shape: [batch_size, 1, time_steps, freq_bins] - mask = mask.unsqueeze(4) - # mask shape: [batch_size, 1, time_steps, freq_bins, 1] - return spec * mask - - -class DeepFiltering(nn.Module): - def __init__(self, - df_bins: int, - df_order: int, - lookahead: int = 0, - ): - super(DeepFiltering, self).__init__() - self.df_bins = df_bins - self.df_order = df_order - self.need_unfold = df_order > 1 - self.lookahead = lookahead - - self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0) - - def spec_unfold(self, spec: torch.Tensor): - """ - Pads and unfolds the spectrogram according to frame_size. - :param spec: complex Tensor, Spectrogram of shape [B, C, T, F]. - :return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size. - """ - if self.need_unfold: - # spec shape: [batch_size, freq_bins, time_steps] - spec_pad = self.pad(spec) - # spec_pad shape: [batch_size, 1, time_steps_pad, freq_bins] - spec_unfold = spec_pad.unfold(2, self.df_order, 1) - # spec_unfold shape: [batch_size, 1, time_steps, freq_bins, df_order] - return spec_unfold - else: - return spec.unsqueeze(-1) - - def forward(self, - spec: torch.Tensor, - coefs: torch.Tensor, - ): - # spec shape: [batch_size, 1, time_steps, freq_bins, 2] - spec = spec.contiguous() - spec_u = self.spec_unfold(torch.view_as_complex(spec)) - # spec_u shape: [batch_size, 1, time_steps, freq_bins, df_order] - - # coefs shape: [batch_size, df_order, time_steps, df_bins, 2] - coefs = torch.view_as_complex(coefs) - # coefs shape: [batch_size, df_order, time_steps, df_bins] - spec_f = spec_u.narrow(-2, 0, self.df_bins) - # spec_f shape: [batch_size, 1, time_steps, df_bins, df_order] - - coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:]) - # coefs shape: [batch_size, 1, df_order, time_steps, df_bins] - - spec_f = self.df(spec_f, coefs) - # spec_f shape: [batch_size, 1, time_steps, df_bins] - - if self.training: - spec = spec.clone() - spec[..., :self.df_bins, :] = torch.view_as_real(spec_f) - # spec shape: [batch_size, 1, time_steps, freq_bins, 2] - return spec - - @staticmethod - def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor: - """ - Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. - :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N]. - :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F]. - :return: (complex Tensor). Spectrogram of shape [B, C, T, F]. - """ - return torch.einsum("...tfn,...ntf->...tf", spec, coefs) - - -class NXDfNet(nn.Module): - def __init__(self, config: NXDfNetConfig): - super(NXDfNet, self).__init__() - self.config = config - - self.stft = DeepSTFT(win_size=config.win_size, freq_bins=config.freq_bins) - self.istft = DeepISTFT(win_size=config.win_size, freq_bins=config.freq_bins) - - self.encoder = Encoder(config) - self.decoder = Decoder(config) - - self.df_decoder = DfDecoder(config) - self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins) - self.df_op = DeepFiltering( - df_bins=config.df_bins, - df_order=config.df_order, - lookahead=config.df_lookahead, - ) - - self.mask = Mask(use_post_filter=config.use_post_filter) - - def forward(self, - noisy: torch.Tensor, - ): - """ - :param noisy: Tensor, shape: [batch_size, num_samples] - :return: - """ - spec = self.stft.forward(noisy) - # spec shape: [batch_size, freq_bins, time_steps, 2] - power_spec = torch.sum(torch.square(spec), dim=-1) - power_spec = power_spec.unsqueeze(1).permute(0, 1, 3, 2) - # power_spec shape: [batch_size, freq_bins, time_steps] - # power_spec shape: [batch_size, 1, freq_bins, time_steps] - # power_spec shape: [batch_size, 1, time_steps, freq_bins] - - df_spec = spec.permute(0, 3, 2, 1) - # df_spec shape: [batch_size, 2, time_steps, freq_bins] - df_spec = df_spec[..., :self.df_decoder.df_bins] - # df_spec shape: [batch_size, 2, time_steps, df_bins] - - # spec shape: [batch_size, freq_bins, time_steps, 2] - spec = torch.transpose(spec, dim0=1, dim1=2) - # spec shape: [batch_size, time_steps, freq_bins, 2] - spec = torch.unsqueeze(spec, dim=1) - # spec shape: [batch_size, 1, time_steps, freq_bins, 2] - - e0, e1, e2, e3, emb, c0, _, h = self.encoder.forward(power_spec, df_spec) - - mask = self.decoder.forward(emb, e3, e2, e1, e0) - # mask shape: [batch_size, 1, time_steps, freq_bins] - if torch.any(mask > 1) or torch.any(mask < 0): - raise AssertionError - - spec_m = self.mask.forward(spec, mask) - - # lsnr shape: [batch_size, time_steps, 1] - # lsnr = torch.transpose(lsnr, dim0=2, dim1=1) - # lsnr shape: [batch_size, 1, time_steps] - - df_coefs = self.df_decoder.forward(emb, c0) - df_coefs = self.df_out_transform(df_coefs) - # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2] - - spec_e = self.df_op.forward(spec.clone(), df_coefs) - # spec_e shape: [batch_size, 1, time_steps, freq_bins, 2] - - spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :] - - spec_e = torch.squeeze(spec_e, dim=1) - spec_e = spec_e.permute(0, 2, 1, 3) - # spec_e shape: [batch_size, freq_bins, time_steps, 2] - - denoise = self.istft.forward(spec_e) - # spec_e shape: [batch_size, freq_bins, time_steps, 2] - return denoise - - -class NXDfNetPretrainedModel(NXDfNet): - def __init__(self, - config: NXDfNetConfig, - ): - super(NXDfNetPretrainedModel, self).__init__( - config=config, - ) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - config = NXDfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - - model = cls(config) - - if os.path.isdir(pretrained_model_name_or_path): - ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) - else: - ckpt_file = pretrained_model_name_or_path - - with open(ckpt_file, "rb") as f: - state_dict = torch.load(f, map_location="cpu", weights_only=True) - model.load_state_dict(state_dict, strict=True) - return model - - def save_pretrained(self, - save_directory: Union[str, os.PathLike], - state_dict: Optional[dict] = None, - ): - - model = self - - if state_dict is None: - state_dict = model.state_dict() - - os.makedirs(save_directory, exist_ok=True) - - # save state dict - model_file = os.path.join(save_directory, MODEL_FILE) - torch.save(state_dict, model_file) - - # save config - config_file = os.path.join(save_directory, CONFIG_FILE) - self.config.to_yaml_file(config_file) - return save_directory - - -def main(): - - config = NXDfNetConfig() - model = NXDfNet(config=config) - - inputs = torch.randn(size=(1, 16000), dtype=torch.float32) - - denoise = model.forward(inputs) - print(denoise.shape) - return - - -if __name__ == "__main__": - main() diff --git a/toolbox/torchaudio/models/nx_dfnet/utils.py b/toolbox/torchaudio/models/nx_dfnet/utils.py deleted file mode 100644 index 8bfe77e68ec9c6254d4334187a28a3e014a04c8e..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_dfnet/utils.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -""" -https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py -""" -import math -import torch - - -def overlap_and_add(signal: torch.Tensor, frame_step: int): - """ - Reconstructs a signal from a framed representation. - - Adds potentially overlapping frames of a signal with shape - `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. - The resulting tensor has shape `[..., output_size]` where - - output_size = (frames - 1) * frame_step + frame_length - - Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py - - :param signal: Tensor, shape: [..., frames, frame_length]. All dimensions may be unknown, and rank must be at least 2. - :param frame_step: int, overlap offsets. Must be less than or equal to frame_length. - :return: Tensor, shape: [..., output_size]. - containing the overlap-added frames of signal's inner-most two dimensions. - output_size = (frames - 1) * frame_step + frame_length - """ - outer_dimensions = signal.size()[:-2] - frames, frame_length = signal.size()[-2:] - - subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor - subframe_step = frame_step // subframe_length - subframes_per_frame = frame_length // subframe_length - - output_size = frame_step * (frames - 1) + frame_length - output_subframes = output_size // subframe_length - - subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) - - frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step) - - frame = frame.clone().detach() - frame = frame.to(signal.device) - frame = frame.long() - - frame = frame.contiguous().view(-1) - - result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) - result.index_add_(-2, frame, subframe_signal) - result = result.view(*outer_dimensions, -1) - return result - - -if __name__ == "__main__": - pass diff --git a/toolbox/torchaudio/models/nx_mpnet/__init__.py b/toolbox/torchaudio/models/nx_mpnet/__init__.py deleted file mode 100644 index 8bc5155c67cae42f80e8126d1727b0edc1e02398..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_mpnet/causal_convolution/__init__.py b/toolbox/torchaudio/models/nx_mpnet/causal_convolution/__init__.py deleted file mode 100644 index 8bc5155c67cae42f80e8126d1727b0edc1e02398..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/causal_convolution/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_mpnet/causal_convolution/causal_conv2d.py b/toolbox/torchaudio/models/nx_mpnet/causal_convolution/causal_conv2d.py deleted file mode 100644 index ad2cb5efeeaebefcd8c161fdbc98031bfb8601e2..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/causal_convolution/causal_conv2d.py +++ /dev/null @@ -1,445 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -from typing import List, Tuple, Union - -import torch -import torch.nn as nn - -from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid2d - - -class SPConvTranspose2d(nn.Module): - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int]], - r=1 - ): - super(SPConvTranspose2d, self).__init__() - self.pad_freq = nn.ConstantPad2d((1, 1, 0, 0), value=0.) - self.out_channels = out_channels - self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1)) - self.r = r - - def forward(self, x: torch.Tensor): - x = self.pad_freq(x) - out = self.conv(x) - - b, c, t, f = out.shape - - out = out.view((b, self.r, c // self.r, t, f)) - out = out.permute(0, 2, 3, 4, 1) - out = out.contiguous().view((b, c // self.r, t, -1)) - return out - - -class CausalConv2dBlock(nn.Module): - def __init__(self, - in_channels: int, - out_channels: int, - dilation: int, - kernel_size: Tuple[int, int] = (2, 3), - ): - super(CausalConv2dBlock, self).__init__() - self.pad_length = dilation - - self.pad_time = nn.ConstantPad2d((0, 0, self.pad_length, 0), value=0.) - self.pad_freq = nn.ConstantPad2d((1, 1, 0, 0), value=0.) - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, dilation=(dilation, 1)) - self.norm = nn.InstanceNorm2d(out_channels, affine=True) - self.activation = nn.PReLU(out_channels) - - def forward(self, - x: torch.Tensor, - cache_pad: torch.Tensor = None, - ): - """ - - :param x: Tensor, shape: [batch_size, channels, time_steps, dim] - :param cache_pad: - :return: - """ - if cache_pad is None: - x = self.pad_time(x) - else: - x = torch.concat(tensors=[cache_pad, x], dim=2) - new_cache_pad = x[:, :, -self.pad_length:, :] - - x = self.pad_freq(x) - - x = self.conv(x) - x = self.norm(x) - x = self.activation(x) - return x, new_cache_pad - - -class CausalConv2dEncoder(nn.Module): - def __init__(self, - num_blocks: int, - hidden_size: int, - ): - super(CausalConv2dEncoder, self).__init__() - self.num_blocks = num_blocks - - self.blocks: List[CausalConv2dBlock] = nn.ModuleList([]) - for idx in range(num_blocks): - in_channels = hidden_size * (idx+1) - dilation = 2 ** idx - block = CausalConv2dBlock( - in_channels=in_channels, - out_channels=hidden_size, - dilation=dilation, - kernel_size=(2, 3), - ) - self.blocks.append(block) - - def forward(self, - x: torch.Tensor, - cache_pad_list: List[torch.Tensor] = None, - ): - """ - :param x: Tensor, shape: [batch_size, channels, time_steps, dim]. - :param cache_pad_list: List[Tensor] - :return: - """ - new_cache_pad_list = list() - - skip = x - for idx, block in enumerate(self.blocks): - x, new_cache_pad = block.forward( - skip, - cache_pad=None if cache_pad_list is None else cache_pad_list[idx] - ) - new_cache_pad_list.append(new_cache_pad) - skip = torch.cat([x, skip], dim=1) - # x shape: [batch_size, channels, time_steps, dim]. - return x, new_cache_pad_list - - def forward_chunk(self, - chunk: torch.Tensor, - cache_pad_list: List[torch.Tensor] = None, - ): - return self.forward(chunk, cache_pad_list) - - def forward_chunk_by_chunk(self, - x: torch.Tensor, - ): - """ - :param x: Tensor, shape: [batch_size, channels, time_steps, dim]. - :return: - """ - batch_size, channels, time_steps, _ = x.shape - - cache_pad_list = None - - outputs = list() - for idx in range(time_steps): - chunk = x[:, :, idx:idx+1, :] - - y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list) - outputs.append(y) - - outputs = torch.concat(outputs, dim=2) - return outputs - - -class DenseEncoder(nn.Module): - def __init__(self, - num_blocks: int, - in_channels: int, - out_channels: int, - ): - super(DenseEncoder, self).__init__() - self.dense_conv_1 = nn.Sequential( - nn.Conv2d(in_channels, out_channels, (1, 1)), - nn.InstanceNorm2d(out_channels, affine=True), - nn.PReLU(out_channels) - ) - self.dense_block = CausalConv2dEncoder( - num_blocks=num_blocks, hidden_size=out_channels, - ) - self.dense_conv_2 = nn.Sequential( - nn.Conv2d(out_channels, out_channels, (1, 3), (1, 2), padding=(0, 1)), - nn.InstanceNorm2d(out_channels, affine=True), - nn.PReLU(out_channels) - ) - - def forward(self, - x: torch.Tensor, - ): - """ - :param x: Tensor, shape: [batch_size, channels, time_steps, dim] - :return: - """ - x = self.dense_conv_1(x) - x, _ = self.dense_block.forward(x) - x = self.dense_conv_2(x) - # x shape: [b, c, t, f//2] - return x - - def forward_chunk(self, - x: torch.Tensor, - cache_pad_list: List[torch.Tensor] = None, - ): - x = self.dense_conv_1(x) - x, new_cache_pad_list = self.dense_block.forward(x, cache_pad_list) - x = self.dense_conv_2(x) - # x shape: [b, c, t, f//2] - return x, new_cache_pad_list - - def forward_chunk_by_chunk(self, - x: torch.Tensor, - ): - """ - :param x: Tensor, shape: [batch_size, channels, time_steps, dim]. - :return: - """ - batch_size, channels, time_steps, _ = x.shape - - cache_pad_list = None - - outputs = list() - for idx in range(time_steps): - chunk = x[:, :, idx:idx+1, :] - - y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list) - outputs.append(y) - - outputs = torch.concat(outputs, dim=2) - return outputs - - -class MaskDecoder(nn.Module): - def __init__(self, - num_blocks: int, - hidden_size: int, - out_channels: int = 1, - beta: float = 2.0, - n_fft: int = 512, - ): - super(MaskDecoder, self).__init__() - self.dense_block = CausalConv2dEncoder( - num_blocks=num_blocks, hidden_size=hidden_size, - ) - self.mask_conv = nn.Sequential( - SPConvTranspose2d(hidden_size, hidden_size, (1, 3), 2), - nn.InstanceNorm2d(hidden_size, affine=True), - nn.PReLU(hidden_size), - nn.Conv2d(hidden_size, out_channels, (1, 2)) - ) - self.lsigmoid = LearnableSigmoid2d(n_fft//2+1, beta=beta) - - def forward(self, - x: torch.Tensor, - ): - """ - - :param x: Tensor, shape: [batch_size, channels, time_steps, dim] - :return: - """ - x, _ = self.dense_block(x) - x = self.mask_conv(x) - # x shape: [batch_size, 1, time_steps, dim*2-1] - x = x.permute(0, 3, 2, 1).squeeze(-1) - # x shape: [b, f, t] - x = self.lsigmoid(x) - return x - - def forward_chunk(self, - x: torch.Tensor, - cache_pad_list: List[torch.Tensor] = None, - ): - x, new_cache_pad_list = self.dense_block(x, cache_pad_list) - x = self.mask_conv(x) - # x shape: [batch_size, 1, time_steps, dim*2-1] - x = x.permute(0, 3, 2, 1).squeeze(-1) - # x shape: [b, f, t] - x = self.lsigmoid(x) - return x, new_cache_pad_list - - def forward_chunk_by_chunk(self, - x: torch.Tensor, - ): - """ - :param x: Tensor, shape: [batch_size, channels, time_steps, dim]. - :return: - """ - batch_size, channels, time_steps, _ = x.shape - - cache_pad_list = None - - outputs = list() - for idx in range(time_steps): - chunk = x[:, :, idx:idx+1, :] - - y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list) - outputs.append(y) - - outputs = torch.concat(outputs, dim=2) - return outputs - - -class PhaseDecoder(nn.Module): - def __init__(self, - num_blocks: int, - hidden_size: int, - out_channels: int = 1, - ): - super(PhaseDecoder, self).__init__() - self.dense_block = CausalConv2dEncoder( - num_blocks=num_blocks, hidden_size=hidden_size, - ) - - self.phase_conv = nn.Sequential( - SPConvTranspose2d(hidden_size, hidden_size, (1, 3), 2), - nn.InstanceNorm2d(hidden_size, affine=True), - nn.PReLU(hidden_size) - ) - self.phase_conv_r = nn.Conv2d(hidden_size, out_channels, (1, 2)) - self.phase_conv_i = nn.Conv2d(hidden_size, out_channels, (1, 2)) - - def forward(self, - x: torch.Tensor, - ): - """ - :param x: Tensor, shape: [batch_size, channels, time_steps, dim] - :return: - """ - x, _ = self.dense_block(x) - - x = self.phase_conv(x) - x_r = self.phase_conv_r(x) - x_i = self.phase_conv_i(x) - x = torch.atan2(x_i, x_r) - x = x.permute(0, 3, 2, 1).squeeze(-1) - # x shape: [b, f, t] - return x - - def forward_chunk(self, - x: torch.Tensor, - cache_pad_list: List[torch.Tensor] = None, - ): - x, new_cache_pad_list = self.dense_block(x, cache_pad_list) - - x = self.phase_conv(x) - x_r = self.phase_conv_r(x) - x_i = self.phase_conv_i(x) - x = torch.atan2(x_i, x_r) - x = x.permute(0, 3, 2, 1).squeeze(-1) - # x shape: [b, f, t] - return x, new_cache_pad_list - - def forward_chunk_by_chunk(self, - x: torch.Tensor, - ): - """ - :param x: Tensor, shape: [batch_size, channels, time_steps, dim]. - :return: - """ - batch_size, channels, time_steps, _ = x.shape - - cache_pad_list = None - - outputs = list() - for idx in range(time_steps): - chunk = x[:, :, idx:idx+1, :] - - y, cache_pad_list = self.forward_chunk(chunk, cache_pad_list=cache_pad_list) - outputs.append(y) - - outputs = torch.concat(outputs, dim=2) - return outputs - - -def main1(): - - encoder = CausalConv2dEncoder( - num_blocks=3, hidden_size=8, - ) - - # x shape: [batch_size, channels, time_steps, dim] - x = torch.rand(size=(1, 8, 200, 32)) - x, new_cache_pad_list = encoder.forward(x) - print(x.shape) - for new_cache_pad in new_cache_pad_list: - print(new_cache_pad.shape) - - x = torch.rand(size=(1, 8, 200, 32)) - x = encoder.forward_chunk_by_chunk(x) - print(x.shape) - - return - - -def main2(): - - encoder = DenseEncoder( - num_blocks=3, in_channels=8, out_channels=8 - ) - - # x shape: [batch_size, channels, time_steps, dim] - x = torch.rand(size=(1, 8, 200, 32)) - x, new_cache_pad_list = encoder.forward(x) - print(x.shape) - for new_cache_pad in new_cache_pad_list: - print(new_cache_pad.shape) - - x = torch.rand(size=(1, 8, 200, 32)) - x = encoder.forward_chunk_by_chunk(x) - print(x.shape) - - return - - -def main3(): - - encoder = MaskDecoder( - num_blocks=3, hidden_size=64, out_channels=1, - n_fft=512, - ) - - # 512 // 2 + 1 = 257 - # 129 * 2 - 1 = 257 - # 257 // 2 + 1 = 129 - - # x shape: [batch_size, channels, time_steps, dim] - x = torch.rand(size=(1, 64, 201, 129)) - x, new_cache_pad_list = encoder.forward(x) - print(x.shape) - for new_cache_pad in new_cache_pad_list: - print(new_cache_pad.shape) - - x = torch.rand(size=(1, 64, 201, 129)) - x = encoder.forward_chunk_by_chunk(x) - print(x.shape) - - return - - - -def main(): - - encoder = PhaseDecoder( - num_blocks=3, hidden_size=64, out_channels=1, - ) - - # 512 // 2 + 1 = 257 - # 129 * 2 - 1 = 257 - # 257 // 2 + 1 = 129 - - # x shape: [batch_size, channels, time_steps, dim] - x = torch.rand(size=(1, 64, 201, 129)) - x, new_cache_pad_list = encoder.forward(x) - print(x.shape) - for new_cache_pad in new_cache_pad_list: - print(new_cache_pad.shape) - - x = torch.rand(size=(1, 64, 201, 129)) - x = encoder.forward_chunk_by_chunk(x) - print(x.shape) - - return - - -if __name__ == "__main__": - main() diff --git a/toolbox/torchaudio/models/nx_mpnet/configuration_nx_mpnet.py b/toolbox/torchaudio/models/nx_mpnet/configuration_nx_mpnet.py deleted file mode 100644 index 3a503fb2ba9308446d24a6be25114f24284324fd..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/configuration_nx_mpnet.py +++ /dev/null @@ -1,90 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -from toolbox.torchaudio.configuration_utils import PretrainedConfig - - -class NXMPNetConfig(PretrainedConfig): - """ - https://github.com/yxlu-0102/MP-SENet/blob/main/config.json - """ - def __init__(self, - sample_rate: int = 8000, - segment_size: int = 16000, - n_fft: int = 512, - win_size: int = 200, - hop_size: int = 80, - - dense_num_blocks: int = 4, - dense_hidden_size: int = 64, - - mask_num_blocks: int = 4, - mask_hidden_size: int = 64, - - phase_num_blocks: int = 4, - phase_hidden_size: int = 64, - - tsfm_hidden_size: int = 64, - tsfm_attention_heads: int = 4, - tsfm_num_blocks: int = 4, - tsfm_dropout_rate: float = 0.0, - tsfm_max_time_relative_position: int = 2048, - tsfm_max_freq_relative_position: int = 256, - tsfm_chunk_size: int = 1, - tsfm_num_left_chunks: int = 64, - tsfm_num_right_chunks: int = 2, - - discriminator_dim: int = 32, - discriminator_in_channel: int = 2, - - compress_factor: float = 0.3, - - batch_size: int = 4, - learning_rate: float = 0.0005, - adam_b1: float = 0.8, - adam_b2: float = 0.99, - lr_decay: float = 0.99, - seed: int = 1234, - - **kwargs - ): - super(NXMPNetConfig, self).__init__(**kwargs) - self.sample_rate = sample_rate - self.segment_size = segment_size - self.n_fft = n_fft - self.win_size = win_size - self.hop_size = hop_size - - self.dense_num_blocks = dense_num_blocks - self.dense_hidden_size = dense_hidden_size - - self.mask_num_blocks = mask_num_blocks - self.mask_hidden_size = mask_hidden_size - - self.phase_num_blocks = phase_num_blocks - self.phase_hidden_size = phase_hidden_size - - self.tsfm_hidden_size = tsfm_hidden_size - self.tsfm_attention_heads = tsfm_attention_heads - self.tsfm_num_blocks = tsfm_num_blocks - self.tsfm_dropout_rate = tsfm_dropout_rate - self.tsfm_max_time_relative_position = tsfm_max_time_relative_position - self.tsfm_max_freq_relative_position = tsfm_max_freq_relative_position - self.tsfm_chunk_size = tsfm_chunk_size - self.tsfm_num_left_chunks = tsfm_num_left_chunks - self.tsfm_num_right_chunks = tsfm_num_right_chunks - - self.discriminator_dim = discriminator_dim - self.discriminator_in_channel = discriminator_in_channel - - self.compress_factor = compress_factor - - self.batch_size = batch_size - self.learning_rate = learning_rate - self.adam_b1 = adam_b1 - self.adam_b2 = adam_b2 - self.lr_decay = lr_decay - self.seed = seed - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_mpnet/discriminator.py b/toolbox/torchaudio/models/nx_mpnet/discriminator.py deleted file mode 100644 index 53f1fbb5efeb5133c6535c13da9ad681646b8f17..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/discriminator.py +++ /dev/null @@ -1,102 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import os -from typing import Optional, Union - -import torch -import torch.nn as nn -import numpy as np -import torch.nn.functional as F -from pesq import pesq -from joblib import Parallel, delayed - -from toolbox.torchaudio.configuration_utils import CONFIG_FILE -from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig -from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid1d - - -class MetricDiscriminator(nn.Module): - def __init__(self, config: NXMPNetConfig): - super(MetricDiscriminator, self).__init__() - dim = config.discriminator_dim - in_channel = config.discriminator_in_channel - - self.layers = nn.Sequential( - nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)), - nn.InstanceNorm2d(dim, affine=True), - nn.PReLU(dim), - nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), - nn.InstanceNorm2d(dim*2, affine=True), - nn.PReLU(dim*2), - nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), - nn.InstanceNorm2d(dim*4, affine=True), - nn.PReLU(dim*4), - nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), - nn.InstanceNorm2d(dim*8, affine=True), - nn.PReLU(dim*8), - nn.AdaptiveMaxPool2d(1), - nn.Flatten(), - nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), - nn.Dropout(0.3), - nn.PReLU(dim*4), - nn.utils.spectral_norm(nn.Linear(dim*4, 1)), - LearnableSigmoid1d(1) - ) - - def forward(self, x, y): - xy = torch.stack((x, y), dim=1) - return self.layers(xy) - - -MODEL_FILE = "discriminator.pt" - - -class MetricDiscriminatorPretrainedModel(MetricDiscriminator): - def __init__(self, - config: NXMPNetConfig, - ): - super(MetricDiscriminatorPretrainedModel, self).__init__( - config=config, - ) - self.config = config - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - config = NXMPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - - model = cls(config) - - if os.path.isdir(pretrained_model_name_or_path): - ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) - else: - ckpt_file = pretrained_model_name_or_path - - with open(ckpt_file, "rb") as f: - state_dict = torch.load(f, map_location="cpu", weights_only=True) - model.load_state_dict(state_dict, strict=True) - return model - - def save_pretrained(self, - save_directory: Union[str, os.PathLike], - state_dict: Optional[dict] = None, - ): - - model = self - - if state_dict is None: - state_dict = model.state_dict() - - os.makedirs(save_directory, exist_ok=True) - - # save state dict - model_file = os.path.join(save_directory, MODEL_FILE) - torch.save(state_dict, model_file) - - # save config - config_file = os.path.join(save_directory, CONFIG_FILE) - self.config.to_yaml_file(config_file) - return save_directory - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_mpnet/inference_nx_mpnet.py b/toolbox/torchaudio/models/nx_mpnet/inference_nx_mpnet.py deleted file mode 100644 index b403e798a715e9464ba11bd78c7bb4002aeab2b1..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/inference_nx_mpnet.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import logging -from pathlib import Path -import shutil -import tempfile -import zipfile - -import librosa -import numpy as np -import torch -import torchaudio - -from project_settings import project_path -from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig -from toolbox.torchaudio.models.nx_mpnet.modeling_nx_mpnet import NXMPNetPretrainedModel, MODEL_FILE -from toolbox.torchaudio.models.nx_mpnet.utils import mag_pha_stft, mag_pha_istft - -logger = logging.getLogger("toolbox") - - -class InferenceNXMPNet(object): - def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"): - self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file - self.device = torch.device(device) - - logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}") - config, generator = self.load_models(self.pretrained_model_path_or_zip_file) - logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}") - - self.config = config - self.generator = generator - self.generator.to(device) - self.generator.eval() - - def load_models(self, model_path: str): - model_path = Path(model_path) - if model_path.name.endswith(".zip"): - with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip: - out_root = Path(tempfile.gettempdir()) / "nx_denoise" - out_root.mkdir(parents=True, exist_ok=True) - f_zip.extractall(path=out_root) - model_path = out_root / model_path.stem - - config = NXMPNetConfig.from_pretrained( - pretrained_model_name_or_path=model_path.as_posix(), - ) - generator = NXMPNetPretrainedModel.from_pretrained( - pretrained_model_name_or_path=model_path.as_posix(), - ) - generator.to(self.device) - generator.eval() - - shutil.rmtree(model_path) - return config, generator - - def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor: - if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: - raise AssertionError(f"The value range of audio samples should be between -1 and 1.") - - noisy_audio = noisy_audio.to(self.device) - - with torch.no_grad(): - noisy_mag, noisy_pha, noisy_com = mag_pha_stft( - noisy_audio, - self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor - ) - # mag_g, pha_g, com_g = self.generator.forward(noisy_mag, noisy_pha) - mag_g, pha_g, com_g = self.generator.forward_chunk_by_chunk(noisy_mag, noisy_pha) - audio_g = mag_pha_istft( - mag_g, pha_g, - self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor - ) - enhanced_audio = audio_g.detach() - - enhanced_audio = enhanced_audio[0] - return enhanced_audio - - -def main(): - model_zip_file = project_path / "trained_models/nx-mpnet-aishell-2-epoch.zip" - infer_mpnet = InferenceNXMPNet(model_zip_file) - - sample_rate = 8000 - noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_2.wav" - noisy_audio, _ = librosa.load( - noisy_audio_file.as_posix(), - sr=sample_rate, - ) - # noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)] - noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) - noisy_audio = noisy_audio.unsqueeze(dim=0) - - enhanced_audio = infer_mpnet.enhancement_by_tensor(noisy_audio) - - filename = "enhanced_audio.wav" - torchaudio.save(filename, enhanced_audio.unsqueeze(dim=0).detach().cpu(), sample_rate) - - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_mpnet/loss.py b/toolbox/torchaudio/models/nx_mpnet/loss.py deleted file mode 100644 index 475535006ee63213332fdc19ae91da1d81fe9cfc..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/loss.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import numpy as np -import torch - - -def anti_wrapping_function(x): - - return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) - - -def phase_losses(phase_r, phase_g): - - ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) - gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1))) - iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2))) - - return ip_loss, gd_loss, iaf_loss - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_mpnet/metrics.py b/toolbox/torchaudio/models/nx_mpnet/metrics.py deleted file mode 100644 index 78468894a56d4488021e83ea47e07c785a385269..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/metrics.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -from joblib import Parallel, delayed -import numpy as np -from pesq import pesq -from typing import List - -from pesq import cypesq - - -def run_pesq(clean_audio: np.ndarray, - noisy_audio: np.ndarray, - sample_rate: int = 16000, - mode: str = "wb", - ) -> float: - if sample_rate == 8000 and mode == "wb": - raise AssertionError(f"mode should be `nb` when sample_rate is 8000") - try: - pesq_score = pesq(sample_rate, clean_audio, noisy_audio, mode) - except cypesq.NoUtterancesError as e: - pesq_score = -1 - except Exception as e: - print(f"pesq failed. error type: {type(e)}, error text: {str(e)}") - pesq_score = -1 - return pesq_score - - -def run_batch_pesq(clean_audio_list: List[np.ndarray], - noisy_audio_list: List[np.ndarray], - sample_rate: int = 16000, - mode: str = "wb", - n_jobs: int = 4, - ) -> List[float]: - parallel = Parallel(n_jobs=n_jobs) - - parallel_tasks = list() - for clean_audio, noisy_audio in zip(clean_audio_list, noisy_audio_list): - parallel_task = delayed(run_pesq)(clean_audio, noisy_audio, sample_rate, mode) - parallel_tasks.append(parallel_task) - - pesq_score_list = parallel.__call__(parallel_tasks) - return pesq_score_list - - -def run_pesq_score(clean_audio_list: List[np.ndarray], - noisy_audio_list: List[np.ndarray], - sample_rate: int = 16000, - mode: str = "wb", - n_jobs: int = 4, - ) -> List[float]: - - pesq_score_list = run_batch_pesq(clean_audio_list=clean_audio_list, - noisy_audio_list=noisy_audio_list, - sample_rate=sample_rate, - mode=mode, - n_jobs=n_jobs, - ) - - pesq_score = np.mean(pesq_score_list) - return pesq_score - - -def main(): - clean_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) - noisy_audio = np.random.uniform(low=0, high=1, size=(2, 160000,)) - - clean_audio_list = list(clean_audio) - noisy_audio_list = list(noisy_audio) - - pesq_score_list = run_batch_pesq(clean_audio_list, noisy_audio_list) - print(pesq_score_list) - - pesq_score = run_pesq_score(clean_audio_list, noisy_audio_list) - print(pesq_score) - - return - - -if __name__ == "__main__": - main() diff --git a/toolbox/torchaudio/models/nx_mpnet/modeling_nx_mpnet.py b/toolbox/torchaudio/models/nx_mpnet/modeling_nx_mpnet.py deleted file mode 100644 index 849e022d1f4d6448fd6306a23104a86803e38971..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/modeling_nx_mpnet.py +++ /dev/null @@ -1,235 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import os -from typing import List, Optional, Union - -import numpy as np -import torch -import torch.nn as nn - -from toolbox.torchaudio.configuration_utils import CONFIG_FILE -from toolbox.torchaudio.models.nx_mpnet.causal_convolution.causal_conv2d import DenseEncoder, MaskDecoder, PhaseDecoder -from toolbox.torchaudio.models.nx_mpnet.transformers.transformers import TSTransformerEncoder -from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig - - -class NXMPNet(nn.Module): - def __init__(self, - config: NXMPNetConfig, - ): - super(NXMPNet, self).__init__() - self.config = config - - self.dense_encoder = DenseEncoder( - num_blocks=config.dense_num_blocks, - in_channels=2, - out_channels=config.dense_hidden_size, - ) - self.ts_transformer = TSTransformerEncoder( - input_size=config.dense_hidden_size, - hidden_size=config.tsfm_hidden_size, - attention_heads=config.tsfm_attention_heads, - num_blocks=config.tsfm_num_blocks, - dropout_rate=config.tsfm_dropout_rate, - max_time_relative_position=config.tsfm_max_time_relative_position, - max_freq_relative_position=config.tsfm_max_freq_relative_position, - chunk_size=config.tsfm_chunk_size, - num_left_chunks=config.tsfm_num_left_chunks, - num_right_chunks=config.tsfm_num_right_chunks, - ) - self.mask_decoder = MaskDecoder( - num_blocks=config.mask_num_blocks, - hidden_size=config.mask_hidden_size, - out_channels=1, - n_fft=config.n_fft, - ) - self.phase_decoder = PhaseDecoder( - num_blocks=config.phase_num_blocks, - hidden_size=config.phase_hidden_size, - out_channels=1, - ) - - def forward(self, noisy_amp, noisy_pha): - """ - :param noisy_amp: Tensor, shape: [b, f, t] - :param noisy_pha: Tensor, shape: [b, f, t] - :return: - """ - x = torch.stack((noisy_amp, noisy_pha), dim=-1).permute(0, 3, 2, 1) # [B, 2, T, F] - # x shape: [b, 2, t, f] - x = self.dense_encoder.forward(x) - # x shape: [b, c, t, f//2] - - x = self.ts_transformer.forward(x) - # x shape: [b, c, t, f//2] - - denoised_amp = noisy_amp * self.mask_decoder(x) - denoised_pha = self.phase_decoder(x) - denoised_com = torch.stack( - tensors=( - denoised_amp * torch.cos(denoised_pha), - denoised_amp * torch.sin(denoised_pha) - ), - dim=-1 - ) - - return denoised_amp, denoised_pha, denoised_com - - def forward_chunk(self, - chunk_noisy_amp: torch.Tensor, - chunk_noisy_pha: torch.Tensor, - cache: dict, - ): - dense_encoder_cache_pad_list = cache["dense_encoder_cache_pad_list"] - mask_decoder_cache_pad_list = cache["mask_decoder_cache_pad_list"] - phase_decoder_cache_pad_list = cache["phase_decoder_cache_pad_list"] - ts_transformer_cache_att_list = cache["ts_transformer_cache_att_list"] - max_att_cache_length = cache["max_att_cache_length"] - - x = torch.stack((chunk_noisy_amp, chunk_noisy_pha), dim=-1).permute(0, 3, 2, 1) # [B, 2, T, F] - # x shape: [b, 2, t, f] - x, new_dense_encoder_cache_pad_list = self.dense_encoder.forward_chunk(x, cache_pad_list=dense_encoder_cache_pad_list) - # x shape: [b, c, t, f//2] - - x, new_ts_transformer_cache_att_list = self.ts_transformer.forward_chunk( - x, - max_att_cache_length=max_att_cache_length, - cache_att_list=ts_transformer_cache_att_list - ) - # x shape: [b, c, t, f//2] - - mask, new_mask_decoder_cache_pad_list = self.mask_decoder.forward_chunk(x, cache_pad_list=mask_decoder_cache_pad_list) - denoised_amp = chunk_noisy_amp * mask - denoised_pha, new_phase_decoder_cache_pad_list = self.phase_decoder.forward_chunk(x, cache_pad_list=phase_decoder_cache_pad_list) - denoised_com = torch.stack( - tensors=( - denoised_amp * torch.cos(denoised_pha), - denoised_amp * torch.sin(denoised_pha) - ), - dim=-1 - ) - - cache = { - "dense_encoder_cache_pad_list": new_dense_encoder_cache_pad_list, - "mask_decoder_cache_pad_list": new_mask_decoder_cache_pad_list, - "phase_decoder_cache_pad_list": new_phase_decoder_cache_pad_list, - "ts_transformer_cache_att_list": new_ts_transformer_cache_att_list, - "max_att_cache_length": max_att_cache_length, - - } - - return denoised_amp, denoised_pha, denoised_com, cache - - def forward_chunk_by_chunk(self, - noisy_amp: torch.Tensor, - noisy_pha: torch.Tensor, - ): - """ - :param noisy_amp: Tensor, shape: [b, f, t] - :param noisy_pha: Tensor, shape: [b, f, t] - :return: - """ - b, f, t = noisy_amp.shape - - max_att_cache_length = (self.config.tsfm_num_left_chunks + self.config.tsfm_num_right_chunks) * self.config.tsfm_chunk_size - - cache = { - "dense_encoder_cache_pad_list": None, - "mask_decoder_cache_pad_list": None, - "phase_decoder_cache_pad_list": None, - "ts_transformer_cache_att_list": None, - "max_att_cache_length": max_att_cache_length, - - } - - denoised_amp_list = list() - denoised_pha_list = list() - denoised_com_list = list() - - for idx in range(t): - chunk_noisy_amp = noisy_amp[:, :, idx:idx+1] - chunk_noisy_pha = noisy_pha[:, :, idx:idx+1] - - denoised_amp, denoised_pha, denoised_com, cache = self.forward_chunk(chunk_noisy_amp, chunk_noisy_pha, cache) - denoised_amp_list.append(denoised_amp) - denoised_pha_list.append(denoised_pha) - denoised_com_list.append(denoised_com) - - denoised_amp_list = torch.concat(denoised_amp_list, dim=2) - denoised_pha_list = torch.concat(denoised_pha_list, dim=2) - denoised_com_list = torch.concat(denoised_com_list, dim=2) - return denoised_amp_list, denoised_pha_list, denoised_com_list - - -MODEL_FILE = "generator.pt" - - -class NXMPNetPretrainedModel(NXMPNet): - def __init__(self, - config: NXMPNetConfig, - ): - super(NXMPNetPretrainedModel, self).__init__( - config=config, - ) - self.config = config - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - config = NXMPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - - model = cls(config) - - if os.path.isdir(pretrained_model_name_or_path): - ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) - else: - ckpt_file = pretrained_model_name_or_path - - with open(ckpt_file, "rb") as f: - state_dict = torch.load(f, map_location="cpu", weights_only=True) - model.load_state_dict(state_dict, strict=True) - return model - - def save_pretrained(self, - save_directory: Union[str, os.PathLike], - state_dict: Optional[dict] = None, - ): - - model = self - - if state_dict is None: - state_dict = model.state_dict() - - os.makedirs(save_directory, exist_ok=True) - - # save state dict - model_file = os.path.join(save_directory, MODEL_FILE) - torch.save(state_dict, model_file) - - # save config - config_file = os.path.join(save_directory, CONFIG_FILE) - self.config.to_yaml_file(config_file) - return save_directory - - -def main(): - config = NXMPNetConfig() - - model = NXMPNet(config) - - noisy_amp = torch.rand([1, 257, 201], dtype=torch.float32) - noisy_pha = torch.rand([1, 257, 201], dtype=torch.float32) - - denoised_amp, denoised_pha, denoised_com = model.forward(noisy_amp, noisy_pha) - print(denoised_amp.shape) - print(denoised_pha.shape) - print(denoised_com.shape) - - denoised_amp, denoised_pha, denoised_com = model.forward_chunk_by_chunk(noisy_amp, noisy_pha) - print(denoised_amp.shape) - print(denoised_pha.shape) - print(denoised_com.shape) - return - - -if __name__ == "__main__": - main() diff --git a/toolbox/torchaudio/models/nx_mpnet/transformers/__init__.py b/toolbox/torchaudio/models/nx_mpnet/transformers/__init__.py deleted file mode 100644 index 8bc5155c67cae42f80e8126d1727b0edc1e02398..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/transformers/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_mpnet/transformers/attention.py b/toolbox/torchaudio/models/nx_mpnet/transformers/attention.py deleted file mode 100644 index 9492d0498e8dcfd2afc02e853c491100a6ba18f7..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/transformers/attention.py +++ /dev/null @@ -1,263 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import math -from typing import Tuple - -import torch -import torch.nn as nn - - -class MultiHeadSelfAttention(nn.Module): - def __init__(self, n_head: int, n_feat: int, dropout_rate: float): - """ - :param n_head: int. the number of heads. - :param n_feat: int. the number of features. - :param dropout_rate: float. dropout rate. - """ - super().__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.dropout = nn.Dropout(p=dropout_rate) - - def forward_qkv(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - transform query, key and value. - :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat). - :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat). - :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat). - :return: - """ - n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - q = q.transpose(1, 2) # (batch, head, time1, d_k) - k = k.transpose(1, 2) # (batch, head, time2, d_k) - v = v.transpose(1, 2) # (batch, head, time2, d_k) - - return q, k, v - - def forward_attention(self, - value: torch.Tensor, - scores: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) - ) -> torch.Tensor: - """ - compute attention context vector. - :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, time2, d_k). - :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, time1, time2). - :param mask: torch.Tensor. mask. shape=(batch_size, 1, time2) or - (batch_size, time1, time2), (0, 0, 0) means fake mask. - :return: torch.Tensor. transformed value. (batch_size, time1, d_model). - weighted by the attention score (batch_size, time1, time2). - """ - n_batch = value.size(0) - # NOTE: When will `if mask.size(2) > 0` be True? - # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the - # 1st chunk to ease the onnx export.] - # 2. pytorch training - if mask.size(2) > 0: # time2 > 0 - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - # For last chunk, time2 might be larger than scores.size(-1) - mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) - scores = scores.masked_fill(mask, -float('inf')) - attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) - - # NOTE: When will `if mask.size(2) > 0` be False? - # 1. onnx(16/-1, -1/-1, 16/0) - # 2. jit (16/-1, -1/-1, 16/0, 16/4) - else: - attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat) - - return self.linear_out(x) # (batch, time1, n_feat) - - def forward(self, - x: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: - - q, k, v = self.forward_qkv(x, x, x) - - if cache.size(0) > 0: - key_cache, value_cache = torch.split( - cache, cache.size(-1) // 2, dim=-1) - k = torch.cat([key_cache, k], dim=2) - v = torch.cat([value_cache, v], dim=2) - # NOTE: We do cache slicing in encoder.forward_chunk, since it's - # non-trivial to calculate `next_cache_start` here. - new_cache = torch.cat((k, v), dim=-1) - - scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) - return self.forward_attention(v, scores, mask), new_cache - - -class RelativeMultiHeadSelfAttention(nn.Module): - - def __init__(self, n_head: int, n_feat: int, dropout_rate: float, max_relative_position: int = 5120): - """ - :param n_head: int. the number of heads. - :param n_feat: int. the number of features. - :param dropout_rate: float. dropout rate. - :param max_relative_position: int. maximum relative position for relative position encoding. - """ - super().__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.dropout = nn.Dropout(p=dropout_rate) - - # Relative position encoding - self.max_relative_position = max_relative_position - self.relative_position_k = nn.Parameter(torch.randn(max_relative_position * 2 + 1, self.d_k)) - - def forward_qkv(self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - transform query, key and value. - :param query: torch.Tensor. query tensor. shape=(batch_size, time1, n_feat). - :param key: torch.Tensor. key tensor. shape=(batch_size, time2, n_feat). - :param value: torch.Tensor. value tensor. shape=(batch_size, time2, n_feat). - :return: - """ - n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - q = q.transpose(1, 2) # (batch, head, time1, d_k) - k = k.transpose(1, 2) # (batch, head, time2, d_k) - v = v.transpose(1, 2) # (batch, head, time2, d_k) - - return q, k, v - - def forward_attention(self, - value: torch.Tensor, - scores: torch.Tensor, - mask: torch.Tensor = None - ) -> torch.Tensor: - """ - compute attention context vector. - :param value: torch.Tensor. transformed value. shape=(batch_size, n_head, key_time_steps, d_k). - :param scores: torch.Tensor. attention score. shape=(batch_size, n_head, query_time_steps, key_time_steps). - :param mask: torch.Tensor. mask. shape=(batch_size, 1, key_time_steps) or (batch_size, query_time_steps, key_time_steps). - :return: torch.Tensor. transformed value. (batch_size, query_time_steps, d_model). - weighted by the attention score (batch_size, query_time_steps, key_time_steps). - """ - n_batch = value.size(0) - if mask is not None: - mask = mask.unsqueeze(1).eq(0) - # mask shape: [batch_size, 1, query_time_steps, key_time_steps] - scores = scores.masked_fill(mask, -float('inf')) - attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) - else: - attn = torch.softmax(scores, dim=-1) - # attn shape: [batch_size, n_head, query_time_steps, key_time_steps] - - p_attn = self.dropout(attn) - - x = torch.matmul(p_attn, value) - # x shape: [batch_size, n_head, query_time_steps, d_k] - x = x.transpose(1, 2) - # x shape: [batch_size, query_time_steps, n_head, d_k] - - x = x.contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, n_feat) - # x shape: [batch_size, query_time_steps, n_head * d_k] - # x shape: [batch_size, query_time_steps, n_feat] - - x = self.linear_out(x) - # x shape: [batch_size, query_time_steps, n_feat] - return x - - def relative_position_encoding(self, length: int) -> torch.Tensor: - """ - Generate relative position encoding. - :param length: int. length of the sequence. - :return: torch.Tensor. relative position encoding. shape=(length, length, d_k). - """ - range_vec = torch.arange(length) - distance_mat = range_vec.unsqueeze(0) - range_vec.unsqueeze(1) - distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position) - final_mat = distance_mat_clipped + self.max_relative_position - return final_mat - - def forward(self, - x: torch.Tensor, - mask: torch.Tensor = None, - cache: torch.Tensor = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - # attention! self attention. - - q, k, v = self.forward_qkv(x, x, x) - # q k v shape: [batch_size, self.h, query_time_steps, self.d_k] - - if cache is not None: - key_cache, value_cache = torch.split( - cache, cache.size(-1) // 2, dim=-1) - k = torch.cat([key_cache, k], dim=2) - v = torch.cat([value_cache, v], dim=2) - - # new_cache shape: [batch_size, self.h, time_steps, self.d_k * 2] - new_cache = torch.cat((k, v), dim=-1) - - # native_scores shape: [batch_size, self.h, q_time_steps, k_time_steps] - native_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) - - # Compute relative position encoding - q_length, k_length = q.size(2), k.size(2) - relative_position = self.relative_position_encoding(k_length) - - relative_position = relative_position[-q_length:] - - relative_position_k = self.relative_position_k[relative_position.view(-1)].view(q_length, k_length, -1) - - relative_position_k = relative_position_k.unsqueeze(0).unsqueeze(0) # (1, 1, q_length, k_length, d_k) - relative_position_k = relative_position_k.expand(q.size(0), q.size(1), -1, -1, -1) # (batch, head, q_length, k_length, d_k) - - relative_position_scores = torch.matmul(q.unsqueeze(3), relative_position_k.transpose(-2, -1)).squeeze(3) / math.sqrt(self.d_k) - # relative_position_scores shape: [batch_size, self.h, q_time_steps, k_time_steps] - - # score - scores = native_scores + relative_position_scores - - return self.forward_attention(v, scores, mask), new_cache - - -def main(): - rel_attention = RelativeMultiHeadSelfAttention(n_head=4, n_feat=256, dropout_rate=0.1) - - x = torch.ones(size=(1, 200, 256), dtype=torch.float32) - xt, new_cache = rel_attention.forward(x, x, x) - - # x = torch.ones(size=(1, 1, 256), dtype=torch.float32) - # cache = torch.ones(size=(1, 4, 199, 128), dtype=torch.float32) - # xt, new_cache = rel_attention.forward(x, x, x, cache=cache) - - print(xt.shape) - print(new_cache.shape) - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_mpnet/transformers/mask.py b/toolbox/torchaudio/models/nx_mpnet/transformers/mask.py deleted file mode 100644 index 087be346c5619573cf5350290dfd3a70a4b685a5..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/transformers/mask.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import torch - - -def make_pad_mask(lengths: torch.Tensor, - max_len: int = 0, - ) -> torch.Tensor: - batch_size = lengths.size(0) - max_len = max_len if max_len > 0 else lengths.max().item() - seq_range = torch.arange( - 0, - max_len, - dtype=torch.int64, - device=lengths.device - ) - seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) - seq_length_expand = lengths.unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - return mask - - - -def subsequent_chunk_mask( - size: int, - chunk_size: int, - num_left_chunks: int = -1, - num_right_chunks: int = 0, - device: torch.device = torch.device("cpu"), -) -> torch.Tensor: - """ - Create mask for subsequent steps (size, size) with chunk size, - this is for streaming encoder - - Examples: - > subsequent_chunk_mask(4, 2) - [[1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 1, 1], - [1, 1, 1, 1]] - - :param size: int. size of mask. - :param chunk_size: int. size of chunk. - :param num_left_chunks: int. number of left chunks. <0: use full chunk. >=0 use num_left_chunks. - :param num_right_chunks: int. number of right chunks. - :param device: torch.device. "cpu" or "cuda" or torch.Tensor.device. - :return: torch.Tensor. mask - """ - - ret = torch.zeros(size, size, device=device, dtype=torch.bool) - for i in range(size): - if num_left_chunks < 0: - start = 0 - else: - start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) - ending = min((i // chunk_size + 1 + num_right_chunks) * chunk_size, size) - ret[i, start:ending] = True - return ret - - -def main(): - chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2) - print(chunk_mask) - - chunk_mask = subsequent_chunk_mask(size=8, chunk_size=2, num_left_chunks=2, num_right_chunks=1) - print(chunk_mask) - - chunk_mask = subsequent_chunk_mask(size=9, chunk_size=2, num_left_chunks=2, num_right_chunks=1) - print(chunk_mask) - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_mpnet/transformers/transformers.py b/toolbox/torchaudio/models/nx_mpnet/transformers/transformers.py deleted file mode 100644 index 14331630facee251c3f0def8e7e590d7860f1838..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/transformers/transformers.py +++ /dev/null @@ -1,477 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -from typing import Dict, Optional, Tuple, List, Union - -import torch -import torch.nn as nn - -from toolbox.torchaudio.models.nx_clean_unet.transformers.mask import subsequent_chunk_mask -from toolbox.torchaudio.models.nx_clean_unet.transformers.attention import MultiHeadSelfAttention, RelativeMultiHeadSelfAttention - - -class PositionwiseFeedForward(nn.Module): - def __init__(self, - input_dim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.ReLU()): - """ - FeedForward are applied on each position of the sequence. - the output dim is same with the input dim. - - :param input_dim: int. input dimension. - :param hidden_units: int. the number of hidden units. - :param dropout_rate: float. dropout rate. - :param activation: torch.nn.Module. activation function. - """ - super(PositionwiseFeedForward, self).__init__() - self.w_1 = torch.nn.Linear(input_dim, hidden_units) - self.activation = activation - self.dropout = torch.nn.Dropout(dropout_rate) - self.w_2 = torch.nn.Linear(hidden_units, input_dim) - - def forward(self, xs: torch.Tensor) -> torch.Tensor: - """ - Forward function. - :param xs: torch.Tensor. input tensor. shape=(batch_size, max_length, dim). - :return: output tensor. shape=(batch_size, max_length, dim). - """ - return self.w_2(self.dropout(self.activation(self.w_1(xs)))) - - -class TransformerBlock(nn.Module): - def __init__(self, - input_dim: int, - dropout_rate: float = 0.1, - n_heads: int = 4, - max_relative_position: int = 5120 - ): - super().__init__() - self.norm1 = nn.LayerNorm(input_dim, eps=1e-5) - self.attention = RelativeMultiHeadSelfAttention( - n_head=n_heads, - n_feat=input_dim, - dropout_rate=dropout_rate, - max_relative_position=max_relative_position, - ) - - self.dropout1 = nn.Dropout(dropout_rate) - self.norm2 = nn.LayerNorm(input_dim, eps=1e-5) - self.ffn = PositionwiseFeedForward( - input_dim=input_dim, - hidden_units=input_dim, - dropout_rate=dropout_rate - ) - self.dropout2 = nn.Dropout(dropout_rate) - self.norm3 = nn.LayerNorm(input_dim, eps=1e-5) - - def forward( - self, - x: torch.Tensor, - mask: torch.Tensor = None, - attention_cache: torch.Tensor = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - - :param x: torch.Tensor. shape=(batch_size, time, input_dim). - :param mask: torch.Tensor. mask tensor for the input. shape=(batch_size, time,time). - :param attention_cache: torch.Tensor. cache tensor of the KEY & VALUE - shape=(batch_size=1, head, cache_t1, d_k * 2), head * d_k == input_dim. - :return: - torch.Tensor: Output tensor (batch_size, time, input_dim). - torch.Tensor: att_cache tensor, (batch_size=1, head, cache_t1 + time, d_k * 2). - """ - xt = self.norm1(x) - - x_att, new_att_cache = self.attention.forward( - xt, mask=mask, cache=attention_cache - ) - x = x + self.dropout1(xt) - xt = self.norm2(x) - xt = self.ffn.forward(xt) - x = x + self.dropout2(xt) - - x = self.norm3(x) - - return x, new_att_cache - - -class TransformerEncoder(nn.Module): - """ - https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/encoder.py#L364 - """ - def __init__(self, - input_size: int = 64, - hidden_size: int = 256, - attention_heads: int = 4, - num_blocks: int = 6, - dropout_rate: float = 0.1, - max_relative_position: int = 1024, - chunk_size: int = 1, - num_left_chunks: int = 128, - num_right_chunks: int = 2, - ): - super().__init__() - self.input_size = input_size - self.hidden_size = hidden_size - - self.max_relative_position = max_relative_position - self.chunk_size = chunk_size - self.num_left_chunks = num_left_chunks - self.num_right_chunks = num_right_chunks - - self.input_linear = nn.Linear( - in_features=self.input_size, - out_features=self.hidden_size, - ) - - self.encoder_layer_list = torch.nn.ModuleList([ - TransformerBlock( - input_dim=hidden_size, - n_heads=attention_heads, - dropout_rate=dropout_rate, - max_relative_position=max_relative_position, - ) for _ in range(num_blocks) - ]) - - self.output_linear = nn.Linear( - in_features=self.hidden_size, - out_features=self.input_size, - ) - - def forward(self, - xs: torch.Tensor, - ): - """ - :param xs: Tensor, shape: [batch_size, time_steps, input_size] - :return: Tensor, shape: [batch_size, time_steps, input_size] - """ - batch_size, time_steps, _ = xs.shape - # xs shape: [batch_size, time_steps, input_size] - xs = self.input_linear.forward(xs) - # xs shape: [batch_size, time_steps, hidden_size] - - chunk_masks = subsequent_chunk_mask( - size=time_steps, - chunk_size=self.chunk_size, - num_left_chunks=self.num_left_chunks, - num_right_chunks=self.num_right_chunks, - ) - chunk_masks = chunk_masks.to(xs.device) - # chunk_masks shape: [time_steps, time_steps] - chunk_masks = torch.broadcast_to(chunk_masks, size=(batch_size, time_steps, time_steps)) - # chunk_masks shape: [batch_size, time_steps, time_steps] - - for encoder_layer in self.encoder_layer_list: - xs, _ = encoder_layer.forward(xs, chunk_masks) - - # xs shape: [batch_size, time_steps, hidden_size] - xs = self.output_linear.forward(xs) - # xs shape: [batch_size, time_steps, input_size] - - return xs - - def forward_chunk(self, - xs: torch.Tensor, - max_att_cache_length: int, - attention_cache: torch.Tensor = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - - :param xs: - :param max_att_cache_length: - :param attention_cache: Tensor, [num_layers, ...] - :return: - """ - # xs shape: [batch_size, time_steps, input_size] - xs = self.input_linear.forward(xs) - # xs shape: [batch_size, time_steps, hidden_size] - - r_att_cache = [] - for idx, encoder_layer in enumerate(self.encoder_layer_list): - xs, new_att_cache = encoder_layer.forward( - x=xs, attention_cache=attention_cache[idx] if attention_cache is not None else None, - ) - # new_att_cache shape: [batch_size, n_heads, time_steps, dim] - if new_att_cache.size(2) > max_att_cache_length: - begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size - end = self.num_right_chunks * self.chunk_size - new_att_cache = new_att_cache[:, :, -begin:-end, :] - r_att_cache.append(new_att_cache) - - r_att_cache = torch.stack(r_att_cache, dim=0) - - # xs shape: [batch_size, time_steps, hidden_size] - xs = self.output_linear.forward(xs) - # xs shape: [batch_size, time_steps, input_size] - - return xs, r_att_cache - - def forward_chunk_by_chunk( - self, - xs: torch.Tensor, - ) -> torch.Tensor: - - batch_size, time_steps, _ = xs.shape - - # attention_cache shape: [num_blocks, attention_heads, self.num_left_chunks * self.chunk_size, n_heads * d_k * 2] - max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size - attention_cache = None - - outputs = [] - for idx in range(0, time_steps, self.chunk_size): - begin = idx - end = begin + self.chunk_size * (self.num_right_chunks + 1) - chunk_xs = xs[:, begin:end, :] - # print(f"begin: {begin}, end: {end}, length: {chunk_xs.size(1)}") - - ys, attention_cache = self.forward_chunk( - xs=chunk_xs, - max_att_cache_length=max_att_cache_length, - attention_cache=attention_cache, - ) - - # ys shape: [batch_size, self.chunk_size * (self.num_right_chunks + 1), input_size] - ys = ys[:, :self.chunk_size, :] - - outputs.append(ys) - - ys = torch.cat(outputs, 1) - return ys - - -class TSTransformerBlock(nn.Module): - def __init__(self, - input_dim: int, - dropout_rate: float = 0.1, - n_heads: int = 4, - max_time_relative_position: int = 2048, - max_freq_relative_position: int = 256, - ): - super(TSTransformerBlock, self).__init__() - self.time_transformer = TransformerBlock(input_dim, dropout_rate, n_heads, max_time_relative_position) - self.freq_transformer = TransformerBlock(input_dim, dropout_rate, n_heads, max_freq_relative_position) - - def forward(self, - x: torch.Tensor, - mask: torch.Tensor = None, - attention_cache: torch.Tensor = None, - ): - """ - - :param x: Tensor. shape: [batch_size, hidden_size, time_steps, input_size] - :param mask: Tensor. shape: [time_steps, time_steps] - :param attention_cache: - :return: - """ - b, c, t, f = x.size() - - mask = None if mask is None else torch.broadcast_to(mask, size=(b*f, t, t)) - - x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c) - x_, new_att_cache = self.time_transformer.forward(x, mask, attention_cache) - x = x_ + x - x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c) - x_, _ = self.freq_transformer.forward(x) - x = x_ + x - x = x.view(b, t, f, c).permute(0, 3, 1, 2) - return x, new_att_cache - - -class TSTransformerEncoder(nn.Module): - def __init__(self, - input_size: int = 64, - hidden_size: int = 256, - attention_heads: int = 4, - num_blocks: int = 6, - dropout_rate: float = 0.1, - max_time_relative_position: int = 2048, - max_freq_relative_position: int = 256, - chunk_size: int = 1, - num_left_chunks: int = 128, - num_right_chunks: int = 2, - ): - super().__init__() - self.input_size = input_size - self.hidden_size = hidden_size - - self.max_time_relative_position = max_time_relative_position - self.max_freq_relative_position = max_freq_relative_position - self.chunk_size = chunk_size - self.num_left_chunks = num_left_chunks - self.num_right_chunks = num_right_chunks - - self.input_linear = nn.Linear( - in_features=self.input_size, - out_features=self.hidden_size, - ) - - self.encoder_layer_list = torch.nn.ModuleList([ - TSTransformerBlock( - input_dim=hidden_size, - n_heads=attention_heads, - dropout_rate=dropout_rate, - max_time_relative_position=max_time_relative_position, - max_freq_relative_position=max_freq_relative_position, - ) for _ in range(num_blocks) - ]) - - self.output_linear = nn.Linear( - in_features=self.hidden_size, - out_features=self.input_size, - ) - - def forward(self, - xs: torch.Tensor, - ): - """ - :param xs: Tensor, shape: [batch_size, channels, time_steps, input_size] - :return: Tensor, shape: [batch_size, channels, time_steps, input_size] - """ - batch_size, channels, time_steps, _ = xs.shape - # xs shape: [batch_size, channels, time_steps, input_size] - xs = xs.permute(0, 3, 2, 1) - # xs shape: [batch_size, input_size, time_steps, channels] - xs = self.input_linear.forward(xs) - # xs shape: [batch_size, input_size, time_steps, hidden_size] - xs = xs.permute(0, 3, 2, 1) - # xs shape: [batch_size, hidden_size, time_steps, input_size] - - chunk_masks = subsequent_chunk_mask( - size=time_steps, - chunk_size=self.chunk_size, - num_left_chunks=self.num_left_chunks, - num_right_chunks=self.num_right_chunks, - ) - chunk_masks = chunk_masks.to(xs.device) - # chunk_masks shape: [time_steps, time_steps] - - for encoder_layer in self.encoder_layer_list: - xs, _ = encoder_layer.forward(xs, chunk_masks) - # xs shape: [batch_size, hidden_size, time_steps, input_size] - xs = xs.permute(0, 3, 2, 1) - # xs shape: [batch_size, input_size, time_steps, hidden_size] - xs = self.output_linear.forward(xs) - # xs shape: [batch_size, input_size, time_steps, channels] - xs = xs.permute(0, 3, 2, 1) - # xs shape: [batch_size, channels, time_steps, input_size] - - return xs - - def forward_chunk(self, - xs: torch.Tensor, - max_att_cache_length: int, - cache_att_list: List[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - - :param xs: - :param max_att_cache_length: - :param cache_att_list: Tensor, shape: [num_layers, ...] - :return: - """ - # xs shape: [batch_size, channels, time_steps, input_size] - xs = xs.permute(0, 3, 2, 1) - xs = self.input_linear.forward(xs) - xs = xs.permute(0, 3, 2, 1) - # xs shape: [batch_size, hidden_size, time_steps, input_size] - - new_cache_att_list = list() - for idx, encoder_layer in enumerate(self.encoder_layer_list): - xs, new_cache_att = encoder_layer.forward( - x=xs, attention_cache=cache_att_list[idx] if cache_att_list is not None else None, - ) - # new_att_cache shape: [b*f, n_heads, time_steps, dim] - if new_cache_att.size(2) > max_att_cache_length: - begin = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size - end = self.num_right_chunks * self.chunk_size - new_cache_att = new_cache_att[:, :, -begin:-end, :] - new_cache_att_list.append(new_cache_att) - - # xs shape: [batch_size, hidden_size, time_steps, input_size] - xs = xs.permute(0, 3, 2, 1) - xs = self.output_linear.forward(xs) - xs = xs.permute(0, 3, 2, 1) - # xs shape: [batch_size, channels, time_steps, input_size] - - return xs, new_cache_att_list - - def forward_chunk_by_chunk( - self, - xs: torch.Tensor, - ) -> torch.Tensor: - - batch_size, channels, time_steps, _ = xs.shape - - max_att_cache_length = (self.num_left_chunks + self.num_right_chunks) * self.chunk_size - cache_att_list = None - - outputs = [] - for idx in range(0, time_steps, self.chunk_size): - begin = idx - end = begin + self.chunk_size * (self.num_right_chunks + 1) - chunk_xs = xs[:, :, begin:end, :] - # chunk_xs shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size] - - ys, cache_att_list = self.forward_chunk( - xs=chunk_xs, - max_att_cache_length=max_att_cache_length, - cache_att_list=cache_att_list, - ) - # ys shape: [batch_size, channels, self.chunk_size * (self.num_right_chunks + 1), input_size] - ys = ys[:, :, :self.chunk_size, :] - - outputs.append(ys) - - ys = torch.cat(outputs, dim=2) - return ys - - -def main2(): - - encoder = TransformerEncoder( - input_size=64, - hidden_size=256, - attention_heads=4, - num_blocks=6, - dropout_rate=0.1, - ) - print(encoder) - - x = torch.ones([4, 200, 64]) - - x = torch.ones([4, 200, 64]) - y = encoder.forward(xs=x) - print(y.shape) - - x = torch.ones([4, 200, 64]) - y = encoder.forward_chunk_by_chunk(xs=x) - print(y.shape) - - return - - -def main(): - - encoder = TSTransformerEncoder( - input_size=8, - hidden_size=16, - attention_heads=2, - num_blocks=2, - dropout_rate=0.1, - ) - # print(encoder) - - x = torch.ones([4, 8, 200, 8]) - y = encoder.forward(xs=x) - print(y.shape) - - x = torch.ones([4, 8, 200, 8]) - y = encoder.forward_chunk_by_chunk(xs=x) - print(y.shape) - - return - - -if __name__ == '__main__': - main() diff --git a/toolbox/torchaudio/models/nx_mpnet/utils.py b/toolbox/torchaudio/models/nx_mpnet/utils.py deleted file mode 100644 index 97e8ef2b76a94304a2730a4a097058bed8613430..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/utils.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/python3 -# -*- coding: utf-8 -*- -import torch -import torch.nn as nn - - -class LearnableSigmoid1d(nn.Module): - def __init__(self, in_features, beta=1): - super().__init__() - self.beta = beta - self.slope = nn.Parameter(torch.ones(in_features)) - self.slope.requiresGrad = True - - def forward(self, x): - # x shape: [batch_size, time_steps, spec_bins] - return self.beta * torch.sigmoid(self.slope * x) - - -class LearnableSigmoid2d(nn.Module): - def __init__(self, in_features, beta=1): - super().__init__() - self.beta = beta - self.slope = nn.Parameter(torch.ones(in_features, 1)) - self.slope.requiresGrad = True - - def forward(self, x): - return self.beta * torch.sigmoid(self.slope * x) - - -def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True): - - hann_window = torch.hann_window(win_size).to(y.device) - stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, - center=center, pad_mode='reflect', normalized=False, return_complex=True) - stft_spec = torch.view_as_real(stft_spec) - mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9) - pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5) - # Magnitude Compression - mag = torch.pow(mag, compress_factor) - com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1) - - return mag, pha, com - - -def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True): - # Magnitude Decompression - mag = torch.pow(mag, (1.0/compress_factor)) - com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha)) - hann_window = torch.hann_window(win_size).to(com.device) - wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center) - - return wav - - -if __name__ == '__main__': - pass diff --git a/toolbox/torchaudio/models/nx_mpnet/yaml/config.yaml b/toolbox/torchaudio/models/nx_mpnet/yaml/config.yaml deleted file mode 100644 index 76b00987ab80d58269cdd5c00dc7f09e498752f7..0000000000000000000000000000000000000000 --- a/toolbox/torchaudio/models/nx_mpnet/yaml/config.yaml +++ /dev/null @@ -1,38 +0,0 @@ -model_name: "nx_denoise" - -sample_rate: 8000 -segment_size: 16000 -n_fft: 512 -win_size: 200 -hop_size: 80 - -dense_num_blocks: 4 -dense_hidden_size: 64 - -mask_num_blocks: 4 -mask_hidden_size: 64 - -phase_num_blocks: 4 -phase_hidden_size: 64 - -tsfm_hidden_size: 64 -tsfm_attention_heads: 4 -tsfm_num_blocks: 4 -tsfm_dropout_rate: 0.0 -tsfm_max_time_relative_position: 2048 -tsfm_max_freq_relative_position: 256 -tsfm_chunk_size: 1 -tsfm_num_left_chunks: 64 -tsfm_num_right_chunks: 2 - -discriminator_dim: 32 -discriminator_in_channel: 2 - -compress_factor: 0.3 - -batch_size: 4 -learning_rate: 0.0005 -adam_b1: 0.8 -adam_b2: 0.99 -lr_decay: 0.99 -seed: 1234 diff --git a/toolbox/torchaudio/models/percepnet/modeling_percetnet.py b/toolbox/torchaudio/models/percepnet/modeling_percetnet.py index f38e243b78bebe6baa7a213d0a0b359e782ca1bc..bcdfcfa2069f8225a9b304ce960d724bc1b9557a 100644 --- a/toolbox/torchaudio/models/percepnet/modeling_percetnet.py +++ b/toolbox/torchaudio/models/percepnet/modeling_percetnet.py @@ -4,8 +4,97 @@ https://github.com/jzi040941/PercepNet https://arxiv.org/abs/2008.04259 + +https://modelzoo.co/model/percepnet + +太复杂了。 +(1)pytorch 模型只是整个 pipeline 中的一部分。 +(2)训练样本需经过基音分析,频谱包络之类的计算。 + """ +import torch +import torch.nn as nn + + +class PercepNet(nn.Module): + """ + https://github.com/jzi040941/PercepNet/blob/main/rnn_train.py#L105 + + 4.1% of an x86 CPU core + """ + def __init__(self, input_dim=70): + super(PercepNet, self).__init__() + # self.hidden_dim = hidden_dim + # self.n_layers = n_layers + + self.fc = nn.Sequential( + nn.Linear(input_dim, 128), + nn.ReLU() + ) + self.conv1 = nn.Sequential( + nn.Conv1d(128, 512, 5, stride=1, padding=4), + nn.ReLU() + )#padding for align with c++ dnn + self.conv2 = nn.Sequential( + nn.Conv1d(512, 512, 3, stride=1, padding=2), + nn.Tanh() + ) + #self.gru = nn.GRU(512, 512, 3, batch_first=True) + self.gru1 = nn.GRU(512, 512, 1, batch_first=True) + self.gru2 = nn.GRU(512, 512, 1, batch_first=True) + self.gru3 = nn.GRU(512, 512, 1, batch_first=True) + + self.gru_gb = nn.GRU(512, 512, 1, batch_first=True) + self.gru_rb = nn.GRU(1024, 128, 1, batch_first=True) + + self.fc_gb = nn.Sequential( + nn.Linear(512*5, 34), + nn.Sigmoid() + ) + self.fc_rb = nn.Sequential( + nn.Linear(128, 34), + nn.Sigmoid() + ) + + def forward(self, x: torch.Tensor): + # x shape: [b, t, f] + x = self.fc(x) + x = x.permute([0, 2, 1]) + # x shape: [b, f, t] + + # causal conv + x = self.conv1(x) + x = x[:, :, :-4] + + # x shape: [b, f, t] + convout = self.conv2(x) + convout = convout[:, :, :-2] + convout = convout.permute([0, 2, 1]) + # convout shape: [b, t, f] + + gru1_out, gru1_state = self.gru1(convout) + gru2_out, gru2_state = self.gru2(gru1_out) + gru3_out, gru3_state = self.gru3(gru2_out) + + gru_gb_out, gru_gb_state = self.gru_gb(gru3_out) + concat_gb_layer = torch.cat(tensors=(convout, gru1_out, gru2_out, gru3_out, gru_gb_out), dim=-1) + gb = self.fc_gb(concat_gb_layer) + + # concat rb need fix + concat_rb_layer = torch.cat(tensors=(gru3_out, convout), dim=-1) + rnn_rb_out, gru_rb_state = self.gru_rb(concat_rb_layer) + rb = self.fc_rb(rnn_rb_out) + + output = torch.cat((gb, rb), dim=-1) + return output + + +def main(): + model = PercepNet() + x = torch.randn(20, 8, 70) + out = model(x) + print(out.shape) -if __name__ == '__main__': - pass +if __name__ == "__main__": + main()