diff --git a/.gitattributes b/.gitattributes
index 5b2c2cefa8481bbfc91b2d73fe574301a8f16d78..fbe3882f375b9158c51a9ea099d6c0bfbf5979e1 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,43 +1,47 @@
-*.7z filter=lfs diff=lfs merge=lfs -text
-*.arrow filter=lfs diff=lfs merge=lfs -text
-*.bin filter=lfs diff=lfs merge=lfs -text
-*.bz2 filter=lfs diff=lfs merge=lfs -text
-*.ckpt filter=lfs diff=lfs merge=lfs -text
-*.ftz filter=lfs diff=lfs merge=lfs -text
-*.gz filter=lfs diff=lfs merge=lfs -text
-*.h5 filter=lfs diff=lfs merge=lfs -text
-*.joblib filter=lfs diff=lfs merge=lfs -text
-*.lfs.* filter=lfs diff=lfs merge=lfs -text
-*.mlmodel filter=lfs diff=lfs merge=lfs -text
-*.model filter=lfs diff=lfs merge=lfs -text
-*.msgpack filter=lfs diff=lfs merge=lfs -text
-*.npy filter=lfs diff=lfs merge=lfs -text
-*.npz filter=lfs diff=lfs merge=lfs -text
-*.onnx filter=lfs diff=lfs merge=lfs -text
-*.ot filter=lfs diff=lfs merge=lfs -text
-*.parquet filter=lfs diff=lfs merge=lfs -text
-*.pb filter=lfs diff=lfs merge=lfs -text
-*.pickle filter=lfs diff=lfs merge=lfs -text
-*.pkl filter=lfs diff=lfs merge=lfs -text
-*.pt filter=lfs diff=lfs merge=lfs -text
-*.pth filter=lfs diff=lfs merge=lfs -text
-*.rar filter=lfs diff=lfs merge=lfs -text
-*.safetensors filter=lfs diff=lfs merge=lfs -text
-saved_model/**/* filter=lfs diff=lfs merge=lfs -text
-*.tar.* filter=lfs diff=lfs merge=lfs -text
-*.tar filter=lfs diff=lfs merge=lfs -text
-*.tflite filter=lfs diff=lfs merge=lfs -text
-*.tgz filter=lfs diff=lfs merge=lfs -text
-*.wasm filter=lfs diff=lfs merge=lfs -text
-*.xz filter=lfs diff=lfs merge=lfs -text
-*.zip filter=lfs diff=lfs merge=lfs -text
-*.zst filter=lfs diff=lfs merge=lfs -text
-*tfevents* filter=lfs diff=lfs merge=lfs -text
-examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text
-examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text
-examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
-examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
-examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
-examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text
-examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text
-examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text
+examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text
+examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text
+examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text
+examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text
+examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text
+examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text
+examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text
+modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps filter=lfs diff=lfs merge=lfs -text
+modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o filter=lfs diff=lfs merge=lfs -text
+modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd filter=lfs diff=lfs merge=lfs -text
+modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..28c115f84cdbae6dd450777a7bbe5c23f130a3ec
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,28 @@
+# general things to ignore
+.DS_Store
+build/
+build_contrib/
+dist/
+.cache/
+*.egg-info/
+*.egg
+*.py[cod]
+__pycache__/
+*.so
+*~
+
+# IDE
+.vscode/
+.idea/
+
+# misc
+checkpoints/
+test_waves/
+reconstructed/
+.python-version
+ruff.log
+/configs/inuse/
+runs/
+/garbages/
+/flagged/
+/experimental/
diff --git a/README.md b/README.md
index eb7e29cb828e7162313500060b2137ae6137e8dd..887bc8c01db2950bd731f3f2aaa2ec1fdcdfebe3 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,13 @@
----
-title: Seed Voice Conversion
-emoji: 🎤🔄
-colorFrom: green
-colorTo: green
-sdk: gradio
-sdk_version: 4.42.0
-app_file: app.py
-pinned: false
-license: gpl-3.0
----
-
+---
+title: Seed Voice Conversion
+emoji: 🎤🔄
+colorFrom: green
+colorTo: green
+sdk: gradio
+sdk_version: 5.23.0
+app_file: app_v1v2.py
+pinned: false
+license: gpl-3.0
+---
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
\ No newline at end of file
diff --git a/app_v1v2.py b/app_v1v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5e63efa4da6092483cf663e81eacd0a56886c4f
--- /dev/null
+++ b/app_v1v2.py
@@ -0,0 +1,175 @@
+import spaces
+import gradio as gr
+import torch
+import yaml
+import argparse
+from seed_vc_wrapper import SeedVCWrapper
+
+# Set up device and torch configurations
+if torch.cuda.is_available():
+ device = torch.device("cuda")
+elif torch.backends.mps.is_available():
+ device = torch.device("mps")
+else:
+ device = torch.device("cpu")
+
+torch._inductor.config.coordinate_descent_tuning = True
+torch._inductor.config.triton.unique_kernel_names = True
+
+if hasattr(torch._inductor.config, "fx_graph_cache"):
+ # Experimental feature to reduce compilation times, will be on by default in future
+ torch._inductor.config.fx_graph_cache = True
+
+dtype = torch.float16
+
+def load_v2_models(args):
+ from hydra.utils import instantiate
+ from omegaconf import DictConfig
+ cfg = DictConfig(yaml.safe_load(open("configs/v2/vc_wrapper.yaml", "r")))
+ vc_wrapper = instantiate(cfg)
+ vc_wrapper.load_checkpoints()
+ vc_wrapper.to(device)
+ vc_wrapper.eval()
+
+ vc_wrapper.setup_ar_caches(max_batch_size=1, max_seq_len=4096, dtype=dtype, device=device)
+
+ if args.compile:
+ vc_wrapper.compile_ar()
+ # vc_wrapper.compile_cfm()
+
+ return vc_wrapper
+
+def create_v1_interface():
+ # Initialize the V1 wrapper
+ vc_wrapper = SeedVCWrapper()
+
+ # Set up Gradio interface
+ description = ("Zero-shot voice conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
+ "for details and updates.
Note that any reference audio will be forcefully clipped to 25s if beyond this length.
"
+ "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.
"
+ "无需训练的 zero-shot 语音/歌声转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)
"
+ "请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。
若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。")
+
+ inputs = [
+ gr.Audio(type="filepath", label="Source Audio / 源音频"),
+ gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
+ gr.Slider(minimum=1, maximum=200, value=10, step=1, label="Diffusion Steps / 扩散步数",
+ info="10 by default, 50~100 for best quality / 默认为 10,50~100 为最佳质量"),
+ gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整",
+ info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate",
+ info="has subtle influence / 有微小影响"),
+ gr.Checkbox(label="Use F0 conditioned model / 启用F0输入", value=False,
+ info="Must set to true for singing voice conversion / 歌声转换时必须勾选"),
+ gr.Checkbox(label="Auto F0 adjust / 自动F0调整", value=True,
+ info="Roughly adjust F0 to match target voice. Only works when F0 conditioned model is used. / 粗略调整 F0 以匹配目标音色,仅在勾选 '启用F0输入' 时生效"),
+ gr.Slider(label='Pitch shift / 音调变换', minimum=-24, maximum=24, step=1, value=0,
+ info="Pitch shift in semitones, only works when F0 conditioned model is used / 半音数的音高变换,仅在勾选 '启用F0输入' 时生效"),
+ ]
+
+ examples = [
+ ["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, False, True, 0],
+ ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, True, True, 0],
+ ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav",
+ "examples/reference/teio_0.wav", 100, 1.0, 0.7, True, False, 0],
+ ["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav",
+ "examples/reference/trump_0.wav", 50, 1.0, 0.7, True, False, -12],
+ ]
+
+ outputs = [
+ gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
+ gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
+ ]
+
+ return gr.Interface(
+ fn=vc_wrapper.convert_voice,
+ description=description,
+ inputs=inputs,
+ outputs=outputs,
+ title="Seed Voice Conversion V1 (Voice & Singing Voice Conversion)",
+ examples=examples,
+ cache_examples=False,
+ )
+
+def create_v2_interface(vc_wrapper):
+ # Set up Gradio interface
+ description = ("Zero-shot voice/style conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) "
+ "for details and updates.
Note that any reference audio will be forcefully clipped to 25s if beyond this length.
"
+ "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.
"
+ "Please click the 'convert style/emotion/accent' checkbox to convert the style, emotion, or accent of the source audio, or else only timbre conversion will be performed.
"
+ "Click the 'anonymization only' checkbox will ignore reference audio but convert source to an 'average voice' determined by model itself.
"
+ "无需训练的 zero-shot 语音/口音转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)
"
+ "请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。
若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。"
+ "
请勾选 'convert style/emotion/accent' 以转换源音频的风格、情感或口音,否则仅执行音色转换。
"
+ "勾选 'anonymization only' 会无视参考音频而将源音频转换为某种由模型自身决定的 '平均音色'。
"
+
+ "Credits to [Vevo](https://github.com/open-mmlab/Amphion/tree/main/models/vc/vevo)"
+ )
+ inputs = [
+ gr.Audio(type="filepath", label="Source Audio / 源音频"),
+ gr.Audio(type="filepath", label="Reference Audio / 参考音频"),
+ gr.Slider(minimum=1, maximum=200, value=30, step=1, label="Diffusion Steps / 扩散步数",
+ info="30 by default, 50~100 for best quality / 默认为 30,50~100 为最佳质量"),
+ gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整",
+ info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"),
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="Intelligibility CFG Rate",
+ info="controls pronunciation intelligibility / 控制发音清晰度"),
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Similarity CFG Rate",
+ info="controls similarity to reference audio / 控制与参考音频的相似度"),
+ gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.9, label="Top-p",
+ info="AR model sampling top P"),
+ gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature",
+ info="AR model sampling temperature"),
+ gr.Slider(minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty",
+ info="AR model sampling repetition penalty"),
+ gr.Checkbox(label="convert style/emotion/accent", value=False),
+ gr.Checkbox(label="anonymization only", value=False),
+ ]
+
+ examples = [
+ ["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 50, 1.0, 0.0, 0.7, 0.9, 1.0, 1.0, False, False],
+ ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 50, 1.0, 0.0, 0.7, 0.9, 1.0, 1.0, False, False],
+ ]
+
+ outputs = [
+ gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'),
+ gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav')
+ ]
+
+ return gr.Interface(
+ fn=vc_wrapper.convert_voice_with_streaming,
+ description=description,
+ inputs=inputs,
+ outputs=outputs,
+ title="Seed Voice Conversion V2 (Voice & Style Conversion)",
+ examples=examples,
+ cache_examples=False,
+ )
+
+def main(args):
+ # Load V2 models
+ vc_wrapper_v2 = load_v2_models(args)
+
+ # Create interfaces
+ v1_interface = create_v1_interface()
+ v2_interface = create_v2_interface(vc_wrapper_v2)
+
+ # Create tabs
+ with gr.Blocks(title="Seed Voice Conversion") as demo:
+ gr.Markdown("# Seed Voice Conversion")
+ gr.Markdown("Choose between V1 (Voice & Singing Voice Conversion) or V2 (Voice & Style Conversion)")
+
+ with gr.Tabs():
+ with gr.TabItem("V2 - Voice & Style Conversion"):
+ v2_interface.render()
+ with gr.TabItem("V1 - Voice & Singing Voice Conversion"):
+ v1_interface.render()
+
+ # Launch the combined interface
+ demo.launch()
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--compile", type=bool, default=True)
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/configs/astral_quantization/default_2048.yml b/configs/astral_quantization/default_2048.yml
new file mode 100644
index 0000000000000000000000000000000000000000..54f91e7cea7722a8cdb85c9855bffeeedb84e5db
--- /dev/null
+++ b/configs/astral_quantization/default_2048.yml
@@ -0,0 +1,40 @@
+_target_: modules.astral_quantization.default_model.AstralQuantizer
+tokenizer_name: "openai/whisper-small"
+ssl_model_name: "facebook/hubert-large-ll60k"
+ssl_output_layer: 18
+encoder:
+ _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
+ dim: 512
+ num_blocks: 12
+ intermediate_dim: 1536
+ dilation: 1
+ input_dim: 1024
+quantizer:
+ _target_: modules.astral_quantization.bsq.BinarySphericalQuantize
+ codebook_size: 2048 # codebook size, must be a power of 2
+ dim: 512
+ entropy_loss_weight: 0.1
+ diversity_gamma: 1.0
+ spherical: True
+ enable_entropy_loss: True
+ soft_entropy_loss: True
+decoder:
+ _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
+ dim: 512
+ num_blocks: 12
+ intermediate_dim: 1536
+ dilation: 1
+ output_dim: 1024
+ gin_channels: 192
+asr_decoder:
+ _target_: modules.astral_quantization.asr_decoder.ASRDecoder
+ hidden_dim: 768
+ num_heads: 12
+ depth: 12
+ block_size: 4096
+ in_channels: 512
+ n_vocab: 51866
+ bos_id: 50528
+ eos_id: 50527
+ dropout_rate: 0.0
+ attn_dropout_rate: 0.0
\ No newline at end of file
diff --git a/configs/astral_quantization/default_32.yml b/configs/astral_quantization/default_32.yml
new file mode 100644
index 0000000000000000000000000000000000000000..bf160129893fb893eed26eabcb8a2da9c42a7159
--- /dev/null
+++ b/configs/astral_quantization/default_32.yml
@@ -0,0 +1,40 @@
+_target_: default_model.AstralQuantizer
+tokenizer_name: "openai/whisper-small"
+ssl_model_name: "facebook/hubert-large-ll60k"
+ssl_output_layer: 18
+encoder:
+ _target_: modules.convnext.ConvNeXtV2Stage
+ dim: 512
+ num_blocks: 12
+ intermediate_dim: 1536
+ dilation: 1
+ input_dim: 1024
+quantizer:
+ _target_: modules.bsq.BinarySphericalQuantize
+ codebook_size: 32 # codebook size, must be a power of 2
+ dim: 512
+ entropy_loss_weight: 0.1
+ diversity_gamma: 1.0
+ spherical: True
+ enable_entropy_loss: True
+ soft_entropy_loss: True
+decoder:
+ _target_: modules.convnext.ConvNeXtV2Stage
+ dim: 512
+ num_blocks: 12
+ intermediate_dim: 1536
+ dilation: 1
+ output_dim: 1024
+ gin_channels: 192
+asr_decoder:
+ _target_: modules.asr_decoder.ASRDecoder
+ hidden_dim: 768
+ num_heads: 12
+ depth: 12
+ block_size: 4096
+ in_channels: 512
+ n_vocab: 51866
+ bos_id: 50528
+ eos_id: 50527
+ dropout_rate: 0.0
+ attn_dropout_rate: 0.0
\ No newline at end of file
diff --git a/configs/config.json b/configs/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..e74f0b4898f6e47e1d198b62cdba989784ce2bb0
--- /dev/null
+++ b/configs/config.json
@@ -0,0 +1 @@
+{"reference_audio_path": "D:/FAcodec/test_waves/kobe_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS 2.4", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS 2.4", "sr_type": "sr_model", "diffusion_steps": 10.0, "inference_cfg_rate": 0.0, "max_prompt_length": 3.0, "block_time": 0.7, "crossfade_length": 0.04, "extra_time": 0.5, "extra_time_right": 0.02}
\ No newline at end of file
diff --git a/configs/inuse/.gitignore b/configs/inuse/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/configs/inuse/config.json b/configs/inuse/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..17d2df311cd2b19c1888c8fbb82effbfbc6edef3
--- /dev/null
+++ b/configs/inuse/config.json
@@ -0,0 +1 @@
+{"reference_audio_path": "D:/seed-vc/examples/reference/trump_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS USB", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS USB", "sr_type": "sr_model", "diffusion_steps": 8.0, "inference_cfg_rate": 0.7, "max_prompt_length": 3.0, "block_time": 0.58, "crossfade_length": 0.04, "extra_time_ce": 2.5, "extra_time": 0.5, "extra_time_right": 0.02}
\ No newline at end of file
diff --git a/configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml b/configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0ec7ef4f6d8c7cf160687747ae01e0d39f6c128d
--- /dev/null
+++ b/configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml
@@ -0,0 +1,98 @@
+log_dir: "./runs"
+save_freq: 1
+log_interval: 10
+save_interval: 1000
+device: "cuda"
+epochs: 1000 # number of epochs for first stage training (pre-training)
+batch_size: 1
+batch_length: 100 # maximum duration of audio in a batch (in seconds)
+max_len: 80 # maximum number of frames
+pretrained_model: "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth"
+pretrained_encoder: ""
+load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
+
+preprocess_params:
+ sr: 44100
+ spect_params:
+ n_fft: 2048
+ win_length: 2048
+ hop_length: 512
+ n_mels: 128
+ fmin: 0
+ fmax: "None"
+
+model_params:
+ dit_type: "DiT" # uDiT or DiT
+ reg_loss_type: "l1" # l1 or l2
+
+ timbre_shifter:
+ se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
+ ckpt_path: './modules/openvoice/checkpoints_v2/converter'
+
+ vocoder:
+ type: "bigvgan"
+ name: "nvidia/bigvgan_v2_44khz_128band_512x"
+
+ speech_tokenizer:
+ type: 'whisper'
+ name: "openai/whisper-small"
+
+ style_encoder:
+ dim: 192
+ campplus_path: "campplus_cn_common.bin"
+
+ DAC:
+ encoder_dim: 64
+ encoder_rates: [2, 5, 5, 6]
+ decoder_dim: 1536
+ decoder_rates: [ 6, 5, 5, 2 ]
+ sr: 24000
+
+ length_regulator:
+ channels: 768
+ is_discrete: false
+ in_channels: 768
+ content_codebook_size: 2048
+ sampling_ratios: [1, 1, 1, 1]
+ vector_quantize: false
+ n_codebooks: 1
+ quantizer_dropout: 0.0
+ f0_condition: true
+ n_f0_bins: 256
+
+ DiT:
+ hidden_dim: 768
+ num_heads: 12
+ depth: 17
+ class_dropout_prob: 0.1
+ block_size: 8192
+ in_channels: 128
+ style_condition: true
+ final_layer_type: 'mlp'
+ target: 'mel' # mel or codec
+ content_dim: 768
+ content_codebook_size: 1024
+ content_type: 'discrete'
+ f0_condition: true
+ n_f0_bins: 256
+ content_codebooks: 1
+ is_causal: false
+ long_skip_connection: false
+ zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
+ time_as_token: false
+ style_as_token: false
+ uvit_skip_connection: true
+ add_resblock_in_transformer: false
+
+ wavenet:
+ hidden_dim: 768
+ num_layers: 8
+ kernel_size: 5
+ dilation_rate: 1
+ p_dropout: 0.2
+ style_condition: true
+
+loss_params:
+ base_lr: 0.0001
+ lambda_mel: 45
+ lambda_kl: 1.0
\ No newline at end of file
diff --git a/configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml b/configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml
new file mode 100644
index 0000000000000000000000000000000000000000..492910d163c7773c64f846ee55384e3e8b81ac00
--- /dev/null
+++ b/configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml
@@ -0,0 +1,91 @@
+log_dir: "./runs"
+save_freq: 1
+log_interval: 10
+save_interval: 1000
+device: "cuda"
+epochs: 1000 # number of epochs for first stage training (pre-training)
+batch_size: 2
+batch_length: 100 # maximum duration of audio in a batch (in seconds)
+max_len: 80 # maximum number of frames
+pretrained_model: "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth"
+pretrained_encoder: ""
+load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
+
+preprocess_params:
+ sr: 22050
+ spect_params:
+ n_fft: 1024
+ win_length: 1024
+ hop_length: 256
+ n_mels: 80
+ fmin: 0
+ fmax: "None"
+
+model_params:
+ dit_type: "DiT" # uDiT or DiT
+ reg_loss_type: "l1" # l1 or l2
+
+ timbre_shifter:
+ se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
+ ckpt_path: './modules/openvoice/checkpoints_v2/converter'
+
+ speech_tokenizer:
+ type: 'whisper'
+ name: "openai/whisper-small"
+
+ style_encoder:
+ dim: 192
+ campplus_path: "campplus_cn_common.bin"
+
+ vocoder:
+ type: "bigvgan"
+ name: "nvidia/bigvgan_v2_22khz_80band_256x"
+
+ length_regulator:
+ channels: 512
+ is_discrete: false
+ in_channels: 768
+ content_codebook_size: 2048
+ sampling_ratios: [1, 1, 1, 1]
+ vector_quantize: false
+ n_codebooks: 1
+ quantizer_dropout: 0.0
+ f0_condition: false
+ n_f0_bins: 512
+
+ DiT:
+ hidden_dim: 512
+ num_heads: 8
+ depth: 13
+ class_dropout_prob: 0.1
+ block_size: 8192
+ in_channels: 80
+ style_condition: true
+ final_layer_type: 'wavenet'
+ target: 'mel' # mel or codec
+ content_dim: 512
+ content_codebook_size: 1024
+ content_type: 'discrete'
+ f0_condition: false
+ n_f0_bins: 512
+ content_codebooks: 1
+ is_causal: false
+ long_skip_connection: true
+ zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
+ time_as_token: false
+ style_as_token: false
+ uvit_skip_connection: true
+ add_resblock_in_transformer: false
+
+ wavenet:
+ hidden_dim: 512
+ num_layers: 8
+ kernel_size: 5
+ dilation_rate: 1
+ p_dropout: 0.2
+ style_condition: true
+
+loss_params:
+ base_lr: 0.0001
+ lambda_mel: 45
+ lambda_kl: 1.0
\ No newline at end of file
diff --git a/configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml b/configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml
new file mode 100644
index 0000000000000000000000000000000000000000..e0677397377158dd30ffdf905946fbd297b36bd5
--- /dev/null
+++ b/configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml
@@ -0,0 +1,82 @@
+log_dir: "./runs/"
+save_freq: 1
+log_interval: 10
+save_interval: 500
+device: "cuda"
+epochs: 1000 # number of epochs for first stage training (pre-training)
+batch_size: 2
+batch_length: 100 # maximum duration of audio in a batch (in seconds)
+max_len: 80 # maximum number of frames
+pretrained_model: "DiT_uvit_tat_xlsr_ema.pth"
+pretrained_encoder: ""
+load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters
+
+preprocess_params:
+ sr: 22050
+ spect_params:
+ n_fft: 1024
+ win_length: 1024
+ hop_length: 256
+ n_mels: 80
+ fmin: 0
+ fmax: 8000
+
+model_params:
+ dit_type: "DiT" # uDiT or DiT
+ reg_loss_type: "l1" # l1 or l2
+ diffusion_type: "flow"
+
+ timbre_shifter:
+ se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt"
+ ckpt_path: './modules/openvoice/checkpoints_v2/converter'
+
+ vocoder:
+ type: "hifigan"
+
+ speech_tokenizer:
+ type: 'xlsr'
+ output_layer: 12
+ name: 'facebook/wav2vec2-xls-r-300m'
+
+ style_encoder:
+ dim: 192
+ campplus_path: "campplus_cn_common.bin"
+
+ length_regulator:
+ channels: 384
+ is_discrete: false
+ in_channels: 1024
+ content_codebook_size: 1024
+ sampling_ratios: [1, 1, 1, 1]
+ vector_quantize: false
+ n_codebooks: 2
+ quantizer_dropout: 0.0
+ f0_condition: false
+ n_f0_bins: 512
+
+ DiT:
+ hidden_dim: 384
+ num_heads: 6
+ depth: 9
+ class_dropout_prob: 0.1
+ block_size: 8192
+ in_channels: 80
+ style_condition: true
+ final_layer_type: 'mlp'
+ target: 'mel' # mel or betavae
+ content_dim: 384
+ content_codebook_size: 1024
+ content_type: 'discrete'
+ f0_condition: false
+ n_f0_bins: 512
+ content_codebooks: 1
+ is_causal: false
+ long_skip_connection: false
+ zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token
+ time_as_token: true
+ style_as_token: true
+ uvit_skip_connection: true
+ add_resblock_in_transformer: false
+
+loss_params:
+ base_lr: 0.0001
\ No newline at end of file
diff --git a/configs/v2/ar_base.yaml b/configs/v2/ar_base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/configs/v2/dit_small.yaml b/configs/v2/dit_small.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..60b93e953d7ec19039af147b7a52002b757a822a
--- /dev/null
+++ b/configs/v2/dit_small.yaml
@@ -0,0 +1,17 @@
+_target_: modules.v2.cfm.CFM
+estimator:
+ _target_: modules.v2.dit_wrapper.DiT
+ time_as_token: true
+ style_as_token: true
+ uvit_skip_connection: false
+ block_size: 8192
+ depth: 13
+ num_heads: 8
+ hidden_dim: 512
+ in_channels: 80
+ content_dim: 512
+ style_encoder_dim: 192
+ class_dropout_prob: 0.1
+ dropout_rate: 0.0
+ attn_dropout_rate: 0.0
+
diff --git a/configs/v2/vc_wrapper.yaml b/configs/v2/vc_wrapper.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c3fe5b84431f53ebd12ec60663f40f61f4a8c231
--- /dev/null
+++ b/configs/v2/vc_wrapper.yaml
@@ -0,0 +1,105 @@
+_target_: modules.v2.vc_wrapper.VoiceConversionWrapper
+sr: 22050
+hop_size: 256
+mel_fn:
+ _target_: modules.audio.mel_spectrogram
+ _partial_: true
+ n_fft: 1024
+ win_size: 1024
+ hop_size: 256
+ num_mels: 80
+ sampling_rate: 22050
+ fmin: 0
+ fmax: null
+ center: False
+cfm:
+ _target_: modules.v2.cfm.CFM
+ estimator:
+ _target_: modules.v2.dit_wrapper.DiT
+ time_as_token: true
+ style_as_token: true
+ uvit_skip_connection: false
+ block_size: 8192
+ depth: 13
+ num_heads: 8
+ hidden_dim: 512
+ in_channels: 80
+ content_dim: 512
+ style_encoder_dim: 192
+ class_dropout_prob: 0.1
+ dropout_rate: 0.0
+ attn_dropout_rate: 0.0
+cfm_length_regulator:
+ _target_: modules.v2.length_regulator.InterpolateRegulator
+ channels: 512
+ is_discrete: true
+ codebook_size: 2048
+ sampling_ratios: [ 1, 1, 1, 1 ]
+ f0_condition: false
+ar:
+ _target_: modules.v2.ar.NaiveWrapper
+ model:
+ _target_: modules.v2.ar.NaiveTransformer
+ config:
+ _target_: modules.v2.ar.NaiveModelArgs
+ dropout: 0.0
+ rope_base: 10000.0
+ dim: 768
+ head_dim: 64
+ n_local_heads: 2
+ intermediate_size: 2304
+ n_head: 12
+ n_layer: 12
+ vocab_size: 2049 # 1 + 1 for eos
+ar_length_regulator:
+ _target_: modules.v2.length_regulator.InterpolateRegulator
+ channels: 768
+ is_discrete: true
+ codebook_size: 32
+ sampling_ratios: [ ]
+ f0_condition: false
+style_encoder:
+ _target_: modules.campplus.DTDNN.CAMPPlus
+ feat_dim: 80
+ embedding_size: 192
+content_extractor_narrow:
+ _target_: modules.astral_quantization.default_model.AstralQuantizer
+ tokenizer_name: "openai/whisper-small"
+ ssl_model_name: "facebook/hubert-large-ll60k"
+ ssl_output_layer: 18
+ skip_ssl: true
+ encoder: &bottleneck_encoder
+ _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage
+ dim: 512
+ num_blocks: 12
+ intermediate_dim: 1536
+ dilation: 1
+ input_dim: 1024
+ quantizer:
+ _target_: modules.astral_quantization.bsq.BinarySphericalQuantize
+ codebook_size: 32 # codebook size, must be a power of 2
+ dim: 512
+ entropy_loss_weight: 0.1
+ diversity_gamma: 1.0
+ spherical: True
+ enable_entropy_loss: True
+ soft_entropy_loss: True
+content_extractor_wide:
+ _target_: modules.astral_quantization.default_model.AstralQuantizer
+ tokenizer_name: "openai/whisper-small"
+ ssl_model_name: "facebook/hubert-large-ll60k"
+ ssl_output_layer: 18
+ encoder: *bottleneck_encoder
+ quantizer:
+ _target_: modules.astral_quantization.bsq.BinarySphericalQuantize
+ codebook_size: 2048 # codebook size, must be a power of 2
+ dim: 512
+ entropy_loss_weight: 0.1
+ diversity_gamma: 1.0
+ spherical: True
+ enable_entropy_loss: True
+ soft_entropy_loss: True
+vocoder:
+ _target_: modules.bigvgan.bigvgan.BigVGAN.from_pretrained
+ pretrained_model_name_or_path: "nvidia/bigvgan_v2_22khz_80band_256x"
+ use_cuda_kernel: false
diff --git a/hf_utils.py b/hf_utils.py
index 9f8c7f7d5f1b82efbd788c7327f76c0dc6a9355a..4ae986c1d88d5b1214c9d3101249e90c5c370ca9 100644
--- a/hf_utils.py
+++ b/hf_utils.py
@@ -2,7 +2,7 @@ import os
from huggingface_hub import hf_hub_download
-def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.yml"):
+def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename=None):
os.makedirs("./checkpoints", exist_ok=True)
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints")
if config_filename is None:
diff --git a/modules/__pycache__/audio.cpython-310.pyc b/modules/__pycache__/audio.cpython-310.pyc
index 79bf4bb4261f2d91fe5d3efb024339a88274162c..651e7ad6c3e297013d527e4a9218ae35f2b92c41 100644
Binary files a/modules/__pycache__/audio.cpython-310.pyc and b/modules/__pycache__/audio.cpython-310.pyc differ
diff --git a/modules/__pycache__/commons.cpython-310.pyc b/modules/__pycache__/commons.cpython-310.pyc
index 9289cfe3ac9362aaeece6546f5b78bbfea6ca40b..5adfe7b95903d4c3134f74bdf68b5b413b848c97 100644
Binary files a/modules/__pycache__/commons.cpython-310.pyc and b/modules/__pycache__/commons.cpython-310.pyc differ
diff --git a/modules/__pycache__/commons.cpython-38.pyc b/modules/__pycache__/commons.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cfe6a34f14ee453f4a79de9b39d80fc62b4d7fda
Binary files /dev/null and b/modules/__pycache__/commons.cpython-38.pyc differ
diff --git a/modules/__pycache__/diffusion_transformer.cpython-310.pyc b/modules/__pycache__/diffusion_transformer.cpython-310.pyc
index c721982be6feb37d7ae333799842c197225c9697..4bcbf046526a3c8e13cf180ebea709e89305d848 100644
Binary files a/modules/__pycache__/diffusion_transformer.cpython-310.pyc and b/modules/__pycache__/diffusion_transformer.cpython-310.pyc differ
diff --git a/modules/__pycache__/flow_matching.cpython-310.pyc b/modules/__pycache__/flow_matching.cpython-310.pyc
index 4f42f2602f27a9f430f3daf9ca3825ebe86172b5..4a6f71cb0ef1cd096cdeed48330dbce1aed55734 100644
Binary files a/modules/__pycache__/flow_matching.cpython-310.pyc and b/modules/__pycache__/flow_matching.cpython-310.pyc differ
diff --git a/modules/__pycache__/length_regulator.cpython-310.pyc b/modules/__pycache__/length_regulator.cpython-310.pyc
index 301c8f57e7713f62bf83b9c4fa712fe7680f26be..c2ada28a9ce6e0b8a61901f913b5ddde33b0c9ac 100644
Binary files a/modules/__pycache__/length_regulator.cpython-310.pyc and b/modules/__pycache__/length_regulator.cpython-310.pyc differ
diff --git a/modules/__pycache__/rmvpe.cpython-310.pyc b/modules/__pycache__/rmvpe.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..18f08923f7c01ba34a8719dd56cd1c05fdf4e33d
Binary files /dev/null and b/modules/__pycache__/rmvpe.cpython-310.pyc differ
diff --git a/modules/astral_quantization/__pycache__/bsq.cpython-310.pyc b/modules/astral_quantization/__pycache__/bsq.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..918dc5ae503d2a7da3860b7637f9af7aa81ea466
Binary files /dev/null and b/modules/astral_quantization/__pycache__/bsq.cpython-310.pyc differ
diff --git a/modules/astral_quantization/__pycache__/convnext.cpython-310.pyc b/modules/astral_quantization/__pycache__/convnext.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..508b91e576abbb86db68531cda662c8e9daa0fe6
Binary files /dev/null and b/modules/astral_quantization/__pycache__/convnext.cpython-310.pyc differ
diff --git a/modules/astral_quantization/__pycache__/default_model.cpython-310.pyc b/modules/astral_quantization/__pycache__/default_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..381bd7f361d729a1640b7eed29bbfdc0515cbcc0
Binary files /dev/null and b/modules/astral_quantization/__pycache__/default_model.cpython-310.pyc differ
diff --git a/modules/astral_quantization/bsq.py b/modules/astral_quantization/bsq.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b70f3401dafa187fc19666d31a0877f533abe33
--- /dev/null
+++ b/modules/astral_quantization/bsq.py
@@ -0,0 +1,569 @@
+"""
+Lookup Free Quantization
+Proposed in https://arxiv.org/abs/2310.05737
+
+In the simplest setup, each dimension is quantized into {-1, 1}.
+An entropy penalty is used to encourage utilization.
+"""
+
+from math import log2, ceil
+from functools import partial, cache
+from collections import namedtuple
+from contextlib import nullcontext
+
+import torch.distributed as dist
+from torch.distributed import nn as dist_nn
+
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+from torch.nn import Module
+from torch.amp import autocast
+
+from einops import rearrange, reduce, pack, unpack
+
+# constants
+
+Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
+
+LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
+
+# distributed helpers
+
+@cache
+def is_distributed():
+ return dist.is_initialized() and dist.get_world_size() > 1
+
+def maybe_distributed_mean(t):
+ if not is_distributed():
+ return t
+
+ dist_nn.all_reduce(t)
+ t = t / dist.get_world_size()
+ return t
+
+# helper functions
+
+def exists(v):
+ return v is not None
+
+def identity(t):
+ return t
+
+def default(*args):
+ for arg in args:
+ if exists(arg):
+ return arg() if callable(arg) else arg
+ return None
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+def l2norm(t):
+ return F.normalize(t, dim = -1)
+
+# entropy
+
+def log(t, eps = 1e-5):
+ return t.clamp(min = eps).log()
+
+def entropy(prob):
+ return (-prob * log(prob)).sum(dim=-1)
+
+# cosine sim linear
+
+class CosineSimLinear(Module):
+ def __init__(
+ self,
+ dim_in,
+ dim_out,
+ scale = 1.
+ ):
+ super().__init__()
+ self.scale = scale
+ self.weight = nn.Parameter(torch.randn(dim_in, dim_out))
+
+ def forward(self, x):
+ x = F.normalize(x, dim = -1)
+ w = F.normalize(self.weight, dim = 0)
+ return (x @ w) * self.scale
+
+def soft_entropy_loss(u, tau=1.0, gamma=1.0):
+ """
+ Compute the soft entropy loss for Binary Spherical Quantization (BSQ).
+
+ Args:
+ u (torch.Tensor): Input latent embeddings of shape (batch_size, L).
+ tau (float): Temperature scaling factor.
+ gamma (float): Weight for the second entropy term.
+
+ Returns:
+ torch.Tensor: Soft entropy loss.
+ """
+ # Binary quantization: Generate implicit codebook corners
+ L = u.size(1) # Dimensionality of codebook
+ corners = torch.tensor([-1.0, 1.0], device=u.device) / (L**0.5)
+
+ # Compute soft quantization probabilities for all dimensions
+ # q_hat(c|u) for each dimension
+ prob_matrix = torch.sigmoid(2 * tau * corners.unsqueeze(1) * u.unsqueeze(2)) # Shape: (batch_size, L, 2)
+
+ # Entropy of q_hat(c|u) (independent along each dimension)
+ entropy_per_dim = -torch.sum(prob_matrix * prob_matrix.log(), dim=-1) # Shape: (batch_size, L)
+ entropy_term1 = entropy_per_dim.mean()
+
+ # Expected probabilities for dataset entropy (approximation)
+ expected_probs = prob_matrix.mean(dim=0) # Mean across batch, shape: (L, 2)
+ entropy_term2 = -torch.sum(expected_probs * expected_probs.log(), dim=-1).mean()
+
+ # Final entropy loss
+ loss = entropy_term1 - gamma * entropy_term2
+ return loss
+
+# class
+
+class BinarySphericalQuantize(Module):
+ def __init__(
+ self,
+ *,
+ dim = None,
+ codebook_size = None,
+ entropy_loss_weight = 0.1,
+ commitment_loss_weight = 0.,
+ diversity_gamma = 1.,
+ straight_through_activation = nn.Identity(),
+ num_codebooks = 1,
+ keep_num_codebooks_dim = None,
+ codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer
+ frac_per_sample_entropy = 0.25, # make less than 1. to only use a random fraction of the probs for per sample entropy
+ has_projections = None,
+ projection_has_bias = True,
+ soft_clamp_input_value = None,
+ cosine_sim_project_in = False,
+ cosine_sim_project_in_scale = None,
+ channel_first = None,
+ experimental_softplus_entropy_loss = False,
+ entropy_loss_offset = 5., # how much to shift the loss before softplus
+ spherical = True, # from https://arxiv.org/abs/2406.07548
+ force_quantization_f32 = True, # will force the quantization step to be full precision
+ enable_entropy_loss = True,
+ soft_entropy_loss = True,
+ ):
+ super().__init__()
+
+ # some assert validations
+
+ assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
+ assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
+
+ codebook_size = default(codebook_size, lambda: 2 ** dim)
+ self.codebook_size = codebook_size
+
+ codebook_dim = int(log2(codebook_size))
+ codebook_dims = codebook_dim * num_codebooks
+ dim = default(dim, codebook_dims)
+
+ has_projections = default(has_projections, dim != codebook_dims)
+
+ if cosine_sim_project_in:
+ cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale)
+ project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in)
+ else:
+ project_in_klass = partial(nn.Linear, bias = projection_has_bias)
+
+ self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity()
+ self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity()
+ self.has_projections = has_projections
+
+ self.dim = dim
+ self.codebook_dim = codebook_dim
+ self.num_codebooks = num_codebooks
+
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
+
+ # channel first
+
+ self.channel_first = channel_first
+
+ # straight through activation
+
+ self.activation = straight_through_activation
+
+ # whether to use BSQ (binary spherical quantization)
+
+ self.spherical = spherical
+ self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity
+
+ # entropy aux loss related weights
+
+ assert 0 < frac_per_sample_entropy <= 1.
+ self.frac_per_sample_entropy = frac_per_sample_entropy
+
+ self.diversity_gamma = diversity_gamma
+ self.entropy_loss_weight = entropy_loss_weight
+
+ # codebook scale
+
+ self.codebook_scale = codebook_scale
+
+ # commitment loss
+
+ self.commitment_loss_weight = commitment_loss_weight
+
+ # whether to soft clamp the input value from -value to value
+
+ self.soft_clamp_input_value = soft_clamp_input_value
+ assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale
+
+ # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions)
+
+ self.entropy_loss_offset = entropy_loss_offset
+ self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss
+
+ # for no auxiliary loss, during inference
+
+ self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
+
+ # whether to force quantization step to be f32
+
+ self.force_quantization_f32 = force_quantization_f32
+
+ # codes
+ self.enable_entropy_loss = enable_entropy_loss
+ self.soft_entropy_loss = soft_entropy_loss
+ if codebook_size <= 100000:
+ all_codes = torch.arange(codebook_size)
+ bits = ((all_codes[..., None].int() & self.mask) != 0).float()
+ codebook = self.bits_to_codes(bits)
+
+ self.register_buffer('codebook', codebook.float(), persistent = False)
+ else:
+ all_codes = torch.arange(pow(2, 16))
+ mask = 2 ** torch.arange(16 - 1, -1, -1)
+ bits = ((all_codes[..., None].int() & mask) != 0).float()
+ codebook = self.bits_to_codes(bits)
+
+ self.register_buffer('codebook', codebook.float(), persistent = False)
+
+ def bits_to_codes(self, bits):
+ return bits * self.codebook_scale * 2 - self.codebook_scale
+
+ @property
+ def dtype(self):
+ return self.codebook.dtype
+
+ def indices_to_codes(
+ self,
+ indices,
+ project_out = True
+ ):
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
+ should_transpose = default(self.channel_first, is_img_or_video)
+
+ if not self.keep_num_codebooks_dim:
+ indices = rearrange(indices, '... -> ... 1')
+
+ # indices to codes, which are bits of either -1 or 1
+
+ bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
+
+ codes = self.bits_to_codes(bits)
+
+ codes = self.maybe_l2norm(codes)
+
+ codes = rearrange(codes, '... c d -> ... (c d)')
+
+ # whether to project codes out to original dimensions
+ # if the input feature dimensions were not log2(codebook size)
+
+ if project_out:
+ codes = self.project_out(codes)
+
+ # rearrange codes back to original shape
+
+ if should_transpose:
+ codes = rearrange(codes, 'b ... d -> b d ...')
+
+ return codes
+
+ def bits_to_z(self, bits):
+ # assert bits must contain only -1 and 1
+ assert torch.all(bits.abs() == 1)
+ quantized = bits.float()
+ quantized = self.maybe_l2norm(quantized)
+ z = self.project_out(quantized)
+ return z
+
+ def forward(
+ self,
+ x,
+ inv_temperature = 100.,
+ return_loss_breakdown = False,
+ mask = None,
+ return_bits = False
+ ):
+ """
+ einstein notation
+ b - batch
+ n - sequence (or flattened spatial dimensions)
+ d - feature dimension, which is also log2(codebook size)
+ c - number of codebook dim
+ """
+
+ is_img_or_video = x.ndim >= 4
+ should_transpose = default(self.channel_first, is_img_or_video)
+
+ # standardize image or video into (batch, seq, dimension)
+
+ if should_transpose:
+ x = rearrange(x, 'b d ... -> b ... d')
+ x, ps = pack_one(x, 'b * d')
+
+ assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
+
+ x = self.project_in(x)
+
+ # maybe soft clamp
+
+ if exists(self.soft_clamp_input_value):
+ clamp_value = self.soft_clamp_input_value
+ x = (x / clamp_value).tanh() * clamp_value
+
+ # split out number of codebooks
+
+ x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
+
+ # maybe l2norm
+
+ x = self.maybe_l2norm(x)
+
+ # whether to force quantization step to be full precision or not
+
+ force_f32 = self.force_quantization_f32
+
+ quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext
+
+ with quantization_context():
+
+ if force_f32:
+ orig_dtype = x.dtype
+ x = x.float()
+
+ # quantize by eq 3.
+
+ original_input = x
+
+ codebook_value = torch.ones_like(x) * self.codebook_scale
+ quantized = torch.where(x > 0, codebook_value, -codebook_value)
+ if return_bits:
+ return quantized
+
+ # calculate indices
+
+ indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
+
+ # maybe l2norm
+
+ quantized = self.maybe_l2norm(quantized)
+
+ # use straight-through gradients (optionally with custom activation fn) if training
+
+ if self.training:
+ x = self.activation(x)
+ x = x + (quantized - x).detach()
+ else:
+ x = quantized
+
+ # entropy aux loss
+ if self.soft_entropy_loss:
+ entropy_aux_loss = soft_entropy_loss(x, tau=1.0, gamma=1.0)
+ elif self.training and self.enable_entropy_loss:
+
+ if force_f32:
+ codebook = self.codebook.float()
+
+ codebook = self.maybe_l2norm(codebook)
+
+ # whether to only use a fraction of probs, for reducing memory
+
+ if self.frac_per_sample_entropy < 1.:
+ # account for mask
+ if exists(mask):
+ original_input = original_input[mask]
+ original_input = rearrange(original_input, 'b n ... -> (b n) ...')
+
+ rand_mask = torch.randn(self.codebook_dim).argsort(dim = -1) < 16
+
+ sampled_input = original_input[..., rand_mask]
+
+ sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook)
+
+ sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1)
+
+ per_sample_probs = sampled_prob
+ else:
+ if exists(mask):
+ original_input = original_input[mask]
+ original_input = rearrange(original_input, 'b n ... -> (b n) ...')
+ # the same as euclidean distance up to a constant
+ distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook)
+
+ prob = (-distance * inv_temperature).softmax(dim = -1)
+
+ per_sample_probs = prob
+
+ # calculate per sample entropy
+
+ per_sample_entropy = entropy(per_sample_probs).mean()
+
+ # distribution over all available tokens in the batch
+
+ avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean')
+
+ avg_prob = maybe_distributed_mean(avg_prob)
+
+ codebook_entropy = entropy(avg_prob).mean()
+
+ # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
+ # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
+
+ entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
+ else:
+ # if not training, just return dummy 0
+ entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
+
+ # whether to make the entropy loss positive or not through a (shifted) softplus
+
+ if self.training and self.experimental_softplus_entropy_loss:
+ entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset)
+
+ # commit loss
+
+ if self.training and self.commitment_loss_weight > 0.:
+
+ commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none')
+
+ if exists(mask):
+ commit_loss = commit_loss[mask]
+
+ commit_loss = commit_loss.mean()
+ else:
+ commit_loss = self.zero
+
+ # input back to original dtype if needed
+
+ if force_f32:
+ x = x.type(orig_dtype)
+
+ # merge back codebook dim
+
+ x = rearrange(x, 'b n c d -> b n (c d)')
+
+ # project out to feature dimension if needed
+
+ x = self.project_out(x)
+
+ # reconstitute image or video dimensions
+
+ if should_transpose:
+ x = unpack_one(x, ps, 'b * d')
+ x = rearrange(x, 'b ... d -> b d ...')
+
+ indices = unpack_one(indices, ps, 'b * c')
+
+ # whether to remove single codebook dim
+
+ if not self.keep_num_codebooks_dim:
+ indices = rearrange(indices, '... 1 -> ...')
+
+ # complete aux loss
+
+ aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
+
+ # returns
+
+ ret = Return(x, indices, aux_loss)
+
+ if not return_loss_breakdown:
+ return ret
+
+ return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
+
+class GroupedResidualBSQ(Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ groups = 1,
+ accept_image_fmap = False,
+ **kwargs
+ ):
+ super().__init__()
+ self.dim = dim
+ self.groups = groups
+ assert (dim % groups) == 0
+ dim_per_group = dim // groups
+
+ self.accept_image_fmap = accept_image_fmap
+
+ self.rvqs = nn.ModuleList([])
+
+ for _ in range(groups):
+ self.rvqs.append(LFQ(
+ dim = dim_per_group,
+ **kwargs
+ ))
+
+ self.codebook_size = self.rvqs[0].codebook_size
+
+ @property
+ def codebooks(self):
+ return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs))
+
+ @property
+ def split_dim(self):
+ return 1 if self.accept_image_fmap else -1
+
+ def get_codes_from_indices(self, indices):
+ codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
+ return torch.stack(codes)
+
+ def get_output_from_indices(self, indices):
+ outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices))
+ return torch.cat(outputs, dim = self.split_dim)
+
+ def forward(
+ self,
+ x,
+ return_all_codes = False
+ ):
+ shape, split_dim = x.shape, self.split_dim
+ assert shape[split_dim] == self.dim
+
+ # split the feature dimension into groups
+
+ x = x.chunk(self.groups, dim = split_dim)
+
+ forward_kwargs = dict(
+ )
+
+ # invoke residual vq on each group
+
+ out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x))
+ out = tuple(zip(*out))
+
+ # otherwise, get all the zipped outputs and combine them
+
+ quantized, all_indices, *maybe_aux_loss = out
+
+ quantized = torch.cat(quantized, dim = split_dim)
+ all_indices = torch.stack(all_indices)
+
+ ret = (quantized, all_indices, *maybe_aux_loss)
+ return ret
diff --git a/modules/astral_quantization/convnext.py b/modules/astral_quantization/convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bef9e282ac332b7169339eeda082d0581d739d1
--- /dev/null
+++ b/modules/astral_quantization/convnext.py
@@ -0,0 +1,209 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List
+
+
+class ConvNextV2LayerNorm(nn.Module):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError(f"Unsupported data format: {self.data_format}")
+ self.normalized_shape = (normalized_shape,)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.data_format == "channels_last":
+ x = torch.nn.functional.layer_norm(
+ x, self.normalized_shape, self.weight, self.bias, self.eps
+ )
+ elif self.data_format == "channels_first":
+ input_dtype = x.dtype
+ x = x.float()
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = x.to(dtype=input_dtype)
+ x = self.weight[None, :, None] * x + self.bias[None, :, None]
+ return x
+
+
+class GRN(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
+
+ def forward(self, x):
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (x * Nx) + self.beta + x
+
+class InterpolationLayer(nn.Module):
+ def __init__(self, ): # this is a default of 1 / 50 * (44100 / 512) / 4
+ super().__init__()
+ pass
+
+ def forward(self, x: torch.Tensor, target_len: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ x = F.interpolate(x, size=target_len, mode='linear')
+ return x
+
+class ConvNeXtV2Stage(nn.Module):
+ def __init__(
+ self,
+ dim: int = 512,
+ intermediate_dim: int = 2048,
+ num_blocks: int = 1,
+ dilation: int = 1,
+ downsample_layer_indices: List[int] = None,
+ downsample_factors: List[int] = None,
+ upsample_layer_indices: List[int] = None,
+ upsample_factors: List[int] = None,
+ interpolation_layer_indices: List[int] = None,
+ input_dim: int = None,
+ output_dim: int = None,
+ gin_channels: int = 0,
+ ):
+ super().__init__()
+ # maybe downsample layers
+ if downsample_layer_indices is not None:
+ assert downsample_factors is not None
+ self.downsample_blocks = nn.ModuleList(
+ [
+ nn.Sequential(
+ ConvNextV2LayerNorm(dim, data_format="channels_first"),
+ nn.Conv1d(
+ dim, dim, kernel_size=downsample_factor, stride=downsample_factor
+ ),
+ ) for _, downsample_factor in zip(downsample_layer_indices, downsample_factors)
+ ]
+ )
+ self.downsample_layer_indices = downsample_layer_indices
+ else:
+ self.downsample_blocks = nn.ModuleList()
+ self.downsample_layer_indices = []
+
+ # maybe upsample layers
+ if upsample_layer_indices is not None:
+ assert upsample_factors is not None
+ self.upsample_blocks = nn.ModuleList(
+ [
+ nn.Sequential(
+ ConvNextV2LayerNorm(dim, data_format="channels_first"),
+ nn.ConvTranspose1d(
+ dim, dim, kernel_size=upsample_factor, stride=upsample_factor
+ ),
+ ) for _, upsample_factor in zip(upsample_layer_indices, upsample_factors)
+ ]
+ )
+ self.upsample_layer_indices = upsample_layer_indices
+ else:
+ self.upsample_blocks = nn.ModuleList()
+ self.upsample_layer_indices = []
+
+ # maybe interpolation layers
+ if interpolation_layer_indices is not None:
+ self.interpolation_blocks = nn.ModuleList(
+ [
+ InterpolationLayer()
+ for _ in interpolation_layer_indices
+ ]
+ )
+ self.interpolation_layer_indices = interpolation_layer_indices
+ else:
+ self.interpolation_blocks = nn.ModuleList()
+ self.interpolation_layer_indices = []
+
+ # main blocks
+ self.blocks = nn.ModuleList(
+ [
+ ConvNeXtV2Block(
+ dim=dim,
+ intermediate_dim=intermediate_dim,
+ dilation=dilation,
+ )
+ for _ in range(num_blocks)
+ ]
+ )
+ # maybe input and output projections
+ if input_dim is not None and input_dim != dim:
+ self.input_projection = nn.Conv1d(input_dim, dim, kernel_size=1)
+ else:
+ self.input_projection = nn.Identity()
+ if output_dim is not None and output_dim != dim:
+ self.output_projection = nn.Conv1d(dim, output_dim, kernel_size=1)
+ else:
+ self.output_projection = nn.Identity()
+
+ if gin_channels > 0:
+ self.gin = nn.Conv1d(gin_channels, dim, kernel_size=1)
+
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ x = self.input_projection(x) # B, D, T
+ if hasattr(self, 'gin'):
+ g = kwargs['g']
+ x = x + self.gin(g)
+ # pad to a multiple of cumprod(downsample_factors)
+ if len(self.downsample_blocks) > 0:
+ downsample_factor = 1
+ for factor in self.downsample_blocks:
+ downsample_factor *= factor[1].stride[0]
+ pad_len = downsample_factor - x.size(-1) % downsample_factor
+ if pad_len > 0:
+ x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1)
+
+ # main blocks
+ for layer_idx, block in enumerate(self.blocks):
+ if layer_idx in self.downsample_layer_indices:
+ x = self.downsample_blocks[self.downsample_layer_indices.index(layer_idx)](x)
+ if layer_idx in self.upsample_layer_indices:
+ x = self.upsample_blocks[self.upsample_layer_indices.index(layer_idx)](x)
+ if layer_idx in self.interpolation_layer_indices:
+ x = self.interpolation_blocks[self.interpolation_layer_indices.index(layer_idx)](x, target_len=kwargs['target_len'])
+ x = block(x)
+ x = self.output_projection(x)
+ return x
+
+ def setup_caches(self, *args, **kwargs):
+ pass
+
+
+class ConvNeXtV2Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ intermediate_dim: int,
+ dilation: int = 1,
+ ):
+ super().__init__()
+ padding = (dilation * (7 - 1)) // 2
+ self.dwconv = nn.Conv1d(
+ dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
+ ) # depthwise conv
+ self.norm = ConvNextV2LayerNorm(dim, data_format="channels_first")
+ self.pwconv1 = nn.Linear(
+ dim, intermediate_dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.grn = GRN(intermediate_dim)
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ residual = x
+ x = self.dwconv(x)
+ x = self.norm(x)
+ x = x.transpose(1, 2) # b d n -> b n d
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.grn(x)
+ x = self.pwconv2(x)
+ x = x.transpose(1, 2) # b n d -> b d n
+ return residual + x
\ No newline at end of file
diff --git a/modules/astral_quantization/default_model.py b/modules/astral_quantization/default_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca50d1cb1b177a89955c9f13451e5ed159b753b5
--- /dev/null
+++ b/modules/astral_quantization/default_model.py
@@ -0,0 +1,73 @@
+import torch
+from transformers import AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor
+
+class AstralQuantizer(torch.nn.Module):
+ def __init__(
+ self,
+ tokenizer_name: str,
+ ssl_model_name: str,
+ ssl_output_layer: int,
+ encoder: torch.nn.Module,
+ quantizer: torch.nn.Module,
+ skip_ssl: bool = False,
+ ):
+ super().__init__()
+ self.encoder = encoder
+ self.quantizer = quantizer
+ self.tokenizer_name = tokenizer_name
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+
+ # Load SSL model from Huggingface
+ self.ssl_model_name = ssl_model_name
+ self.ssl_output_layer = ssl_output_layer
+ self.ssl_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ssl_model_name)
+
+ if skip_ssl: # in case the same SSL model has been loaded somewhere else
+ self.ssl_model = None
+ else:
+ self.ssl_model = AutoModel.from_pretrained(ssl_model_name).eval()
+ self.ssl_model.encoder.layers = self.ssl_model.encoder.layers[:ssl_output_layer]
+ self.ssl_model.encoder.layer_norm = torch.nn.Identity()
+
+ def load_separate_checkpoint(self, checkpoint_path):
+ params = torch.load(checkpoint_path, map_location='cpu')['net']
+ for key in params.keys():
+ for k in list(params[key].keys()):
+ if k.startswith("module."):
+ params[key][k[len("module."):]] = params[key][k]
+ del params[key][k]
+ self.encoder.load_state_dict(params['encoder'])
+ self.quantizer.load_state_dict(params['vq'])
+ if self.decoder is not None:
+ self.decoder.load_state_dict(params['decoder'])
+ if self.asr_decoder is not None:
+ self.asr_decoder.load_state_dict(params['predictor'], strict=False)
+
+ def forward(self, waves_16k, wave_16k_lens, ssl_model=None):
+ ssl_fn = self.ssl_model if self.ssl_model else ssl_model
+ assert ssl_fn is not None, "In case in-class SSL model loading is skipped, external ssl_model must be provided"
+ waves_16k_input_list = [
+ waves_16k[bib, :wave_16k_lens[bib]].cpu().numpy()
+ for bib in range(len(waves_16k))
+ ]
+ alt_inputs = self.ssl_feature_extractor(
+ waves_16k_input_list,
+ return_tensors='pt',
+ return_attention_mask=True,
+ padding=True,
+ sampling_rate=16000
+ ).to(waves_16k.device)
+ feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320 # frame rate of hubert is 50 Hz
+
+ outputs = ssl_fn(
+ alt_inputs.input_values,
+ attention_mask=alt_inputs.attention_mask,
+ )
+ last_hidden_states = outputs.last_hidden_state
+ last_hidden_states = last_hidden_states[:, :feature_lens.max(), :]
+ feature_lens = feature_lens.clamp(max=last_hidden_states.size(1))
+ last_hidden_states = last_hidden_states.transpose(1, 2)
+ x_hidden = self.encoder(last_hidden_states, feature_lens)
+ x_hidden = x_hidden.transpose(1, 2)
+ x_quantized, indices = self.quantizer(x_hidden)[:2]
+ return x_quantized, indices, feature_lens
\ No newline at end of file
diff --git a/modules/astral_quantization/transformer.py b/modules/astral_quantization/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..015ec417d02e3442d8ea120c3802807c73e89bb4
--- /dev/null
+++ b/modules/astral_quantization/transformer.py
@@ -0,0 +1,254 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+import time
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+class AdaptiveLayerNorm(nn.Module):
+ r"""Adaptive Layer Normalization"""
+
+ def __init__(self, d_model, norm) -> None:
+ super(AdaptiveLayerNorm, self).__init__()
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
+ self.norm = norm
+ self.d_model = d_model
+ self.eps = self.norm.eps
+
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
+ if embedding is None:
+ return self.norm(input)
+ weight, bias = torch.split(
+ self.project_layer(embedding),
+ split_size_or_sections=self.d_model,
+ dim=-1,
+ )
+ return weight * self.norm(input) + bias
+
+
+@dataclass
+class ModelArgs:
+ block_size: int = 2048
+ vocab_size: int = 32000
+ n_layer: int = 32
+ n_head: int = 32
+ dim: int = 4096
+ intermediate_size: int = None
+ n_local_heads: int = -1
+ head_dim: int = 64
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ has_cross_attention: bool = False
+ context_dim: int = 0
+ is_causal: bool = False
+ dropout_rate: float = 0.1
+ attn_dropout_rate: float = 0.1
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ # self.head_dim = self.dim // self.n_head
+
+class Transformer(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.config = config
+
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
+ self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+ self.max_batch_size = -1
+ self.max_seq_length = config.block_size
+ freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
+ self.config.rope_base)
+ self.register_buffer("freqs_cis", freqs_cis)
+
+ causal_mask = torch.tril(
+ torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
+ )
+ self.register_buffer("causal_mask", causal_mask)
+
+ def forward(self,
+ x: Tensor,
+ c: Tensor,
+ input_pos: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ context: Optional[Tensor] = None,
+ context_input_pos: Optional[Tensor] = None,
+ cross_attention_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ if mask is None:
+ mask = self.causal_mask[:x.size(1), :x.size(1)]
+ else:
+ mask = mask[..., input_pos]
+ freqs_cis = self.freqs_cis[input_pos]
+ if context is not None:
+ context_freqs_cis = self.freqs_cis[context_input_pos]
+ else:
+ context_freqs_cis = None
+ skip_in_x_list = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x, c, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask)
+ x = self.norm(x, c)
+ return x
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.attention = Attention(config)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+ self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+ if config.has_cross_attention:
+ self.has_cross_attention = True
+ self.cross_attention = Attention(config, is_cross_attention=True)
+ self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+ else:
+ self.has_cross_attention = False
+
+ def forward(self,
+ x: Tensor,
+ c: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ context: Optional[Tensor] = None,
+ context_freqs_cis: Optional[Tensor] = None,
+ cross_attention_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ #time_attn_start = time.time()
+ h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask)
+ #print(f"time take for attention of sequence length {x.shape[1]} is {time.time() - time_attn_start}")
+ if self.has_cross_attention:
+ h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, context, context_freqs_cis)
+ out = h + self.feed_forward(self.ffn_norm(h, c))
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ if is_cross_attention:
+ self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
+ self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
+ else:
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
+ self.kv_cache = None
+
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+ self.attn_dropout_rate = config.attn_dropout_rate
+
+ def forward(self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ context: Optional[Tensor] = None,
+ context_freqs_cis: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_local_heads * self.head_dim
+ if context is None:
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
+ context_seqlen = seqlen
+ else:
+ q = self.wq(x)
+ k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
+ context_seqlen = context.shape[1]
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+
+ q = apply_rotary_emb(q, freqs_cis)
+ k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_dropout_rate if self.training else 0.0)
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
+
+ y = self.wo(y)
+ return y
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+def precompute_freqs_cis(
+ seq_len: int, n_elem: int, base: int = 10000,
+ dtype: torch.dtype = torch.bfloat16
+) -> Tensor:
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+ return cache.to(dtype=dtype)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ],
+ -1,
+ )
+
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
+
diff --git a/modules/audio.py b/modules/audio.py
index abe783b0e0af630319700c931eb51d2ce375282b..ae677ffb1c124b557b3dbe0343ae415f5281cddb 100644
--- a/modules/audio.py
+++ b/modules/audio.py
@@ -1,82 +1,82 @@
-import numpy as np
-import torch
-import torch.utils.data
-from librosa.filters import mel as librosa_mel_fn
-from scipy.io.wavfile import read
-
-MAX_WAV_VALUE = 32768.0
-
-
-def load_wav(full_path):
- sampling_rate, data = read(full_path)
- return data, sampling_rate
-
-
-def dynamic_range_compression(x, C=1, clip_val=1e-5):
- return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
-
-
-def dynamic_range_decompression(x, C=1):
- return np.exp(x) / C
-
-
-def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
- return torch.log(torch.clamp(x, min=clip_val) * C)
-
-
-def dynamic_range_decompression_torch(x, C=1):
- return torch.exp(x) / C
-
-
-def spectral_normalize_torch(magnitudes):
- output = dynamic_range_compression_torch(magnitudes)
- return output
-
-
-def spectral_de_normalize_torch(magnitudes):
- output = dynamic_range_decompression_torch(magnitudes)
- return output
-
-
-mel_basis = {}
-hann_window = {}
-
-
-def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
- if torch.min(y) < -1.0:
- print("min value is ", torch.min(y))
- if torch.max(y) > 1.0:
- print("max value is ", torch.max(y))
-
- global mel_basis, hann_window # pylint: disable=global-statement
- if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
- mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
- mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
- hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
-
- y = torch.nn.functional.pad(
- y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
- )
- y = y.squeeze(1)
-
- spec = torch.view_as_real(
- torch.stft(
- y,
- n_fft,
- hop_length=hop_size,
- win_length=win_size,
- window=hann_window[str(sampling_rate) + "_" + str(y.device)],
- center=center,
- pad_mode="reflect",
- normalized=False,
- onesided=True,
- return_complex=True,
- )
- )
-
- spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
-
- spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
- spec = spectral_normalize_torch(spec)
-
- return spec
+import numpy as np
+import torch
+import torch.utils.data
+from librosa.filters import mel as librosa_mel_fn
+from scipy.io.wavfile import read
+
+MAX_WAV_VALUE = 32768.0
+
+
+def load_wav(full_path):
+ sampling_rate, data = read(full_path)
+ return data, sampling_rate
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
+ if torch.min(y) < -1.0:
+ print("min value is ", torch.min(y))
+ if torch.max(y) > 1.0:
+ print("max value is ", torch.max(y))
+
+ global mel_basis, hann_window # pylint: disable=global-statement
+ if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+ hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
+ )
+ y = y.squeeze(1)
+
+ spec = torch.view_as_real(
+ torch.stft(
+ y,
+ n_fft,
+ hop_length=hop_size,
+ win_length=win_size,
+ window=hann_window[str(sampling_rate) + "_" + str(y.device)],
+ center=center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+ )
+
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+
+ spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
+ spec = spectral_normalize_torch(spec)
+
+ return spec
diff --git a/modules/bigvgan/__pycache__/activations.cpython-310.pyc b/modules/bigvgan/__pycache__/activations.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bedc597b24c34a77805b188e40b71f1d5f221118
Binary files /dev/null and b/modules/bigvgan/__pycache__/activations.cpython-310.pyc differ
diff --git a/modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc b/modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dba1f3aac51089857f5a8226f45af8441fbf65a3
Binary files /dev/null and b/modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc differ
diff --git a/modules/bigvgan/__pycache__/env.cpython-310.pyc b/modules/bigvgan/__pycache__/env.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9c5f5045d1715338beb0335539cb76f229124f54
Binary files /dev/null and b/modules/bigvgan/__pycache__/env.cpython-310.pyc differ
diff --git a/modules/bigvgan/__pycache__/meldataset.cpython-310.pyc b/modules/bigvgan/__pycache__/meldataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..58995f652192d5e524e02445b22451a1d8ea87b2
Binary files /dev/null and b/modules/bigvgan/__pycache__/meldataset.cpython-310.pyc differ
diff --git a/modules/bigvgan/__pycache__/utils.cpython-310.pyc b/modules/bigvgan/__pycache__/utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..763d6d7ead834b85c9ea9aa27f306f8d59041fd5
Binary files /dev/null and b/modules/bigvgan/__pycache__/utils.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc b/modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7692a9a3df20d2fe2b7923398589e05741607fd2
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc b/modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1e13bedd060a666d59013e80b43573fad05e3516
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc b/modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..089a26db09f702d6a6984743f4befb52b1c80fbf
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/cuda/activation1d.py b/modules/bigvgan/alias_free_activation/cuda/activation1d.py
index fc0d313cb265170943fb7cb16742b031038f7859..76797ef1424b99a160803d53cb6c24fe20599bd2 100644
--- a/modules/bigvgan/alias_free_activation/cuda/activation1d.py
+++ b/modules/bigvgan/alias_free_activation/cuda/activation1d.py
@@ -3,10 +3,10 @@
import torch
import torch.nn as nn
-from alias_free_activation.torch.resample import UpSample1d, DownSample1d
+from ..torch.resample import UpSample1d, DownSample1d
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
-from alias_free_activation.cuda import load
+from ..cuda import load
anti_alias_activation_cuda = load.load()
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps b/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps
new file mode 100644
index 0000000000000000000000000000000000000000..ce9bc6ec886c4bda48c73677998abe0b2b73bfcc
--- /dev/null
+++ b/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e233713716a5778577f244b0f310944ff26d3079ce0e42491791da7d42e363c1
+size 522068
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log b/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log
new file mode 100644
index 0000000000000000000000000000000000000000..bd3c097a5622dde1a5c17fd152e04750a1dedded
--- /dev/null
+++ b/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log
@@ -0,0 +1,7 @@
+# ninja log v5
+9 39554 7516864785377831 anti_alias_activation.o 3a177f31dd72c43c
+13 152601 7516865914203767 anti_alias_activation_cuda.cuda.o 2d613e7382d803fd
+152628 153062 7516865920541751 anti_alias_activation_cuda.pyd f6366e9bdfb27f7
+128 50503 7654004565901584 anti_alias_activation.o 9ed3213f2e0d0858
+133 176837 7654005827401976 anti_alias_activation_cuda.cuda.o a679b6661c609136
+176839 177401 7654005835005523 anti_alias_activation_cuda.pyd f6366e9bdfb27f7
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o
new file mode 100644
index 0000000000000000000000000000000000000000..812f06975323c9e1937fa01c943e31ae02322145
--- /dev/null
+++ b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:74c2824b05582070b69f51ec588aadb268c4fddf18fbb4590f901d1cdf32185c
+size 3246655
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o
new file mode 100644
index 0000000000000000000000000000000000000000..329fb7a9b147a0af665ff7f7686bd0cc915ecc84
--- /dev/null
+++ b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86c48de557041de7ebaff7926b5f346cc5e4e2dddc6cf5b88409f6cb161db0f4
+size 4724513
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp
new file mode 100644
index 0000000000000000000000000000000000000000..3093a741ef126748042cafaef5c368f3ec5e2d3f
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp differ
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib
new file mode 100644
index 0000000000000000000000000000000000000000..1be22a5e2a68606c56c961333b67a251bf40d8ea
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib differ
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd
new file mode 100644
index 0000000000000000000000000000000000000000..dc51b91fc3d147e24ad0155ee31809556bb7208a
--- /dev/null
+++ b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:db37ea2dd31dfe67e68ee6019877d14638c41724ff9342c55f638f4d2cda3d03
+size 2454528
diff --git a/modules/bigvgan/alias_free_activation/cuda/build/build.ninja b/modules/bigvgan/alias_free_activation/cuda/build/build.ninja
new file mode 100644
index 0000000000000000000000000000000000000000..8c41cf88948be2657b26c226365a86b99278764a
--- /dev/null
+++ b/modules/bigvgan/alias_free_activation/cuda/build/build.ninja
@@ -0,0 +1,38 @@
+ninja_required_version = 1.3
+cxx = cl
+nvcc = C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin\nvcc
+
+cflags = -DTORCH_EXTENSION_NAME=anti_alias_activation_cuda -DTORCH_API_INCLUDE_EXTENSION_H -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\include" -ID:\Anaconda\envs\vocos\Include /std:c++17 -O3 /MD /wd4819 /wd4251 /wd4244 /wd4267 /wd4275 /wd4018 /wd4190 /wd4624 /wd4067 /wd4068 /EHsc
+post_cflags =
+cuda_cflags = -Xcudafe --diag_suppress=dll_interface_conflict_dllexport_assumed -Xcudafe --diag_suppress=dll_interface_conflict_none_assumed -Xcudafe --diag_suppress=field_without_dll_interface -Xcudafe --diag_suppress=base_class_has_different_dll_interface -Xcompiler /EHsc -Xcompiler /wd4068 -Xcompiler /wd4067 -Xcompiler /wd4624 -Xcompiler /wd4190 -Xcompiler /wd4018 -Xcompiler /wd4275 -Xcompiler /wd4267 -Xcompiler /wd4244 -Xcompiler /wd4251 -Xcompiler /wd4819 -Xcompiler /MD -DTORCH_EXTENSION_NAME=anti_alias_activation_cuda -DTORCH_API_INCLUDE_EXTENSION_H -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\include" -ID:\Anaconda\envs\vocos\Include -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++17 -O3 -gencode arch=compute_70,code=sm_70 --use_fast_math -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda -gencode arch=compute_80,code=sm_80
+cuda_post_cflags =
+cuda_dlink_post_cflags =
+sycl_dlink_post_cflags =
+ldflags = /DLL c10.lib c10_cuda.lib torch_cpu.lib torch_cuda.lib -INCLUDE:?warp_size@cuda@at@@YAHXZ torch.lib /LIBPATH:D:\Anaconda\envs\vocos\lib\site-packages\torch\lib torch_python.lib /LIBPATH:D:\Anaconda\envs\vocos\libs "/LIBPATH:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\lib\x64" cudart.lib
+
+rule compile
+ command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags
+ deps = msvc
+
+rule cuda_compile
+ depfile = $out.d
+ deps = gcc
+ command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
+
+
+
+
+
+rule link
+ command = "D$:\Visual Studio\VC\Tools\MSVC\14.29.30133\bin\Hostx86\x64/link.exe" $in /nologo $ldflags /out:$out
+
+build anti_alias_activation.o: compile D$:\seed-vc\modules\bigvgan\alias_free_activation\cuda\anti_alias_activation.cpp
+build anti_alias_activation_cuda.cuda.o: cuda_compile D$:\seed-vc\modules\bigvgan\alias_free_activation\cuda\anti_alias_activation_cuda.cu
+
+
+
+
+
+build anti_alias_activation_cuda.pyd: link anti_alias_activation.o anti_alias_activation_cuda.cuda.o
+
+default anti_alias_activation_cuda.pyd
diff --git a/modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc b/modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..76ee62bc7f4fd61a4da3faf3b5b608082eecd92a
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc b/modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6235f3b7476dd85812d5100a2d6ea941bca1280b
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/torch/__pycache__/filter.cpython-310.pyc b/modules/bigvgan/alias_free_activation/torch/__pycache__/filter.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb021ff5f7292a9def47b079d56dbfcc5a18f150
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/torch/__pycache__/filter.cpython-310.pyc differ
diff --git a/modules/bigvgan/alias_free_activation/torch/__pycache__/resample.cpython-310.pyc b/modules/bigvgan/alias_free_activation/torch/__pycache__/resample.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..17524739fecc6a6ffecbc9ec37007645cd6b422d
Binary files /dev/null and b/modules/bigvgan/alias_free_activation/torch/__pycache__/resample.cpython-310.pyc differ
diff --git a/modules/bigvgan/bigvgan.py b/modules/bigvgan/bigvgan.py
index 5a1196fa9fc6bca4276e23d5fe659e3f5af9b04a..41d6e44a2cb59a39d51cb4994d05804dc497dda7 100644
--- a/modules/bigvgan/bigvgan.py
+++ b/modules/bigvgan/bigvgan.py
@@ -42,15 +42,15 @@ class AMPBlock1(torch.nn.Module):
"""
def __init__(
- self,
- h: AttrDict,
- channels: int,
- kernel_size: int = 3,
- dilation: tuple = (1, 3, 5),
- activation: str = None,
+ self,
+ h: AttrDict,
+ channels: int,
+ kernel_size: int = 3,
+ dilation: tuple = (1, 3, 5),
+ activation: str = None,
):
super().__init__()
-
+
self.h = h
self.convs1 = nn.ModuleList(
@@ -93,7 +93,7 @@ class AMPBlock1(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
- from alias_free_activation.cuda.activation1d import (
+ from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
@@ -161,15 +161,15 @@ class AMPBlock2(torch.nn.Module):
"""
def __init__(
- self,
- h: AttrDict,
- channels: int,
- kernel_size: int = 3,
- dilation: tuple = (1, 3, 5),
- activation: str = None,
+ self,
+ h: AttrDict,
+ channels: int,
+ kernel_size: int = 3,
+ dilation: tuple = (1, 3, 5),
+ activation: str = None,
):
super().__init__()
-
+
self.h = h
self.convs = nn.ModuleList(
@@ -193,7 +193,7 @@ class AMPBlock2(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
- from alias_free_activation.cuda.activation1d import (
+ from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
@@ -270,7 +270,7 @@ class BigVGAN(
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
- from alias_free_activation.cuda.activation1d import (
+ from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
@@ -304,7 +304,7 @@ class BigVGAN(
[
weight_norm(
ConvTranspose1d(
- h.upsample_initial_channel // (2**i),
+ h.upsample_initial_channel // (2 ** i),
h.upsample_initial_channel // (2 ** (i + 1)),
k,
u,
@@ -320,7 +320,7 @@ class BigVGAN(
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
- zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
):
self.resblocks.append(
resblock_class(h, ch, k, d, activation=h.activation)
@@ -412,20 +412,20 @@ class BigVGAN(
@classmethod
def _from_pretrained(
- cls,
- *,
- model_id: str,
- revision: str,
- cache_dir: str,
- force_download: bool,
- proxies: Optional[Dict],
- resume_download: bool,
- local_files_only: bool,
- token: Union[str, bool, None],
- map_location: str = "cpu", # Additional argument
- strict: bool = False, # Additional argument
- use_cuda_kernel: bool = False,
- **model_kwargs,
+ cls,
+ *,
+ model_id: str,
+ revision: str,
+ cache_dir: str,
+ force_download: bool,
+ proxies: Optional[Dict],
+ resume_download: bool,
+ local_files_only: bool,
+ token: Union[str, bool, None],
+ map_location: str = "cpu", # Additional argument
+ strict: bool = False, # Additional argument
+ use_cuda_kernel: bool = False,
+ **model_kwargs,
):
"""Load Pytorch pretrained weights and return the loaded model."""
@@ -489,4 +489,4 @@ class BigVGAN(
model.remove_weight_norm()
model.load_state_dict(checkpoint_dict["generator"])
- return model
+ return model
\ No newline at end of file
diff --git a/modules/commons.py b/modules/commons.py
index 350208e50ba8630e53c30847db345cc3ace77473..fb0f6f89ef550e7570beffdef1d438e7dc51259f 100644
--- a/modules/commons.py
+++ b/modules/commons.py
@@ -1,490 +1,476 @@
-import math
-import numpy as np
-import torch
-from torch import nn
-from torch.nn import functional as F
-from munch import Munch
-import json
-
-
-class AttrDict(dict):
- def __init__(self, *args, **kwargs):
- super(AttrDict, self).__init__(*args, **kwargs)
- self.__dict__ = self
-
-
-def init_weights(m, mean=0.0, std=0.01):
- classname = m.__class__.__name__
- if classname.find("Conv") != -1:
- m.weight.data.normal_(mean, std)
-
-
-def get_padding(kernel_size, dilation=1):
- return int((kernel_size * dilation - dilation) / 2)
-
-
-def convert_pad_shape(pad_shape):
- l = pad_shape[::-1]
- pad_shape = [item for sublist in l for item in sublist]
- return pad_shape
-
-
-def intersperse(lst, item):
- result = [item] * (len(lst) * 2 + 1)
- result[1::2] = lst
- return result
-
-
-def kl_divergence(m_p, logs_p, m_q, logs_q):
- """KL(P||Q)"""
- kl = (logs_q - logs_p) - 0.5
- kl += (
- 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
- )
- return kl
-
-
-def rand_gumbel(shape):
- """Sample from the Gumbel distribution, protect from overflows."""
- uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
- return -torch.log(-torch.log(uniform_samples))
-
-
-def rand_gumbel_like(x):
- g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
- return g
-
-
-def slice_segments(x, ids_str, segment_size=4):
- ret = torch.zeros_like(x[:, :, :segment_size])
- for i in range(x.size(0)):
- idx_str = ids_str[i]
- idx_end = idx_str + segment_size
- ret[i] = x[i, :, idx_str:idx_end]
- return ret
-
-
-def slice_segments_audio(x, ids_str, segment_size=4):
- ret = torch.zeros_like(x[:, :segment_size])
- for i in range(x.size(0)):
- idx_str = ids_str[i]
- idx_end = idx_str + segment_size
- ret[i] = x[i, idx_str:idx_end]
- return ret
-
-
-def rand_slice_segments(x, x_lengths=None, segment_size=4):
- b, d, t = x.size()
- if x_lengths is None:
- x_lengths = t
- ids_str_max = x_lengths - segment_size + 1
- ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
- dtype=torch.long
- )
- ret = slice_segments(x, ids_str, segment_size)
- return ret, ids_str
-
-
-def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
- position = torch.arange(length, dtype=torch.float)
- num_timescales = channels // 2
- log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
- num_timescales - 1
- )
- inv_timescales = min_timescale * torch.exp(
- torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
- )
- scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
- signal = F.pad(signal, [0, 0, 0, channels % 2])
- signal = signal.view(1, channels, length)
- return signal
-
-
-def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
- b, channels, length = x.size()
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
- return x + signal.to(dtype=x.dtype, device=x.device)
-
-
-def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
- b, channels, length = x.size()
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
- return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
-
-
-def subsequent_mask(length):
- mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
- return mask
-
-
-@torch.jit.script
-def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
- n_channels_int = n_channels[0]
- in_act = input_a + input_b
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
- acts = t_act * s_act
- return acts
-
-
-def convert_pad_shape(pad_shape):
- l = pad_shape[::-1]
- pad_shape = [item for sublist in l for item in sublist]
- return pad_shape
-
-
-def shift_1d(x):
- x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
- return x
-
-
-def sequence_mask(length, max_length=None):
- if max_length is None:
- max_length = length.max()
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
- return x.unsqueeze(0) < length.unsqueeze(1)
-
-
-def avg_with_mask(x, mask):
- assert mask.dtype == torch.float, "Mask should be float"
-
- if mask.ndim == 2:
- mask = mask.unsqueeze(1)
-
- if mask.shape[1] == 1:
- mask = mask.expand_as(x)
-
- return (x * mask).sum() / mask.sum()
-
-
-def generate_path(duration, mask):
- """
- duration: [b, 1, t_x]
- mask: [b, 1, t_y, t_x]
- """
- device = duration.device
-
- b, _, t_y, t_x = mask.shape
- cum_duration = torch.cumsum(duration, -1)
-
- cum_duration_flat = cum_duration.view(b * t_x)
- path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
- path = path.view(b, t_x, t_y)
- path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
- path = path.unsqueeze(1).transpose(2, 3) * mask
- return path
-
-
-def clip_grad_value_(parameters, clip_value, norm_type=2):
- if isinstance(parameters, torch.Tensor):
- parameters = [parameters]
- parameters = list(filter(lambda p: p.grad is not None, parameters))
- norm_type = float(norm_type)
- if clip_value is not None:
- clip_value = float(clip_value)
-
- total_norm = 0
- for p in parameters:
- param_norm = p.grad.data.norm(norm_type)
- total_norm += param_norm.item() ** norm_type
- if clip_value is not None:
- p.grad.data.clamp_(min=-clip_value, max=clip_value)
- total_norm = total_norm ** (1.0 / norm_type)
- return total_norm
-
-
-def log_norm(x, mean=-4, std=4, dim=2):
- """
- normalized log mel -> mel -> norm -> log(norm)
- """
- x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
- return x
-
-
-def load_F0_models(path):
- # load F0 model
- from .JDC.model import JDCNet
-
- F0_model = JDCNet(num_class=1, seq_len=192)
- params = torch.load(path, map_location="cpu")["net"]
- F0_model.load_state_dict(params)
- _ = F0_model.train()
-
- return F0_model
-
-
-def modify_w2v_forward(self, output_layer=15):
- """
- change forward method of w2v encoder to get its intermediate layer output
- :param self:
- :param layer:
- :return:
- """
- from transformers.modeling_outputs import BaseModelOutput
-
- def forward(
- hidden_states,
- attention_mask=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
-
- conv_attention_mask = attention_mask
- if attention_mask is not None:
- # make sure padded tokens output 0
- hidden_states = hidden_states.masked_fill(
- ~attention_mask.bool().unsqueeze(-1), 0.0
- )
-
- # extend attention_mask
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(
- dtype=hidden_states.dtype
- )
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
- attention_mask = attention_mask.expand(
- attention_mask.shape[0],
- 1,
- attention_mask.shape[-1],
- attention_mask.shape[-1],
- )
-
- hidden_states = self.dropout(hidden_states)
-
- if self.embed_positions is not None:
- relative_position_embeddings = self.embed_positions(hidden_states)
- else:
- relative_position_embeddings = None
-
- deepspeed_zero3_is_enabled = False
-
- for i, layer in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
- dropout_probability = torch.rand([])
-
- skip_the_layer = (
- True
- if self.training and (dropout_probability < self.config.layerdrop)
- else False
- )
- if not skip_the_layer or deepspeed_zero3_is_enabled:
- # under deepspeed zero3 all gpus must run in sync
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- layer.__call__,
- hidden_states,
- attention_mask,
- relative_position_embeddings,
- output_attentions,
- conv_attention_mask,
- )
- else:
- layer_outputs = layer(
- hidden_states,
- attention_mask=attention_mask,
- relative_position_embeddings=relative_position_embeddings,
- output_attentions=output_attentions,
- conv_attention_mask=conv_attention_mask,
- )
- hidden_states = layer_outputs[0]
-
- if skip_the_layer:
- layer_outputs = (None, None)
-
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
-
- if i == output_layer - 1:
- break
-
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, all_hidden_states, all_self_attentions]
- if v is not None
- )
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
-
- return forward
-
-
-MATPLOTLIB_FLAG = False
-
-
-def plot_spectrogram_to_numpy(spectrogram):
- global MATPLOTLIB_FLAG
- if not MATPLOTLIB_FLAG:
- import matplotlib
- import logging
-
- matplotlib.use("Agg")
- MATPLOTLIB_FLAG = True
- mpl_logger = logging.getLogger("matplotlib")
- mpl_logger.setLevel(logging.WARNING)
- import matplotlib.pylab as plt
- import numpy as np
-
- fig, ax = plt.subplots(figsize=(10, 2))
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
- plt.colorbar(im, ax=ax)
- plt.xlabel("Frames")
- plt.ylabel("Channels")
- plt.tight_layout()
-
- fig.canvas.draw()
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
- plt.close()
- return data
-
-
-def normalize_f0(f0_sequence):
- # Remove unvoiced frames (replace with -1)
- voiced_indices = np.where(f0_sequence > 0)[0]
- f0_voiced = f0_sequence[voiced_indices]
-
- # Convert to log scale
- log_f0 = np.log2(f0_voiced)
-
- # Calculate mean and standard deviation
- mean_f0 = np.mean(log_f0)
- std_f0 = np.std(log_f0)
-
- # Normalize the F0 sequence
- normalized_f0 = (log_f0 - mean_f0) / std_f0
-
- # Create the normalized F0 sequence with unvoiced frames
- normalized_sequence = np.zeros_like(f0_sequence)
- normalized_sequence[voiced_indices] = normalized_f0
- normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
-
- return normalized_sequence
-
-
-def build_model(args, stage="DiT"):
- if stage == "DiT":
- from modules.flow_matching import CFM
- from modules.length_regulator import InterpolateRegulator
-
- length_regulator = InterpolateRegulator(
- channels=args.length_regulator.channels,
- sampling_ratios=args.length_regulator.sampling_ratios,
- is_discrete=args.length_regulator.is_discrete,
- in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
- vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
- codebook_size=args.length_regulator.content_codebook_size,
- n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
- quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
- f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
- n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
- )
- cfm = CFM(args)
- nets = Munch(
- cfm=cfm,
- length_regulator=length_regulator,
- )
- elif stage == 'codec':
- from dac.model.dac import Encoder
- from modules.quantize import (
- FAquantizer,
- )
-
- encoder = Encoder(
- d_model=args.DAC.encoder_dim,
- strides=args.DAC.encoder_rates,
- d_latent=1024,
- causal=args.causal,
- lstm=args.lstm,
- )
-
- quantizer = FAquantizer(
- in_dim=1024,
- n_p_codebooks=1,
- n_c_codebooks=args.n_c_codebooks,
- n_t_codebooks=2,
- n_r_codebooks=3,
- codebook_size=1024,
- codebook_dim=8,
- quantizer_dropout=0.5,
- causal=args.causal,
- separate_prosody_encoder=args.separate_prosody_encoder,
- timbre_norm=args.timbre_norm,
- )
-
- nets = Munch(
- encoder=encoder,
- quantizer=quantizer,
- )
- else:
- raise ValueError(f"Unknown stage: {stage}")
-
- return nets
-
-
-def load_checkpoint(
- model,
- optimizer,
- path,
- load_only_params=True,
- ignore_modules=[],
- is_distributed=False,
-):
- state = torch.load(path, map_location="cpu")
- params = state["net"]
- for key in model:
- if key in params and key not in ignore_modules:
- if not is_distributed:
- # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
- for k in list(params[key].keys()):
- if k.startswith("module."):
- params[key][k[len("module.") :]] = params[key][k]
- del params[key][k]
- model_state_dict = model[key].state_dict()
- # 过滤出形状匹配的键值对
- filtered_state_dict = {
- k: v
- for k, v in params[key].items()
- if k in model_state_dict and v.shape == model_state_dict[k].shape
- }
- skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
- if skipped_keys:
- print(
- f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
- )
- print("%s loaded" % key)
- model[key].load_state_dict(filtered_state_dict, strict=False)
- _ = [model[key].eval() for key in model]
-
- if not load_only_params:
- epoch = state["epoch"] + 1
- iters = state["iters"]
- optimizer.load_state_dict(state["optimizer"])
- optimizer.load_scheduler_state_dict(state["scheduler"])
-
- else:
- epoch = 0
- iters = 0
-
- return model, optimizer, epoch, iters
-
-
-def recursive_munch(d):
- if isinstance(d, dict):
- return Munch((k, recursive_munch(v)) for k, v in d.items())
- elif isinstance(d, list):
- return [recursive_munch(v) for v in d]
- else:
- return d
+import math
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+from munch import Munch
+import json
+import argparse
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ("yes", "true", "t", "y", "1"):
+ return True
+ elif v.lower() in ("no", "false", "f", "n", "0"):
+ return False
+ else:
+ raise argparse.ArgumentTypeError("Boolean value expected.")
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def intersperse(lst, item):
+ result = [item] * (len(lst) * 2 + 1)
+ result[1::2] = lst
+ return result
+
+
+def kl_divergence(m_p, logs_p, m_q, logs_q):
+ """KL(P||Q)"""
+ kl = (logs_q - logs_p) - 0.5
+ kl += (
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
+ )
+ return kl
+
+
+def rand_gumbel(shape):
+ """Sample from the Gumbel distribution, protect from overflows."""
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
+ return -torch.log(-torch.log(uniform_samples))
+
+
+def rand_gumbel_like(x):
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
+ return g
+
+
+def slice_segments(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, :, idx_str:idx_end]
+ return ret
+
+
+def slice_segments_audio(x, ids_str, segment_size=4):
+ ret = torch.zeros_like(x[:, :segment_size])
+ for i in range(x.size(0)):
+ idx_str = ids_str[i]
+ idx_end = idx_str + segment_size
+ ret[i] = x[i, idx_str:idx_end]
+ return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+ b, d, t = x.size()
+ if x_lengths is None:
+ x_lengths = t
+ ids_str_max = x_lengths - segment_size + 1
+ ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
+ dtype=torch.long
+ )
+ ret = slice_segments(x, ids_str, segment_size)
+ return ret, ids_str
+
+
+def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
+ position = torch.arange(length, dtype=torch.float)
+ num_timescales = channels // 2
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
+ num_timescales - 1
+ )
+ inv_timescales = min_timescale * torch.exp(
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
+ )
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
+ signal = signal.view(1, channels, length)
+ return signal
+
+
+def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return x + signal.to(dtype=x.dtype, device=x.device)
+
+
+def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
+ b, channels, length = x.size()
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
+
+
+def subsequent_mask(length):
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+ return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def shift_1d(x):
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+ return x
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def avg_with_mask(x, mask):
+ assert mask.dtype == torch.float, "Mask should be float"
+
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(1)
+
+ if mask.shape[1] == 1:
+ mask = mask.expand_as(x)
+
+ return (x * mask).sum() / mask.sum()
+
+
+def generate_path(duration, mask):
+ """
+ duration: [b, 1, t_x]
+ mask: [b, 1, t_y, t_x]
+ """
+ device = duration.device
+
+ b, _, t_y, t_x = mask.shape
+ cum_duration = torch.cumsum(duration, -1)
+
+ cum_duration_flat = cum_duration.view(b * t_x)
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
+ path = path.view(b, t_x, t_y)
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
+ path = path.unsqueeze(1).transpose(2, 3) * mask
+ return path
+
+
+def clip_grad_value_(parameters, clip_value, norm_type=2):
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
+ norm_type = float(norm_type)
+ if clip_value is not None:
+ clip_value = float(clip_value)
+
+ total_norm = 0
+ for p in parameters:
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm.item() ** norm_type
+ if clip_value is not None:
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
+ total_norm = total_norm ** (1.0 / norm_type)
+ return total_norm
+
+
+def log_norm(x, mean=-4, std=4, dim=2):
+ """
+ normalized log mel -> mel -> norm -> log(norm)
+ """
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
+ return x
+
+
+def load_F0_models(path):
+ # load F0 model
+ from .JDC.model import JDCNet
+
+ F0_model = JDCNet(num_class=1, seq_len=192)
+ params = torch.load(path, map_location="cpu")["net"]
+ F0_model.load_state_dict(params)
+ _ = F0_model.train()
+
+ return F0_model
+
+
+def modify_w2v_forward(self, output_layer=15):
+ """
+ change forward method of w2v encoder to get its intermediate layer output
+ :param self:
+ :param layer:
+ :return:
+ """
+ from transformers.modeling_outputs import BaseModelOutput
+
+ def forward(
+ hidden_states,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ conv_attention_mask = attention_mask
+ if attention_mask is not None:
+ # make sure padded tokens output 0
+ hidden_states = hidden_states.masked_fill(
+ ~attention_mask.bool().unsqueeze(-1), 0.0
+ )
+
+ # extend attention_mask
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(
+ dtype=hidden_states.dtype
+ )
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0],
+ 1,
+ attention_mask.shape[-1],
+ attention_mask.shape[-1],
+ )
+
+ hidden_states = self.dropout(hidden_states)
+
+ if self.embed_positions is not None:
+ relative_position_embeddings = self.embed_positions(hidden_states)
+ else:
+ relative_position_embeddings = None
+
+ deepspeed_zero3_is_enabled = False
+
+ for i, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = torch.rand([])
+
+ skip_the_layer = (
+ True
+ if self.training and (dropout_probability < self.config.layerdrop)
+ else False
+ )
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
+ # under deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer.__call__,
+ hidden_states,
+ attention_mask,
+ relative_position_embeddings,
+ output_attentions,
+ conv_attention_mask,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ relative_position_embeddings=relative_position_embeddings,
+ output_attentions=output_attentions,
+ conv_attention_mask=conv_attention_mask,
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if i == output_layer - 1:
+ break
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
+ if v is not None
+ )
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ return forward
+
+
+MATPLOTLIB_FLAG = False
+
+
+def plot_spectrogram_to_numpy(spectrogram):
+ global MATPLOTLIB_FLAG
+ if not MATPLOTLIB_FLAG:
+ import matplotlib
+ import logging
+
+ matplotlib.use("Agg")
+ MATPLOTLIB_FLAG = True
+ mpl_logger = logging.getLogger("matplotlib")
+ mpl_logger.setLevel(logging.WARNING)
+ import matplotlib.pylab as plt
+ import numpy as np
+
+ fig, ax = plt.subplots(figsize=(10, 2))
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
+ plt.colorbar(im, ax=ax)
+ plt.xlabel("Frames")
+ plt.ylabel("Channels")
+ plt.tight_layout()
+
+ fig.canvas.draw()
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ plt.close()
+ return data
+
+
+def normalize_f0(f0_sequence):
+ # Remove unvoiced frames (replace with -1)
+ voiced_indices = np.where(f0_sequence > 0)[0]
+ f0_voiced = f0_sequence[voiced_indices]
+
+ # Convert to log scale
+ log_f0 = np.log2(f0_voiced)
+
+ # Calculate mean and standard deviation
+ mean_f0 = np.mean(log_f0)
+ std_f0 = np.std(log_f0)
+
+ # Normalize the F0 sequence
+ normalized_f0 = (log_f0 - mean_f0) / std_f0
+
+ # Create the normalized F0 sequence with unvoiced frames
+ normalized_sequence = np.zeros_like(f0_sequence)
+ normalized_sequence[voiced_indices] = normalized_f0
+ normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
+
+ return normalized_sequence
+
+
+def build_model(args, stage="DiT"):
+ if stage == "DiT":
+ from modules.flow_matching import CFM
+ from modules.length_regulator import InterpolateRegulator
+
+ length_regulator = InterpolateRegulator(
+ channels=args.length_regulator.channels,
+ sampling_ratios=args.length_regulator.sampling_ratios,
+ is_discrete=args.length_regulator.is_discrete,
+ in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
+ codebook_size=args.length_regulator.content_codebook_size,
+ f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
+ n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
+ )
+ cfm = CFM(args)
+ nets = Munch(
+ cfm=cfm,
+ length_regulator=length_regulator,
+ )
+ else:
+ raise ValueError(f"Unknown stage: {stage}")
+
+ return nets
+
+
+def load_checkpoint(
+ model,
+ optimizer,
+ path,
+ load_only_params=True,
+ ignore_modules=[],
+ is_distributed=False,
+ load_ema=False,
+):
+ state = torch.load(path, map_location="cpu")
+ params = state["net"]
+ if load_ema and "ema" in state:
+ print("Loading EMA")
+ for key in model:
+ i = 0
+ for param_name in params[key]:
+ if "input_pos" in param_name:
+ continue
+ assert params[key][param_name].shape == state["ema"][key][0][i].shape
+ params[key][param_name] = state["ema"][key][0][i].clone()
+ i += 1
+ for key in model:
+ if key in params and key not in ignore_modules:
+ if not is_distributed:
+ # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
+ for k in list(params[key].keys()):
+ if k.startswith("module."):
+ params[key][k[len("module.") :]] = params[key][k]
+ del params[key][k]
+ model_state_dict = model[key].state_dict()
+ # 过滤出形状匹配的键值对
+ filtered_state_dict = {
+ k: v
+ for k, v in params[key].items()
+ if k in model_state_dict and v.shape == model_state_dict[k].shape
+ }
+ skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
+ if skipped_keys:
+ print(
+ f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
+ )
+ print("%s loaded" % key)
+ model[key].load_state_dict(filtered_state_dict, strict=False)
+ _ = [model[key].eval() for key in model]
+
+ if not load_only_params:
+ epoch = state["epoch"] + 1
+ iters = state["iters"]
+ optimizer.load_state_dict(state["optimizer"])
+ optimizer.load_scheduler_state_dict(state["scheduler"])
+
+ else:
+ epoch = 0
+ iters = 0
+
+ return model, optimizer, epoch, iters
+
+
+def recursive_munch(d):
+ if isinstance(d, dict):
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
+ elif isinstance(d, list):
+ return [recursive_munch(v) for v in d]
+ else:
+ return d
diff --git a/modules/diffusion_transformer.py b/modules/diffusion_transformer.py
index b7f40975e52d1cc7944192bff30e2e7341e4fedb..f9b468fa6701a72fab3e55e31dadc814e10c78f1 100644
--- a/modules/diffusion_transformer.py
+++ b/modules/diffusion_transformer.py
@@ -1,240 +1,537 @@
-import torch
-from torch import nn
-import math
-
-from modules.gpt_fast.model import ModelArgs, Transformer
-# from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
-from modules.wavenet import WN
-from modules.commons import sequence_mask
-
-from torch.nn.utils import weight_norm
-
-def modulate(x, shift, scale):
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
-
-
-#################################################################################
-# Embedding Layers for Timesteps and Class Labels #
-#################################################################################
-
-class TimestepEmbedder(nn.Module):
- """
- Embeds scalar timesteps into vector representations.
- """
- def __init__(self, hidden_size, frequency_embedding_size=256):
- super().__init__()
- self.mlp = nn.Sequential(
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
- nn.SiLU(),
- nn.Linear(hidden_size, hidden_size, bias=True),
- )
- self.frequency_embedding_size = frequency_embedding_size
- self.max_period = 10000
- self.scale = 1000
-
- half = frequency_embedding_size // 2
- freqs = torch.exp(
- -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
- )
- self.register_buffer("freqs", freqs)
-
- def timestep_embedding(self, t):
- """
- Create sinusoidal timestep embeddings.
- :param t: a 1-D Tensor of N indices, one per batch element.
- These may be fractional.
- :param dim: the dimension of the output.
- :param max_period: controls the minimum frequency of the embeddings.
- :return: an (N, D) Tensor of positional embeddings.
- """
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
-
- args = self.scale * t[:, None].float() * self.freqs[None]
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
- if self.frequency_embedding_size % 2:
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
- return embedding
-
- def forward(self, t):
- t_freq = self.timestep_embedding(t)
- t_emb = self.mlp(t_freq)
- return t_emb
-
-
-class StyleEmbedder(nn.Module):
- """
- Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
- """
- def __init__(self, input_size, hidden_size, dropout_prob):
- super().__init__()
- use_cfg_embedding = dropout_prob > 0
- self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
- self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
- self.input_size = input_size
- self.dropout_prob = dropout_prob
-
- def forward(self, labels, train, force_drop_ids=None):
- use_dropout = self.dropout_prob > 0
- if (train and use_dropout) or (force_drop_ids is not None):
- labels = self.token_drop(labels, force_drop_ids)
- else:
- labels = self.style_in(labels)
- embeddings = labels
- return embeddings
-
-class FinalLayer(nn.Module):
- """
- The final layer of DiT.
- """
- def __init__(self, hidden_size, patch_size, out_channels):
- super().__init__()
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
- self.adaLN_modulation = nn.Sequential(
- nn.SiLU(),
- nn.Linear(hidden_size, 2 * hidden_size, bias=True)
- )
-
- def forward(self, x, c):
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
- x = modulate(self.norm_final(x), shift, scale)
- x = self.linear(x)
- return x
-
-class DiT(torch.nn.Module):
- def __init__(
- self,
- args
- ):
- super(DiT, self).__init__()
- self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
- self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
- self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
- model_args = ModelArgs(
- block_size=16384,#args.DiT.block_size,
- n_layer=args.DiT.depth,
- n_head=args.DiT.num_heads,
- dim=args.DiT.hidden_dim,
- head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
- vocab_size=1024,
- uvit_skip_connection=self.uvit_skip_connection,
- )
- self.transformer = Transformer(model_args)
- self.in_channels = args.DiT.in_channels
- self.out_channels = args.DiT.in_channels
- self.num_heads = args.DiT.num_heads
-
- self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
-
- self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
- self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
- self.content_dim = args.DiT.content_dim # for continuous content
- self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
- self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
-
- self.is_causal = args.DiT.is_causal
-
- self.n_f0_bins = args.DiT.n_f0_bins
- self.f0_bins = torch.arange(2, 1024, 1024 // args.DiT.n_f0_bins)
- self.f0_embedder = nn.Embedding(args.DiT.n_f0_bins, args.DiT.hidden_dim)
- self.f0_condition = args.DiT.f0_condition
-
- self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
- self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
- # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True))
- # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True))
-
- input_pos = torch.arange(16384)
- self.register_buffer("input_pos", input_pos)
-
- self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
- self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
- self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
- if self.final_layer_type == 'wavenet':
- self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
- kernel_size=args.wavenet.kernel_size,
- dilation_rate=args.wavenet.dilation_rate,
- n_layers=args.wavenet.num_layers,
- gin_channels=args.wavenet.hidden_dim,
- p_dropout=args.wavenet.p_dropout,
- causal=False)
- self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
- else:
- self.final_mlp = nn.Sequential(
- nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
- nn.SiLU(),
- nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
- )
- self.transformer_style_condition = args.DiT.style_condition
- self.wavenet_style_condition = args.wavenet.style_condition
- assert args.DiT.style_condition == args.wavenet.style_condition
-
- self.class_dropout_prob = args.DiT.class_dropout_prob
- self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
- self.res_projection = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim) # residual connection from tranformer output to final output
- self.long_skip_connection = args.DiT.long_skip_connection
- self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
-
- self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
- args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
- args.DiT.hidden_dim)
- if self.style_as_token:
- self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
-
- def setup_caches(self, max_batch_size, max_seq_length):
- self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
- def forward(self, x, prompt_x, x_lens, t, style, cond, f0=None, mask_content=False):
- class_dropout = False
- if self.training and torch.rand(1) < self.class_dropout_prob:
- class_dropout = True
- if not self.training and mask_content:
- class_dropout = True
- # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
- cond_in_module = self.cond_projection
-
- B, _, T = x.size()
-
-
- t1 = self.t_embedder(t) # (N, D)
-
- cond = cond_in_module(cond)
- if self.f0_condition and f0 is not None:
- quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
- cond = cond + self.f0_embedder(quantized_f0)
-
- x = x.transpose(1, 2)
- prompt_x = prompt_x.transpose(1, 2)
-
- x_in = torch.cat([x, prompt_x, cond], dim=-1)
- if self.transformer_style_condition and not self.style_as_token:
- x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1)
- if class_dropout:
- x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0
- x_in = self.cond_x_merge_linear(x_in) # (N, T, D)
-
- if self.style_as_token:
- style = self.style_in(style)
- style = torch.zeros_like(style) if class_dropout else style
- x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
- if self.time_as_token:
- x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
- x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1)
- input_pos = self.input_pos[:x_in.size(1)] # (T,)
- x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None
- x_res = self.transformer(x_in, None if self.time_as_token else t1.unsqueeze(1), input_pos, x_mask_expanded)
- x_res = x_res[:, 1:] if self.time_as_token else x_res
- x_res = x_res[:, 1:] if self.style_as_token else x_res
- if self.long_skip_connection:
- x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
- if self.final_layer_type == 'wavenet':
- x = self.conv1(x_res)
- x = x.transpose(1, 2)
- t2 = self.t_embedder2(t)
- x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
- x_res) # long residual connection
- x = self.final_layer(x, t1).transpose(1, 2)
- x = self.conv2(x)
- else:
- x = self.final_mlp(x_res)
- x = x.transpose(1, 2)
- return x
+import torch
+from torch import nn
+import math
+
+# from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer
+from modules.wavenet import WN
+from modules.commons import sequence_mask
+
+from torch.nn.utils import weight_norm
+
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+class AdaptiveLayerNorm(nn.Module):
+ r"""Adaptive Layer Normalization"""
+
+ def __init__(self, d_model, norm) -> None:
+ super(AdaptiveLayerNorm, self).__init__()
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
+ self.norm = norm
+ self.d_model = d_model
+ self.eps = self.norm.eps
+
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
+ if embedding is None:
+ return self.norm(input)
+ weight, bias = torch.split(
+ self.project_layer(embedding),
+ split_size_or_sections=self.d_model,
+ dim=-1,
+ )
+ return weight * self.norm(input) + bias
+
+
+@dataclass
+class ModelArgs:
+ block_size: int = 2048
+ vocab_size: int = 32000
+ n_layer: int = 32
+ n_head: int = 32
+ dim: int = 4096
+ intermediate_size: int = None
+ n_local_heads: int = -1
+ head_dim: int = 64
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ has_cross_attention: bool = False
+ context_dim: int = 0
+ uvit_skip_connection: bool = False
+ time_as_token: bool = False
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ # self.head_dim = self.dim // self.n_head
+
+class Transformer(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.config = config
+
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
+ self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+ self.freqs_cis: Optional[Tensor] = None
+ self.mask_cache: Optional[Tensor] = None
+ self.max_batch_size = -1
+ self.max_seq_length = -1
+
+ def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=False):
+ if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
+ return
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_length = find_multiple(max_seq_length, 8)
+ self.max_seq_length = max_seq_length
+ self.max_batch_size = max_batch_size
+ dtype = self.norm.project_layer.weight.dtype
+ device = self.norm.project_layer.weight.device
+
+ self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
+ self.config.rope_base, dtype).to(device)
+ self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device)
+ self.use_kv_cache = use_kv_cache
+ self.uvit_skip_connection = self.config.uvit_skip_connection
+ if self.uvit_skip_connection:
+ self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2]
+ self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2]
+ else:
+ self.layers_emit_skip = []
+ self.layers_receive_skip = []
+
+ def forward(self,
+ x: Tensor,
+ c: Tensor,
+ input_pos: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ context: Optional[Tensor] = None,
+ context_input_pos: Optional[Tensor] = None,
+ cross_attention_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ assert self.freqs_cis is not None, "Caches must be initialized first"
+ if mask is None: # in case of non-causal model
+ if not self.training and self.use_kv_cache:
+ mask = self.causal_mask[None, None, input_pos]
+ else:
+ mask = self.causal_mask[None, None, input_pos]
+ mask = mask[..., input_pos]
+ freqs_cis = self.freqs_cis[input_pos]
+ if context is not None:
+ context_freqs_cis = self.freqs_cis[context_input_pos]
+ else:
+ context_freqs_cis = None
+ skip_in_x_list = []
+ for i, layer in enumerate(self.layers):
+ if self.uvit_skip_connection and i in self.layers_receive_skip:
+ skip_in_x = skip_in_x_list.pop(-1)
+ else:
+ skip_in_x = None
+ x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x)
+ if self.uvit_skip_connection and i in self.layers_emit_skip:
+ skip_in_x_list.append(x)
+ x = self.norm(x, c)
+ return x
+
+ @classmethod
+ def from_name(cls, name: str):
+ return cls(ModelArgs.from_name(name))
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.attention = Attention(config)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+ self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+ if config.has_cross_attention:
+ self.has_cross_attention = True
+ self.cross_attention = Attention(config, is_cross_attention=True)
+ self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+ else:
+ self.has_cross_attention = False
+
+ if config.uvit_skip_connection:
+ self.skip_in_linear = nn.Linear(config.dim * 2, config.dim)
+ self.uvit_skip_connection = True
+ else:
+ self.uvit_skip_connection = False
+
+ self.time_as_token = config.time_as_token
+
+ def forward(self,
+ x: Tensor,
+ c: Tensor,
+ input_pos: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ context: Optional[Tensor] = None,
+ context_freqs_cis: Optional[Tensor] = None,
+ cross_attention_mask: Optional[Tensor] = None,
+ skip_in_x: Optional[Tensor] = None,
+ ) -> Tensor:
+ c = None if self.time_as_token else c
+ if self.uvit_skip_connection and skip_in_x is not None:
+ x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1))
+ h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos)
+ if self.has_cross_attention:
+ h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis)
+ out = h + self.feed_forward(self.ffn_norm(h, c))
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ if is_cross_attention:
+ self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
+ self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
+ else:
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
+ self.kv_cache = None
+
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+ # self._register_load_state_dict_pre_hook(self.load_hook)
+
+ # def load_hook(self, state_dict, prefix, *args):
+ # if prefix + "wq.weight" in state_dict:
+ # wq = state_dict.pop(prefix + "wq.weight")
+ # wk = state_dict.pop(prefix + "wk.weight")
+ # wv = state_dict.pop(prefix + "wv.weight")
+ # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
+
+ def forward(self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ context: Optional[Tensor] = None,
+ context_freqs_cis: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_local_heads * self.head_dim
+ if context is None:
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
+ context_seqlen = seqlen
+ else:
+ q = self.wq(x)
+ k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
+ context_seqlen = context.shape[1]
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+
+ q = apply_rotary_emb(q, freqs_cis)
+ k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
+
+ y = self.wo(y)
+ return y
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+def precompute_freqs_cis(
+ seq_len: int, n_elem: int, base: int = 10000,
+ dtype: torch.dtype = torch.bfloat16
+) -> Tensor:
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+ return cache.to(dtype=dtype)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ],
+ -1,
+ )
+
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
+
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+ self.max_period = 10000
+ self.scale = 1000
+
+ half = frequency_embedding_size // 2
+ freqs = torch.exp(
+ -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ )
+ self.register_buffer("freqs", freqs)
+
+ def timestep_embedding(self, t):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+
+ args = self.scale * t[:, None].float() * self.freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if self.frequency_embedding_size % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class StyleEmbedder(nn.Module):
+ """
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
+ """
+ def __init__(self, input_size, hidden_size, dropout_prob):
+ super().__init__()
+ use_cfg_embedding = dropout_prob > 0
+ self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size)
+ self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True))
+ self.input_size = input_size
+ self.dropout_prob = dropout_prob
+
+ def forward(self, labels, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ labels = self.token_drop(labels, force_drop_ids)
+ else:
+ labels = self.style_in(labels)
+ embeddings = labels
+ return embeddings
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of DiT.
+ """
+ def __init__(self, hidden_size, patch_size, out_channels):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True))
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
+ )
+
+ def forward(self, x, c):
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
+ x = modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ return x
+
+class DiT(torch.nn.Module):
+ def __init__(
+ self,
+ args
+ ):
+ super(DiT, self).__init__()
+ self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False
+ self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False
+ self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False
+ model_args = ModelArgs(
+ block_size=16384,#args.DiT.block_size,
+ n_layer=args.DiT.depth,
+ n_head=args.DiT.num_heads,
+ dim=args.DiT.hidden_dim,
+ head_dim=args.DiT.hidden_dim // args.DiT.num_heads,
+ vocab_size=1024,
+ uvit_skip_connection=self.uvit_skip_connection,
+ time_as_token=self.time_as_token,
+ )
+ self.transformer = Transformer(model_args)
+ self.in_channels = args.DiT.in_channels
+ self.out_channels = args.DiT.in_channels
+ self.num_heads = args.DiT.num_heads
+
+ self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True))
+
+ self.content_type = args.DiT.content_type # 'discrete' or 'continuous'
+ self.content_codebook_size = args.DiT.content_codebook_size # for discrete content
+ self.content_dim = args.DiT.content_dim # for continuous content
+ self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content
+ self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content
+
+ self.is_causal = args.DiT.is_causal
+
+ self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim)
+
+ input_pos = torch.arange(16384)
+ self.register_buffer("input_pos", input_pos)
+
+ self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet
+ if self.final_layer_type == 'wavenet':
+ self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim)
+ self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim)
+ self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1)
+ self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim,
+ kernel_size=args.wavenet.kernel_size,
+ dilation_rate=args.wavenet.dilation_rate,
+ n_layers=args.wavenet.num_layers,
+ gin_channels=args.wavenet.hidden_dim,
+ p_dropout=args.wavenet.p_dropout,
+ causal=False)
+ self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim)
+ self.res_projection = nn.Linear(args.DiT.hidden_dim,
+ args.wavenet.hidden_dim) # residual connection from tranformer output to final output
+ self.wavenet_style_condition = args.wavenet.style_condition
+ assert args.DiT.style_condition == args.wavenet.style_condition
+ else:
+ self.final_mlp = nn.Sequential(
+ nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim),
+ nn.SiLU(),
+ nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels),
+ )
+ self.transformer_style_condition = args.DiT.style_condition
+
+
+ self.class_dropout_prob = args.DiT.class_dropout_prob
+ self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim)
+
+ self.long_skip_connection = args.DiT.long_skip_connection
+ self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim)
+
+ self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 +
+ args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token),
+ args.DiT.hidden_dim)
+ if self.style_as_token:
+ self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim)
+
+ def setup_caches(self, max_batch_size, max_seq_length):
+ self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False)
+ def forward(self, x, prompt_x, x_lens, t, style, cond, mask_content=False):
+ class_dropout = False
+ if self.training and torch.rand(1) < self.class_dropout_prob:
+ class_dropout = True
+ if not self.training and mask_content:
+ class_dropout = True
+ # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection
+ cond_in_module = self.cond_projection
+
+ B, _, T = x.size()
+
+
+ t1 = self.t_embedder(t) # (N, D)
+
+ cond = cond_in_module(cond)
+
+ x = x.transpose(1, 2)
+ prompt_x = prompt_x.transpose(1, 2)
+
+ x_in = torch.cat([x, prompt_x, cond], dim=-1)
+ if self.transformer_style_condition and not self.style_as_token:
+ x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1)
+ if class_dropout:
+ x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0
+ x_in = self.cond_x_merge_linear(x_in) # (N, T, D)
+
+ if self.style_as_token:
+ style = self.style_in(style)
+ style = torch.zeros_like(style) if class_dropout else style
+ x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
+ if self.time_as_token:
+ x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
+ x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1)
+ input_pos = self.input_pos[:x_in.size(1)] # (T,)
+ x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None
+ x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded)
+ x_res = x_res[:, 1:] if self.time_as_token else x_res
+ x_res = x_res[:, 1:] if self.style_as_token else x_res
+ if self.long_skip_connection:
+ x_res = self.skip_linear(torch.cat([x_res, x], dim=-1))
+ if self.final_layer_type == 'wavenet':
+ x = self.conv1(x_res)
+ x = x.transpose(1, 2)
+ t2 = self.t_embedder2(t)
+ x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection(
+ x_res) # long residual connection
+ x = self.final_layer(x, t1).transpose(1, 2)
+ x = self.conv2(x)
+ else:
+ x = self.final_mlp(x_res)
+ x = x.transpose(1, 2)
+ return x
\ No newline at end of file
diff --git a/modules/flow_matching.py b/modules/flow_matching.py
index c2581c620f884b7c4b60164729240c310198d74a..61389183c6604edf80a9517ead197dd6aa097740 100644
--- a/modules/flow_matching.py
+++ b/modules/flow_matching.py
@@ -49,6 +49,7 @@ class BASECFM(torch.nn.Module, ABC):
B, T = mu.size(0), mu.size(1)
z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
+ # t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate)
def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5):
@@ -66,7 +67,7 @@ class BASECFM(torch.nn.Module, ABC):
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
- t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
+ t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0]
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
@@ -79,16 +80,28 @@ class BASECFM(torch.nn.Module, ABC):
if self.zero_prompt_speech_token:
mu[..., :prompt_len] = 0
for step in tqdm(range(1, len(t_span))):
- dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu, f0)
- # Classifier-Free Guidance inference introduced in VoiceBox
+ dt = t_span[step] - t_span[step - 1]
if inference_cfg_rate > 0:
- cfg_dphi_dt = self.estimator(
- x, torch.zeros_like(prompt_x), x_lens, t.unsqueeze(0),
- torch.zeros_like(style),
- torch.zeros_like(mu), None
+ # Stack original and CFG (null) inputs for batched processing
+ stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0)
+ stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0)
+ stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0)
+ stacked_x = torch.cat([x, x], dim=0)
+ stacked_t = torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0)
+
+ # Perform a single forward pass for both original and CFG inputs
+ stacked_dphi_dt = self.estimator(
+ stacked_x, stacked_prompt_x, x_lens, stacked_t, stacked_style, stacked_mu,
)
- dphi_dt = ((1.0 + inference_cfg_rate) * dphi_dt -
- inference_cfg_rate * cfg_dphi_dt)
+
+ # Split the output back into the original and CFG components
+ dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0)
+
+ # Apply CFG formula
+ dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt
+ else:
+ dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
+
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
@@ -97,8 +110,7 @@ class BASECFM(torch.nn.Module, ABC):
x[:, :, :prompt_len] = 0
return sol[-1]
-
- def forward(self, x1, x_lens, prompt_lens, mu, style, f0=None):
+ def forward(self, x1, x_lens, prompt_lens, mu, style):
"""Computes diffusion loss
Args:
@@ -134,13 +146,13 @@ class BASECFM(torch.nn.Module, ABC):
if self.zero_prompt_speech_token:
mu[bib, :, :prompt_lens[bib]] = 0
- estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(), style, mu, f0)
+ estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(1).squeeze(1), style, mu, prompt_lens)
loss = 0
for bib in range(b):
loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
loss /= b
- return loss, y
+ return loss, estimator_out + (1 - self.sigma_min) * z
diff --git a/modules/length_regulator.py b/modules/length_regulator.py
index a896c6ced97e409ba657f60af59a2f82e1688e65..8bc875326f8b846a09fbb9602d3ebf3ba6cc3b0f 100644
--- a/modules/length_regulator.py
+++ b/modules/length_regulator.py
@@ -1,141 +1,141 @@
-from typing import Tuple
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-from modules.commons import sequence_mask
-import numpy as np
-from dac.nn.quantize import VectorQuantize
-
-# f0_bin = 256
-f0_max = 1100.0
-f0_min = 50.0
-f0_mel_min = 1127 * np.log(1 + f0_min / 700)
-f0_mel_max = 1127 * np.log(1 + f0_max / 700)
-
-def f0_to_coarse(f0, f0_bin):
- f0_mel = 1127 * (1 + f0 / 700).log()
- a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
- b = f0_mel_min * a - 1.
- f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
- # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
- f0_coarse = torch.round(f0_mel).long()
- f0_coarse = f0_coarse * (f0_coarse > 0)
- f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
- f0_coarse = f0_coarse * (f0_coarse < f0_bin)
- f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
- return f0_coarse
-
-class InterpolateRegulator(nn.Module):
- def __init__(
- self,
- channels: int,
- sampling_ratios: Tuple,
- is_discrete: bool = False,
- in_channels: int = None, # only applies to continuous input
- vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input
- codebook_size: int = 1024, # for discrete only
- out_channels: int = None,
- groups: int = 1,
- n_codebooks: int = 1, # number of codebooks
- quantizer_dropout: float = 0.0, # dropout for quantizer
- f0_condition: bool = False,
- n_f0_bins: int = 512,
- ):
- super().__init__()
- self.sampling_ratios = sampling_ratios
- out_channels = out_channels or channels
- model = nn.ModuleList([])
- if len(sampling_ratios) > 0:
- self.interpolate = True
- for _ in sampling_ratios:
- module = nn.Conv1d(channels, channels, 3, 1, 1)
- norm = nn.GroupNorm(groups, channels)
- act = nn.Mish()
- model.extend([module, norm, act])
- else:
- self.interpolate = False
- model.append(
- nn.Conv1d(channels, out_channels, 1, 1)
- )
- self.model = nn.Sequential(*model)
- self.embedding = nn.Embedding(codebook_size, channels)
- self.is_discrete = is_discrete
-
- self.mask_token = nn.Parameter(torch.zeros(1, channels))
-
- self.n_codebooks = n_codebooks
- if n_codebooks > 1:
- self.extra_codebooks = nn.ModuleList([
- nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
- ])
- self.extra_codebook_mask_tokens = nn.ParameterList([
- nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
- ])
- self.quantizer_dropout = quantizer_dropout
-
- if f0_condition:
- self.f0_embedding = nn.Embedding(n_f0_bins, channels)
- self.f0_condition = f0_condition
- self.n_f0_bins = n_f0_bins
- self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
- self.f0_mask = nn.Parameter(torch.zeros(1, channels))
- else:
- self.f0_condition = False
-
- if not is_discrete:
- self.content_in_proj = nn.Linear(in_channels, channels)
- if vector_quantize:
- self.vq = VectorQuantize(channels, codebook_size, 8)
-
- def forward(self, x, ylens=None, n_quantizers=None, f0=None):
- # apply token drop
- if self.training:
- n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
- dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
- n_dropout = int(x.shape[0] * self.quantizer_dropout)
- n_quantizers[:n_dropout] = dropout[:n_dropout]
- n_quantizers = n_quantizers.to(x.device)
- # decide whether to drop for each sample in batch
- else:
- n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
- if self.is_discrete:
- if self.n_codebooks > 1:
- assert len(x.size()) == 3
- x_emb = self.embedding(x[:, 0])
- for i, emb in enumerate(self.extra_codebooks):
- x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
- # add mask token if not using this codebook
- # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
- x = x_emb
- elif self.n_codebooks == 1:
- if len(x.size()) == 2:
- x = self.embedding(x)
- else:
- x = self.embedding(x[:, 0])
- else:
- x = self.content_in_proj(x)
- # x in (B, T, D)
- mask = sequence_mask(ylens).unsqueeze(-1)
- if self.interpolate:
- x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
- else:
- x = x.transpose(1, 2).contiguous()
- mask = mask[:, :x.size(2), :]
- ylens = ylens.clamp(max=x.size(2)).long()
- if self.f0_condition:
- if f0 is None:
- x = x + self.f0_mask.unsqueeze(-1)
- else:
- #quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
- quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
- quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
- f0_emb = self.f0_embedding(quantized_f0)
- f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
- x = x + f0_emb
- out = self.model(x).transpose(1, 2).contiguous()
- if hasattr(self, 'vq'):
- out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2))
- out_q = out_q.transpose(1, 2)
- return out_q * mask, ylens, codes, commitment_loss, codebook_loss
- olens = ylens
- return out * mask, olens, None, None, None
+from typing import Tuple
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from modules.commons import sequence_mask
+import numpy as np
+from dac.nn.quantize import VectorQuantize
+
+# f0_bin = 256
+f0_max = 1100.0
+f0_min = 50.0
+f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+
+def f0_to_coarse(f0, f0_bin):
+ f0_mel = 1127 * (1 + f0 / 700).log()
+ a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
+ b = f0_mel_min * a - 1.
+ f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
+ # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
+ f0_coarse = torch.round(f0_mel).long()
+ f0_coarse = f0_coarse * (f0_coarse > 0)
+ f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
+ f0_coarse = f0_coarse * (f0_coarse < f0_bin)
+ f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
+ return f0_coarse
+
+class InterpolateRegulator(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ sampling_ratios: Tuple,
+ is_discrete: bool = False,
+ in_channels: int = None, # only applies to continuous input
+ vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input
+ codebook_size: int = 1024, # for discrete only
+ out_channels: int = None,
+ groups: int = 1,
+ n_codebooks: int = 1, # number of codebooks
+ quantizer_dropout: float = 0.0, # dropout for quantizer
+ f0_condition: bool = False,
+ n_f0_bins: int = 512,
+ ):
+ super().__init__()
+ self.sampling_ratios = sampling_ratios
+ out_channels = out_channels or channels
+ model = nn.ModuleList([])
+ if len(sampling_ratios) > 0:
+ self.interpolate = True
+ for _ in sampling_ratios:
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
+ norm = nn.GroupNorm(groups, channels)
+ act = nn.Mish()
+ model.extend([module, norm, act])
+ else:
+ self.interpolate = False
+ model.append(
+ nn.Conv1d(channels, out_channels, 1, 1)
+ )
+ self.model = nn.Sequential(*model)
+ self.embedding = nn.Embedding(codebook_size, channels)
+ self.is_discrete = is_discrete
+
+ self.mask_token = nn.Parameter(torch.zeros(1, channels))
+
+ self.n_codebooks = n_codebooks
+ if n_codebooks > 1:
+ self.extra_codebooks = nn.ModuleList([
+ nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1)
+ ])
+ self.extra_codebook_mask_tokens = nn.ParameterList([
+ nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1)
+ ])
+ self.quantizer_dropout = quantizer_dropout
+
+ if f0_condition:
+ self.f0_embedding = nn.Embedding(n_f0_bins, channels)
+ self.f0_condition = f0_condition
+ self.n_f0_bins = n_f0_bins
+ self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
+ self.f0_mask = nn.Parameter(torch.zeros(1, channels))
+ else:
+ self.f0_condition = False
+
+ if not is_discrete:
+ self.content_in_proj = nn.Linear(in_channels, channels)
+ if vector_quantize:
+ self.vq = VectorQuantize(channels, codebook_size, 8)
+
+ def forward(self, x, ylens=None, n_quantizers=None, f0=None):
+ # apply token drop
+ if self.training:
+ n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks
+ dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],))
+ n_dropout = int(x.shape[0] * self.quantizer_dropout)
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
+ n_quantizers = n_quantizers.to(x.device)
+ # decide whether to drop for each sample in batch
+ else:
+ n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers)
+ if self.is_discrete:
+ if self.n_codebooks > 1:
+ assert len(x.size()) == 3
+ x_emb = self.embedding(x[:, 0])
+ for i, emb in enumerate(self.extra_codebooks):
+ x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1])
+ # add mask token if not using this codebook
+ # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i]
+ x = x_emb
+ elif self.n_codebooks == 1:
+ if len(x.size()) == 2:
+ x = self.embedding(x)
+ else:
+ x = self.embedding(x[:, 0])
+ else:
+ x = self.content_in_proj(x)
+ # x in (B, T, D)
+ mask = sequence_mask(ylens).unsqueeze(-1)
+ if self.interpolate:
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
+ else:
+ x = x.transpose(1, 2).contiguous()
+ mask = mask[:, :x.size(2), :]
+ ylens = ylens.clamp(max=x.size(2)).long()
+ if self.f0_condition:
+ if f0 is None:
+ x = x + self.f0_mask.unsqueeze(-1)
+ else:
+ #quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
+ quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
+ quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
+ f0_emb = self.f0_embedding(quantized_f0)
+ f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
+ x = x + f0_emb
+ out = self.model(x).transpose(1, 2).contiguous()
+ if hasattr(self, 'vq'):
+ out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2))
+ out_q = out_q.transpose(1, 2)
+ return out_q * mask, ylens, codes, commitment_loss, codebook_loss
+ olens = ylens
+ return out * mask, olens, None, None, None
diff --git a/modules/rmvpe.py b/modules/rmvpe.py
index 066a9eebdbcb16ab2d9de0e4738ad3575405907f..44ae2e0ec9fde661dd8360d3fd731fe66b5ab51c 100644
--- a/modules/rmvpe.py
+++ b/modules/rmvpe.py
@@ -486,7 +486,13 @@ class RMVPE:
self.resample_kernel = {}
self.is_half = is_half
if device is None:
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ #device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ if torch.cuda.is_available():
+ device = "cuda:0"
+ elif torch.backends.mps.is_available():
+ device = "mps"
+ else:
+ device = "cpu"
self.device = device
self.mel_extractor = MelSpectrogram(
is_half, 128, 16000, 1024, 160, None, 30, 8000
@@ -572,6 +578,37 @@ class RMVPE:
# t3 = ttime()
# print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
return f0
+ def infer_from_audio_batch(self, audio, thred=0.03):
+ # torch.cuda.synchronize()
+ # t0 = ttime()
+ if not torch.is_tensor(audio):
+ audio = torch.from_numpy(audio)
+ mel = self.mel_extractor(
+ audio.float().to(self.device), center=True
+ )
+ # print(123123123,mel.device.type)
+ # torch.cuda.synchronize()
+ # t1 = ttime()
+ hidden = self.mel2hidden(mel)
+ # torch.cuda.synchronize()
+ # t2 = ttime()
+ # print(234234,hidden.device.type)
+ if "privateuseone" not in str(self.device):
+ hidden = hidden.cpu().numpy()
+ else:
+ pass
+ if self.is_half == True:
+ hidden = hidden.astype("float32")
+
+ f0s = []
+ for bib in range(hidden.shape[0]):
+ f0s.append(self.decode(hidden[bib], thred=thred))
+ f0s = np.stack(f0s)
+ f0s = torch.from_numpy(f0s).to(self.device)
+ # torch.cuda.synchronize()
+ # t3 = ttime()
+ # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
+ return f0s
def to_local_average_cents(self, salience, thred=0.05):
# t0 = ttime()
diff --git a/modules/v2/__pycache__/ar.cpython-310.pyc b/modules/v2/__pycache__/ar.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4ef2289697c46cebf33dec7e22ad3832e22768d5
Binary files /dev/null and b/modules/v2/__pycache__/ar.cpython-310.pyc differ
diff --git a/modules/v2/__pycache__/cfm.cpython-310.pyc b/modules/v2/__pycache__/cfm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fff9c2cda4430b1f3849eb7e0941491130030e32
Binary files /dev/null and b/modules/v2/__pycache__/cfm.cpython-310.pyc differ
diff --git a/modules/v2/__pycache__/dit_model.cpython-310.pyc b/modules/v2/__pycache__/dit_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..30550b7c8fa982077b5988cb11ecebcffa45ea8a
Binary files /dev/null and b/modules/v2/__pycache__/dit_model.cpython-310.pyc differ
diff --git a/modules/v2/__pycache__/dit_wrapper.cpython-310.pyc b/modules/v2/__pycache__/dit_wrapper.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c16d4408efe4500b7caf352370c03562b0a6b43
Binary files /dev/null and b/modules/v2/__pycache__/dit_wrapper.cpython-310.pyc differ
diff --git a/modules/v2/__pycache__/length_regulator.cpython-310.pyc b/modules/v2/__pycache__/length_regulator.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d20e70c5fa10e54c6dd8abba283d448e34b66b96
Binary files /dev/null and b/modules/v2/__pycache__/length_regulator.cpython-310.pyc differ
diff --git a/modules/v2/__pycache__/model.cpython-310.pyc b/modules/v2/__pycache__/model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..202b549ed99cfcf71724feae469c648534d13b30
Binary files /dev/null and b/modules/v2/__pycache__/model.cpython-310.pyc differ
diff --git a/modules/v2/__pycache__/vc_wrapper.cpython-310.pyc b/modules/v2/__pycache__/vc_wrapper.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5ceaa1e09ce09776cb8e26c81e94ccf4850512d
Binary files /dev/null and b/modules/v2/__pycache__/vc_wrapper.cpython-310.pyc differ
diff --git a/modules/v2/ar.py b/modules/v2/ar.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c38b99e14c0aa93bf4a908c231437904886ad8e
--- /dev/null
+++ b/modules/v2/ar.py
@@ -0,0 +1,763 @@
+import dataclasses
+import json
+import math
+from collections import OrderedDict
+from functools import partial, wraps
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional, Tuple, List
+from tqdm import tqdm
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from torch import Tensor
+from torch.nn import functional as F
+from torch.utils.checkpoint import checkpoint
+
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+def l2norm(t, groups = 1):
+ t = rearrange(t, '... (g d) -> ... g d', g = groups)
+ t = F.normalize(t, p = 2, dim = -1)
+ return rearrange(t, '... g d -> ... (g d)')
+
+@dataclass
+class BaseModelArgs:
+ model_type: str = "base"
+
+ vocab_size: int = 32000
+ n_layer: int = 32
+ n_head: int = 32
+ dim: int = 4096
+ intermediate_size: int = None
+ n_local_heads: int = -1
+ head_dim: int = 64
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ max_seq_len: int = 4096
+ dropout: float = 0.0
+ tie_word_embeddings: bool = True
+ attention_qkv_bias: bool = False
+
+ # Gradient checkpointing
+ use_gradient_checkpointing: bool = False
+
+ # Initialize the model
+ initializer_range: float = 0.02
+
+ qk_norm: bool = False
+ layerscale: bool = False
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ self.head_dim = self.dim // self.n_head
+
+ def save(self, path: str):
+ with open(path, "w") as f:
+ json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False)
+
+
+@dataclass
+class NaiveModelArgs(BaseModelArgs):
+ model_type: str = "naive"
+
+
+class KVCache(nn.Module):
+ def __init__(
+ self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16
+ ):
+ super().__init__()
+ cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim)
+ self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
+ self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))
+
+ def update(self, input_pos, k_val, v_val):
+ # input_pos: [S], k_val: [B, H, S, D]
+ assert input_pos.shape[0] == k_val.shape[2]
+
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:, :, input_pos] = k_val
+ v_out[:, :, input_pos] = v_val
+
+ return k_out, v_out
+
+
+@dataclass
+class TransformerForwardResult:
+ token_logits: Tensor
+ token_targets: Tensor
+
+
+@dataclass
+class BaseTransformerForwardResult:
+ logits: Tensor
+ hidden_states: Tensor
+
+
+class BaseTransformer(nn.Module):
+ def __init__(
+ self,
+ config: BaseModelArgs,
+ init_weights: bool = True,
+ ) -> None:
+ super().__init__()
+ self.config = config
+
+ # Slow transformer
+ self.embeddings = nn.Embedding(
+ config.vocab_size,
+ config.dim,
+ )
+ self.layers = nn.ModuleList(
+ TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer)
+ )
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
+
+ if self.config.tie_word_embeddings is False:
+ self.output = nn.Linear(
+ config.dim,
+ config.vocab_size,
+ bias=False,
+ )
+
+ self.register_buffer(
+ "freqs_cis",
+ precompute_freqs_cis(
+ config.max_seq_len,
+ config.dim // config.n_head,
+ config.rope_base,
+ ),
+ persistent=False,
+ )
+ self.register_buffer(
+ "causal_mask",
+ torch.tril(
+ torch.ones(
+ config.max_seq_len,
+ config.max_seq_len,
+ dtype=torch.bool,
+ )
+ ),
+ persistent=False,
+ )
+
+ self.output = nn.Linear(
+ config.dim,
+ config.vocab_size,
+ bias=False,
+ )
+
+ # For kv cache
+ self.max_batch_size = -1
+ self.max_seq_len = -1
+
+ if init_weights:
+ self.apply(self._init_weights)
+
+ def setup_caches(
+ self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda"
+ ):
+ if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size:
+ return
+
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_len = find_multiple(max_seq_len, 8)
+ self.max_seq_len = max_seq_len
+ self.max_batch_size = max_batch_size
+
+ for b in self.layers:
+ b.attention.kv_cache = KVCache(
+ max_batch_size,
+ max_seq_len,
+ self.config.n_local_heads,
+ head_dim,
+ dtype=dtype,
+ ).to(device)
+
+ def embed_base(self, x: Tensor, x_lens: Tensor) -> Tensor:
+ for bib in range(x.size(0)):
+ x[bib, x_lens[bib]:] = self.config.vocab_size - 1
+
+ x_emb = self.embeddings(x)
+ return x, x_emb
+
+ def forward(
+ self,
+ inp: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ input_pos: Optional[Tensor] = None,
+ ) -> BaseTransformerForwardResult:
+ seq_len = inp.size(1)
+
+ # Here we want to merge the embeddings of the codebooks
+ # x = self.embed(inp)
+ x = inp.clone()
+
+ if input_pos is None:
+ freqs_cis = self.freqs_cis[:seq_len].repeat(inp.size(0), 1, 1, 1)
+ else:
+ freqs_cis = self.freqs_cis[input_pos]
+
+ # Not that the causal mask here follows the definition of scaled_dot_product_attention
+ # That is, FALSE means masked out
+ # To maintain consistency, key_padding_mask use TRUE to mask out
+ mask = None
+ if key_padding_mask is not None:
+ mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K)
+ mask = mask & key_padding_mask[:, None, None, :].logical_not()
+
+ for layer in self.layers:
+ if self.config.use_gradient_checkpointing and self.training:
+ x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True)
+ else:
+ x = layer(x, freqs_cis, mask)
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+
+ if self.config.tie_word_embeddings:
+ token_logits = F.linear(slow_out, self.embeddings.weight)
+ else:
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def forward_generate(
+ self,
+ inp: Tensor,
+ input_pos: Optional[Tensor] = None,
+ kv_pos: Optional[Tensor] = None,
+ return_all: bool = False,
+ ) -> BaseTransformerForwardResult:
+ # This is used for generation, optimized for torch compile
+
+ x = inp
+ max_seq_len = self.max_seq_len
+
+ mask = self.causal_mask[None, None, kv_pos, :max_seq_len] # (B, N, Q, K)
+ freqs_cis = self.freqs_cis[input_pos]
+
+ for layer in self.layers:
+ x = layer(x, freqs_cis, mask, input_pos=kv_pos)
+
+ x = x[:, -1:]
+
+ # We got slow_out here
+ slow_out = self.norm(x)
+
+ token_logits = self.output(slow_out)
+
+ return BaseTransformerForwardResult(
+ logits=token_logits,
+ hidden_states=x,
+ )
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+class NaiveTransformer(BaseTransformer):
+ def __init__(self, config: NaiveModelArgs) -> None:
+ super().__init__(config, init_weights=False)
+ self.apply(self._init_weights)
+
+ def forward(
+ self,
+ inp: Tensor,
+ cond_lens: Tensor,
+ target: Tensor,
+ target_lens: Tensor,
+ key_padding_mask: Optional[Tensor] = None,
+ input_pos: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ parent_result = super().forward(
+ inp=inp,
+ key_padding_mask=key_padding_mask,
+ input_pos=input_pos,
+ )
+ token_logits = parent_result.logits
+
+ # construct targets for token_logits
+ token_targets = torch.zeros(token_logits.size(0), token_logits.size(1), dtype=torch.long,
+ device=target.device) - 100
+ for bib in range(token_targets.size(0)):
+ token_targets[bib, cond_lens[bib] + 1:cond_lens[bib] + target_lens[bib] + 1] = target[bib, :target_lens[bib]]
+ token_targets[bib, cond_lens[bib] + target_lens[bib] + 1] = self.config.vocab_size - 1
+ return TransformerForwardResult(
+ token_logits=token_logits,
+ token_targets=token_targets,
+ )
+
+ def infer_slow(self, inp: Tensor, input_pos: Optional[Tensor] = None):
+ # no kv cache used
+ parent_result = super().forward(inp, input_pos=input_pos)
+ latent = parent_result.hidden_states[:, -1]
+ base_logits = parent_result.logits[:, -1]
+ base_sampled, _ = topk_sampling(base_logits, top_k=-1, top_p=1.0)
+ return base_sampled
+
+ def forward_generate(
+ self,
+ x: Tensor,
+ input_pos: Optional[Tensor] = None,
+ kv_pos: Optional[Tensor] = None,
+ vq_masks: Optional[Tensor] = None,
+ ) -> TransformerForwardResult:
+ x = super().forward_generate(x, input_pos, kv_pos, vq_masks)
+ return x
+
+class NaiveWrapper(nn.Module):
+ def __init__(self, model: NaiveTransformer) -> None:
+ super().__init__()
+ self.model = model
+ self.sep_token_emb = nn.Parameter(torch.randn(model.config.dim))
+
+ def setup_caches(self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda"):
+ self.model.setup_caches(max_batch_size, max_seq_len, dtype, device)
+
+ def forward(self, cond: Tensor, cond_lens: Tensor, x: Tensor, x_lens: Tensor) -> torch.Tensor:
+ # style_emb = self.style_in(style).unsqueeze(1) # [B, 1, D]
+ sep_token_emb = self.sep_token_emb.expand(x.size(0), 1, -1)
+ _, x_emb = self.model.embed_base(x, x_lens)
+ emb_seq_list = []
+ for i in range(x.size(0)):
+ emb_seq = torch.cat([
+ sep_token_emb[i:i + 1],
+ cond[i:i+1, :cond_lens[i]],
+ sep_token_emb[i:i+1],
+ x_emb[i:i+1, :x_lens[i]]], dim=1)
+ emb_seq_list.append(emb_seq)
+ max_len = max([emb_seq.size(1) for emb_seq in emb_seq_list])
+ emb_seq = torch.cat([
+ F.pad(emb_seq, (0, 0, 0, max_len - emb_seq.size(1)), value=0)
+ for emb_seq in emb_seq_list
+ ], dim=0)
+ # input_pos = torch.arange(emb_seq.size(1), device=emb_seq.device).repeat(emb_seq.size(0), 1)
+ input_pos = torch.zeros(emb_seq.size(0), emb_seq.size(1), device=emb_seq.device, dtype=torch.long)
+ for i in range(x.size(0)):
+ input_pos[i, :cond_lens[i] + 1] = torch.arange(cond_lens[i] + 1, device=emb_seq.device)
+ input_pos[i, cond_lens[i] + 1: cond_lens[i] + x_lens[i] + 2] = torch.arange(x_lens[i] + 1, device=emb_seq.device)
+ out = self.model(emb_seq, cond_lens, x, x_lens, input_pos=input_pos)
+ loss = F.cross_entropy(out.token_logits.transpose(1, 2), out.token_targets.long(), ignore_index=-100)
+ return loss
+
+ @torch.no_grad()
+ def infer(self, cond: Tensor) -> torch.Tensor:
+ sep_token_emb = self.sep_token_emb.expand(1, 1, -1)
+ emb_seq = torch.cat([sep_token_emb, cond, sep_token_emb], dim=1)
+ pred_codes = []
+ input_pos = torch.arange(cond.size(1) + 1, device=cond.device)
+ for i in tqdm(range(4000)):
+ input_pos = torch.cat([input_pos, torch.LongTensor([i]).to(cond.device)], dim=0)
+ base = self.model.infer_slow(emb_seq, input_pos)
+ if base == self.model.config.vocab_size - 1:
+ break
+ new_emb = self.model.embed_base(base, torch.LongTensor([1]).to(base.device))[1]
+ emb_seq = torch.cat([emb_seq, new_emb], dim=1)
+ pred_codes.append(base)
+ return torch.cat(pred_codes, dim=-1)
+
+ @torch.no_grad()
+ def generate(
+ self,
+ prompt_text,
+ prompt_target,
+ compiled_decode_fn = None,
+ **sampling_kwargs,
+ ):
+ sep_token_emb = self.sep_token_emb.expand(1, 1, -1)
+ emb_seq = torch.cat([sep_token_emb, prompt_text, sep_token_emb], dim=1)
+ input_pos = torch.arange(prompt_text.size(1) + 1, device=emb_seq.device)
+ input_pos = torch.cat([input_pos, torch.LongTensor([0]).to(emb_seq.device)])
+ prompt_target_emb = self.model.embed_base(prompt_target,torch.LongTensor([prompt_target.size(1)]).to(prompt_target.device))[1]
+ emb_seq = torch.cat([emb_seq, prompt_target_emb], dim=1)
+ input_pos = torch.cat([input_pos, torch.arange(prompt_target_emb.size(1)).to(input_pos.device) + 1])
+
+ pred_codes = []
+ kv_pos = torch.arange(emb_seq.size(1), device=emb_seq.device)
+ next_tokens = self.decode_one_token_ar(emb_seq, input_pos, kv_pos, suppress_tokens=[self.model.config.vocab_size - 1], **sampling_kwargs)
+ pred_base = next_tokens[0]
+ pred_codes.append(pred_base)
+ new_emb = self.model.embed_base(pred_base.unsqueeze(0), torch.LongTensor([1]).to(pred_base.device))[1]
+ emb_seq = torch.cat([emb_seq, new_emb], dim=1)
+ for _ in tqdm(range(4000)):
+ suppress_eos = len(pred_codes) < 10
+ input_pos = input_pos[-1:] + 1
+ kv_pos = kv_pos[-1:] + 1
+ next_tokens = self.decode_one_token_ar(
+ emb_seq[:, -1:].reshape(1, 1, -1),
+ input_pos.reshape(1),
+ kv_pos.reshape(1),
+ previous_tokens=torch.cat(pred_codes),
+ suppress_tokens=[self.model.config.vocab_size - 1] if suppress_eos else None,
+ compiled_decode_fn=compiled_decode_fn,
+ **sampling_kwargs)
+ pred_base = next_tokens[0]
+ if pred_base == self.model.config.vocab_size - 1:
+ break
+ pred_codes.append(pred_base.clone())
+ new_emb = self.model.embed_base(pred_base.unsqueeze(0), torch.LongTensor([1]).to(pred_base.device))[1]
+ emb_seq = torch.cat([emb_seq, new_emb], dim=1)
+ return torch.stack(pred_codes, dim=-1)
+
+ def decode_one_token_ar(
+ self,
+ x: torch.Tensor,
+ input_pos: torch.Tensor,
+ kv_pos: torch.Tensor,
+ previous_tokens: torch.Tensor = None,
+ compiled_decode_fn = None,
+ **sampling_kwargs,
+ ) -> torch.Tensor:
+ if compiled_decode_fn is not None:
+ x = compiled_decode_fn(x, input_pos, kv_pos)
+ else:
+ x = self.model.forward_generate(x, input_pos, kv_pos)
+
+ sampling_kwargs_main = sampling_kwargs.copy()
+ codebooks = [
+ sample(
+ x.logits,
+ previous_tokens=(
+ previous_tokens[0] if previous_tokens is not None else None
+ ),
+ **sampling_kwargs_main,
+ )[0]
+ ]
+ codebooks = torch.stack(codebooks, dim=0)
+ return codebooks
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None:
+ super().__init__()
+ self.attention = Attention(config, use_sdpa=use_sdpa)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
+ self.attention_norm = RMSNorm(config.dim, config.norm_eps)
+
+ def forward(
+ self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None
+ ) -> Tensor:
+ h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
+ out = h + self.feed_forward(self.ffn_norm(h))
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: BaseModelArgs, use_sdpa: bool = True):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ self.wqkv = nn.Linear(
+ config.dim, total_head_dim, bias=config.attention_qkv_bias
+ )
+ self.wo = nn.Linear(config.dim, config.dim, bias=False)
+ self.kv_cache = None
+
+ self.dropout = config.dropout
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+ self.use_sdpa = use_sdpa
+ self._register_load_state_dict_pre_hook(self.load_hook)
+ self.qk_norm = config.qk_norm
+ self.qk_norm_groups = 1
+ self.qk_norm_scale = 10
+ self.qk_norm_dim_scale = False
+ self.qk_norm_q_scale = self.qk_norm_k_scale = 1
+
+ if self.qk_norm and self.qk_norm_dim_scale:
+ self.qk_norm_q_scale = nn.Parameter(torch.ones(self.n_head, 1, self.head_dim))
+ self.qk_norm_k_scale = nn.Parameter(torch.ones(self.n_head, 1, self.head_dim))
+ def load_hook(self, state_dict, prefix, *args):
+ if prefix + "wq.weight" in state_dict:
+ wq = state_dict.pop(prefix + "wq.weight")
+ wk = state_dict.pop(prefix + "wk.weight")
+ wv = state_dict.pop(prefix + "wv.weight")
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
+
+ def forward(
+ self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_local_heads * self.head_dim
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+
+ if self.qk_norm:
+ qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
+ q, k = map(qk_l2norm, (q, k))
+ scale = self.qk_norm_scale
+
+ q = q * self.qk_norm_q_scale
+ k = k * self.qk_norm_k_scale
+
+ q = apply_rotary_emb(q, freqs_cis)
+ k = apply_rotary_emb(k, freqs_cis)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+
+ if self.use_sdpa:
+ if mask is None:
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.dropout if self.training else 0.0,
+ is_causal=True,
+ # No third party attn_mask here to use flash_attention
+ )
+ else:
+ y = F.scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+ else:
+ y = self.eq_scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ )
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+ return self.wo(y)
+
+ def eq_scaled_dot_product_attention(
+ self,
+ query,
+ key,
+ value,
+ attn_mask=None,
+ dropout_p=0.0,
+ ) -> torch.Tensor:
+ # This is a standard scaled dot product attention
+ # It's low efficient, but it doesn't raise cuda error
+
+ L, S = query.size(-2), key.size(-2)
+ scale_factor = 1 / math.sqrt(query.size(-1))
+ attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+
+ attn_weight = query @ key.transpose(-2, -1) * scale_factor
+ attn_weight += attn_bias
+ attn_weight = torch.softmax(attn_weight, dim=-1)
+ attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
+
+ return attn_weight @ value
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: BaseModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+ self.dropout = nn.Dropout(p=config.dropout)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor:
+ freqs = 1.0 / (
+ base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)
+ )
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+ return cache.to(dtype=torch.bfloat16)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+ freqs_cis = freqs_cis.view(x.size(0), xshaped.size(1), 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ],
+ -1,
+ )
+
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
+
+def top_k_top_p_filtering(
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
+):
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
+ Args:
+ logits: logits distribution shape (batch size, vocabulary size)
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
+ """
+ if top_k > 0:
+ top_k = min(
+ max(top_k, min_tokens_to_keep), logits.size(-1)
+ ) # Safety check
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = filter_value
+
+ if top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(
+ F.softmax(sorted_logits, dim=-1), dim=-1
+ )
+
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ if min_tokens_to_keep > 1:
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
+ # Shift the indices to the right to keep also the first token above the threshold
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
+ ..., :-1
+ ].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ # scatter sorted tensors to original indexing
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ 1, sorted_indices, sorted_indices_to_remove
+ )
+ logits[indices_to_remove] = filter_value
+ return logits
+
+def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
+ # temperature: (`optional`) float
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
+ # top_k: (`optional`) int
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
+ # top_p: (`optional`) float
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
+
+ # Temperature (higher temperature => more likely to sample low probability tokens)
+ if temperature != 1.0:
+ logits = logits / temperature
+ # Top-p/top-k filtering
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
+ # Sample
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
+ logprobs = F.log_softmax(logits.float(), dim=-1)
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)]
+ return token, current_logprobs
+
+def sample(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ **sampling_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ probs = logits_to_probs(
+ logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs
+ )
+ idx_next = multinomial_sample_one_no_sync(probs)
+ return idx_next, probs
+
+def multinomial_sample_one_no_sync(
+ probs_sort,
+): # Does multinomial sampling without a cuda synchronization
+ q = torch.empty_like(probs_sort).exponential_(1)
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
+
+
+def logits_to_probs(
+ logits,
+ previous_tokens: Optional[torch.Tensor] = None,
+ suppress_tokens: Optional[List[int]] = None,
+ temperature: torch.Tensor = 0.7,
+ top_p: torch.Tensor = 0.7,
+ repetition_penalty: torch.Tensor = 1.5,
+) -> torch.Tensor:
+ # Apply repetition penalty
+ if previous_tokens is not None:
+ previous_tokens = previous_tokens.long()
+ score = torch.gather(logits, dim=0, index=previous_tokens)
+ score = torch.where(
+ score < 0, score * repetition_penalty, score / repetition_penalty
+ )
+ logits.scatter_(dim=0, index=previous_tokens, src=score)
+ if suppress_tokens is not None:
+ for token in suppress_tokens:
+ logits[token] = -float("Inf")
+
+ # Apply top-p sampling
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cum_probs > top_p
+ sorted_indices_to_remove[0] = False # keep at least one option
+ indices_to_remove = sorted_indices_to_remove.scatter(
+ dim=0, index=sorted_indices, src=sorted_indices_to_remove
+ )
+ logits = logits.masked_fill(indices_to_remove, -float("Inf"))
+
+ logits = logits / max(temperature, 1e-5)
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ return probs
diff --git a/modules/v2/cfm.py b/modules/v2/cfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0ea58ef15f31c324bbcc061f15be8650824790e
--- /dev/null
+++ b/modules/v2/cfm.py
@@ -0,0 +1,173 @@
+import torch
+from tqdm import tqdm
+
+class CFM(torch.nn.Module):
+ def __init__(
+ self,
+ estimator: torch.nn.Module,
+ ):
+ super().__init__()
+ self.sigma_min = 1e-6
+ self.estimator = estimator
+ self.in_channels = estimator.in_channels
+ self.criterion = torch.nn.L1Loss()
+
+ @torch.inference_mode()
+ def inference(self,
+ mu: torch.Tensor,
+ x_lens: torch.Tensor,
+ prompt: torch.Tensor,
+ style: torch.Tensor,
+ n_timesteps=10,
+ temperature=1.0,
+ inference_cfg_rate=[0.5, 0.5],
+ random_voice=False,
+ ):
+ """Forward diffusion
+
+ Args:
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ x_lens (torch.Tensor): length of each mel-spectrogram
+ shape: (batch_size,)
+ prompt (torch.Tensor): prompt
+ shape: (batch_size, n_feats, prompt_len)
+ style (torch.Tensor): style
+ shape: (batch_size, style_dim)
+ n_timesteps (int): number of diffusion steps
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
+ inference_cfg_rate (float, optional): Classifier-Free Guidance inference introduced in VoiceBox. Defaults to 0.5.
+
+ Returns:
+ sample: generated mel-spectrogram
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+ B, T = mu.size(0), mu.size(1)
+ z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
+ t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span)
+ return self.solve_euler(z, x_lens, prompt, mu, style, t_span, inference_cfg_rate, random_voice)
+ def solve_euler(self, x, x_lens, prompt, mu, style, t_span, inference_cfg_rate=[0.5, 0.5], random_voice=False,):
+ """
+ Fixed euler solver for ODEs.
+ Args:
+ x (torch.Tensor): random noise
+ t_span (torch.Tensor): n_timesteps interpolated
+ shape: (n_timesteps + 1,)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ x_lens (torch.Tensor): length of each mel-spectrogram
+ shape: (batch_size,)
+ prompt (torch.Tensor): prompt
+ shape: (batch_size, n_feats, prompt_len)
+ style (torch.Tensor): style
+ shape: (batch_size, style_dim)
+ inference_cfg_rate (float, optional): Classifier-Free Guidance inference introduced in VoiceBox. Defaults to 0.5.
+ sway_sampling (bool, optional): Sway sampling. Defaults to False.
+ amo_sampling (bool, optional): AMO sampling. Defaults to False.
+ """
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
+
+ # apply prompt
+ prompt_len = prompt.size(-1)
+ prompt_x = torch.zeros_like(x)
+ prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
+ x[..., :prompt_len] = 0
+ for step in tqdm(range(1, len(t_span))):
+ if random_voice:
+ cfg_dphi_dt = self.estimator(
+ torch.cat([x, x], dim=0),
+ torch.cat([torch.zeros_like(prompt_x), torch.zeros_like(prompt_x)], dim=0),
+ torch.cat([x_lens, x_lens], dim=0),
+ torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0),
+ torch.cat([torch.zeros_like(style), torch.zeros_like(style)], dim=0),
+ torch.cat([mu, torch.zeros_like(mu)], dim=0),
+ )
+ cond_txt, uncond = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2]
+ dphi_dt = ((1.0 + inference_cfg_rate[0]) * cond_txt - inference_cfg_rate[0] * uncond)
+ elif all(i == 0 for i in inference_cfg_rate):
+ dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu)
+ elif inference_cfg_rate[0] == 0:
+ # Classifier-Free Guidance inference introduced in VoiceBox
+ cfg_dphi_dt = self.estimator(
+ torch.cat([x, x], dim=0),
+ torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0),
+ torch.cat([x_lens, x_lens], dim=0),
+ torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0),
+ torch.cat([style, torch.zeros_like(style)], dim=0),
+ torch.cat([mu, mu], dim=0),
+ )
+ cond_txt_spk, cond_txt = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2]
+ dphi_dt = ((1.0 + inference_cfg_rate[1]) * cond_txt_spk - inference_cfg_rate[1] * cond_txt)
+ elif inference_cfg_rate[1] == 0:
+ cfg_dphi_dt = self.estimator(
+ torch.cat([x, x], dim=0),
+ torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0),
+ torch.cat([x_lens, x_lens], dim=0),
+ torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0),
+ torch.cat([style, torch.zeros_like(style)], dim=0),
+ torch.cat([mu, torch.zeros_like(mu)], dim=0),
+ )
+ cond_txt_spk, uncond = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2]
+ dphi_dt = ((1.0 + inference_cfg_rate[0]) * cond_txt_spk - inference_cfg_rate[0] * uncond)
+ else:
+ # Multi-condition Classifier-Free Guidance inference introduced in MegaTTS3
+ cfg_dphi_dt = self.estimator(
+ torch.cat([x, x, x], dim=0),
+ torch.cat([prompt_x, torch.zeros_like(prompt_x), torch.zeros_like(prompt_x)], dim=0),
+ torch.cat([x_lens, x_lens, x_lens], dim=0),
+ torch.cat([t.unsqueeze(0), t.unsqueeze(0), t.unsqueeze(0)], dim=0),
+ torch.cat([style, torch.zeros_like(style), torch.zeros_like(style)], dim=0),
+ torch.cat([mu, mu, torch.zeros_like(mu)], dim=0),
+ )
+ cond_txt_spk, cond_txt, uncond = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2], cfg_dphi_dt[2:3]
+ dphi_dt = (1.0 + inference_cfg_rate[0] + inference_cfg_rate[1]) * cond_txt_spk - \
+ inference_cfg_rate[0] * uncond - inference_cfg_rate[1] * cond_txt
+ x = x + dt * dphi_dt
+ t = t + dt
+ if step < len(t_span) - 1:
+ dt = t_span[step + 1] - t
+ x[:, :, :prompt_len] = 0
+
+ return x
+
+ def forward(self, x1, x_lens, prompt_lens, mu, style):
+ """Computes diffusion loss
+
+ Args:
+ x1 (torch.Tensor): Target
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): target mask
+ shape: (batch_size, 1, mel_timesteps)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+
+ Returns:
+ loss: conditional flow matching loss
+ y: conditional flow
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+ b, _, t = x1.shape
+
+ # random timestep
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype)
+ # sample noise p(x_0)
+ z = torch.randn_like(x1)
+
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
+ u = x1 - (1 - self.sigma_min) * z
+ prompt = torch.zeros_like(x1)
+ for bib in range(b):
+ prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]]
+ # range covered by prompt are set to 0
+ y[bib, :, :prompt_lens[bib]] = 0
+
+ estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(), style, mu)
+ loss = 0
+ for bib in range(b):
+ loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]])
+ loss /= b
+
+ return loss
diff --git a/modules/v2/dit_model.py b/modules/v2/dit_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4374ac86a4d4d0869788cdd16087115c4418ba5f
--- /dev/null
+++ b/modules/v2/dit_model.py
@@ -0,0 +1,250 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from dataclasses import dataclass
+from typing import Optional, Union, Tuple, List
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+import time
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+class AdaptiveLayerNorm(nn.Module):
+ r"""Adaptive Layer Normalization"""
+
+ def __init__(self, d_model, norm) -> None:
+ super(AdaptiveLayerNorm, self).__init__()
+ self.linear = nn.Linear(d_model, 6 * d_model)
+ self.act = nn.SiLU()
+ self.norm = norm
+ self.d_model = d_model
+ self.eps = self.norm.eps
+
+ def forward(self, x: Tensor, emb: Tensor) -> Tuple[Tensor]:
+ emb = self.linear(self.act(emb))
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=-1)
+
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
+
+class AdaptiveLayerNormFinal(nn.Module):
+ r"""Adaptive Layer Normalization"""
+
+ def __init__(self, d_model, norm) -> None:
+ super(AdaptiveLayerNormFinal, self).__init__()
+ self.linear = nn.Linear(d_model, 2 * d_model)
+ self.act = nn.SiLU()
+ self.norm = norm
+ self.d_model = d_model
+ self.eps = self.norm.eps
+
+ def forward(self, x: Tensor, emb: Tensor) -> Tuple[Tensor]:
+ emb = self.linear(self.act(emb))
+ scale, shift = torch.chunk(emb, 2, dim=-1)
+
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+@dataclass
+class ModelArgs:
+ block_size: int = 2048
+ vocab_size: int = 32000
+ n_layer: int = 32
+ n_head: int = 32
+ dim: int = 4096
+ intermediate_size: int = None
+ n_local_heads: int = -1
+ head_dim: int = 64
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ uvit_skip_connection: bool = False
+ time_as_token: bool = False
+ dropout_rate: float = 0.1
+ attn_dropout_rate: float = 0.1
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ # self.head_dim = self.dim // self.n_head
+
+class Transformer(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.config = config
+
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
+ self.norm = AdaptiveLayerNormFinal(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+ self.max_batch_size = -1
+ self.max_seq_length = config.block_size
+
+ self.uvit_skip_connection = self.config.uvit_skip_connection
+ if self.uvit_skip_connection:
+ self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2]
+ self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2]
+ else:
+ self.layers_emit_skip = []
+ self.layers_receive_skip = []
+ freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
+ self.config.rope_base)
+ self.register_buffer("freqs_cis", freqs_cis)
+
+ causal_mask = torch.tril(
+ torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
+ )
+ self.register_buffer("causal_mask", causal_mask)
+
+ def forward(self,
+ x: Tensor,
+ c: Tensor,
+ input_pos: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ mask = mask[..., input_pos]
+ freqs_cis = self.freqs_cis[input_pos]
+ for i, layer in enumerate(self.layers):
+ x = layer(x, c, freqs_cis, mask)
+ x = self.norm(x, c)
+ return x
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.attention = Attention(config)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
+ self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+ def forward(self,
+ x: Tensor,
+ c: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ ) -> Tensor:
+ normed_x, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attention_norm(x, emb=c)
+ # attention
+ attn_output = self.attention(normed_x, freqs_cis, mask)
+ x = x + gate_msa * attn_output
+ normed_x = self.ffn_norm(x) * (1 + scale_mlp) + shift_mlp
+ ff_output = self.feed_forward(normed_x)
+ x = x + gate_mlp * ff_output
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ if is_cross_attention:
+ self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
+ self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
+ else:
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
+ self.kv_cache = None
+
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+ self.attn_dropout_rate = config.attn_dropout_rate
+
+ def forward(self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ context: Optional[Tensor] = None,
+ context_freqs_cis: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_local_heads * self.head_dim
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
+ context_seqlen = seqlen
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+
+ q = apply_rotary_emb(q, freqs_cis)
+ k = apply_rotary_emb(k, freqs_cis)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
+
+ y = self.wo(y)
+ return y
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+ self.dropout = nn.Dropout(config.dropout_rate)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+def precompute_freqs_cis(
+ seq_len: int, n_elem: int, base: int = 10000,
+ dtype: torch.dtype = torch.bfloat16
+) -> Tensor:
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+ return cache.to(dtype=dtype)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ],
+ -1,
+ )
+
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
+
diff --git a/modules/v2/dit_wrapper.py b/modules/v2/dit_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..f653239de475fba17794d6b8d7e28a8edd1b65a0
--- /dev/null
+++ b/modules/v2/dit_wrapper.py
@@ -0,0 +1,152 @@
+import torch
+from torch import nn
+import math
+
+from modules.v2.dit_model import ModelArgs, Transformer
+from modules.commons import sequence_mask
+
+from torch.nn.utils import weight_norm
+
+def modulate(x, shift, scale):
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
+
+
+#################################################################################
+# Embedding Layers for Timesteps and Class Labels #
+#################################################################################
+
+class TimestepEmbedder(nn.Module):
+ """
+ Embeds scalar timesteps into vector representations.
+ """
+ def __init__(self, hidden_size, frequency_embedding_size=256):
+ super().__init__()
+ self.mlp = nn.Sequential(
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
+ nn.SiLU(),
+ nn.Linear(hidden_size, hidden_size, bias=True),
+ )
+ self.frequency_embedding_size = frequency_embedding_size
+
+ @staticmethod
+ def timestep_embedding(t, dim, max_period=10000, scale=1000):
+ """
+ Create sinusoidal timestep embeddings.
+ :param t: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an (N, D) Tensor of positional embeddings.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=t.device)
+ args = scale * t[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+ def forward(self, t):
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
+ t_emb = self.mlp(t_freq)
+ return t_emb
+
+
+class DiT(torch.nn.Module):
+ def __init__(
+ self,
+ time_as_token,
+ style_as_token,
+ uvit_skip_connection,
+ block_size,
+ depth,
+ num_heads,
+ hidden_dim,
+ in_channels,
+ content_dim,
+ style_encoder_dim,
+ class_dropout_prob,
+ dropout_rate,
+ attn_dropout_rate,
+ ):
+ super(DiT, self).__init__()
+ self.time_as_token = time_as_token
+ self.style_as_token = style_as_token
+ self.uvit_skip_connection = uvit_skip_connection
+ model_args = ModelArgs(
+ block_size=block_size,
+ n_layer=depth,
+ n_head=num_heads,
+ dim=hidden_dim,
+ head_dim=hidden_dim // num_heads,
+ vocab_size=1, # we don't use this
+ uvit_skip_connection=self.uvit_skip_connection,
+ time_as_token=self.time_as_token,
+ dropout_rate=dropout_rate,
+ attn_dropout_rate=attn_dropout_rate,
+ )
+ self.transformer = Transformer(model_args)
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.num_heads = num_heads
+
+ self.x_embedder = weight_norm(nn.Linear(in_channels, hidden_dim, bias=True))
+
+ self.content_dim = content_dim # for continuous content
+ self.cond_projection = nn.Linear(content_dim, hidden_dim, bias=True) # continuous content
+
+ self.t_embedder = TimestepEmbedder(hidden_dim)
+
+ self.final_mlp = nn.Sequential(
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.SiLU(),
+ nn.Linear(hidden_dim, in_channels),
+ )
+
+ self.class_dropout_prob = class_dropout_prob
+
+ self.cond_x_merge_linear = nn.Linear(hidden_dim + in_channels + in_channels, hidden_dim)
+ self.style_in = nn.Linear(style_encoder_dim, hidden_dim)
+
+ def forward(self, x, prompt_x, x_lens, t, style, cond):
+ class_dropout = False
+ content_dropout = False
+ if self.training and torch.rand(1) < self.class_dropout_prob:
+ class_dropout = True
+ if self.training and torch.rand(1) < 0.5:
+ content_dropout = True
+ cond_in_module = self.cond_projection
+
+ B, _, T = x.size()
+
+ t1 = self.t_embedder(t) # (N, D)
+ cond = cond_in_module(cond)
+
+ x = x.transpose(1, 2)
+ prompt_x = prompt_x.transpose(1, 2)
+
+ x_in = torch.cat([x, prompt_x, cond], dim=-1)
+ if class_dropout:
+ x_in[..., self.in_channels:self.in_channels*2] = 0
+ if content_dropout:
+ x_in[..., self.in_channels*2:] = 0
+ x_in = self.cond_x_merge_linear(x_in) # (N, T, D)
+
+ style = self.style_in(style)
+ style = torch.zeros_like(style) if class_dropout else style
+ if self.style_as_token:
+ x_in = torch.cat([style.unsqueeze(1), x_in], dim=1)
+ if self.time_as_token:
+ x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1)
+ x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token, max_length=x_in.size(1)).to(x.device).unsqueeze(1)
+ input_pos = torch.arange(x_in.size(1)).to(x.device)
+ x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1)
+ x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded)
+ x_res = x_res[:, 1:] if self.time_as_token else x_res
+ x_res = x_res[:, 1:] if self.style_as_token else x_res
+ x = self.final_mlp(x_res)
+ x = x.transpose(1, 2)
+ return x
diff --git a/modules/v2/length_regulator.py b/modules/v2/length_regulator.py
new file mode 100644
index 0000000000000000000000000000000000000000..7efe5a62bc5afba06a8abe5051aace9ad97dbf3e
--- /dev/null
+++ b/modules/v2/length_regulator.py
@@ -0,0 +1,105 @@
+from typing import Tuple
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from modules.commons import sequence_mask
+import numpy as np
+
+# f0_bin = 256
+f0_max = 1100.0
+f0_min = 50.0
+f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+
+def f0_to_coarse(f0, f0_bin):
+ f0_mel = 1127 * (1 + f0 / 700).log()
+ a = (f0_bin - 2) / (f0_mel_max - f0_mel_min)
+ b = f0_mel_min * a - 1.
+ f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel)
+ # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1))
+ f0_coarse = torch.round(f0_mel).long()
+ f0_coarse = f0_coarse * (f0_coarse > 0)
+ f0_coarse = f0_coarse + ((f0_coarse < 1) * 1)
+ f0_coarse = f0_coarse * (f0_coarse < f0_bin)
+ f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1))
+ return f0_coarse
+
+class InterpolateRegulator(nn.Module):
+ def __init__(
+ self,
+ channels: int,
+ sampling_ratios: Tuple,
+ is_discrete: bool = False,
+ in_channels: int = None, # only applies to continuous input
+ codebook_size: int = 1024, # for discrete only
+ out_channels: int = None,
+ groups: int = 1,
+ f0_condition: bool = False,
+ n_f0_bins: int = 512,
+ ):
+ super().__init__()
+ self.sampling_ratios = sampling_ratios
+ out_channels = out_channels or channels
+ model = nn.ModuleList([])
+ if len(sampling_ratios) > 0:
+ self.interpolate = True
+ for _ in sampling_ratios:
+ module = nn.Conv1d(channels, channels, 3, 1, 1)
+ norm = nn.GroupNorm(groups, channels)
+ act = nn.Mish()
+ model.extend([module, norm, act])
+ else:
+ self.interpolate = False
+ model.append(
+ nn.Conv1d(channels, out_channels, 1, 1) if channels != out_channels else nn.Identity()
+ )
+ self.model = nn.Sequential(*model)
+ self.embedding = nn.Embedding(codebook_size, channels)
+ self.is_discrete = is_discrete
+
+ self.mask_token = nn.Parameter(torch.zeros(1, channels))
+
+ if f0_condition:
+ self.f0_embedding = nn.Embedding(n_f0_bins, channels)
+ self.f0_condition = f0_condition
+ self.n_f0_bins = n_f0_bins
+ self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins)
+ self.f0_mask = nn.Parameter(torch.zeros(1, channels))
+ else:
+ self.f0_condition = False
+
+ if not is_discrete:
+ self.content_in_proj = nn.Linear(in_channels, channels)
+
+ def forward(self, x, ylens=None, f0=None):
+ if self.is_discrete:
+ if len(x.size()) == 2:
+ x = self.embedding(x)
+ else:
+ x = self.embedding(x[:, 0])
+ else:
+ x = self.content_in_proj(x)
+ # x in (B, T, D)
+
+ if self.interpolate:
+ mask = sequence_mask(ylens).unsqueeze(-1)
+ x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
+ else:
+ x = x.transpose(1, 2).contiguous()
+ mask = None
+ # mask = mask[:, :x.size(2), :]
+ # ylens = ylens.clamp(max=x.size(2)).long()
+ if self.f0_condition:
+ if f0 is None:
+ x = x + self.f0_mask.unsqueeze(-1)
+ else:
+ # quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T)
+ quantized_f0 = f0_to_coarse(f0, self.n_f0_bins)
+ quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long()
+ f0_emb = self.f0_embedding(quantized_f0)
+ f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest')
+ x = x + f0_emb
+ out = self.model(x).transpose(1, 2).contiguous()
+ out = out * mask if mask is not None else out
+ olens = ylens
+ return out, olens
diff --git a/modules/v2/model.py b/modules/v2/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a96dd0b6c58991ca3e203ca6c5247dc0413e48b4
--- /dev/null
+++ b/modules/v2/model.py
@@ -0,0 +1,302 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+class AdaptiveLayerNorm(nn.Module):
+ r"""Adaptive Layer Normalization"""
+
+ def __init__(self, d_model, norm) -> None:
+ super(AdaptiveLayerNorm, self).__init__()
+ self.project_layer = nn.Linear(d_model, 2 * d_model)
+ self.norm = norm
+ self.d_model = d_model
+ self.eps = self.norm.eps
+
+ def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor:
+ if embedding is None:
+ return self.norm(input)
+ weight, bias = torch.split(
+ self.project_layer(embedding),
+ split_size_or_sections=self.d_model,
+ dim=-1,
+ )
+ return weight * self.norm(input) + bias
+
+
+@dataclass
+class ModelArgs:
+ block_size: int = 2048
+ vocab_size: int = 32000
+ n_layer: int = 32
+ n_head: int = 32
+ dim: int = 4096
+ intermediate_size: int = None
+ n_local_heads: int = -1
+ head_dim: int = 64
+ rope_base: float = 10000
+ norm_eps: float = 1e-5
+ has_cross_attention: bool = False
+ context_dim: int = 0
+ uvit_skip_connection: bool = False
+ time_as_token: bool = False
+
+ def __post_init__(self):
+ if self.n_local_heads == -1:
+ self.n_local_heads = self.n_head
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ # self.head_dim = self.dim // self.n_head
+
+class Transformer(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.config = config
+
+ self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
+ self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+ self.freqs_cis: Optional[Tensor] = None
+ self.mask_cache: Optional[Tensor] = None
+ self.max_batch_size = -1
+ self.max_seq_length = -1
+
+ def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=False):
+ if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
+ return
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_length = find_multiple(max_seq_length, 8)
+ self.max_seq_length = max_seq_length
+ self.max_batch_size = max_batch_size
+ dtype = self.norm.project_layer.weight.dtype
+ device = self.norm.project_layer.weight.device
+
+ self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim,
+ self.config.rope_base, dtype).to(device)
+ self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device)
+ self.use_kv_cache = use_kv_cache
+ self.uvit_skip_connection = self.config.uvit_skip_connection
+ if self.uvit_skip_connection:
+ self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2]
+ self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2]
+ else:
+ self.layers_emit_skip = []
+ self.layers_receive_skip = []
+
+ def forward(self,
+ x: Tensor,
+ c: Tensor,
+ input_pos: Optional[Tensor] = None,
+ mask: Optional[Tensor] = None,
+ context: Optional[Tensor] = None,
+ context_input_pos: Optional[Tensor] = None,
+ cross_attention_mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ assert self.freqs_cis is not None, "Caches must be initialized first"
+ if mask is None: # in case of non-causal model
+ if not self.training and self.use_kv_cache:
+ mask = self.causal_mask[None, None, input_pos]
+ else:
+ mask = self.causal_mask[None, None, input_pos]
+ mask = mask[..., input_pos]
+ freqs_cis = self.freqs_cis[input_pos]
+ if context is not None:
+ context_freqs_cis = self.freqs_cis[context_input_pos]
+ else:
+ context_freqs_cis = None
+ skip_in_x_list = []
+ for i, layer in enumerate(self.layers):
+ if self.uvit_skip_connection and i in self.layers_receive_skip:
+ skip_in_x = skip_in_x_list.pop(-1)
+ else:
+ skip_in_x = None
+ x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x)
+ if self.uvit_skip_connection and i in self.layers_emit_skip:
+ skip_in_x_list.append(x)
+ x = self.norm(x, c)
+ return x
+
+ @classmethod
+ def from_name(cls, name: str):
+ return cls(ModelArgs.from_name(name))
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.attention = Attention(config)
+ self.feed_forward = FeedForward(config)
+ self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+ self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+
+ if config.has_cross_attention:
+ self.has_cross_attention = True
+ self.cross_attention = Attention(config, is_cross_attention=True)
+ self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps))
+ else:
+ self.has_cross_attention = False
+
+ if config.uvit_skip_connection:
+ self.skip_in_linear = nn.Linear(config.dim * 2, config.dim)
+ self.uvit_skip_connection = True
+ else:
+ self.uvit_skip_connection = False
+
+ self.time_as_token = config.time_as_token
+
+ def forward(self,
+ x: Tensor,
+ c: Tensor,
+ input_pos: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ context: Optional[Tensor] = None,
+ context_freqs_cis: Optional[Tensor] = None,
+ cross_attention_mask: Optional[Tensor] = None,
+ skip_in_x: Optional[Tensor] = None,
+ ) -> Tensor:
+ c = None if self.time_as_token else c
+ if self.uvit_skip_connection and skip_in_x is not None:
+ x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1))
+ h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos)
+ if self.has_cross_attention:
+ h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis)
+ out = h + self.feed_forward(self.ffn_norm(h, c))
+ return out
+
+
+class Attention(nn.Module):
+ def __init__(self, config: ModelArgs, is_cross_attention: bool = False):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
+ # key, query, value projections for all heads, but in a batch
+ if is_cross_attention:
+ self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False)
+ self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False)
+ else:
+ self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
+ self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False)
+ self.kv_cache = None
+
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.n_local_heads = config.n_local_heads
+ self.dim = config.dim
+ # self._register_load_state_dict_pre_hook(self.load_hook)
+
+ # def load_hook(self, state_dict, prefix, *args):
+ # if prefix + "wq.weight" in state_dict:
+ # wq = state_dict.pop(prefix + "wq.weight")
+ # wk = state_dict.pop(prefix + "wk.weight")
+ # wv = state_dict.pop(prefix + "wv.weight")
+ # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
+
+ def forward(self,
+ x: Tensor,
+ freqs_cis: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ context: Optional[Tensor] = None,
+ context_freqs_cis: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_local_heads * self.head_dim
+ if context is None:
+ q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1)
+ context_seqlen = seqlen
+ else:
+ q = self.wq(x)
+ k, v = self.wkv(context).split([kv_size, kv_size], dim=-1)
+ context_seqlen = context.shape[1]
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+ v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim)
+
+ q = apply_rotary_emb(q, freqs_cis)
+ k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head)
+
+ y = self.wo(y)
+ return y
+
+
+class FeedForward(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False)
+ self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim: int, eps: float = 1e-5):
+ super().__init__()
+ self.eps = eps
+ self.weight = nn.Parameter(torch.ones(dim))
+
+ def _norm(self, x):
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
+
+ def forward(self, x: Tensor) -> Tensor:
+ output = self._norm(x.float()).type_as(x)
+ return output * self.weight
+
+
+def precompute_freqs_cis(
+ seq_len: int, n_elem: int, base: int = 10000,
+ dtype: torch.dtype = torch.bfloat16
+) -> Tensor:
+ freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
+ t = torch.arange(seq_len, device=freqs.device)
+ freqs = torch.outer(t, freqs)
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
+ cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
+ return cache.to(dtype=dtype)
+
+
+def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
+ xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
+ freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
+ x_out2 = torch.stack(
+ [
+ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
+ xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
+ ],
+ -1,
+ )
+
+ x_out2 = x_out2.flatten(3)
+ return x_out2.type_as(x)
diff --git a/modules/v2/vc_wrapper.py b/modules/v2/vc_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..359e805001c831bbc257f2da7d377a32967a27df
--- /dev/null
+++ b/modules/v2/vc_wrapper.py
@@ -0,0 +1,606 @@
+import spaces
+import torch
+import librosa
+import torchaudio
+import numpy as np
+from pydub import AudioSegment
+from hf_utils import load_custom_model_from_hf
+
+DEFAULT_REPO_ID = "Plachta/Seed-VC"
+DEFAULT_CFM_CHECKPOINT = "v2/cfm_small.pth"
+DEFAULT_AR_CHECKPOINT = "v2/ar_base.pth"
+
+DEFAULT_CE_REPO_ID = "Plachta/ASTRAL-quantization"
+DEFAULT_CE_NARROW_CHECKPOINT = "bsq32/bsq32_light.pth"
+DEFAULT_CE_WIDE_CHECKPOINT = "bsq2048/bsq2048_light.pth"
+
+DEFAULT_SE_REPO_ID = "funasr/campplus"
+DEFAULT_SE_CHECKPOINT = "campplus_cn_common.bin"
+
+class VoiceConversionWrapper(torch.nn.Module):
+ def __init__(
+ self,
+ sr: int,
+ hop_size: int,
+ mel_fn: callable,
+ cfm: torch.nn.Module,
+ cfm_length_regulator: torch.nn.Module,
+ content_extractor_narrow: torch.nn.Module,
+ content_extractor_wide: torch.nn.Module,
+ ar_length_regulator: torch.nn.Module,
+ ar: torch.nn.Module,
+ style_encoder: torch.nn.Module,
+ vocoder: torch.nn.Module,
+ ):
+ super(VoiceConversionWrapper, self).__init__()
+ self.sr = sr
+ self.hop_size = hop_size
+ self.mel_fn = mel_fn
+ self.cfm = cfm
+ self.cfm_length_regulator = cfm_length_regulator
+ self.content_extractor_narrow = content_extractor_narrow
+ self.content_extractor_wide = content_extractor_wide
+ self.vocoder = vocoder
+ self.ar_length_regulator = ar_length_regulator
+ self.ar = ar
+ self.style_encoder = style_encoder
+ # Set streaming parameters
+ self.overlap_frame_len = 16
+ self.bitrate = "320k"
+ self.compiled_decode_fn = None
+ self.dit_compiled = False
+ self.dit_max_context_len = 30 # in seconds
+ self.compile_len = 87 * self.dit_max_context_len
+
+ def compile_ar(self):
+ """
+ Compile the AR model for inference.
+ """
+ self.compiled_decode_fn = torch.compile(
+ self.ar.model.forward_generate,
+ fullgraph=True,
+ backend="inductor" if torch.cuda.is_available() else "aot_eager",
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
+ )
+
+ def compile_cfm(self):
+ self.cfm.estimator.transformer = torch.compile(
+ self.cfm.estimator.transformer,
+ fullgraph=True,
+ backend="inductor" if torch.cuda.is_available() else "aot_eager",
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
+ )
+ self.dit_compiled = True
+
+ @staticmethod
+ def strip_prefix(state_dict: dict, prefix: str = "module.") -> dict:
+ """
+ Strip the prefix from the state_dict keys.
+ """
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k.startswith(prefix):
+ new_key = k[len(prefix):]
+ else:
+ new_key = k
+ new_state_dict[new_key] = v
+ return new_state_dict
+
+ @staticmethod
+ def duration_reduction_func(token_seq, n_gram=1):
+ """
+ Args:
+ token_seq: (T,)
+ Returns:
+ reduced_token_seq: (T')
+ reduced_token_seq_len: T'
+ """
+ n_gram_seq = token_seq.unfold(0, n_gram, 1)
+ mask = torch.all(n_gram_seq[1:] != n_gram_seq[:-1], dim=1)
+ reduced_token_seq = torch.cat(
+ (n_gram_seq[0, :n_gram], n_gram_seq[1:, -1][mask])
+ )
+ return reduced_token_seq, len(reduced_token_seq)
+
+ @staticmethod
+ def crossfade(chunk1, chunk2, overlap):
+ """Apply crossfade between two audio chunks."""
+ fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
+ fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
+ if len(chunk2) < overlap:
+ chunk2[:overlap] = chunk2[:overlap] * fade_in[:len(chunk2)] + (chunk1[-overlap:] * fade_out)[:len(chunk2)]
+ else:
+ chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
+ return chunk2
+
+ def _stream_wave_chunks(self, vc_wave, processed_frames, vc_mel, overlap_wave_len,
+ generated_wave_chunks, previous_chunk, is_last_chunk, stream_output):
+ """
+ Helper method to handle streaming wave chunks.
+
+ Args:
+ vc_wave: The current wave chunk
+ processed_frames: Number of frames processed so far
+ vc_mel: The mel spectrogram
+ overlap_wave_len: Length of overlap between chunks
+ generated_wave_chunks: List of generated wave chunks
+ previous_chunk: Previous wave chunk for crossfading
+ is_last_chunk: Whether this is the last chunk
+ stream_output: Whether to stream the output
+
+ Returns:
+ Tuple of (processed_frames, previous_chunk, should_break, mp3_bytes, full_audio)
+ where should_break indicates if processing should stop
+ mp3_bytes is the MP3 bytes if streaming, None otherwise
+ full_audio is the full audio if this is the last chunk, None otherwise
+ """
+ mp3_bytes = None
+ full_audio = None
+
+ if processed_frames == 0:
+ if is_last_chunk:
+ output_wave = vc_wave[0].cpu().numpy()
+ generated_wave_chunks.append(output_wave)
+
+ if stream_output:
+ output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+ mp3_bytes = AudioSegment(
+ output_wave_int16.tobytes(), frame_rate=self.sr,
+ sample_width=output_wave_int16.dtype.itemsize, channels=1
+ ).export(format="mp3", bitrate=self.bitrate).read()
+ full_audio = (self.sr, np.concatenate(generated_wave_chunks))
+ else:
+ return processed_frames, previous_chunk, True, None, np.concatenate(generated_wave_chunks)
+
+ return processed_frames, previous_chunk, True, mp3_bytes, full_audio
+
+ output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
+ generated_wave_chunks.append(output_wave)
+ previous_chunk = vc_wave[0, -overlap_wave_len:]
+ processed_frames += vc_mel.size(2) - self.overlap_frame_len
+
+ if stream_output:
+ output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+ mp3_bytes = AudioSegment(
+ output_wave_int16.tobytes(), frame_rate=self.sr,
+ sample_width=output_wave_int16.dtype.itemsize, channels=1
+ ).export(format="mp3", bitrate=self.bitrate).read()
+
+ elif is_last_chunk:
+ output_wave = self.crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
+ generated_wave_chunks.append(output_wave)
+ processed_frames += vc_mel.size(2) - self.overlap_frame_len
+
+ if stream_output:
+ output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+ mp3_bytes = AudioSegment(
+ output_wave_int16.tobytes(), frame_rate=self.sr,
+ sample_width=output_wave_int16.dtype.itemsize, channels=1
+ ).export(format="mp3", bitrate=self.bitrate).read()
+ full_audio = (self.sr, np.concatenate(generated_wave_chunks))
+ else:
+ return processed_frames, previous_chunk, True, None, np.concatenate(generated_wave_chunks)
+
+ return processed_frames, previous_chunk, True, mp3_bytes, full_audio
+
+ else:
+ output_wave = self.crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
+ generated_wave_chunks.append(output_wave)
+ previous_chunk = vc_wave[0, -overlap_wave_len:]
+ processed_frames += vc_mel.size(2) - self.overlap_frame_len
+
+ if stream_output:
+ output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+ mp3_bytes = AudioSegment(
+ output_wave_int16.tobytes(), frame_rate=self.sr,
+ sample_width=output_wave_int16.dtype.itemsize, channels=1
+ ).export(format="mp3", bitrate=self.bitrate).read()
+
+ return processed_frames, previous_chunk, False, mp3_bytes, full_audio
+
+ def load_checkpoints(
+ self,
+ cfm_checkpoint_path = None,
+ ar_checkpoint_path = None,
+ ):
+ if cfm_checkpoint_path is None:
+ cfm_checkpoint_path = load_custom_model_from_hf(
+ repo_id=DEFAULT_REPO_ID,
+ model_filename=DEFAULT_CFM_CHECKPOINT,
+ )
+ if ar_checkpoint_path is None:
+ ar_checkpoint_path = load_custom_model_from_hf(
+ repo_id=DEFAULT_REPO_ID,
+ model_filename=DEFAULT_AR_CHECKPOINT,
+ )
+ # cfm
+ cfm_checkpoint = torch.load(cfm_checkpoint_path, map_location="cpu")
+ cfm_length_regulator_state_dict = self.strip_prefix(cfm_checkpoint["net"]['length_regulator'], "module.")
+ cfm_state_dict = self.strip_prefix(cfm_checkpoint["net"]['cfm'], "module.")
+ self.cfm.load_state_dict(cfm_state_dict, strict=False)
+ self.cfm_length_regulator.load_state_dict(cfm_length_regulator_state_dict, strict=False)
+
+ # ar
+ ar_checkpoint = torch.load(ar_checkpoint_path, map_location="cpu")
+ ar_length_regulator_state_dict = self.strip_prefix(ar_checkpoint["net"]['length_regulator'], "module.")
+ ar_state_dict = self.strip_prefix(ar_checkpoint["net"]['ar'], "module.")
+ self.ar.load_state_dict(ar_state_dict, strict=False)
+ self.ar_length_regulator.load_state_dict(ar_length_regulator_state_dict, strict=False)
+
+ # content extractor
+ content_extractor_narrow_checkpoint_path = load_custom_model_from_hf(
+ repo_id=DEFAULT_CE_REPO_ID,
+ model_filename=DEFAULT_CE_NARROW_CHECKPOINT,
+ )
+ content_extractor_narrow_checkpoint = torch.load(content_extractor_narrow_checkpoint_path, map_location="cpu")
+ self.content_extractor_narrow.load_state_dict(
+ content_extractor_narrow_checkpoint, strict=False
+ )
+
+ content_extractor_wide_checkpoint_path = load_custom_model_from_hf(
+ repo_id=DEFAULT_CE_REPO_ID,
+ model_filename=DEFAULT_CE_WIDE_CHECKPOINT,
+ )
+ content_extractor_wide_checkpoint = torch.load(content_extractor_wide_checkpoint_path, map_location="cpu")
+ self.content_extractor_wide.load_state_dict(
+ content_extractor_wide_checkpoint, strict=False
+ )
+
+ # style encoder
+ style_encoder_checkpoint_path = load_custom_model_from_hf(DEFAULT_SE_REPO_ID, DEFAULT_SE_CHECKPOINT, config_filename=None)
+ style_encoder_checkpoint = torch.load(style_encoder_checkpoint_path, map_location="cpu")
+ self.style_encoder.load_state_dict(style_encoder_checkpoint, strict=False)
+
+ def setup_ar_caches(self, max_batch_size=1, max_seq_len=4096, dtype=torch.float32, device=torch.device("cpu")):
+ self.ar.setup_caches(max_batch_size=max_batch_size, max_seq_len=max_seq_len, dtype=dtype, device=device)
+
+ def compute_style(self, waves_16k: torch.Tensor):
+ feat = torchaudio.compliance.kaldi.fbank(waves_16k,
+ num_mel_bins=80,
+ dither=0,
+ sample_frequency=16000)
+ feat = feat - feat.mean(dim=0, keepdim=True)
+ style = self.style_encoder(feat.unsqueeze(0))
+ return style
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def convert_timbre(
+ self,
+ source_audio_path: str,
+ target_audio_path: str,
+ diffusion_steps: int = 30,
+ length_adjust: float = 1.0,
+ inference_cfg_rate: float = 0.5,
+ use_sway_sampling: bool = False,
+ use_amo_sampling: bool = False,
+ device: torch.device = torch.device("cpu"),
+ dtype: torch.dtype = torch.float32,
+ ):
+ source_wave = librosa.load(source_audio_path, sr=self.sr)[0]
+ target_wave = librosa.load(target_audio_path, sr=self.sr)[0]
+ source_wave_tensor = torch.tensor(source_wave).unsqueeze(0).to(device)
+ target_wave_tensor = torch.tensor(target_wave).unsqueeze(0).to(device)
+
+ # get 16khz audio
+ source_wave_16k = librosa.resample(source_wave, orig_sr=self.sr, target_sr=16000)
+ target_wave_16k = librosa.resample(target_wave, orig_sr=self.sr, target_sr=16000)
+ source_wave_16k_tensor = torch.tensor(source_wave_16k).unsqueeze(0).to(device)
+ target_wave_16k_tensor = torch.tensor(target_wave_16k).unsqueeze(0).to(device)
+
+ # compute mel spectrogram
+ source_mel = self.mel_fn(source_wave_tensor)
+ target_mel = self.mel_fn(target_wave_tensor)
+ source_mel_len = source_mel.size(2)
+ target_mel_len = target_mel.size(2)
+
+ with torch.autocast(device_type=device.type, dtype=dtype):
+ # compute content features
+ _, source_content_indices, _ = self.content_extractor_wide(source_wave_16k_tensor, [source_wave_16k.size])
+ _, target_content_indices, _ = self.content_extractor_wide(target_wave_16k_tensor, [target_wave_16k.size])
+
+ # compute style features
+ target_style = self.compute_style(target_wave_16k_tensor)
+
+ # Length regulation
+ cond, _ = self.cfm_length_regulator(source_content_indices, ylens=torch.LongTensor([source_mel_len]).to(device))
+ prompt_condition, _, = self.cfm_length_regulator(target_content_indices, ylens=torch.LongTensor([target_mel_len]).to(device))
+
+ cat_condition = torch.cat([prompt_condition, cond], dim=1)
+ # generate mel spectrogram
+ vc_mel = self.cfm.inference(
+ cat_condition,
+ torch.LongTensor([cat_condition.size(1)]).to(device),
+ target_mel, target_style, diffusion_steps,
+ inference_cfg_rate=inference_cfg_rate,
+ sway_sampling=use_sway_sampling,
+ amo_sampling=use_amo_sampling,
+ )
+ vc_mel = vc_mel[:, :, target_mel_len:]
+ vc_wave = self.vocoder(vc_mel.float()).squeeze()[None]
+ return vc_wave.cpu().numpy()
+
+ @torch.no_grad()
+ @torch.inference_mode()
+ def convert_voice(
+ self,
+ source_audio_path: str,
+ target_audio_path: str,
+ diffusion_steps: int = 30,
+ length_adjust: float = 1.0,
+ inference_cfg_rate: float = 0.5,
+ top_p: float = 0.7,
+ temperature: float = 0.7,
+ repetition_penalty: float = 1.5,
+ use_sway_sampling: bool = False,
+ use_amo_sampling: bool = False,
+ device: torch.device = torch.device("cpu"),
+ dtype: torch.dtype = torch.float32,
+ ):
+ source_wave = librosa.load(source_audio_path, sr=self.sr)[0]
+ target_wave = librosa.load(target_audio_path, sr=self.sr)[0]
+ source_wave_tensor = torch.tensor(source_wave).unsqueeze(0).to(device)
+ target_wave_tensor = torch.tensor(target_wave).unsqueeze(0).to(device)
+
+ # get 16khz audio
+ source_wave_16k = librosa.resample(source_wave, orig_sr=self.sr, target_sr=16000)
+ target_wave_16k = librosa.resample(target_wave, orig_sr=self.sr, target_sr=16000)
+ source_wave_16k_tensor = torch.tensor(source_wave_16k).unsqueeze(0).to(device)
+ target_wave_16k_tensor = torch.tensor(target_wave_16k).unsqueeze(0).to(device)
+
+ # compute mel spectrogram
+ source_mel = self.mel_fn(source_wave_tensor)
+ target_mel = self.mel_fn(target_wave_tensor)
+ source_mel_len = source_mel.size(2)
+ target_mel_len = target_mel.size(2)
+
+ with torch.autocast(device_type=device.type, dtype=dtype):
+ # compute content features
+ _, source_content_indices, _ = self.content_extractor_wide(source_wave_16k_tensor, [source_wave_16k.size])
+ _, target_content_indices, _ = self.content_extractor_wide(target_wave_16k_tensor, [target_wave_16k.size])
+
+ _, source_narrow_indices, _ = self.content_extractor_narrow(source_wave_16k_tensor,
+ [source_wave_16k.size], ssl_model=self.content_extractor_wide.ssl_model)
+ _, target_narrow_indices, _ = self.content_extractor_narrow(target_wave_16k_tensor,
+ [target_wave_16k.size], ssl_model=self.content_extractor_wide.ssl_model)
+
+ src_narrow_reduced, src_narrow_len = self.duration_reduction_func(source_narrow_indices[0], 1)
+ tgt_narrow_reduced, tgt_narrow_len = self.duration_reduction_func(target_narrow_indices[0], 1)
+
+ ar_cond = self.ar_length_regulator(torch.cat([tgt_narrow_reduced, src_narrow_reduced], dim=0)[None])[0]
+
+ ar_out = self.ar.generate(ar_cond, target_content_indices, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty)
+ ar_out_mel_len = torch.LongTensor([int(source_mel_len / source_content_indices.size(-1) * ar_out.size(-1) * length_adjust)]).to(device)
+ # compute style features
+ target_style = self.compute_style(target_wave_16k_tensor)
+
+ # Length regulation
+ cond, _ = self.cfm_length_regulator(ar_out, ylens=torch.LongTensor([ar_out_mel_len]).to(device))
+ prompt_condition, _, = self.cfm_length_regulator(target_content_indices, ylens=torch.LongTensor([target_mel_len]).to(device))
+
+ cat_condition = torch.cat([prompt_condition, cond], dim=1)
+ # generate mel spectrogram
+ vc_mel = self.cfm.inference(
+ cat_condition,
+ torch.LongTensor([cat_condition.size(1)]).to(device),
+ target_mel, target_style, diffusion_steps,
+ inference_cfg_rate=inference_cfg_rate,
+ sway_sampling=use_sway_sampling,
+ amo_sampling=use_amo_sampling,
+ )
+ vc_mel = vc_mel[:, :, target_mel_len:]
+ vc_wave = self.vocoder(vc_mel.float()).squeeze()[None]
+ return vc_wave.cpu().numpy()
+
+ def _process_content_features(self, audio_16k_tensor, is_narrow=False):
+ """Process audio through Whisper model to extract features."""
+ content_extractor_fn = self.content_extractor_narrow if is_narrow else self.content_extractor_wide
+ if audio_16k_tensor.size(-1) <= 16000 * 30:
+ # Compute content features
+ _, content_indices, _ = content_extractor_fn(audio_16k_tensor, [audio_16k_tensor.size(-1)], ssl_model=self.content_extractor_wide.ssl_model)
+ else:
+ # Process long audio in chunks
+ overlapping_time = 5 # 5 seconds
+ features_list = []
+ buffer = None
+ traversed_time = 0
+ while traversed_time < audio_16k_tensor.size(-1):
+ if buffer is None: # first chunk
+ chunk = audio_16k_tensor[:, traversed_time:traversed_time + 16000 * 30]
+ else:
+ chunk = torch.cat([
+ buffer,
+ audio_16k_tensor[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]
+ ], dim=-1)
+ _, chunk_content_indices, _ = content_extractor_fn(chunk, [chunk.size(-1)], ssl_model=self.content_extractor_wide.ssl_model)
+ if traversed_time == 0:
+ features_list.append(chunk_content_indices)
+ else:
+ features_list.append(chunk_content_indices[:, 50 * overlapping_time:])
+ buffer = chunk[:, -16000 * overlapping_time:]
+ traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
+ content_indices = torch.cat(features_list, dim=1)
+
+ return content_indices
+
+ @spaces.GPU
+ @torch.no_grad()
+ @torch.inference_mode()
+ def convert_voice_with_streaming(
+ self,
+ source_audio_path: str,
+ target_audio_path: str,
+ diffusion_steps: int = 30,
+ length_adjust: float = 1.0,
+ intelligebility_cfg_rate: float = 0.7,
+ similarity_cfg_rate: float = 0.7,
+ top_p: float = 0.7,
+ temperature: float = 0.7,
+ repetition_penalty: float = 1.5,
+ convert_style: bool = False,
+ anonymization_only: bool = False,
+ device: torch.device = torch.device("cuda"),
+ dtype: torch.dtype = torch.float16,
+ stream_output: bool = True,
+ ):
+ """
+ Convert voice with streaming support for long audio files.
+
+ Args:
+ source_audio_path: Path to source audio file
+ target_audio_path: Path to target audio file
+ diffusion_steps: Number of diffusion steps (default: 30)
+ length_adjust: Length adjustment factor (default: 1.0)
+ intelligebility_cfg_rate: CFG rate for intelligibility (default: 0.7)
+ similarity_cfg_rate: CFG rate for similarity (default: 0.7)
+ top_p: Top-p sampling parameter (default: 0.7)
+ temperature: Temperature for sampling (default: 0.7)
+ repetition_penalty: Repetition penalty (default: 1.5)
+ device: Device to use (default: cpu)
+ dtype: Data type to use (default: float32)
+ stream_output: Whether to stream the output (default: True)
+
+ Returns:
+ If stream_output is True, yields (mp3_bytes, full_audio) tuples
+ If stream_output is False, returns the full audio as a numpy array
+ """
+ # Load audio
+ source_wave = librosa.load(source_audio_path, sr=self.sr)[0]
+ target_wave = librosa.load(target_audio_path, sr=self.sr)[0]
+
+ # Limit target audio to 25 seconds
+ target_wave = target_wave[:self.sr * (self.dit_max_context_len - 5)]
+
+ source_wave_tensor = torch.tensor(source_wave).unsqueeze(0).float().to(device)
+ target_wave_tensor = torch.tensor(target_wave).unsqueeze(0).float().to(device)
+
+ # Resample to 16kHz for feature extraction
+ source_wave_16k = librosa.resample(source_wave, orig_sr=self.sr, target_sr=16000)
+ target_wave_16k = librosa.resample(target_wave, orig_sr=self.sr, target_sr=16000)
+ source_wave_16k_tensor = torch.tensor(source_wave_16k).unsqueeze(0).to(device)
+ target_wave_16k_tensor = torch.tensor(target_wave_16k).unsqueeze(0).to(device)
+
+ # Compute mel spectrograms
+ source_mel = self.mel_fn(source_wave_tensor)
+ target_mel = self.mel_fn(target_wave_tensor)
+ source_mel_len = source_mel.size(2)
+ target_mel_len = target_mel.size(2)
+
+ # Set up chunk processing parameters
+ max_context_window = self.sr // self.hop_size * self.dit_max_context_len
+ overlap_wave_len = self.overlap_frame_len * self.hop_size
+
+ with torch.autocast(device_type=device.type, dtype=dtype):
+ # Compute content features
+ source_content_indices = self._process_content_features(source_wave_16k_tensor, is_narrow=False)
+ target_content_indices = self._process_content_features(target_wave_16k_tensor, is_narrow=False)
+ # Compute style features
+ target_style = self.compute_style(target_wave_16k_tensor)
+ prompt_condition, _, = self.cfm_length_regulator(target_content_indices,
+ ylens=torch.LongTensor([target_mel_len]).to(device))
+
+ # prepare for streaming
+ generated_wave_chunks = []
+ processed_frames = 0
+ previous_chunk = None
+ if convert_style:
+ with torch.autocast(device_type=device.type, dtype=dtype):
+ source_narrow_indices = self._process_content_features(source_wave_16k_tensor, is_narrow=True)
+ target_narrow_indices = self._process_content_features(target_wave_16k_tensor, is_narrow=True)
+ src_narrow_reduced, src_narrow_len = self.duration_reduction_func(source_narrow_indices[0], 1)
+ tgt_narrow_reduced, tgt_narrow_len = self.duration_reduction_func(target_narrow_indices[0], 1)
+ # Process src_narrow_reduced in chunks of max 1000 tokens
+ max_chunk_size = 1000
+
+ # Process src_narrow_reduced in chunks
+ for i in range(0, len(src_narrow_reduced), max_chunk_size):
+ is_last_chunk = i + max_chunk_size >= len(src_narrow_reduced)
+ with torch.autocast(device_type=device.type, dtype=dtype):
+ chunk = src_narrow_reduced[i:i + max_chunk_size]
+ if anonymization_only:
+ chunk_ar_cond = self.ar_length_regulator(chunk[None])[0]
+ chunk_ar_out = self.ar.generate(chunk_ar_cond, torch.zeros([1, 0]).long().to(device),
+ compiled_decode_fn=self.compiled_decode_fn,
+ top_p=top_p, temperature=temperature,
+ repetition_penalty=repetition_penalty)
+ else:
+ # For each chunk, we need to include tgt_narrow_reduced as context
+ chunk_ar_cond = self.ar_length_regulator(torch.cat([tgt_narrow_reduced, chunk], dim=0)[None])[0]
+ chunk_ar_out = self.ar.generate(chunk_ar_cond, target_content_indices, compiled_decode_fn=self.compiled_decode_fn,
+ top_p=top_p, temperature=temperature,
+ repetition_penalty=repetition_penalty)
+ chunkar_out_mel_len = torch.LongTensor([int(source_mel_len / source_content_indices.size(
+ -1) * chunk_ar_out.size(-1) * length_adjust)]).to(device)
+ # Length regulation
+ chunk_cond, _ = self.cfm_length_regulator(chunk_ar_out, ylens=torch.LongTensor([chunkar_out_mel_len]).to(device))
+ cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
+ original_len = cat_condition.size(1)
+ # pad cat_condition to compile_len
+ if self.dit_compiled:
+ cat_condition = torch.nn.functional.pad(cat_condition,
+ (0, 0, 0, self.compile_len - cat_condition.size(1),),
+ value=0)
+ # Voice Conversion
+ vc_mel = self.cfm.inference(
+ cat_condition,
+ torch.LongTensor([original_len]).to(device),
+ target_mel, target_style, diffusion_steps,
+ inference_cfg_rate=[intelligebility_cfg_rate, similarity_cfg_rate],
+ random_voice=anonymization_only,
+ )
+ vc_mel = vc_mel[:, :, target_mel_len:original_len]
+ vc_wave = self.vocoder(vc_mel).squeeze()[None]
+ processed_frames, previous_chunk, should_break, mp3_bytes, full_audio = self._stream_wave_chunks(
+ vc_wave, processed_frames, vc_mel, overlap_wave_len,
+ generated_wave_chunks, previous_chunk, is_last_chunk, stream_output
+ )
+
+ if stream_output and mp3_bytes is not None:
+ yield mp3_bytes, full_audio
+
+ if should_break:
+ if not stream_output:
+ return full_audio
+ break
+ else:
+ cond, _ = self.cfm_length_regulator(source_content_indices, ylens=torch.LongTensor([source_mel_len]).to(device))
+
+ # Process in chunks for streaming
+ max_source_window = max_context_window - target_mel.size(2)
+
+ # Generate chunk by chunk and stream the output
+ while processed_frames < cond.size(1):
+ chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
+ is_last_chunk = processed_frames + max_source_window >= cond.size(1)
+ cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
+ original_len = cat_condition.size(1)
+ # pad cat_condition to compile_len
+ if self.dit_compiled:
+ cat_condition = torch.nn.functional.pad(cat_condition,
+ (0, 0, 0, self.compile_len - cat_condition.size(1),), value=0)
+ with torch.autocast(device_type=device.type, dtype=dtype):
+ # Voice Conversion
+ vc_mel = self.cfm.inference(
+ cat_condition,
+ torch.LongTensor([original_len]).to(device),
+ target_mel, target_style, diffusion_steps,
+ inference_cfg_rate=[intelligebility_cfg_rate, similarity_cfg_rate],
+ random_voice=anonymization_only,
+ )
+ vc_mel = vc_mel[:, :, target_mel_len:original_len]
+ vc_wave = self.vocoder(vc_mel).squeeze()[None]
+
+ processed_frames, previous_chunk, should_break, mp3_bytes, full_audio = self._stream_wave_chunks(
+ vc_wave, processed_frames, vc_mel, overlap_wave_len,
+ generated_wave_chunks, previous_chunk, is_last_chunk, stream_output
+ )
+
+ if stream_output and mp3_bytes is not None:
+ yield mp3_bytes, full_audio
+
+ if should_break:
+ if not stream_output:
+ return full_audio
+ break
+
+
diff --git a/requirements.txt b/requirements.txt
index e608c6a787058579c732680c3337566a0e36b259..4fa463bd62265380f5549538f043fa358200ef01 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,24 @@
---extra-index-url https://download.pytorch.org/whl/cu113
-torch
-torchvision
-torchaudio
-scipy==1.13.1
-onnxruntime-gpu==1.19.0
-librosa==0.10.2
-huggingface-hub
-munch
-einops
-descript-audio-codec
-git+https://github.com/openai/whisper.git
-pydub
-transformers
\ No newline at end of file
+--extra-index-url https://download.pytorch.org/whl/cu121
+torch==2.4.0
+torchvision==0.19.0
+torchaudio==2.4.0
+scipy==1.13.1
+librosa==0.10.2
+huggingface-hub==0.23.4
+munch==4.0.0
+einops==0.8.0
+descript-audio-codec==1.0.0
+gradio==5.23.0
+pydub==0.25.1
+resemblyzer
+jiwer==3.0.3
+transformers==4.46.3
+FreeSimpleGUI==5.1.1
+soundfile==0.12.1
+sounddevice==0.5.0
+modelscope==1.18.1
+funasr==1.1.5
+numpy==1.26.4
+hydra-core==1.3.2
+pyyaml
+python-dotenv
diff --git a/seed_vc_wrapper.py b/seed_vc_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..14caf0cfd3a18861e1a2b2c5fc3774aaa244e353
--- /dev/null
+++ b/seed_vc_wrapper.py
@@ -0,0 +1,463 @@
+import spaces
+import torch
+import torchaudio
+import librosa
+import numpy as np
+from pydub import AudioSegment
+import yaml
+from modules.commons import build_model, load_checkpoint, recursive_munch
+from hf_utils import load_custom_model_from_hf
+from modules.campplus.DTDNN import CAMPPlus
+from modules.bigvgan import bigvgan
+from modules.audio import mel_spectrogram
+from modules.rmvpe import RMVPE
+from transformers import AutoFeatureExtractor, WhisperModel
+
+class SeedVCWrapper:
+ def __init__(self, device=None):
+ """
+ Initialize the Seed-VC wrapper with all necessary models and configurations.
+
+ Args:
+ device: torch device to use. If None, will be automatically determined.
+ """
+ # Set device
+ if device is None:
+ if torch.cuda.is_available():
+ self.device = torch.device("cuda")
+ elif torch.backends.mps.is_available():
+ self.device = torch.device("mps")
+ else:
+ self.device = torch.device("cpu")
+ else:
+ self.device = device
+
+ # Load base model and configuration
+ self._load_base_model()
+
+ # Load F0 conditioned model
+ self._load_f0_model()
+
+ # Load additional modules
+ self._load_additional_modules()
+
+ # Set streaming parameters
+ self.overlap_frame_len = 16
+ self.bitrate = "320k"
+
+ def _load_base_model(self):
+ """Load the base DiT model for voice conversion."""
+ dit_checkpoint_path, dit_config_path = load_custom_model_from_hf(
+ "Plachta/Seed-VC",
+ "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
+ "config_dit_mel_seed_uvit_whisper_small_wavenet.yml"
+ )
+ config = yaml.safe_load(open(dit_config_path, 'r'))
+ model_params = recursive_munch(config['model_params'])
+ self.model = build_model(model_params, stage='DiT')
+ self.hop_length = config['preprocess_params']['spect_params']['hop_length']
+ self.sr = config['preprocess_params']['sr']
+
+ # Load checkpoints
+ self.model, _, _, _ = load_checkpoint(
+ self.model, None, dit_checkpoint_path,
+ load_only_params=True, ignore_modules=[], is_distributed=False
+ )
+ for key in self.model:
+ self.model[key].eval()
+ self.model[key].to(self.device)
+ self.model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
+
+ # Set up mel spectrogram function
+ mel_fn_args = {
+ "n_fft": config['preprocess_params']['spect_params']['n_fft'],
+ "win_size": config['preprocess_params']['spect_params']['win_length'],
+ "hop_size": config['preprocess_params']['spect_params']['hop_length'],
+ "num_mels": config['preprocess_params']['spect_params']['n_mels'],
+ "sampling_rate": self.sr,
+ "fmin": 0,
+ "fmax": None,
+ "center": False
+ }
+ self.to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
+
+ # Load whisper model
+ whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer, 'whisper_name') else "openai/whisper-small"
+ self.whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(self.device)
+ del self.whisper_model.decoder
+ self.whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
+
+ def _load_f0_model(self):
+ """Load the F0 conditioned model for voice conversion."""
+ dit_checkpoint_path, dit_config_path = load_custom_model_from_hf(
+ "Plachta/Seed-VC",
+ "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
+ "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml"
+ )
+ config = yaml.safe_load(open(dit_config_path, 'r'))
+ model_params = recursive_munch(config['model_params'])
+ self.model_f0 = build_model(model_params, stage='DiT')
+ self.hop_length_f0 = config['preprocess_params']['spect_params']['hop_length']
+ self.sr_f0 = config['preprocess_params']['sr']
+
+ # Load checkpoints
+ self.model_f0, _, _, _ = load_checkpoint(
+ self.model_f0, None, dit_checkpoint_path,
+ load_only_params=True, ignore_modules=[], is_distributed=False
+ )
+ for key in self.model_f0:
+ self.model_f0[key].eval()
+ self.model_f0[key].to(self.device)
+ self.model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
+
+ # Set up mel spectrogram function for F0 model
+ mel_fn_args_f0 = {
+ "n_fft": config['preprocess_params']['spect_params']['n_fft'],
+ "win_size": config['preprocess_params']['spect_params']['win_length'],
+ "hop_size": config['preprocess_params']['spect_params']['hop_length'],
+ "num_mels": config['preprocess_params']['spect_params']['n_mels'],
+ "sampling_rate": self.sr_f0,
+ "fmin": 0,
+ "fmax": None,
+ "center": False
+ }
+ self.to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0)
+
+ def _load_additional_modules(self):
+ """Load additional modules like CAMPPlus, BigVGAN, and RMVPE."""
+ # Load CAMPPlus
+ campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None)
+ self.campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
+ self.campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
+ self.campplus_model.eval()
+ self.campplus_model.to(self.device)
+
+ # Load BigVGAN models
+ self.bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False)
+ self.bigvgan_model.remove_weight_norm()
+ self.bigvgan_model = self.bigvgan_model.eval().to(self.device)
+
+ self.bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False)
+ self.bigvgan_44k_model.remove_weight_norm()
+ self.bigvgan_44k_model = self.bigvgan_44k_model.eval().to(self.device)
+
+ # Load RMVPE for F0 extraction
+ model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
+ self.rmvpe = RMVPE(model_path, is_half=False, device=self.device)
+
+ @staticmethod
+ def adjust_f0_semitones(f0_sequence, n_semitones):
+ """Adjust F0 values by a number of semitones."""
+ factor = 2 ** (n_semitones / 12)
+ return f0_sequence * factor
+
+ @staticmethod
+ def crossfade(chunk1, chunk2, overlap):
+ """Apply crossfade between two audio chunks."""
+ fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
+ fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
+ if len(chunk2) < overlap:
+ chunk2[:overlap] = chunk2[:overlap] * fade_in[:len(chunk2)] + (chunk1[-overlap:] * fade_out)[:len(chunk2)]
+ else:
+ chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
+ return chunk2
+
+ def _stream_wave_chunks(self, vc_wave, processed_frames, vc_target, overlap_wave_len,
+ generated_wave_chunks, previous_chunk, is_last_chunk, stream_output, sr):
+ """
+ Helper method to handle streaming wave chunks.
+
+ Args:
+ vc_wave: The current wave chunk
+ processed_frames: Number of frames processed so far
+ vc_target: The target mel spectrogram
+ overlap_wave_len: Length of overlap between chunks
+ generated_wave_chunks: List of generated wave chunks
+ previous_chunk: Previous wave chunk for crossfading
+ is_last_chunk: Whether this is the last chunk
+ stream_output: Whether to stream the output
+ sr: Sample rate
+
+ Returns:
+ Tuple of (processed_frames, previous_chunk, should_break, mp3_bytes, full_audio)
+ where should_break indicates if processing should stop
+ mp3_bytes is the MP3 bytes if streaming, None otherwise
+ full_audio is the full audio if this is the last chunk, None otherwise
+ """
+ mp3_bytes = None
+ full_audio = None
+
+ if processed_frames == 0:
+ if is_last_chunk:
+ output_wave = vc_wave[0].cpu().numpy()
+ generated_wave_chunks.append(output_wave)
+
+ if stream_output:
+ output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+ mp3_bytes = AudioSegment(
+ output_wave_int16.tobytes(), frame_rate=sr,
+ sample_width=output_wave_int16.dtype.itemsize, channels=1
+ ).export(format="mp3", bitrate=self.bitrate).read()
+ full_audio = (sr, np.concatenate(generated_wave_chunks))
+ else:
+ return processed_frames, previous_chunk, True, None, np.concatenate(generated_wave_chunks)
+
+ return processed_frames, previous_chunk, True, mp3_bytes, full_audio
+
+ output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy()
+ generated_wave_chunks.append(output_wave)
+ previous_chunk = vc_wave[0, -overlap_wave_len:]
+ processed_frames += vc_target.size(2) - self.overlap_frame_len
+
+ if stream_output:
+ output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+ mp3_bytes = AudioSegment(
+ output_wave_int16.tobytes(), frame_rate=sr,
+ sample_width=output_wave_int16.dtype.itemsize, channels=1
+ ).export(format="mp3", bitrate=self.bitrate).read()
+
+ elif is_last_chunk:
+ output_wave = self.crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len)
+ generated_wave_chunks.append(output_wave)
+ processed_frames += vc_target.size(2) - self.overlap_frame_len
+
+ if stream_output:
+ output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+ mp3_bytes = AudioSegment(
+ output_wave_int16.tobytes(), frame_rate=sr,
+ sample_width=output_wave_int16.dtype.itemsize, channels=1
+ ).export(format="mp3", bitrate=self.bitrate).read()
+ full_audio = (sr, np.concatenate(generated_wave_chunks))
+ else:
+ return processed_frames, previous_chunk, True, None, np.concatenate(generated_wave_chunks)
+
+ return processed_frames, previous_chunk, True, mp3_bytes, full_audio
+
+ else:
+ output_wave = self.crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len)
+ generated_wave_chunks.append(output_wave)
+ previous_chunk = vc_wave[0, -overlap_wave_len:]
+ processed_frames += vc_target.size(2) - self.overlap_frame_len
+
+ if stream_output:
+ output_wave_int16 = (output_wave * 32768.0).astype(np.int16)
+ mp3_bytes = AudioSegment(
+ output_wave_int16.tobytes(), frame_rate=sr,
+ sample_width=output_wave_int16.dtype.itemsize, channels=1
+ ).export(format="mp3", bitrate=self.bitrate).read()
+
+ return processed_frames, previous_chunk, False, mp3_bytes, full_audio
+
+ def _process_whisper_features(self, audio_16k, is_source=True):
+ """Process audio through Whisper model to extract features."""
+ if audio_16k.size(-1) <= 16000 * 30:
+ # If audio is short enough, process in one go
+ inputs = self.whisper_feature_extractor(
+ [audio_16k.squeeze(0).cpu().numpy()],
+ return_tensors="pt",
+ return_attention_mask=True,
+ sampling_rate=16000
+ )
+ input_features = self.whisper_model._mask_input_features(
+ inputs.input_features, attention_mask=inputs.attention_mask
+ ).to(self.device)
+ outputs = self.whisper_model.encoder(
+ input_features.to(self.whisper_model.encoder.dtype),
+ head_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ )
+ features = outputs.last_hidden_state.to(torch.float32)
+ features = features[:, :audio_16k.size(-1) // 320 + 1]
+ else:
+ # Process long audio in chunks
+ overlapping_time = 5 # 5 seconds
+ features_list = []
+ buffer = None
+ traversed_time = 0
+ while traversed_time < audio_16k.size(-1):
+ if buffer is None: # first chunk
+ chunk = audio_16k[:, traversed_time:traversed_time + 16000 * 30]
+ else:
+ chunk = torch.cat([
+ buffer,
+ audio_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]
+ ], dim=-1)
+ inputs = self.whisper_feature_extractor(
+ [chunk.squeeze(0).cpu().numpy()],
+ return_tensors="pt",
+ return_attention_mask=True,
+ sampling_rate=16000
+ )
+ input_features = self.whisper_model._mask_input_features(
+ inputs.input_features, attention_mask=inputs.attention_mask
+ ).to(self.device)
+ outputs = self.whisper_model.encoder(
+ input_features.to(self.whisper_model.encoder.dtype),
+ head_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ )
+ chunk_features = outputs.last_hidden_state.to(torch.float32)
+ chunk_features = chunk_features[:, :chunk.size(-1) // 320 + 1]
+ if traversed_time == 0:
+ features_list.append(chunk_features)
+ else:
+ features_list.append(chunk_features[:, 50 * overlapping_time:])
+ buffer = chunk[:, -16000 * overlapping_time:]
+ traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
+ features = torch.cat(features_list, dim=1)
+
+ return features
+
+ @spaces.GPU
+ @torch.no_grad()
+ @torch.inference_mode()
+ def convert_voice(self, source, target, diffusion_steps=10, length_adjust=1.0,
+ inference_cfg_rate=0.7, f0_condition=False, auto_f0_adjust=True,
+ pitch_shift=0, stream_output=True):
+ """
+ Convert both timbre and voice from source to target.
+
+ Args:
+ source: Path to source audio file
+ target: Path to target audio file
+ diffusion_steps: Number of diffusion steps (default: 10)
+ length_adjust: Length adjustment factor (default: 1.0)
+ inference_cfg_rate: Inference CFG rate (default: 0.7)
+ f0_condition: Whether to use F0 conditioning (default: False)
+ auto_f0_adjust: Whether to automatically adjust F0 (default: True)
+ pitch_shift: Pitch shift in semitones (default: 0)
+ stream_output: Whether to stream the output (default: True)
+
+ Returns:
+ If stream_output is True, yields (mp3_bytes, full_audio) tuples
+ If stream_output is False, returns the full audio as a numpy array
+ """
+ # Select appropriate models based on F0 condition
+ inference_module = self.model if not f0_condition else self.model_f0
+ mel_fn = self.to_mel if not f0_condition else self.to_mel_f0
+ bigvgan_fn = self.bigvgan_model if not f0_condition else self.bigvgan_44k_model
+ sr = 22050 if not f0_condition else 44100
+ hop_length = 256 if not f0_condition else 512
+ max_context_window = sr // hop_length * 30
+ overlap_wave_len = self.overlap_frame_len * hop_length
+
+ # Load audio
+ source_audio = librosa.load(source, sr=sr)[0]
+ ref_audio = librosa.load(target, sr=sr)[0]
+
+ # Process audio
+ source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device)
+ ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(self.device)
+
+ # Resample to 16kHz for feature extraction
+ ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
+ converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
+
+ # Extract Whisper features
+ S_alt = self._process_whisper_features(converted_waves_16k, is_source=True)
+ S_ori = self._process_whisper_features(ref_waves_16k, is_source=False)
+
+ # Compute mel spectrograms
+ mel = mel_fn(source_audio.to(self.device).float())
+ mel2 = mel_fn(ref_audio.to(self.device).float())
+
+ # Set target lengths
+ target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device)
+ target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
+
+ # Compute style features
+ feat2 = torchaudio.compliance.kaldi.fbank(
+ ref_waves_16k,
+ num_mel_bins=80,
+ dither=0,
+ sample_frequency=16000
+ )
+ feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
+ style2 = self.campplus_model(feat2.unsqueeze(0))
+
+ # Process F0 if needed
+ if f0_condition:
+ F0_ori = self.rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.03)
+ F0_alt = self.rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.03)
+
+ if self.device == "mps":
+ F0_ori = torch.from_numpy(F0_ori).float().to(self.device)[None]
+ F0_alt = torch.from_numpy(F0_alt).float().to(self.device)[None]
+ else:
+ F0_ori = torch.from_numpy(F0_ori).to(self.device)[None]
+ F0_alt = torch.from_numpy(F0_alt).to(self.device)[None]
+
+ voiced_F0_ori = F0_ori[F0_ori > 1]
+ voiced_F0_alt = F0_alt[F0_alt > 1]
+
+ log_f0_alt = torch.log(F0_alt + 1e-5)
+ voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5)
+ voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
+ median_log_f0_ori = torch.median(voiced_log_f0_ori)
+ median_log_f0_alt = torch.median(voiced_log_f0_alt)
+
+ # Shift alt log f0 level to ori log f0 level
+ shifted_log_f0_alt = log_f0_alt.clone()
+ if auto_f0_adjust:
+ shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
+ shifted_f0_alt = torch.exp(shifted_log_f0_alt)
+ if pitch_shift != 0:
+ shifted_f0_alt[F0_alt > 1] = self.adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch_shift)
+ else:
+ F0_ori = None
+ F0_alt = None
+ shifted_f0_alt = None
+
+ # Length regulation
+ cond, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(
+ S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt
+ )
+ prompt_condition, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(
+ S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori
+ )
+
+ # Process in chunks for streaming
+ max_source_window = max_context_window - mel2.size(2)
+ processed_frames = 0
+ generated_wave_chunks = []
+ previous_chunk = None
+
+ # Generate chunk by chunk and stream the output
+ while processed_frames < cond.size(1):
+ chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
+ is_last_chunk = processed_frames + max_source_window >= cond.size(1)
+ cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
+
+ with torch.autocast(device_type=self.device.type, dtype=torch.float16):
+ # Voice Conversion
+ vc_target = inference_module.cfm.inference(
+ cat_condition,
+ torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
+ mel2, style2, None, diffusion_steps,
+ inference_cfg_rate=inference_cfg_rate
+ )
+ vc_target = vc_target[:, :, mel2.size(-1):]
+
+ vc_wave = bigvgan_fn(vc_target.float())[0]
+
+ processed_frames, previous_chunk, should_break, mp3_bytes, full_audio = self._stream_wave_chunks(
+ vc_wave, processed_frames, vc_target, overlap_wave_len,
+ generated_wave_chunks, previous_chunk, is_last_chunk, stream_output, sr
+ )
+
+ if stream_output and mp3_bytes is not None:
+ yield mp3_bytes, full_audio
+
+ if should_break:
+ if not stream_output:
+ return full_audio
+ break
+
+ if not stream_output:
+ return np.concatenate(generated_wave_chunks)
+
+ return None, None
\ No newline at end of file