diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/__pycache__/configuration_bigcodec.cpython-39.pyc b/__pycache__/configuration_bigcodec.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..addbfc333619bd1790037f582e33f7e54140298d Binary files /dev/null and b/__pycache__/configuration_bigcodec.cpython-39.pyc differ diff --git a/__pycache__/modeling_bigcodec.cpython-39.pyc b/__pycache__/modeling_bigcodec.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f619275e7198a7b71ec73beb764d5d8c5987256 Binary files /dev/null and b/__pycache__/modeling_bigcodec.cpython-39.pyc differ diff --git a/__pycache__/modeling_xcodec2.cpython-39.pyc b/__pycache__/modeling_xcodec2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c5cdb78b8ffb2eb83869d16ab5cc58281e81b93 Binary files /dev/null and b/__pycache__/modeling_xcodec2.cpython-39.pyc differ diff --git a/config.json b/config.json new file mode 100644 index 0000000000000000000000000000000000000000..5acd94f351833f55f40fd49e4f9da49d34c8540f --- /dev/null +++ b/config.json @@ -0,0 +1,11 @@ +{ + "model_type": "xcodec2", + "semantic_hidden_size": 1024, + "codec_encoder_hidden_size": 1024, + "codec_decoder_hidden_size": 1024, + "use_vocos": true, + "architectures": [ + "XCodec2Model" + ] + } + \ No newline at end of file diff --git a/configuration_bigcodec.py b/configuration_bigcodec.py new file mode 100644 index 0000000000000000000000000000000000000000..c7aa72b64eca8cae1220413f47039a112100d866 --- /dev/null +++ b/configuration_bigcodec.py @@ -0,0 +1,19 @@ +from transformers import PretrainedConfig + +class BigCodecConfig(PretrainedConfig): + model_type = "bigcodec" + + def __init__( + self, + # 下面这些只是示例超参 + semantic_hidden_size=1024, + codec_encoder_hidden_size=1024, + codec_decoder_hidden_size=1024, + use_vocos=True, + **kwargs + ): + super().__init__(**kwargs) + self.semantic_hidden_size = semantic_hidden_size + self.codec_encoder_hidden_size = codec_encoder_hidden_size + self.codec_decoder_hidden_size = codec_decoder_hidden_size + self.use_vocos = use_vocos diff --git a/modeling_xcodec2.py b/modeling_xcodec2.py new file mode 100644 index 0000000000000000000000000000000000000000..3594648a5e1a7e9e5c7677ff066fb6b541730b9e --- /dev/null +++ b/modeling_xcodec2.py @@ -0,0 +1,165 @@ +import torch +import torch.nn as nn +from transformers import PreTrainedModel +from configuration_bigcodec import BigCodecConfig + +# 请确保这些模块路径是正确的 +from vq.codec_encoder import CodecEncoder_Transformer +from vq.codec_decoder_vocos import CodecDecoderVocos +from vq.module import SemanticEncoder +from transformers import AutoFeatureExtractor, Wav2Vec2BertModel + +class XCodec2Model(PreTrainedModel): + config_class = BigCodecConfig + + def __init__(self, config: BigCodecConfig): + super().__init__(config) + + # 1) 语义模型 + self.semantic_model = Wav2Vec2BertModel.from_pretrained( + "facebook/w2v-bert-2.0", + output_hidden_states=True + ) + self.semantic_model.eval() + + self.SemanticEncoder_module = SemanticEncoder( + config.semantic_hidden_size, + config.semantic_hidden_size, + config.semantic_hidden_size + ) + + # 2) Codec Encoder + self.CodecEnc = CodecEncoder_Transformer() + + # 3) Codec Decoder + self.generator = CodecDecoderVocos() + + # 4) 两个全连接层 + self.fc_prior = nn.Linear(2048, 2048) + self.fc_post_a = nn.Linear(2048, 1024) + feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") + self.feature_extractor = feature_extractor + + def forward(self, input_waveform, sample_rate=16000): + """ + 这里的 forward 不一定要叫 forward,也可以拆成别的方法; + 但是如果想兼容 pipeline,需要在 forward 里给出核心逻辑。 + + 参数: + input_waveform: [batch_size, waveform_length] + sample_rate: 默认 16000 + 返回: + 重构后的语音音频 (Tensor) + """ + # 1) 特征提取 + # 如果需要 padding,可以在这里做 + input_features = self.feature_extractor( + input_waveform, + sampling_rate=sample_rate, + return_tensors="pt" + ).input_features.to(self.device) # [batch, frames, feat_dim] + + # 2) 语义层 + semantic_output = self.semantic_model(input_features) + semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层 + semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames] + semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16) + + # 3) codec encoder + wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time] + vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例 + vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames] + + # 对齐语义向量的时间帧数,这里只做示例处理 + # 真实做法里可能要先对齐维度 + if vq_emb.shape[-1] != semantic_encoded.shape[-1]: + # 简单强行截断或补零都行,需要你自己决定 + min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1]) + vq_emb = vq_emb[:, :, :min_len] + semantic_encoded = semantic_encoded[:, :, :min_len] + + # 4) 拼接 + concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 1024 + 1024, frames] + + # 5) fc_prior + concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2) + + # 6) decoder 的量化部分 + _, vq_code, _ = self.generator(concat_emb, vq=True) + vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) + vq_post_emb = vq_post_emb.transpose(1, 2) + + # 7) fc_post_a + vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) + + # 8) 最后解码成波形 + recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] + # recon_audio: [batch, time] + return recon_audio + + def encode_code(self, input_waveform, sample_rate=16000): + """ + 将输入的音频编码为代码表示。 + + 参数: + input_waveform: [batch_size, waveform_length] + sample_rate: 默认 16000 + 返回: + 编码后的代码 (Tensor) + """ + with torch.no_grad(): + # 1) 特征提取 + input_features = self.feature_extractor( + input_waveform, + sampling_rate=sample_rate, + return_tensors="pt" + ).input_features.to(self.device) # [batch, frames, feat_dim] + + # 2) 语义层 + semantic_output = self.semantic_model(input_features) + semantic_hidden_16 = semantic_output.hidden_states[16] # 取第16层 + semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) # [batch, hidden_dim, frames] + semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16) + + # 3) codec encoder + wav = input_waveform.unsqueeze(1).to(self.device) # shape: [batch, 1, time] + vq_emb = self.CodecEnc(wav) # [batch, time//down, 1024] 只是示例 + vq_emb = vq_emb.transpose(1, 2) # -> [batch, 1024, frames] + + # 对齐语义向量的时间帧数,这里只做示例处理 + if vq_emb.shape[-1] != semantic_encoded.shape[-1]: + min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1]) + vq_emb = vq_emb[:, :, :min_len] + semantic_encoded = semantic_encoded[:, :, :min_len] + + # 4) 拼接 + concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) # [batch, 2048, frames] + + # 5) fc_prior + concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2) + + # 6) decoder 的量化部分,获取code + _, vq_code, _ = self.generator(concat_emb, vq=True) + # vq_code: [batch, frames] + return vq_code + + def decode_code(self, vq_code): + """ + 将编码后的代码解码回音频。 + + 参数: + vq_code: 编码后的代码 (Tensor) [batch, frames] + 返回: + 解码后的音频 (Tensor) [batch, waveform_length] + """ + with torch.no_grad(): + # 获取量化后的嵌入 + vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) + vq_post_emb = vq_post_emb.transpose(1, 2) # [batch, 1024, frames] + + # 7) fc_post_a + vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) # [batch, 1024, frames] + + # 8) 最后解码成波形 + recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] # [batch, time] + return recon_audio diff --git a/module.py b/module.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pytorch_model.bin b/pytorch_model.bin new file mode 100644 index 0000000000000000000000000000000000000000..793a4ae86341f078614eefe86860b099f66d74a0 --- /dev/null +++ b/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8cb939062f3930e56ff22082f49c95461aedc8ceade7ff7b16a1b10f1e92e0be +size 3291343655 diff --git a/reconstructed.wav b/reconstructed.wav new file mode 100644 index 0000000000000000000000000000000000000000..c41d76703f5f388f201f1dfc4175a9c35bca04df Binary files /dev/null and b/reconstructed.wav differ diff --git a/test.flac b/test.flac new file mode 100755 index 0000000000000000000000000000000000000000..237fef5216e4c0098de87513e646a2b79f0dd89b Binary files /dev/null and b/test.flac differ diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7a25e3c9614d4d2d98067492a0043e86c0d711 --- /dev/null +++ b/test.py @@ -0,0 +1,21 @@ +import torch +import soundfile as sf +from transformers import AutoConfig + +from modeling_xcodec2 import XCodec2Model + +model_path = "/data/zheny/xcodec2" # 这是你在 huggingface 上的仓库名 + +model = XCodec2Model.from_pretrained(model_path) +model.eval().cuda() + +# 准备一段音频 +wav, sr = sf.read("test.flac") +wav_tensor = torch.from_numpy(wav).float().unsqueeze(0) # [1, time] + +with torch.no_grad(): + vq_code = model.encode_code(input_waveform=wav_tensor ) + print(vq_code) + recon_wav = model.decode_code(vq_code).cpu() + +sf.write("reconstructed.wav", recon_wav[0,0,:].numpy(), sr) diff --git a/vq/__init__.py b/vq/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..82b8462e5c4c1d1e8ac0e6e6ce02589f4e55d959 --- /dev/null +++ b/vq/__init__.py @@ -0,0 +1,4 @@ +from vq.codec_encoder import CodecEncoder +from vq.codec_decoder import CodecDecoder +from vq.codec_decoder_vocos import CodecDecoderVocos +from vq.codec_encoder import CodecEncoder_Transformer,CodecEncoder_only_Transformer \ No newline at end of file diff --git a/vq/__pycache__/__init__.cpython-310.pyc b/vq/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8045fbf325a1a92bcab9bb26006a965959ea6d0f Binary files /dev/null and b/vq/__pycache__/__init__.cpython-310.pyc differ diff --git a/vq/__pycache__/__init__.cpython-311.pyc b/vq/__pycache__/__init__.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..3a622e8da2ba3969a82261d2c1e6a3dbe2a87155 Binary files /dev/null and b/vq/__pycache__/__init__.cpython-311.pyc differ diff --git a/vq/__pycache__/__init__.cpython-312.pyc b/vq/__pycache__/__init__.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a1bd6ded01b7b342483f5e8ab7993e74a8b74054 Binary files /dev/null and b/vq/__pycache__/__init__.cpython-312.pyc differ diff --git a/vq/__pycache__/__init__.cpython-38.pyc b/vq/__pycache__/__init__.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..dde7ebf760686dcb7201064a63e9619b280c0c6e Binary files /dev/null and b/vq/__pycache__/__init__.cpython-38.pyc differ diff --git a/vq/__pycache__/__init__.cpython-39.pyc b/vq/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee6edbe0d9811e1f6d087bee9c44604e457ae546 Binary files /dev/null and b/vq/__pycache__/__init__.cpython-39.pyc differ diff --git a/vq/__pycache__/activations.cpython-310.pyc b/vq/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d58546839943071a21bee919697b57c3a6e01384 Binary files /dev/null and b/vq/__pycache__/activations.cpython-310.pyc differ diff --git a/vq/__pycache__/activations.cpython-311.pyc b/vq/__pycache__/activations.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..b68f81ee8be448f64ef4c8f7bab31d57bdd561a8 Binary files /dev/null and b/vq/__pycache__/activations.cpython-311.pyc differ diff --git a/vq/__pycache__/activations.cpython-312.pyc b/vq/__pycache__/activations.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..7b65c1363e2538d7f7da65eacf3801535d9dfe7e Binary files /dev/null and b/vq/__pycache__/activations.cpython-312.pyc differ diff --git a/vq/__pycache__/activations.cpython-38.pyc b/vq/__pycache__/activations.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..0df9e1a83b50d72a8b7da1b0e85a3c34c5084f60 Binary files /dev/null and b/vq/__pycache__/activations.cpython-38.pyc differ diff --git a/vq/__pycache__/activations.cpython-39.pyc b/vq/__pycache__/activations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a93430eb9e182116727bb2fd9a8c1ced9e1a423b Binary files /dev/null and b/vq/__pycache__/activations.cpython-39.pyc differ diff --git a/vq/__pycache__/blocks.cpython-310.pyc b/vq/__pycache__/blocks.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..2d2533abbe61cf93b51d5a76fadbfeadf81dca75 Binary files /dev/null and b/vq/__pycache__/blocks.cpython-310.pyc differ diff --git a/vq/__pycache__/blocks.cpython-39.pyc b/vq/__pycache__/blocks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa36890d26c5d256baf108973f548e68237fc1a9 Binary files /dev/null and b/vq/__pycache__/blocks.cpython-39.pyc differ diff --git a/vq/__pycache__/bs_roformer5.cpython-310.pyc b/vq/__pycache__/bs_roformer5.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ac6beca8e2a0640d18bea04d67a5afcf631a308 Binary files /dev/null and b/vq/__pycache__/bs_roformer5.cpython-310.pyc differ diff --git a/vq/__pycache__/bs_roformer5.cpython-38.pyc b/vq/__pycache__/bs_roformer5.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..912a7abb3dc54bd1cbde57f0c1d9d6c81d8b2d2d Binary files /dev/null and b/vq/__pycache__/bs_roformer5.cpython-38.pyc differ diff --git a/vq/__pycache__/bs_roformer5.cpython-39.pyc b/vq/__pycache__/bs_roformer5.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92d8ebd686f5ca067f8afedd3c11df3eae114f8b Binary files /dev/null and b/vq/__pycache__/bs_roformer5.cpython-39.pyc differ diff --git a/vq/__pycache__/codec_decoder.cpython-310.pyc b/vq/__pycache__/codec_decoder.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..70b5729ec447bd862aec51154656561ccae09cc8 Binary files /dev/null and b/vq/__pycache__/codec_decoder.cpython-310.pyc differ diff --git a/vq/__pycache__/codec_decoder.cpython-311.pyc b/vq/__pycache__/codec_decoder.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..d3bba0654a05fc2d5df7f0cfa85d4e1e8f9e43a9 Binary files /dev/null and b/vq/__pycache__/codec_decoder.cpython-311.pyc differ diff --git a/vq/__pycache__/codec_decoder.cpython-312.pyc b/vq/__pycache__/codec_decoder.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..9fd232bd7d3e8217ae32122d85158c8d3d07a0ba Binary files /dev/null and b/vq/__pycache__/codec_decoder.cpython-312.pyc differ diff --git a/vq/__pycache__/codec_decoder.cpython-39.pyc b/vq/__pycache__/codec_decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..342c797d5e48be38a60b3987b4ccf03e115c42f1 Binary files /dev/null and b/vq/__pycache__/codec_decoder.cpython-39.pyc differ diff --git a/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc b/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..18326f4b24686e299d318969eefeb9f53d0e101b Binary files /dev/null and b/vq/__pycache__/codec_decoder_vocos.cpython-310.pyc differ diff --git a/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc b/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..16d7559f3e464f7cc8229cf5cf3864cf6ce5270a Binary files /dev/null and b/vq/__pycache__/codec_decoder_vocos.cpython-311.pyc differ diff --git a/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc b/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4d14cbb3475d4adca86027ce48447562dd8bd872 Binary files /dev/null and b/vq/__pycache__/codec_decoder_vocos.cpython-312.pyc differ diff --git a/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc b/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..039c2ff83bd0dece5f89baf003ee5c212a9ec9eb Binary files /dev/null and b/vq/__pycache__/codec_decoder_vocos.cpython-39.pyc differ diff --git a/vq/__pycache__/codec_encoder.cpython-310.pyc b/vq/__pycache__/codec_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bae05d20ccc33e03b13e92a6ddf4792c7963bac Binary files /dev/null and b/vq/__pycache__/codec_encoder.cpython-310.pyc differ diff --git a/vq/__pycache__/codec_encoder.cpython-311.pyc b/vq/__pycache__/codec_encoder.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a95e106c52b918fc9879df154585f48f5b97f084 Binary files /dev/null and b/vq/__pycache__/codec_encoder.cpython-311.pyc differ diff --git a/vq/__pycache__/codec_encoder.cpython-312.pyc b/vq/__pycache__/codec_encoder.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..033919b2465e77531ac1f4340a42e8247dcf927e Binary files /dev/null and b/vq/__pycache__/codec_encoder.cpython-312.pyc differ diff --git a/vq/__pycache__/codec_encoder.cpython-38.pyc b/vq/__pycache__/codec_encoder.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..b200d30c057053802717cd326fcfbd047261134e Binary files /dev/null and b/vq/__pycache__/codec_encoder.cpython-38.pyc differ diff --git a/vq/__pycache__/codec_encoder.cpython-39.pyc b/vq/__pycache__/codec_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6776bd150fda7c289c24c8798d44fb5abe57993 Binary files /dev/null and b/vq/__pycache__/codec_encoder.cpython-39.pyc differ diff --git a/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc b/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..164f333e2d1ee7902db697d446e83ea603e5d395 Binary files /dev/null and b/vq/__pycache__/factorized_vector_quantize.cpython-310.pyc differ diff --git a/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc b/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..d1018bb516e0ca39612d3046cc66e93d3e9fc857 Binary files /dev/null and b/vq/__pycache__/factorized_vector_quantize.cpython-311.pyc differ diff --git a/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc b/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a79936f97258dfac4e4c1fa00bcba84459f754bf Binary files /dev/null and b/vq/__pycache__/factorized_vector_quantize.cpython-312.pyc differ diff --git a/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc b/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff4eb90349eef0207e03b2d991f85499da62ab59 Binary files /dev/null and b/vq/__pycache__/factorized_vector_quantize.cpython-39.pyc differ diff --git a/vq/__pycache__/module.cpython-310.pyc b/vq/__pycache__/module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cd9176e099c4aeba05411b5aab4a9914ea3fccf Binary files /dev/null and b/vq/__pycache__/module.cpython-310.pyc differ diff --git a/vq/__pycache__/module.cpython-311.pyc b/vq/__pycache__/module.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..8bae56a74114e5ecb523e39275bf451e406e43a2 Binary files /dev/null and b/vq/__pycache__/module.cpython-311.pyc differ diff --git a/vq/__pycache__/module.cpython-312.pyc b/vq/__pycache__/module.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..67dc7809ed198ecba09efecfa433224dc7b5ea4a Binary files /dev/null and b/vq/__pycache__/module.cpython-312.pyc differ diff --git a/vq/__pycache__/module.cpython-38.pyc b/vq/__pycache__/module.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..2fd7e9ff24b9fb50a53e26e196d0c2e7062b71ef Binary files /dev/null and b/vq/__pycache__/module.cpython-38.pyc differ diff --git a/vq/__pycache__/module.cpython-39.pyc b/vq/__pycache__/module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97d6cb68f45b74a72d220b758c224fc1d7bcf974 Binary files /dev/null and b/vq/__pycache__/module.cpython-39.pyc differ diff --git a/vq/__pycache__/residual_vq.cpython-310.pyc b/vq/__pycache__/residual_vq.cpython-310.pyc new file mode 100755 index 0000000000000000000000000000000000000000..91031f08f8680edad3363081827fcd95c3908e74 Binary files /dev/null and b/vq/__pycache__/residual_vq.cpython-310.pyc differ diff --git a/vq/__pycache__/residual_vq.cpython-311.pyc b/vq/__pycache__/residual_vq.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..92069858ec62cb51a7cce7cdf1550695c656e557 Binary files /dev/null and b/vq/__pycache__/residual_vq.cpython-311.pyc differ diff --git a/vq/__pycache__/residual_vq.cpython-312.pyc b/vq/__pycache__/residual_vq.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..4d81960f8bbec9c309fa08267b44245d8a5b4959 Binary files /dev/null and b/vq/__pycache__/residual_vq.cpython-312.pyc differ diff --git a/vq/__pycache__/residual_vq.cpython-39.pyc b/vq/__pycache__/residual_vq.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c004cb105a51fd161b325e6568b965634f0cddce Binary files /dev/null and b/vq/__pycache__/residual_vq.cpython-39.pyc differ diff --git a/vq/__pycache__/unet.cpython-312.pyc b/vq/__pycache__/unet.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..e431b555864d1029ed9735ed1bad0bae1b75c4f8 Binary files /dev/null and b/vq/__pycache__/unet.cpython-312.pyc differ diff --git a/vq/__pycache__/unet.cpython-39.pyc b/vq/__pycache__/unet.cpython-39.pyc new file mode 100755 index 0000000000000000000000000000000000000000..9e12a03c0fba3df01e7d87e249fd77a3bab599f1 Binary files /dev/null and b/vq/__pycache__/unet.cpython-39.pyc differ diff --git a/vq/activations.py b/vq/activations.py new file mode 100755 index 0000000000000000000000000000000000000000..2444a7bd9d52018e97892820a072b39a21245372 --- /dev/null +++ b/vq/activations.py @@ -0,0 +1,120 @@ +# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. +# LICENSE is in incl_licenses directory. + +import torch +from torch import nn, sin, pow +from torch.nn import Parameter + + +class Snake(nn.Module): + ''' + Implementation of a sine-based periodic activation function + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter + References: + - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snake(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha: trainable parameter + alpha is initialized to 1 by default, higher values = higher-frequency. + alpha will be trained along with the rest of your model. + ''' + super(Snake, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + Snake ∶= x + 1/a * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + if self.alpha_logscale: + alpha = torch.exp(alpha) + x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x + + +class SnakeBeta(nn.Module): + ''' + A modified Snake function which uses separate parameters for the magnitude of the periodic components + Shape: + - Input: (B, C, T) + - Output: (B, C, T), same shape as the input + Parameters: + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + References: + - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: + https://arxiv.org/abs/2006.08195 + Examples: + >>> a1 = snakebeta(256) + >>> x = torch.randn(256) + >>> x = a1(x) + ''' + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): + ''' + Initialization. + INPUT: + - in_features: shape of the input + - alpha - trainable parameter that controls frequency + - beta - trainable parameter that controls magnitude + alpha is initialized to 1 by default, higher values = higher-frequency. + beta is initialized to 1 by default, higher values = higher-magnitude. + alpha will be trained along with the rest of your model. + ''' + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = Parameter(torch.zeros(in_features) * alpha) + self.bias = Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = Parameter(torch.ones(in_features) * alpha) + self.bias = Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.bias.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + ''' + Forward pass of the function. + Applies the function to the input elementwise. + SnakeBeta ∶= x + 1/b * sin^2 (xa) + ''' + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.bias.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2) + + return x \ No newline at end of file diff --git a/vq/alias_free_torch/__init__.py b/vq/alias_free_torch/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..a2318b63198250856809c0cb46210a4147b829bc --- /dev/null +++ b/vq/alias_free_torch/__init__.py @@ -0,0 +1,6 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +from .filter import * +from .resample import * +from .act import * \ No newline at end of file diff --git a/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc b/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a22bd3dc8fb49f8d4de96bae72d267e781b4ac78 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/__init__.cpython-310.pyc differ diff --git a/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc b/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..8c624a5d6d4dc886d41f429249030bde5539887e Binary files /dev/null and b/vq/alias_free_torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc b/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..a43e41d400b719ffcbd2edeb74835601d56febf5 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/__init__.cpython-312.pyc differ diff --git a/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc b/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..80a6fe5a8f4ae483b2d80bf053a3d6b0a555934b Binary files /dev/null and b/vq/alias_free_torch/__pycache__/__init__.cpython-38.pyc differ diff --git a/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc b/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01dd0800751de52eba5308427e45658cfa47442d Binary files /dev/null and b/vq/alias_free_torch/__pycache__/__init__.cpython-39.pyc differ diff --git a/vq/alias_free_torch/__pycache__/act.cpython-310.pyc b/vq/alias_free_torch/__pycache__/act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffe5fe4a82a3e1cbf2953da33930503f8ee72b84 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/act.cpython-310.pyc differ diff --git a/vq/alias_free_torch/__pycache__/act.cpython-311.pyc b/vq/alias_free_torch/__pycache__/act.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..6db1958197fb0a213b4ff1df75dcfbf89317ac7d Binary files /dev/null and b/vq/alias_free_torch/__pycache__/act.cpython-311.pyc differ diff --git a/vq/alias_free_torch/__pycache__/act.cpython-312.pyc b/vq/alias_free_torch/__pycache__/act.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..8cfb68556b30eb0c90e80734d84ce4630c84f6fe Binary files /dev/null and b/vq/alias_free_torch/__pycache__/act.cpython-312.pyc differ diff --git a/vq/alias_free_torch/__pycache__/act.cpython-38.pyc b/vq/alias_free_torch/__pycache__/act.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..5ae72232f292c8b85ef9b03ed69be2ff1c1f76e9 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/act.cpython-38.pyc differ diff --git a/vq/alias_free_torch/__pycache__/act.cpython-39.pyc b/vq/alias_free_torch/__pycache__/act.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11d880c40d0fed4924c1d5245722c5218e185ef2 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/act.cpython-39.pyc differ diff --git a/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc b/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f69124ecafe75d22e22a313a9d163717518814c Binary files /dev/null and b/vq/alias_free_torch/__pycache__/filter.cpython-310.pyc differ diff --git a/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc b/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..ce1c1840a49147f4771cb6c43c75fafdff331ebf Binary files /dev/null and b/vq/alias_free_torch/__pycache__/filter.cpython-311.pyc differ diff --git a/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc b/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..e7c8b0886d3813d2f3184509934becbdd6a7cf6f Binary files /dev/null and b/vq/alias_free_torch/__pycache__/filter.cpython-312.pyc differ diff --git a/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc b/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..947c698cca0f981ecc738c950d4d437e2840ba66 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/filter.cpython-38.pyc differ diff --git a/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc b/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46ca966e33e20803fba84a5da40097a83907bc7d Binary files /dev/null and b/vq/alias_free_torch/__pycache__/filter.cpython-39.pyc differ diff --git a/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc b/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2a0bda9f06882109c8568d1571dd953f7186918 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/resample.cpython-310.pyc differ diff --git a/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc b/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc new file mode 100755 index 0000000000000000000000000000000000000000..51f3f964031846bf45ae58bc751a8d74afe4495f Binary files /dev/null and b/vq/alias_free_torch/__pycache__/resample.cpython-311.pyc differ diff --git a/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc b/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc new file mode 100755 index 0000000000000000000000000000000000000000..bd086d178282ace7cda97ede64bd598c43012dab Binary files /dev/null and b/vq/alias_free_torch/__pycache__/resample.cpython-312.pyc differ diff --git a/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc b/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc new file mode 100755 index 0000000000000000000000000000000000000000..9e1268b9d83d1115c5f7fa2ac77f6d82227d3652 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/resample.cpython-38.pyc differ diff --git a/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc b/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15613b4678ac49403758cba125a01e482c12b390 Binary files /dev/null and b/vq/alias_free_torch/__pycache__/resample.cpython-39.pyc differ diff --git a/vq/alias_free_torch/act.py b/vq/alias_free_torch/act.py new file mode 100755 index 0000000000000000000000000000000000000000..028debd697dd60458aae75010057df038bd3518a --- /dev/null +++ b/vq/alias_free_torch/act.py @@ -0,0 +1,28 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from .resample import UpSample1d, DownSample1d + + +class Activation1d(nn.Module): + def __init__(self, + activation, + up_ratio: int = 2, + down_ratio: int = 2, + up_kernel_size: int = 12, + down_kernel_size: int = 12): + super().__init__() + self.up_ratio = up_ratio + self.down_ratio = down_ratio + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + # x: [B,C,T] + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + + return x \ No newline at end of file diff --git a/vq/alias_free_torch/filter.py b/vq/alias_free_torch/filter.py new file mode 100755 index 0000000000000000000000000000000000000000..7ad6ea87c1f10ddd94c544037791d7a4634d5ae1 --- /dev/null +++ b/vq/alias_free_torch/filter.py @@ -0,0 +1,95 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + +if 'sinc' in dir(torch): + sinc = torch.sinc +else: + # This code is adopted from adefossez's julius.core.sinc under the MIT License + # https://adefossez.github.io/julius/julius/core.html + # LICENSE is in incl_licenses directory. + def sinc(x: torch.Tensor): + """ + Implementation of sinc, i.e. sin(pi * x) / (pi * x) + __Warning__: Different to julius.sinc, the input is multiplied by `pi`! + """ + return torch.where(x == 0, + torch.tensor(1., device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x) + + +# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License +# https://adefossez.github.io/julius/julius/lowpass.html +# LICENSE is in incl_licenses directory. +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size] + even = (kernel_size % 2 == 0) + half_size = kernel_size // 2 + + #For kaiser window + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.: + beta = 0.1102 * (A - 8.7) + elif A >= 21.: + beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.) + else: + beta = 0. + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + + # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio + if even: + time = (torch.arange(-half_size, half_size) + 0.5) + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * sinc(2 * cutoff * time) + # Normalize filter to have sum = 1, otherwise we will have a small leakage + # of the constant component in the input signal. + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + + return filter + + +class LowPassFilter1d(nn.Module): + def __init__(self, + cutoff=0.5, + half_width=0.6, + stride: int = 1, + padding: bool = True, + padding_mode: str = 'replicate', + kernel_size: int = 12): + # kernel_size should be even number for stylegan3 setup, + # in this implementation, odd number is also possible. + super().__init__() + if cutoff < -0.: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = (kernel_size % 2 == 0) + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + #input [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), + mode=self.padding_mode) + out = F.conv1d(x, self.filter.expand(C, -1, -1), + stride=self.stride, groups=C) + + return out \ No newline at end of file diff --git a/vq/alias_free_torch/resample.py b/vq/alias_free_torch/resample.py new file mode 100755 index 0000000000000000000000000000000000000000..750e6c3402cc5ac939c4b9d075246562e0e1d1a7 --- /dev/null +++ b/vq/alias_free_torch/resample.py @@ -0,0 +1,49 @@ +# Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0 +# LICENSE is in incl_licenses directory. + +import torch.nn as nn +from torch.nn import functional as F +from .filter import LowPassFilter1d +from .filter import kaiser_sinc_filter1d + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.stride = ratio + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + kernel_size=self.kernel_size) + self.register_buffer("filter", filter) + + # x: [B, C, T] + def forward(self, x): + _, C, _ = x.shape + + x = F.pad(x, (self.pad, self.pad), mode='replicate') + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + x = x[..., self.pad_left:-self.pad_right] + + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size) + + def forward(self, x): + xx = self.lowpass(x) + + return xx \ No newline at end of file diff --git a/vq/blocks.py b/vq/blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..3996fec146cbf4f3caef4f9da3bbbe04f7729bbb --- /dev/null +++ b/vq/blocks.py @@ -0,0 +1,183 @@ +from typing import Callable, Sequence, Type, Union + +import numpy as np +import torch +import torch.nn as nn + +ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]] + + +class FeedForwardModule(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.net = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class Residual(nn.Module): + + def __init__(self, module: nn.Module) -> None: + super().__init__() + self.module = module + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.module(x) + x + + +class DilatedConvolutionalUnit(FeedForwardModule): + + def __init__( + self, + hidden_dim: int, + dilation: int, + kernel_size: int, + activation: ModuleFactory, + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + self.net = nn.Sequential( + activation(), + normalization( + nn.Conv1d( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=kernel_size, + dilation=dilation, + padding=((kernel_size - 1) * dilation) // 2, + )), + activation(), + nn.Conv1d(in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=1), + ) + + +class UpsamplingUnit(FeedForwardModule): + + def __init__( + self, + input_dim: int, + output_dim: int, + stride: int, + activation: ModuleFactory, + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + self.net = nn.Sequential( + activation(), + normalization( + nn.ConvTranspose1d( + in_channels=input_dim, + out_channels=output_dim, + kernel_size=2 * stride, + stride=stride, + padding=stride // 2+ stride % 2, + output_padding=1 if stride % 2 != 0 else 0 + ))) + + +class DownsamplingUnit(FeedForwardModule): + + def __init__( + self, + input_dim: int, + output_dim: int, + stride: int, + activation: ModuleFactory, + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + self.net = nn.Sequential( + activation(), + normalization( + nn.Conv1d( + in_channels=input_dim, + out_channels=output_dim, + kernel_size=2 * stride, + stride=stride, + padding= stride // 2+ stride % 2, + + ))) + + +class DilatedResidualEncoder(FeedForwardModule): + + def __init__( + self, + capacity: int, + dilated_unit: Type[DilatedConvolutionalUnit], + downsampling_unit: Type[DownsamplingUnit], + ratios: Sequence[int], + dilations: Union[Sequence[int], Sequence[Sequence[int]]], + pre_network_conv: Type[nn.Conv1d], + post_network_conv: Type[nn.Conv1d], + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + channels = capacity * 2**np.arange(len(ratios) + 1) + + dilations_list = self.normalize_dilations(dilations, ratios) + + net = [normalization(pre_network_conv(out_channels=channels[0]))] + + for ratio, dilations, input_dim, output_dim in zip( + ratios, dilations_list, channels[:-1], channels[1:]): + for dilation in dilations: + net.append(Residual(dilated_unit(input_dim, dilation))) + net.append(downsampling_unit(input_dim, output_dim, ratio)) + + net.append(post_network_conv(in_channels=output_dim)) + + self.net = nn.Sequential(*net) + + @staticmethod + def normalize_dilations(dilations: Union[Sequence[int], + Sequence[Sequence[int]]], + ratios: Sequence[int]): + if isinstance(dilations[0], int): + dilations = [dilations for _ in ratios] + return dilations + + +class DilatedResidualDecoder(FeedForwardModule): + + def __init__( + self, + capacity: int, + dilated_unit: Type[DilatedConvolutionalUnit], + upsampling_unit: Type[UpsamplingUnit], + ratios: Sequence[int], + dilations: Union[Sequence[int], Sequence[Sequence[int]]], + pre_network_conv: Type[nn.Conv1d], + post_network_conv: Type[nn.Conv1d], + normalization: Callable[[nn.Module], + nn.Module] = lambda x: x) -> None: + super().__init__() + channels = capacity * 2**np.arange(len(ratios) + 1) + channels = channels[::-1] + + dilations_list = self.normalize_dilations(dilations, ratios) + dilations_list = dilations_list[::-1] + + net = [pre_network_conv(out_channels=channels[0])] + + for ratio, dilations, input_dim, output_dim in zip( + ratios, dilations_list, channels[:-1], channels[1:]): + net.append(upsampling_unit(input_dim, output_dim, ratio)) + for dilation in dilations: + net.append(Residual(dilated_unit(output_dim, dilation))) + + net.append(normalization(post_network_conv(in_channels=output_dim))) + + self.net = nn.Sequential(*net) + + @staticmethod + def normalize_dilations(dilations: Union[Sequence[int], + Sequence[Sequence[int]]], + ratios: Sequence[int]): + if isinstance(dilations[0], int): + dilations = [dilations for _ in ratios] + return dilations \ No newline at end of file diff --git a/vq/bs_roformer5.py b/vq/bs_roformer5.py new file mode 100755 index 0000000000000000000000000000000000000000..08aa016d731a6a5cae3e4f38514d97187ad7adb4 --- /dev/null +++ b/vq/bs_roformer5.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Module, ModuleList +import torchaudio +from einops import rearrange +import numpy as np +# from rotary_embedding_torch import RotaryEmbedding + +from torchtune.modules import RotaryPositionalEmbeddings + + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + r"""https://github.com/meta-llama/llama/blob/main/llama/model.py""" + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm_x = torch.mean(x ** 2, dim=-1, keepdim=True) + output = x * torch.rsqrt(norm_x + self.eps) * self.weight + return output + + + +class MLP(nn.Module): + def __init__(self, dim: int) -> None: + super().__init__() + + self.fc1 = nn.Linear(dim, 4 * dim, bias=False) + self.silu = nn.SiLU() + self.fc2 = nn.Linear(4 * dim, dim, bias=False) + + def forward(self, x): + x = self.fc1(x) + x = self.silu(x) + x = self.fc2(x) + return x + + +class Attention(nn.Module): + + def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): + super().__init__() + + assert dim % n_heads == 0 + + self.n_heads = n_heads + self.dim = dim + self.rotary_embed = rotary_embed + + self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + assert self.flash, "Must have flash attention." + + self.c_attn = nn.Linear(dim, 3 * dim, bias=False) + self.c_proj = nn.Linear(dim, dim, bias=False) + + def forward(self, x): + r""" + Args: + x: (b, t, h*d) + + Constants: + b: batch_size + t: time steps + r: 3 + h: heads_num + d: heads_dim + """ + B, T, C = x.size() + + q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads) + # q, k, v: (b, h, t, d) + + q = self.rotary_embed(q) + k = self.rotary_embed(k) + + if self.flash: + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False) + + y = rearrange(y, 'b h t d -> b t (h d)') + + y = self.c_proj(y) + # shape: (b, t, h*d) + + return y + + +class TransformerBlock(nn.Module): + def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): + + super().__init__() + self.dim = dim + self.n_heads = n_heads + + self.att_norm = RMSNorm(dim) + self.ffn_norm = RMSNorm(dim) + self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed) + self.mlp = MLP(dim=dim) + + + def forward( + self, + x: torch.Tensor, + ): + x = x + self.att(self.att_norm(x)) + x = x + self.mlp(self.ffn_norm(x)) + return x + + +if __name__ == '__main__': + rotary_embed_128 = RotaryPositionalEmbeddings(dim=128) + transformer_block = TransformerBlock( + dim=1024, + n_heads=8, + rotary_embed=rotary_embed_128 + ) + x = torch.randn(2, 128, 1024) + y = transformer_block(x) + print(y.shape) + c=1 \ No newline at end of file diff --git a/vq/codec_decoder.py b/vq/codec_decoder.py new file mode 100755 index 0000000000000000000000000000000000000000..969a395f43d8b8acd37a88aadc76cf7e5a49f269 --- /dev/null +++ b/vq/codec_decoder.py @@ -0,0 +1,304 @@ +import sys + +import numpy as np +import torch +import torch.nn as nn +from vq.residual_vq import ResidualVQ +from vq.module import WNConv1d, DecoderBlock, ResLSTM +from vq.alias_free_torch import * +from vq import activations +import vq.blocks as blocks +from torch.nn import utils + +from vq.bs_roformer5 import TransformerBlock + +from torchtune.modules import RotaryPositionalEmbeddings + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + +class CodecDecoder(nn.Module): + def __init__(self, + in_channels=1024, + upsample_initial_channel=1536, + ngf=48, + use_rnn=True, + rnn_bidirectional=False, + rnn_num_layers=2, + up_ratios=(5, 4, 4, 4, 2), + dilations=(1, 3, 9), + vq_num_quantizers=1, + vq_dim=2048, + vq_commit_weight=0.25, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_size=16384, + codebook_dim=32, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf = ngf + self.up_ratios = up_ratios + + self.quantizer = ResidualVQ( + num_quantizers=vq_num_quantizers, + dim=vq_dim, # double the dim for acousitc and semantic + codebook_size=codebook_size, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + ) + channels = upsample_initial_channel + layers = [WNConv1d(in_channels, channels, kernel_size=7, padding=3)] + + if use_rnn: + layers += [ + ResLSTM(channels, + num_layers=rnn_num_layers, + bidirectional=rnn_bidirectional + ) + ] + + for i, stride in enumerate(up_ratios): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride, dilations)] + + layers += [ + Activation1d(activation=activations.SnakeBeta(output_dim, alpha_logscale=True)), + WNConv1d(output_dim, 1, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + self.reset_parameters() + + def forward(self, x, vq=True): + if vq is True: + x, q, commit_loss = self.quantizer(x) + return x, q, commit_loss + x = self.model(x) + return x + + def vq2emb(self, vq): + self.quantizer = self.quantizer.eval() + x = self.quantizer.vq2emb(vq) + return x + + def get_emb(self): + self.quantizer = self.quantizer.eval() + embs = self.quantizer.get_emb() + return embs + + def inference_vq(self, vq): + x = vq[None,:,:] + x = self.model(x) + return x + + def inference_0(self, x): + x, q, loss, perp = self.quantizer(x) + x = self.model(x) + return x, None + + def inference(self, x): + x = self.model(x) + return x, None + + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + +class CodecDecoder_oobleck_Transformer(nn.Module): + def __init__(self, + ngf=32, + up_ratios=(5, 4, 4, 4, 2), + dilations=(1, 3, 9), + vq_num_quantizers=1, + vq_dim=1024, + vq_commit_weight=0.25, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_size=16384, + codebook_dim=16, + hidden_dim=1024, + depth=12, + heads=16, + pos_meb_dim=64, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.capacity = ngf + self.up_ratios = up_ratios + self.hidden_dim = hidden_dim + self.quantizer = ResidualVQ( + num_quantizers=vq_num_quantizers, + dim=vq_dim, # double the dim for acousitc and semantic + codebook_size=codebook_size, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + ) + + time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + transformer_blocks = [ + TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + for _ in range(depth) + ] + + self.transformers = nn.Sequential(*transformer_blocks) + + self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + + self.conv_blocks = blocks.DilatedResidualDecoder( + capacity=self.capacity, + dilated_unit=self.dilated_unit, + upsampling_unit=self.upsampling_unit, + ratios=up_ratios, # 逆转编码器的下采样比率 + dilations=dilations, + pre_network_conv=self.pre_conv, + post_network_conv=self.post_conv, + ) + + + + self.reset_parameters() + + def forward(self, x, vq=True): + if vq is True: + x, q, commit_loss = self.quantizer(x) + return x, q, commit_loss + x= self.transformers(x) + x = self.final_layer_norm(x) + x = x.permute(0, 2, 1) + x = self.conv_blocks(x) + return x + + def vq2emb(self, vq): + self.quantizer = self.quantizer.eval() + x = self.quantizer.vq2emb(vq) + return x + + def get_emb(self): + self.quantizer = self.quantizer.eval() + embs = self.quantizer.get_emb() + return embs + + def inference_vq(self, vq): + x = vq[None,:,:] + x = self.model(x) + return x + + def inference_0(self, x): + x, q, loss, perp = self.quantizer(x) + x = self.model(x) + return x, None + + def inference(self, x): + x = self.model(x) + return x, None + + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + def pre_conv(self, out_channels): + return nn.Conv1d(in_channels=self.hidden_dim, out_channels=out_channels, kernel_size=1) + + # 定义后处理卷积层,将模型的输出映射到最终的输出通道数 + def post_conv(self,in_channels): + return nn.Conv1d(in_channels=in_channels, out_channels=1, kernel_size=1) + + def dilated_unit(self, hidden_dim, dilation): + return blocks.DilatedConvolutionalUnit( + hidden_dim=hidden_dim, + dilation=dilation, + kernel_size=3, + activation=nn.ReLU , + normalization=utils.weight_norm + ) + + # 定义上采样单元 + def upsampling_unit(self,input_dim, output_dim, stride): + return blocks.UpsamplingUnit( + input_dim=input_dim, + output_dim=output_dim, + stride=stride, + activation=nn.ReLU , + normalization=utils.weight_norm + ) + +def main(): + # 设置设备 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # 初始化模型 + model = CodecDecoder_oobleck_Transformer().to(device) + print("Model initialized.") + + # 创建测试输入: batch_size x in_channels x sequence_length + batch_size = 2 + in_channels = 1024 + sequence_length = 100 # 示例长度,可以根据需要调整 + dummy_input = torch.randn(batch_size, sequence_length, in_channels).to(device) + print(f"Dummy input shape: {dummy_input.shape}") + + # 将模型设为评估模式 + model.eval() + + + + output_no_vq = model(dummy_input, vq=False) + c=1 + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/vq/codec_decoder_vocos.py b/vq/codec_decoder_vocos.py new file mode 100755 index 0000000000000000000000000000000000000000..2bb45820ff92794327e060a2b2a52ad5df644b5b --- /dev/null +++ b/vq/codec_decoder_vocos.py @@ -0,0 +1,638 @@ +import sys +sys.path.append('/aifs4su/data/zheny/bigcodec_final/BigCodec_conv_transformer_vocos') +import numpy as np +import torch +import torch.nn as nn +from vq.residual_vq import ResidualVQ +from vq.module import WNConv1d, DecoderBlock, ResLSTM +from vq.alias_free_torch import * +from vq import activations +from typing import Optional +from vq.module import ConvNeXtBlock, AdaLayerNorm +from vq.bs_roformer5 import TransformerBlock +# from rotary_embedding_torch import RotaryEmbedding +from torchtune.modules import RotaryPositionalEmbeddings +from vector_quantize_pytorch import ResidualFSQ +from torch.nn import Module, ModuleList +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + window = torch.hann_window(win_length) + self.register_buffer("window", window) + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * self.window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + assert (window_envelope > 1e-11).all() + y = y / window_envelope + + return y + + + +class FourierHead(nn.Module): + """Base class for inverse fourier modules.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class ISTFTHead(FourierHead): + """ + ISTFT Head module for predicting STFT complex coefficients. + + Args: + dim (int): Hidden dimension of the model. + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames, which should align with + the resolution of the input features. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): + super().__init__() + out_dim = n_fft + 2 + self.out = torch.nn.Linear(dim, out_dim) + self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the ISTFTHead module. + + Args: + x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, + L is the sequence length, and H denotes the model dimension. + + Returns: + Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. + """ + x_pred = self.out(x ) + # x_pred = x + x_pred = x_pred.transpose(1, 2) + mag, p = x_pred.chunk(2, dim=1) + mag = torch.exp(mag) + mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes + # wrapping happens here. These two lines produce real and imaginary value + x = torch.cos(p) + y = torch.sin(p) + # recalculating phase here does not produce anything new + # only costs time + # phase = torch.atan2(y, x) + # S = mag * torch.exp(phase * 1j) + # better directly produce the complex value + S = mag * (x + 1j * y) + audio = self.istft(S) + return audio.unsqueeze(1),x_pred + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv1d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv1d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb=None): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv1d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h = q.shape + q = q.permute(0, 2, 1) # b,hw,c + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + + h_ = self.proj_out(h_) + + return x + h_ + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + + +class Backbone(nn.Module): + """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + + Returns: + Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, + and H denotes the model dimension. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class VocosBackbone(Backbone): + """ + Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. + num_layers (int): Number of ConvNeXtBlock layers. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional model. Defaults to None. + """ + + def __init__( + self, hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64): + super().__init__() + + self.embed = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=7, padding=3) + + + + self.temb_ch = 0 + block_in = hidden_dim + dropout = 0.1 + + prior_net : tp.List[nn.Module] = [ + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ] + self.prior_net = nn.Sequential(*prior_net) + + depth = depth + time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + + transformer_blocks = [ + TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + for _ in range(depth) + ] + + + self.transformers = nn.Sequential(*transformer_blocks) + self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + post_net : tp.List[nn.Module] = [ + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ResnetBlock(in_channels=block_in,out_channels=block_in, + temb_channels=self.temb_ch,dropout=dropout), + ] + self.post_net = nn.Sequential(*post_net) + + def forward(self, x: torch.Tensor ) -> torch.Tensor: + x = x.transpose(1, 2) + x = self.embed(x) + x = self.prior_net(x) + x = x.transpose(1, 2) + x= self.transformers(x) + x = x.transpose(1, 2) + x = self.post_net(x) + x = x.transpose(1, 2) + x = self.final_layer_norm(x) + return x + + + + + + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + +class CodecDecoderVocos(nn.Module): + def __init__(self, + hidden_dim=1024, + depth=12, + heads=16, + pos_meb_dim=64, + hop_length=320, + vq_num_quantizers=1, + vq_dim=2048, #1024 2048 + vq_commit_weight=0.25, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_size=16384, + codebook_dim=16, + ): + super().__init__() + self.hop_length = hop_length + + self.quantizer = ResidualFSQ( + dim = vq_dim, + levels = [4, 4, 4, 4, 4,4,4,4], + num_quantizers = 1 + ) + + # self.quantizer = ResidualVQ( + # num_quantizers=vq_num_quantizers, + # dim=vq_dim, + # codebook_size=codebook_size, + # codebook_dim=codebook_dim, + # threshold_ema_dead_code=2, + # commitment=vq_commit_weight, + # weight_init=vq_weight_init, + # full_commit_loss=vq_full_commit_loss, + # ) + + + self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim) + + self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same") + + self.reset_parameters() + + def forward(self, x, vq=True): + if vq is True: + # x, q, commit_loss = self.quantizer(x) + x = x.permute(0, 2, 1) + x, q = self.quantizer(x) + x = x.permute(0, 2, 1) + q = q.permute(0, 2, 1) + return x, q, None + x = self.backbone(x) + x,_ = self.head(x) + + return x ,_ + + def vq2emb(self, vq): + self.quantizer = self.quantizer.eval() + x = self.quantizer.vq2emb(vq) + return x + + def get_emb(self): + self.quantizer = self.quantizer.eval() + embs = self.quantizer.get_emb() + return embs + + def inference_vq(self, vq): + x = vq[None,:,:] + x = self.model(x) + return x + + def inference_0(self, x): + x, q, loss, perp = self.quantizer(x) + x = self.model(x) + return x, None + + def inference(self, x): + x = self.model(x) + return x, None + + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + + +class CodecDecoderVocos_transpose(nn.Module): + def __init__(self, + hidden_dim=1024, + depth=12, + heads=16, + pos_meb_dim=64, + hop_length=320, + vq_num_quantizers=1, + vq_dim=1024, #1024 2048 + vq_commit_weight=0.25, + vq_weight_init=False, + vq_full_commit_loss=False, + codebook_size=16384, + codebook_dim=16, + ): + super().__init__() + self.hop_length = hop_length + + + self.quantizer = ResidualVQ( + num_quantizers=vq_num_quantizers, + dim=vq_dim, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + threshold_ema_dead_code=2, + commitment=vq_commit_weight, + weight_init=vq_weight_init, + full_commit_loss=vq_full_commit_loss, + ) + + + self.backbone = VocosBackbone( hidden_dim=hidden_dim,depth=depth,heads=heads,pos_meb_dim=pos_meb_dim) + + self.inverse_mel_conv = nn.Sequential( + nn.GELU(), + nn.ConvTranspose1d( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=2, + padding=1, + output_padding=1 # 确保输出长度与编码前匹配 + ), + nn.GELU(), + nn.ConvTranspose1d( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + padding=1 + ) + ) + + self.head = ISTFTHead(dim=hidden_dim, n_fft=self.hop_length*4, hop_length=self.hop_length, padding="same") + + self.reset_parameters() + + def forward(self, x, vq=True): + if vq is True: + x, q, commit_loss = self.quantizer(x) + return x, q, commit_loss + x = self.backbone(x) + x,_ = self.head(x) + + return x ,_ + + def vq2emb(self, vq): + self.quantizer = self.quantizer.eval() + x = self.quantizer.vq2emb(vq) + return x + + def get_emb(self): + self.quantizer = self.quantizer.eval() + embs = self.quantizer.get_emb() + return embs + + def inference_vq(self, vq): + x = vq[None,:,:] + x = self.model(x) + return x + + def inference_0(self, x): + x, q, loss, perp = self.quantizer(x) + x = self.model(x) + return x, None + + def inference(self, x): + x = self.model(x) + return x, None + + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + + + +def main(): + # 设置设备 + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + + # 初始化模型 + model = CodecDecoderVocos_transpose().to(device) + print("Model initialized.") + + # 创建测试输入: batch_size x in_channels x sequence_length + batch_size = 2 + in_channels = 1024 + sequence_length = 50 # 示例长度,可以根据需要调整 + dummy_input = torch.randn(batch_size, in_channels, sequence_length).to(device) + print(f"Dummy input shape: {dummy_input.shape}") + + # 将模型设为评估模式 + model.eval() + + # 前向传播(使用 VQ) + # with torch.no_grad(): + # try: + # output, q, commit_loss = model(dummy_input, vq=True) + # print("Forward pass with VQ:") + # print(f"Output shape: {output.shape}") + # print(f"Quantized codes shape: {q.shape}") + # print(f"Commitment loss: {commit_loss}") + # except Exception as e: + # print(f"Error during forward pass with VQ: {e}") + + # 前向传播(不使用 VQ) + with torch.no_grad(): + # try: + output_no_vq = model(dummy_input, vq=False) + print("\nForward pass without VQ:") + print(f"Output shape: {output_no_vq.shape}") + c=1 + # except Exception as e: + # print(f"Error during forward pass without VQ: {e}") + + + # model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters()) + # model_size_mb = model_size_bytes / (1024 ** 2) + # print(f"Model size: {model_size_bytes} bytes ({model_size_mb:.2f} MB)") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/vq/codec_encoder.py b/vq/codec_encoder.py new file mode 100755 index 0000000000000000000000000000000000000000..7ef7cc7b29bfeed8f9c791f7f8ae8c33c319678a --- /dev/null +++ b/vq/codec_encoder.py @@ -0,0 +1,335 @@ +import sys + +import torch +from torch import nn +import numpy as np +from vq.module import WNConv1d, EncoderBlock, ResLSTM +from vq.alias_free_torch import * +from vq import activations +from vq.bs_roformer5 import TransformerBlock + +from torchtune.modules import RotaryPositionalEmbeddings +import vq.blocks as blocks +from torch.nn import utils +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + +class CodecEncoder(nn.Module): + def __init__(self, + ngf=48, + use_rnn=True, + rnn_bidirectional=False, + rnn_num_layers=2, + up_ratios=(2, 2, 4, 4, 5), + dilations=(1, 3, 9), + out_channels=1024): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf = ngf + self.up_ratios = up_ratios + + # Create first convolution + d_model = ngf + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for i, stride in enumerate(up_ratios): + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride, dilations=dilations)] + # RNN + if use_rnn: + self.block += [ + ResLSTM(d_model, + num_layers=rnn_num_layers, + bidirectional=rnn_bidirectional + ) + ] + # Create last convolution + self.block += [ + Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), + WNConv1d(d_model, out_channels, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + self.reset_parameters() + + def forward(self, x): + out = self.block(x) + return out + + def inference(self, x): + return self.block(x) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + +class Transpose(nn.Module): + def __init__(self, dim1, dim2): + super(Transpose, self).__init__() + self.dim1 = dim1 + self.dim2 = dim2 + + def forward(self, x): + return x.transpose(self.dim1, self.dim2) + +class CodecEncoder_Transformer(nn.Module): + def __init__(self, + ngf=48, + up_ratios=[2, 2, 4, 4, 5], + dilations=(1, 3, 9), + hidden_dim=1024, + depth=12, + heads=12, + pos_meb_dim=64, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf =ngf + self.up_ratios = up_ratios + + d_model = ngf + self.conv_blocks = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + + for i, stride in enumerate(up_ratios): + d_model *= 2 + self.conv_blocks += [EncoderBlock(d_model, stride=stride, dilations=dilations)] + + self.conv_blocks = nn.Sequential(*self.conv_blocks) + + + # time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + + # transformer_blocks = [ + # TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + # for _ in range(depth) + # ] + + + # self.transformers = nn.Sequential(*transformer_blocks) + + # self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + + self.conv_final_block = [ + Activation1d(activation=activations.SnakeBeta(d_model, alpha_logscale=True)), + WNConv1d(d_model, hidden_dim, kernel_size=3, padding=1), + ] + self.conv_final_block = nn.Sequential(*self.conv_final_block) + + self.reset_parameters() + + def forward(self, x): + x = self.conv_blocks(x) + # x = x.permute(0, 2, 1) + # x= self.transformers(x) + # x = self.final_layer_norm(x) + # x = x.permute(0, 2, 1) + x = self.conv_final_block (x) + x = x.permute(0, 2, 1) + return x + + def inference(self, x): + return self.block(x) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + + +class Codec_oobleck_Transformer(nn.Module): + def __init__(self, + ngf=32, + up_ratios=(2, 2,4,4, 5), + dilations=(1, 3, 9), + hidden_dim=1024, + depth=12, + heads=16, + pos_meb_dim=64, + ): + super().__init__() + self.hop_length = np.prod(up_ratios) + self.ngf =ngf + self.up_ratios = up_ratios + self.hidden_dim = hidden_dim + + + self.conv_blocks = blocks.DilatedResidualEncoder( + capacity=ngf, + dilated_unit=self.dilated_unit, + downsampling_unit=self.downsampling_unit, + ratios=up_ratios, + dilations=dilations, + pre_network_conv=self.pre_conv, + post_network_conv=self.post_conv, + ) + + + time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + transformer_blocks = [ + TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + for _ in range(depth) + ] + + self.transformers = nn.Sequential(*transformer_blocks) + + self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + + + self.reset_parameters() + + def forward(self, x): + x = self.conv_blocks(x) + x = x.permute(0, 2, 1) + x= self.transformers(x) + x = self.final_layer_norm(x) + return x + + def inference(self, x): + return self.block(x) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + + def _apply_weight_norm(m): + if isinstance(m, nn.Conv1d): + torch.nn.utils.weight_norm(m) + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + self.apply(init_weights) + + def dilated_unit(self,hidden_dim, dilation): + return blocks.DilatedConvolutionalUnit(hidden_dim, + dilation, + kernel_size=3, + activation=nn.ReLU, + normalization=utils.weight_norm) + + def downsampling_unit(self, input_dim: int, output_dim: int, stride: int): + return blocks.DownsamplingUnit(input_dim, + output_dim, + stride, + nn.ReLU, + normalization=utils.weight_norm) + + def pre_conv(self,out_channels): + return nn.Conv1d(1, out_channels, 1) + + def post_conv(self,in_channels): + return nn.Conv1d(in_channels, self.hidden_dim, 1) + + + + + +class CodecEncoder_only_Transformer(nn.Module): + def __init__(self,hidden_dim=1024,depth=12,heads=16,pos_meb_dim=64): + super().__init__() + # self.embed = nn.Linear(input_dim, hidden_dim )input_dim=300, + + depth = depth + time_rotary_embed = RotaryPositionalEmbeddings(dim=pos_meb_dim) + + + transformer_blocks = [ + TransformerBlock(dim=hidden_dim, n_heads=heads, rotary_embed=time_rotary_embed) + for _ in range(depth) + ] + + + self.transformers = nn.Sequential(*transformer_blocks) + + self.final_layer_norm = nn.LayerNorm(hidden_dim, eps=1e-6) + + def forward(self, x: torch.Tensor ) -> torch.Tensor: + # x = self.embed(x) + + + x= self.transformers(x) + x = self.final_layer_norm(x) + + return x + + + + + + + +def get_model_size(model): + # 计算总参数数 + total_params = sum(p.numel() for p in model.parameters()) + + # 假设每个参数都是32位浮点数,计算模型大小(以字节为单位) + model_size_bytes = total_params # 每个参数4字节 + + # 转换为更易读的单位(例如,MB) + model_size_mb = model_size_bytes / (1024 ** 2) + + return total_params, model_size_mb + +if __name__ == '__main__': + model = Codec_oobleck_Transformer() + x = torch.randn(1, 1, 16000) # example input tensor + output = model(x) + print("Output shape:", output.shape) diff --git a/vq/factorized_vector_quantize.py b/vq/factorized_vector_quantize.py new file mode 100755 index 0000000000000000000000000000000000000000..35f0c66736112f771a1933ca7e156b8cd5259e66 --- /dev/null +++ b/vq/factorized_vector_quantize.py @@ -0,0 +1,109 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +class FactorizedVectorQuantize(nn.Module): + def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.commitment = commitment + + if dim != self.codebook_dim: + self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim)) + self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim)) + else: + self.in_proj = nn.Identity() + self.out_proj = nn.Identity() + self._codebook = nn.Embedding(codebook_size, self.codebook_dim) + + @property + def codebook(self): + return self._codebook + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + # transpose since we use linear + + z = rearrange(z, "b d t -> b t d") + + # Factorized codes project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x T x D) + z_e = rearrange(z_e, "b t d -> b d t") + z_q, indices = self.decode_latents(z_e) + + + if self.training: + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction='none').mean([1, 2]) * self.commitment + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction='none').mean([1, 2]) + commit_loss = commitment_loss + codebook_loss + else: + commit_loss = torch.zeros(z.shape[0], device = z.device) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = rearrange(z_q, "b d t -> b t d") + z_q = self.out_proj(z_q) + z_q = rearrange(z_q, "b t d -> b d t") + + return z_q, indices, commit_loss + + def vq2emb(self, vq, proj=True): + emb = self.embed_code(vq) + if proj: + emb = self.out_proj(emb) + return emb + + def get_emb(self): + return self.codebook.weight + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices \ No newline at end of file diff --git a/vq/module.py b/vq/module.py new file mode 100755 index 0000000000000000000000000000000000000000..0c4f69b351abbc3906ced487f4609ed784c29975 --- /dev/null +++ b/vq/module.py @@ -0,0 +1,420 @@ +import torch.nn as nn +from einops import rearrange +from . import activations +from .alias_free_torch import * +from torch.nn.utils import weight_norm + +from typing import Optional, Tuple + +from torch.nn.utils import weight_norm, remove_weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Activation1d(activation=activations.SnakeBeta(dim, alpha_logscale=True)), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + return x + self.block(x) + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1, dilations = (1, 3, 9)): + super().__init__() + runits = [ResidualUnit(dim // 2, dilation=d) for d in dilations] + self.block = nn.Sequential( + *runits, + Activation1d(activation=activations.SnakeBeta(dim//2, alpha_logscale=True)), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + ), + ) + + def forward(self, x): + return self.block(x) + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, dilations = (1, 3, 9)): + super().__init__() + self.block = nn.Sequential( + Activation1d(activation=activations.SnakeBeta(input_dim, alpha_logscale=True)), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=stride // 2 + stride % 2, + output_padding= stride % 2, + ) + ) + self.block.extend([ResidualUnit(output_dim, dilation=d) for d in dilations]) + + def forward(self, x): + return self.block(x) + +class ResLSTM(nn.Module): + def __init__(self, dimension: int, + num_layers: int = 2, + bidirectional: bool = False, + skip: bool = True): + super().__init__() + self.skip = skip + self.lstm = nn.LSTM(dimension, dimension if not bidirectional else dimension // 2, + num_layers, batch_first=True, + bidirectional=bidirectional) + + def forward(self, x): + """ + Args: + x: [B, F, T] + + Returns: + y: [B, F, T] + """ + x = rearrange(x, "b f t -> b t f") + y, _ = self.lstm(x) + if self.skip: + y = y + x + y = rearrange(y, "b t f -> b f t") + return y + + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: float, + adanorm_num_embeddings: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.adanorm = adanorm_num_embeddings is not None + if adanorm_num_embeddings: + self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward(self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + num_embeddings (int): Number of embeddings. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + self.shift = nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding_id) + shift = self.shift(cond_embedding_id) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale + shift + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: Tuple[int, int, int] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: Optional[float] = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + weight_norm(nn.Conv1d(dim, dim, kernel_size, 1, dilation=1, padding=self.get_padding(kernel_size, 1))), + ] + ) + + self.gamma = nn.ParameterList( + [ + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + nn.Parameter(layer_scale_init_value * torch.ones(dim, 1), requires_grad=True) + if layer_scale_init_value is not None + else None, + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor: + """ + Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values. + + Args: + x (Tensor): Input tensor. + clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7. + + Returns: + Tensor: Element-wise logarithm of the input tensor with clipping applied. + """ + return torch.log(torch.clip(x, min=clip_val)) + + +def symlog(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * torch.log1p(x.abs()) + + +def symexp(x: torch.Tensor) -> torch.Tensor: + return torch.sign(x) * (torch.exp(x.abs()) - 1) + + + +class SemanticEncoder(nn.Module): + def __init__( + self, + input_channels: int, + code_dim: int, + encode_channels: int, + kernel_size: int = 3, + bias: bool = True, + ): + super(SemanticEncoder, self).__init__() + + # 初始卷积,将 input_channels 映射到 encode_channels + self.initial_conv = nn.Conv1d( + in_channels=input_channels, + out_channels=encode_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + + # 残差块 + self.residual_blocks = nn.Sequential( + nn.ReLU(inplace=True), + nn.Conv1d( + encode_channels, + encode_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=bias + ), + nn.ReLU(inplace=True), + nn.Conv1d( + encode_channels, + encode_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=bias + ) + ) + + # 最终卷积,将 encode_channels 映射到 code_dim + self.final_conv = nn.Conv1d( + in_channels=encode_channels, + out_channels=code_dim, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + + def forward(self, x): + """ + 前向传播方法。 + + Args: + x (Tensor): 输入张量,形状为 (Batch, Input_channels, Length) + + Returns: + Tensor: 编码后的张量,形状为 (Batch, Code_dim, Length) + """ + x = self.initial_conv(x) # (Batch, Encode_channels, Length) + x = self.residual_blocks(x) + x # 残差连接 + x = self.final_conv(x) # (Batch, Code_dim, Length) + return x + +class SemanticDecoder(nn.Module): + def __init__( + self, + code_dim: int, + output_channels: int, + decode_channels: int, + kernel_size: int = 3, + bias: bool = True, + ): + super(SemanticDecoder, self).__init__() + + # Initial convolution to map code_dim to decode_channels + self.initial_conv = nn.Conv1d( + in_channels=code_dim, + out_channels=decode_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + + # Residual Blocks + self.residual_blocks = nn.Sequential( + nn.ReLU(inplace=True), + nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias), + nn.ReLU(inplace=True), + nn.Conv1d(decode_channels, decode_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias) + ) + + # Final convolution to map decode_channels to output_channels + self.final_conv = nn.Conv1d( + in_channels=decode_channels, + out_channels=output_channels, + kernel_size=kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=False + ) + + def forward(self, z): + # z: (Batch, Code_dim, Length) + x = self.initial_conv(z) # (Batch, Decode_channels, Length) + x = self.residual_blocks(x) + x # Residual connection + x = self.final_conv(x) # (Batch, Output_channels, Length) + return x \ No newline at end of file diff --git a/vq/residual_vq.py b/vq/residual_vq.py new file mode 100755 index 0000000000000000000000000000000000000000..40d3338fd940aa2e41177827d6e24c5269765b86 --- /dev/null +++ b/vq/residual_vq.py @@ -0,0 +1,53 @@ +import math +import torch +from torch import nn +from .factorized_vector_quantize import FactorizedVectorQuantize + +class ResidualVQ(nn.Module): + def __init__( + self, + *, + num_quantizers, + codebook_size, + **kwargs + ): + super().__init__() + VQ = FactorizedVectorQuantize + if type(codebook_size) == int: + codebook_size = [codebook_size] * num_quantizers + self.layers = nn.ModuleList([VQ(codebook_size=size, **kwargs) for size in codebook_size]) + self.num_quantizers = num_quantizers + + def forward(self, x): + quantized_out = 0. + residual = x + + all_losses = [] + all_indices = [] + + for idx, layer in enumerate(self.layers): + quantized, indices, loss = layer(residual) + + residual = residual - quantized + + quantized_out = quantized_out + quantized + + loss = loss.mean() + + all_indices.append(indices) + all_losses.append(loss) + all_losses, all_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, all_indices, all_losses + + def vq2emb(self, vq, proj=True): + # [B, T, num_quantizers] + quantized_out = 0. + for idx, layer in enumerate(self.layers): + quantized = layer.vq2emb(vq[:, :, idx], proj=proj) + quantized_out = quantized_out + quantized + return quantized_out + def get_emb(self): + embs = [] + for idx, layer in enumerate(self.layers): + embs.append(layer.get_emb()) + return embs diff --git a/vq/unet.py b/vq/unet.py new file mode 100755 index 0000000000000000000000000000000000000000..ca31029d0866b61663f75045c7770cc7208d9482 --- /dev/null +++ b/vq/unet.py @@ -0,0 +1,210 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +import numpy as np + + +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): + super(EncoderBlock, self).__init__() + + self.pool_size = 2 + + self.conv_block = ConvBlock(in_channels, out_channels, kernel_size) + + def forward(self, x): + latent = self.conv_block(x) + output = F.avg_pool2d(latent, kernel_size=self.pool_size) + return output, latent + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): + super(DecoderBlock, self).__init__() + + stride = 2 + + self.upsample = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=stride, + stride=stride, + padding=(0, 0), + bias=False, + ) + + self.conv_block = ConvBlock(in_channels * 2, out_channels, kernel_size) + + def forward(self, x, latent): + x = self.upsample(x) + x = torch.cat((x, latent), dim=1) + output = self.conv_block(x) + return output + + +class UNet(nn.Module): + def __init__(self,freq_dim=1281,out_channel=1024): + super(UNet, self).__init__() + + self.downsample_ratio = 16 + + + in_channels = 1 #self.audio_channels * self.cmplx_num + + self.encoder_block1 = EncoderBlock(in_channels, 16) + self.encoder_block2 = EncoderBlock(16, 64) + self.encoder_block3 = EncoderBlock(64, 256) + self.encoder_block4 = EncoderBlock(256, 1024) + self.middle = EncoderBlock(1024, 1024) + self.decoder_block1 = DecoderBlock(1024, 256) + self.decoder_block2 = DecoderBlock(256, 64) + self.decoder_block3 = DecoderBlock(64, 16) + self.decoder_block4 = DecoderBlock(16, 16) + + self.fc = nn.Linear(freq_dim*16, out_channel) + + def forward(self, x_ori): + """ + Args: + complex_sp: (batch_size, channels_num, time_steps, freq_bins),复数张量 + + Returns: + output: (batch_size, channels_num, time_steps, freq_bins),复数张量 + """ + + + x= self.process_image(x_ori) + x1, latent1 = self.encoder_block1(x) + x2, latent2 = self.encoder_block2(x1) + x3, latent3 = self.encoder_block3(x2) + x4, latent4 = self.encoder_block4(x3) + _, h = self.middle(x4) + x5 = self.decoder_block1(h, latent4) + x6 = self.decoder_block2(x5, latent3) + x7 = self.decoder_block3(x6, latent2) + x8 = self.decoder_block4(x7, latent1) + x= self.unprocess_image(x8,x_ori.shape[2]) + x = x.permute(0, 2, 1, 3).contiguous() # 将形状变为 [6, 256, 16, 1024] + x = x.view(x.size(0), x.size(1), -1) + x= self.fc(x) + + return x + + def process_image(self, x): + """ + 处理频谱以便可以被 downsample_ratio 整除。 + + Args: + x: (B, C, T, F) + + Returns: + output: (B, C, T_padded, F_reduced) + """ + + B, C, T, Freq = x.shape + + pad_len = ( + int(np.ceil(T / self.downsample_ratio)) * self.downsample_ratio + - T + ) + x = F.pad(x, pad=(0, 0, 0, pad_len)) + + output = x[:, :, :, 0 : Freq - 1] + + return output + + def unprocess_image(self, x,time_steps): + """ + 恢复频谱到原始形状。 + + Args: + x: (B, C, T_padded, F_reduced) + + Returns: + output: (B, C, T_original, F_original) + """ + x = F.pad(x, pad=(0, 1)) + + output = x[:, :,0:time_steps, :] + + return output + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=(3, 3)): + super(ConvBlock, self).__init__() + + padding = [kernel_size[0] // 2, kernel_size[1] // 2] + + self.bn1 = nn.BatchNorm2d(in_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=False, + ) + + if in_channels != out_channels: + self.shortcut = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1), + padding=(0, 0), + ) + self.is_shortcut = True + else: + self.is_shortcut = False + + def forward(self, x): + h = self.conv1(F.leaky_relu_(self.bn1(x))) + h = self.conv2(F.leaky_relu_(self.bn2(h))) + + if self.is_shortcut: + return self.shortcut(x) + h + else: + return x + h + + +def test_unet(): + # 定义输入参数 + batch_size = 6 + channels = 1 # 音频通道数 + time_steps = 256 # 时间步数 + freq_bins = 1024 # 频率 bins 数 + + # 创建一个随机的复数张量作为输入 + real_part = torch.randn(batch_size, channels, time_steps, freq_bins) + imag_part = torch.randn(batch_size, channels, time_steps, freq_bins) + complex_sp = real_part #torch.complex(real_part, imag_part) + + # 实例化 UNet 模型 + model = UNet() + + # 前向传播 + output = model(complex_sp) + + # 输出输入和输出的形状 + print("输入形状:", complex_sp.shape) + print("输出形状:", output.shape) + + # 检查输出是否为复数张量 + assert torch.is_complex(output), "输出不是复数张量" + + # 检查输出形状是否与输入形状一致 + assert output.shape == complex_sp.shape, "输出形状与输入形状不一致" + + print("测试通过,模型正常工作。") + +# 运行测试函数 +if __name__ == "__main__": + test_unet() \ No newline at end of file