diff --git a/.gitattributes b/.gitattributes index 5b2c2cefa8481bbfc91b2d73fe574301a8f16d78..fbe3882f375b9158c51a9ea099d6c0bfbf5979e1 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,43 +1,47 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text -examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text -examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text -examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text -examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text -examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text -examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text -examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text -examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +examples/reference/dingzhen_0.wav filter=lfs diff=lfs merge=lfs -text +examples/reference/s3p2.wav filter=lfs diff=lfs merge=lfs -text +examples/source/source_s3.wav filter=lfs diff=lfs merge=lfs -text +examples/source/source_s4.wav filter=lfs diff=lfs merge=lfs -text +examples/source/Wiz[[:space:]]Khalifa,Charlie[[:space:]]Puth[[:space:]]-[[:space:]]See[[:space:]]You[[:space:]]Again[[:space:]]\[vocals\]_\[cut_28sec\].wav filter=lfs diff=lfs merge=lfs -text +examples/reference/trump_0.wav filter=lfs diff=lfs merge=lfs -text +examples/source/jay_0.wav filter=lfs diff=lfs merge=lfs -text +examples/source/TECHNOPOLIS[[:space:]]-[[:space:]]2085[[:space:]]\[vocals\]_\[cut_14sec\].wav filter=lfs diff=lfs merge=lfs -text +modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps filter=lfs diff=lfs merge=lfs -text +modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o filter=lfs diff=lfs merge=lfs -text +modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd filter=lfs diff=lfs merge=lfs -text +modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..28c115f84cdbae6dd450777a7bbe5c23f130a3ec --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +# general things to ignore +.DS_Store +build/ +build_contrib/ +dist/ +.cache/ +*.egg-info/ +*.egg +*.py[cod] +__pycache__/ +*.so +*~ + +# IDE +.vscode/ +.idea/ + +# misc +checkpoints/ +test_waves/ +reconstructed/ +.python-version +ruff.log +/configs/inuse/ +runs/ +/garbages/ +/flagged/ +/experimental/ diff --git a/README.md b/README.md index eb7e29cb828e7162313500060b2137ae6137e8dd..887bc8c01db2950bd731f3f2aaa2ec1fdcdfebe3 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,13 @@ ---- -title: Seed Voice Conversion -emoji: 🎤🔄 -colorFrom: green -colorTo: green -sdk: gradio -sdk_version: 4.42.0 -app_file: app.py -pinned: false -license: gpl-3.0 ---- - +--- +title: Seed Voice Conversion +emoji: 🎤🔄 +colorFrom: green +colorTo: green +sdk: gradio +sdk_version: 5.23.0 +app_file: app_v1v2.py +pinned: false +license: gpl-3.0 +--- + Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/app_v1v2.py b/app_v1v2.py new file mode 100644 index 0000000000000000000000000000000000000000..f5e63efa4da6092483cf663e81eacd0a56886c4f --- /dev/null +++ b/app_v1v2.py @@ -0,0 +1,175 @@ +import spaces +import gradio as gr +import torch +import yaml +import argparse +from seed_vc_wrapper import SeedVCWrapper + +# Set up device and torch configurations +if torch.cuda.is_available(): + device = torch.device("cuda") +elif torch.backends.mps.is_available(): + device = torch.device("mps") +else: + device = torch.device("cpu") + +torch._inductor.config.coordinate_descent_tuning = True +torch._inductor.config.triton.unique_kernel_names = True + +if hasattr(torch._inductor.config, "fx_graph_cache"): + # Experimental feature to reduce compilation times, will be on by default in future + torch._inductor.config.fx_graph_cache = True + +dtype = torch.float16 + +def load_v2_models(args): + from hydra.utils import instantiate + from omegaconf import DictConfig + cfg = DictConfig(yaml.safe_load(open("configs/v2/vc_wrapper.yaml", "r"))) + vc_wrapper = instantiate(cfg) + vc_wrapper.load_checkpoints() + vc_wrapper.to(device) + vc_wrapper.eval() + + vc_wrapper.setup_ar_caches(max_batch_size=1, max_seq_len=4096, dtype=dtype, device=device) + + if args.compile: + vc_wrapper.compile_ar() + # vc_wrapper.compile_cfm() + + return vc_wrapper + +def create_v1_interface(): + # Initialize the V1 wrapper + vc_wrapper = SeedVCWrapper() + + # Set up Gradio interface + description = ("Zero-shot voice conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) " + "for details and updates.
Note that any reference audio will be forcefully clipped to 25s if beyond this length.
" + "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.
" + "无需训练的 zero-shot 语音/歌声转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)
" + "请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。
若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。") + + inputs = [ + gr.Audio(type="filepath", label="Source Audio / 源音频"), + gr.Audio(type="filepath", label="Reference Audio / 参考音频"), + gr.Slider(minimum=1, maximum=200, value=10, step=1, label="Diffusion Steps / 扩散步数", + info="10 by default, 50~100 for best quality / 默认为 10,50~100 为最佳质量"), + gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整", + info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"), + gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Inference CFG Rate", + info="has subtle influence / 有微小影响"), + gr.Checkbox(label="Use F0 conditioned model / 启用F0输入", value=False, + info="Must set to true for singing voice conversion / 歌声转换时必须勾选"), + gr.Checkbox(label="Auto F0 adjust / 自动F0调整", value=True, + info="Roughly adjust F0 to match target voice. Only works when F0 conditioned model is used. / 粗略调整 F0 以匹配目标音色,仅在勾选 '启用F0输入' 时生效"), + gr.Slider(label='Pitch shift / 音调变换', minimum=-24, maximum=24, step=1, value=0, + info="Pitch shift in semitones, only works when F0 conditioned model is used / 半音数的音高变换,仅在勾选 '启用F0输入' 时生效"), + ] + + examples = [ + ["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 25, 1.0, 0.7, False, True, 0], + ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 25, 1.0, 0.7, True, True, 0], + ["examples/source/Wiz Khalifa,Charlie Puth - See You Again [vocals]_[cut_28sec].wav", + "examples/reference/teio_0.wav", 100, 1.0, 0.7, True, False, 0], + ["examples/source/TECHNOPOLIS - 2085 [vocals]_[cut_14sec].wav", + "examples/reference/trump_0.wav", 50, 1.0, 0.7, True, False, -12], + ] + + outputs = [ + gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'), + gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav') + ] + + return gr.Interface( + fn=vc_wrapper.convert_voice, + description=description, + inputs=inputs, + outputs=outputs, + title="Seed Voice Conversion V1 (Voice & Singing Voice Conversion)", + examples=examples, + cache_examples=False, + ) + +def create_v2_interface(vc_wrapper): + # Set up Gradio interface + description = ("Zero-shot voice/style conversion with in-context learning. For local deployment please check [GitHub repository](https://github.com/Plachtaa/seed-vc) " + "for details and updates.
Note that any reference audio will be forcefully clipped to 25s if beyond this length.
" + "If total duration of source and reference audio exceeds 30s, source audio will be processed in chunks.
" + "Please click the 'convert style/emotion/accent' checkbox to convert the style, emotion, or accent of the source audio, or else only timbre conversion will be performed.
" + "Click the 'anonymization only' checkbox will ignore reference audio but convert source to an 'average voice' determined by model itself.
" + "无需训练的 zero-shot 语音/口音转换模型,若需本地部署查看[GitHub页面](https://github.com/Plachtaa/seed-vc)
" + "请注意,参考音频若超过 25 秒,则会被自动裁剪至此长度。
若源音频和参考音频的总时长超过 30 秒,源音频将被分段处理。" + "
请勾选 'convert style/emotion/accent' 以转换源音频的风格、情感或口音,否则仅执行音色转换。
" + "勾选 'anonymization only' 会无视参考音频而将源音频转换为某种由模型自身决定的 '平均音色'。
" + + "Credits to [Vevo](https://github.com/open-mmlab/Amphion/tree/main/models/vc/vevo)" + ) + inputs = [ + gr.Audio(type="filepath", label="Source Audio / 源音频"), + gr.Audio(type="filepath", label="Reference Audio / 参考音频"), + gr.Slider(minimum=1, maximum=200, value=30, step=1, label="Diffusion Steps / 扩散步数", + info="30 by default, 50~100 for best quality / 默认为 30,50~100 为最佳质量"), + gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Adjust / 长度调整", + info="<1.0 for speed-up speech, >1.0 for slow-down speech / <1.0 加速语速,>1.0 减慢语速"), + gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="Intelligibility CFG Rate", + info="controls pronunciation intelligibility / 控制发音清晰度"), + gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.7, label="Similarity CFG Rate", + info="controls similarity to reference audio / 控制与参考音频的相似度"), + gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.9, label="Top-p", + info="AR model sampling top P"), + gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature", + info="AR model sampling temperature"), + gr.Slider(minimum=1.0, maximum=3.0, step=0.1, value=1.0, label="Repetition Penalty", + info="AR model sampling repetition penalty"), + gr.Checkbox(label="convert style/emotion/accent", value=False), + gr.Checkbox(label="anonymization only", value=False), + ] + + examples = [ + ["examples/source/yae_0.wav", "examples/reference/dingzhen_0.wav", 50, 1.0, 0.0, 0.7, 0.9, 1.0, 1.0, False, False], + ["examples/source/jay_0.wav", "examples/reference/azuma_0.wav", 50, 1.0, 0.0, 0.7, 0.9, 1.0, 1.0, False, False], + ] + + outputs = [ + gr.Audio(label="Stream Output Audio / 流式输出", streaming=True, format='mp3'), + gr.Audio(label="Full Output Audio / 完整输出", streaming=False, format='wav') + ] + + return gr.Interface( + fn=vc_wrapper.convert_voice_with_streaming, + description=description, + inputs=inputs, + outputs=outputs, + title="Seed Voice Conversion V2 (Voice & Style Conversion)", + examples=examples, + cache_examples=False, + ) + +def main(args): + # Load V2 models + vc_wrapper_v2 = load_v2_models(args) + + # Create interfaces + v1_interface = create_v1_interface() + v2_interface = create_v2_interface(vc_wrapper_v2) + + # Create tabs + with gr.Blocks(title="Seed Voice Conversion") as demo: + gr.Markdown("# Seed Voice Conversion") + gr.Markdown("Choose between V1 (Voice & Singing Voice Conversion) or V2 (Voice & Style Conversion)") + + with gr.Tabs(): + with gr.TabItem("V2 - Voice & Style Conversion"): + v2_interface.render() + with gr.TabItem("V1 - Voice & Singing Voice Conversion"): + v1_interface.render() + + # Launch the combined interface + demo.launch() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--compile", type=bool, default=True) + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/configs/astral_quantization/default_2048.yml b/configs/astral_quantization/default_2048.yml new file mode 100644 index 0000000000000000000000000000000000000000..54f91e7cea7722a8cdb85c9855bffeeedb84e5db --- /dev/null +++ b/configs/astral_quantization/default_2048.yml @@ -0,0 +1,40 @@ +_target_: modules.astral_quantization.default_model.AstralQuantizer +tokenizer_name: "openai/whisper-small" +ssl_model_name: "facebook/hubert-large-ll60k" +ssl_output_layer: 18 +encoder: + _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage + dim: 512 + num_blocks: 12 + intermediate_dim: 1536 + dilation: 1 + input_dim: 1024 +quantizer: + _target_: modules.astral_quantization.bsq.BinarySphericalQuantize + codebook_size: 2048 # codebook size, must be a power of 2 + dim: 512 + entropy_loss_weight: 0.1 + diversity_gamma: 1.0 + spherical: True + enable_entropy_loss: True + soft_entropy_loss: True +decoder: + _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage + dim: 512 + num_blocks: 12 + intermediate_dim: 1536 + dilation: 1 + output_dim: 1024 + gin_channels: 192 +asr_decoder: + _target_: modules.astral_quantization.asr_decoder.ASRDecoder + hidden_dim: 768 + num_heads: 12 + depth: 12 + block_size: 4096 + in_channels: 512 + n_vocab: 51866 + bos_id: 50528 + eos_id: 50527 + dropout_rate: 0.0 + attn_dropout_rate: 0.0 \ No newline at end of file diff --git a/configs/astral_quantization/default_32.yml b/configs/astral_quantization/default_32.yml new file mode 100644 index 0000000000000000000000000000000000000000..bf160129893fb893eed26eabcb8a2da9c42a7159 --- /dev/null +++ b/configs/astral_quantization/default_32.yml @@ -0,0 +1,40 @@ +_target_: default_model.AstralQuantizer +tokenizer_name: "openai/whisper-small" +ssl_model_name: "facebook/hubert-large-ll60k" +ssl_output_layer: 18 +encoder: + _target_: modules.convnext.ConvNeXtV2Stage + dim: 512 + num_blocks: 12 + intermediate_dim: 1536 + dilation: 1 + input_dim: 1024 +quantizer: + _target_: modules.bsq.BinarySphericalQuantize + codebook_size: 32 # codebook size, must be a power of 2 + dim: 512 + entropy_loss_weight: 0.1 + diversity_gamma: 1.0 + spherical: True + enable_entropy_loss: True + soft_entropy_loss: True +decoder: + _target_: modules.convnext.ConvNeXtV2Stage + dim: 512 + num_blocks: 12 + intermediate_dim: 1536 + dilation: 1 + output_dim: 1024 + gin_channels: 192 +asr_decoder: + _target_: modules.asr_decoder.ASRDecoder + hidden_dim: 768 + num_heads: 12 + depth: 12 + block_size: 4096 + in_channels: 512 + n_vocab: 51866 + bos_id: 50528 + eos_id: 50527 + dropout_rate: 0.0 + attn_dropout_rate: 0.0 \ No newline at end of file diff --git a/configs/config.json b/configs/config.json new file mode 100644 index 0000000000000000000000000000000000000000..e74f0b4898f6e47e1d198b62cdba989784ce2bb0 --- /dev/null +++ b/configs/config.json @@ -0,0 +1 @@ +{"reference_audio_path": "D:/FAcodec/test_waves/kobe_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS 2.4", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS 2.4", "sr_type": "sr_model", "diffusion_steps": 10.0, "inference_cfg_rate": 0.0, "max_prompt_length": 3.0, "block_time": 0.7, "crossfade_length": 0.04, "extra_time": 0.5, "extra_time_right": 0.02} \ No newline at end of file diff --git a/configs/inuse/.gitignore b/configs/inuse/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/inuse/config.json b/configs/inuse/config.json new file mode 100644 index 0000000000000000000000000000000000000000..17d2df311cd2b19c1888c8fbb82effbfbc6edef3 --- /dev/null +++ b/configs/inuse/config.json @@ -0,0 +1 @@ +{"reference_audio_path": "D:/seed-vc/examples/reference/trump_0.wav", "sg_hostapi": "MME", "sg_wasapi_exclusive": false, "sg_input_device": "\u9ea6\u514b\u98ce (Razer BlackShark V2 HS USB", "sg_output_device": "\u626c\u58f0\u5668 (Razer BlackShark V2 HS USB", "sr_type": "sr_model", "diffusion_steps": 8.0, "inference_cfg_rate": 0.7, "max_prompt_length": 3.0, "block_time": 0.58, "crossfade_length": 0.04, "extra_time_ce": 2.5, "extra_time": 0.5, "extra_time_right": 0.02} \ No newline at end of file diff --git a/configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml b/configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml new file mode 100644 index 0000000000000000000000000000000000000000..0ec7ef4f6d8c7cf160687747ae01e0d39f6c128d --- /dev/null +++ b/configs/presets/config_dit_mel_seed_uvit_whisper_base_f0_44k.yml @@ -0,0 +1,98 @@ +log_dir: "./runs" +save_freq: 1 +log_interval: 10 +save_interval: 1000 +device: "cuda" +epochs: 1000 # number of epochs for first stage training (pre-training) +batch_size: 1 +batch_length: 100 # maximum duration of audio in a batch (in seconds) +max_len: 80 # maximum number of frames +pretrained_model: "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth" +pretrained_encoder: "" +load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters + +preprocess_params: + sr: 44100 + spect_params: + n_fft: 2048 + win_length: 2048 + hop_length: 512 + n_mels: 128 + fmin: 0 + fmax: "None" + +model_params: + dit_type: "DiT" # uDiT or DiT + reg_loss_type: "l1" # l1 or l2 + + timbre_shifter: + se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt" + ckpt_path: './modules/openvoice/checkpoints_v2/converter' + + vocoder: + type: "bigvgan" + name: "nvidia/bigvgan_v2_44khz_128band_512x" + + speech_tokenizer: + type: 'whisper' + name: "openai/whisper-small" + + style_encoder: + dim: 192 + campplus_path: "campplus_cn_common.bin" + + DAC: + encoder_dim: 64 + encoder_rates: [2, 5, 5, 6] + decoder_dim: 1536 + decoder_rates: [ 6, 5, 5, 2 ] + sr: 24000 + + length_regulator: + channels: 768 + is_discrete: false + in_channels: 768 + content_codebook_size: 2048 + sampling_ratios: [1, 1, 1, 1] + vector_quantize: false + n_codebooks: 1 + quantizer_dropout: 0.0 + f0_condition: true + n_f0_bins: 256 + + DiT: + hidden_dim: 768 + num_heads: 12 + depth: 17 + class_dropout_prob: 0.1 + block_size: 8192 + in_channels: 128 + style_condition: true + final_layer_type: 'mlp' + target: 'mel' # mel or codec + content_dim: 768 + content_codebook_size: 1024 + content_type: 'discrete' + f0_condition: true + n_f0_bins: 256 + content_codebooks: 1 + is_causal: false + long_skip_connection: false + zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token + time_as_token: false + style_as_token: false + uvit_skip_connection: true + add_resblock_in_transformer: false + + wavenet: + hidden_dim: 768 + num_layers: 8 + kernel_size: 5 + dilation_rate: 1 + p_dropout: 0.2 + style_condition: true + +loss_params: + base_lr: 0.0001 + lambda_mel: 45 + lambda_kl: 1.0 \ No newline at end of file diff --git a/configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml b/configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml new file mode 100644 index 0000000000000000000000000000000000000000..492910d163c7773c64f846ee55384e3e8b81ac00 --- /dev/null +++ b/configs/presets/config_dit_mel_seed_uvit_whisper_small_wavenet.yml @@ -0,0 +1,91 @@ +log_dir: "./runs" +save_freq: 1 +log_interval: 10 +save_interval: 1000 +device: "cuda" +epochs: 1000 # number of epochs for first stage training (pre-training) +batch_size: 2 +batch_length: 100 # maximum duration of audio in a batch (in seconds) +max_len: 80 # maximum number of frames +pretrained_model: "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth" +pretrained_encoder: "" +load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters + +preprocess_params: + sr: 22050 + spect_params: + n_fft: 1024 + win_length: 1024 + hop_length: 256 + n_mels: 80 + fmin: 0 + fmax: "None" + +model_params: + dit_type: "DiT" # uDiT or DiT + reg_loss_type: "l1" # l1 or l2 + + timbre_shifter: + se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt" + ckpt_path: './modules/openvoice/checkpoints_v2/converter' + + speech_tokenizer: + type: 'whisper' + name: "openai/whisper-small" + + style_encoder: + dim: 192 + campplus_path: "campplus_cn_common.bin" + + vocoder: + type: "bigvgan" + name: "nvidia/bigvgan_v2_22khz_80band_256x" + + length_regulator: + channels: 512 + is_discrete: false + in_channels: 768 + content_codebook_size: 2048 + sampling_ratios: [1, 1, 1, 1] + vector_quantize: false + n_codebooks: 1 + quantizer_dropout: 0.0 + f0_condition: false + n_f0_bins: 512 + + DiT: + hidden_dim: 512 + num_heads: 8 + depth: 13 + class_dropout_prob: 0.1 + block_size: 8192 + in_channels: 80 + style_condition: true + final_layer_type: 'wavenet' + target: 'mel' # mel or codec + content_dim: 512 + content_codebook_size: 1024 + content_type: 'discrete' + f0_condition: false + n_f0_bins: 512 + content_codebooks: 1 + is_causal: false + long_skip_connection: true + zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token + time_as_token: false + style_as_token: false + uvit_skip_connection: true + add_resblock_in_transformer: false + + wavenet: + hidden_dim: 512 + num_layers: 8 + kernel_size: 5 + dilation_rate: 1 + p_dropout: 0.2 + style_condition: true + +loss_params: + base_lr: 0.0001 + lambda_mel: 45 + lambda_kl: 1.0 \ No newline at end of file diff --git a/configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml b/configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml new file mode 100644 index 0000000000000000000000000000000000000000..e0677397377158dd30ffdf905946fbd297b36bd5 --- /dev/null +++ b/configs/presets/config_dit_mel_seed_uvit_xlsr_tiny.yml @@ -0,0 +1,82 @@ +log_dir: "./runs/" +save_freq: 1 +log_interval: 10 +save_interval: 500 +device: "cuda" +epochs: 1000 # number of epochs for first stage training (pre-training) +batch_size: 2 +batch_length: 100 # maximum duration of audio in a batch (in seconds) +max_len: 80 # maximum number of frames +pretrained_model: "DiT_uvit_tat_xlsr_ema.pth" +pretrained_encoder: "" +load_only_params: False # set to true if do not want to load epoch numbers and optimizer parameters + +preprocess_params: + sr: 22050 + spect_params: + n_fft: 1024 + win_length: 1024 + hop_length: 256 + n_mels: 80 + fmin: 0 + fmax: 8000 + +model_params: + dit_type: "DiT" # uDiT or DiT + reg_loss_type: "l1" # l1 or l2 + diffusion_type: "flow" + + timbre_shifter: + se_db_path: "./modules/openvoice/checkpoints_v2/converter/se_db.pt" + ckpt_path: './modules/openvoice/checkpoints_v2/converter' + + vocoder: + type: "hifigan" + + speech_tokenizer: + type: 'xlsr' + output_layer: 12 + name: 'facebook/wav2vec2-xls-r-300m' + + style_encoder: + dim: 192 + campplus_path: "campplus_cn_common.bin" + + length_regulator: + channels: 384 + is_discrete: false + in_channels: 1024 + content_codebook_size: 1024 + sampling_ratios: [1, 1, 1, 1] + vector_quantize: false + n_codebooks: 2 + quantizer_dropout: 0.0 + f0_condition: false + n_f0_bins: 512 + + DiT: + hidden_dim: 384 + num_heads: 6 + depth: 9 + class_dropout_prob: 0.1 + block_size: 8192 + in_channels: 80 + style_condition: true + final_layer_type: 'mlp' + target: 'mel' # mel or betavae + content_dim: 384 + content_codebook_size: 1024 + content_type: 'discrete' + f0_condition: false + n_f0_bins: 512 + content_codebooks: 1 + is_causal: false + long_skip_connection: false + zero_prompt_speech_token: false # for prompt component, do not input corresponding speech token + time_as_token: true + style_as_token: true + uvit_skip_connection: true + add_resblock_in_transformer: false + +loss_params: + base_lr: 0.0001 \ No newline at end of file diff --git a/configs/v2/ar_base.yaml b/configs/v2/ar_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/v2/dit_small.yaml b/configs/v2/dit_small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..60b93e953d7ec19039af147b7a52002b757a822a --- /dev/null +++ b/configs/v2/dit_small.yaml @@ -0,0 +1,17 @@ +_target_: modules.v2.cfm.CFM +estimator: + _target_: modules.v2.dit_wrapper.DiT + time_as_token: true + style_as_token: true + uvit_skip_connection: false + block_size: 8192 + depth: 13 + num_heads: 8 + hidden_dim: 512 + in_channels: 80 + content_dim: 512 + style_encoder_dim: 192 + class_dropout_prob: 0.1 + dropout_rate: 0.0 + attn_dropout_rate: 0.0 + diff --git a/configs/v2/vc_wrapper.yaml b/configs/v2/vc_wrapper.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c3fe5b84431f53ebd12ec60663f40f61f4a8c231 --- /dev/null +++ b/configs/v2/vc_wrapper.yaml @@ -0,0 +1,105 @@ +_target_: modules.v2.vc_wrapper.VoiceConversionWrapper +sr: 22050 +hop_size: 256 +mel_fn: + _target_: modules.audio.mel_spectrogram + _partial_: true + n_fft: 1024 + win_size: 1024 + hop_size: 256 + num_mels: 80 + sampling_rate: 22050 + fmin: 0 + fmax: null + center: False +cfm: + _target_: modules.v2.cfm.CFM + estimator: + _target_: modules.v2.dit_wrapper.DiT + time_as_token: true + style_as_token: true + uvit_skip_connection: false + block_size: 8192 + depth: 13 + num_heads: 8 + hidden_dim: 512 + in_channels: 80 + content_dim: 512 + style_encoder_dim: 192 + class_dropout_prob: 0.1 + dropout_rate: 0.0 + attn_dropout_rate: 0.0 +cfm_length_regulator: + _target_: modules.v2.length_regulator.InterpolateRegulator + channels: 512 + is_discrete: true + codebook_size: 2048 + sampling_ratios: [ 1, 1, 1, 1 ] + f0_condition: false +ar: + _target_: modules.v2.ar.NaiveWrapper + model: + _target_: modules.v2.ar.NaiveTransformer + config: + _target_: modules.v2.ar.NaiveModelArgs + dropout: 0.0 + rope_base: 10000.0 + dim: 768 + head_dim: 64 + n_local_heads: 2 + intermediate_size: 2304 + n_head: 12 + n_layer: 12 + vocab_size: 2049 # 1 + 1 for eos +ar_length_regulator: + _target_: modules.v2.length_regulator.InterpolateRegulator + channels: 768 + is_discrete: true + codebook_size: 32 + sampling_ratios: [ ] + f0_condition: false +style_encoder: + _target_: modules.campplus.DTDNN.CAMPPlus + feat_dim: 80 + embedding_size: 192 +content_extractor_narrow: + _target_: modules.astral_quantization.default_model.AstralQuantizer + tokenizer_name: "openai/whisper-small" + ssl_model_name: "facebook/hubert-large-ll60k" + ssl_output_layer: 18 + skip_ssl: true + encoder: &bottleneck_encoder + _target_: modules.astral_quantization.convnext.ConvNeXtV2Stage + dim: 512 + num_blocks: 12 + intermediate_dim: 1536 + dilation: 1 + input_dim: 1024 + quantizer: + _target_: modules.astral_quantization.bsq.BinarySphericalQuantize + codebook_size: 32 # codebook size, must be a power of 2 + dim: 512 + entropy_loss_weight: 0.1 + diversity_gamma: 1.0 + spherical: True + enable_entropy_loss: True + soft_entropy_loss: True +content_extractor_wide: + _target_: modules.astral_quantization.default_model.AstralQuantizer + tokenizer_name: "openai/whisper-small" + ssl_model_name: "facebook/hubert-large-ll60k" + ssl_output_layer: 18 + encoder: *bottleneck_encoder + quantizer: + _target_: modules.astral_quantization.bsq.BinarySphericalQuantize + codebook_size: 2048 # codebook size, must be a power of 2 + dim: 512 + entropy_loss_weight: 0.1 + diversity_gamma: 1.0 + spherical: True + enable_entropy_loss: True + soft_entropy_loss: True +vocoder: + _target_: modules.bigvgan.bigvgan.BigVGAN.from_pretrained + pretrained_model_name_or_path: "nvidia/bigvgan_v2_22khz_80band_256x" + use_cuda_kernel: false diff --git a/hf_utils.py b/hf_utils.py index 9f8c7f7d5f1b82efbd788c7327f76c0dc6a9355a..4ae986c1d88d5b1214c9d3101249e90c5c370ca9 100644 --- a/hf_utils.py +++ b/hf_utils.py @@ -2,7 +2,7 @@ import os from huggingface_hub import hf_hub_download -def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename="config.yml"): +def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", config_filename=None): os.makedirs("./checkpoints", exist_ok=True) model_path = hf_hub_download(repo_id=repo_id, filename=model_filename, cache_dir="./checkpoints") if config_filename is None: diff --git a/modules/__pycache__/audio.cpython-310.pyc b/modules/__pycache__/audio.cpython-310.pyc index 79bf4bb4261f2d91fe5d3efb024339a88274162c..651e7ad6c3e297013d527e4a9218ae35f2b92c41 100644 Binary files a/modules/__pycache__/audio.cpython-310.pyc and b/modules/__pycache__/audio.cpython-310.pyc differ diff --git a/modules/__pycache__/commons.cpython-310.pyc b/modules/__pycache__/commons.cpython-310.pyc index 9289cfe3ac9362aaeece6546f5b78bbfea6ca40b..5adfe7b95903d4c3134f74bdf68b5b413b848c97 100644 Binary files a/modules/__pycache__/commons.cpython-310.pyc and b/modules/__pycache__/commons.cpython-310.pyc differ diff --git a/modules/__pycache__/commons.cpython-38.pyc b/modules/__pycache__/commons.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfe6a34f14ee453f4a79de9b39d80fc62b4d7fda Binary files /dev/null and b/modules/__pycache__/commons.cpython-38.pyc differ diff --git a/modules/__pycache__/diffusion_transformer.cpython-310.pyc b/modules/__pycache__/diffusion_transformer.cpython-310.pyc index c721982be6feb37d7ae333799842c197225c9697..4bcbf046526a3c8e13cf180ebea709e89305d848 100644 Binary files a/modules/__pycache__/diffusion_transformer.cpython-310.pyc and b/modules/__pycache__/diffusion_transformer.cpython-310.pyc differ diff --git a/modules/__pycache__/flow_matching.cpython-310.pyc b/modules/__pycache__/flow_matching.cpython-310.pyc index 4f42f2602f27a9f430f3daf9ca3825ebe86172b5..4a6f71cb0ef1cd096cdeed48330dbce1aed55734 100644 Binary files a/modules/__pycache__/flow_matching.cpython-310.pyc and b/modules/__pycache__/flow_matching.cpython-310.pyc differ diff --git a/modules/__pycache__/length_regulator.cpython-310.pyc b/modules/__pycache__/length_regulator.cpython-310.pyc index 301c8f57e7713f62bf83b9c4fa712fe7680f26be..c2ada28a9ce6e0b8a61901f913b5ddde33b0c9ac 100644 Binary files a/modules/__pycache__/length_regulator.cpython-310.pyc and b/modules/__pycache__/length_regulator.cpython-310.pyc differ diff --git a/modules/__pycache__/rmvpe.cpython-310.pyc b/modules/__pycache__/rmvpe.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18f08923f7c01ba34a8719dd56cd1c05fdf4e33d Binary files /dev/null and b/modules/__pycache__/rmvpe.cpython-310.pyc differ diff --git a/modules/astral_quantization/__pycache__/bsq.cpython-310.pyc b/modules/astral_quantization/__pycache__/bsq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..918dc5ae503d2a7da3860b7637f9af7aa81ea466 Binary files /dev/null and b/modules/astral_quantization/__pycache__/bsq.cpython-310.pyc differ diff --git a/modules/astral_quantization/__pycache__/convnext.cpython-310.pyc b/modules/astral_quantization/__pycache__/convnext.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..508b91e576abbb86db68531cda662c8e9daa0fe6 Binary files /dev/null and b/modules/astral_quantization/__pycache__/convnext.cpython-310.pyc differ diff --git a/modules/astral_quantization/__pycache__/default_model.cpython-310.pyc b/modules/astral_quantization/__pycache__/default_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..381bd7f361d729a1640b7eed29bbfdc0515cbcc0 Binary files /dev/null and b/modules/astral_quantization/__pycache__/default_model.cpython-310.pyc differ diff --git a/modules/astral_quantization/bsq.py b/modules/astral_quantization/bsq.py new file mode 100644 index 0000000000000000000000000000000000000000..1b70f3401dafa187fc19666d31a0877f533abe33 --- /dev/null +++ b/modules/astral_quantization/bsq.py @@ -0,0 +1,569 @@ +""" +Lookup Free Quantization +Proposed in https://arxiv.org/abs/2310.05737 + +In the simplest setup, each dimension is quantized into {-1, 1}. +An entropy penalty is used to encourage utilization. +""" + +from math import log2, ceil +from functools import partial, cache +from collections import namedtuple +from contextlib import nullcontext + +import torch.distributed as dist +from torch.distributed import nn as dist_nn + +import torch +from torch import nn, einsum +import torch.nn.functional as F +from torch.nn import Module +from torch.amp import autocast + +from einops import rearrange, reduce, pack, unpack + +# constants + +Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss']) + +LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment']) + +# distributed helpers + +@cache +def is_distributed(): + return dist.is_initialized() and dist.get_world_size() > 1 + +def maybe_distributed_mean(t): + if not is_distributed(): + return t + + dist_nn.all_reduce(t) + t = t / dist.get_world_size() + return t + +# helper functions + +def exists(v): + return v is not None + +def identity(t): + return t + +def default(*args): + for arg in args: + if exists(arg): + return arg() if callable(arg) else arg + return None + +def pack_one(t, pattern): + return pack([t], pattern) + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + +def l2norm(t): + return F.normalize(t, dim = -1) + +# entropy + +def log(t, eps = 1e-5): + return t.clamp(min = eps).log() + +def entropy(prob): + return (-prob * log(prob)).sum(dim=-1) + +# cosine sim linear + +class CosineSimLinear(Module): + def __init__( + self, + dim_in, + dim_out, + scale = 1. + ): + super().__init__() + self.scale = scale + self.weight = nn.Parameter(torch.randn(dim_in, dim_out)) + + def forward(self, x): + x = F.normalize(x, dim = -1) + w = F.normalize(self.weight, dim = 0) + return (x @ w) * self.scale + +def soft_entropy_loss(u, tau=1.0, gamma=1.0): + """ + Compute the soft entropy loss for Binary Spherical Quantization (BSQ). + + Args: + u (torch.Tensor): Input latent embeddings of shape (batch_size, L). + tau (float): Temperature scaling factor. + gamma (float): Weight for the second entropy term. + + Returns: + torch.Tensor: Soft entropy loss. + """ + # Binary quantization: Generate implicit codebook corners + L = u.size(1) # Dimensionality of codebook + corners = torch.tensor([-1.0, 1.0], device=u.device) / (L**0.5) + + # Compute soft quantization probabilities for all dimensions + # q_hat(c|u) for each dimension + prob_matrix = torch.sigmoid(2 * tau * corners.unsqueeze(1) * u.unsqueeze(2)) # Shape: (batch_size, L, 2) + + # Entropy of q_hat(c|u) (independent along each dimension) + entropy_per_dim = -torch.sum(prob_matrix * prob_matrix.log(), dim=-1) # Shape: (batch_size, L) + entropy_term1 = entropy_per_dim.mean() + + # Expected probabilities for dataset entropy (approximation) + expected_probs = prob_matrix.mean(dim=0) # Mean across batch, shape: (L, 2) + entropy_term2 = -torch.sum(expected_probs * expected_probs.log(), dim=-1).mean() + + # Final entropy loss + loss = entropy_term1 - gamma * entropy_term2 + return loss + +# class + +class BinarySphericalQuantize(Module): + def __init__( + self, + *, + dim = None, + codebook_size = None, + entropy_loss_weight = 0.1, + commitment_loss_weight = 0., + diversity_gamma = 1., + straight_through_activation = nn.Identity(), + num_codebooks = 1, + keep_num_codebooks_dim = None, + codebook_scale = 1., # for residual LFQ, codebook scaled down by 2x at each layer + frac_per_sample_entropy = 0.25, # make less than 1. to only use a random fraction of the probs for per sample entropy + has_projections = None, + projection_has_bias = True, + soft_clamp_input_value = None, + cosine_sim_project_in = False, + cosine_sim_project_in_scale = None, + channel_first = None, + experimental_softplus_entropy_loss = False, + entropy_loss_offset = 5., # how much to shift the loss before softplus + spherical = True, # from https://arxiv.org/abs/2406.07548 + force_quantization_f32 = True, # will force the quantization step to be full precision + enable_entropy_loss = True, + soft_entropy_loss = True, + ): + super().__init__() + + # some assert validations + + assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ' + assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})' + + codebook_size = default(codebook_size, lambda: 2 ** dim) + self.codebook_size = codebook_size + + codebook_dim = int(log2(codebook_size)) + codebook_dims = codebook_dim * num_codebooks + dim = default(dim, codebook_dims) + + has_projections = default(has_projections, dim != codebook_dims) + + if cosine_sim_project_in: + cosine_sim_project_in = default(cosine_sim_project_in_scale, codebook_scale) + project_in_klass = partial(CosineSimLinear, scale = cosine_sim_project_in) + else: + project_in_klass = partial(nn.Linear, bias = projection_has_bias) + + self.project_in = project_in_klass(dim, codebook_dims) if has_projections else nn.Identity() + self.project_out = nn.Linear(codebook_dims, dim, bias = projection_has_bias) if has_projections else nn.Identity() + self.has_projections = has_projections + + self.dim = dim + self.codebook_dim = codebook_dim + self.num_codebooks = num_codebooks + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + # channel first + + self.channel_first = channel_first + + # straight through activation + + self.activation = straight_through_activation + + # whether to use BSQ (binary spherical quantization) + + self.spherical = spherical + self.maybe_l2norm = (lambda t: l2norm(t) * self.codebook_scale) if spherical else identity + + # entropy aux loss related weights + + assert 0 < frac_per_sample_entropy <= 1. + self.frac_per_sample_entropy = frac_per_sample_entropy + + self.diversity_gamma = diversity_gamma + self.entropy_loss_weight = entropy_loss_weight + + # codebook scale + + self.codebook_scale = codebook_scale + + # commitment loss + + self.commitment_loss_weight = commitment_loss_weight + + # whether to soft clamp the input value from -value to value + + self.soft_clamp_input_value = soft_clamp_input_value + assert not exists(soft_clamp_input_value) or soft_clamp_input_value >= codebook_scale + + # whether to make the entropy loss positive through a softplus (experimental, please report if this worked or not in discussions) + + self.entropy_loss_offset = entropy_loss_offset + self.experimental_softplus_entropy_loss = experimental_softplus_entropy_loss + + # for no auxiliary loss, during inference + + self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1)) + self.register_buffer('zero', torch.tensor(0.), persistent = False) + + # whether to force quantization step to be f32 + + self.force_quantization_f32 = force_quantization_f32 + + # codes + self.enable_entropy_loss = enable_entropy_loss + self.soft_entropy_loss = soft_entropy_loss + if codebook_size <= 100000: + all_codes = torch.arange(codebook_size) + bits = ((all_codes[..., None].int() & self.mask) != 0).float() + codebook = self.bits_to_codes(bits) + + self.register_buffer('codebook', codebook.float(), persistent = False) + else: + all_codes = torch.arange(pow(2, 16)) + mask = 2 ** torch.arange(16 - 1, -1, -1) + bits = ((all_codes[..., None].int() & mask) != 0).float() + codebook = self.bits_to_codes(bits) + + self.register_buffer('codebook', codebook.float(), persistent = False) + + def bits_to_codes(self, bits): + return bits * self.codebook_scale * 2 - self.codebook_scale + + @property + def dtype(self): + return self.codebook.dtype + + def indices_to_codes( + self, + indices, + project_out = True + ): + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + should_transpose = default(self.channel_first, is_img_or_video) + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, '... -> ... 1') + + # indices to codes, which are bits of either -1 or 1 + + bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype) + + codes = self.bits_to_codes(bits) + + codes = self.maybe_l2norm(codes) + + codes = rearrange(codes, '... c d -> ... (c d)') + + # whether to project codes out to original dimensions + # if the input feature dimensions were not log2(codebook size) + + if project_out: + codes = self.project_out(codes) + + # rearrange codes back to original shape + + if should_transpose: + codes = rearrange(codes, 'b ... d -> b d ...') + + return codes + + def bits_to_z(self, bits): + # assert bits must contain only -1 and 1 + assert torch.all(bits.abs() == 1) + quantized = bits.float() + quantized = self.maybe_l2norm(quantized) + z = self.project_out(quantized) + return z + + def forward( + self, + x, + inv_temperature = 100., + return_loss_breakdown = False, + mask = None, + return_bits = False + ): + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension, which is also log2(codebook size) + c - number of codebook dim + """ + + is_img_or_video = x.ndim >= 4 + should_transpose = default(self.channel_first, is_img_or_video) + + # standardize image or video into (batch, seq, dimension) + + if should_transpose: + x = rearrange(x, 'b d ... -> b ... d') + x, ps = pack_one(x, 'b * d') + + assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}' + + x = self.project_in(x) + + # maybe soft clamp + + if exists(self.soft_clamp_input_value): + clamp_value = self.soft_clamp_input_value + x = (x / clamp_value).tanh() * clamp_value + + # split out number of codebooks + + x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks) + + # maybe l2norm + + x = self.maybe_l2norm(x) + + # whether to force quantization step to be full precision or not + + force_f32 = self.force_quantization_f32 + + quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext + + with quantization_context(): + + if force_f32: + orig_dtype = x.dtype + x = x.float() + + # quantize by eq 3. + + original_input = x + + codebook_value = torch.ones_like(x) * self.codebook_scale + quantized = torch.where(x > 0, codebook_value, -codebook_value) + if return_bits: + return quantized + + # calculate indices + + indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum') + + # maybe l2norm + + quantized = self.maybe_l2norm(quantized) + + # use straight-through gradients (optionally with custom activation fn) if training + + if self.training: + x = self.activation(x) + x = x + (quantized - x).detach() + else: + x = quantized + + # entropy aux loss + if self.soft_entropy_loss: + entropy_aux_loss = soft_entropy_loss(x, tau=1.0, gamma=1.0) + elif self.training and self.enable_entropy_loss: + + if force_f32: + codebook = self.codebook.float() + + codebook = self.maybe_l2norm(codebook) + + # whether to only use a fraction of probs, for reducing memory + + if self.frac_per_sample_entropy < 1.: + # account for mask + if exists(mask): + original_input = original_input[mask] + original_input = rearrange(original_input, 'b n ... -> (b n) ...') + + rand_mask = torch.randn(self.codebook_dim).argsort(dim = -1) < 16 + + sampled_input = original_input[..., rand_mask] + + sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook) + + sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1) + + per_sample_probs = sampled_prob + else: + if exists(mask): + original_input = original_input[mask] + original_input = rearrange(original_input, 'b n ... -> (b n) ...') + # the same as euclidean distance up to a constant + distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook) + + prob = (-distance * inv_temperature).softmax(dim = -1) + + per_sample_probs = prob + + # calculate per sample entropy + + per_sample_entropy = entropy(per_sample_probs).mean() + + # distribution over all available tokens in the batch + + avg_prob = reduce(per_sample_probs, '... c d -> c d', 'mean') + + avg_prob = maybe_distributed_mean(avg_prob) + + codebook_entropy = entropy(avg_prob).mean() + + # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions + # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch + + entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy + else: + # if not training, just return dummy 0 + entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero + + # whether to make the entropy loss positive or not through a (shifted) softplus + + if self.training and self.experimental_softplus_entropy_loss: + entropy_aux_loss = F.softplus(entropy_aux_loss + self.entropy_loss_offset) + + # commit loss + + if self.training and self.commitment_loss_weight > 0.: + + commit_loss = F.mse_loss(original_input, quantized.detach(), reduction = 'none') + + if exists(mask): + commit_loss = commit_loss[mask] + + commit_loss = commit_loss.mean() + else: + commit_loss = self.zero + + # input back to original dtype if needed + + if force_f32: + x = x.type(orig_dtype) + + # merge back codebook dim + + x = rearrange(x, 'b n c d -> b n (c d)') + + # project out to feature dimension if needed + + x = self.project_out(x) + + # reconstitute image or video dimensions + + if should_transpose: + x = unpack_one(x, ps, 'b * d') + x = rearrange(x, 'b ... d -> b d ...') + + indices = unpack_one(indices, ps, 'b * c') + + # whether to remove single codebook dim + + if not self.keep_num_codebooks_dim: + indices = rearrange(indices, '... 1 -> ...') + + # complete aux loss + + aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight + + # returns + + ret = Return(x, indices, aux_loss) + + if not return_loss_breakdown: + return ret + + return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss) + +class GroupedResidualBSQ(Module): + def __init__( + self, + *, + dim, + groups = 1, + accept_image_fmap = False, + **kwargs + ): + super().__init__() + self.dim = dim + self.groups = groups + assert (dim % groups) == 0 + dim_per_group = dim // groups + + self.accept_image_fmap = accept_image_fmap + + self.rvqs = nn.ModuleList([]) + + for _ in range(groups): + self.rvqs.append(LFQ( + dim = dim_per_group, + **kwargs + )) + + self.codebook_size = self.rvqs[0].codebook_size + + @property + def codebooks(self): + return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs)) + + @property + def split_dim(self): + return 1 if self.accept_image_fmap else -1 + + def get_codes_from_indices(self, indices): + codes = tuple(rvq.get_codes_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices)) + return torch.stack(codes) + + def get_output_from_indices(self, indices): + outputs = tuple(rvq.get_output_from_indices(chunk_indices) for rvq, chunk_indices in zip(self.rvqs, indices)) + return torch.cat(outputs, dim = self.split_dim) + + def forward( + self, + x, + return_all_codes = False + ): + shape, split_dim = x.shape, self.split_dim + assert shape[split_dim] == self.dim + + # split the feature dimension into groups + + x = x.chunk(self.groups, dim = split_dim) + + forward_kwargs = dict( + ) + + # invoke residual vq on each group + + out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x)) + out = tuple(zip(*out)) + + # otherwise, get all the zipped outputs and combine them + + quantized, all_indices, *maybe_aux_loss = out + + quantized = torch.cat(quantized, dim = split_dim) + all_indices = torch.stack(all_indices) + + ret = (quantized, all_indices, *maybe_aux_loss) + return ret diff --git a/modules/astral_quantization/convnext.py b/modules/astral_quantization/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..7bef9e282ac332b7169339eeda082d0581d739d1 --- /dev/null +++ b/modules/astral_quantization/convnext.py @@ -0,0 +1,209 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import List + + +class ConvNextV2LayerNorm(nn.Module): + r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, + width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError(f"Unsupported data format: {self.data_format}") + self.normalized_shape = (normalized_shape,) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.data_format == "channels_last": + x = torch.nn.functional.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + input_dtype = x.dtype + x = x.float() + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = x.to(dtype=input_dtype) + x = self.weight[None, :, None] * x + self.bias[None, :, None] + return x + + +class GRN(nn.Module): + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=1, keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + +class InterpolationLayer(nn.Module): + def __init__(self, ): # this is a default of 1 / 50 * (44100 / 512) / 4 + super().__init__() + pass + + def forward(self, x: torch.Tensor, target_len: torch.Tensor, *args, **kwargs) -> torch.Tensor: + x = F.interpolate(x, size=target_len, mode='linear') + return x + +class ConvNeXtV2Stage(nn.Module): + def __init__( + self, + dim: int = 512, + intermediate_dim: int = 2048, + num_blocks: int = 1, + dilation: int = 1, + downsample_layer_indices: List[int] = None, + downsample_factors: List[int] = None, + upsample_layer_indices: List[int] = None, + upsample_factors: List[int] = None, + interpolation_layer_indices: List[int] = None, + input_dim: int = None, + output_dim: int = None, + gin_channels: int = 0, + ): + super().__init__() + # maybe downsample layers + if downsample_layer_indices is not None: + assert downsample_factors is not None + self.downsample_blocks = nn.ModuleList( + [ + nn.Sequential( + ConvNextV2LayerNorm(dim, data_format="channels_first"), + nn.Conv1d( + dim, dim, kernel_size=downsample_factor, stride=downsample_factor + ), + ) for _, downsample_factor in zip(downsample_layer_indices, downsample_factors) + ] + ) + self.downsample_layer_indices = downsample_layer_indices + else: + self.downsample_blocks = nn.ModuleList() + self.downsample_layer_indices = [] + + # maybe upsample layers + if upsample_layer_indices is not None: + assert upsample_factors is not None + self.upsample_blocks = nn.ModuleList( + [ + nn.Sequential( + ConvNextV2LayerNorm(dim, data_format="channels_first"), + nn.ConvTranspose1d( + dim, dim, kernel_size=upsample_factor, stride=upsample_factor + ), + ) for _, upsample_factor in zip(upsample_layer_indices, upsample_factors) + ] + ) + self.upsample_layer_indices = upsample_layer_indices + else: + self.upsample_blocks = nn.ModuleList() + self.upsample_layer_indices = [] + + # maybe interpolation layers + if interpolation_layer_indices is not None: + self.interpolation_blocks = nn.ModuleList( + [ + InterpolationLayer() + for _ in interpolation_layer_indices + ] + ) + self.interpolation_layer_indices = interpolation_layer_indices + else: + self.interpolation_blocks = nn.ModuleList() + self.interpolation_layer_indices = [] + + # main blocks + self.blocks = nn.ModuleList( + [ + ConvNeXtV2Block( + dim=dim, + intermediate_dim=intermediate_dim, + dilation=dilation, + ) + for _ in range(num_blocks) + ] + ) + # maybe input and output projections + if input_dim is not None and input_dim != dim: + self.input_projection = nn.Conv1d(input_dim, dim, kernel_size=1) + else: + self.input_projection = nn.Identity() + if output_dim is not None and output_dim != dim: + self.output_projection = nn.Conv1d(dim, output_dim, kernel_size=1) + else: + self.output_projection = nn.Identity() + + if gin_channels > 0: + self.gin = nn.Conv1d(gin_channels, dim, kernel_size=1) + + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + x = self.input_projection(x) # B, D, T + if hasattr(self, 'gin'): + g = kwargs['g'] + x = x + self.gin(g) + # pad to a multiple of cumprod(downsample_factors) + if len(self.downsample_blocks) > 0: + downsample_factor = 1 + for factor in self.downsample_blocks: + downsample_factor *= factor[1].stride[0] + pad_len = downsample_factor - x.size(-1) % downsample_factor + if pad_len > 0: + x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1) + + # main blocks + for layer_idx, block in enumerate(self.blocks): + if layer_idx in self.downsample_layer_indices: + x = self.downsample_blocks[self.downsample_layer_indices.index(layer_idx)](x) + if layer_idx in self.upsample_layer_indices: + x = self.upsample_blocks[self.upsample_layer_indices.index(layer_idx)](x) + if layer_idx in self.interpolation_layer_indices: + x = self.interpolation_blocks[self.interpolation_layer_indices.index(layer_idx)](x, target_len=kwargs['target_len']) + x = block(x) + x = self.output_projection(x) + return x + + def setup_caches(self, *args, **kwargs): + pass + + +class ConvNeXtV2Block(nn.Module): + def __init__( + self, + dim: int, + intermediate_dim: int, + dilation: int = 1, + ): + super().__init__() + padding = (dilation * (7 - 1)) // 2 + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation + ) # depthwise conv + self.norm = ConvNextV2LayerNorm(dim, data_format="channels_first") + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(intermediate_dim) + self.pwconv2 = nn.Linear(intermediate_dim, dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = self.norm(x) + x = x.transpose(1, 2) # b d n -> b n d + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.transpose(1, 2) # b n d -> b d n + return residual + x \ No newline at end of file diff --git a/modules/astral_quantization/default_model.py b/modules/astral_quantization/default_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ca50d1cb1b177a89955c9f13451e5ed159b753b5 --- /dev/null +++ b/modules/astral_quantization/default_model.py @@ -0,0 +1,73 @@ +import torch +from transformers import AutoTokenizer, AutoModel, Wav2Vec2FeatureExtractor + +class AstralQuantizer(torch.nn.Module): + def __init__( + self, + tokenizer_name: str, + ssl_model_name: str, + ssl_output_layer: int, + encoder: torch.nn.Module, + quantizer: torch.nn.Module, + skip_ssl: bool = False, + ): + super().__init__() + self.encoder = encoder + self.quantizer = quantizer + self.tokenizer_name = tokenizer_name + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + + # Load SSL model from Huggingface + self.ssl_model_name = ssl_model_name + self.ssl_output_layer = ssl_output_layer + self.ssl_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(ssl_model_name) + + if skip_ssl: # in case the same SSL model has been loaded somewhere else + self.ssl_model = None + else: + self.ssl_model = AutoModel.from_pretrained(ssl_model_name).eval() + self.ssl_model.encoder.layers = self.ssl_model.encoder.layers[:ssl_output_layer] + self.ssl_model.encoder.layer_norm = torch.nn.Identity() + + def load_separate_checkpoint(self, checkpoint_path): + params = torch.load(checkpoint_path, map_location='cpu')['net'] + for key in params.keys(): + for k in list(params[key].keys()): + if k.startswith("module."): + params[key][k[len("module."):]] = params[key][k] + del params[key][k] + self.encoder.load_state_dict(params['encoder']) + self.quantizer.load_state_dict(params['vq']) + if self.decoder is not None: + self.decoder.load_state_dict(params['decoder']) + if self.asr_decoder is not None: + self.asr_decoder.load_state_dict(params['predictor'], strict=False) + + def forward(self, waves_16k, wave_16k_lens, ssl_model=None): + ssl_fn = self.ssl_model if self.ssl_model else ssl_model + assert ssl_fn is not None, "In case in-class SSL model loading is skipped, external ssl_model must be provided" + waves_16k_input_list = [ + waves_16k[bib, :wave_16k_lens[bib]].cpu().numpy() + for bib in range(len(waves_16k)) + ] + alt_inputs = self.ssl_feature_extractor( + waves_16k_input_list, + return_tensors='pt', + return_attention_mask=True, + padding=True, + sampling_rate=16000 + ).to(waves_16k.device) + feature_lens = alt_inputs.data['attention_mask'].sum(-1) // 320 # frame rate of hubert is 50 Hz + + outputs = ssl_fn( + alt_inputs.input_values, + attention_mask=alt_inputs.attention_mask, + ) + last_hidden_states = outputs.last_hidden_state + last_hidden_states = last_hidden_states[:, :feature_lens.max(), :] + feature_lens = feature_lens.clamp(max=last_hidden_states.size(1)) + last_hidden_states = last_hidden_states.transpose(1, 2) + x_hidden = self.encoder(last_hidden_states, feature_lens) + x_hidden = x_hidden.transpose(1, 2) + x_quantized, indices = self.quantizer(x_hidden)[:2] + return x_quantized, indices, feature_lens \ No newline at end of file diff --git a/modules/astral_quantization/transformer.py b/modules/astral_quantization/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..015ec417d02e3442d8ea120c3802807c73e89bb4 --- /dev/null +++ b/modules/astral_quantization/transformer.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F +import time + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if embedding is None: + return self.norm(input) + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + has_cross_attention: bool = False + context_dim: int = 0 + is_causal: bool = False + dropout_rate: float = 0.1 + attn_dropout_rate: float = 0.1 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + # self.head_dim = self.dim // self.n_head + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + + self.max_batch_size = -1 + self.max_seq_length = config.block_size + freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, + self.config.rope_base) + self.register_buffer("freqs_cis", freqs_cis) + + causal_mask = torch.tril( + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) + ) + self.register_buffer("causal_mask", causal_mask) + + def forward(self, + x: Tensor, + c: Tensor, + input_pos: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + context: Optional[Tensor] = None, + context_input_pos: Optional[Tensor] = None, + cross_attention_mask: Optional[Tensor] = None, + ) -> Tensor: + if mask is None: + mask = self.causal_mask[:x.size(1), :x.size(1)] + else: + mask = mask[..., input_pos] + freqs_cis = self.freqs_cis[input_pos] + if context is not None: + context_freqs_cis = self.freqs_cis[context_input_pos] + else: + context_freqs_cis = None + skip_in_x_list = [] + for i, layer in enumerate(self.layers): + x = layer(x, c, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask) + x = self.norm(x, c) + return x + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + + if config.has_cross_attention: + self.has_cross_attention = True + self.cross_attention = Attention(config, is_cross_attention=True) + self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + else: + self.has_cross_attention = False + + def forward(self, + x: Tensor, + c: Tensor, + freqs_cis: Tensor, + mask: Tensor, + context: Optional[Tensor] = None, + context_freqs_cis: Optional[Tensor] = None, + cross_attention_mask: Optional[Tensor] = None, + ) -> Tensor: + #time_attn_start = time.time() + h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask) + #print(f"time take for attention of sequence length {x.shape[1]} is {time.time() - time_attn_start}") + if self.has_cross_attention: + h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, context, context_freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h, c)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs, is_cross_attention: bool = False): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + if is_cross_attention: + self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False) + self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False) + else: + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self.attn_dropout_rate = config.attn_dropout_rate + + def forward(self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + context: Optional[Tensor] = None, + context_freqs_cis: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + if context is None: + q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) + context_seqlen = seqlen + else: + q = self.wq(x) + k, v = self.wkv(context).split([kv_size, kv_size], dim=-1) + context_seqlen = context.shape[1] + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_dropout_rate if self.training else 0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x))) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000, + dtype: torch.dtype = torch.bfloat16 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) + diff --git a/modules/audio.py b/modules/audio.py index abe783b0e0af630319700c931eb51d2ce375282b..ae677ffb1c124b557b3dbe0343ae415f5281cddb 100644 --- a/modules/audio.py +++ b/modules/audio.py @@ -1,82 +1,82 @@ -import numpy as np -import torch -import torch.utils.data -from librosa.filters import mel as librosa_mel_fn -from scipy.io.wavfile import read - -MAX_WAV_VALUE = 32768.0 - - -def load_wav(full_path): - sampling_rate, data = read(full_path) - return data, sampling_rate - - -def dynamic_range_compression(x, C=1, clip_val=1e-5): - return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) - - -def dynamic_range_decompression(x, C=1): - return np.exp(x) / C - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression_torch(x, C=1): - return torch.exp(x) / C - - -def spectral_normalize_torch(magnitudes): - output = dynamic_range_compression_torch(magnitudes) - return output - - -def spectral_de_normalize_torch(magnitudes): - output = dynamic_range_decompression_torch(magnitudes) - return output - - -mel_basis = {} -hann_window = {} - - -def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) - - global mel_basis, hann_window # pylint: disable=global-statement - if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) - hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[str(sampling_rate) + "_" + str(y.device)], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) - - spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec) - spec = spectral_normalize_torch(spec) - - return spec +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window # pylint: disable=global-statement + if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis: + mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[str(sampling_rate) + "_" + str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec diff --git a/modules/bigvgan/__pycache__/activations.cpython-310.pyc b/modules/bigvgan/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bedc597b24c34a77805b188e40b71f1d5f221118 Binary files /dev/null and b/modules/bigvgan/__pycache__/activations.cpython-310.pyc differ diff --git a/modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc b/modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dba1f3aac51089857f5a8226f45af8441fbf65a3 Binary files /dev/null and b/modules/bigvgan/__pycache__/bigvgan.cpython-310.pyc differ diff --git a/modules/bigvgan/__pycache__/env.cpython-310.pyc b/modules/bigvgan/__pycache__/env.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c5f5045d1715338beb0335539cb76f229124f54 Binary files /dev/null and b/modules/bigvgan/__pycache__/env.cpython-310.pyc differ diff --git a/modules/bigvgan/__pycache__/meldataset.cpython-310.pyc b/modules/bigvgan/__pycache__/meldataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58995f652192d5e524e02445b22451a1d8ea87b2 Binary files /dev/null and b/modules/bigvgan/__pycache__/meldataset.cpython-310.pyc differ diff --git a/modules/bigvgan/__pycache__/utils.cpython-310.pyc b/modules/bigvgan/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..763d6d7ead834b85c9ea9aa27f306f8d59041fd5 Binary files /dev/null and b/modules/bigvgan/__pycache__/utils.cpython-310.pyc differ diff --git a/modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc b/modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7692a9a3df20d2fe2b7923398589e05741607fd2 Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc b/modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e13bedd060a666d59013e80b43573fad05e3516 Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/__pycache__/activation1d.cpython-310.pyc differ diff --git a/modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc b/modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..089a26db09f702d6a6984743f4befb52b1c80fbf Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/__pycache__/load.cpython-310.pyc differ diff --git a/modules/bigvgan/alias_free_activation/cuda/activation1d.py b/modules/bigvgan/alias_free_activation/cuda/activation1d.py index fc0d313cb265170943fb7cb16742b031038f7859..76797ef1424b99a160803d53cb6c24fe20599bd2 100644 --- a/modules/bigvgan/alias_free_activation/cuda/activation1d.py +++ b/modules/bigvgan/alias_free_activation/cuda/activation1d.py @@ -3,10 +3,10 @@ import torch import torch.nn as nn -from alias_free_activation.torch.resample import UpSample1d, DownSample1d +from ..torch.resample import UpSample1d, DownSample1d # load fused CUDA kernel: this enables importing anti_alias_activation_cuda -from alias_free_activation.cuda import load +from ..cuda import load anti_alias_activation_cuda = load.load() diff --git a/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps b/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps new file mode 100644 index 0000000000000000000000000000000000000000..ce9bc6ec886c4bda48c73677998abe0b2b73bfcc --- /dev/null +++ b/modules/bigvgan/alias_free_activation/cuda/build/.ninja_deps @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e233713716a5778577f244b0f310944ff26d3079ce0e42491791da7d42e363c1 +size 522068 diff --git a/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log b/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log new file mode 100644 index 0000000000000000000000000000000000000000..bd3c097a5622dde1a5c17fd152e04750a1dedded --- /dev/null +++ b/modules/bigvgan/alias_free_activation/cuda/build/.ninja_log @@ -0,0 +1,7 @@ +# ninja log v5 +9 39554 7516864785377831 anti_alias_activation.o 3a177f31dd72c43c +13 152601 7516865914203767 anti_alias_activation_cuda.cuda.o 2d613e7382d803fd +152628 153062 7516865920541751 anti_alias_activation_cuda.pyd f6366e9bdfb27f7 +128 50503 7654004565901584 anti_alias_activation.o 9ed3213f2e0d0858 +133 176837 7654005827401976 anti_alias_activation_cuda.cuda.o a679b6661c609136 +176839 177401 7654005835005523 anti_alias_activation_cuda.pyd f6366e9bdfb27f7 diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o new file mode 100644 index 0000000000000000000000000000000000000000..812f06975323c9e1937fa01c943e31ae02322145 --- /dev/null +++ b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74c2824b05582070b69f51ec588aadb268c4fddf18fbb4590f901d1cdf32185c +size 3246655 diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o new file mode 100644 index 0000000000000000000000000000000000000000..329fb7a9b147a0af665ff7f7686bd0cc915ecc84 --- /dev/null +++ b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.cuda.o @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:86c48de557041de7ebaff7926b5f346cc5e4e2dddc6cf5b88409f6cb161db0f4 +size 4724513 diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp new file mode 100644 index 0000000000000000000000000000000000000000..3093a741ef126748042cafaef5c368f3ec5e2d3f Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.exp differ diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib new file mode 100644 index 0000000000000000000000000000000000000000..1be22a5e2a68606c56c961333b67a251bf40d8ea Binary files /dev/null and b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.lib differ diff --git a/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd new file mode 100644 index 0000000000000000000000000000000000000000..dc51b91fc3d147e24ad0155ee31809556bb7208a --- /dev/null +++ b/modules/bigvgan/alias_free_activation/cuda/build/anti_alias_activation_cuda.pyd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db37ea2dd31dfe67e68ee6019877d14638c41724ff9342c55f638f4d2cda3d03 +size 2454528 diff --git a/modules/bigvgan/alias_free_activation/cuda/build/build.ninja b/modules/bigvgan/alias_free_activation/cuda/build/build.ninja new file mode 100644 index 0000000000000000000000000000000000000000..8c41cf88948be2657b26c226365a86b99278764a --- /dev/null +++ b/modules/bigvgan/alias_free_activation/cuda/build/build.ninja @@ -0,0 +1,38 @@ +ninja_required_version = 1.3 +cxx = cl +nvcc = C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin\nvcc + +cflags = -DTORCH_EXTENSION_NAME=anti_alias_activation_cuda -DTORCH_API_INCLUDE_EXTENSION_H -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\include" -ID:\Anaconda\envs\vocos\Include /std:c++17 -O3 /MD /wd4819 /wd4251 /wd4244 /wd4267 /wd4275 /wd4018 /wd4190 /wd4624 /wd4067 /wd4068 /EHsc +post_cflags = +cuda_cflags = -Xcudafe --diag_suppress=dll_interface_conflict_dllexport_assumed -Xcudafe --diag_suppress=dll_interface_conflict_none_assumed -Xcudafe --diag_suppress=field_without_dll_interface -Xcudafe --diag_suppress=base_class_has_different_dll_interface -Xcompiler /EHsc -Xcompiler /wd4068 -Xcompiler /wd4067 -Xcompiler /wd4624 -Xcompiler /wd4190 -Xcompiler /wd4018 -Xcompiler /wd4275 -Xcompiler /wd4267 -Xcompiler /wd4244 -Xcompiler /wd4251 -Xcompiler /wd4819 -Xcompiler /MD -DTORCH_EXTENSION_NAME=anti_alias_activation_cuda -DTORCH_API_INCLUDE_EXTENSION_H -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include -ID:\Anaconda\envs\vocos\lib\site-packages\torch\include\torch\csrc\api\include "-IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\include" -ID:\Anaconda\envs\vocos\Include -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++17 -O3 -gencode arch=compute_70,code=sm_70 --use_fast_math -U__CUDA_NO_HALF_OPERATORS__ -U__CUDA_NO_HALF_CONVERSIONS__ --expt-relaxed-constexpr --expt-extended-lambda -gencode arch=compute_80,code=sm_80 +cuda_post_cflags = +cuda_dlink_post_cflags = +sycl_dlink_post_cflags = +ldflags = /DLL c10.lib c10_cuda.lib torch_cpu.lib torch_cuda.lib -INCLUDE:?warp_size@cuda@at@@YAHXZ torch.lib /LIBPATH:D:\Anaconda\envs\vocos\lib\site-packages\torch\lib torch_python.lib /LIBPATH:D:\Anaconda\envs\vocos\libs "/LIBPATH:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\lib\x64" cudart.lib + +rule compile + command = cl /showIncludes $cflags -c $in /Fo$out $post_cflags + deps = msvc + +rule cuda_compile + depfile = $out.d + deps = gcc + command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags + + + + + +rule link + command = "D$:\Visual Studio\VC\Tools\MSVC\14.29.30133\bin\Hostx86\x64/link.exe" $in /nologo $ldflags /out:$out + +build anti_alias_activation.o: compile D$:\seed-vc\modules\bigvgan\alias_free_activation\cuda\anti_alias_activation.cpp +build anti_alias_activation_cuda.cuda.o: cuda_compile D$:\seed-vc\modules\bigvgan\alias_free_activation\cuda\anti_alias_activation_cuda.cu + + + + + +build anti_alias_activation_cuda.pyd: link anti_alias_activation.o anti_alias_activation_cuda.cuda.o + +default anti_alias_activation_cuda.pyd diff --git a/modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc b/modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76ee62bc7f4fd61a4da3faf3b5b608082eecd92a Binary files /dev/null and b/modules/bigvgan/alias_free_activation/torch/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc b/modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6235f3b7476dd85812d5100a2d6ea941bca1280b Binary files /dev/null and b/modules/bigvgan/alias_free_activation/torch/__pycache__/act.cpython-310.pyc differ diff --git a/modules/bigvgan/alias_free_activation/torch/__pycache__/filter.cpython-310.pyc b/modules/bigvgan/alias_free_activation/torch/__pycache__/filter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb021ff5f7292a9def47b079d56dbfcc5a18f150 Binary files /dev/null and b/modules/bigvgan/alias_free_activation/torch/__pycache__/filter.cpython-310.pyc differ diff --git a/modules/bigvgan/alias_free_activation/torch/__pycache__/resample.cpython-310.pyc b/modules/bigvgan/alias_free_activation/torch/__pycache__/resample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17524739fecc6a6ffecbc9ec37007645cd6b422d Binary files /dev/null and b/modules/bigvgan/alias_free_activation/torch/__pycache__/resample.cpython-310.pyc differ diff --git a/modules/bigvgan/bigvgan.py b/modules/bigvgan/bigvgan.py index 5a1196fa9fc6bca4276e23d5fe659e3f5af9b04a..41d6e44a2cb59a39d51cb4994d05804dc497dda7 100644 --- a/modules/bigvgan/bigvgan.py +++ b/modules/bigvgan/bigvgan.py @@ -42,15 +42,15 @@ class AMPBlock1(torch.nn.Module): """ def __init__( - self, - h: AttrDict, - channels: int, - kernel_size: int = 3, - dilation: tuple = (1, 3, 5), - activation: str = None, + self, + h: AttrDict, + channels: int, + kernel_size: int = 3, + dilation: tuple = (1, 3, 5), + activation: str = None, ): super().__init__() - + self.h = h self.convs1 = nn.ModuleList( @@ -93,7 +93,7 @@ class AMPBlock1(torch.nn.Module): # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( + from .alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) @@ -161,15 +161,15 @@ class AMPBlock2(torch.nn.Module): """ def __init__( - self, - h: AttrDict, - channels: int, - kernel_size: int = 3, - dilation: tuple = (1, 3, 5), - activation: str = None, + self, + h: AttrDict, + channels: int, + kernel_size: int = 3, + dilation: tuple = (1, 3, 5), + activation: str = None, ): super().__init__() - + self.h = h self.convs = nn.ModuleList( @@ -193,7 +193,7 @@ class AMPBlock2(torch.nn.Module): # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( + from .alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) @@ -270,7 +270,7 @@ class BigVGAN( # Select which Activation1d, lazy-load cuda version to ensure backward compatibility if self.h.get("use_cuda_kernel", False): - from alias_free_activation.cuda.activation1d import ( + from .alias_free_activation.cuda.activation1d import ( Activation1d as CudaActivation1d, ) @@ -304,7 +304,7 @@ class BigVGAN( [ weight_norm( ConvTranspose1d( - h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** i), h.upsample_initial_channel // (2 ** (i + 1)), k, u, @@ -320,7 +320,7 @@ class BigVGAN( for i in range(len(self.ups)): ch = h.upsample_initial_channel // (2 ** (i + 1)) for j, (k, d) in enumerate( - zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) ): self.resblocks.append( resblock_class(h, ch, k, d, activation=h.activation) @@ -412,20 +412,20 @@ class BigVGAN( @classmethod def _from_pretrained( - cls, - *, - model_id: str, - revision: str, - cache_dir: str, - force_download: bool, - proxies: Optional[Dict], - resume_download: bool, - local_files_only: bool, - token: Union[str, bool, None], - map_location: str = "cpu", # Additional argument - strict: bool = False, # Additional argument - use_cuda_kernel: bool = False, - **model_kwargs, + cls, + *, + model_id: str, + revision: str, + cache_dir: str, + force_download: bool, + proxies: Optional[Dict], + resume_download: bool, + local_files_only: bool, + token: Union[str, bool, None], + map_location: str = "cpu", # Additional argument + strict: bool = False, # Additional argument + use_cuda_kernel: bool = False, + **model_kwargs, ): """Load Pytorch pretrained weights and return the loaded model.""" @@ -489,4 +489,4 @@ class BigVGAN( model.remove_weight_norm() model.load_state_dict(checkpoint_dict["generator"]) - return model + return model \ No newline at end of file diff --git a/modules/commons.py b/modules/commons.py index 350208e50ba8630e53c30847db345cc3ace77473..fb0f6f89ef550e7570beffdef1d438e7dc51259f 100644 --- a/modules/commons.py +++ b/modules/commons.py @@ -1,490 +1,476 @@ -import math -import numpy as np -import torch -from torch import nn -from torch.nn import functional as F -from munch import Munch -import json - - -class AttrDict(dict): - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - - -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - -def intersperse(lst, item): - result = [item] * (len(lst) * 2 + 1) - result[1::2] = lst - return result - - -def kl_divergence(m_p, logs_p, m_q, logs_q): - """KL(P||Q)""" - kl = (logs_q - logs_p) - 0.5 - kl += ( - 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) - ) - return kl - - -def rand_gumbel(shape): - """Sample from the Gumbel distribution, protect from overflows.""" - uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 - return -torch.log(-torch.log(uniform_samples)) - - -def rand_gumbel_like(x): - g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) - return g - - -def slice_segments(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, :, idx_str:idx_end] - return ret - - -def slice_segments_audio(x, ids_str, segment_size=4): - ret = torch.zeros_like(x[:, :segment_size]) - for i in range(x.size(0)): - idx_str = ids_str[i] - idx_end = idx_str + segment_size - ret[i] = x[i, idx_str:idx_end] - return ret - - -def rand_slice_segments(x, x_lengths=None, segment_size=4): - b, d, t = x.size() - if x_lengths is None: - x_lengths = t - ids_str_max = x_lengths - segment_size + 1 - ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to( - dtype=torch.long - ) - ret = slice_segments(x, ids_str, segment_size) - return ret, ids_str - - -def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): - position = torch.arange(length, dtype=torch.float) - num_timescales = channels // 2 - log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( - num_timescales - 1 - ) - inv_timescales = min_timescale * torch.exp( - torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment - ) - scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) - signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) - signal = F.pad(signal, [0, 0, 0, channels % 2]) - signal = signal.view(1, channels, length) - return signal - - -def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return x + signal.to(dtype=x.dtype, device=x.device) - - -def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): - b, channels, length = x.size() - signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) - return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) - - -def subsequent_mask(length): - mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) - return mask - - -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts - - -def convert_pad_shape(pad_shape): - l = pad_shape[::-1] - pad_shape = [item for sublist in l for item in sublist] - return pad_shape - - -def shift_1d(x): - x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] - return x - - -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - -def avg_with_mask(x, mask): - assert mask.dtype == torch.float, "Mask should be float" - - if mask.ndim == 2: - mask = mask.unsqueeze(1) - - if mask.shape[1] == 1: - mask = mask.expand_as(x) - - return (x * mask).sum() / mask.sum() - - -def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - device = duration.device - - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2, 3) * mask - return path - - -def clip_grad_value_(parameters, clip_value, norm_type=2): - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - parameters = list(filter(lambda p: p.grad is not None, parameters)) - norm_type = float(norm_type) - if clip_value is not None: - clip_value = float(clip_value) - - total_norm = 0 - for p in parameters: - param_norm = p.grad.data.norm(norm_type) - total_norm += param_norm.item() ** norm_type - if clip_value is not None: - p.grad.data.clamp_(min=-clip_value, max=clip_value) - total_norm = total_norm ** (1.0 / norm_type) - return total_norm - - -def log_norm(x, mean=-4, std=4, dim=2): - """ - normalized log mel -> mel -> norm -> log(norm) - """ - x = torch.log(torch.exp(x * std + mean).norm(dim=dim)) - return x - - -def load_F0_models(path): - # load F0 model - from .JDC.model import JDCNet - - F0_model = JDCNet(num_class=1, seq_len=192) - params = torch.load(path, map_location="cpu")["net"] - F0_model.load_state_dict(params) - _ = F0_model.train() - - return F0_model - - -def modify_w2v_forward(self, output_layer=15): - """ - change forward method of w2v encoder to get its intermediate layer output - :param self: - :param layer: - :return: - """ - from transformers.modeling_outputs import BaseModelOutput - - def forward( - hidden_states, - attention_mask=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - ): - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - conv_attention_mask = attention_mask - if attention_mask is not None: - # make sure padded tokens output 0 - hidden_states = hidden_states.masked_fill( - ~attention_mask.bool().unsqueeze(-1), 0.0 - ) - - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to( - dtype=hidden_states.dtype - ) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], - 1, - attention_mask.shape[-1], - attention_mask.shape[-1], - ) - - hidden_states = self.dropout(hidden_states) - - if self.embed_positions is not None: - relative_position_embeddings = self.embed_positions(hidden_states) - else: - relative_position_embeddings = None - - deepspeed_zero3_is_enabled = False - - for i, layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) - dropout_probability = torch.rand([]) - - skip_the_layer = ( - True - if self.training and (dropout_probability < self.config.layerdrop) - else False - ) - if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer.__call__, - hidden_states, - attention_mask, - relative_position_embeddings, - output_attentions, - conv_attention_mask, - ) - else: - layer_outputs = layer( - hidden_states, - attention_mask=attention_mask, - relative_position_embeddings=relative_position_embeddings, - output_attentions=output_attentions, - conv_attention_mask=conv_attention_mask, - ) - hidden_states = layer_outputs[0] - - if skip_the_layer: - layer_outputs = (None, None) - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - if i == output_layer - 1: - break - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [hidden_states, all_hidden_states, all_self_attentions] - if v is not None - ) - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - return forward - - -MATPLOTLIB_FLAG = False - - -def plot_spectrogram_to_numpy(spectrogram): - global MATPLOTLIB_FLAG - if not MATPLOTLIB_FLAG: - import matplotlib - import logging - - matplotlib.use("Agg") - MATPLOTLIB_FLAG = True - mpl_logger = logging.getLogger("matplotlib") - mpl_logger.setLevel(logging.WARNING) - import matplotlib.pylab as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(10, 2)) - im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") - plt.colorbar(im, ax=ax) - plt.xlabel("Frames") - plt.ylabel("Channels") - plt.tight_layout() - - fig.canvas.draw() - data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") - data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - plt.close() - return data - - -def normalize_f0(f0_sequence): - # Remove unvoiced frames (replace with -1) - voiced_indices = np.where(f0_sequence > 0)[0] - f0_voiced = f0_sequence[voiced_indices] - - # Convert to log scale - log_f0 = np.log2(f0_voiced) - - # Calculate mean and standard deviation - mean_f0 = np.mean(log_f0) - std_f0 = np.std(log_f0) - - # Normalize the F0 sequence - normalized_f0 = (log_f0 - mean_f0) / std_f0 - - # Create the normalized F0 sequence with unvoiced frames - normalized_sequence = np.zeros_like(f0_sequence) - normalized_sequence[voiced_indices] = normalized_f0 - normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames - - return normalized_sequence - - -def build_model(args, stage="DiT"): - if stage == "DiT": - from modules.flow_matching import CFM - from modules.length_regulator import InterpolateRegulator - - length_regulator = InterpolateRegulator( - channels=args.length_regulator.channels, - sampling_ratios=args.length_regulator.sampling_ratios, - is_discrete=args.length_regulator.is_discrete, - in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None, - vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False, - codebook_size=args.length_regulator.content_codebook_size, - n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1, - quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0, - f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False, - n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512, - ) - cfm = CFM(args) - nets = Munch( - cfm=cfm, - length_regulator=length_regulator, - ) - elif stage == 'codec': - from dac.model.dac import Encoder - from modules.quantize import ( - FAquantizer, - ) - - encoder = Encoder( - d_model=args.DAC.encoder_dim, - strides=args.DAC.encoder_rates, - d_latent=1024, - causal=args.causal, - lstm=args.lstm, - ) - - quantizer = FAquantizer( - in_dim=1024, - n_p_codebooks=1, - n_c_codebooks=args.n_c_codebooks, - n_t_codebooks=2, - n_r_codebooks=3, - codebook_size=1024, - codebook_dim=8, - quantizer_dropout=0.5, - causal=args.causal, - separate_prosody_encoder=args.separate_prosody_encoder, - timbre_norm=args.timbre_norm, - ) - - nets = Munch( - encoder=encoder, - quantizer=quantizer, - ) - else: - raise ValueError(f"Unknown stage: {stage}") - - return nets - - -def load_checkpoint( - model, - optimizer, - path, - load_only_params=True, - ignore_modules=[], - is_distributed=False, -): - state = torch.load(path, map_location="cpu") - params = state["net"] - for key in model: - if key in params and key not in ignore_modules: - if not is_distributed: - # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix - for k in list(params[key].keys()): - if k.startswith("module."): - params[key][k[len("module.") :]] = params[key][k] - del params[key][k] - model_state_dict = model[key].state_dict() - # 过滤出形状匹配的键值对 - filtered_state_dict = { - k: v - for k, v in params[key].items() - if k in model_state_dict and v.shape == model_state_dict[k].shape - } - skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys()) - if skipped_keys: - print( - f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}" - ) - print("%s loaded" % key) - model[key].load_state_dict(filtered_state_dict, strict=False) - _ = [model[key].eval() for key in model] - - if not load_only_params: - epoch = state["epoch"] + 1 - iters = state["iters"] - optimizer.load_state_dict(state["optimizer"]) - optimizer.load_scheduler_state_dict(state["scheduler"]) - - else: - epoch = 0 - iters = 0 - - return model, optimizer, epoch, iters - - -def recursive_munch(d): - if isinstance(d, dict): - return Munch((k, recursive_munch(v)) for k, v in d.items()) - elif isinstance(d, list): - return [recursive_munch(v) for v in d] - else: - return d +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from munch import Munch +import json +import argparse + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +def kl_divergence(m_p, logs_p, m_q, logs_q): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def slice_segments_audio(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to( + dtype=torch.long + ) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def avg_with_mask(x, mask): + assert mask.dtype == torch.float, "Mask should be float" + + if mask.ndim == 2: + mask = mask.unsqueeze(1) + + if mask.shape[1] == 1: + mask = mask.expand_as(x) + + return (x * mask).sum() / mask.sum() + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm + + +def log_norm(x, mean=-4, std=4, dim=2): + """ + normalized log mel -> mel -> norm -> log(norm) + """ + x = torch.log(torch.exp(x * std + mean).norm(dim=dim)) + return x + + +def load_F0_models(path): + # load F0 model + from .JDC.model import JDCNet + + F0_model = JDCNet(num_class=1, seq_len=192) + params = torch.load(path, map_location="cpu")["net"] + F0_model.load_state_dict(params) + _ = F0_model.train() + + return F0_model + + +def modify_w2v_forward(self, output_layer=15): + """ + change forward method of w2v encoder to get its intermediate layer output + :param self: + :param layer: + :return: + """ + from transformers.modeling_outputs import BaseModelOutput + + def forward( + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + conv_attention_mask = attention_mask + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states = hidden_states.masked_fill( + ~attention_mask.bool().unsqueeze(-1), 0.0 + ) + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to( + dtype=hidden_states.dtype + ) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], + 1, + attention_mask.shape[-1], + attention_mask.shape[-1], + ) + + hidden_states = self.dropout(hidden_states) + + if self.embed_positions is not None: + relative_position_embeddings = self.embed_positions(hidden_states) + else: + relative_position_embeddings = None + + deepspeed_zero3_is_enabled = False + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = ( + True + if self.training and (dropout_probability < self.config.layerdrop) + else False + ) + if not skip_the_layer or deepspeed_zero3_is_enabled: + # under deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + relative_position_embeddings, + output_attentions, + conv_attention_mask, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + conv_attention_mask=conv_attention_mask, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if i == output_layer - 1: + break + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions] + if v is not None + ) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + return forward + + +MATPLOTLIB_FLAG = False + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + import logging + + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + mpl_logger = logging.getLogger("matplotlib") + mpl_logger.setLevel(logging.WARNING) + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def normalize_f0(f0_sequence): + # Remove unvoiced frames (replace with -1) + voiced_indices = np.where(f0_sequence > 0)[0] + f0_voiced = f0_sequence[voiced_indices] + + # Convert to log scale + log_f0 = np.log2(f0_voiced) + + # Calculate mean and standard deviation + mean_f0 = np.mean(log_f0) + std_f0 = np.std(log_f0) + + # Normalize the F0 sequence + normalized_f0 = (log_f0 - mean_f0) / std_f0 + + # Create the normalized F0 sequence with unvoiced frames + normalized_sequence = np.zeros_like(f0_sequence) + normalized_sequence[voiced_indices] = normalized_f0 + normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames + + return normalized_sequence + + +def build_model(args, stage="DiT"): + if stage == "DiT": + from modules.flow_matching import CFM + from modules.length_regulator import InterpolateRegulator + + length_regulator = InterpolateRegulator( + channels=args.length_regulator.channels, + sampling_ratios=args.length_regulator.sampling_ratios, + is_discrete=args.length_regulator.is_discrete, + in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None, + codebook_size=args.length_regulator.content_codebook_size, + f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False, + n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512, + ) + cfm = CFM(args) + nets = Munch( + cfm=cfm, + length_regulator=length_regulator, + ) + else: + raise ValueError(f"Unknown stage: {stage}") + + return nets + + +def load_checkpoint( + model, + optimizer, + path, + load_only_params=True, + ignore_modules=[], + is_distributed=False, + load_ema=False, +): + state = torch.load(path, map_location="cpu") + params = state["net"] + if load_ema and "ema" in state: + print("Loading EMA") + for key in model: + i = 0 + for param_name in params[key]: + if "input_pos" in param_name: + continue + assert params[key][param_name].shape == state["ema"][key][0][i].shape + params[key][param_name] = state["ema"][key][0][i].clone() + i += 1 + for key in model: + if key in params and key not in ignore_modules: + if not is_distributed: + # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix + for k in list(params[key].keys()): + if k.startswith("module."): + params[key][k[len("module.") :]] = params[key][k] + del params[key][k] + model_state_dict = model[key].state_dict() + # 过滤出形状匹配的键值对 + filtered_state_dict = { + k: v + for k, v in params[key].items() + if k in model_state_dict and v.shape == model_state_dict[k].shape + } + skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys()) + if skipped_keys: + print( + f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}" + ) + print("%s loaded" % key) + model[key].load_state_dict(filtered_state_dict, strict=False) + _ = [model[key].eval() for key in model] + + if not load_only_params: + epoch = state["epoch"] + 1 + iters = state["iters"] + optimizer.load_state_dict(state["optimizer"]) + optimizer.load_scheduler_state_dict(state["scheduler"]) + + else: + epoch = 0 + iters = 0 + + return model, optimizer, epoch, iters + + +def recursive_munch(d): + if isinstance(d, dict): + return Munch((k, recursive_munch(v)) for k, v in d.items()) + elif isinstance(d, list): + return [recursive_munch(v) for v in d] + else: + return d diff --git a/modules/diffusion_transformer.py b/modules/diffusion_transformer.py index b7f40975e52d1cc7944192bff30e2e7341e4fedb..f9b468fa6701a72fab3e55e31dadc814e10c78f1 100644 --- a/modules/diffusion_transformer.py +++ b/modules/diffusion_transformer.py @@ -1,240 +1,537 @@ -import torch -from torch import nn -import math - -from modules.gpt_fast.model import ModelArgs, Transformer -# from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer -from modules.wavenet import WN -from modules.commons import sequence_mask - -from torch.nn.utils import weight_norm - -def modulate(x, shift, scale): - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - - -################################################################################# -# Embedding Layers for Timesteps and Class Labels # -################################################################################# - -class TimestepEmbedder(nn.Module): - """ - Embeds scalar timesteps into vector representations. - """ - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias=True), - nn.SiLU(), - nn.Linear(hidden_size, hidden_size, bias=True), - ) - self.frequency_embedding_size = frequency_embedding_size - self.max_period = 10000 - self.scale = 1000 - - half = frequency_embedding_size // 2 - freqs = torch.exp( - -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ) - self.register_buffer("freqs", freqs) - - def timestep_embedding(self, t): - """ - Create sinusoidal timestep embeddings. - :param t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an (N, D) Tensor of positional embeddings. - """ - # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py - - args = self.scale * t[:, None].float() * self.freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if self.frequency_embedding_size % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t) - t_emb = self.mlp(t_freq) - return t_emb - - -class StyleEmbedder(nn.Module): - """ - Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. - """ - def __init__(self, input_size, hidden_size, dropout_prob): - super().__init__() - use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size) - self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True)) - self.input_size = input_size - self.dropout_prob = dropout_prob - - def forward(self, labels, train, force_drop_ids=None): - use_dropout = self.dropout_prob > 0 - if (train and use_dropout) or (force_drop_ids is not None): - labels = self.token_drop(labels, force_drop_ids) - else: - labels = self.style_in(labels) - embeddings = labels - return embeddings - -class FinalLayer(nn.Module): - """ - The final layer of DiT. - """ - def __init__(self, hidden_size, patch_size, out_channels): - super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) - self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)) - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(hidden_size, 2 * hidden_size, bias=True) - ) - - def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) - x = modulate(self.norm_final(x), shift, scale) - x = self.linear(x) - return x - -class DiT(torch.nn.Module): - def __init__( - self, - args - ): - super(DiT, self).__init__() - self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False - self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False - self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False - model_args = ModelArgs( - block_size=16384,#args.DiT.block_size, - n_layer=args.DiT.depth, - n_head=args.DiT.num_heads, - dim=args.DiT.hidden_dim, - head_dim=args.DiT.hidden_dim // args.DiT.num_heads, - vocab_size=1024, - uvit_skip_connection=self.uvit_skip_connection, - ) - self.transformer = Transformer(model_args) - self.in_channels = args.DiT.in_channels - self.out_channels = args.DiT.in_channels - self.num_heads = args.DiT.num_heads - - self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True)) - - self.content_type = args.DiT.content_type # 'discrete' or 'continuous' - self.content_codebook_size = args.DiT.content_codebook_size # for discrete content - self.content_dim = args.DiT.content_dim # for continuous content - self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content - self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content - - self.is_causal = args.DiT.is_causal - - self.n_f0_bins = args.DiT.n_f0_bins - self.f0_bins = torch.arange(2, 1024, 1024 // args.DiT.n_f0_bins) - self.f0_embedder = nn.Embedding(args.DiT.n_f0_bins, args.DiT.hidden_dim) - self.f0_condition = args.DiT.f0_condition - - self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim) - self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim) - # self.style_embedder1 = weight_norm(nn.Linear(1024, args.DiT.hidden_dim, bias=True)) - # self.style_embedder2 = weight_norm(nn.Linear(1024, args.style_encoder.dim, bias=True)) - - input_pos = torch.arange(16384) - self.register_buffer("input_pos", input_pos) - - self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim) - self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1) - self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet - if self.final_layer_type == 'wavenet': - self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim, - kernel_size=args.wavenet.kernel_size, - dilation_rate=args.wavenet.dilation_rate, - n_layers=args.wavenet.num_layers, - gin_channels=args.wavenet.hidden_dim, - p_dropout=args.wavenet.p_dropout, - causal=False) - self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim) - else: - self.final_mlp = nn.Sequential( - nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim), - nn.SiLU(), - nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels), - ) - self.transformer_style_condition = args.DiT.style_condition - self.wavenet_style_condition = args.wavenet.style_condition - assert args.DiT.style_condition == args.wavenet.style_condition - - self.class_dropout_prob = args.DiT.class_dropout_prob - self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim) - self.res_projection = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim) # residual connection from tranformer output to final output - self.long_skip_connection = args.DiT.long_skip_connection - self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim) - - self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 + - args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token), - args.DiT.hidden_dim) - if self.style_as_token: - self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim) - - def setup_caches(self, max_batch_size, max_seq_length): - self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False) - def forward(self, x, prompt_x, x_lens, t, style, cond, f0=None, mask_content=False): - class_dropout = False - if self.training and torch.rand(1) < self.class_dropout_prob: - class_dropout = True - if not self.training and mask_content: - class_dropout = True - # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection - cond_in_module = self.cond_projection - - B, _, T = x.size() - - - t1 = self.t_embedder(t) # (N, D) - - cond = cond_in_module(cond) - if self.f0_condition and f0 is not None: - quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T) - cond = cond + self.f0_embedder(quantized_f0) - - x = x.transpose(1, 2) - prompt_x = prompt_x.transpose(1, 2) - - x_in = torch.cat([x, prompt_x, cond], dim=-1) - if self.transformer_style_condition and not self.style_as_token: - x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1) - if class_dropout: - x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0 - x_in = self.cond_x_merge_linear(x_in) # (N, T, D) - - if self.style_as_token: - style = self.style_in(style) - style = torch.zeros_like(style) if class_dropout else style - x_in = torch.cat([style.unsqueeze(1), x_in], dim=1) - if self.time_as_token: - x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1) - x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) - input_pos = self.input_pos[:x_in.size(1)] # (T,) - x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None - x_res = self.transformer(x_in, None if self.time_as_token else t1.unsqueeze(1), input_pos, x_mask_expanded) - x_res = x_res[:, 1:] if self.time_as_token else x_res - x_res = x_res[:, 1:] if self.style_as_token else x_res - if self.long_skip_connection: - x_res = self.skip_linear(torch.cat([x_res, x], dim=-1)) - if self.final_layer_type == 'wavenet': - x = self.conv1(x_res) - x = x.transpose(1, 2) - t2 = self.t_embedder2(t) - x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection( - x_res) # long residual connection - x = self.final_layer(x, t1).transpose(1, 2) - x = self.conv2(x) - else: - x = self.final_mlp(x_res) - x = x.transpose(1, 2) - return x +import torch +from torch import nn +import math + +# from modules.torchscript_modules.gpt_fast_model import ModelArgs, Transformer +from modules.wavenet import WN +from modules.commons import sequence_mask + +from torch.nn.utils import weight_norm + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if embedding is None: + return self.norm(input) + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + has_cross_attention: bool = False + context_dim: int = 0 + uvit_skip_connection: bool = False + time_as_token: bool = False + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + # self.head_dim = self.dim // self.n_head + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=False): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + dtype = self.norm.project_layer.weight.dtype + device = self.norm.project_layer.weight.device + + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, + self.config.rope_base, dtype).to(device) + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device) + self.use_kv_cache = use_kv_cache + self.uvit_skip_connection = self.config.uvit_skip_connection + if self.uvit_skip_connection: + self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2] + self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2] + else: + self.layers_emit_skip = [] + self.layers_receive_skip = [] + + def forward(self, + x: Tensor, + c: Tensor, + input_pos: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + context: Optional[Tensor] = None, + context_input_pos: Optional[Tensor] = None, + cross_attention_mask: Optional[Tensor] = None, + ) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + if mask is None: # in case of non-causal model + if not self.training and self.use_kv_cache: + mask = self.causal_mask[None, None, input_pos] + else: + mask = self.causal_mask[None, None, input_pos] + mask = mask[..., input_pos] + freqs_cis = self.freqs_cis[input_pos] + if context is not None: + context_freqs_cis = self.freqs_cis[context_input_pos] + else: + context_freqs_cis = None + skip_in_x_list = [] + for i, layer in enumerate(self.layers): + if self.uvit_skip_connection and i in self.layers_receive_skip: + skip_in_x = skip_in_x_list.pop(-1) + else: + skip_in_x = None + x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x) + if self.uvit_skip_connection and i in self.layers_emit_skip: + skip_in_x_list.append(x) + x = self.norm(x, c) + return x + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + + if config.has_cross_attention: + self.has_cross_attention = True + self.cross_attention = Attention(config, is_cross_attention=True) + self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + else: + self.has_cross_attention = False + + if config.uvit_skip_connection: + self.skip_in_linear = nn.Linear(config.dim * 2, config.dim) + self.uvit_skip_connection = True + else: + self.uvit_skip_connection = False + + self.time_as_token = config.time_as_token + + def forward(self, + x: Tensor, + c: Tensor, + input_pos: Tensor, + freqs_cis: Tensor, + mask: Tensor, + context: Optional[Tensor] = None, + context_freqs_cis: Optional[Tensor] = None, + cross_attention_mask: Optional[Tensor] = None, + skip_in_x: Optional[Tensor] = None, + ) -> Tensor: + c = None if self.time_as_token else c + if self.uvit_skip_connection and skip_in_x is not None: + x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1)) + h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos) + if self.has_cross_attention: + h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h, c)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs, is_cross_attention: bool = False): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + if is_cross_attention: + self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False) + self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False) + else: + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + # self._register_load_state_dict_pre_hook(self.load_hook) + + # def load_hook(self, state_dict, prefix, *args): + # if prefix + "wq.weight" in state_dict: + # wq = state_dict.pop(prefix + "wq.weight") + # wk = state_dict.pop(prefix + "wk.weight") + # wv = state_dict.pop(prefix + "wv.weight") + # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward(self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + context: Optional[Tensor] = None, + context_freqs_cis: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + if context is None: + q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) + context_seqlen = seqlen + else: + q = self.wq(x) + k, v = self.wkv(context).split([kv_size, kv_size], dim=-1) + context_seqlen = context.shape[1] + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000, + dtype: torch.dtype = torch.bfloat16 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + self.max_period = 10000 + self.scale = 1000 + + half = frequency_embedding_size // 2 + freqs = torch.exp( + -math.log(self.max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ) + self.register_buffer("freqs", freqs) + + def timestep_embedding(self, t): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + + args = self.scale * t[:, None].float() * self.freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if self.frequency_embedding_size % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t) + t_emb = self.mlp(t_freq) + return t_emb + + +class StyleEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, input_size, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(int(use_cfg_embedding), hidden_size) + self.style_in = weight_norm(nn.Linear(input_size, hidden_size, bias=True)) + self.input_size = input_size + self.dropout_prob = dropout_prob + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + else: + labels = self.style_in(labels) + embeddings = labels + return embeddings + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = weight_norm(nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class DiT(torch.nn.Module): + def __init__( + self, + args + ): + super(DiT, self).__init__() + self.time_as_token = args.DiT.time_as_token if hasattr(args.DiT, 'time_as_token') else False + self.style_as_token = args.DiT.style_as_token if hasattr(args.DiT, 'style_as_token') else False + self.uvit_skip_connection = args.DiT.uvit_skip_connection if hasattr(args.DiT, 'uvit_skip_connection') else False + model_args = ModelArgs( + block_size=16384,#args.DiT.block_size, + n_layer=args.DiT.depth, + n_head=args.DiT.num_heads, + dim=args.DiT.hidden_dim, + head_dim=args.DiT.hidden_dim // args.DiT.num_heads, + vocab_size=1024, + uvit_skip_connection=self.uvit_skip_connection, + time_as_token=self.time_as_token, + ) + self.transformer = Transformer(model_args) + self.in_channels = args.DiT.in_channels + self.out_channels = args.DiT.in_channels + self.num_heads = args.DiT.num_heads + + self.x_embedder = weight_norm(nn.Linear(args.DiT.in_channels, args.DiT.hidden_dim, bias=True)) + + self.content_type = args.DiT.content_type # 'discrete' or 'continuous' + self.content_codebook_size = args.DiT.content_codebook_size # for discrete content + self.content_dim = args.DiT.content_dim # for continuous content + self.cond_embedder = nn.Embedding(args.DiT.content_codebook_size, args.DiT.hidden_dim) # discrete content + self.cond_projection = nn.Linear(args.DiT.content_dim, args.DiT.hidden_dim, bias=True) # continuous content + + self.is_causal = args.DiT.is_causal + + self.t_embedder = TimestepEmbedder(args.DiT.hidden_dim) + + input_pos = torch.arange(16384) + self.register_buffer("input_pos", input_pos) + + self.final_layer_type = args.DiT.final_layer_type # mlp or wavenet + if self.final_layer_type == 'wavenet': + self.t_embedder2 = TimestepEmbedder(args.wavenet.hidden_dim) + self.conv1 = nn.Linear(args.DiT.hidden_dim, args.wavenet.hidden_dim) + self.conv2 = nn.Conv1d(args.wavenet.hidden_dim, args.DiT.in_channels, 1) + self.wavenet = WN(hidden_channels=args.wavenet.hidden_dim, + kernel_size=args.wavenet.kernel_size, + dilation_rate=args.wavenet.dilation_rate, + n_layers=args.wavenet.num_layers, + gin_channels=args.wavenet.hidden_dim, + p_dropout=args.wavenet.p_dropout, + causal=False) + self.final_layer = FinalLayer(args.wavenet.hidden_dim, 1, args.wavenet.hidden_dim) + self.res_projection = nn.Linear(args.DiT.hidden_dim, + args.wavenet.hidden_dim) # residual connection from tranformer output to final output + self.wavenet_style_condition = args.wavenet.style_condition + assert args.DiT.style_condition == args.wavenet.style_condition + else: + self.final_mlp = nn.Sequential( + nn.Linear(args.DiT.hidden_dim, args.DiT.hidden_dim), + nn.SiLU(), + nn.Linear(args.DiT.hidden_dim, args.DiT.in_channels), + ) + self.transformer_style_condition = args.DiT.style_condition + + + self.class_dropout_prob = args.DiT.class_dropout_prob + self.content_mask_embedder = nn.Embedding(1, args.DiT.hidden_dim) + + self.long_skip_connection = args.DiT.long_skip_connection + self.skip_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels, args.DiT.hidden_dim) + + self.cond_x_merge_linear = nn.Linear(args.DiT.hidden_dim + args.DiT.in_channels * 2 + + args.style_encoder.dim * self.transformer_style_condition * (not self.style_as_token), + args.DiT.hidden_dim) + if self.style_as_token: + self.style_in = nn.Linear(args.style_encoder.dim, args.DiT.hidden_dim) + + def setup_caches(self, max_batch_size, max_seq_length): + self.transformer.setup_caches(max_batch_size, max_seq_length, use_kv_cache=False) + def forward(self, x, prompt_x, x_lens, t, style, cond, mask_content=False): + class_dropout = False + if self.training and torch.rand(1) < self.class_dropout_prob: + class_dropout = True + if not self.training and mask_content: + class_dropout = True + # cond_in_module = self.cond_embedder if self.content_type == 'discrete' else self.cond_projection + cond_in_module = self.cond_projection + + B, _, T = x.size() + + + t1 = self.t_embedder(t) # (N, D) + + cond = cond_in_module(cond) + + x = x.transpose(1, 2) + prompt_x = prompt_x.transpose(1, 2) + + x_in = torch.cat([x, prompt_x, cond], dim=-1) + if self.transformer_style_condition and not self.style_as_token: + x_in = torch.cat([x_in, style[:, None, :].repeat(1, T, 1)], dim=-1) + if class_dropout: + x_in[..., self.in_channels:] = x_in[..., self.in_channels:] * 0 + x_in = self.cond_x_merge_linear(x_in) # (N, T, D) + + if self.style_as_token: + style = self.style_in(style) + style = torch.zeros_like(style) if class_dropout else style + x_in = torch.cat([style.unsqueeze(1), x_in], dim=1) + if self.time_as_token: + x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1) + x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token).to(x.device).unsqueeze(1) + input_pos = self.input_pos[:x_in.size(1)] # (T,) + x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) if not self.is_causal else None + x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) + x_res = x_res[:, 1:] if self.time_as_token else x_res + x_res = x_res[:, 1:] if self.style_as_token else x_res + if self.long_skip_connection: + x_res = self.skip_linear(torch.cat([x_res, x], dim=-1)) + if self.final_layer_type == 'wavenet': + x = self.conv1(x_res) + x = x.transpose(1, 2) + t2 = self.t_embedder2(t) + x = self.wavenet(x, x_mask, g=t2.unsqueeze(2)).transpose(1, 2) + self.res_projection( + x_res) # long residual connection + x = self.final_layer(x, t1).transpose(1, 2) + x = self.conv2(x) + else: + x = self.final_mlp(x_res) + x = x.transpose(1, 2) + return x \ No newline at end of file diff --git a/modules/flow_matching.py b/modules/flow_matching.py index c2581c620f884b7c4b60164729240c310198d74a..61389183c6604edf80a9517ead197dd6aa097740 100644 --- a/modules/flow_matching.py +++ b/modules/flow_matching.py @@ -49,6 +49,7 @@ class BASECFM(torch.nn.Module, ABC): B, T = mu.size(0), mu.size(1) z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + # t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span) return self.solve_euler(z, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate) def solve_euler(self, x, x_lens, prompt, mu, style, f0, t_span, inference_cfg_rate=0.5): @@ -66,7 +67,7 @@ class BASECFM(torch.nn.Module, ABC): shape: (batch_size, spk_emb_dim) cond: Not used but kept for future purposes """ - t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + t, _, _ = t_span[0], t_span[-1], t_span[1] - t_span[0] # I am storing this because I can later plot it by putting a debugger here and saving it to a file # Or in future might add like a return_all_steps flag @@ -79,16 +80,28 @@ class BASECFM(torch.nn.Module, ABC): if self.zero_prompt_speech_token: mu[..., :prompt_len] = 0 for step in tqdm(range(1, len(t_span))): - dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu, f0) - # Classifier-Free Guidance inference introduced in VoiceBox + dt = t_span[step] - t_span[step - 1] if inference_cfg_rate > 0: - cfg_dphi_dt = self.estimator( - x, torch.zeros_like(prompt_x), x_lens, t.unsqueeze(0), - torch.zeros_like(style), - torch.zeros_like(mu), None + # Stack original and CFG (null) inputs for batched processing + stacked_prompt_x = torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0) + stacked_style = torch.cat([style, torch.zeros_like(style)], dim=0) + stacked_mu = torch.cat([mu, torch.zeros_like(mu)], dim=0) + stacked_x = torch.cat([x, x], dim=0) + stacked_t = torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0) + + # Perform a single forward pass for both original and CFG inputs + stacked_dphi_dt = self.estimator( + stacked_x, stacked_prompt_x, x_lens, stacked_t, stacked_style, stacked_mu, ) - dphi_dt = ((1.0 + inference_cfg_rate) * dphi_dt - - inference_cfg_rate * cfg_dphi_dt) + + # Split the output back into the original and CFG components + dphi_dt, cfg_dphi_dt = stacked_dphi_dt.chunk(2, dim=0) + + # Apply CFG formula + dphi_dt = (1.0 + inference_cfg_rate) * dphi_dt - inference_cfg_rate * cfg_dphi_dt + else: + dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu) + x = x + dt * dphi_dt t = t + dt sol.append(x) @@ -97,8 +110,7 @@ class BASECFM(torch.nn.Module, ABC): x[:, :, :prompt_len] = 0 return sol[-1] - - def forward(self, x1, x_lens, prompt_lens, mu, style, f0=None): + def forward(self, x1, x_lens, prompt_lens, mu, style): """Computes diffusion loss Args: @@ -134,13 +146,13 @@ class BASECFM(torch.nn.Module, ABC): if self.zero_prompt_speech_token: mu[bib, :, :prompt_lens[bib]] = 0 - estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(), style, mu, f0) + estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(1).squeeze(1), style, mu, prompt_lens) loss = 0 for bib in range(b): loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]]) loss /= b - return loss, y + return loss, estimator_out + (1 - self.sigma_min) * z diff --git a/modules/length_regulator.py b/modules/length_regulator.py index a896c6ced97e409ba657f60af59a2f82e1688e65..8bc875326f8b846a09fbb9602d3ebf3ba6cc3b0f 100644 --- a/modules/length_regulator.py +++ b/modules/length_regulator.py @@ -1,141 +1,141 @@ -from typing import Tuple -import torch -import torch.nn as nn -from torch.nn import functional as F -from modules.commons import sequence_mask -import numpy as np -from dac.nn.quantize import VectorQuantize - -# f0_bin = 256 -f0_max = 1100.0 -f0_min = 50.0 -f0_mel_min = 1127 * np.log(1 + f0_min / 700) -f0_mel_max = 1127 * np.log(1 + f0_max / 700) - -def f0_to_coarse(f0, f0_bin): - f0_mel = 1127 * (1 + f0 / 700).log() - a = (f0_bin - 2) / (f0_mel_max - f0_mel_min) - b = f0_mel_min * a - 1. - f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel) - # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1)) - f0_coarse = torch.round(f0_mel).long() - f0_coarse = f0_coarse * (f0_coarse > 0) - f0_coarse = f0_coarse + ((f0_coarse < 1) * 1) - f0_coarse = f0_coarse * (f0_coarse < f0_bin) - f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1)) - return f0_coarse - -class InterpolateRegulator(nn.Module): - def __init__( - self, - channels: int, - sampling_ratios: Tuple, - is_discrete: bool = False, - in_channels: int = None, # only applies to continuous input - vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input - codebook_size: int = 1024, # for discrete only - out_channels: int = None, - groups: int = 1, - n_codebooks: int = 1, # number of codebooks - quantizer_dropout: float = 0.0, # dropout for quantizer - f0_condition: bool = False, - n_f0_bins: int = 512, - ): - super().__init__() - self.sampling_ratios = sampling_ratios - out_channels = out_channels or channels - model = nn.ModuleList([]) - if len(sampling_ratios) > 0: - self.interpolate = True - for _ in sampling_ratios: - module = nn.Conv1d(channels, channels, 3, 1, 1) - norm = nn.GroupNorm(groups, channels) - act = nn.Mish() - model.extend([module, norm, act]) - else: - self.interpolate = False - model.append( - nn.Conv1d(channels, out_channels, 1, 1) - ) - self.model = nn.Sequential(*model) - self.embedding = nn.Embedding(codebook_size, channels) - self.is_discrete = is_discrete - - self.mask_token = nn.Parameter(torch.zeros(1, channels)) - - self.n_codebooks = n_codebooks - if n_codebooks > 1: - self.extra_codebooks = nn.ModuleList([ - nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1) - ]) - self.extra_codebook_mask_tokens = nn.ParameterList([ - nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1) - ]) - self.quantizer_dropout = quantizer_dropout - - if f0_condition: - self.f0_embedding = nn.Embedding(n_f0_bins, channels) - self.f0_condition = f0_condition - self.n_f0_bins = n_f0_bins - self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins) - self.f0_mask = nn.Parameter(torch.zeros(1, channels)) - else: - self.f0_condition = False - - if not is_discrete: - self.content_in_proj = nn.Linear(in_channels, channels) - if vector_quantize: - self.vq = VectorQuantize(channels, codebook_size, 8) - - def forward(self, x, ylens=None, n_quantizers=None, f0=None): - # apply token drop - if self.training: - n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks - dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],)) - n_dropout = int(x.shape[0] * self.quantizer_dropout) - n_quantizers[:n_dropout] = dropout[:n_dropout] - n_quantizers = n_quantizers.to(x.device) - # decide whether to drop for each sample in batch - else: - n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers) - if self.is_discrete: - if self.n_codebooks > 1: - assert len(x.size()) == 3 - x_emb = self.embedding(x[:, 0]) - for i, emb in enumerate(self.extra_codebooks): - x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1]) - # add mask token if not using this codebook - # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i] - x = x_emb - elif self.n_codebooks == 1: - if len(x.size()) == 2: - x = self.embedding(x) - else: - x = self.embedding(x[:, 0]) - else: - x = self.content_in_proj(x) - # x in (B, T, D) - mask = sequence_mask(ylens).unsqueeze(-1) - if self.interpolate: - x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest') - else: - x = x.transpose(1, 2).contiguous() - mask = mask[:, :x.size(2), :] - ylens = ylens.clamp(max=x.size(2)).long() - if self.f0_condition: - if f0 is None: - x = x + self.f0_mask.unsqueeze(-1) - else: - #quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T) - quantized_f0 = f0_to_coarse(f0, self.n_f0_bins) - quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long() - f0_emb = self.f0_embedding(quantized_f0) - f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest') - x = x + f0_emb - out = self.model(x).transpose(1, 2).contiguous() - if hasattr(self, 'vq'): - out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2)) - out_q = out_q.transpose(1, 2) - return out_q * mask, ylens, codes, commitment_loss, codebook_loss - olens = ylens - return out * mask, olens, None, None, None +from typing import Tuple +import torch +import torch.nn as nn +from torch.nn import functional as F +from modules.commons import sequence_mask +import numpy as np +from dac.nn.quantize import VectorQuantize + +# f0_bin = 256 +f0_max = 1100.0 +f0_min = 50.0 +f0_mel_min = 1127 * np.log(1 + f0_min / 700) +f0_mel_max = 1127 * np.log(1 + f0_max / 700) + +def f0_to_coarse(f0, f0_bin): + f0_mel = 1127 * (1 + f0 / 700).log() + a = (f0_bin - 2) / (f0_mel_max - f0_mel_min) + b = f0_mel_min * a - 1. + f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel) + # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1)) + f0_coarse = torch.round(f0_mel).long() + f0_coarse = f0_coarse * (f0_coarse > 0) + f0_coarse = f0_coarse + ((f0_coarse < 1) * 1) + f0_coarse = f0_coarse * (f0_coarse < f0_bin) + f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1)) + return f0_coarse + +class InterpolateRegulator(nn.Module): + def __init__( + self, + channels: int, + sampling_ratios: Tuple, + is_discrete: bool = False, + in_channels: int = None, # only applies to continuous input + vector_quantize: bool = False, # whether to use vector quantization, only applies to continuous input + codebook_size: int = 1024, # for discrete only + out_channels: int = None, + groups: int = 1, + n_codebooks: int = 1, # number of codebooks + quantizer_dropout: float = 0.0, # dropout for quantizer + f0_condition: bool = False, + n_f0_bins: int = 512, + ): + super().__init__() + self.sampling_ratios = sampling_ratios + out_channels = out_channels or channels + model = nn.ModuleList([]) + if len(sampling_ratios) > 0: + self.interpolate = True + for _ in sampling_ratios: + module = nn.Conv1d(channels, channels, 3, 1, 1) + norm = nn.GroupNorm(groups, channels) + act = nn.Mish() + model.extend([module, norm, act]) + else: + self.interpolate = False + model.append( + nn.Conv1d(channels, out_channels, 1, 1) + ) + self.model = nn.Sequential(*model) + self.embedding = nn.Embedding(codebook_size, channels) + self.is_discrete = is_discrete + + self.mask_token = nn.Parameter(torch.zeros(1, channels)) + + self.n_codebooks = n_codebooks + if n_codebooks > 1: + self.extra_codebooks = nn.ModuleList([ + nn.Embedding(codebook_size, channels) for _ in range(n_codebooks - 1) + ]) + self.extra_codebook_mask_tokens = nn.ParameterList([ + nn.Parameter(torch.zeros(1, channels)) for _ in range(n_codebooks - 1) + ]) + self.quantizer_dropout = quantizer_dropout + + if f0_condition: + self.f0_embedding = nn.Embedding(n_f0_bins, channels) + self.f0_condition = f0_condition + self.n_f0_bins = n_f0_bins + self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins) + self.f0_mask = nn.Parameter(torch.zeros(1, channels)) + else: + self.f0_condition = False + + if not is_discrete: + self.content_in_proj = nn.Linear(in_channels, channels) + if vector_quantize: + self.vq = VectorQuantize(channels, codebook_size, 8) + + def forward(self, x, ylens=None, n_quantizers=None, f0=None): + # apply token drop + if self.training: + n_quantizers = torch.ones((x.shape[0],)) * self.n_codebooks + dropout = torch.randint(1, self.n_codebooks + 1, (x.shape[0],)) + n_dropout = int(x.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(x.device) + # decide whether to drop for each sample in batch + else: + n_quantizers = torch.ones((x.shape[0],), device=x.device) * (self.n_codebooks if n_quantizers is None else n_quantizers) + if self.is_discrete: + if self.n_codebooks > 1: + assert len(x.size()) == 3 + x_emb = self.embedding(x[:, 0]) + for i, emb in enumerate(self.extra_codebooks): + x_emb = x_emb + (n_quantizers > i+1)[..., None, None] * emb(x[:, i+1]) + # add mask token if not using this codebook + # x_emb = x_emb + (n_quantizers <= i+1)[..., None, None] * self.extra_codebook_mask_tokens[i] + x = x_emb + elif self.n_codebooks == 1: + if len(x.size()) == 2: + x = self.embedding(x) + else: + x = self.embedding(x[:, 0]) + else: + x = self.content_in_proj(x) + # x in (B, T, D) + mask = sequence_mask(ylens).unsqueeze(-1) + if self.interpolate: + x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest') + else: + x = x.transpose(1, 2).contiguous() + mask = mask[:, :x.size(2), :] + ylens = ylens.clamp(max=x.size(2)).long() + if self.f0_condition: + if f0 is None: + x = x + self.f0_mask.unsqueeze(-1) + else: + #quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T) + quantized_f0 = f0_to_coarse(f0, self.n_f0_bins) + quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long() + f0_emb = self.f0_embedding(quantized_f0) + f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest') + x = x + f0_emb + out = self.model(x).transpose(1, 2).contiguous() + if hasattr(self, 'vq'): + out_q, commitment_loss, codebook_loss, codes, out, = self.vq(out.transpose(1, 2)) + out_q = out_q.transpose(1, 2) + return out_q * mask, ylens, codes, commitment_loss, codebook_loss + olens = ylens + return out * mask, olens, None, None, None diff --git a/modules/rmvpe.py b/modules/rmvpe.py index 066a9eebdbcb16ab2d9de0e4738ad3575405907f..44ae2e0ec9fde661dd8360d3fd731fe66b5ab51c 100644 --- a/modules/rmvpe.py +++ b/modules/rmvpe.py @@ -486,7 +486,13 @@ class RMVPE: self.resample_kernel = {} self.is_half = is_half if device is None: - device = "cuda:0" if torch.cuda.is_available() else "cpu" + #device = "cuda:0" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda:0" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" self.device = device self.mel_extractor = MelSpectrogram( is_half, 128, 16000, 1024, 160, None, 30, 8000 @@ -572,6 +578,37 @@ class RMVPE: # t3 = ttime() # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0)) return f0 + def infer_from_audio_batch(self, audio, thred=0.03): + # torch.cuda.synchronize() + # t0 = ttime() + if not torch.is_tensor(audio): + audio = torch.from_numpy(audio) + mel = self.mel_extractor( + audio.float().to(self.device), center=True + ) + # print(123123123,mel.device.type) + # torch.cuda.synchronize() + # t1 = ttime() + hidden = self.mel2hidden(mel) + # torch.cuda.synchronize() + # t2 = ttime() + # print(234234,hidden.device.type) + if "privateuseone" not in str(self.device): + hidden = hidden.cpu().numpy() + else: + pass + if self.is_half == True: + hidden = hidden.astype("float32") + + f0s = [] + for bib in range(hidden.shape[0]): + f0s.append(self.decode(hidden[bib], thred=thred)) + f0s = np.stack(f0s) + f0s = torch.from_numpy(f0s).to(self.device) + # torch.cuda.synchronize() + # t3 = ttime() + # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0)) + return f0s def to_local_average_cents(self, salience, thred=0.05): # t0 = ttime() diff --git a/modules/v2/__pycache__/ar.cpython-310.pyc b/modules/v2/__pycache__/ar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ef2289697c46cebf33dec7e22ad3832e22768d5 Binary files /dev/null and b/modules/v2/__pycache__/ar.cpython-310.pyc differ diff --git a/modules/v2/__pycache__/cfm.cpython-310.pyc b/modules/v2/__pycache__/cfm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fff9c2cda4430b1f3849eb7e0941491130030e32 Binary files /dev/null and b/modules/v2/__pycache__/cfm.cpython-310.pyc differ diff --git a/modules/v2/__pycache__/dit_model.cpython-310.pyc b/modules/v2/__pycache__/dit_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30550b7c8fa982077b5988cb11ecebcffa45ea8a Binary files /dev/null and b/modules/v2/__pycache__/dit_model.cpython-310.pyc differ diff --git a/modules/v2/__pycache__/dit_wrapper.cpython-310.pyc b/modules/v2/__pycache__/dit_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c16d4408efe4500b7caf352370c03562b0a6b43 Binary files /dev/null and b/modules/v2/__pycache__/dit_wrapper.cpython-310.pyc differ diff --git a/modules/v2/__pycache__/length_regulator.cpython-310.pyc b/modules/v2/__pycache__/length_regulator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d20e70c5fa10e54c6dd8abba283d448e34b66b96 Binary files /dev/null and b/modules/v2/__pycache__/length_regulator.cpython-310.pyc differ diff --git a/modules/v2/__pycache__/model.cpython-310.pyc b/modules/v2/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..202b549ed99cfcf71724feae469c648534d13b30 Binary files /dev/null and b/modules/v2/__pycache__/model.cpython-310.pyc differ diff --git a/modules/v2/__pycache__/vc_wrapper.cpython-310.pyc b/modules/v2/__pycache__/vc_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5ceaa1e09ce09776cb8e26c81e94ccf4850512d Binary files /dev/null and b/modules/v2/__pycache__/vc_wrapper.cpython-310.pyc differ diff --git a/modules/v2/ar.py b/modules/v2/ar.py new file mode 100644 index 0000000000000000000000000000000000000000..1c38b99e14c0aa93bf4a908c231437904886ad8e --- /dev/null +++ b/modules/v2/ar.py @@ -0,0 +1,763 @@ +import dataclasses +import json +import math +from collections import OrderedDict +from functools import partial, wraps +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple, List +from tqdm import tqdm + +import torch +import torch.nn as nn +from einops import rearrange +from torch import Tensor +from torch.nn import functional as F +from torch.utils.checkpoint import checkpoint + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +def l2norm(t, groups = 1): + t = rearrange(t, '... (g d) -> ... g d', g = groups) + t = F.normalize(t, p = 2, dim = -1) + return rearrange(t, '... g d -> ... (g d)') + +@dataclass +class BaseModelArgs: + model_type: str = "base" + + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + max_seq_len: int = 4096 + dropout: float = 0.0 + tie_word_embeddings: bool = True + attention_qkv_bias: bool = False + + # Gradient checkpointing + use_gradient_checkpointing: bool = False + + # Initialize the model + initializer_range: float = 0.02 + + qk_norm: bool = False + layerscale: bool = False + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + def save(self, path: str): + with open(path, "w") as f: + json.dump(self.__dict__, f, indent=4, sort_keys=True, ensure_ascii=False) + + +@dataclass +class NaiveModelArgs(BaseModelArgs): + model_type: str = "naive" + + +class KVCache(nn.Module): + def __init__( + self, max_batch_size, max_seq_len, n_heads, head_dim, dtype=torch.bfloat16 + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_len, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +@dataclass +class TransformerForwardResult: + token_logits: Tensor + token_targets: Tensor + + +@dataclass +class BaseTransformerForwardResult: + logits: Tensor + hidden_states: Tensor + + +class BaseTransformer(nn.Module): + def __init__( + self, + config: BaseModelArgs, + init_weights: bool = True, + ) -> None: + super().__init__() + self.config = config + + # Slow transformer + self.embeddings = nn.Embedding( + config.vocab_size, + config.dim, + ) + self.layers = nn.ModuleList( + TransformerBlock(config, use_sdpa=True) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + + if self.config.tie_word_embeddings is False: + self.output = nn.Linear( + config.dim, + config.vocab_size, + bias=False, + ) + + self.register_buffer( + "freqs_cis", + precompute_freqs_cis( + config.max_seq_len, + config.dim // config.n_head, + config.rope_base, + ), + persistent=False, + ) + self.register_buffer( + "causal_mask", + torch.tril( + torch.ones( + config.max_seq_len, + config.max_seq_len, + dtype=torch.bool, + ) + ), + persistent=False, + ) + + self.output = nn.Linear( + config.dim, + config.vocab_size, + bias=False, + ) + + # For kv cache + self.max_batch_size = -1 + self.max_seq_len = -1 + + if init_weights: + self.apply(self._init_weights) + + def setup_caches( + self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda" + ): + if self.max_seq_len >= max_seq_len and self.max_batch_size >= max_batch_size: + return + + head_dim = self.config.dim // self.config.n_head + max_seq_len = find_multiple(max_seq_len, 8) + self.max_seq_len = max_seq_len + self.max_batch_size = max_batch_size + + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, + max_seq_len, + self.config.n_local_heads, + head_dim, + dtype=dtype, + ).to(device) + + def embed_base(self, x: Tensor, x_lens: Tensor) -> Tensor: + for bib in range(x.size(0)): + x[bib, x_lens[bib]:] = self.config.vocab_size - 1 + + x_emb = self.embeddings(x) + return x, x_emb + + def forward( + self, + inp: Tensor, + key_padding_mask: Optional[Tensor] = None, + input_pos: Optional[Tensor] = None, + ) -> BaseTransformerForwardResult: + seq_len = inp.size(1) + + # Here we want to merge the embeddings of the codebooks + # x = self.embed(inp) + x = inp.clone() + + if input_pos is None: + freqs_cis = self.freqs_cis[:seq_len].repeat(inp.size(0), 1, 1, 1) + else: + freqs_cis = self.freqs_cis[input_pos] + + # Not that the causal mask here follows the definition of scaled_dot_product_attention + # That is, FALSE means masked out + # To maintain consistency, key_padding_mask use TRUE to mask out + mask = None + if key_padding_mask is not None: + mask = self.causal_mask[None, None, :seq_len, :seq_len] # (B, N, Q, K) + mask = mask & key_padding_mask[:, None, None, :].logical_not() + + for layer in self.layers: + if self.config.use_gradient_checkpointing and self.training: + x = checkpoint(layer, x, freqs_cis, mask, use_reentrant=True) + else: + x = layer(x, freqs_cis, mask) + + # We got slow_out here + slow_out = self.norm(x) + + if self.config.tie_word_embeddings: + token_logits = F.linear(slow_out, self.embeddings.weight) + else: + token_logits = self.output(slow_out) + + return BaseTransformerForwardResult( + logits=token_logits, + hidden_states=x, + ) + + def forward_generate( + self, + inp: Tensor, + input_pos: Optional[Tensor] = None, + kv_pos: Optional[Tensor] = None, + return_all: bool = False, + ) -> BaseTransformerForwardResult: + # This is used for generation, optimized for torch compile + + x = inp + max_seq_len = self.max_seq_len + + mask = self.causal_mask[None, None, kv_pos, :max_seq_len] # (B, N, Q, K) + freqs_cis = self.freqs_cis[input_pos] + + for layer in self.layers: + x = layer(x, freqs_cis, mask, input_pos=kv_pos) + + x = x[:, -1:] + + # We got slow_out here + slow_out = self.norm(x) + + token_logits = self.output(slow_out) + + return BaseTransformerForwardResult( + logits=token_logits, + hidden_states=x, + ) + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + +class NaiveTransformer(BaseTransformer): + def __init__(self, config: NaiveModelArgs) -> None: + super().__init__(config, init_weights=False) + self.apply(self._init_weights) + + def forward( + self, + inp: Tensor, + cond_lens: Tensor, + target: Tensor, + target_lens: Tensor, + key_padding_mask: Optional[Tensor] = None, + input_pos: Optional[Tensor] = None, + ) -> TransformerForwardResult: + parent_result = super().forward( + inp=inp, + key_padding_mask=key_padding_mask, + input_pos=input_pos, + ) + token_logits = parent_result.logits + + # construct targets for token_logits + token_targets = torch.zeros(token_logits.size(0), token_logits.size(1), dtype=torch.long, + device=target.device) - 100 + for bib in range(token_targets.size(0)): + token_targets[bib, cond_lens[bib] + 1:cond_lens[bib] + target_lens[bib] + 1] = target[bib, :target_lens[bib]] + token_targets[bib, cond_lens[bib] + target_lens[bib] + 1] = self.config.vocab_size - 1 + return TransformerForwardResult( + token_logits=token_logits, + token_targets=token_targets, + ) + + def infer_slow(self, inp: Tensor, input_pos: Optional[Tensor] = None): + # no kv cache used + parent_result = super().forward(inp, input_pos=input_pos) + latent = parent_result.hidden_states[:, -1] + base_logits = parent_result.logits[:, -1] + base_sampled, _ = topk_sampling(base_logits, top_k=-1, top_p=1.0) + return base_sampled + + def forward_generate( + self, + x: Tensor, + input_pos: Optional[Tensor] = None, + kv_pos: Optional[Tensor] = None, + vq_masks: Optional[Tensor] = None, + ) -> TransformerForwardResult: + x = super().forward_generate(x, input_pos, kv_pos, vq_masks) + return x + +class NaiveWrapper(nn.Module): + def __init__(self, model: NaiveTransformer) -> None: + super().__init__() + self.model = model + self.sep_token_emb = nn.Parameter(torch.randn(model.config.dim)) + + def setup_caches(self, max_batch_size: int, max_seq_len: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda"): + self.model.setup_caches(max_batch_size, max_seq_len, dtype, device) + + def forward(self, cond: Tensor, cond_lens: Tensor, x: Tensor, x_lens: Tensor) -> torch.Tensor: + # style_emb = self.style_in(style).unsqueeze(1) # [B, 1, D] + sep_token_emb = self.sep_token_emb.expand(x.size(0), 1, -1) + _, x_emb = self.model.embed_base(x, x_lens) + emb_seq_list = [] + for i in range(x.size(0)): + emb_seq = torch.cat([ + sep_token_emb[i:i + 1], + cond[i:i+1, :cond_lens[i]], + sep_token_emb[i:i+1], + x_emb[i:i+1, :x_lens[i]]], dim=1) + emb_seq_list.append(emb_seq) + max_len = max([emb_seq.size(1) for emb_seq in emb_seq_list]) + emb_seq = torch.cat([ + F.pad(emb_seq, (0, 0, 0, max_len - emb_seq.size(1)), value=0) + for emb_seq in emb_seq_list + ], dim=0) + # input_pos = torch.arange(emb_seq.size(1), device=emb_seq.device).repeat(emb_seq.size(0), 1) + input_pos = torch.zeros(emb_seq.size(0), emb_seq.size(1), device=emb_seq.device, dtype=torch.long) + for i in range(x.size(0)): + input_pos[i, :cond_lens[i] + 1] = torch.arange(cond_lens[i] + 1, device=emb_seq.device) + input_pos[i, cond_lens[i] + 1: cond_lens[i] + x_lens[i] + 2] = torch.arange(x_lens[i] + 1, device=emb_seq.device) + out = self.model(emb_seq, cond_lens, x, x_lens, input_pos=input_pos) + loss = F.cross_entropy(out.token_logits.transpose(1, 2), out.token_targets.long(), ignore_index=-100) + return loss + + @torch.no_grad() + def infer(self, cond: Tensor) -> torch.Tensor: + sep_token_emb = self.sep_token_emb.expand(1, 1, -1) + emb_seq = torch.cat([sep_token_emb, cond, sep_token_emb], dim=1) + pred_codes = [] + input_pos = torch.arange(cond.size(1) + 1, device=cond.device) + for i in tqdm(range(4000)): + input_pos = torch.cat([input_pos, torch.LongTensor([i]).to(cond.device)], dim=0) + base = self.model.infer_slow(emb_seq, input_pos) + if base == self.model.config.vocab_size - 1: + break + new_emb = self.model.embed_base(base, torch.LongTensor([1]).to(base.device))[1] + emb_seq = torch.cat([emb_seq, new_emb], dim=1) + pred_codes.append(base) + return torch.cat(pred_codes, dim=-1) + + @torch.no_grad() + def generate( + self, + prompt_text, + prompt_target, + compiled_decode_fn = None, + **sampling_kwargs, + ): + sep_token_emb = self.sep_token_emb.expand(1, 1, -1) + emb_seq = torch.cat([sep_token_emb, prompt_text, sep_token_emb], dim=1) + input_pos = torch.arange(prompt_text.size(1) + 1, device=emb_seq.device) + input_pos = torch.cat([input_pos, torch.LongTensor([0]).to(emb_seq.device)]) + prompt_target_emb = self.model.embed_base(prompt_target,torch.LongTensor([prompt_target.size(1)]).to(prompt_target.device))[1] + emb_seq = torch.cat([emb_seq, prompt_target_emb], dim=1) + input_pos = torch.cat([input_pos, torch.arange(prompt_target_emb.size(1)).to(input_pos.device) + 1]) + + pred_codes = [] + kv_pos = torch.arange(emb_seq.size(1), device=emb_seq.device) + next_tokens = self.decode_one_token_ar(emb_seq, input_pos, kv_pos, suppress_tokens=[self.model.config.vocab_size - 1], **sampling_kwargs) + pred_base = next_tokens[0] + pred_codes.append(pred_base) + new_emb = self.model.embed_base(pred_base.unsqueeze(0), torch.LongTensor([1]).to(pred_base.device))[1] + emb_seq = torch.cat([emb_seq, new_emb], dim=1) + for _ in tqdm(range(4000)): + suppress_eos = len(pred_codes) < 10 + input_pos = input_pos[-1:] + 1 + kv_pos = kv_pos[-1:] + 1 + next_tokens = self.decode_one_token_ar( + emb_seq[:, -1:].reshape(1, 1, -1), + input_pos.reshape(1), + kv_pos.reshape(1), + previous_tokens=torch.cat(pred_codes), + suppress_tokens=[self.model.config.vocab_size - 1] if suppress_eos else None, + compiled_decode_fn=compiled_decode_fn, + **sampling_kwargs) + pred_base = next_tokens[0] + if pred_base == self.model.config.vocab_size - 1: + break + pred_codes.append(pred_base.clone()) + new_emb = self.model.embed_base(pred_base.unsqueeze(0), torch.LongTensor([1]).to(pred_base.device))[1] + emb_seq = torch.cat([emb_seq, new_emb], dim=1) + return torch.stack(pred_codes, dim=-1) + + def decode_one_token_ar( + self, + x: torch.Tensor, + input_pos: torch.Tensor, + kv_pos: torch.Tensor, + previous_tokens: torch.Tensor = None, + compiled_decode_fn = None, + **sampling_kwargs, + ) -> torch.Tensor: + if compiled_decode_fn is not None: + x = compiled_decode_fn(x, input_pos, kv_pos) + else: + x = self.model.forward_generate(x, input_pos, kv_pos) + + sampling_kwargs_main = sampling_kwargs.copy() + codebooks = [ + sample( + x.logits, + previous_tokens=( + previous_tokens[0] if previous_tokens is not None else None + ), + **sampling_kwargs_main, + )[0] + ] + codebooks = torch.stack(codebooks, dim=0) + return codebooks + +class TransformerBlock(nn.Module): + def __init__(self, config: BaseModelArgs, use_sdpa: bool = True) -> None: + super().__init__() + self.attention = Attention(config, use_sdpa=use_sdpa) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Tensor = None + ) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: BaseModelArgs, use_sdpa: bool = True): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear( + config.dim, total_head_dim, bias=config.attention_qkv_bias + ) + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.dropout = config.dropout + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self.use_sdpa = use_sdpa + self._register_load_state_dict_pre_hook(self.load_hook) + self.qk_norm = config.qk_norm + self.qk_norm_groups = 1 + self.qk_norm_scale = 10 + self.qk_norm_dim_scale = False + self.qk_norm_q_scale = self.qk_norm_k_scale = 1 + + if self.qk_norm and self.qk_norm_dim_scale: + self.qk_norm_q_scale = nn.Parameter(torch.ones(self.n_head, 1, self.head_dim)) + self.qk_norm_k_scale = nn.Parameter(torch.ones(self.n_head, 1, self.head_dim)) + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + if self.qk_norm: + qk_l2norm = partial(l2norm, groups = self.qk_norm_groups) + q, k = map(qk_l2norm, (q, k)) + scale = self.qk_norm_scale + + q = q * self.qk_norm_q_scale + k = k * self.qk_norm_k_scale + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + + if self.use_sdpa: + if mask is None: + y = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + # No third party attn_mask here to use flash_attention + ) + else: + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + else: + y = self.eq_scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + ) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + return self.wo(y) + + def eq_scaled_dot_product_attention( + self, + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + ) -> torch.Tensor: + # This is a standard scaled dot product attention + # It's low efficient, but it doesn't raise cuda error + + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) + attn_bias = torch.zeros(1, 1, L, S, dtype=query.dtype, device=query.device) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + + return attn_weight @ value + + +class FeedForward(nn.Module): + def __init__(self, config: BaseModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + self.dropout = nn.Dropout(p=config.dropout) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x))) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis(seq_len: int, n_elem: int, base: int = 10000) -> Tensor: + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=torch.bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(x.size(0), xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) + +def top_k_top_p_filtering( + logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min( + max(top_k, min_tokens_to_keep), logits.size(-1) + ) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1 + ) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + +def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): + # temperature: (`optional`) float + # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + # top_k: (`optional`) int + # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. + # top_p: (`optional`) float + # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. + + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + logits = logits / temperature + # Top-p/top-k filtering + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + # Sample + token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + logprobs = F.log_softmax(logits.float(), dim=-1) + current_logprobs = logprobs[torch.arange(logprobs.shape[0]), token.squeeze(1)] + return token, current_logprobs + +def sample( + logits, + previous_tokens: Optional[torch.Tensor] = None, + **sampling_kwargs, +) -> Tuple[torch.Tensor, torch.Tensor]: + probs = logits_to_probs( + logits=logits[0, -1], previous_tokens=previous_tokens, **sampling_kwargs + ) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs( + logits, + previous_tokens: Optional[torch.Tensor] = None, + suppress_tokens: Optional[List[int]] = None, + temperature: torch.Tensor = 0.7, + top_p: torch.Tensor = 0.7, + repetition_penalty: torch.Tensor = 1.5, +) -> torch.Tensor: + # Apply repetition penalty + if previous_tokens is not None: + previous_tokens = previous_tokens.long() + score = torch.gather(logits, dim=0, index=previous_tokens) + score = torch.where( + score < 0, score * repetition_penalty, score / repetition_penalty + ) + logits.scatter_(dim=0, index=previous_tokens, src=score) + if suppress_tokens is not None: + for token in suppress_tokens: + logits[token] = -float("Inf") + + # Apply top-p sampling + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cum_probs > top_p + sorted_indices_to_remove[0] = False # keep at least one option + indices_to_remove = sorted_indices_to_remove.scatter( + dim=0, index=sorted_indices, src=sorted_indices_to_remove + ) + logits = logits.masked_fill(indices_to_remove, -float("Inf")) + + logits = logits / max(temperature, 1e-5) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs diff --git a/modules/v2/cfm.py b/modules/v2/cfm.py new file mode 100644 index 0000000000000000000000000000000000000000..b0ea58ef15f31c324bbcc061f15be8650824790e --- /dev/null +++ b/modules/v2/cfm.py @@ -0,0 +1,173 @@ +import torch +from tqdm import tqdm + +class CFM(torch.nn.Module): + def __init__( + self, + estimator: torch.nn.Module, + ): + super().__init__() + self.sigma_min = 1e-6 + self.estimator = estimator + self.in_channels = estimator.in_channels + self.criterion = torch.nn.L1Loss() + + @torch.inference_mode() + def inference(self, + mu: torch.Tensor, + x_lens: torch.Tensor, + prompt: torch.Tensor, + style: torch.Tensor, + n_timesteps=10, + temperature=1.0, + inference_cfg_rate=[0.5, 0.5], + random_voice=False, + ): + """Forward diffusion + + Args: + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + x_lens (torch.Tensor): length of each mel-spectrogram + shape: (batch_size,) + prompt (torch.Tensor): prompt + shape: (batch_size, n_feats, prompt_len) + style (torch.Tensor): style + shape: (batch_size, style_dim) + n_timesteps (int): number of diffusion steps + temperature (float, optional): temperature for scaling noise. Defaults to 1.0. + inference_cfg_rate (float, optional): Classifier-Free Guidance inference introduced in VoiceBox. Defaults to 0.5. + + Returns: + sample: generated mel-spectrogram + shape: (batch_size, n_feats, mel_timesteps) + """ + B, T = mu.size(0), mu.size(1) + z = torch.randn([B, self.in_channels, T], device=mu.device) * temperature + t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) + t_span = t_span + (-1) * (torch.cos(torch.pi / 2 * t_span) - 1 + t_span) + return self.solve_euler(z, x_lens, prompt, mu, style, t_span, inference_cfg_rate, random_voice) + def solve_euler(self, x, x_lens, prompt, mu, style, t_span, inference_cfg_rate=[0.5, 0.5], random_voice=False,): + """ + Fixed euler solver for ODEs. + Args: + x (torch.Tensor): random noise + t_span (torch.Tensor): n_timesteps interpolated + shape: (n_timesteps + 1,) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + x_lens (torch.Tensor): length of each mel-spectrogram + shape: (batch_size,) + prompt (torch.Tensor): prompt + shape: (batch_size, n_feats, prompt_len) + style (torch.Tensor): style + shape: (batch_size, style_dim) + inference_cfg_rate (float, optional): Classifier-Free Guidance inference introduced in VoiceBox. Defaults to 0.5. + sway_sampling (bool, optional): Sway sampling. Defaults to False. + amo_sampling (bool, optional): AMO sampling. Defaults to False. + """ + t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] + + # apply prompt + prompt_len = prompt.size(-1) + prompt_x = torch.zeros_like(x) + prompt_x[..., :prompt_len] = prompt[..., :prompt_len] + x[..., :prompt_len] = 0 + for step in tqdm(range(1, len(t_span))): + if random_voice: + cfg_dphi_dt = self.estimator( + torch.cat([x, x], dim=0), + torch.cat([torch.zeros_like(prompt_x), torch.zeros_like(prompt_x)], dim=0), + torch.cat([x_lens, x_lens], dim=0), + torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0), + torch.cat([torch.zeros_like(style), torch.zeros_like(style)], dim=0), + torch.cat([mu, torch.zeros_like(mu)], dim=0), + ) + cond_txt, uncond = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2] + dphi_dt = ((1.0 + inference_cfg_rate[0]) * cond_txt - inference_cfg_rate[0] * uncond) + elif all(i == 0 for i in inference_cfg_rate): + dphi_dt = self.estimator(x, prompt_x, x_lens, t.unsqueeze(0), style, mu) + elif inference_cfg_rate[0] == 0: + # Classifier-Free Guidance inference introduced in VoiceBox + cfg_dphi_dt = self.estimator( + torch.cat([x, x], dim=0), + torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0), + torch.cat([x_lens, x_lens], dim=0), + torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0), + torch.cat([style, torch.zeros_like(style)], dim=0), + torch.cat([mu, mu], dim=0), + ) + cond_txt_spk, cond_txt = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2] + dphi_dt = ((1.0 + inference_cfg_rate[1]) * cond_txt_spk - inference_cfg_rate[1] * cond_txt) + elif inference_cfg_rate[1] == 0: + cfg_dphi_dt = self.estimator( + torch.cat([x, x], dim=0), + torch.cat([prompt_x, torch.zeros_like(prompt_x)], dim=0), + torch.cat([x_lens, x_lens], dim=0), + torch.cat([t.unsqueeze(0), t.unsqueeze(0)], dim=0), + torch.cat([style, torch.zeros_like(style)], dim=0), + torch.cat([mu, torch.zeros_like(mu)], dim=0), + ) + cond_txt_spk, uncond = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2] + dphi_dt = ((1.0 + inference_cfg_rate[0]) * cond_txt_spk - inference_cfg_rate[0] * uncond) + else: + # Multi-condition Classifier-Free Guidance inference introduced in MegaTTS3 + cfg_dphi_dt = self.estimator( + torch.cat([x, x, x], dim=0), + torch.cat([prompt_x, torch.zeros_like(prompt_x), torch.zeros_like(prompt_x)], dim=0), + torch.cat([x_lens, x_lens, x_lens], dim=0), + torch.cat([t.unsqueeze(0), t.unsqueeze(0), t.unsqueeze(0)], dim=0), + torch.cat([style, torch.zeros_like(style), torch.zeros_like(style)], dim=0), + torch.cat([mu, mu, torch.zeros_like(mu)], dim=0), + ) + cond_txt_spk, cond_txt, uncond = cfg_dphi_dt[0:1], cfg_dphi_dt[1:2], cfg_dphi_dt[2:3] + dphi_dt = (1.0 + inference_cfg_rate[0] + inference_cfg_rate[1]) * cond_txt_spk - \ + inference_cfg_rate[0] * uncond - inference_cfg_rate[1] * cond_txt + x = x + dt * dphi_dt + t = t + dt + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + x[:, :, :prompt_len] = 0 + + return x + + def forward(self, x1, x_lens, prompt_lens, mu, style): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_feats, mel_timesteps) + mask (torch.Tensor): target mask + shape: (batch_size, 1, mel_timesteps) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_feats, mel_timesteps) + spks (torch.Tensor, optional): speaker embedding. Defaults to None. + shape: (batch_size, spk_emb_dim) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_feats, mel_timesteps) + """ + b, _, t = x1.shape + + # random timestep + t = torch.rand([b, 1, 1], device=mu.device, dtype=x1.dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + prompt = torch.zeros_like(x1) + for bib in range(b): + prompt[bib, :, :prompt_lens[bib]] = x1[bib, :, :prompt_lens[bib]] + # range covered by prompt are set to 0 + y[bib, :, :prompt_lens[bib]] = 0 + + estimator_out = self.estimator(y, prompt, x_lens, t.squeeze(), style, mu) + loss = 0 + for bib in range(b): + loss += self.criterion(estimator_out[bib, :, prompt_lens[bib]:x_lens[bib]], u[bib, :, prompt_lens[bib]:x_lens[bib]]) + loss /= b + + return loss diff --git a/modules/v2/dit_model.py b/modules/v2/dit_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4374ac86a4d4d0869788cdd16087115c4418ba5f --- /dev/null +++ b/modules/v2/dit_model.py @@ -0,0 +1,250 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional, Union, Tuple, List + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F +import time + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.linear = nn.Linear(d_model, 6 * d_model) + self.act = nn.SiLU() + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, x: Tensor, emb: Tensor) -> Tuple[Tensor]: + emb = self.linear(self.act(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=-1) + + x = self.norm(x) * (1 + scale_msa) + shift_msa + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + +class AdaptiveLayerNormFinal(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNormFinal, self).__init__() + self.linear = nn.Linear(d_model, 2 * d_model) + self.act = nn.SiLU() + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, x: Tensor, emb: Tensor) -> Tuple[Tensor]: + emb = self.linear(self.act(emb)) + scale, shift = torch.chunk(emb, 2, dim=-1) + + x = self.norm(x) * (1 + scale) + shift + return x + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + uvit_skip_connection: bool = False + time_as_token: bool = False + dropout_rate: float = 0.1 + attn_dropout_rate: float = 0.1 + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + # self.head_dim = self.dim // self.n_head + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = AdaptiveLayerNormFinal(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + + self.max_batch_size = -1 + self.max_seq_length = config.block_size + + self.uvit_skip_connection = self.config.uvit_skip_connection + if self.uvit_skip_connection: + self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2] + self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2] + else: + self.layers_emit_skip = [] + self.layers_receive_skip = [] + freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, + self.config.rope_base) + self.register_buffer("freqs_cis", freqs_cis) + + causal_mask = torch.tril( + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) + ) + self.register_buffer("causal_mask", causal_mask) + + def forward(self, + x: Tensor, + c: Tensor, + input_pos: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + ) -> Tensor: + mask = mask[..., input_pos] + freqs_cis = self.freqs_cis[input_pos] + for i, layer in enumerate(self.layers): + x = layer(x, c, freqs_cis, mask) + x = self.norm(x, c) + return x + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps) + self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + def forward(self, + x: Tensor, + c: Tensor, + freqs_cis: Tensor, + mask: Tensor, + ) -> Tensor: + normed_x, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attention_norm(x, emb=c) + # attention + attn_output = self.attention(normed_x, freqs_cis, mask) + x = x + gate_msa * attn_output + normed_x = self.ffn_norm(x) * (1 + scale_mlp) + shift_mlp + ff_output = self.feed_forward(normed_x) + x = x + gate_mlp * ff_output + return x + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs, is_cross_attention: bool = False): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + if is_cross_attention: + self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False) + self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False) + else: + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self.attn_dropout_rate = config.attn_dropout_rate + + def forward(self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + context: Optional[Tensor] = None, + context_freqs_cis: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) + context_seqlen = seqlen + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x))) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000, + dtype: torch.dtype = torch.bfloat16 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) + diff --git a/modules/v2/dit_wrapper.py b/modules/v2/dit_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f653239de475fba17794d6b8d7e28a8edd1b65a0 --- /dev/null +++ b/modules/v2/dit_wrapper.py @@ -0,0 +1,152 @@ +import torch +from torch import nn +import math + +from modules.v2.dit_model import ModelArgs, Transformer +from modules.commons import sequence_mask + +from torch.nn.utils import weight_norm + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000, scale=1000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = scale * t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class DiT(torch.nn.Module): + def __init__( + self, + time_as_token, + style_as_token, + uvit_skip_connection, + block_size, + depth, + num_heads, + hidden_dim, + in_channels, + content_dim, + style_encoder_dim, + class_dropout_prob, + dropout_rate, + attn_dropout_rate, + ): + super(DiT, self).__init__() + self.time_as_token = time_as_token + self.style_as_token = style_as_token + self.uvit_skip_connection = uvit_skip_connection + model_args = ModelArgs( + block_size=block_size, + n_layer=depth, + n_head=num_heads, + dim=hidden_dim, + head_dim=hidden_dim // num_heads, + vocab_size=1, # we don't use this + uvit_skip_connection=self.uvit_skip_connection, + time_as_token=self.time_as_token, + dropout_rate=dropout_rate, + attn_dropout_rate=attn_dropout_rate, + ) + self.transformer = Transformer(model_args) + self.in_channels = in_channels + self.out_channels = in_channels + self.num_heads = num_heads + + self.x_embedder = weight_norm(nn.Linear(in_channels, hidden_dim, bias=True)) + + self.content_dim = content_dim # for continuous content + self.cond_projection = nn.Linear(content_dim, hidden_dim, bias=True) # continuous content + + self.t_embedder = TimestepEmbedder(hidden_dim) + + self.final_mlp = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, in_channels), + ) + + self.class_dropout_prob = class_dropout_prob + + self.cond_x_merge_linear = nn.Linear(hidden_dim + in_channels + in_channels, hidden_dim) + self.style_in = nn.Linear(style_encoder_dim, hidden_dim) + + def forward(self, x, prompt_x, x_lens, t, style, cond): + class_dropout = False + content_dropout = False + if self.training and torch.rand(1) < self.class_dropout_prob: + class_dropout = True + if self.training and torch.rand(1) < 0.5: + content_dropout = True + cond_in_module = self.cond_projection + + B, _, T = x.size() + + t1 = self.t_embedder(t) # (N, D) + cond = cond_in_module(cond) + + x = x.transpose(1, 2) + prompt_x = prompt_x.transpose(1, 2) + + x_in = torch.cat([x, prompt_x, cond], dim=-1) + if class_dropout: + x_in[..., self.in_channels:self.in_channels*2] = 0 + if content_dropout: + x_in[..., self.in_channels*2:] = 0 + x_in = self.cond_x_merge_linear(x_in) # (N, T, D) + + style = self.style_in(style) + style = torch.zeros_like(style) if class_dropout else style + if self.style_as_token: + x_in = torch.cat([style.unsqueeze(1), x_in], dim=1) + if self.time_as_token: + x_in = torch.cat([t1.unsqueeze(1), x_in], dim=1) + x_mask = sequence_mask(x_lens + self.style_as_token + self.time_as_token, max_length=x_in.size(1)).to(x.device).unsqueeze(1) + input_pos = torch.arange(x_in.size(1)).to(x.device) + x_mask_expanded = x_mask[:, None, :].repeat(1, 1, x_in.size(1), 1) + x_res = self.transformer(x_in, t1.unsqueeze(1), input_pos, x_mask_expanded) + x_res = x_res[:, 1:] if self.time_as_token else x_res + x_res = x_res[:, 1:] if self.style_as_token else x_res + x = self.final_mlp(x_res) + x = x.transpose(1, 2) + return x diff --git a/modules/v2/length_regulator.py b/modules/v2/length_regulator.py new file mode 100644 index 0000000000000000000000000000000000000000..7efe5a62bc5afba06a8abe5051aace9ad97dbf3e --- /dev/null +++ b/modules/v2/length_regulator.py @@ -0,0 +1,105 @@ +from typing import Tuple +import torch +import torch.nn as nn +from torch.nn import functional as F +from modules.commons import sequence_mask +import numpy as np + +# f0_bin = 256 +f0_max = 1100.0 +f0_min = 50.0 +f0_mel_min = 1127 * np.log(1 + f0_min / 700) +f0_mel_max = 1127 * np.log(1 + f0_max / 700) + +def f0_to_coarse(f0, f0_bin): + f0_mel = 1127 * (1 + f0 / 700).log() + a = (f0_bin - 2) / (f0_mel_max - f0_mel_min) + b = f0_mel_min * a - 1. + f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel) + # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1)) + f0_coarse = torch.round(f0_mel).long() + f0_coarse = f0_coarse * (f0_coarse > 0) + f0_coarse = f0_coarse + ((f0_coarse < 1) * 1) + f0_coarse = f0_coarse * (f0_coarse < f0_bin) + f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1)) + return f0_coarse + +class InterpolateRegulator(nn.Module): + def __init__( + self, + channels: int, + sampling_ratios: Tuple, + is_discrete: bool = False, + in_channels: int = None, # only applies to continuous input + codebook_size: int = 1024, # for discrete only + out_channels: int = None, + groups: int = 1, + f0_condition: bool = False, + n_f0_bins: int = 512, + ): + super().__init__() + self.sampling_ratios = sampling_ratios + out_channels = out_channels or channels + model = nn.ModuleList([]) + if len(sampling_ratios) > 0: + self.interpolate = True + for _ in sampling_ratios: + module = nn.Conv1d(channels, channels, 3, 1, 1) + norm = nn.GroupNorm(groups, channels) + act = nn.Mish() + model.extend([module, norm, act]) + else: + self.interpolate = False + model.append( + nn.Conv1d(channels, out_channels, 1, 1) if channels != out_channels else nn.Identity() + ) + self.model = nn.Sequential(*model) + self.embedding = nn.Embedding(codebook_size, channels) + self.is_discrete = is_discrete + + self.mask_token = nn.Parameter(torch.zeros(1, channels)) + + if f0_condition: + self.f0_embedding = nn.Embedding(n_f0_bins, channels) + self.f0_condition = f0_condition + self.n_f0_bins = n_f0_bins + self.f0_bins = torch.arange(2, 1024, 1024 // n_f0_bins) + self.f0_mask = nn.Parameter(torch.zeros(1, channels)) + else: + self.f0_condition = False + + if not is_discrete: + self.content_in_proj = nn.Linear(in_channels, channels) + + def forward(self, x, ylens=None, f0=None): + if self.is_discrete: + if len(x.size()) == 2: + x = self.embedding(x) + else: + x = self.embedding(x[:, 0]) + else: + x = self.content_in_proj(x) + # x in (B, T, D) + + if self.interpolate: + mask = sequence_mask(ylens).unsqueeze(-1) + x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest') + else: + x = x.transpose(1, 2).contiguous() + mask = None + # mask = mask[:, :x.size(2), :] + # ylens = ylens.clamp(max=x.size(2)).long() + if self.f0_condition: + if f0 is None: + x = x + self.f0_mask.unsqueeze(-1) + else: + # quantized_f0 = torch.bucketize(f0, self.f0_bins.to(f0.device)) # (N, T) + quantized_f0 = f0_to_coarse(f0, self.n_f0_bins) + quantized_f0 = quantized_f0.clamp(0, self.n_f0_bins - 1).long() + f0_emb = self.f0_embedding(quantized_f0) + f0_emb = F.interpolate(f0_emb.transpose(1, 2).contiguous(), size=ylens.max(), mode='nearest') + x = x + f0_emb + out = self.model(x).transpose(1, 2).contiguous() + out = out * mask if mask is not None else out + olens = ylens + return out, olens diff --git a/modules/v2/model.py b/modules/v2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..a96dd0b6c58991ca3e203ca6c5247dc0413e48b4 --- /dev/null +++ b/modules/v2/model.py @@ -0,0 +1,302 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if embedding is None: + return self.norm(input) + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + n_head: int = 32 + dim: int = 4096 + intermediate_size: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + has_cross_attention: bool = False + context_dim: int = 0 + uvit_skip_connection: bool = False + time_as_token: bool = False + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_head + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + # self.head_dim = self.dim // self.n_head + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) + self.norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + + self.freqs_cis: Optional[Tensor] = None + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length, use_kv_cache=False): + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + dtype = self.norm.project_layer.weight.dtype + device = self.norm.project_layer.weight.device + + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, + self.config.rope_base, dtype).to(device) + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device) + self.use_kv_cache = use_kv_cache + self.uvit_skip_connection = self.config.uvit_skip_connection + if self.uvit_skip_connection: + self.layers_emit_skip = [i for i in range(self.config.n_layer) if i < self.config.n_layer // 2] + self.layers_receive_skip = [i for i in range(self.config.n_layer) if i > self.config.n_layer // 2] + else: + self.layers_emit_skip = [] + self.layers_receive_skip = [] + + def forward(self, + x: Tensor, + c: Tensor, + input_pos: Optional[Tensor] = None, + mask: Optional[Tensor] = None, + context: Optional[Tensor] = None, + context_input_pos: Optional[Tensor] = None, + cross_attention_mask: Optional[Tensor] = None, + ) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + if mask is None: # in case of non-causal model + if not self.training and self.use_kv_cache: + mask = self.causal_mask[None, None, input_pos] + else: + mask = self.causal_mask[None, None, input_pos] + mask = mask[..., input_pos] + freqs_cis = self.freqs_cis[input_pos] + if context is not None: + context_freqs_cis = self.freqs_cis[context_input_pos] + else: + context_freqs_cis = None + skip_in_x_list = [] + for i, layer in enumerate(self.layers): + if self.uvit_skip_connection and i in self.layers_receive_skip: + skip_in_x = skip_in_x_list.pop(-1) + else: + skip_in_x = None + x = layer(x, c, input_pos, freqs_cis, mask, context, context_freqs_cis, cross_attention_mask, skip_in_x) + if self.uvit_skip_connection and i in self.layers_emit_skip: + skip_in_x_list.append(x) + x = self.norm(x, c) + return x + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + self.attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + + if config.has_cross_attention: + self.has_cross_attention = True + self.cross_attention = Attention(config, is_cross_attention=True) + self.cross_attention_norm = AdaptiveLayerNorm(config.dim, RMSNorm(config.dim, eps=config.norm_eps)) + else: + self.has_cross_attention = False + + if config.uvit_skip_connection: + self.skip_in_linear = nn.Linear(config.dim * 2, config.dim) + self.uvit_skip_connection = True + else: + self.uvit_skip_connection = False + + self.time_as_token = config.time_as_token + + def forward(self, + x: Tensor, + c: Tensor, + input_pos: Tensor, + freqs_cis: Tensor, + mask: Tensor, + context: Optional[Tensor] = None, + context_freqs_cis: Optional[Tensor] = None, + cross_attention_mask: Optional[Tensor] = None, + skip_in_x: Optional[Tensor] = None, + ) -> Tensor: + c = None if self.time_as_token else c + if self.uvit_skip_connection and skip_in_x is not None: + x = self.skip_in_linear(torch.cat([x, skip_in_x], dim=-1)) + h = x + self.attention(self.attention_norm(x, c), freqs_cis, mask, input_pos) + if self.has_cross_attention: + h = h + self.cross_attention(self.cross_attention_norm(h, c), freqs_cis, cross_attention_mask, input_pos, context, context_freqs_cis) + out = h + self.feed_forward(self.ffn_norm(h, c)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs, is_cross_attention: bool = False): + super().__init__() + assert config.dim % config.n_head == 0 + + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim + # key, query, value projections for all heads, but in a batch + if is_cross_attention: + self.wq = nn.Linear(config.dim, config.n_head * config.head_dim, bias=False) + self.wkv = nn.Linear(config.context_dim, 2 * config.n_local_heads * config.head_dim, bias=False) + else: + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wo = nn.Linear(config.head_dim * config.n_head, config.dim, bias=False) + self.kv_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + # self._register_load_state_dict_pre_hook(self.load_hook) + + # def load_hook(self, state_dict, prefix, *args): + # if prefix + "wq.weight" in state_dict: + # wq = state_dict.pop(prefix + "wq.weight") + # wk = state_dict.pop(prefix + "wk.weight") + # wv = state_dict.pop(prefix + "wv.weight") + # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward(self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + context: Optional[Tensor] = None, + context_freqs_cis: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_local_heads * self.head_dim + if context is None: + q, k, v = self.wqkv(x).split([kv_size, kv_size, kv_size], dim=-1) + context_seqlen = seqlen + else: + q = self.wq(x) + k, v = self.wkv(context).split([kv_size, kv_size], dim=-1) + context_seqlen = context.shape[1] + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, context_seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, context_freqs_cis if context_freqs_cis is not None else freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.head_dim * self.n_head) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) + self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +def precompute_freqs_cis( + seq_len: int, n_elem: int, base: int = 10000, + dtype: torch.dtype = torch.bfloat16 +) -> Tensor: + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/modules/v2/vc_wrapper.py b/modules/v2/vc_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..359e805001c831bbc257f2da7d377a32967a27df --- /dev/null +++ b/modules/v2/vc_wrapper.py @@ -0,0 +1,606 @@ +import spaces +import torch +import librosa +import torchaudio +import numpy as np +from pydub import AudioSegment +from hf_utils import load_custom_model_from_hf + +DEFAULT_REPO_ID = "Plachta/Seed-VC" +DEFAULT_CFM_CHECKPOINT = "v2/cfm_small.pth" +DEFAULT_AR_CHECKPOINT = "v2/ar_base.pth" + +DEFAULT_CE_REPO_ID = "Plachta/ASTRAL-quantization" +DEFAULT_CE_NARROW_CHECKPOINT = "bsq32/bsq32_light.pth" +DEFAULT_CE_WIDE_CHECKPOINT = "bsq2048/bsq2048_light.pth" + +DEFAULT_SE_REPO_ID = "funasr/campplus" +DEFAULT_SE_CHECKPOINT = "campplus_cn_common.bin" + +class VoiceConversionWrapper(torch.nn.Module): + def __init__( + self, + sr: int, + hop_size: int, + mel_fn: callable, + cfm: torch.nn.Module, + cfm_length_regulator: torch.nn.Module, + content_extractor_narrow: torch.nn.Module, + content_extractor_wide: torch.nn.Module, + ar_length_regulator: torch.nn.Module, + ar: torch.nn.Module, + style_encoder: torch.nn.Module, + vocoder: torch.nn.Module, + ): + super(VoiceConversionWrapper, self).__init__() + self.sr = sr + self.hop_size = hop_size + self.mel_fn = mel_fn + self.cfm = cfm + self.cfm_length_regulator = cfm_length_regulator + self.content_extractor_narrow = content_extractor_narrow + self.content_extractor_wide = content_extractor_wide + self.vocoder = vocoder + self.ar_length_regulator = ar_length_regulator + self.ar = ar + self.style_encoder = style_encoder + # Set streaming parameters + self.overlap_frame_len = 16 + self.bitrate = "320k" + self.compiled_decode_fn = None + self.dit_compiled = False + self.dit_max_context_len = 30 # in seconds + self.compile_len = 87 * self.dit_max_context_len + + def compile_ar(self): + """ + Compile the AR model for inference. + """ + self.compiled_decode_fn = torch.compile( + self.ar.model.forward_generate, + fullgraph=True, + backend="inductor" if torch.cuda.is_available() else "aot_eager", + mode="reduce-overhead" if torch.cuda.is_available() else None, + ) + + def compile_cfm(self): + self.cfm.estimator.transformer = torch.compile( + self.cfm.estimator.transformer, + fullgraph=True, + backend="inductor" if torch.cuda.is_available() else "aot_eager", + mode="reduce-overhead" if torch.cuda.is_available() else None, + ) + self.dit_compiled = True + + @staticmethod + def strip_prefix(state_dict: dict, prefix: str = "module.") -> dict: + """ + Strip the prefix from the state_dict keys. + """ + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith(prefix): + new_key = k[len(prefix):] + else: + new_key = k + new_state_dict[new_key] = v + return new_state_dict + + @staticmethod + def duration_reduction_func(token_seq, n_gram=1): + """ + Args: + token_seq: (T,) + Returns: + reduced_token_seq: (T') + reduced_token_seq_len: T' + """ + n_gram_seq = token_seq.unfold(0, n_gram, 1) + mask = torch.all(n_gram_seq[1:] != n_gram_seq[:-1], dim=1) + reduced_token_seq = torch.cat( + (n_gram_seq[0, :n_gram], n_gram_seq[1:, -1][mask]) + ) + return reduced_token_seq, len(reduced_token_seq) + + @staticmethod + def crossfade(chunk1, chunk2, overlap): + """Apply crossfade between two audio chunks.""" + fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2 + fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2 + if len(chunk2) < overlap: + chunk2[:overlap] = chunk2[:overlap] * fade_in[:len(chunk2)] + (chunk1[-overlap:] * fade_out)[:len(chunk2)] + else: + chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out + return chunk2 + + def _stream_wave_chunks(self, vc_wave, processed_frames, vc_mel, overlap_wave_len, + generated_wave_chunks, previous_chunk, is_last_chunk, stream_output): + """ + Helper method to handle streaming wave chunks. + + Args: + vc_wave: The current wave chunk + processed_frames: Number of frames processed so far + vc_mel: The mel spectrogram + overlap_wave_len: Length of overlap between chunks + generated_wave_chunks: List of generated wave chunks + previous_chunk: Previous wave chunk for crossfading + is_last_chunk: Whether this is the last chunk + stream_output: Whether to stream the output + + Returns: + Tuple of (processed_frames, previous_chunk, should_break, mp3_bytes, full_audio) + where should_break indicates if processing should stop + mp3_bytes is the MP3 bytes if streaming, None otherwise + full_audio is the full audio if this is the last chunk, None otherwise + """ + mp3_bytes = None + full_audio = None + + if processed_frames == 0: + if is_last_chunk: + output_wave = vc_wave[0].cpu().numpy() + generated_wave_chunks.append(output_wave) + + if stream_output: + output_wave_int16 = (output_wave * 32768.0).astype(np.int16) + mp3_bytes = AudioSegment( + output_wave_int16.tobytes(), frame_rate=self.sr, + sample_width=output_wave_int16.dtype.itemsize, channels=1 + ).export(format="mp3", bitrate=self.bitrate).read() + full_audio = (self.sr, np.concatenate(generated_wave_chunks)) + else: + return processed_frames, previous_chunk, True, None, np.concatenate(generated_wave_chunks) + + return processed_frames, previous_chunk, True, mp3_bytes, full_audio + + output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy() + generated_wave_chunks.append(output_wave) + previous_chunk = vc_wave[0, -overlap_wave_len:] + processed_frames += vc_mel.size(2) - self.overlap_frame_len + + if stream_output: + output_wave_int16 = (output_wave * 32768.0).astype(np.int16) + mp3_bytes = AudioSegment( + output_wave_int16.tobytes(), frame_rate=self.sr, + sample_width=output_wave_int16.dtype.itemsize, channels=1 + ).export(format="mp3", bitrate=self.bitrate).read() + + elif is_last_chunk: + output_wave = self.crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len) + generated_wave_chunks.append(output_wave) + processed_frames += vc_mel.size(2) - self.overlap_frame_len + + if stream_output: + output_wave_int16 = (output_wave * 32768.0).astype(np.int16) + mp3_bytes = AudioSegment( + output_wave_int16.tobytes(), frame_rate=self.sr, + sample_width=output_wave_int16.dtype.itemsize, channels=1 + ).export(format="mp3", bitrate=self.bitrate).read() + full_audio = (self.sr, np.concatenate(generated_wave_chunks)) + else: + return processed_frames, previous_chunk, True, None, np.concatenate(generated_wave_chunks) + + return processed_frames, previous_chunk, True, mp3_bytes, full_audio + + else: + output_wave = self.crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len) + generated_wave_chunks.append(output_wave) + previous_chunk = vc_wave[0, -overlap_wave_len:] + processed_frames += vc_mel.size(2) - self.overlap_frame_len + + if stream_output: + output_wave_int16 = (output_wave * 32768.0).astype(np.int16) + mp3_bytes = AudioSegment( + output_wave_int16.tobytes(), frame_rate=self.sr, + sample_width=output_wave_int16.dtype.itemsize, channels=1 + ).export(format="mp3", bitrate=self.bitrate).read() + + return processed_frames, previous_chunk, False, mp3_bytes, full_audio + + def load_checkpoints( + self, + cfm_checkpoint_path = None, + ar_checkpoint_path = None, + ): + if cfm_checkpoint_path is None: + cfm_checkpoint_path = load_custom_model_from_hf( + repo_id=DEFAULT_REPO_ID, + model_filename=DEFAULT_CFM_CHECKPOINT, + ) + if ar_checkpoint_path is None: + ar_checkpoint_path = load_custom_model_from_hf( + repo_id=DEFAULT_REPO_ID, + model_filename=DEFAULT_AR_CHECKPOINT, + ) + # cfm + cfm_checkpoint = torch.load(cfm_checkpoint_path, map_location="cpu") + cfm_length_regulator_state_dict = self.strip_prefix(cfm_checkpoint["net"]['length_regulator'], "module.") + cfm_state_dict = self.strip_prefix(cfm_checkpoint["net"]['cfm'], "module.") + self.cfm.load_state_dict(cfm_state_dict, strict=False) + self.cfm_length_regulator.load_state_dict(cfm_length_regulator_state_dict, strict=False) + + # ar + ar_checkpoint = torch.load(ar_checkpoint_path, map_location="cpu") + ar_length_regulator_state_dict = self.strip_prefix(ar_checkpoint["net"]['length_regulator'], "module.") + ar_state_dict = self.strip_prefix(ar_checkpoint["net"]['ar'], "module.") + self.ar.load_state_dict(ar_state_dict, strict=False) + self.ar_length_regulator.load_state_dict(ar_length_regulator_state_dict, strict=False) + + # content extractor + content_extractor_narrow_checkpoint_path = load_custom_model_from_hf( + repo_id=DEFAULT_CE_REPO_ID, + model_filename=DEFAULT_CE_NARROW_CHECKPOINT, + ) + content_extractor_narrow_checkpoint = torch.load(content_extractor_narrow_checkpoint_path, map_location="cpu") + self.content_extractor_narrow.load_state_dict( + content_extractor_narrow_checkpoint, strict=False + ) + + content_extractor_wide_checkpoint_path = load_custom_model_from_hf( + repo_id=DEFAULT_CE_REPO_ID, + model_filename=DEFAULT_CE_WIDE_CHECKPOINT, + ) + content_extractor_wide_checkpoint = torch.load(content_extractor_wide_checkpoint_path, map_location="cpu") + self.content_extractor_wide.load_state_dict( + content_extractor_wide_checkpoint, strict=False + ) + + # style encoder + style_encoder_checkpoint_path = load_custom_model_from_hf(DEFAULT_SE_REPO_ID, DEFAULT_SE_CHECKPOINT, config_filename=None) + style_encoder_checkpoint = torch.load(style_encoder_checkpoint_path, map_location="cpu") + self.style_encoder.load_state_dict(style_encoder_checkpoint, strict=False) + + def setup_ar_caches(self, max_batch_size=1, max_seq_len=4096, dtype=torch.float32, device=torch.device("cpu")): + self.ar.setup_caches(max_batch_size=max_batch_size, max_seq_len=max_seq_len, dtype=dtype, device=device) + + def compute_style(self, waves_16k: torch.Tensor): + feat = torchaudio.compliance.kaldi.fbank(waves_16k, + num_mel_bins=80, + dither=0, + sample_frequency=16000) + feat = feat - feat.mean(dim=0, keepdim=True) + style = self.style_encoder(feat.unsqueeze(0)) + return style + + @torch.no_grad() + @torch.inference_mode() + def convert_timbre( + self, + source_audio_path: str, + target_audio_path: str, + diffusion_steps: int = 30, + length_adjust: float = 1.0, + inference_cfg_rate: float = 0.5, + use_sway_sampling: bool = False, + use_amo_sampling: bool = False, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ): + source_wave = librosa.load(source_audio_path, sr=self.sr)[0] + target_wave = librosa.load(target_audio_path, sr=self.sr)[0] + source_wave_tensor = torch.tensor(source_wave).unsqueeze(0).to(device) + target_wave_tensor = torch.tensor(target_wave).unsqueeze(0).to(device) + + # get 16khz audio + source_wave_16k = librosa.resample(source_wave, orig_sr=self.sr, target_sr=16000) + target_wave_16k = librosa.resample(target_wave, orig_sr=self.sr, target_sr=16000) + source_wave_16k_tensor = torch.tensor(source_wave_16k).unsqueeze(0).to(device) + target_wave_16k_tensor = torch.tensor(target_wave_16k).unsqueeze(0).to(device) + + # compute mel spectrogram + source_mel = self.mel_fn(source_wave_tensor) + target_mel = self.mel_fn(target_wave_tensor) + source_mel_len = source_mel.size(2) + target_mel_len = target_mel.size(2) + + with torch.autocast(device_type=device.type, dtype=dtype): + # compute content features + _, source_content_indices, _ = self.content_extractor_wide(source_wave_16k_tensor, [source_wave_16k.size]) + _, target_content_indices, _ = self.content_extractor_wide(target_wave_16k_tensor, [target_wave_16k.size]) + + # compute style features + target_style = self.compute_style(target_wave_16k_tensor) + + # Length regulation + cond, _ = self.cfm_length_regulator(source_content_indices, ylens=torch.LongTensor([source_mel_len]).to(device)) + prompt_condition, _, = self.cfm_length_regulator(target_content_indices, ylens=torch.LongTensor([target_mel_len]).to(device)) + + cat_condition = torch.cat([prompt_condition, cond], dim=1) + # generate mel spectrogram + vc_mel = self.cfm.inference( + cat_condition, + torch.LongTensor([cat_condition.size(1)]).to(device), + target_mel, target_style, diffusion_steps, + inference_cfg_rate=inference_cfg_rate, + sway_sampling=use_sway_sampling, + amo_sampling=use_amo_sampling, + ) + vc_mel = vc_mel[:, :, target_mel_len:] + vc_wave = self.vocoder(vc_mel.float()).squeeze()[None] + return vc_wave.cpu().numpy() + + @torch.no_grad() + @torch.inference_mode() + def convert_voice( + self, + source_audio_path: str, + target_audio_path: str, + diffusion_steps: int = 30, + length_adjust: float = 1.0, + inference_cfg_rate: float = 0.5, + top_p: float = 0.7, + temperature: float = 0.7, + repetition_penalty: float = 1.5, + use_sway_sampling: bool = False, + use_amo_sampling: bool = False, + device: torch.device = torch.device("cpu"), + dtype: torch.dtype = torch.float32, + ): + source_wave = librosa.load(source_audio_path, sr=self.sr)[0] + target_wave = librosa.load(target_audio_path, sr=self.sr)[0] + source_wave_tensor = torch.tensor(source_wave).unsqueeze(0).to(device) + target_wave_tensor = torch.tensor(target_wave).unsqueeze(0).to(device) + + # get 16khz audio + source_wave_16k = librosa.resample(source_wave, orig_sr=self.sr, target_sr=16000) + target_wave_16k = librosa.resample(target_wave, orig_sr=self.sr, target_sr=16000) + source_wave_16k_tensor = torch.tensor(source_wave_16k).unsqueeze(0).to(device) + target_wave_16k_tensor = torch.tensor(target_wave_16k).unsqueeze(0).to(device) + + # compute mel spectrogram + source_mel = self.mel_fn(source_wave_tensor) + target_mel = self.mel_fn(target_wave_tensor) + source_mel_len = source_mel.size(2) + target_mel_len = target_mel.size(2) + + with torch.autocast(device_type=device.type, dtype=dtype): + # compute content features + _, source_content_indices, _ = self.content_extractor_wide(source_wave_16k_tensor, [source_wave_16k.size]) + _, target_content_indices, _ = self.content_extractor_wide(target_wave_16k_tensor, [target_wave_16k.size]) + + _, source_narrow_indices, _ = self.content_extractor_narrow(source_wave_16k_tensor, + [source_wave_16k.size], ssl_model=self.content_extractor_wide.ssl_model) + _, target_narrow_indices, _ = self.content_extractor_narrow(target_wave_16k_tensor, + [target_wave_16k.size], ssl_model=self.content_extractor_wide.ssl_model) + + src_narrow_reduced, src_narrow_len = self.duration_reduction_func(source_narrow_indices[0], 1) + tgt_narrow_reduced, tgt_narrow_len = self.duration_reduction_func(target_narrow_indices[0], 1) + + ar_cond = self.ar_length_regulator(torch.cat([tgt_narrow_reduced, src_narrow_reduced], dim=0)[None])[0] + + ar_out = self.ar.generate(ar_cond, target_content_indices, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty) + ar_out_mel_len = torch.LongTensor([int(source_mel_len / source_content_indices.size(-1) * ar_out.size(-1) * length_adjust)]).to(device) + # compute style features + target_style = self.compute_style(target_wave_16k_tensor) + + # Length regulation + cond, _ = self.cfm_length_regulator(ar_out, ylens=torch.LongTensor([ar_out_mel_len]).to(device)) + prompt_condition, _, = self.cfm_length_regulator(target_content_indices, ylens=torch.LongTensor([target_mel_len]).to(device)) + + cat_condition = torch.cat([prompt_condition, cond], dim=1) + # generate mel spectrogram + vc_mel = self.cfm.inference( + cat_condition, + torch.LongTensor([cat_condition.size(1)]).to(device), + target_mel, target_style, diffusion_steps, + inference_cfg_rate=inference_cfg_rate, + sway_sampling=use_sway_sampling, + amo_sampling=use_amo_sampling, + ) + vc_mel = vc_mel[:, :, target_mel_len:] + vc_wave = self.vocoder(vc_mel.float()).squeeze()[None] + return vc_wave.cpu().numpy() + + def _process_content_features(self, audio_16k_tensor, is_narrow=False): + """Process audio through Whisper model to extract features.""" + content_extractor_fn = self.content_extractor_narrow if is_narrow else self.content_extractor_wide + if audio_16k_tensor.size(-1) <= 16000 * 30: + # Compute content features + _, content_indices, _ = content_extractor_fn(audio_16k_tensor, [audio_16k_tensor.size(-1)], ssl_model=self.content_extractor_wide.ssl_model) + else: + # Process long audio in chunks + overlapping_time = 5 # 5 seconds + features_list = [] + buffer = None + traversed_time = 0 + while traversed_time < audio_16k_tensor.size(-1): + if buffer is None: # first chunk + chunk = audio_16k_tensor[:, traversed_time:traversed_time + 16000 * 30] + else: + chunk = torch.cat([ + buffer, + audio_16k_tensor[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)] + ], dim=-1) + _, chunk_content_indices, _ = content_extractor_fn(chunk, [chunk.size(-1)], ssl_model=self.content_extractor_wide.ssl_model) + if traversed_time == 0: + features_list.append(chunk_content_indices) + else: + features_list.append(chunk_content_indices[:, 50 * overlapping_time:]) + buffer = chunk[:, -16000 * overlapping_time:] + traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time + content_indices = torch.cat(features_list, dim=1) + + return content_indices + + @spaces.GPU + @torch.no_grad() + @torch.inference_mode() + def convert_voice_with_streaming( + self, + source_audio_path: str, + target_audio_path: str, + diffusion_steps: int = 30, + length_adjust: float = 1.0, + intelligebility_cfg_rate: float = 0.7, + similarity_cfg_rate: float = 0.7, + top_p: float = 0.7, + temperature: float = 0.7, + repetition_penalty: float = 1.5, + convert_style: bool = False, + anonymization_only: bool = False, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.float16, + stream_output: bool = True, + ): + """ + Convert voice with streaming support for long audio files. + + Args: + source_audio_path: Path to source audio file + target_audio_path: Path to target audio file + diffusion_steps: Number of diffusion steps (default: 30) + length_adjust: Length adjustment factor (default: 1.0) + intelligebility_cfg_rate: CFG rate for intelligibility (default: 0.7) + similarity_cfg_rate: CFG rate for similarity (default: 0.7) + top_p: Top-p sampling parameter (default: 0.7) + temperature: Temperature for sampling (default: 0.7) + repetition_penalty: Repetition penalty (default: 1.5) + device: Device to use (default: cpu) + dtype: Data type to use (default: float32) + stream_output: Whether to stream the output (default: True) + + Returns: + If stream_output is True, yields (mp3_bytes, full_audio) tuples + If stream_output is False, returns the full audio as a numpy array + """ + # Load audio + source_wave = librosa.load(source_audio_path, sr=self.sr)[0] + target_wave = librosa.load(target_audio_path, sr=self.sr)[0] + + # Limit target audio to 25 seconds + target_wave = target_wave[:self.sr * (self.dit_max_context_len - 5)] + + source_wave_tensor = torch.tensor(source_wave).unsqueeze(0).float().to(device) + target_wave_tensor = torch.tensor(target_wave).unsqueeze(0).float().to(device) + + # Resample to 16kHz for feature extraction + source_wave_16k = librosa.resample(source_wave, orig_sr=self.sr, target_sr=16000) + target_wave_16k = librosa.resample(target_wave, orig_sr=self.sr, target_sr=16000) + source_wave_16k_tensor = torch.tensor(source_wave_16k).unsqueeze(0).to(device) + target_wave_16k_tensor = torch.tensor(target_wave_16k).unsqueeze(0).to(device) + + # Compute mel spectrograms + source_mel = self.mel_fn(source_wave_tensor) + target_mel = self.mel_fn(target_wave_tensor) + source_mel_len = source_mel.size(2) + target_mel_len = target_mel.size(2) + + # Set up chunk processing parameters + max_context_window = self.sr // self.hop_size * self.dit_max_context_len + overlap_wave_len = self.overlap_frame_len * self.hop_size + + with torch.autocast(device_type=device.type, dtype=dtype): + # Compute content features + source_content_indices = self._process_content_features(source_wave_16k_tensor, is_narrow=False) + target_content_indices = self._process_content_features(target_wave_16k_tensor, is_narrow=False) + # Compute style features + target_style = self.compute_style(target_wave_16k_tensor) + prompt_condition, _, = self.cfm_length_regulator(target_content_indices, + ylens=torch.LongTensor([target_mel_len]).to(device)) + + # prepare for streaming + generated_wave_chunks = [] + processed_frames = 0 + previous_chunk = None + if convert_style: + with torch.autocast(device_type=device.type, dtype=dtype): + source_narrow_indices = self._process_content_features(source_wave_16k_tensor, is_narrow=True) + target_narrow_indices = self._process_content_features(target_wave_16k_tensor, is_narrow=True) + src_narrow_reduced, src_narrow_len = self.duration_reduction_func(source_narrow_indices[0], 1) + tgt_narrow_reduced, tgt_narrow_len = self.duration_reduction_func(target_narrow_indices[0], 1) + # Process src_narrow_reduced in chunks of max 1000 tokens + max_chunk_size = 1000 + + # Process src_narrow_reduced in chunks + for i in range(0, len(src_narrow_reduced), max_chunk_size): + is_last_chunk = i + max_chunk_size >= len(src_narrow_reduced) + with torch.autocast(device_type=device.type, dtype=dtype): + chunk = src_narrow_reduced[i:i + max_chunk_size] + if anonymization_only: + chunk_ar_cond = self.ar_length_regulator(chunk[None])[0] + chunk_ar_out = self.ar.generate(chunk_ar_cond, torch.zeros([1, 0]).long().to(device), + compiled_decode_fn=self.compiled_decode_fn, + top_p=top_p, temperature=temperature, + repetition_penalty=repetition_penalty) + else: + # For each chunk, we need to include tgt_narrow_reduced as context + chunk_ar_cond = self.ar_length_regulator(torch.cat([tgt_narrow_reduced, chunk], dim=0)[None])[0] + chunk_ar_out = self.ar.generate(chunk_ar_cond, target_content_indices, compiled_decode_fn=self.compiled_decode_fn, + top_p=top_p, temperature=temperature, + repetition_penalty=repetition_penalty) + chunkar_out_mel_len = torch.LongTensor([int(source_mel_len / source_content_indices.size( + -1) * chunk_ar_out.size(-1) * length_adjust)]).to(device) + # Length regulation + chunk_cond, _ = self.cfm_length_regulator(chunk_ar_out, ylens=torch.LongTensor([chunkar_out_mel_len]).to(device)) + cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1) + original_len = cat_condition.size(1) + # pad cat_condition to compile_len + if self.dit_compiled: + cat_condition = torch.nn.functional.pad(cat_condition, + (0, 0, 0, self.compile_len - cat_condition.size(1),), + value=0) + # Voice Conversion + vc_mel = self.cfm.inference( + cat_condition, + torch.LongTensor([original_len]).to(device), + target_mel, target_style, diffusion_steps, + inference_cfg_rate=[intelligebility_cfg_rate, similarity_cfg_rate], + random_voice=anonymization_only, + ) + vc_mel = vc_mel[:, :, target_mel_len:original_len] + vc_wave = self.vocoder(vc_mel).squeeze()[None] + processed_frames, previous_chunk, should_break, mp3_bytes, full_audio = self._stream_wave_chunks( + vc_wave, processed_frames, vc_mel, overlap_wave_len, + generated_wave_chunks, previous_chunk, is_last_chunk, stream_output + ) + + if stream_output and mp3_bytes is not None: + yield mp3_bytes, full_audio + + if should_break: + if not stream_output: + return full_audio + break + else: + cond, _ = self.cfm_length_regulator(source_content_indices, ylens=torch.LongTensor([source_mel_len]).to(device)) + + # Process in chunks for streaming + max_source_window = max_context_window - target_mel.size(2) + + # Generate chunk by chunk and stream the output + while processed_frames < cond.size(1): + chunk_cond = cond[:, processed_frames:processed_frames + max_source_window] + is_last_chunk = processed_frames + max_source_window >= cond.size(1) + cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1) + original_len = cat_condition.size(1) + # pad cat_condition to compile_len + if self.dit_compiled: + cat_condition = torch.nn.functional.pad(cat_condition, + (0, 0, 0, self.compile_len - cat_condition.size(1),), value=0) + with torch.autocast(device_type=device.type, dtype=dtype): + # Voice Conversion + vc_mel = self.cfm.inference( + cat_condition, + torch.LongTensor([original_len]).to(device), + target_mel, target_style, diffusion_steps, + inference_cfg_rate=[intelligebility_cfg_rate, similarity_cfg_rate], + random_voice=anonymization_only, + ) + vc_mel = vc_mel[:, :, target_mel_len:original_len] + vc_wave = self.vocoder(vc_mel).squeeze()[None] + + processed_frames, previous_chunk, should_break, mp3_bytes, full_audio = self._stream_wave_chunks( + vc_wave, processed_frames, vc_mel, overlap_wave_len, + generated_wave_chunks, previous_chunk, is_last_chunk, stream_output + ) + + if stream_output and mp3_bytes is not None: + yield mp3_bytes, full_audio + + if should_break: + if not stream_output: + return full_audio + break + + diff --git a/requirements.txt b/requirements.txt index e608c6a787058579c732680c3337566a0e36b259..4fa463bd62265380f5549538f043fa358200ef01 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,14 +1,24 @@ ---extra-index-url https://download.pytorch.org/whl/cu113 -torch -torchvision -torchaudio -scipy==1.13.1 -onnxruntime-gpu==1.19.0 -librosa==0.10.2 -huggingface-hub -munch -einops -descript-audio-codec -git+https://github.com/openai/whisper.git -pydub -transformers \ No newline at end of file +--extra-index-url https://download.pytorch.org/whl/cu121 +torch==2.4.0 +torchvision==0.19.0 +torchaudio==2.4.0 +scipy==1.13.1 +librosa==0.10.2 +huggingface-hub==0.23.4 +munch==4.0.0 +einops==0.8.0 +descript-audio-codec==1.0.0 +gradio==5.23.0 +pydub==0.25.1 +resemblyzer +jiwer==3.0.3 +transformers==4.46.3 +FreeSimpleGUI==5.1.1 +soundfile==0.12.1 +sounddevice==0.5.0 +modelscope==1.18.1 +funasr==1.1.5 +numpy==1.26.4 +hydra-core==1.3.2 +pyyaml +python-dotenv diff --git a/seed_vc_wrapper.py b/seed_vc_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..14caf0cfd3a18861e1a2b2c5fc3774aaa244e353 --- /dev/null +++ b/seed_vc_wrapper.py @@ -0,0 +1,463 @@ +import spaces +import torch +import torchaudio +import librosa +import numpy as np +from pydub import AudioSegment +import yaml +from modules.commons import build_model, load_checkpoint, recursive_munch +from hf_utils import load_custom_model_from_hf +from modules.campplus.DTDNN import CAMPPlus +from modules.bigvgan import bigvgan +from modules.audio import mel_spectrogram +from modules.rmvpe import RMVPE +from transformers import AutoFeatureExtractor, WhisperModel + +class SeedVCWrapper: + def __init__(self, device=None): + """ + Initialize the Seed-VC wrapper with all necessary models and configurations. + + Args: + device: torch device to use. If None, will be automatically determined. + """ + # Set device + if device is None: + if torch.cuda.is_available(): + self.device = torch.device("cuda") + elif torch.backends.mps.is_available(): + self.device = torch.device("mps") + else: + self.device = torch.device("cpu") + else: + self.device = device + + # Load base model and configuration + self._load_base_model() + + # Load F0 conditioned model + self._load_f0_model() + + # Load additional modules + self._load_additional_modules() + + # Set streaming parameters + self.overlap_frame_len = 16 + self.bitrate = "320k" + + def _load_base_model(self): + """Load the base DiT model for voice conversion.""" + dit_checkpoint_path, dit_config_path = load_custom_model_from_hf( + "Plachta/Seed-VC", + "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth", + "config_dit_mel_seed_uvit_whisper_small_wavenet.yml" + ) + config = yaml.safe_load(open(dit_config_path, 'r')) + model_params = recursive_munch(config['model_params']) + self.model = build_model(model_params, stage='DiT') + self.hop_length = config['preprocess_params']['spect_params']['hop_length'] + self.sr = config['preprocess_params']['sr'] + + # Load checkpoints + self.model, _, _, _ = load_checkpoint( + self.model, None, dit_checkpoint_path, + load_only_params=True, ignore_modules=[], is_distributed=False + ) + for key in self.model: + self.model[key].eval() + self.model[key].to(self.device) + self.model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192) + + # Set up mel spectrogram function + mel_fn_args = { + "n_fft": config['preprocess_params']['spect_params']['n_fft'], + "win_size": config['preprocess_params']['spect_params']['win_length'], + "hop_size": config['preprocess_params']['spect_params']['hop_length'], + "num_mels": config['preprocess_params']['spect_params']['n_mels'], + "sampling_rate": self.sr, + "fmin": 0, + "fmax": None, + "center": False + } + self.to_mel = lambda x: mel_spectrogram(x, **mel_fn_args) + + # Load whisper model + whisper_name = model_params.speech_tokenizer.whisper_name if hasattr(model_params.speech_tokenizer, 'whisper_name') else "openai/whisper-small" + self.whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(self.device) + del self.whisper_model.decoder + self.whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name) + + def _load_f0_model(self): + """Load the F0 conditioned model for voice conversion.""" + dit_checkpoint_path, dit_config_path = load_custom_model_from_hf( + "Plachta/Seed-VC", + "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth", + "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml" + ) + config = yaml.safe_load(open(dit_config_path, 'r')) + model_params = recursive_munch(config['model_params']) + self.model_f0 = build_model(model_params, stage='DiT') + self.hop_length_f0 = config['preprocess_params']['spect_params']['hop_length'] + self.sr_f0 = config['preprocess_params']['sr'] + + # Load checkpoints + self.model_f0, _, _, _ = load_checkpoint( + self.model_f0, None, dit_checkpoint_path, + load_only_params=True, ignore_modules=[], is_distributed=False + ) + for key in self.model_f0: + self.model_f0[key].eval() + self.model_f0[key].to(self.device) + self.model_f0.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192) + + # Set up mel spectrogram function for F0 model + mel_fn_args_f0 = { + "n_fft": config['preprocess_params']['spect_params']['n_fft'], + "win_size": config['preprocess_params']['spect_params']['win_length'], + "hop_size": config['preprocess_params']['spect_params']['hop_length'], + "num_mels": config['preprocess_params']['spect_params']['n_mels'], + "sampling_rate": self.sr_f0, + "fmin": 0, + "fmax": None, + "center": False + } + self.to_mel_f0 = lambda x: mel_spectrogram(x, **mel_fn_args_f0) + + def _load_additional_modules(self): + """Load additional modules like CAMPPlus, BigVGAN, and RMVPE.""" + # Load CAMPPlus + campplus_ckpt_path = load_custom_model_from_hf("funasr/campplus", "campplus_cn_common.bin", config_filename=None) + self.campplus_model = CAMPPlus(feat_dim=80, embedding_size=192) + self.campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu")) + self.campplus_model.eval() + self.campplus_model.to(self.device) + + # Load BigVGAN models + self.bigvgan_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_22khz_80band_256x', use_cuda_kernel=False) + self.bigvgan_model.remove_weight_norm() + self.bigvgan_model = self.bigvgan_model.eval().to(self.device) + + self.bigvgan_44k_model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x', use_cuda_kernel=False) + self.bigvgan_44k_model.remove_weight_norm() + self.bigvgan_44k_model = self.bigvgan_44k_model.eval().to(self.device) + + # Load RMVPE for F0 extraction + model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None) + self.rmvpe = RMVPE(model_path, is_half=False, device=self.device) + + @staticmethod + def adjust_f0_semitones(f0_sequence, n_semitones): + """Adjust F0 values by a number of semitones.""" + factor = 2 ** (n_semitones / 12) + return f0_sequence * factor + + @staticmethod + def crossfade(chunk1, chunk2, overlap): + """Apply crossfade between two audio chunks.""" + fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2 + fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2 + if len(chunk2) < overlap: + chunk2[:overlap] = chunk2[:overlap] * fade_in[:len(chunk2)] + (chunk1[-overlap:] * fade_out)[:len(chunk2)] + else: + chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out + return chunk2 + + def _stream_wave_chunks(self, vc_wave, processed_frames, vc_target, overlap_wave_len, + generated_wave_chunks, previous_chunk, is_last_chunk, stream_output, sr): + """ + Helper method to handle streaming wave chunks. + + Args: + vc_wave: The current wave chunk + processed_frames: Number of frames processed so far + vc_target: The target mel spectrogram + overlap_wave_len: Length of overlap between chunks + generated_wave_chunks: List of generated wave chunks + previous_chunk: Previous wave chunk for crossfading + is_last_chunk: Whether this is the last chunk + stream_output: Whether to stream the output + sr: Sample rate + + Returns: + Tuple of (processed_frames, previous_chunk, should_break, mp3_bytes, full_audio) + where should_break indicates if processing should stop + mp3_bytes is the MP3 bytes if streaming, None otherwise + full_audio is the full audio if this is the last chunk, None otherwise + """ + mp3_bytes = None + full_audio = None + + if processed_frames == 0: + if is_last_chunk: + output_wave = vc_wave[0].cpu().numpy() + generated_wave_chunks.append(output_wave) + + if stream_output: + output_wave_int16 = (output_wave * 32768.0).astype(np.int16) + mp3_bytes = AudioSegment( + output_wave_int16.tobytes(), frame_rate=sr, + sample_width=output_wave_int16.dtype.itemsize, channels=1 + ).export(format="mp3", bitrate=self.bitrate).read() + full_audio = (sr, np.concatenate(generated_wave_chunks)) + else: + return processed_frames, previous_chunk, True, None, np.concatenate(generated_wave_chunks) + + return processed_frames, previous_chunk, True, mp3_bytes, full_audio + + output_wave = vc_wave[0, :-overlap_wave_len].cpu().numpy() + generated_wave_chunks.append(output_wave) + previous_chunk = vc_wave[0, -overlap_wave_len:] + processed_frames += vc_target.size(2) - self.overlap_frame_len + + if stream_output: + output_wave_int16 = (output_wave * 32768.0).astype(np.int16) + mp3_bytes = AudioSegment( + output_wave_int16.tobytes(), frame_rate=sr, + sample_width=output_wave_int16.dtype.itemsize, channels=1 + ).export(format="mp3", bitrate=self.bitrate).read() + + elif is_last_chunk: + output_wave = self.crossfade(previous_chunk.cpu().numpy(), vc_wave[0].cpu().numpy(), overlap_wave_len) + generated_wave_chunks.append(output_wave) + processed_frames += vc_target.size(2) - self.overlap_frame_len + + if stream_output: + output_wave_int16 = (output_wave * 32768.0).astype(np.int16) + mp3_bytes = AudioSegment( + output_wave_int16.tobytes(), frame_rate=sr, + sample_width=output_wave_int16.dtype.itemsize, channels=1 + ).export(format="mp3", bitrate=self.bitrate).read() + full_audio = (sr, np.concatenate(generated_wave_chunks)) + else: + return processed_frames, previous_chunk, True, None, np.concatenate(generated_wave_chunks) + + return processed_frames, previous_chunk, True, mp3_bytes, full_audio + + else: + output_wave = self.crossfade(previous_chunk.cpu().numpy(), vc_wave[0, :-overlap_wave_len].cpu().numpy(), overlap_wave_len) + generated_wave_chunks.append(output_wave) + previous_chunk = vc_wave[0, -overlap_wave_len:] + processed_frames += vc_target.size(2) - self.overlap_frame_len + + if stream_output: + output_wave_int16 = (output_wave * 32768.0).astype(np.int16) + mp3_bytes = AudioSegment( + output_wave_int16.tobytes(), frame_rate=sr, + sample_width=output_wave_int16.dtype.itemsize, channels=1 + ).export(format="mp3", bitrate=self.bitrate).read() + + return processed_frames, previous_chunk, False, mp3_bytes, full_audio + + def _process_whisper_features(self, audio_16k, is_source=True): + """Process audio through Whisper model to extract features.""" + if audio_16k.size(-1) <= 16000 * 30: + # If audio is short enough, process in one go + inputs = self.whisper_feature_extractor( + [audio_16k.squeeze(0).cpu().numpy()], + return_tensors="pt", + return_attention_mask=True, + sampling_rate=16000 + ) + input_features = self.whisper_model._mask_input_features( + inputs.input_features, attention_mask=inputs.attention_mask + ).to(self.device) + outputs = self.whisper_model.encoder( + input_features.to(self.whisper_model.encoder.dtype), + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ) + features = outputs.last_hidden_state.to(torch.float32) + features = features[:, :audio_16k.size(-1) // 320 + 1] + else: + # Process long audio in chunks + overlapping_time = 5 # 5 seconds + features_list = [] + buffer = None + traversed_time = 0 + while traversed_time < audio_16k.size(-1): + if buffer is None: # first chunk + chunk = audio_16k[:, traversed_time:traversed_time + 16000 * 30] + else: + chunk = torch.cat([ + buffer, + audio_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)] + ], dim=-1) + inputs = self.whisper_feature_extractor( + [chunk.squeeze(0).cpu().numpy()], + return_tensors="pt", + return_attention_mask=True, + sampling_rate=16000 + ) + input_features = self.whisper_model._mask_input_features( + inputs.input_features, attention_mask=inputs.attention_mask + ).to(self.device) + outputs = self.whisper_model.encoder( + input_features.to(self.whisper_model.encoder.dtype), + head_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ) + chunk_features = outputs.last_hidden_state.to(torch.float32) + chunk_features = chunk_features[:, :chunk.size(-1) // 320 + 1] + if traversed_time == 0: + features_list.append(chunk_features) + else: + features_list.append(chunk_features[:, 50 * overlapping_time:]) + buffer = chunk[:, -16000 * overlapping_time:] + traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time + features = torch.cat(features_list, dim=1) + + return features + + @spaces.GPU + @torch.no_grad() + @torch.inference_mode() + def convert_voice(self, source, target, diffusion_steps=10, length_adjust=1.0, + inference_cfg_rate=0.7, f0_condition=False, auto_f0_adjust=True, + pitch_shift=0, stream_output=True): + """ + Convert both timbre and voice from source to target. + + Args: + source: Path to source audio file + target: Path to target audio file + diffusion_steps: Number of diffusion steps (default: 10) + length_adjust: Length adjustment factor (default: 1.0) + inference_cfg_rate: Inference CFG rate (default: 0.7) + f0_condition: Whether to use F0 conditioning (default: False) + auto_f0_adjust: Whether to automatically adjust F0 (default: True) + pitch_shift: Pitch shift in semitones (default: 0) + stream_output: Whether to stream the output (default: True) + + Returns: + If stream_output is True, yields (mp3_bytes, full_audio) tuples + If stream_output is False, returns the full audio as a numpy array + """ + # Select appropriate models based on F0 condition + inference_module = self.model if not f0_condition else self.model_f0 + mel_fn = self.to_mel if not f0_condition else self.to_mel_f0 + bigvgan_fn = self.bigvgan_model if not f0_condition else self.bigvgan_44k_model + sr = 22050 if not f0_condition else 44100 + hop_length = 256 if not f0_condition else 512 + max_context_window = sr // hop_length * 30 + overlap_wave_len = self.overlap_frame_len * hop_length + + # Load audio + source_audio = librosa.load(source, sr=sr)[0] + ref_audio = librosa.load(target, sr=sr)[0] + + # Process audio + source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(self.device) + ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(self.device) + + # Resample to 16kHz for feature extraction + ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000) + converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000) + + # Extract Whisper features + S_alt = self._process_whisper_features(converted_waves_16k, is_source=True) + S_ori = self._process_whisper_features(ref_waves_16k, is_source=False) + + # Compute mel spectrograms + mel = mel_fn(source_audio.to(self.device).float()) + mel2 = mel_fn(ref_audio.to(self.device).float()) + + # Set target lengths + target_lengths = torch.LongTensor([int(mel.size(2) * length_adjust)]).to(mel.device) + target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device) + + # Compute style features + feat2 = torchaudio.compliance.kaldi.fbank( + ref_waves_16k, + num_mel_bins=80, + dither=0, + sample_frequency=16000 + ) + feat2 = feat2 - feat2.mean(dim=0, keepdim=True) + style2 = self.campplus_model(feat2.unsqueeze(0)) + + # Process F0 if needed + if f0_condition: + F0_ori = self.rmvpe.infer_from_audio(ref_waves_16k[0], thred=0.03) + F0_alt = self.rmvpe.infer_from_audio(converted_waves_16k[0], thred=0.03) + + if self.device == "mps": + F0_ori = torch.from_numpy(F0_ori).float().to(self.device)[None] + F0_alt = torch.from_numpy(F0_alt).float().to(self.device)[None] + else: + F0_ori = torch.from_numpy(F0_ori).to(self.device)[None] + F0_alt = torch.from_numpy(F0_alt).to(self.device)[None] + + voiced_F0_ori = F0_ori[F0_ori > 1] + voiced_F0_alt = F0_alt[F0_alt > 1] + + log_f0_alt = torch.log(F0_alt + 1e-5) + voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5) + voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5) + median_log_f0_ori = torch.median(voiced_log_f0_ori) + median_log_f0_alt = torch.median(voiced_log_f0_alt) + + # Shift alt log f0 level to ori log f0 level + shifted_log_f0_alt = log_f0_alt.clone() + if auto_f0_adjust: + shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori + shifted_f0_alt = torch.exp(shifted_log_f0_alt) + if pitch_shift != 0: + shifted_f0_alt[F0_alt > 1] = self.adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch_shift) + else: + F0_ori = None + F0_alt = None + shifted_f0_alt = None + + # Length regulation + cond, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator( + S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt + ) + prompt_condition, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator( + S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori + ) + + # Process in chunks for streaming + max_source_window = max_context_window - mel2.size(2) + processed_frames = 0 + generated_wave_chunks = [] + previous_chunk = None + + # Generate chunk by chunk and stream the output + while processed_frames < cond.size(1): + chunk_cond = cond[:, processed_frames:processed_frames + max_source_window] + is_last_chunk = processed_frames + max_source_window >= cond.size(1) + cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1) + + with torch.autocast(device_type=self.device.type, dtype=torch.float16): + # Voice Conversion + vc_target = inference_module.cfm.inference( + cat_condition, + torch.LongTensor([cat_condition.size(1)]).to(mel2.device), + mel2, style2, None, diffusion_steps, + inference_cfg_rate=inference_cfg_rate + ) + vc_target = vc_target[:, :, mel2.size(-1):] + + vc_wave = bigvgan_fn(vc_target.float())[0] + + processed_frames, previous_chunk, should_break, mp3_bytes, full_audio = self._stream_wave_chunks( + vc_wave, processed_frames, vc_target, overlap_wave_len, + generated_wave_chunks, previous_chunk, is_last_chunk, stream_output, sr + ) + + if stream_output and mp3_bytes is not None: + yield mp3_bytes, full_audio + + if should_break: + if not stream_output: + return full_audio + break + + if not stream_output: + return np.concatenate(generated_wave_chunks) + + return None, None \ No newline at end of file