diff --git a/README.md b/README.md index 51223e306ee278ec0896ceffe2ded040ab566535..9f548a8582d579f88df5945c0daa2c2e499a36e3 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,13 @@ --- -title: ThinkSound -emoji: 🌍 -colorFrom: green +title: Test +emoji: 📚 +colorFrom: gray colorTo: gray sdk: gradio sdk_version: 5.35.0 app_file: app.py pinned: false -license: apache-2.0 -short_description: 'demo of ThinkSound ' +license: mit --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..214b9d9afc82ff3426a846fe8a4b3fbf3f206d6e --- /dev/null +++ b/app.py @@ -0,0 +1,331 @@ +from prefigure.prefigure import get_all_args, push_wandb_config +import json +import os +os.environ["GRADIO_TEMP_DIR"] = "./.gradio_tmp" +import re +import torch +import torchaudio +# import pytorch_lightning as pl +import lightning as L +from lightning.pytorch.callbacks import Timer, ModelCheckpoint, BasePredictionWriter +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.tuner import Tuner +from lightning.pytorch import seed_everything +import random +from datetime import datetime +# from think_sound.data.dataset import create_dataloader_from_config +from think_sound.data.datamodule import DataModule +from think_sound.models import create_model_from_config +from think_sound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model +from think_sound.training import create_training_wrapper_from_config, create_demo_callback_from_config +from think_sound.training.utils import copy_state_dict +from think_sound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler +from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils +from torch.utils.data import Dataset +from typing import Optional, Union +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder +from torchvision.utils import save_image +from transformers import AutoProcessor +import torch.nn.functional as F +import gradio as gr +import tempfile +import subprocess +from huggingface_hub import hf_hub_download + +_CLIP_SIZE = 224 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + +def pad_to_square(video_tensor): + if len(video_tensor.shape) != 4: + raise ValueError("Input tensor must have shape (l, c, h, w)") + + l, c, h, w = video_tensor.shape + max_side = max(h, w) + + pad_h = max_side - h + pad_w = max_side - w + + padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + + video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0) + + return video_padded + + +class VGGSound(Dataset): + + def __init__( + self, + sample_rate: int = 44_100, + duration_sec: float = 9.0, + audio_samples: Optional[int] = 397312, + normalize_audio: bool = False, + ): + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = self.audio_samples + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Lambda(pad_to_square), # 先填充为正方形 + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b") + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.resampler = {} + + def sample(self, video_path,label): + video_id = video_path + + reader = StreamingMediaDecoder(video_path) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + + clip_chunk = clip_chunk[:self.clip_expected_length] + # import ipdb + # ipdb.set_trace() + if clip_chunk.shape[0] != self.clip_expected_length: + current_length = clip_chunk.shape[0] + padding_needed = self.clip_expected_length - current_length + + # Check that padding needed is no more than 2 + assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed' + + # If assertion passes, proceed with padding + if padding_needed > 0: + last_frame = clip_chunk[-1] + log.info(last_frame.shape) + # Repeat the last frame to reach the expected length + padding = last_frame.repeat(padding_needed, 1, 1, 1) + clip_chunk = torch.cat((clip_chunk, padding), dim=0) + # raise RuntimeError(f'CLIP video wrong length {video_id}, ' + # f'expected {self.clip_expected_length}, ' + # f'got {clip_chunk.shape[0]}') + + # save_image(clip_chunk[0] / 255.0,'ori.png') + clip_chunk = pad_to_square(clip_chunk) + + clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"] + + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + # padding using the last frame, but no more than 2 + current_length = sync_chunk.shape[0] + last_frame = sync_chunk[-1] + # 重复最后一帧以进行填充 + padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1) + assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}' + sync_chunk = torch.cat((sync_chunk, padding), dim=0) + # raise RuntimeError(f'Sync video wrong length {video_id}, ' + # f'expected {self.sync_expected_length}, ' + # f'got {sync_chunk.shape[0]}') + + sync_chunk = self.sync_transform(sync_chunk) + # assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \ + # and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape' + data = { + 'id': video_id, + 'caption': label, + # 'audio': audio_chunk, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + +# 检查设备 +if torch.cuda.is_available(): + device = 'cuda' + extra_device = 'cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0' +else: + device = 'cpu' + extra_device = 'cpu' + +vae_ckpt = hf_hub_download(repo_id="UncleWang233/occdata", filename="epoch=3-step=100000.ckpt",repo_type="dataset") +synchformer_ckpt = hf_hub_download(repo_id="UncleWang233/occdata", filename="synchformer_state_dict.pth",repo_type="dataset") +feature_extractor = FeaturesUtils( + vae_ckpt=vae_ckpt, + vae_config='think_sound/configs/model_configs/autoencoders/stable_audio_2_0_vae.json', + enable_conditions=True, + synchformer_ckpt=synchformer_ckpt +).eval().to(extra_device) + +preprocesser = VGGSound() + +args = get_all_args() + +seed = 10086 + +seed_everything(seed, workers=True) + + +#Get JSON config from args.model_config +with open("think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json") as f: + model_config = json.load(f) + +model = create_model_from_config(model_config) + +## speed by torch.compile +if args.compile: + model = torch.compile(model) + +if args.pretrained_ckpt_path: + copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion. + +if args.remove_pretransform_weight_norm == "pre_load": + remove_weight_norm_from_model(model.pretransform) + + +load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.') +# new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")} +model.pretransform.load_state_dict(load_vae_state) + +# Remove weight_norm from the pretransform if specified +if args.remove_pretransform_weight_norm == "post_load": + remove_weight_norm_from_model(model.pretransform) +ckpt_path = hf_hub_download(repo_id="UncleWang233/occdata", filename="epoch=10-step=68000.ckpt",repo_type="dataset") +training_wrapper = create_training_wrapper_from_config(model_config, model) +# 加载模型权重时根据设备选择map_location +if device == 'cuda': + training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict']) +else: + training_wrapper.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict']) + +def get_audio(video_path, caption): + # 允许caption为空 + if caption is None: + caption = '' + timer = Timer(duration="00:15:00:00") + data = preprocesser.sample(video_path, caption) + + preprocessed_data = {} + metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(data['caption']) + preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0) + preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0) + + t5_features = feature_extractor.encode_t5_text(data['caption']) + preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0) + + clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device)) + preprocessed_data['metaclip_features'] = clip_features.detach().cpu().squeeze(0) + + sync_features = feature_extractor.encode_video_with_sync(data['sync_video'].unsqueeze(0).to(extra_device)) + preprocessed_data['sync_features'] = sync_features.detach().cpu().squeeze(0) + preprocessed_data['video_exist'] = torch.tensor(True) + + metadata = [preprocessed_data] + + batch_size = 1 + length = 194 + with torch.amp.autocast(device): + conditioning = training_wrapper.diffusion.conditioner(metadata, training_wrapper.device) + + video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) + conditioning['metaclip_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_clip_feat + conditioning['sync_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_sync_feat + + cond_inputs = training_wrapper.diffusion.get_conditioning_inputs(conditioning) + noise = torch.randn([batch_size, training_wrapper.diffusion.io_channels, length]).to(training_wrapper.device) + with torch.amp.autocast(device): + model = training_wrapper.diffusion.model + if training_wrapper.diffusion_objective == "v": + fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True) + elif training_wrapper.diffusion_objective == "rectified_flow": + import time + start_time = time.time() + fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True) + end_time = time.time() + execution_time = end_time - start_time + print(f"执行时间: {execution_time:.2f} 秒") + if training_wrapper.diffusion.pretransform is not None: + fakes = training_wrapper.diffusion.pretransform.decode(fakes) + + audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + # 保存临时音频文件 + with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio: + torchaudio.save(tmp_audio.name, audios[0], 44100) + audio_path = tmp_audio.name + return audio_path + +# 合成新视频:用ffmpeg将音频与原视频合成 + +def synthesize_video_with_audio(video_file, caption): + # 允许caption为空 + if caption is None: + caption = '' + audio_path = get_audio(video_file, caption) + with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video: + output_video_path = tmp_video.name + # ffmpeg命令:用新音频替换原视频音轨 + cmd = [ + 'ffmpeg', '-y', '-i', video_file, '-i', audio_path, + '-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0', + '-shortest', output_video_path + ] + subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + return output_video_path + +# Gradio界面 +with gr.Blocks() as demo: + gr.Markdown("# ThinkSound\nupload video and caption(optional), and get video with audio!") + with gr.Row(): + video_input = gr.Video(label="upload video") + caption_input = gr.Textbox(label="caption(optional)", placeholder="can be empty", lines=1) + output_video = gr.Video(label="output video") + btn = gr.Button("start synthesize") + btn.click(fn=synthesize_video_with_audio, inputs=[video_input, caption_input], outputs=output_video) + + gr.Examples( + examples=[ + ["./examples/1_mute.mp4", "Playing Trumpet"], + ["./examples/2_mute.mp4", "Axe striking"], + ["./examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier"], + ["./examples/4_mute.mp4", "train passing by"], + ["./examples/5_mute.mp4", "Lighting Firecrackers"] + ], + inputs=[video_input, caption_input], + ) + +demo.launch(share=True) + diff --git a/data_utils/__init__.py b/data_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_utils/__pycache__/__init__.cpython-310.pyc b/data_utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6649c3a8b1e3891099f9dcc6544cfe78b2fd0819 Binary files /dev/null and b/data_utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/data_utils/__pycache__/utils.cpython-310.pyc b/data_utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce286899e4b3e6d8ef183749a5428ec34618bb19 Binary files /dev/null and b/data_utils/__pycache__/utils.cpython-310.pyc differ diff --git a/data_utils/__pycache__/utils.cpython-39.pyc b/data_utils/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46823f149a2bbd0d09481eff2ee6d1ad86fe151b Binary files /dev/null and b/data_utils/__pycache__/utils.cpython-39.pyc differ diff --git a/data_utils/ext/synchformer/LICENSE b/data_utils/ext/synchformer/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2f70bf24b6f45f458998bdf5746376c4832352ea --- /dev/null +++ b/data_utils/ext/synchformer/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Vladimir Iashin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/data_utils/ext/synchformer/__init__.py b/data_utils/ext/synchformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9eff0160aa046d712d9330c4201b0ccd4c0c51b0 --- /dev/null +++ b/data_utils/ext/synchformer/__init__.py @@ -0,0 +1 @@ +from data_utils.ext.synchformer.synchformer import Synchformer diff --git a/data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc b/data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ce33b394f89c7e27b9d3ce6128905759ae97585 Binary files /dev/null and b/data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc b/data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43e5f377eba172a01f2d255bd1892d94fe4eaf1c Binary files /dev/null and b/data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc differ diff --git a/data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc b/data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..410b25263432859b32e10180b22403ba2ed6b511 Binary files /dev/null and b/data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc differ diff --git a/data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc b/data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af8c9377841c9c77b52249865cbc42d98c14278f Binary files /dev/null and b/data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc differ diff --git a/data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc b/data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0df0dd022851746361ae2410b5d9508764e1bbd Binary files /dev/null and b/data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc differ diff --git a/data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc b/data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cca707887580daa686ca15539a1fe0dd93b73884 Binary files /dev/null and b/data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc differ diff --git a/data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc b/data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5880bf322ab6106b6e97b8d3b5c5c835e114b424 Binary files /dev/null and b/data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc differ diff --git a/data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc b/data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eae02f00ddcb30ca46ebf32b5debb946d6618d81 Binary files /dev/null and b/data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc differ diff --git a/data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-310.pyc b/data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa6ddadb705a39423b4768966a3a0ea41e76b648 Binary files /dev/null and b/data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-310.pyc differ diff --git a/data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-39.pyc b/data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10ce47b92644955ec4a037d3703ba997e7162429 Binary files /dev/null and b/data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-39.pyc differ diff --git a/data_utils/ext/synchformer/__pycache__/vit_helper.cpython-310.pyc b/data_utils/ext/synchformer/__pycache__/vit_helper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15ebefdf8dbdb0c826516b654134f0db46f00282 Binary files /dev/null and b/data_utils/ext/synchformer/__pycache__/vit_helper.cpython-310.pyc differ diff --git a/data_utils/ext/synchformer/__pycache__/vit_helper.cpython-39.pyc b/data_utils/ext/synchformer/__pycache__/vit_helper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62566399384b6ab2d9d619a9e7ff0364821a0a51 Binary files /dev/null and b/data_utils/ext/synchformer/__pycache__/vit_helper.cpython-39.pyc differ diff --git a/data_utils/ext/synchformer/divided_224_16x4.yaml b/data_utils/ext/synchformer/divided_224_16x4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f9d20b76302a8af7928391643bd4b2d184e970aa --- /dev/null +++ b/data_utils/ext/synchformer/divided_224_16x4.yaml @@ -0,0 +1,84 @@ +TRAIN: + ENABLE: True + DATASET: Ssv2 + BATCH_SIZE: 32 + EVAL_PERIOD: 5 + CHECKPOINT_PERIOD: 5 + AUTO_RESUME: True + CHECKPOINT_EPOCH_RESET: True + CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth +DATA: + NUM_FRAMES: 16 + SAMPLING_RATE: 4 + TRAIN_JITTER_SCALES: [256, 320] + TRAIN_CROP_SIZE: 224 + TEST_CROP_SIZE: 224 + INPUT_CHANNEL_NUM: [3] + MEAN: [0.5, 0.5, 0.5] + STD: [0.5, 0.5, 0.5] + PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2 + PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames + INV_UNIFORM_SAMPLE: True + RANDOM_FLIP: False + REVERSE_INPUT_CHANNEL: True + USE_RAND_AUGMENT: True + RE_PROB: 0.0 + USE_REPEATED_AUG: False + USE_RANDOM_RESIZE_CROPS: False + COLORJITTER: False + GRAYSCALE: False + GAUSSIAN: False +SOLVER: + BASE_LR: 1e-4 + LR_POLICY: steps_with_relative_lrs + LRS: [1, 0.1, 0.01] + STEPS: [0, 20, 30] + MAX_EPOCH: 35 + MOMENTUM: 0.9 + WEIGHT_DECAY: 5e-2 + WARMUP_EPOCHS: 0.0 + OPTIMIZING_METHOD: adamw + USE_MIXED_PRECISION: True + SMOOTHING: 0.2 +SLOWFAST: + ALPHA: 8 +VIT: + PATCH_SIZE: 16 + PATCH_SIZE_TEMP: 2 + CHANNELS: 3 + EMBED_DIM: 768 + DEPTH: 12 + NUM_HEADS: 12 + MLP_RATIO: 4 + QKV_BIAS: True + VIDEO_INPUT: True + TEMPORAL_RESOLUTION: 8 + USE_MLP: True + DROP: 0.0 + POS_DROPOUT: 0.0 + DROP_PATH: 0.2 + IM_PRETRAINED: True + HEAD_DROPOUT: 0.0 + HEAD_ACT: tanh + PRETRAINED_WEIGHTS: vit_1k + ATTN_LAYER: divided +MODEL: + NUM_CLASSES: 174 + ARCH: slow + MODEL_NAME: VisionTransformer + LOSS_FUNC: cross_entropy +TEST: + ENABLE: True + DATASET: Ssv2 + BATCH_SIZE: 64 + NUM_ENSEMBLE_VIEWS: 1 + NUM_SPATIAL_CROPS: 3 +DATA_LOADER: + NUM_WORKERS: 4 + PIN_MEMORY: True +NUM_GPUS: 8 +NUM_SHARDS: 4 +RNG_SEED: 0 +OUTPUT_DIR: . +TENSORBOARD: + ENABLE: True diff --git a/data_utils/ext/synchformer/motionformer.py b/data_utils/ext/synchformer/motionformer.py new file mode 100644 index 0000000000000000000000000000000000000000..148b5d3c7f021a8dfe38f7134a919c21d35e6bab --- /dev/null +++ b/data_utils/ext/synchformer/motionformer.py @@ -0,0 +1,400 @@ +import logging +from pathlib import Path + +import einops +import torch +from omegaconf import OmegaConf +from timm.layers import trunc_normal_ +from torch import nn + +from data_utils.ext.synchformer.utils import check_if_file_exists_else_download +from data_utils.ext.synchformer.video_model_builder import VisionTransformer + +FILE2URL = { + # cfg + 'motionformer_224_16x4.yaml': + 'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml', + 'joint_224_16x4.yaml': + 'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml', + 'divided_224_16x4.yaml': + 'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml', + # ckpt + 'ssv2_motionformer_224_16x4.pyth': + 'https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth', + 'ssv2_joint_224_16x4.pyth': + 'https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth', + 'ssv2_divided_224_16x4.pyth': + 'https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth', +} + + +class MotionFormer(VisionTransformer): + ''' This class serves three puposes: + 1. Renames the class to MotionFormer. + 2. Downloads the cfg from the original repo and patches it if needed. + 3. Takes care of feature extraction by redefining .forward() + - if `extract_features=True` and `factorize_space_time=False`, + the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8 + - if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D) + and spatial and temporal transformer encoder layers are used. + - if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True` + the output is of shape (B, D) and spatial and temporal transformer encoder layers + are used as well as the global representation is extracted from segments (extra pos emb + is added). + ''' + + def __init__( + self, + extract_features: bool = False, + ckpt_path: str = None, + factorize_space_time: bool = None, + agg_space_module: str = None, + agg_time_module: str = None, + add_global_repr: bool = True, + agg_segments_module: str = None, + max_segments: int = None, + ): + self.extract_features = extract_features + self.ckpt_path = ckpt_path + self.factorize_space_time = factorize_space_time + + if self.ckpt_path is not None: + check_if_file_exists_else_download(self.ckpt_path, FILE2URL) + ckpt = torch.load(self.ckpt_path, map_location='cpu') + mformer_ckpt2cfg = { + 'ssv2_motionformer_224_16x4.pyth': 'motionformer_224_16x4.yaml', + 'ssv2_joint_224_16x4.pyth': 'joint_224_16x4.yaml', + 'ssv2_divided_224_16x4.pyth': 'divided_224_16x4.yaml', + } + # init from motionformer ckpt or from our Stage I ckpt + # depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to + # load the state dict differently + was_pt_on_avclip = self.ckpt_path.endswith( + '.pt') # checks if it is a stage I ckpt (FIXME: a bit generic) + if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())): + cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name] + elif was_pt_on_avclip: + # TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it) + s1_cfg = ckpt.get('args', None) # Stage I cfg + if s1_cfg is not None: + s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path + # if the stage I ckpt was initialized from a motionformer ckpt or train from scratch + if s1_vfeat_extractor_ckpt_path is not None: + cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name] + else: + cfg_fname = 'divided_224_16x4.yaml' + else: + cfg_fname = 'divided_224_16x4.yaml' + else: + raise ValueError(f'ckpt_path {self.ckpt_path} is not supported.') + else: + was_pt_on_avclip = False + cfg_fname = 'divided_224_16x4.yaml' + # logging.info(f'No ckpt_path provided, using {cfg_fname} config.') + + if cfg_fname in ['motionformer_224_16x4.yaml', 'divided_224_16x4.yaml']: + pos_emb_type = 'separate' + elif cfg_fname == 'joint_224_16x4.yaml': + pos_emb_type = 'joint' + + self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname + + check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL) + mformer_cfg = OmegaConf.load(self.mformer_cfg_path) + logging.info(f'Loading MotionFormer config from {self.mformer_cfg_path.absolute()}') + + # patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`) + mformer_cfg.VIT.ATTN_DROPOUT = 0.0 + mformer_cfg.VIT.POS_EMBED = pos_emb_type + mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True + mformer_cfg.VIT.APPROX_ATTN_TYPE = 'none' # guessing + mformer_cfg.VIT.APPROX_ATTN_DIM = 64 # from ckpt['cfg'] + + # finally init VisionTransformer with the cfg + super().__init__(mformer_cfg) + + # load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt + if (self.ckpt_path is not None) and (not was_pt_on_avclip): + _ckpt_load_status = self.load_state_dict(ckpt['model_state'], strict=False) + if len(_ckpt_load_status.missing_keys) > 0 or len( + _ckpt_load_status.unexpected_keys) > 0: + logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed.' \ + f'Missing keys: {_ckpt_load_status.missing_keys}, ' \ + f'Unexpected keys: {_ckpt_load_status.unexpected_keys}') + else: + logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.') + + if self.extract_features: + assert isinstance(self.norm, + nn.LayerNorm), 'early x[:, 1:, :] may not be safe for per-tr weights' + # pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger + self.pre_logits = nn.Identity() + # we don't need the classification head (saving memory) + self.head = nn.Identity() + self.head_drop = nn.Identity() + # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer) + transf_enc_layer_kwargs = dict( + d_model=self.embed_dim, + nhead=self.num_heads, + activation=nn.GELU(), + batch_first=True, + dim_feedforward=self.mlp_ratio * self.embed_dim, + dropout=self.drop_rate, + layer_norm_eps=1e-6, + norm_first=True, + ) + # define adapters if needed + if self.factorize_space_time: + if agg_space_module == 'TransformerEncoderLayer': + self.spatial_attn_agg = SpatialTransformerEncoderLayer( + **transf_enc_layer_kwargs) + elif agg_space_module == 'AveragePooling': + self.spatial_attn_agg = AveragePooling(avg_pattern='BS D t h w -> BS D t', + then_permute_pattern='BS D t -> BS t D') + if agg_time_module == 'TransformerEncoderLayer': + self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs) + elif agg_time_module == 'AveragePooling': + self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D') + elif 'Identity' in agg_time_module: + self.temp_attn_agg = nn.Identity() + # define a global aggregation layer (aggregarate over segments) + self.add_global_repr = add_global_repr + if add_global_repr: + if agg_segments_module == 'TransformerEncoderLayer': + # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D) + # we need to add pos emb (PE) because previously we added the same PE for each segment + pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1 + self.global_attn_agg = TemporalTransformerEncoderLayer( + add_pos_emb=True, + pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT, + pos_max_len=pos_max_len, + **transf_enc_layer_kwargs) + elif agg_segments_module == 'AveragePooling': + self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D') + + if was_pt_on_avclip: + # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors) + # and keep only the state_dict of the feat extractor + ckpt_weights = dict() + for k, v in ckpt['state_dict'].items(): + if k.startswith(('module.v_encoder.', 'v_encoder.')): + k = k.replace('module.', '').replace('v_encoder.', '') + ckpt_weights[k] = v + _load_status = self.load_state_dict(ckpt_weights, strict=False) + if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0: + logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n' \ + f'Missing keys ({len(_load_status.missing_keys)}): ' \ + f'{_load_status.missing_keys}, \n' \ + f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \ + f'{_load_status.unexpected_keys} \n' \ + f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.') + else: + logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.') + + # patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1 + # but it used to calculate the number of patches, so we need to set keep it + self.patch_embed.requires_grad_(False) + + def forward(self, x): + ''' + x is of shape (B, S, C, T, H, W) where S is the number of segments. + ''' + # Batch, Segments, Channels, T=frames, Height, Width + B, S, C, T, H, W = x.shape + # Motionformer expects a tensor of shape (1, B, C, T, H, W). + # The first dimension (1) is a dummy dimension to make the input tensor and won't be used: + # see `video_model_builder.video_input`. + # x = x.unsqueeze(0) # (1, B, S, C, T, H, W) + + orig_shape = (B, S, C, T, H, W) + x = x.view(B * S, C, T, H, W) # flatten batch and segments + x = self.forward_segments(x, orig_shape=orig_shape) + # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D)) + x = x.view(B, S, *x.shape[1:]) + # x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity` + + return x # x is (B, S, ...) + + def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor: + '''x is of shape (1, BS, C, T, H, W) where S is the number of segments.''' + x, x_mask = self.forward_features(x) + + assert self.extract_features + + # (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8 + x = x[:, + 1:, :] # without the CLS token for efficiency (should be safe for LayerNorm and FC) + x = self.norm(x) + x = self.pre_logits(x) + if self.factorize_space_time: + x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D) + + x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D) + x = self.temp_attn_agg( + x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity` + + return x + + def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor: + ''' + feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8 + Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions. + From `self.patch_embed_3d`, it follows that we could reshape feats with: + `feats.transpose(1, 2).view(B*S, D, t, h, w)` + ''' + B, S, C, T, H, W = orig_shape + D = self.embed_dim + + # num patches in each dimension + t = T // self.patch_embed_3d.z_block_size + h = self.patch_embed_3d.height + w = self.patch_embed_3d.width + + feats = feats.permute(0, 2, 1) # (B*S, D, T) + feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w) + + return feats + + +class BaseEncoderLayer(nn.TransformerEncoderLayer): + ''' + This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token + to the sequence and outputs the CLS token's representation. + This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream + and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream. + We also, optionally, add a positional embedding to the input sequence which + allows to reuse it for global aggregation (of segments) for both streams. + ''' + + def __init__(self, + add_pos_emb: bool = False, + pos_emb_drop: float = None, + pos_max_len: int = None, + *args_transformer_enc, + **kwargs_transformer_enc): + super().__init__(*args_transformer_enc, **kwargs_transformer_enc) + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim)) + trunc_normal_(self.cls_token, std=.02) + + # add positional embedding + self.add_pos_emb = add_pos_emb + if add_pos_emb: + self.pos_max_len = 1 + pos_max_len # +1 (for CLS) + self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim)) + self.pos_drop = nn.Dropout(pos_emb_drop) + trunc_normal_(self.pos_emb, std=.02) + + self.apply(self._init_weights) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None): + ''' x is of shape (B, N, D); if provided x_mask is of shape (B, N)''' + batch_dim = x.shape[0] + + # add CLS token + cls_tokens = self.cls_token.expand(batch_dim, -1, -1) # expanding to match batch dimension + x = torch.cat((cls_tokens, x), dim=-2) # (batch_dim, 1+seq_len, D) + if x_mask is not None: + cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool, + device=x_mask.device) # 1=keep; 0=mask + x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) # (batch_dim, 1+seq_len) + B, N = x_mask_w_cls.shape + # torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks + x_mask_w_cls = x_mask_w_cls.reshape(B, 1, 1, N)\ + .expand(-1, self.self_attn.num_heads, N, -1)\ + .reshape(B * self.self_attn.num_heads, N, N) + assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, 'x_mask_w_cls.dtype != bool' + x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask) + else: + x_mask_w_cls = None + + # add positional embedding + if self.add_pos_emb: + seq_len = x.shape[ + 1] # (don't even think about moving it before the CLS token concatenation) + assert seq_len <= self.pos_max_len, f'Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})' + x = x + self.pos_emb[:, :seq_len, :] + x = self.pos_drop(x) + + # apply encoder layer (calls nn.TransformerEncoderLayer.forward); + x = super().forward(src=x, src_mask=x_mask_w_cls) # (batch_dim, 1+seq_len, D) + + # CLS token is expected to hold spatial information for each frame + x = x[:, 0, :] # (batch_dim, D) + + return x + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token', 'pos_emb'} + + +class SpatialTransformerEncoderLayer(BaseEncoderLayer): + ''' Aggregates spatial dimensions by applying attention individually to each frame. ''' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: + ''' x is of shape (B*S, D, t, h, w) where S is the number of segments. + if specified x_mask (B*S, t, h, w), 0=masked, 1=kept + Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame. ''' + BS, D, t, h, w = x.shape + + # time as a batch dimension and flatten spatial dimensions as sequence + x = einops.rearrange(x, 'BS D t h w -> (BS t) (h w) D') + # similar to mask + if x_mask is not None: + x_mask = einops.rearrange(x_mask, 'BS t h w -> (BS t) (h w)') + + # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation + x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D) + + # reshape back to (B*S, t, D) + x = einops.rearrange(x, '(BS t) D -> BS t D', BS=BS, t=t) + + # (B*S, t, D) + return x + + +class TemporalTransformerEncoderLayer(BaseEncoderLayer): + ''' Aggregates temporal dimension with attention. Also used with pos emb as global aggregation + in both streams. ''' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x): + ''' x is of shape (B*S, t, D) where S is the number of segments. + Returns a tensor of shape (B*S, D) pooling temporal information. ''' + BS, t, D = x.shape + + # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation + x = super().forward(x) # (B*S, D) + + return x # (B*S, D) + + +class AveragePooling(nn.Module): + + def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None: + ''' patterns are e.g. "bs t d -> bs d" ''' + super().__init__() + # TODO: need to register them as buffers (but fails because these are strings) + self.reduce_fn = 'mean' + self.avg_pattern = avg_pattern + self.then_permute_pattern = then_permute_pattern + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor: + x = einops.reduce(x, self.avg_pattern, self.reduce_fn) + if self.then_permute_pattern is not None: + x = einops.rearrange(x, self.then_permute_pattern) + return x diff --git a/data_utils/ext/synchformer/synchformer.py b/data_utils/ext/synchformer/synchformer.py new file mode 100644 index 0000000000000000000000000000000000000000..fd580fa6cc1701eedeeebc5fcc3951755207df96 --- /dev/null +++ b/data_utils/ext/synchformer/synchformer.py @@ -0,0 +1,55 @@ +import logging +from typing import Any, Mapping + +import torch +from torch import nn + +from data_utils.ext.synchformer.motionformer import MotionFormer + + +class Synchformer(nn.Module): + + def __init__(self): + super().__init__() + + self.vfeat_extractor = MotionFormer(extract_features=True, + factorize_space_time=True, + agg_space_module='TransformerEncoderLayer', + agg_time_module='torch.nn.Identity', + add_global_repr=False) + + # self.vfeat_extractor = instantiate_from_config(vfeat_extractor) + # self.afeat_extractor = instantiate_from_config(afeat_extractor) + # # bridging the s3d latent dim (1024) into what is specified in the config + # # to match e.g. the transformer dim + # self.vproj = instantiate_from_config(vproj) + # self.aproj = instantiate_from_config(aproj) + # self.transformer = instantiate_from_config(transformer) + + def forward(self, vis): + B, S, Tv, C, H, W = vis.shape + vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) + # feat extractors return a tuple of segment-level and global features (ignored for sync) + # (B, S, tv, D), e.g. (B, 7, 8, 768) + vis = self.vfeat_extractor(vis) + return vis + + def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True): + # discard all entries except vfeat_extractor + sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')} + + return super().load_state_dict(sd, strict) + + +if __name__ == "__main__": + model = Synchformer().cuda().eval() + sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True) + model.load_state_dict(sd) + + vid = torch.randn(2, 7, 16, 3, 224, 224).cuda() + features = model.extract_vfeats(vid, for_loop=False).detach().cpu() + print(features.shape) + + # extract and save the state dict only + # sd = torch.load('./ext_weights/sync_model_audioset.pt')['model'] + # torch.save(sd, './ext_weights/synchformer_state_dict.pth') diff --git a/data_utils/ext/synchformer/utils.py b/data_utils/ext/synchformer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a797eb9c66f04b7c29934bfc384c935cdf441a62 --- /dev/null +++ b/data_utils/ext/synchformer/utils.py @@ -0,0 +1,92 @@ +from hashlib import md5 +from pathlib import Path + +import requests +from tqdm import tqdm + +PARENT_LINK = 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a' +FNAME2LINK = { + # S3: Synchability: AudioSet (run 2) + '24-01-22T20-34-52.pt': + f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt', + 'cfg-24-01-22T20-34-52.yaml': + f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml', + # S2: Synchformer: AudioSet (run 2) + '24-01-04T16-39-21.pt': + f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt', + 'cfg-24-01-04T16-39-21.yaml': + f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml', + # S2: Synchformer: AudioSet (run 1) + '23-08-28T11-23-23.pt': + f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt', + 'cfg-23-08-28T11-23-23.yaml': + f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml', + # S2: Synchformer: LRS3 (run 2) + '23-12-23T18-33-57.pt': + f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt', + 'cfg-23-12-23T18-33-57.yaml': + f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml', + # S2: Synchformer: VGS (run 2) + '24-01-02T10-00-53.pt': + f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt', + 'cfg-24-01-02T10-00-53.yaml': + f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml', + # SparseSync: ft VGGSound-Full + '22-09-21T21-00-52.pt': + f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt', + 'cfg-22-09-21T21-00-52.yaml': + f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml', + # SparseSync: ft VGGSound-Sparse + '22-07-28T15-49-45.pt': + f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt', + 'cfg-22-07-28T15-49-45.yaml': + f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml', + # SparseSync: only pt on LRS3 + '22-07-13T22-25-49.pt': + f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt', + 'cfg-22-07-13T22-25-49.yaml': + f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml', + # SparseSync: feature extractors + 'ResNetAudio-22-08-04T09-51-04.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt', # 2s + 'ResNetAudio-22-08-03T23-14-49.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt', # 3s + 'ResNetAudio-22-08-03T23-14-28.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt', # 4s + 'ResNetAudio-22-06-24T08-10-33.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt', # 5s + 'ResNetAudio-22-06-24T17-31-07.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt', # 6s + 'ResNetAudio-22-06-24T23-57-11.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt', # 7s + 'ResNetAudio-22-06-25T04-35-42.pt': + f'{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt', # 8s +} + + +def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024): + '''Checks if file exists, if not downloads it from the link to the path''' + path = Path(path) + if not path.exists(): + path.parent.mkdir(exist_ok=True, parents=True) + link = fname2link.get(path.name, None) + if link is None: + raise ValueError(f'Cant find the checkpoint file: {path}.', + f'Please download it manually and ensure the path exists.') + with requests.get(fname2link[path.name], stream=True) as r: + total_size = int(r.headers.get('content-length', 0)) + with tqdm(total=total_size, unit='B', unit_scale=True) as pbar: + with open(path, 'wb') as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def get_md5sum(path): + hash_md5 = md5() + with open(path, 'rb') as f: + for chunk in iter(lambda: f.read(4096 * 8), b''): + hash_md5.update(chunk) + md5sum = hash_md5.hexdigest() + return md5sum diff --git a/data_utils/ext/synchformer/video_model_builder.py b/data_utils/ext/synchformer/video_model_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..da6decd3ab8a2f938e7df66f046451fff6413b5f --- /dev/null +++ b/data_utils/ext/synchformer/video_model_builder.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copyright 2020 Ross Wightman +# Modified Model definition + +from collections import OrderedDict +from functools import partial + +import torch +import torch.nn as nn +from timm.layers import trunc_normal_ + +from data_utils.ext.synchformer import vit_helper + + +class VisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage """ + + def __init__(self, cfg): + super().__init__() + self.img_size = cfg.DATA.TRAIN_CROP_SIZE + self.patch_size = cfg.VIT.PATCH_SIZE + self.in_chans = cfg.VIT.CHANNELS + if cfg.TRAIN.DATASET == "Epickitchens": + self.num_classes = [97, 300] + else: + self.num_classes = cfg.MODEL.NUM_CLASSES + self.embed_dim = cfg.VIT.EMBED_DIM + self.depth = cfg.VIT.DEPTH + self.num_heads = cfg.VIT.NUM_HEADS + self.mlp_ratio = cfg.VIT.MLP_RATIO + self.qkv_bias = cfg.VIT.QKV_BIAS + self.drop_rate = cfg.VIT.DROP + self.drop_path_rate = cfg.VIT.DROP_PATH + self.head_dropout = cfg.VIT.HEAD_DROPOUT + self.video_input = cfg.VIT.VIDEO_INPUT + self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION + self.use_mlp = cfg.VIT.USE_MLP + self.num_features = self.embed_dim + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT + self.head_act = cfg.VIT.HEAD_ACT + self.cfg = cfg + + # Patch Embedding + self.patch_embed = vit_helper.PatchEmbed(img_size=224, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim) + + # 3D Patch Embedding + self.patch_embed_3d = vit_helper.PatchEmbed3D(img_size=self.img_size, + temporal_resolution=self.temporal_resolution, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP) + self.patch_embed_3d.proj.weight.data = torch.zeros_like( + self.patch_embed_3d.proj.weight.data) + + # Number of patches + if self.video_input: + num_patches = self.patch_embed.num_patches * self.temporal_resolution + else: + num_patches = self.patch_embed.num_patches + self.num_patches = num_patches + + # CLS token + self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + trunc_normal_(self.cls_token, std=.02) + + # Positional embedding + self.pos_embed = nn.Parameter( + torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim)) + self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT) + trunc_normal_(self.pos_embed, std=.02) + + if self.cfg.VIT.POS_EMBED == "joint": + self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim)) + trunc_normal_(self.st_embed, std=.02) + elif self.cfg.VIT.POS_EMBED == "separate": + self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim)) + + # Layer Blocks + dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] + if self.cfg.VIT.ATTN_LAYER == "divided": + self.blocks = nn.ModuleList([ + vit_helper.DividedSpaceTimeBlock( + attn_type=cfg.VIT.ATTN_LAYER, + dim=self.embed_dim, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + ) for i in range(self.depth) + ]) + else: + self.blocks = nn.ModuleList([ + vit_helper.Block(attn_type=cfg.VIT.ATTN_LAYER, + dim=self.embed_dim, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + use_original_code=self.cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE) + for i in range(self.depth) + ]) + self.norm = norm_layer(self.embed_dim) + + # MLP head + if self.use_mlp: + hidden_dim = self.embed_dim + if self.head_act == 'tanh': + # logging.info("Using TanH activation in MLP") + act = nn.Tanh() + elif self.head_act == 'gelu': + # logging.info("Using GELU activation in MLP") + act = nn.GELU() + else: + # logging.info("Using ReLU activation in MLP") + act = nn.ReLU() + self.pre_logits = nn.Sequential( + OrderedDict([ + ('fc', nn.Linear(self.embed_dim, hidden_dim)), + ('act', act), + ])) + else: + self.pre_logits = nn.Identity() + + # Classifier Head + self.head_drop = nn.Dropout(p=self.head_dropout) + if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1: + for a, i in enumerate(range(len(self.num_classes))): + setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i])) + else: + self.head = nn.Linear(self.embed_dim, + self.num_classes) if self.num_classes > 0 else nn.Identity() + + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + if self.cfg.VIT.POS_EMBED == "joint": + return {'pos_embed', 'cls_token', 'st_embed'} + else: + return {'pos_embed', 'cls_token', 'temp_embed'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = (nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()) + + def forward_features(self, x): + # if self.video_input: + # x = x[0] + B = x.shape[0] + + # Tokenize input + # if self.cfg.VIT.PATCH_SIZE_TEMP > 1: + # for simplicity of mapping between content dimensions (input x) and token dims (after patching) + # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details): + + # apply patching on input + x = self.patch_embed_3d(x) + tok_mask = None + + # else: + # tok_mask = None + # # 2D tokenization + # if self.video_input: + # x = x.permute(0, 2, 1, 3, 4) + # (B, T, C, H, W) = x.shape + # x = x.reshape(B * T, C, H, W) + + # x = self.patch_embed(x) + + # if self.video_input: + # (B2, T2, D2) = x.shape + # x = x.reshape(B, T * T2, D2) + + # Append CLS token + cls_tokens = self.cls_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + # if tok_mask is not None: + # # prepend 1(=keep) to the mask to account for the CLS token as well + # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1) + + # Interpolate positinoal embeddings + # if self.cfg.DATA.TRAIN_CROP_SIZE != 224: + # pos_embed = self.pos_embed + # N = pos_embed.shape[1] - 1 + # npatch = int((x.size(1) - 1) / self.temporal_resolution) + # class_emb = pos_embed[:, 0] + # pos_embed = pos_embed[:, 1:] + # dim = x.shape[-1] + # pos_embed = torch.nn.functional.interpolate( + # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + # scale_factor=math.sqrt(npatch / N), + # mode='bicubic', + # ) + # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1) + # else: + new_pos_embed = self.pos_embed + npatch = self.patch_embed.num_patches + + # Add positional embeddings to input + if self.video_input: + if self.cfg.VIT.POS_EMBED == "separate": + cls_embed = self.pos_embed[:, 0, :].unsqueeze(1) + tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1) + tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1) + total_pos_embed = tile_pos_embed + tile_temporal_embed + total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1) + x = x + total_pos_embed + elif self.cfg.VIT.POS_EMBED == "joint": + x = x + self.st_embed + else: + # image input + x = x + new_pos_embed + + # Apply positional dropout + x = self.pos_drop(x) + + # Encoding using transformer layers + for i, blk in enumerate(self.blocks): + x = blk(x, + seq_len=npatch, + num_frames=self.temporal_resolution, + approx=self.cfg.VIT.APPROX_ATTN_TYPE, + num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM, + tok_mask=tok_mask) + + ### v-iashin: I moved it to the forward pass + # x = self.norm(x)[:, 0] + # x = self.pre_logits(x) + ### + return x, tok_mask + + # def forward(self, x): + # x = self.forward_features(x) + # ### v-iashin: here. This should leave the same forward output as before + # x = self.norm(x)[:, 0] + # x = self.pre_logits(x) + # ### + # x = self.head_drop(x) + # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1: + # output = [] + # for head in range(len(self.num_classes)): + # x_out = getattr(self, "head%d" % head)(x) + # if not self.training: + # x_out = torch.nn.functional.softmax(x_out, dim=-1) + # output.append(x_out) + # return output + # else: + # x = self.head(x) + # if not self.training: + # x = torch.nn.functional.softmax(x, dim=-1) + # return x diff --git a/data_utils/ext/synchformer/vit_helper.py b/data_utils/ext/synchformer/vit_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..6af730a135bf49240ec439c81c9ad0aa5c9a505e --- /dev/null +++ b/data_utils/ext/synchformer/vit_helper.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copyright 2020 Ross Wightman +# Modified Model definition +"""Video models.""" + +import math + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from timm.layers import to_2tuple +from torch import einsum +from torch.nn import functional as F + +default_cfgs = { + 'vit_1k': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + 'vit_1k_large': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', +} + + +def qkv_attn(q, k, v, tok_mask: torch.Tensor = None): + sim = einsum('b i d, b j d -> b i j', q, k) + # apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N) + if tok_mask is not None: + BSH, N = tok_mask.shape + sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0, + float('-inf')) # 1 - broadcasts across N + attn = sim.softmax(dim=-1) + out = einsum('b i j, b j d -> b i d', attn, v) + return out + + +class DividedAttention(nn.Module): + + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + # init to zeros + self.qkv.weight.data.fill_(0) + self.qkv.bias.data.fill_(0) + self.proj.weight.data.fill_(1) + self.proj.bias.data.fill_(0) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims): + # num of heads variable + h = self.num_heads + + # project x to q, k, v vaalues + q, k, v = self.qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + if tok_mask is not None: + # replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d + assert len(tok_mask.shape) == 2 + tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1]) + + # Scale q + q *= self.scale + + # Take out cls_q, cls_k, cls_v + (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v)) + # the same for masking + if tok_mask is not None: + cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:] + else: + cls_mask, mask_ = None, None + + # let CLS token attend to key / values of all patches across time and space + cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask) + + # rearrange across time or space + q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims), + (q_, k_, v_)) + + # expand CLS token keys and values across time or space and concat + r = q_.shape[0] // cls_k.shape[0] + cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v)) + + k_ = torch.cat((cls_k, k_), dim=1) + v_ = torch.cat((cls_v, v_), dim=1) + + # the same for masking (if provided) + if tok_mask is not None: + # since mask does not have the latent dim (d), we need to remove it from einops dims + mask_ = rearrange(mask_, f'{einops_from} -> {einops_to}'.replace(' d', ''), + **einops_dims) + cls_mask = repeat(cls_mask, 'b () -> (b r) ()', + r=r) # expand cls_mask across time or space + mask_ = torch.cat((cls_mask, mask_), dim=1) + + # attention + out = qkv_attn(q_, k_, v_, tok_mask=mask_) + + # merge back time or space + out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims) + + # concat back the cls token + out = torch.cat((cls_out, out), dim=1) + + # merge back the heads + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + + ## to out + x = self.proj(out) + x = self.proj_drop(x) + return x + + +class DividedSpaceTimeBlock(nn.Module): + + def __init__(self, + dim=768, + num_heads=12, + attn_type='divided', + mlp_ratio=4., + qkv_bias=False, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + + self.einops_from_space = 'b (f n) d' + self.einops_to_space = '(b f) n d' + self.einops_from_time = 'b (f n) d' + self.einops_to_time = '(b n) f d' + + self.norm1 = norm_layer(dim) + + self.attn = DividedAttention(dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop) + + self.timeattn = DividedAttention(dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop) + + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + self.norm3 = norm_layer(dim) + + def forward(self, + x, + seq_len=196, + num_frames=8, + approx='none', + num_landmarks=128, + tok_mask: torch.Tensor = None): + time_output = self.timeattn(self.norm3(x), + self.einops_from_time, + self.einops_to_time, + n=seq_len, + tok_mask=tok_mask) + time_residual = x + time_output + + space_output = self.attn(self.norm1(time_residual), + self.einops_from_space, + self.einops_to_space, + f=num_frames, + tok_mask=tok_mask) + space_residual = time_residual + self.drop_path(space_output) + + x = space_residual + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = img_size if type(img_size) is tuple else to_2tuple(img_size) + patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed3D(nn.Module): + """ Image to Patch Embedding """ + + def __init__(self, + img_size=224, + temporal_resolution=4, + in_chans=3, + patch_size=16, + z_block_size=2, + embed_dim=768, + flatten=True): + super().__init__() + self.height = (img_size // patch_size) + self.width = (img_size // patch_size) + ### v-iashin: these two are incorrect + # self.frames = (temporal_resolution // z_block_size) + # self.num_patches = self.height * self.width * self.frames + self.z_block_size = z_block_size + ### + self.proj = nn.Conv3d(in_chans, + embed_dim, + kernel_size=(z_block_size, patch_size, patch_size), + stride=(z_block_size, patch_size, patch_size)) + self.flatten = flatten + + def forward(self, x): + B, C, T, H, W = x.shape + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) + return x + + +class HeadMLP(nn.Module): + + def __init__(self, n_input, n_classes, n_hidden=512, p=0.1): + super(HeadMLP, self).__init__() + self.n_input = n_input + self.n_classes = n_classes + self.n_hidden = n_hidden + if n_hidden is None: + # use linear classifier + self.block_forward = nn.Sequential(nn.Dropout(p=p), + nn.Linear(n_input, n_classes, bias=True)) + else: + # use simple MLP classifier + self.block_forward = nn.Sequential(nn.Dropout(p=p), + nn.Linear(n_input, n_hidden, bias=True), + nn.BatchNorm1d(n_hidden), nn.ReLU(inplace=True), + nn.Dropout(p=p), + nn.Linear(n_hidden, n_classes, bias=True)) + print(f"Dropout-NLP: {p}") + + def forward(self, x): + return self.block_forward(x) + + +def _conv_filter(state_dict, patch_size=16): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k: + v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + out_dict[k] = v + return out_dict + + +def adapt_input_conv(in_chans, conv_weight, agg='sum'): + conv_type = conv_weight.dtype + conv_weight = conv_weight.float() + O, I, J, K = conv_weight.shape + if in_chans == 1: + if I > 3: + assert conv_weight.shape[1] % 3 == 0 + # For models with space2depth stems + conv_weight = conv_weight.reshape(O, I // 3, 3, J, K) + conv_weight = conv_weight.sum(dim=2, keepdim=False) + else: + if agg == 'sum': + print("Summing conv1 weights") + conv_weight = conv_weight.sum(dim=1, keepdim=True) + else: + print("Averaging conv1 weights") + conv_weight = conv_weight.mean(dim=1, keepdim=True) + elif in_chans != 3: + if I != 3: + raise NotImplementedError('Weight format not supported by conversion.') + else: + if agg == 'sum': + print("Summing conv1 weights") + repeat = int(math.ceil(in_chans / 3)) + conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :] + conv_weight *= (3 / float(in_chans)) + else: + print("Averaging conv1 weights") + conv_weight = conv_weight.mean(dim=1, keepdim=True) + conv_weight = conv_weight.repeat(1, in_chans, 1, 1) + conv_weight = conv_weight.to(conv_type) + return conv_weight + + +def load_pretrained(model, + cfg=None, + num_classes=1000, + in_chans=3, + filter_fn=None, + strict=True, + progress=False): + # Load state dict + assert (f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]") + state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS]) + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + input_convs = 'patch_embed.proj' + if input_convs is not None and in_chans != 3: + if isinstance(input_convs, str): + input_convs = (input_convs, ) + for input_conv_name in input_convs: + weight_name = input_conv_name + '.weight' + try: + state_dict[weight_name] = adapt_input_conv(in_chans, + state_dict[weight_name], + agg='avg') + print( + f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)' + ) + except NotImplementedError as e: + del state_dict[weight_name] + strict = False + print( + f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.' + ) + + classifier_name = 'head' + label_offset = cfg.get('label_offset', 0) + pretrain_classes = 1000 + if num_classes != pretrain_classes: + # completely discard fully connected if model num_classes doesn't match pretrained weights + del state_dict[classifier_name + '.weight'] + del state_dict[classifier_name + '.bias'] + strict = False + elif label_offset > 0: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + + loaded_state = state_dict + self_state = model.state_dict() + all_names = set(self_state.keys()) + saved_names = set([]) + for name, param in loaded_state.items(): + param = param + if 'module.' in name: + name = name.replace('module.', '') + if name in self_state.keys() and param.shape == self_state[name].shape: + saved_names.add(name) + self_state[name].copy_(param) + else: + print(f"didnt load: {name} of shape: {param.shape}") + print("Missing Keys:") + print(all_names - saved_names) diff --git a/data_utils/utils.py b/data_utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..68cb806d9a44d9f6fa891f715b4d64f6534c3bd1 --- /dev/null +++ b/data_utils/utils.py @@ -0,0 +1,115 @@ +"""Utility functions.""" +import contextlib +import csv +import json +import os +import pathlib +import warnings + +import numpy as np + + +def save_args(filename, args): + """Save the command-line arguments.""" + args_dict = {} + for key, value in vars(args).items(): + if isinstance(value, pathlib.Path): + args_dict[key] = str(value) + else: + args_dict[key] = value + save_json(filename, args_dict) + + +def inverse_dict(d): + """Return the inverse dictionary.""" + return {v: k for k, v in d.items()} + + +def save_txt(filename, data): + """Save a list to a TXT file.""" + with open(filename, "w", encoding="utf8") as f: + for item in data: + f.write(f"{item}\n") + + +def load_txt(filename): + """Load a TXT file as a list.""" + with open(filename, encoding="utf8") as f: + return [line.strip() for line in f] + + +def save_json(filename, data): + """Save data as a JSON file.""" + with open(filename, "w", encoding="utf8") as f: + json.dump(data, f) + + +def load_json(filename): + """Load data from a JSON file.""" + with open(filename, encoding="utf8") as f: + return json.load(f) + + +def save_csv(filename, data, header=""): + """Save data as a CSV file.""" + np.savetxt( + filename, data, fmt="%d", delimiter=",", header=header, comments="" + ) + + +def load_csv(filename, skiprows=1): + """Load data from a CSV file.""" + return np.loadtxt(filename, dtype=int, delimiter=",", skiprows=skiprows) + + +def load_csv_text(filename, headerless=True): + """Read a CSV file into a list of dictionaries or lists.""" + with open(filename) as f: + if headerless: + return [row for row in csv.reader(f)] + reader = csv.DictReader(f) + return [ + {field: row[field] for field in reader.fieldnames} + for row in reader + ] + + +def ignore_exceptions(func): + """Decorator that ignores all errors and warnings.""" + + def inner(*args, **kwargs): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + return func(*args, **kwargs) + except Exception: + return None + + return inner + + +def suppress_outputs(func): + """Decorator that suppresses writing to stdout and stderr.""" + + def inner(*args, **kwargs): + devnull = open(os.devnull, "w") + with contextlib.redirect_stdout(devnull): + with contextlib.redirect_stderr(devnull): + return func(*args, **kwargs) + + return inner + + +def resolve_paths(func): + """Decorator that resolves all paths.""" + + def inner(*args, **kwargs): + parsed = func(*args, **kwargs) + for key in vars(parsed).keys(): + if isinstance(getattr(parsed, key), pathlib.Path): + setattr( + parsed, key, getattr(parsed, key).expanduser().resolve() + ) + return parsed + + return inner diff --git a/data_utils/v2a_utils/__init__.py b/data_utils/v2a_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/data_utils/v2a_utils/__pycache__/__init__.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10d232e6b7d73a07d1d4b7f08ac3a581b9110a89 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be016caa49d423f4b0a52f66ac3b210b52423ce1 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-310.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-38.pyc b/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c38b83aea9638b59c45ec959b4c1cc5281f0f89e Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-38.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b80d86ad166b7715cf5e7e0b8aef626c262b78f4 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-39.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/audioset_224.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/audioset_224.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1996d974e0dd8489b9ee54b64093d893a8bb8c6f Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/audioset_224.cpython-39.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/audioset_video_224.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/audioset_video_224.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa209c10ef36183642053047c60abdfa9b715796 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/audioset_video_224.cpython-39.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/feature_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..471714af83c05d9612ddda5bb09dfab70813ec34 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/feature_utils.cpython-310.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/feature_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e133a860468cbc37baa9acea75ec1a2b7c5c456 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/feature_utils.cpython-39.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c33953aafdaa73f59beeced690bee714dcee756a Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-310.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6577c2bb22d4d84bbbc2013c74a74e30ab89b71d Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-39.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6565541958c14e9294afcfc29f9ce764fe04e83 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-310.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-38.pyc b/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcc14fe670bd8fe46989ecba1e5bdc9a4e47f937 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-38.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e62be420b89e4fcf0d698b6f0ff6aac91cb8889e Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-39.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/feature_utils_224_no_sync.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/feature_utils_224_no_sync.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34cd3500ab00860c8727a17891e7ccbfd084fa76 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/feature_utils_224_no_sync.cpython-39.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/vggsound.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..174ca8efd7276497218351d82f28f9a139fa5d16 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/vggsound.cpython-310.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/vggsound.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..510649119739eb6dab8d71fdecec9429910c06b7 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/vggsound.cpython-39.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound_224.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/vggsound_224.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f414d5b164c9b3eb63b322f1d46061bdcfd94c9 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/vggsound_224.cpython-310.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound_224.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/vggsound_224.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6f8d690d16e6cc85cc4b9d946ca2a7d7e02d54f Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/vggsound_224.cpython-39.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound_224_no_audio.cpython-310.pyc b/data_utils/v2a_utils/__pycache__/vggsound_224_no_audio.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d15160ffbd28c168fd85ba1cc0107ac9a05cd830 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/vggsound_224_no_audio.cpython-310.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound_224_no_sync.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/vggsound_224_no_sync.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e68e23e6f0e5f2487d985741a09cfab54638f231 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/vggsound_224_no_sync.cpython-39.pyc differ diff --git a/data_utils/v2a_utils/__pycache__/vggsound_text.cpython-39.pyc b/data_utils/v2a_utils/__pycache__/vggsound_text.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ac3fd63b2d590d77c06801489338005bc10c953 Binary files /dev/null and b/data_utils/v2a_utils/__pycache__/vggsound_text.cpython-39.pyc differ diff --git a/data_utils/v2a_utils/feature_utils_224.py b/data_utils/v2a_utils/feature_utils_224.py new file mode 100644 index 0000000000000000000000000000000000000000..520d51f30ff47813b482742de59a50e004a0420a --- /dev/null +++ b/data_utils/v2a_utils/feature_utils_224.py @@ -0,0 +1,182 @@ +from typing import Literal, Optional +import json +import open_clip +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from open_clip import create_model_from_pretrained +from torchvision.transforms import Normalize +from think_sound.models.factory import create_model_from_config +from think_sound.models.utils import load_ckpt_state_dict +from think_sound.training.utils import copy_state_dict +from transformers import AutoModel +from transformers import AutoProcessor +from transformers import T5EncoderModel, AutoTokenizer +import logging +from data_utils.ext.synchformer import Synchformer + +log = logging.getLogger() + +def patch_clip(clip_model): + # a hack to make it output last hidden states + # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 + def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = text_outputs[0] + pooled_output = text_outputs[1] + text_features = self.text_projection(pooled_output) + + return text_features, last_hidden_state + + clip_model.get_text_features = new_get_text_features.__get__(clip_model) + return clip_model + + +class FeaturesUtils(nn.Module): + + def __init__( + self, + *, + vae_ckpt: Optional[str] = None, + vae_config: Optional[str] = None, + synchformer_ckpt: Optional[str] = None, + enable_conditions: bool = True, + need_vae_encoder: bool = True, + ): + super().__init__() + + if enable_conditions: + self.clip_model = AutoModel.from_pretrained("facebook/metaclip-h14-fullcc2.5b") + self.clip_model = patch_clip(self.clip_model) + self.t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl") + self.t5_model = T5EncoderModel.from_pretrained("google/t5-v1_1-xl") + self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b") + # self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073], + # std=[0.26862954, 0.26130258, 0.27577711]) + self.synchformer = Synchformer() + self.synchformer.load_state_dict( + torch.load(synchformer_ckpt, weights_only=True, map_location='cpu')) + + # self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14' + else: + self.clip_model = None + self.synchformer = None + self.tokenizer = None + + if vae_ckpt is not None: + with open(vae_config) as f: + vae_config = json.load(f) + self.vae = create_model_from_config(vae_config) + print(f"Loading model checkpoint from {vae_ckpt}") + # Load checkpoint + copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.' + else: + self.tod = None + + def compile(self): + if self.clip_model is not None: + self.clip_model.encode_image = torch.compile(self.clip_model.encode_image) + self.clip_model.encode_text = torch.compile(self.clip_model.encode_text) + if self.synchformer is not None: + self.synchformer = torch.compile(self.synchformer) + + + def train(self, mode: bool) -> None: + return super().train(False) + + @torch.inference_mode() + def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor: + assert self.clip_model is not None, 'CLIP is not loaded' + # x: (B, T, C, H, W) H/W: 384 + b, t, c, h, w = x.shape + + assert c == 3 and h == 224 and w == 224 + # x = self.clip_preprocess(x) + x = rearrange(x, 'b t c h w -> (b t) c h w') + outputs = [] + if batch_size < 0: + batch_size = b * t + for i in range(0, b * t, batch_size): + outputs.append(self.clip_model.get_image_features(x[i:i + batch_size])) + x = torch.cat(outputs, dim=0) + # x = self.clip_model.encode_image(x, normalize=True) + x = rearrange(x, '(b t) d -> b t d', b=b) + return x + + @torch.inference_mode() + def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor: + assert self.synchformer is not None, 'Synchformer is not loaded' + # x: (B, T, C, H, W) H/W: 384 + b, t, c, h, w = x.shape + # import ipdb + # ipdb.set_trace() + assert c == 3 and h == 224 and w == 224 + + # partition the video + segment_size = 16 + step_size = 8 + num_segments = (t - segment_size) // step_size + 1 + segments = [] + for i in range(num_segments): + segments.append(x[:, i * step_size:i * step_size + segment_size]) + x = torch.stack(segments, dim=1) # (B, S, T, C, H, W) + + outputs = [] + if batch_size < 0: + batch_size = b + x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w') + for i in range(0, b * num_segments, batch_size): + outputs.append(self.synchformer(x[i:i + batch_size])) + x = torch.cat(outputs, dim=0) + x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b) + return x + + @torch.inference_mode() + def encode_text(self, text: list[str]) -> torch.Tensor: + assert self.clip_model is not None, 'CLIP is not loaded' + # assert self.tokenizer is not None, 'Tokenizer is not loaded' + # x: (B, L) + tokens = self.clip_processor(text=text, truncation=True, max_length=77, padding="max_length",return_tensors="pt").to(self.device) + return self.clip_model.get_text_features(**tokens) + + @torch.inference_mode() + def encode_t5_text(self, text: list[str]) -> torch.Tensor: + assert self.t5_model is not None, 'T5 model is not loaded' + assert self.t5_tokenizer is not None, 'T5 Tokenizer is not loaded' + # x: (B, L) + inputs = self.t5_tokenizer(text, + truncation=True, + max_length=77, + padding="max_length", + return_tensors="pt").to(self.device) + return self.t5_model(**inputs).last_hidden_state + + @torch.inference_mode() + def encode_audio(self, x) -> torch.Tensor: + x = self.vae.encode(x) + return x + + @property + def device(self): + return next(self.parameters()).device + + @property + def dtype(self): + return next(self.parameters()).dtype diff --git a/defaults.ini b/defaults.ini new file mode 100644 index 0000000000000000000000000000000000000000..2b7cf194a68b3bab4d28eb3aade5cf71c01a7065 --- /dev/null +++ b/defaults.ini @@ -0,0 +1,64 @@ + +[DEFAULTS] + +#name of the run +name = think_sound + +# the batch size +batch_size = 8 +test_batch_size = 32 + +# predict ckpt directory +ckpt_dir = "" + +# number of GPUs to use for training +num_gpus = 1 + +# number of nodes to use for training +num_nodes = 1 + +# Multi-GPU strategy for PyTorch Lightning +strategy = "" + +# Precision to use for training +precision = "bf16-mixed" + +# number of CPU workers for the DataLoader +num_workers = 8 + +# the random seed +seed = 42 + +# Batches for gradient accumulation +accum_batches = 1 + +# Number of steps between checkpoints +checkpoint_every = 2000 + +# trainer checkpoint file to restart training from +ckpt_path = '' + +# model checkpoint file to start a new training run from +pretrained_ckpt_path = '' + +# Checkpoint path for the pretransform model if needed +pretransform_ckpt_path = '' + +# configuration model specifying model hyperparameters +model_config = '' + +# configuration for datasets +dataset_config = '' + +# directory to save the checkpoints in +save_dir = '' + +# gradient_clip_val passed into PyTorch Lightning Trainer +gradient_clip_val = 0.0 + +# remove the weight norm from the pretransform model +remove_pretransform_weight_norm = '' + +compile = False + +repeat_num = 5 \ No newline at end of file diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..8973636ce2541938d73f37c2e77b53240d5742a2 --- /dev/null +++ b/requirement.txt @@ -0,0 +1,232 @@ +absl-py==2.2.2 +accelerate==1.7.0 +aeiou==0.0.20 +aiobotocore==2.22.0 +aiofiles==24.1.0 +aiohappyeyeballs==2.6.1 +aiohttp==3.11.18 +aioitertools==0.12.0 +aiosignal==1.3.2 +alias-free-torch==0.0.6 +altair==5.5.0 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.9.0 +appdirs==1.4.4 +argbind==0.3.9 +asttokens==3.0.0 +async-timeout==5.0.1 +attrs==25.3.0 +audioread==3.0.1 +auraloss==0.4.0 +av==14.4.0 +bleach==6.2.0 +bokeh==3.7.3 +botocore==1.37.3 +braceexpand==0.1.7 +certifi==2025.4.26 +cffi==1.17.1 +charset-normalizer==3.4.2 +clean-fid==0.1.35 +click==8.1.8 +clip-anytorch==2.6.0 +cloudpickle==3.1.1 +colorcet==3.1.0 +colorlog==6.9.0 +configparser==7.2.0 +contourpy==1.3.2 +cycler==0.12.1 +Cython==3.1.1 +dctorch==0.1.2 +decorator==5.2.1 +decord==0.6.0 +descript-audio-codec==1.0.0 +descript-audiotools==0.7.2 +docker-pycreds==0.4.0 +docstring_parser==0.16 +einops==0.7.0 +einops-exts==0.0.4 +ema-pytorch==0.2.3 +encodec==0.1.1 +exceptiongroup==1.3.0 +executing==2.2.0 +fastapi==0.115.12 +fastcore==1.8.2 +ffmpeg==1.4 +ffmpy==0.5.0 +filelock==3.18.0 +fire==0.7.0 +flatten-dict==0.4.2 +fonttools==4.58.0 +frozenlist==1.6.0 +fsspec==2025.5.0 +ftfy==6.3.1 +future==1.0.0 +gin-config==0.5.0 +gitdb==4.0.12 +GitPython==3.1.44 +gradio==5.31.0 +gradio_client==1.10.1 +groovy==0.1.2 +grpcio==1.71.0 +h11==0.16.0 +h5py==3.13.0 +holoviews==1.20.2 +httpcore==1.0.9 +httpx==0.28.1 +huggingface-hub==0.31.4 +hydra-colorlog==1.2.0 +hydra-core==1.3.2 +idna==3.10 +imageio==2.37.0 +importlib-resources==5.12.0 +importlib_metadata==8.7.0 +ipython==8.36.0 +jedi==0.19.2 +Jinja2==3.1.6 +jmespath==1.0.1 +joblib==1.5.0 +jsonmerge==1.9.2 +jsonschema==4.23.0 +jsonschema-specifications==2025.4.1 +julius==0.2.7 +k-diffusion==0.1.1 +kiwisolver==1.4.8 +kornia==0.8.1 +kornia_rs==0.1.9 +laion-clap==1.1.4 +lazy_loader==0.4 +librosa==0.9.2 +lightning==2.5.1.post0 +lightning-utilities==0.14.3 +linkify-it-py==2.0.3 +llvmlite==0.43.0 +local-attention==1.8.6 +Markdown==3.8 +markdown-it-py==3.0.0 +markdown2==2.5.3 +MarkupSafe==3.0.2 +matplotlib==3.10.3 +matplotlib-inline==0.1.7 +mdit-py-plugins==0.4.2 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.4.4 +narwhals==1.40.0 +networkx==3.4.2 +nitrous_ema==0.0.1 +numba==0.60.0 +numpy==1.23.5 +omegaconf==2.3.0 +open_clip_torch==2.32.0 +opencv-python==4.11.0.86 +orjson==3.10.18 +packaging==24.2 +pandas==2.0.2 +panel==1.7.0 +param==2.2.0 +parso==0.8.4 +pathtools==0.1.2 +pedalboard==0.7.4 +pexpect==4.9.0 +pillow==11.2.1 +platformdirs==4.3.8 +plotly==6.1.1 +pooch==1.8.2 +prefigure==0.0.9 +progressbar==2.5 +prompt_toolkit==3.0.51 +propcache==0.3.1 +protobuf==3.19.6 +psutil==7.0.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pycparser==2.22 +pydantic==2.11.5 +pydantic_core==2.33.2 +pydub==0.25.1 +Pygments==2.19.1 +pyloudnorm==0.1.1 +pynndescent==0.5.13 +pyparsing==3.2.3 +pystoi==0.4.1 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.0 +python-multipart==0.0.20 +pytorch-lightning==2.5.1.post0 +pytz==2025.2 +pyviz_comms==3.0.4 +PyWavelets==1.4.1 +PyYAML==6.0.2 +randomname==0.2.1 +referencing==0.36.2 +regex==2024.11.6 +requests==2.32.3 +resampy==0.4.3 +rich==14.0.0 +rpds-py==0.25.1 +ruff==0.11.11 +s3fs==2025.5.0 +safehttpx==0.1.6 +safetensors==0.5.3 +scikit-image==0.24.0 +scikit-learn==1.6.1 +scipy==1.15.3 +semantic-version==2.10.0 +sentencepiece==0.1.99 +sentry-sdk==2.29.1 +setproctitle==1.3.6 +shellingham==1.5.4 +six==1.17.0 +smmap==5.0.2 +sniffio==1.3.1 +SoundFile==0.10.2 +stack-data==0.6.3 +starlette==0.46.2 +sympy==1.13.1 +tensorboard==2.19.0 +tensorboard-data-server==0.7.2 +tensordict==0.8.3 +termcolor==3.1.0 +threadpoolctl==3.6.0 +tifffile==2025.5.10 +timm==1.0.15 +tokenizers==0.21.1 +tomlkit==0.13.2 +torch==2.6.0 +torch-stoi==0.2.3 +torchaudio==2.6.0 +torchdiffeq==0.2.5 +torchlibrosa==0.1.0 +torchmetrics==0.11.4 +torchsde==0.2.6 +torchvision==0.21.0 +tornado==6.5.1 +tqdm==4.67.1 +traitlets==5.14.3 +trampoline==0.1.2 +transformers==4.52.3 +triton==3.2.0 +typer==0.15.4 +typing-inspection==0.4.1 +typing_extensions==4.13.2 +tzdata==2025.2 +uc-micro-py==1.0.3 +umap-learn==0.5.7 +urllib3==2.4.0 +uvicorn==0.34.2 +v-diffusion-pytorch==0.0.2 +vector-quantize-pytorch==1.9.14 +wandb==0.15.4 +wcwidth==0.2.13 +webdataset==0.2.48 +webencodings==0.5.1 +websockets==15.0.1 +Werkzeug==3.1.3 +wget==3.2 +wrapt==1.17.2 +x-transformers==1.26.6 +xyzservices==2025.4.0 +yarl==1.20.0 +zipp==3.21.0 +git+https://github.com/patrick-kidger/torchcubicspline.git diff --git a/think_sound/__init__.py b/think_sound/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22446be50eb6617222c50b007a38d06490cbab41 --- /dev/null +++ b/think_sound/__init__.py @@ -0,0 +1,2 @@ +from .models.factory import create_model_from_config, create_model_from_config_path +from .models.pretrained import get_pretrained_model \ No newline at end of file diff --git a/think_sound/__pycache__/__init__.cpython-310.pyc b/think_sound/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3122a65e35c929cf1509f27a8c66e31d56299281 Binary files /dev/null and b/think_sound/__pycache__/__init__.cpython-310.pyc differ diff --git a/think_sound/__pycache__/__init__.cpython-38.pyc b/think_sound/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0050985c5802bad6bcb67109e9c81e99c54603d0 Binary files /dev/null and b/think_sound/__pycache__/__init__.cpython-38.pyc differ diff --git a/think_sound/__pycache__/__init__.cpython-39.pyc b/think_sound/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b045e33239dbb5291b54962bfa6855d0a29d0c2b Binary files /dev/null and b/think_sound/__pycache__/__init__.cpython-39.pyc differ diff --git a/think_sound/configs/model_configs/audiossl/flow_audio_channel_ssl.json b/think_sound/configs/model_configs/audiossl/flow_audio_channel_ssl.json new file mode 100644 index 0000000000000000000000000000000000000000..5e3b3cc543c1da956020a441514594308cf804b0 --- /dev/null +++ b/think_sound/configs/model_configs/audiossl/flow_audio_channel_ssl.json @@ -0,0 +1,98 @@ +{ + "model_type": "diffusion_infill", + "sample_size": 441000, + "sample_rate": 44100, + "audio_channels": 4, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "requires_grad": false, + "config": { + "in_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 4 + } + }, + "diffusion": { + "input_concat_ids": ["x_ctx"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "project_cond_tokens": false, + "input_concat_dim": 64, + "transformer_type": "continuous_transformer", + "ctx_drop": 0.1 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "timestep_sampler": "logit_normal", + "diffusion_objective": "rectified_flow", + "frac_lengths_mask": [0.7, 1.0], + "min_span_len": 10, + "ctx_drop": 0.1, + "r_drop": 0.2, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-4, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 64, + "num_demos": 1, + "demo_cond": [ + {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} + ], + "demo_cfg_scales": [7] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/audiossl/flow_audio_ssl.json b/think_sound/configs/model_configs/audiossl/flow_audio_ssl.json new file mode 100644 index 0000000000000000000000000000000000000000..9a79e2be2e5cf0150505155909ce127277ddefef --- /dev/null +++ b/think_sound/configs/model_configs/audiossl/flow_audio_ssl.json @@ -0,0 +1,97 @@ +{ + "model_type": "diffusion_infill", + "sample_size": 440320, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "requires_grad": false, + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "diffusion": { + "input_concat_ids": ["x_ctx"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "project_cond_tokens": false, + "input_concat_dim": 64, + "transformer_type": "continuous_transformer", + "ctx_drop": 0.1 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "timestep_sampler": "logit_normal", + "diffusion_objective": "rectified_flow", + "frac_lengths_mask": [0.7, 1.0], + "min_span_len": 10, + "ctx_drop": 0.1, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 100, + "num_demos": 1, + "demo_cond": [ + {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} + ], + "demo_cfg_scales": [7] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl.json b/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl.json new file mode 100644 index 0000000000000000000000000000000000000000..f9ea6f38ebbd73b6dab02edfd8259b8efcf999c9 --- /dev/null +++ b/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl.json @@ -0,0 +1,98 @@ +{ + "model_type": "diffusion_infill", + "sample_size": 441000, + "sample_rate": 44100, + "audio_channels": 4, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "requires_grad": false, + "config": { + "in_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 4 + } + }, + "diffusion": { + "input_concat_ids": ["x_ctx"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "project_cond_tokens": false, + "input_concat_dim": 64, + "transformer_type": "continuous_transformer", + "ctx_drop": 0.1 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "timestep_sampler": "logit_normal", + "diffusion_objective": "rectified_flow", + "frac_lengths_mask": [0.7, 1.0], + "min_span_len": 10, + "ctx_drop": 0.1, + "r_drop": 0.0, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 64, + "num_demos": 1, + "demo_cond": [ + {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} + ], + "demo_cfg_scales": [7] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_drop0.0.json b/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_drop0.0.json new file mode 100644 index 0000000000000000000000000000000000000000..f269e1545a8f9c451c460fdb9fbce23c9d6cf7d3 --- /dev/null +++ b/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_drop0.0.json @@ -0,0 +1,98 @@ +{ + "model_type": "diffusion_infill", + "sample_size": 441000, + "sample_rate": 44100, + "audio_channels": 4, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "requires_grad": false, + "config": { + "in_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 4 + } + }, + "diffusion": { + "input_concat_ids": ["x_ctx"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "project_cond_tokens": false, + "input_concat_dim": 64, + "transformer_type": "continuous_transformer", + "ctx_drop": 0.1 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "timestep_sampler": "logit_normal", + "diffusion_objective": "rectified_flow", + "frac_lengths_mask": [0.7, 1.0], + "min_span_len": 10, + "ctx_drop": 0.0, + "r_drop": 0.0, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 64, + "num_demos": 1, + "demo_cond": [ + {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} + ], + "demo_cfg_scales": [7] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_drop0.2.json b/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_drop0.2.json new file mode 100644 index 0000000000000000000000000000000000000000..12a7d852975536a2f57ee7db3c42e291e1d140c0 --- /dev/null +++ b/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_drop0.2.json @@ -0,0 +1,98 @@ +{ + "model_type": "diffusion_infill", + "sample_size": 441000, + "sample_rate": 44100, + "audio_channels": 4, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "requires_grad": false, + "config": { + "in_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 4 + } + }, + "diffusion": { + "input_concat_ids": ["x_ctx"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "project_cond_tokens": false, + "input_concat_dim": 64, + "transformer_type": "continuous_transformer", + "ctx_drop": 0.1 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "timestep_sampler": "logit_normal", + "diffusion_objective": "rectified_flow", + "frac_lengths_mask": [0.7, 1.0], + "min_span_len": 10, + "ctx_drop": 0.2, + "r_drop": 0.0, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 64, + "num_demos": 1, + "demo_cond": [ + {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} + ], + "demo_cfg_scales": [7] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_lr.json b/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_lr.json new file mode 100644 index 0000000000000000000000000000000000000000..abe1e45098f507bfa832c6410e0e35711dc243f1 --- /dev/null +++ b/think_sound/configs/model_configs/audiossl/flow_audio_token_ssl_lr.json @@ -0,0 +1,98 @@ +{ + "model_type": "diffusion_infill", + "sample_size": 441000, + "sample_rate": 44100, + "audio_channels": 4, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "requires_grad": false, + "config": { + "in_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 4 + } + }, + "diffusion": { + "input_concat_ids": ["x_ctx"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "project_cond_tokens": false, + "input_concat_dim": 64, + "transformer_type": "continuous_transformer", + "ctx_drop": 0.1 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "timestep_sampler": "logit_normal", + "diffusion_objective": "rectified_flow", + "frac_lengths_mask": [0.7, 1.0], + "min_span_len": 10, + "ctx_drop": 0.1, + "r_drop": 0.0, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 100000, + "power": 0.5, + "warmup": 0.999 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 64, + "num_demos": 1, + "demo_cond": [ + {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} + ], + "demo_cfg_scales": [7] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/dac_2048_32_vae.json b/think_sound/configs/model_configs/autoencoders/dac_2048_32_vae.json new file mode 100644 index 0000000000000000000000000000000000000000..25457472a9d4b0d096abc1f7b197d6f4fb8a7fa7 --- /dev/null +++ b/think_sound/configs/model_configs/autoencoders/dac_2048_32_vae.json @@ -0,0 +1,71 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "encoder": { + "type": "dac", + "config": { + "latent_dim": 64, + "d_model": 128, + "strides": [4, 8, 8, 8] + } + }, + "decoder": { + "type": "dac", + "config": { + "latent_dim": 32, + "channels": 1536, + "rates": [8, 8, 8, 4] + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 32, + "downsampling_ratio": 2048, + "io_channels": 1 + }, + "training": { + "learning_rate": 1e-4, + "warmup_steps": 0, + "use_ema": false, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 32, + "n_ffts": [2048, 1024, 512, 256, 128, 64, 32], + "hop_lengths": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + } + }, + "demo": { + "demo_every": 2000 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/encodec_musicgen_rvq.json b/think_sound/configs/model_configs/autoencoders/encodec_musicgen_rvq.json new file mode 100644 index 0000000000000000000000000000000000000000..e76bd3d9a12ae028f3038562ce8082b8eadca116 --- /dev/null +++ b/think_sound/configs/model_configs/autoencoders/encodec_musicgen_rvq.json @@ -0,0 +1,88 @@ +{ + "model_type": "autoencoder", + "sample_size": 32000, + "sample_rate": 32000, + "audio_channels": 1, + "model": { + "encoder": { + "type": "seanet", + "config": { + "channels": 1, + "dimension": 128, + "n_filters": 64, + "ratios": [4, 4, 5, 8], + "n_residual_layers": 1, + "dilation_base": 2, + "lstm": 2, + "norm": "weight_norm" + } + }, + "decoder": { + "type": "seanet", + "config": { + "channels": 1, + "dimension": 128, + "n_filters": 64, + "ratios": [4, 4, 5, 8], + "n_residual_layers": 1, + "dilation_base": 2, + "lstm": 2, + "norm": "weight_norm" + } + }, + "bottleneck": { + "type": "rvq", + "config": { + "num_quantizers": 4, + "codebook_size": 2048, + "dim": 128, + "decay": 0.99, + "threshold_ema_dead_code": 2 + } + }, + "latent_dim": 128, + "downsampling_ratio": 640, + "io_channels": 1 + }, + "training": { + "learning_rate": 1e-4, + "warmup_steps": 0, + "use_ema": true, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 32, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + } + }, + "demo": { + "demo_every": 2000 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/foa_audio_vae.json b/think_sound/configs/model_configs/autoencoders/foa_audio_vae.json new file mode 100644 index 0000000000000000000000000000000000000000..f1128edad22c618120c7e28c2bf7a68bd1a015e9 --- /dev/null +++ b/think_sound/configs/model_configs/autoencoders/foa_audio_vae.json @@ -0,0 +1,122 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 4, + "model": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 4 + }, + "training": { + "learning_rate": 3e-5, + "warmup_steps": 0, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 3e-5, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 6e-5, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 64, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-4 + } + } + }, + "demo": { + "demo_every": 10000 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/foa_audio_vae_256.json b/think_sound/configs/model_configs/autoencoders/foa_audio_vae_256.json new file mode 100644 index 0000000000000000000000000000000000000000..5feb255de582b9aae0024daafd2ec5ad77edbb11 --- /dev/null +++ b/think_sound/configs/model_configs/autoencoders/foa_audio_vae_256.json @@ -0,0 +1,122 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 4, + "model": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 4 + }, + "training": { + "learning_rate": 1.5e-4, + "warmup_steps": 0, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1.5e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 3e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 64, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-4 + } + } + }, + "demo": { + "demo_every": 10000 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/foa_audio_vae_256_lr1e-4.json b/think_sound/configs/model_configs/autoencoders/foa_audio_vae_256_lr1e-4.json new file mode 100644 index 0000000000000000000000000000000000000000..063d69ff1ce1a371e60dd1fc176434d50acbd788 --- /dev/null +++ b/think_sound/configs/model_configs/autoencoders/foa_audio_vae_256_lr1e-4.json @@ -0,0 +1,122 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 4, + "model": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 4, + "channels": 256, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 256, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 4, + "channels": 256, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 128, + "downsampling_ratio": 2048, + "io_channels": 4 + }, + "training": { + "learning_rate": 1.5e-4, + "warmup_steps": 0, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 2e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 64, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-4 + } + } + }, + "demo": { + "demo_every": 10000 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/foa_audio_vae_decoder.json b/think_sound/configs/model_configs/autoencoders/foa_audio_vae_decoder.json new file mode 100644 index 0000000000000000000000000000000000000000..bb604d623300dac618e00f98f73f7eab5fd04228 --- /dev/null +++ b/think_sound/configs/model_configs/autoencoders/foa_audio_vae_decoder.json @@ -0,0 +1,124 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 4, + "model": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 4, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 4 + }, + "training": { + "learning_rate": 3e-5, + "warmup_steps": 0, + "latent_mask_ratio": 0.1, + "encoder_freeze_on_warmup": true, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 3e-5, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 6e-5, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 64, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-4 + } + } + }, + "demo": { + "demo_every": 10000 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/ori.json b/think_sound/configs/model_configs/autoencoders/ori.json new file mode 100644 index 0000000000000000000000000000000000000000..3aa762f2a4bb3ff631fd53401c5ec22e524e9bf2 --- /dev/null +++ b/think_sound/configs/model_configs/autoencoders/ori.json @@ -0,0 +1,122 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + }, + "training": { + "learning_rate": 1.5e-4, + "warmup_steps": 0, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1.5e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 3e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 64, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-4 + } + } + }, + "demo": { + "demo_every": 2000 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/speech_vae.json b/think_sound/configs/model_configs/autoencoders/speech_vae.json new file mode 100644 index 0000000000000000000000000000000000000000..7d8a9a7b15ca48be660b55d48a87244b3d53f27a --- /dev/null +++ b/think_sound/configs/model_configs/autoencoders/speech_vae.json @@ -0,0 +1,122 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 16000, + "audio_channels": 1, + "model": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 1, + "channels": 64, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 1, + "channels": 64, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 1 + }, + "training": { + "learning_rate": 1.5e-4, + "warmup_steps": 0, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1.5e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 3e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 64, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-4 + } + } + }, + "demo": { + "demo_every": 10000 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/speech_vae_44k.json b/think_sound/configs/model_configs/autoencoders/speech_vae_44k.json new file mode 100644 index 0000000000000000000000000000000000000000..6f77e2e17823517af3d5cede126d40acf8b5f5dc --- /dev/null +++ b/think_sound/configs/model_configs/autoencoders/speech_vae_44k.json @@ -0,0 +1,122 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + }, + "training": { + "learning_rate": 1.5e-4, + "warmup_steps": 0, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1.5e-5, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 3e-5, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 64, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-4 + } + } + }, + "demo": { + "demo_every": 10000 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/stable_audio_1_0_vae.json b/think_sound/configs/model_configs/autoencoders/stable_audio_1_0_vae.json new file mode 100644 index 0000000000000000000000000000000000000000..26dcb25f3322e79422c7ab288aace9f23e711768 --- /dev/null +++ b/think_sound/configs/model_configs/autoencoders/stable_audio_1_0_vae.json @@ -0,0 +1,111 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "encoder": { + "type": "dac", + "config": { + "in_channels": 2, + "latent_dim": 128, + "d_model": 128, + "strides": [4, 4, 8, 8] + } + }, + "decoder": { + "type": "dac", + "config": { + "out_channels": 2, + "latent_dim": 64, + "channels": 1536, + "rates": [8, 8, 4, 4] + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 1024, + "io_channels": 2 + }, + "training": { + "learning_rate": 1e-4, + "warmup_steps": 0, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1e-4 + } + }, + "scheduler": { + "type": "ExponentialLR", + "config": { + "gamma": 0.999996 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1e-4 + } + }, + "scheduler": { + "type": "ExponentialLR", + "config": { + "gamma": 0.999996 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 32, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-6 + } + } + }, + "demo": { + "demo_every": 2000 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/autoencoders/stable_audio_2_0_vae.json b/think_sound/configs/model_configs/autoencoders/stable_audio_2_0_vae.json new file mode 100644 index 0000000000000000000000000000000000000000..95f72495502f5b0a725378a0b9f51a56bb9da910 --- /dev/null +++ b/think_sound/configs/model_configs/autoencoders/stable_audio_2_0_vae.json @@ -0,0 +1,122 @@ +{ + "model_type": "autoencoder", + "sample_size": 65536, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + }, + "training": { + "learning_rate": 1.5e-4, + "warmup_steps": 0, + "use_ema": true, + "optimizer_configs": { + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 1.5e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "betas": [0.8, 0.99], + "lr": 3e-4, + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 200000, + "power": 0.5, + "warmup": 0.999 + } + } + } + }, + "loss_configs": { + "discriminator": { + "type": "encodec", + "config": { + "filters": 64, + "n_ffts": [2048, 1024, 512, 256, 128], + "hop_lengths": [512, 256, 128, 64, 32], + "win_lengths": [2048, 1024, 512, 256, 128] + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0 + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32], + "hop_sizes": [512, 256, 128, 64, 32, 16, 8], + "win_lengths": [2048, 1024, 512, 256, 128, 64, 32], + "perceptual_weighting": true + }, + "weights": { + "mrstft": 1.0 + } + }, + "time": { + "type": "l1", + "weights": { + "l1": 0.0 + } + }, + "bottleneck": { + "type": "kl", + "weights": { + "kl": 1e-4 + } + } + }, + "demo": { + "demo_every": 10000 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base.json b/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base.json new file mode 100644 index 0000000000000000000000000000000000000000..a57f9e4abc99157128f505c6f5e5188101808f9b --- /dev/null +++ b/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base.json @@ -0,0 +1,18 @@ +{ + "model_type": "diffusion_uncond", + "sample_size": 65536, + "sample_rate": 48000, + "model": { + "type": "DAU1d", + "config": { + "n_attn_layers": 5 + } + }, + "training": { + "learning_rate": 1e-4, + "demo": { + "demo_every": 2000, + "demo_steps": 250 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json b/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json new file mode 100644 index 0000000000000000000000000000000000000000..4319a56731f981d2de1a294c2727e087475d1633 --- /dev/null +++ b/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base_16k.json @@ -0,0 +1,18 @@ +{ + "model_type": "diffusion_uncond", + "sample_size": 65536, + "sample_rate": 16000, + "model": { + "type": "DAU1d", + "config": { + "n_attn_layers": 5 + } + }, + "training": { + "learning_rate": 1e-4, + "demo": { + "demo_every": 2000, + "demo_steps": 250 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json b/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json new file mode 100644 index 0000000000000000000000000000000000000000..fedb83fa3c741d7c1d4a7215e909862a81730805 --- /dev/null +++ b/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_base_44k.json @@ -0,0 +1,18 @@ +{ + "model_type": "diffusion_uncond", + "sample_size": 65536, + "sample_rate": 44100, + "model": { + "type": "DAU1d", + "config": { + "n_attn_layers": 5 + } + }, + "training": { + "learning_rate": 4e-5, + "demo": { + "demo_every": 2000, + "demo_steps": 250 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_large.json b/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_large.json new file mode 100644 index 0000000000000000000000000000000000000000..f9f96a455ad9e40b4ea624bda4b9c209fea4bcca --- /dev/null +++ b/think_sound/configs/model_configs/dance_diffusion/dance_diffusion_large.json @@ -0,0 +1,18 @@ +{ + "model_type": "diffusion_uncond", + "sample_size": 131072, + "sample_rate": 48000, + "model": { + "type": "DAU1d", + "config": { + "n_attn_layers": 5 + } + }, + "training": { + "learning_rate": 1e-4, + "demo": { + "demo_every": 2000, + "demo_steps": 250 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/mono2audio/flow_audio_mono.json b/think_sound/configs/model_configs/mono2audio/flow_audio_mono.json new file mode 100644 index 0000000000000000000000000000000000000000..bd8ae0d8aff26b0f4cf437c161ed2c6840b94f4f --- /dev/null +++ b/think_sound/configs/model_configs/mono2audio/flow_audio_mono.json @@ -0,0 +1,102 @@ +{ + "model_type": "diffusion_prior", + "sample_size": 440320, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "prior_type": "mono_stereo", + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "video", + "type": "video_linear", + "config": { + "dim": 1024, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["video"], + "input_concat_ids": ["source"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "cond_token_dim": 768, + "input_concat_dim": 64, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer" + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "learning_rate": 5e-4, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2, + "demo_steps": 100 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/mono2audio/flow_audio_mono_aural_loss.json b/think_sound/configs/model_configs/mono2audio/flow_audio_mono_aural_loss.json new file mode 100644 index 0000000000000000000000000000000000000000..6b19cf7e750f69766c2a5a0185178ce1964c3787 --- /dev/null +++ b/think_sound/configs/model_configs/mono2audio/flow_audio_mono_aural_loss.json @@ -0,0 +1,103 @@ +{ + "model_type": "diffusion_prior", + "sample_size": 440320, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "prior_type": "mono_stereo", + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "video", + "type": "video_linear", + "config": { + "dim": 1024, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["video"], + "input_concat_ids": ["source"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "cond_token_dim": 768, + "input_concat_dim": 64, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer" + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "learning_rate": 5e-4, + "use_reconstruction_loss": true, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 500, + "demo_steps": 100 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/mono2audio/flow_audio_mono_lr1e4.json b/think_sound/configs/model_configs/mono2audio/flow_audio_mono_lr1e4.json new file mode 100644 index 0000000000000000000000000000000000000000..bd8ae0d8aff26b0f4cf437c161ed2c6840b94f4f --- /dev/null +++ b/think_sound/configs/model_configs/mono2audio/flow_audio_mono_lr1e4.json @@ -0,0 +1,102 @@ +{ + "model_type": "diffusion_prior", + "sample_size": 440320, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "prior_type": "mono_stereo", + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "video", + "type": "video_linear", + "config": { + "dim": 1024, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["video"], + "input_concat_ids": ["source"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "cond_token_dim": 768, + "input_concat_dim": 64, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer" + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "learning_rate": 5e-4, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2, + "demo_steps": 100 + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/txt2audio/stable_audio_1_0.json b/think_sound/configs/model_configs/txt2audio/stable_audio_1_0.json new file mode 100644 index 0000000000000000000000000000000000000000..22db891d8529f894a26a0c7f7d173ef2ae84b744 --- /dev/null +++ b/think_sound/configs/model_configs/txt2audio/stable_audio_1_0.json @@ -0,0 +1,107 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 4194304, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "dac", + "config": { + "in_channels": 2, + "latent_dim": 128, + "d_model": 128, + "strides": [4, 4, 8, 8] + } + }, + "decoder": { + "type": "dac", + "config": { + "out_channels": 2, + "latent_dim": 64, + "channels": 1536, + "rates": [8, 8, 4, 4] + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 1024, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "prompt", + "type": "clap_text", + "config": { + "audio_model_type": "HTSAT-base", + "enable_fusion": true, + "clap_ckpt_path": "/path/to/clap.ckpt", + "use_text_features": true, + "feature_layer_ix": -2 + } + }, + { + "id": "seconds_start", + "type": "int", + "config": { + "min_val": 0, + "max_val": 512 + } + }, + { + "id": "seconds_total", + "type": "int", + "config": { + "min_val": 0, + "max_val": 512 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "type": "adp_cfg_1d", + "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], + "config": { + "in_channels": 64, + "context_embedding_features": 768, + "context_embedding_max_length": 79, + "channels": 256, + "resnet_groups": 16, + "kernel_multiplier_downsample": 2, + "multipliers": [4, 4, 4, 5, 5], + "factors": [1, 2, 2, 4], + "num_blocks": [2, 2, 2, 2], + "attentions": [1, 3, 3, 3, 3], + "attention_heads": 16, + "attention_multiplier": 4, + "use_nearest_upsample": false, + "use_skip_scale": true, + "use_context_time": true + } + }, + "io_channels": 64 + }, + "training": { + "learning_rate": 4e-5, + "demo": { + "demo_every": 2000, + "demo_steps": 250, + "num_demos": 4, + "demo_cond": [ + {"prompt": "A beautiful piano arpeggio", "seconds_start": 0, "seconds_total": 95}, + {"prompt": "A tropical house track with upbeat melodies, a driving bassline, and cheery vibes", "seconds_start": 0, "seconds_total": 90}, + {"prompt": "A cool 80s glam rock song with driving drums and distorted guitars", "seconds_start": 0, "seconds_total": 180}, + {"prompt": "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle.", "seconds_start": 0, "seconds_total": 60} + ], + "demo_cfg_scales": [3, 6, 9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/txt2audio/stable_audio_2_0.json b/think_sound/configs/model_configs/txt2audio/stable_audio_2_0.json new file mode 100644 index 0000000000000000000000000000000000000000..933e24307d322bcc2d5aa9a5d41c308706849791 --- /dev/null +++ b/think_sound/configs/model_configs/txt2audio/stable_audio_2_0.json @@ -0,0 +1,124 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 362496, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "prompt", + "type": "clap_text", + "config": { + "audio_model_type": "HTSAT-base", + "enable_fusion": true, + "clap_ckpt_path": "useful_ckpts/clap-htsat-fused/pytorch_model.bin", + "use_text_features": true, + "feature_layer_ix": -2 + } + }, + { + "id": "seconds_start", + "type": "number", + "config": { + "min_val": 0, + "max_val": 512 + } + }, + { + "id": "seconds_total", + "type": "number", + "config": { + "min_val": 0, + "max_val": 512 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], + "global_cond_ids": ["seconds_start", "seconds_total"], + "type": "dit", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "cond_token_dim": 768, + "global_cond_dim": 1536, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer" + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 1, + "demo_steps": 100, + "num_demos": 1, + "demo_cond": [ + {"prompt": "children shouting", "seconds_start": 0, "seconds_total": 10} + ], + "demo_cfg_scales": [7] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit.json new file mode 100644 index 0000000000000000000000000000000000000000..6b18684a37f77e9193165ff24fe63dd9194783ef --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit.json @@ -0,0 +1,137 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow.json new file mode 100644 index 0000000000000000000000000000000000000000..4f9b8704492402b92b4eb2a6c0929755b2c4a8e2 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow.json @@ -0,0 +1,138 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "timestep_sampler": "uniform", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": false + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-3, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 100000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit.json new file mode 100644 index 0000000000000000000000000000000000000000..e724a8533f25332324a4a87b5715ec0ea7df3a11 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit.json @@ -0,0 +1,138 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_addvideo.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_addvideo.json new file mode 100644 index 0000000000000000000000000000000000000000..34676ef598ac32a83269097c0404fe389aaade52 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_addvideo.json @@ -0,0 +1,140 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "cross_attend": false, + "add_video": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_cross.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_cross.json new file mode 100644 index 0000000000000000000000000000000000000000..a6b208b02695845e892920d29c4c450f0b574dc4 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_cross.json @@ -0,0 +1,139 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "cross_attend": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_cross_gated.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_cross_gated.json new file mode 100644 index 0000000000000000000000000000000000000000..3f0b545d3a7e906b68491e8e82a8a096e6ecd116 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_cross_gated.json @@ -0,0 +1,141 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "cross_attend": true, + "add_video": true, + "gated_video": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_text_latents/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_text_latents/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_text_latents/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_text_latents/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_text_latents/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_text_latents/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_text_latents/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_text_latents/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_text_latents/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_text_latents/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_gatedvideo.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_gatedvideo.json new file mode 100644 index 0000000000000000000000000000000000000000..32cae7b5fd5c922af7e35fe0c85a197499654215 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_gatedvideo.json @@ -0,0 +1,141 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "cross_attend": false, + "add_video": true, + "gated_video": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_inpaint.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_inpaint.json new file mode 100644 index 0000000000000000000000000000000000000000..687836326151c5a23a57b00e2bb7619e383db34c --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_inpaint.json @@ -0,0 +1,140 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "use_inpaint": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "max_mask_segments": 10, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_text_latents/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_text_latents/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_text_latents/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_text_latents/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_text_latents/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_text_latents/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_text_latents/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_text_latents/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_text_latents/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_text_latents/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5.json new file mode 100644 index 0000000000000000000000000000000000000000..46e5073708d33ee804d27c57aac7ee56e8580c3f --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5.json @@ -0,0 +1,146 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_cross.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_cross.json new file mode 100644 index 0000000000000000000000000000000000000000..87d5edb0be880070e3875ae70365c65612b47c55 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_cross.json @@ -0,0 +1,148 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "add_video": true, + "cross_attend": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 4000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated.json new file mode 100644 index 0000000000000000000000000000000000000000..b40d5a04c5e38ae9ef734ada069512df3ad887f7 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated.json @@ -0,0 +1,148 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "add_video": true, + "gated_video": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 4000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross.json new file mode 100644 index 0000000000000000000000000000000000000000..c1e1c1dd37c57f47b0b07d1ce5d0b38d527d649f --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross.json @@ -0,0 +1,149 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "add_video": true, + "gated_video": true, + "cross_attend": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 100000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross_lr1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross_lr1.json new file mode 100644 index 0000000000000000000000000000000000000000..3b1e8c61ec9dde8d78998865f84a94a4b71fed39 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross_lr1.json @@ -0,0 +1,150 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "add_video": true, + "gated_video": true, + "cross_attend": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.95], + "weight_decay": 1e-6, + "fused": true, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross_lr2.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross_lr2.json new file mode 100644 index 0000000000000000000000000000000000000000..3db01645031f702b73d2b8cc2a218bf2142128e9 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_cross_lr2.json @@ -0,0 +1,150 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "add_video": true, + "gated_video": true, + "cross_attend": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.95], + "weight_decay": 1e-6, + "fused": true, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 100000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_lr1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_lr1.json new file mode 100644 index 0000000000000000000000000000000000000000..86f0c2dd1c52132155801da4d2b7c7a86bccb8cc --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_gated_lr1.json @@ -0,0 +1,149 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "add_video": true, + "gated_video": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.95], + "weight_decay": 1e-6, + "fused": true, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global.json new file mode 100644 index 0000000000000000000000000000000000000000..d001485e091baf2bed5222c3e8b81af8646f9b94 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global.json @@ -0,0 +1,154 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_global_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features", "metaclip_global_text_features", "t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_latents_t5_clip/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_latents_t5_clip/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_latents_t5_clip/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_latents_t5_clip/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_latents_t5_clip/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_latents_t5_clip/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_latents_t5_clip/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_latents_t5_clip/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_latents_t5_clip/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global_gated.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global_gated.json new file mode 100644 index 0000000000000000000000000000000000000000..133793bfac6acf03e467f0f21097a3439f4e2bbd --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global_gated.json @@ -0,0 +1,156 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_global_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features", "metaclip_global_text_features", "t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "add_video": true, + "gated_video": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global_gated_cross.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global_gated_cross.json new file mode 100644 index 0000000000000000000000000000000000000000..b713ba4bb782f451df268db799e8bc425beaf503 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_global_gated_cross.json @@ -0,0 +1,157 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_global_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features", "metaclip_global_text_features", "t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "add_video": true, + "gated_video": true, + "cross_attend": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_latents_t5_clip/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_latents_t5_clip/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_latents_t5_clip/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_latents_t5_clip/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_latents_t5_clip/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_latents_t5_clip/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_latents_t5_clip/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_latents_t5_clip/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_latents_t5_clip/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_inpaint.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_inpaint.json new file mode 100644 index 0000000000000000000000000000000000000000..03f4b9e9fdff78c510d41b7d6d8d6b9a530fb620 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_inpaint.json @@ -0,0 +1,148 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "use_inpaint": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "max_mask_segments": 10, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size1.json new file mode 100644 index 0000000000000000000000000000000000000000..beb2ee1864048b10f28fa43a20a9dbead84f644e --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size1.json @@ -0,0 +1,147 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "kernel_size": 1 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json new file mode 100644 index 0000000000000000000000000000000000000000..5d466bef5a76c21149b14c8c998c55e737a4845a --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json @@ -0,0 +1,147 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "kernel_size": 3 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_medium.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_medium.json new file mode 100644 index 0000000000000000000000000000000000000000..5de7b392fc926c1b47367536c5f2602a7ad85b4c --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_medium.json @@ -0,0 +1,147 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":768 , + "depth":21, + "fused_depth":14, + "num_heads":12, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "kernel_size": 3 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_small.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_small.json new file mode 100644 index 0000000000000000000000000000000000000000..ef13d4538940a0c37f35e57c897a2dbda4700863 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_small.json @@ -0,0 +1,147 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":768 , + "depth":18, + "fused_depth":12, + "num_heads":12, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "kernel_size": 3 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_sync_kernel3.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_sync_kernel3.json new file mode 100644 index 0000000000000000000000000000000000000000..7085d6f0e3cef6b7bb95cf166269e02c89d380b2 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_sync_kernel3.json @@ -0,0 +1,148 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "kernel_size": 3, + "sync_kernel": 3 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_sync_kernel5.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_sync_kernel5.json new file mode 100644 index 0000000000000000000000000000000000000000..daf6e561b1cd4fc3d55f120afd24858cdb477b7d --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_sync_kernel5.json @@ -0,0 +1,148 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "kernel_size": 3, + "sync_kernel": 5 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_uniform.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_uniform.json new file mode 100644 index 0000000000000000000000000000000000000000..cc9027f71a4a3b3fea1ff44ac216f4acfbbb58a4 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3_uniform.json @@ -0,0 +1,147 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "kernel_size": 3 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "uniform", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size5.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size5.json new file mode 100644 index 0000000000000000000000000000000000000000..0f862850f543c3fc4c722770d86cec304f53be08 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size5.json @@ -0,0 +1,147 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "kernel_size": 5 + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr1.json new file mode 100644 index 0000000000000000000000000000000000000000..2c5e14bd5bd36ca93e37f105b293a2768c63a424 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr1.json @@ -0,0 +1,147 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.95], + "weight_decay": 1e-6, + "fused": true, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr1_drop0_1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr1_drop0_1.json new file mode 100644 index 0000000000000000000000000000000000000000..d464136ceab1c1e36980f6bdd7dd2e3e1417553c --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr1_drop0_1.json @@ -0,0 +1,147 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.1, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.95], + "weight_decay": 1e-6, + "fused": true, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr2.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr2.json new file mode 100644 index 0000000000000000000000000000000000000000..460ac2ccae0d45e1d45c12e149cea199396a9384 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_lr2.json @@ -0,0 +1,147 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.95], + "weight_decay": 1e-6, + "fused": true, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 100000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 5000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [5] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_mlp.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_mlp.json new file mode 100644 index 0000000000000000000000000000000000000000..7ea85ea36d77d13028cb3ddf90a23d6d692a7fa7 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_mlp.json @@ -0,0 +1,147 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + }, + { + "id": "t5_features", + "type": "mm_unchang", + "config": { + "dim": 2048, + "output_dim": 2048 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":2048, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "use_mlp": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_t5_clip/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_latents_t5_clip/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_latents_t5_clip/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_latents_t5_clip/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_latents_t5_clip/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_latents_t5_clip/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_latents_t5_clip/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_latents_t5_clip/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_latents_t5_clip/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_latents_t5_clip/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_triplegated.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_triplegated.json new file mode 100644 index 0000000000000000000000000000000000000000..ef2c134f7c3c8f985a2d213dc2e477f9e2ea2ac2 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_triplegated.json @@ -0,0 +1,142 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true, + "cross_attend": false, + "add_video": true, + "gated_video": false, + "triple_fusion": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_text_latents/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_text_latents/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_text_latents/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_text_latents/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_text_latents/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_text_latents/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_text_latents/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_text_latents/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_text_latents/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_text_latents/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_large.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_large.json new file mode 100644 index 0000000000000000000000000000000000000000..2e0ae59455750209382b915acbc1b9d2df2e8914 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_large.json @@ -0,0 +1,137 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1280 , + "depth":24, + "fused_depth":16, + "num_heads":20, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr1.json new file mode 100644 index 0000000000000000000000000000000000000000..7049af0756c0e46125ac98fb8ce71c6b0a883e41 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr1.json @@ -0,0 +1,137 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-3, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 100000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr2.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr2.json new file mode 100644 index 0000000000000000000000000000000000000000..964e276c547328461bd9e5b6f2d3afcf97406c6f --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr2.json @@ -0,0 +1,137 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.95], + "weight_decay": 1e-3, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 100000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr3.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr3.json new file mode 100644 index 0000000000000000000000000000000000000000..902ee308129b20354af81f4ea68930d4560e0d4d --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr3.json @@ -0,0 +1,137 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr4.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr4.json new file mode 100644 index 0000000000000000000000000000000000000000..a989ec63b012163f887c2d152199d6b54a782837 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_lr4.json @@ -0,0 +1,138 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.95], + "weight_decay": 1e-6, + "fused": true, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_old.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_old.json new file mode 100644 index 0000000000000000000000000000000000000000..3a464c0c08b2dd5814b9ef7cc4b187ad3ad37a9c --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_old.json @@ -0,0 +1,138 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":896 , + "depth":21, + "fused_depth":14, + "num_heads":14, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.95], + "weight_decay": 1e-3, + "fused": true, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 100000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_oldlr.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_oldlr.json new file mode 100644 index 0000000000000000000000000000000000000000..a989ec63b012163f887c2d152199d6b54a782837 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_oldlr.json @@ -0,0 +1,138 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 1e-4, + "betas": [0.9, 0.95], + "weight_decay": 1e-6, + "fused": true, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_v.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_v.json new file mode 100644 index 0000000000000000000000000000000000000000..c5743ed295023a20afe335ffd0748a05ae53fd05 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_v.json @@ -0,0 +1,137 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "v", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": false + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-3, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 100000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_v1.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_v1.json new file mode 100644 index 0000000000000000000000000000000000000000..a3b5400c23cf62568f84294fba5427f7caf19720 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_v1.json @@ -0,0 +1,138 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": false + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_vae.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_vae.json new file mode 100644 index 0000000000000000000000000000000000000000..cf32673a444143f40f2fe21e890eb317f01ad544 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_vae.json @@ -0,0 +1,138 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": false, + "timestep_sampler": "logit_normal", + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-4, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0Cu33yBwAPg_000060.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/bmKtI808DsU_000009.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/VC0c22cJTbM_000424.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/F3gsbUTdc2U_000090.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/WatvT8A8iug_000100.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/0nvBTp-q7tU_000112.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/3-PFuDkTM48_000080.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/luSAuu-BoPs_000232.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/__8UJxW0aOQ_000002.npz", + "dataset/vggsound/video_224_latents_text_sync_roi_npz/test/_0m_YMpQayA_000168.npz" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_weight.json b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_weight.json new file mode 100644 index 0000000000000000000000000000000000000000..5a3458f86a3cd4bd5103e74bc03233b553416c2a --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_weight.json @@ -0,0 +1,137 @@ +{ + "model_type": "mm_diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "metaclip_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "metaclip_text_features", + "type": "mm_unchang", + "config": { + "dim": 1024, + "output_dim": 1024 + } + }, + { + "id": "sync_features", + "type": "mm_unchang", + "config": { + "dim": 768, + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features"], + "type": "mmdit", + "diffusion_objective": "rectified_flow", + "config": { + "latent_dim":64, + "clip_dim":1024, + "sync_dim":768, + "text_dim":1024, + "hidden_dim":1024 , + "depth":21, + "fused_depth":14, + "num_heads":16, + "latent_seq_len":194, + "clip_seq_len":72, + "sync_seq_len":216, + "v2": true + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": true, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.95], + "weight_decay": 1e-3, + "eps": 1e-6 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 24, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_224_pad_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_224_pad_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_224_pad_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_224_pad_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_224_pad_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_224_pad_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_224_pad_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_224_pad_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_224_pad_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_224_pad_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/multimodal_clip.json b/think_sound/configs/model_configs/vt2audio/multimodal_clip.json new file mode 100644 index 0000000000000000000000000000000000000000..0e47ef726ebca26a0c31848773e718c515164192 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/multimodal_clip.json @@ -0,0 +1,125 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "clip_features", + "type": "video_linear", + "config": { + "dim": 1024, + "output_dim": 1536 + } + }, + { + "id": "caption_t5", + "type": "t5", + "config": { + "t5_model_name": "t5-v1_1-xl", + "output_dim": 1536 + } + } + ], + "cond_dim": 1536 + }, + "diffusion": { + "add_cond_ids": ["clip_features"], + "cross_attention_cond_ids": ["caption_t5"], + "type": "dit", + "diffusion_objective": "rectified_flow", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "cond_token_dim": 1536, + "global_cond_dim": 1536, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer" + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": false, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 64, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/multimodal_clip_text_clip.json b/think_sound/configs/model_configs/vt2audio/multimodal_clip_text_clip.json new file mode 100644 index 0000000000000000000000000000000000000000..3502aa0b9145ac5886742d3a6c96727eef69e209 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/multimodal_clip_text_clip.json @@ -0,0 +1,124 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "clip_features", + "type": "video_linear", + "config": { + "dim": 1024, + "output_dim": 1536 + } + }, + { + "id": "caption", + "type": "clip_text", + "config": { + "output_dim": 768 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "add_cond_ids": ["clip_features"], + "cross_attention_cond_ids": ["caption"], + "type": "dit", + "diffusion_objective": "rectified_flow", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "cond_token_dim": 768, + "global_cond_dim": 1536, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer" + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": false, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 64, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/multimodal_clip_text_metaclip.json b/think_sound/configs/model_configs/vt2audio/multimodal_clip_text_metaclip.json new file mode 100644 index 0000000000000000000000000000000000000000..d9cdd9376a75d5bbcade31db3aed8f255c7c73da --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/multimodal_clip_text_metaclip.json @@ -0,0 +1,141 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 397312, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "clip_features", + "type": "video_linear", + "config": { + "dim": 1024, + "output_dim": 1536 + } + }, + { + "id": "caption", + "type": "metaclip_text", + "config": { + "output_dim": 768 + } + }, + { + "id": "seconds_start", + "type": "number", + "config": { + "min_val": 0, + "max_val": 512 + } + }, + { + "id": "seconds_total", + "type": "number", + "config": { + "min_val": 0, + "max_val": 512 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "add_cond_ids": ["clip_features"], + "cross_attention_cond_ids": ["caption", "seconds_start", "seconds_total"], + "global_cond_ids": ["seconds_start", "seconds_total"], + "type": "dit", + "diffusion_objective": "rectified_flow", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "cond_token_dim": 768, + "global_cond_dim": 1536, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer" + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "pre_encoded": false, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 64, + "num_demos": 10, + "demo_cond": [ + "dataset/vggsound/video_latents_text/test/0Cu33yBwAPg_000060.pth", + "dataset/vggsound/video_latents_text/test/bmKtI808DsU_000009.pth", + "dataset/vggsound/video_latents_text/test/VC0c22cJTbM_000424.pth", + "dataset/vggsound/video_latents_text/test/F3gsbUTdc2U_000090.pth", + "dataset/vggsound/video_latents_text/test/WatvT8A8iug_000100.pth", + "dataset/vggsound/video_latents_text/test/0nvBTp-q7tU_000112.pth", + "dataset/vggsound/video_latents_text/test/3-PFuDkTM48_000080.pth", + "dataset/vggsound/video_latents_text/test/luSAuu-BoPs_000232.pth", + "dataset/vggsound/video_latents_text/test/__8UJxW0aOQ_000002.pth", + "dataset/vggsound/video_latents_text/test/_0m_YMpQayA_000168.pth" + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/configs/model_configs/vt2audio/vt2a_stable_clip_load.json b/think_sound/configs/model_configs/vt2audio/vt2a_stable_clip_load.json new file mode 100644 index 0000000000000000000000000000000000000000..2015675c4576bc88310e9f0f8b2a941584e7c8a7 --- /dev/null +++ b/think_sound/configs/model_configs/vt2audio/vt2a_stable_clip_load.json @@ -0,0 +1,139 @@ +{ + "model_type": "diffusion_cond", + "sample_size": 441000, + "sample_rate": 44100, + "audio_channels": 2, + "model": { + "pretransform": { + "type": "autoencoder", + "iterate_batch": true, + "config": { + "encoder": { + "type": "oobleck", + "config": { + "in_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 128, + "use_snake": true + } + }, + "decoder": { + "type": "oobleck", + "config": { + "out_channels": 2, + "channels": 128, + "c_mults": [1, 2, 4, 8, 16], + "strides": [2, 4, 4, 8, 8], + "latent_dim": 64, + "use_snake": true, + "final_tanh": false + } + }, + "bottleneck": { + "type": "vae" + }, + "latent_dim": 64, + "downsampling_ratio": 2048, + "io_channels": 2 + } + }, + "conditioning": { + "configs": [ + { + "id": "video", + "type": "video_linear", + "config": { + "dim": 1024, + "output_dim": 1536 + } + }, + { + "id": "prompt", + "type": "t5", + "config": { + "t5_model_name": "t5-v1_1-xl", + "output_dim": 768 + } + }, + { + "id": "seconds_start", + "type": "number", + "config": { + "min_val": 0, + "max_val": 512 + } + }, + { + "id": "seconds_total", + "type": "number", + "config": { + "min_val": 0, + "max_val": 512 + } + } + ], + "cond_dim": 768 + }, + "diffusion": { + "cross_attention_cond_ids": ["prompt", "seconds_start", "seconds_total"], + "global_cond_ids": ["seconds_start", "seconds_total"], + "add_cond_ids": ["video"], + "type": "dit", + "diffusion_objective": "rectified_flow", + "config": { + "io_channels": 64, + "embed_dim": 1536, + "depth": 24, + "num_heads": 24, + "cond_token_dim": 768, + "global_cond_dim": 1536, + "project_cond_tokens": false, + "transformer_type": "continuous_transformer" + } + }, + "io_channels": 64 + }, + "training": { + "use_ema": true, + "log_loss_info": false, + "cfg_dropout_prob": 0.2, + "optimizer_configs": { + "diffusion": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": 5e-5, + "betas": [0.9, 0.999], + "weight_decay": 1e-3 + } + }, + "scheduler": { + "type": "InverseLR", + "config": { + "inv_gamma": 1000000, + "power": 0.5, + "warmup": 0.99 + } + } + } + }, + "demo": { + "demo_every": 2000, + "demo_steps": 64, + "num_demos": 8, + "demo_cond": [ + {"video": "data/VGGSOUND/MetaClip-Huge/test/0Cu33yBwAPg_000060.npy", "prompt": "church bell ringing", "seconds_start": 0, "seconds_total": 10}, + {"video": "data/VGGSOUND/MetaClip-Huge/test/bmKtI808DsU_000009.npy", "prompt": "lions growling", "seconds_start": 0, "seconds_total": 10}, + {"video": "data/VGGSOUND/MetaClip-Huge/test/VC0c22cJTbM_000424.npy", "prompt": "writing on blackboard with chalk", "seconds_start": 0, "seconds_total": 10}, + {"video": "data/VGGSOUND/MetaClip-Huge/test/F3gsbUTdc2U_000090.npy", "prompt": "wind chime", "seconds_start": 0, "seconds_total": 10}, + {"video": "data/VGGSOUND/MetaClip-Huge/test/WatvT8A8iug_000100.npy", "prompt": "car passing by", "seconds_start": 0, "seconds_total": 10}, + {"video": "data/VGGSOUND/MetaClip-Huge/test/0nvBTp-q7tU_000112.npy", "prompt": "driving snowmobile", "seconds_start": 0, "seconds_total": 10}, + {"video": "data/VGGSOUND/MetaClip-Huge/test/3-PFuDkTM48_000080.npy", "prompt": "playing accordion", "seconds_start": 0, "seconds_total": 10}, + {"video": "data/VGGSOUND/MetaClip-Huge/test/luSAuu-BoPs_000232.npy", "prompt": "squishing water", "seconds_start": 0, "seconds_total": 10} + ], + "demo_cfg_scales": [3,6,9] + } + } +} \ No newline at end of file diff --git a/think_sound/data/__init__.py b/think_sound/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/think_sound/data/__pycache__/__init__.cpython-310.pyc b/think_sound/data/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef50ce6b886a66ebe29fa889b95c859c2155374f Binary files /dev/null and b/think_sound/data/__pycache__/__init__.cpython-310.pyc differ diff --git a/think_sound/data/__pycache__/datamodule.cpython-310.pyc b/think_sound/data/__pycache__/datamodule.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcb3ed109244948edcca0ebfe0c32808ab02e88a Binary files /dev/null and b/think_sound/data/__pycache__/datamodule.cpython-310.pyc differ diff --git a/think_sound/data/__pycache__/dataset.cpython-310.pyc b/think_sound/data/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..070073611ef8a8915f747bd35074d502b5c57fbb Binary files /dev/null and b/think_sound/data/__pycache__/dataset.cpython-310.pyc differ diff --git a/think_sound/data/__pycache__/utils.cpython-310.pyc b/think_sound/data/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..037168b106b682650443f9d4399af0ce37889375 Binary files /dev/null and b/think_sound/data/__pycache__/utils.cpython-310.pyc differ diff --git a/think_sound/data/datamodule.py b/think_sound/data/datamodule.py new file mode 100644 index 0000000000000000000000000000000000000000..8a4733d3a8ffceb8ddd8e769a231a9f3a0019a20 --- /dev/null +++ b/think_sound/data/datamodule.py @@ -0,0 +1,192 @@ +import lightning as L +from .dataset import LatentDataset, SampleDataset, VideoDataset, AudioDataset, MultiModalDataset, LocalDatasetConfig, collation_fn +import importlib +from torch.utils.data import DataLoader + + +def get_configs(audio_configs): + configs = [] + for config in audio_configs: + data_dir_path = config.get("path", None) + audio_dir_path = config.get("audio_dir", None) + split_path = config.get("split_path", None) + assert data_dir_path is not None, "Path must be set for local audio directory configuration" + + custom_metadata_fn = None + custom_metadata_module_path = config.get("custom_metadata_module", None) + + if custom_metadata_module_path: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + custom_metadata_fn = metadata_module.get_custom_metadata + + configs.append( + LocalDatasetConfig( + id=config["id"], + path=data_dir_path, + split_path=split_path, + custom_metadata_fn=custom_metadata_fn, + audio_dir=audio_dir_path + ) + ) + return configs + +class DataModule(L.LightningDataModule): + def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5): + super().__init__() + dataset_type = dataset_config.get("dataset_type", None) + self.batch_size = batch_size + self.num_workers = num_workers + self.test_batch_size = test_batch_size + self.repeat_num = repeat_num + assert dataset_type is not None, "Dataset type must be specified in dataset config" + + if audio_channels == 1: + force_channels = "mono" + elif audio_channels == 2: + force_channels = "stereo" + else: + force_channels = "foa" + val_dir_configs = dataset_config.get("val_datasets", None) + test_dir_configs = dataset_config.get("test_datasets", None) + configs = [] + val_configs = [] + test_configs = [] + if dataset_type == "audio_dir": + audio_dir_configs = dataset_config.get("datasets", None) + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + configs = get_configs(audio_dir_configs) + val_configs = get_configs(val_dir_configs) + test_configs = get_configs(test_dir_configs) + elif dataset_type == "latent_dir" or dataset_type == "video_dataset": + audio_dir_configs = dataset_config.get("datasets", None) + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + for i, dataset in enumerate((audio_dir_configs, val_dir_configs, test_dir_configs)): + for config in dataset: + data_dir_path = config.get("path", None) + audio_dir_path = config.get("audio_dir", None) + split_path = config.get("split_path", None) + assert data_dir_path is not None, "Path must be set for local audio directory configuration" + + content = LocalDatasetConfig( + id=config["id"], + path=data_dir_path, + split_path=split_path, + audio_dir=audio_dir_path, + extra_cot=config.get("extra_cot", None) + ) + if i == 0: + configs.append(content) + elif i == 1: + val_configs.append(content) + else: + test_configs.append(content) + elif dataset_type == "multimodal_dir": + self.audio_configs = [] + self.video_configs = [] + audio_dir_configs = dataset_config.get("audio_datasets", None) + video_dir_configs = dataset_config.get("video_datasets", None) + assert audio_dir_configs is not None and video_dir_configs is not None, "Directory configuration must be specified in video_datasets and audio_datasets" + for i, dataset in enumerate((audio_dir_configs, video_dir_configs, val_dir_configs, test_dir_configs)): + for config in dataset: + data_dir_path = config.get("path", None) + audio_dir_path = config.get("audio_dir", None) + split_path = config.get("split_path", None) + assert data_dir_path is not None, "Path must be set for local audio directory configuration" + print(f'extra cot: {config.get("extra_cot", None)}') + content = LocalDatasetConfig( + id=config["id"], + path=data_dir_path, + split_path=split_path, + audio_dir=audio_dir_path, + extra_cot=config.get("extra_cot", None) + ) + if i == 0: + self.audio_configs.append(content) + elif i == 1: + self.video_configs.append(content) + elif i == 2: + val_configs.append(content) + else: + test_configs.append(content) + self.dataset_type = dataset_type + self.configs = configs + self.val_configs = val_configs + self.test_configs = test_configs + self.sample_rate = sample_rate + self.sample_size = sample_size + self.random_crop = dataset_config.get("random_crop", True) + self.input_type = dataset_config.get("input_type", "video") + self.fps = dataset_config.get("fps", 4) + self.force_channels = force_channels + + + def setup(self, stage: str): + if self.dataset_type == 'audio_dir': + dataset_class = SampleDataset + elif self.dataset_type == 'latent_dir': + dataset_class = LatentDataset + elif self.dataset_type == 'video_dataset': + dataset_class = VideoDataset + elif self.dataset_type == 'multimodal_dir': + dataset_class = VideoDataset + + def create_dataset(configs, random_crop): + return dataset_class( + configs, + sample_rate=self.sample_rate, + sample_size=self.sample_size, + random_crop=random_crop, + input_type=self.input_type, + fps=self.input_type, + force_channels=self.force_channels + ) + + if stage == 'fit': + if self.dataset_type != 'multimodal_dir': + self.train_set = create_dataset(self.configs, random_crop=self.random_crop) + else: + self.video_set = VideoDataset( + self.video_configs, + sample_rate=self.sample_rate, + sample_size=self.sample_size, + random_crop=self.random_crop, + input_type=self.input_type, + fps=self.input_type, + force_channels=self.force_channels + ) + self.audio_set = AudioDataset( + self.audio_configs, + sample_rate=self.sample_rate, + sample_size=self.sample_size, + random_crop=self.random_crop, + input_type=self.input_type, + fps=self.input_type, + force_channels=self.force_channels + ) + self.train_set = MultiModalDataset([self.video_set]*self.repeat_num, [self.audio_set]) + self.val_set = create_dataset(self.val_configs, random_crop=False) + elif stage == 'validate': + self.val_set = create_dataset(self.val_configs, random_crop=False) + elif stage == 'predict': + self.test_set = create_dataset(self.test_configs, random_crop=False) + + def train_dataloader(self): + return DataLoader(self.train_set, self.batch_size, shuffle=True, + num_workers=self.num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) + + def val_dataloader(self): + return DataLoader(self.val_set, self.batch_size, shuffle=False, + num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn) + + def predict_dataloader(self): + return DataLoader(self.test_set, batch_size=self.test_batch_size, shuffle=False, + num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn) + + # def predict_dataloader(self): + # return DataLoader(self.mnist_predict, batch_size=self.batch_size) + + # def teardown(self, stage: str): + # # Used to clean-up when the run is finished + # ... \ No newline at end of file diff --git a/think_sound/data/dataset.py b/think_sound/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e4fadcf9ed0951928f29109214ac2bde5c773a --- /dev/null +++ b/think_sound/data/dataset.py @@ -0,0 +1,1268 @@ +import importlib +import numpy as np +import io +import os +import posixpath +import random +import re +import subprocess +import time +import torch +import torchaudio +import webdataset as wds +import pandas as pd +from aeiou.core import is_silence +from os import path +from pathlib import Path +from pedalboard.io import AudioFile +from torchaudio import transforms as T +from typing import Optional, Callable, List +import bisect + +from .utils import FOA, Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, PadCrop_Video_Normalized_T, PadCrop_Video_Hiera_Normalized_T, PadCrop_Video_Image_Normalized_T, PadCrop_DualVideo_Normalized_T + +AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus") + +# fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py + +def fast_scandir( + dir:str, # top-level directory at which to begin scanning + ext:list, # list of allowed file extensions, + #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB + ): + "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" + subfolders, files = [], [] + ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed + try: # hope to avoid 'permission denied' by this try + for f in os.scandir(dir): + try: # 'hope to avoid too many levels of symbolic links' error + if f.is_dir(): + subfolders.append(f.path) + elif f.is_file(): + file_ext = os.path.splitext(f.name)[1].lower() + is_hidden = os.path.basename(f.path).startswith(".") + + if file_ext in ext and not is_hidden: + files.append(f.path) + except: + pass + except: + pass + + for dir in list(subfolders): + sf, f = fast_scandir(dir, ext) + subfolders.extend(sf) + files.extend(f) + return subfolders, files + +def keyword_scandir( + dir: str, # top-level directory at which to begin scanning + ext: list, # list of allowed file extensions + keywords: list, # list of keywords to search for in the file name +): + "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" + subfolders, files = [], [] + # make keywords case insensitive + keywords = [keyword.lower() for keyword in keywords] + # add starting period to extensions if needed + ext = ['.'+x if x[0] != '.' else x for x in ext] + banned_words = ["paxheader", "__macosx"] + try: # hope to avoid 'permission denied' by this try + for f in os.scandir(dir): + try: # 'hope to avoid too many levels of symbolic links' error + if f.is_dir(): + subfolders.append(f.path) + elif f.is_file(): + is_hidden = f.name.split("/")[-1][0] == '.' + has_ext = os.path.splitext(f.name)[1].lower() in ext + name_lower = f.name.lower() + has_keyword = any( + [keyword in name_lower for keyword in keywords]) + has_banned = any( + [banned_word in name_lower for banned_word in banned_words]) + if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"): + files.append(f.path) + except: + pass + except: + pass + + for dir in list(subfolders): + sf, f = keyword_scandir(dir, ext, keywords) + subfolders.extend(sf) + files.extend(f) + return subfolders, files + +def get_audio_filenames( + paths: list, # directories in which to search + keywords=None, + exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus'] +): + "recursively get a list of audio filenames" + filenames = [] + if type(paths) is str: + paths = [paths] + for path in paths: # get a list of relevant filenames + if keywords is not None: + subfolders, files = keyword_scandir(path, exts, keywords) + else: + subfolders, files = fast_scandir(path, exts) + filenames.extend(files) + return filenames + +class LocalDatasetConfig: + def __init__( + self, + id: str, + path: str, + split_path: str, + audio_dir: str = None, + extra_cot: str = None, + custom_metadata_fn: Optional[Callable[[str], str]] = None + ): + self.id = id + self.path = path + self.split_path = split_path + self.audio_dir = audio_dir + self.custom_metadata_fn = custom_metadata_fn + self.extra_cot = extra_cot +class SampleDataset(torch.utils.data.Dataset): + def __init__( + self, + configs, + sample_size=65536, + sample_rate=48000, + keywords=None, + random_crop=True, + input_type="prompt", + fps=4, + force_channels="stereo" + ): + super().__init__() + self.filenames = [] + + self.augs = torch.nn.Sequential( + PhaseFlipper(), + ) + + self.root_paths = [] + if input_type == 'video': + self.pad_crop = PadCrop_Video_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) + elif input_type == 'video_hiera': + self.pad_crop = PadCrop_Video_Hiera_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) + elif input_type == 'video_image': + self.pad_crop = PadCrop_Video_Image_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) + elif input_type == 'dual_video': + self.pad_crop = PadCrop_DualVideo_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) + else: + self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop) + + self.force_channels = force_channels + print('######################') + print(f'input channels is: {force_channels}') + print('######################') + self.encoding = torch.nn.Sequential( + FOA() if self.force_channels == "foa" else torch.nn.Identity(), + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + ) + self.input_type = input_type + self.sr = sample_rate + self.custom_metadata_fns = {} + + for config in configs: + self.root_paths.append(config.path) + def add_prefix(s): + return str(os.path.join(config.path,f'{s.strip()}')) + with open(config.split_path,'r') as f: + item_names = f.readlines() + filenames = list(map(add_prefix, item_names)) + self.filenames.extend(filenames) + # self.filenames.extend(get_audio_filenames(config.path, keywords)) + if config.custom_metadata_fn is not None: + self.custom_metadata_fns[config.path] = config.custom_metadata_fn + + print(f'Found {len(self.filenames)} files') + + def load_file(self, filename): + ext = filename.split(".")[-1] + if ext == "mp3": + with AudioFile(filename) as f: + audio = f.read(f.frames) + audio = torch.from_numpy(audio) + in_sr = f.samplerate + else: + audio, in_sr = torchaudio.load(filename, format=ext) + + if in_sr != self.sr: + try: + resample_tf = T.Resample(in_sr, self.sr) + audio = resample_tf(audio) + except: + print(f'{filename} resample errors') + + assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' + return audio + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + audio_filename = self.filenames[idx] + assert os.path.exists(audio_filename), f'{audio_filename}: file not exists' + try: + start_time = time.time() + audio = self.load_file(audio_filename) + info = {} + info["path"] = audio_filename + + for root_path in self.root_paths: + if root_path in audio_filename: + info["relpath"] = path.relpath(audio_filename, root_path) + + + for custom_md_path in self.custom_metadata_fns.keys(): + if custom_md_path in audio_filename: + custom_metadata_fn = self.custom_metadata_fns[custom_md_path] + custom_metadata = custom_metadata_fn(info, audio) + info.update(custom_metadata) + + if "__reject__" in info and info["__reject__"]: + return self[random.randrange(len(self))] + if self.input_type == 'video': + audio, video, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['video']) + info['video'] = video + elif self.input_type == 'dual_video': + audio, video_360, video_fov, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['video'], info['video_fov']) + info['video_360'] = video_360 + info['video_fov'] = video_fov + else: + audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio) + assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' + # Run augmentations on this sample (including random crop) + if self.augs is not None: + audio = self.augs(audio) + + audio = audio.clamp(-1, 1) + + # Encode the file to assist in prediction + if self.encoding is not None: + audio = self.encoding(audio) + + + + info["timestamps"] = (t_start, t_end) + info["seconds_start"] = seconds_start + info["seconds_total"] = seconds_total + info["padding_mask"] = padding_mask + + end_time = time.time() + info["load_time"] = end_time - start_time + + + return (audio, info) + except Exception as e: + print(f'Couldn\'t load file {audio_filename}: {e}') + return self[random.randrange(len(self))] + +class LatentDataset(torch.utils.data.Dataset): + def __init__( + self, + configs, + sample_size=65536, + sample_rate=48000, + keywords=None, + random_crop=True, + input_type="prompt", + fps=4, + force_channels="stereo" + ): + super().__init__() + self.filenames = [] + + self.augs = torch.nn.Sequential( + PhaseFlipper(), + ) + + self.root_paths = [] + + self.force_channels = force_channels + print('######################') + print(f'input channels is: {force_channels}') + print('######################') + self.encoding = torch.nn.Sequential( + FOA() if self.force_channels == "foa" else torch.nn.Identity(), + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + ) + self.input_type = input_type + self.sr = sample_rate + for config in configs: + self.root_paths.append(config.path) + def add_prefix(s): + return str(os.path.join(config.path,f'{s.strip()}')) + with open(config.split_path,'r') as f: + item_names = f.readlines() + filenames = list(map(add_prefix, item_names)) + self.filenames.extend(filenames) + # self.filenames.extend(get_audio_filenames(config.path, keywords)) + + + print(f'Found {len(self.filenames)} files') + + def load_file(self, filename, info): + # try: + npz_file = filename.replace('.pth','.npz') + if os.path.exists(filename) and '.npz' not in filename: + data = torch.load(filename, weights_only=False) + elif os.path.exists(npz_file): + # print(filename) + npz_data = np.load(npz_file,allow_pickle=True) + data = {key: npz_data[key] for key in npz_data.files} + # print("data.keys()",data.keys()) + for key in data.keys(): + if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): + data[key] = torch.from_numpy(data[key]) + else: + raise ValueError(f'error load file: {filename}') + info.update(data) + audio = data['latent'] + # except: + # print(f'error load file: {filename}') + return audio, info['metaclip_features'] + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + audio_filename = self.filenames[idx] + assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists' + # try: + start_time = time.time() + info = {} + audio, video = self.load_file(audio_filename, info) + info["path"] = audio_filename + assert audio.shape == (64,194), f'{audio.shape} input error, id: {id}' + assert video.shape == (72,1024), f'{video.shape} input error, id: {id}' + info['id'] = Path(audio_filename).stem + for root_path in self.root_paths: + if root_path in audio_filename: + info["relpath"] = path.relpath(audio_filename, root_path) + + return (audio, info) + +class AudioDataset(torch.utils.data.Dataset): + def __init__( + self, + configs, + sample_size=65536, + sample_rate=48000, + keywords=None, + random_crop=True, + input_type="prompt", + fps=4, + force_channels="stereo" + ): + super().__init__() + self.filenames = [] + + self.augs = torch.nn.Sequential( + PhaseFlipper(), + ) + + self.root_paths = [] + + self.force_channels = force_channels + print('######################') + print(f'input channels is: {force_channels}') + print('######################') + self.encoding = torch.nn.Sequential( + FOA() if self.force_channels == "foa" else torch.nn.Identity(), + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + ) + self.fake_clip_features = torch.zeros(72, 1024) + self.fake_sync_features = torch.zeros(216, 768) + self.video_exist = torch.tensor(0, dtype=torch.bool) + self.input_type = input_type + self.sr = sample_rate + for config in configs: + self.root_paths.append(config.path) + def add_prefix(s): + return str(os.path.join(config.path,f'{s.strip()}')) + with open(config.split_path,'r') as f: + item_names = f.readlines() + filenames = list(map(add_prefix, item_names)) + self.filenames.extend(filenames) + # self.filenames.extend(get_audio_filenames(config.path, keywords)) + + + print(f'Found {len(self.filenames)} files') + + def load_file(self, filename, info): + # try: + npz_file = filename.replace('.pth','.npz') + if os.path.exists(filename) and '.npz' not in filename: + data = torch.load(filename, weights_only=False) + elif os.path.exists(npz_file): + # print(filename) + npz_data = np.load(npz_file,allow_pickle=True) + data = {key: npz_data[key] for key in npz_data.files} + # print("data.keys()",data.keys()) + for key in data.keys(): + if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): + data[key] = torch.from_numpy(data[key]) + else: + raise ValueError(f'error load file: {filename}') + info.update(data) + audio = data['latent'] + info['metaclip_features'] = self.fake_clip_features + info['sync_features'] = self.fake_sync_features + info['video_exist'] = self.video_exist + # except: + # print(f'error load file: {filename}') + return audio, info['metaclip_features'] + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + audio_filename = self.filenames[idx] + assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists' + # try: + start_time = time.time() + info = {} + audio, video = self.load_file(audio_filename, info) + info["path"] = audio_filename + assert audio.shape == (64,194), f'{audio.shape} input error, id: {id}' + assert video.shape == (72,1024), f'{video.shape} input error, id: {id}' + info['id'] = Path(audio_filename).stem + for root_path in self.root_paths: + if root_path in audio_filename: + info["relpath"] = path.relpath(audio_filename, root_path) + + return (audio, info) + +class VideoDataset(torch.utils.data.Dataset): + def __init__( + self, + configs, + sample_size=65536, + sample_rate=48000, + keywords=None, + random_crop=True, + input_type="prompt", + fps=4, + force_channels="stereo", + ): + + super().__init__() + self.filenames = [] + print(f'configs: {configs[0]}') + if configs[0].extra_cot is not None: + self.extra_cot = configs[0].extra_cot + print(f'load extra cot from {self.extra_cot}') + else: + self.extra_cot = None + self.augs = torch.nn.Sequential( + PhaseFlipper(), + ) + + self.root_paths = [] + + self.force_channels = force_channels + print('######################') + print(f'input channels is: {force_channels}') + print('######################') + self.encoding = torch.nn.Sequential( + FOA() if self.force_channels == "foa" else torch.nn.Identity(), + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + ) + self.input_type = input_type + self.sr = sample_rate + self.video_exist = torch.tensor(1, dtype=torch.bool) + for config in configs: + self.root_paths.append(config.path) + def add_prefix(s): + return str(os.path.join(config.path,f'{s.strip()}')) + with open(config.split_path,'r') as f: + item_names = f.readlines() + filenames = list(map(add_prefix, item_names)) + self.filenames.extend(filenames) + # self.filenames.extend(get_audio_filenames(config.path, keywords)) + + + print(f'Found {len(self.filenames)} files') + + def load_file(self, filename, info): + # try: + npz_file = filename.replace('.pth','.npz') + if os.path.exists(filename) and '.npz' not in filename: + data = torch.load(filename, weights_only=False) + elif os.path.exists(npz_file): + # print(filename) + npz_data = np.load(npz_file,allow_pickle=True) + data = {key: npz_data[key] for key in npz_data.files} + # print("data.keys()",data.keys()) + for key in data.keys(): + if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): + data[key] = torch.from_numpy(data[key]) + if self.extra_cot is not None: + extra_pth = filename.replace('.npz','.pth') + extra_pth = os.path.join(self.extra_cot, os.path.basename(extra_pth)) + if os.path.exists(extra_pth): + extra_data = torch.load(extra_pth, weights_only=False) + for key in extra_data.keys(): + if isinstance(extra_data[key], torch.Tensor): + # print(f'load extra cot {key}') + data[key] = extra_data[key] + else: + raise ValueError(f'error load file: {filename}') + info.update(data) + if 'latent' in data.keys(): + audio = data['latent'] + else: + audio = torch.zeros(64,194) + info['video_exist'] = self.video_exist + # except: + # print(f'error load file: {filename}') + return audio, info['metaclip_features'] + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + audio_filename = self.filenames[idx] + assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists' + # try: + start_time = time.time() + info = {} + audio, video = self.load_file(audio_filename, info) + info["path"] = audio_filename + assert audio is None or audio.shape == (64,194), f'{audio.shape} input error, id: {id}' + assert video.shape == (72,1024), f'{video.shape} input error, id: {id}' + info['id'] = Path(audio_filename).stem + for root_path in self.root_paths: + if root_path in audio_filename: + info["relpath"] = path.relpath(audio_filename, root_path) + + return (audio, info) + +# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset +class MultiModalDataset(torch.utils.data.Dataset): + datasets: list[torch.utils.data.Dataset] + cumulative_sizes: list[int] + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, video_datasets: list[torch.utils.data.Dataset], audio_datasets: list[torch.utils.data.Dataset]): + super().__init__() + self.video_datasets = list(video_datasets) + self.audio_datasets = list(audio_datasets) + self.datasets = self.video_datasets + self.audio_datasets + + self.cumulative_sizes = self.cumsum(self.datasets) + print(f'Found {self.cumulative_sizes[-1]} files') + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: + return self.video_datasets[0].compute_latent_stats() + + +# class MultiModalDataset(torch.utils.data.Dataset): +# def __init__( +# self, +# configs, +# sample_size=65536, +# sample_rate=48000, +# keywords=None, +# random_crop=True, +# input_type="prompt", +# fps=4, +# force_channels="stereo" +# ): +# super().__init__() +# self.filenames = [] +# self.captions = [] +# self.caption_t5s = [] +# self.ids = [] +# self.augs = torch.nn.Sequential( +# PhaseFlipper(), +# ) + +# self.root_paths = [] +# if input_type == 'video': +# self.pad_crop = PadCrop_Video_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) +# elif input_type == 'video_hiera': +# self.pad_crop = PadCrop_Video_Hiera_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) +# elif input_type == 'video_image': +# self.pad_crop = PadCrop_Video_Image_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) +# elif input_type == 'dual_video': +# self.pad_crop = PadCrop_DualVideo_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop) +# else: +# self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop) + +# self.force_channels = force_channels +# print('######################') +# print(f'input channels is: {force_channels}') +# print('######################') +# self.encoding = torch.nn.Sequential( +# FOA() if self.force_channels == "foa" else torch.nn.Identity(), +# Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), +# Mono() if self.force_channels == "mono" else torch.nn.Identity(), +# ) +# self.input_type = input_type +# self.sr = sample_rate +# self.custom_metadata_fns = {} + +# for config in configs: +# print(config.split_path) +# self.root_paths.append(config.path) +# def add_prefix(s): +# return str(os.path.join(config.path,f'{s.strip()}')) +# with open(config.split_path,'r') as f: +# item_names = f.readlines() +# csv_path = config.split_path.replace('.txt','.csv') +# df = pd.read_csv(csv_path) +# # 检查是否存在 'caption_t5' 列,如果不存在则创建并复制 'caption' 的值 +# if 'caption_t5' not in df.columns: +# df['caption_t5'] = df['caption'] + +# captions = df['caption'].tolist() +# caption_t5s = df['caption_t5'].tolist() +# filenames = list(map(add_prefix, item_names)) +# assert len(captions) == len(caption_t5s) and len(captions) == len(filenames), f'{config.path} has wrong filename and caption' +# if config.id == 'vggsound': +# self.filenames.extend(filenames*5) +# self.captions.extend(captions*5) +# self.caption_t5s.extend(caption_t5s*5) +# self.ids.extend(df['id'].tolist()*5) +# else: +# self.filenames.extend(filenames) +# self.captions.extend(captions) +# self.caption_t5s.extend(caption_t5s) +# self.ids.extend(df['id'].tolist()) +# # self.filenames.extend(get_audio_filenames(config.path, keywords)) +# if config.custom_metadata_fn is not None: +# self.custom_metadata_fns[config.path] = config.custom_metadata_fn + +# assert len(self.ids) == len(self.captions) and len(self.caption_t5s) == len(self.filenames), 'length need to be same' +# print(f'Found {len(self.filenames)} files') + + +# def load_file(self, filename): +# ext = filename.split(".")[-1] +# if ext == "mp3": +# with AudioFile(filename) as f: +# audio = f.read(f.frames) +# audio = torch.from_numpy(audio) +# in_sr = f.samplerate +# else: +# audio, in_sr = torchaudio.load(filename, format=ext) + +# if in_sr != self.sr: +# try: +# resample_tf = T.Resample(in_sr, self.sr) +# audio = resample_tf(audio) +# except: +# print(f'{filename} resample errors') + +# assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' +# return audio + +# def __len__(self): +# return len(self.filenames) + +# def __getitem__(self, idx): +# audio_filename = self.filenames[idx] +# id = self.ids[idx] +# assert str(id) == str(Path(audio_filename).stem), f'audio_file: {audio_filename} needs to be same as {id} ' +# assert os.path.exists(audio_filename), f'{audio_filename}: file not exists' +# try: +# start_time = time.time() +# audio = self.load_file(audio_filename) +# caption = self.captions[idx] +# caption_t5 = self.caption_t5s[idx] +# if pd.isna(caption_t5) or caption_t5 == '': +# caption_t5 = caption +# info = {} +# info["path"] = audio_filename +# info['caption'] = caption +# info['caption_t5'] = caption_t5 + +# for root_path in self.root_paths: +# if root_path in audio_filename: +# info["relpath"] = path.relpath(audio_filename, root_path) + + +# for custom_md_path in self.custom_metadata_fns.keys(): +# if custom_md_path in audio_filename: +# custom_metadata_fn = self.custom_metadata_fns[custom_md_path] +# custom_metadata = custom_metadata_fn(info, audio) +# info.update(custom_metadata) + +# if "__reject__" in info and info["__reject__"]: +# return self[random.randrange(len(self))] +# # if self.input_type == 'video': +# # audio, video, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['clip_features']) +# # info['clip_features'] = video +# # else: +# if info['flag']: +# audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio,randomize=False) +# else: +# audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio,randomize=True) +# assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!' +# # Run augmentations on this sample (including random crop) +# if self.augs is not None: +# audio = self.augs(audio) + +# audio = audio.clamp(-1, 1) + +# # Encode the file to assist in prediction +# if self.encoding is not None: +# audio = self.encoding(audio) + + + +# info["timestamps"] = (t_start, t_end) +# info["seconds_start"] = seconds_start +# info["seconds_total"] = seconds_total +# info["padding_mask"] = padding_mask + +# end_time = time.time() +# info["load_time"] = end_time - start_time + + +# return (audio, info) +# except Exception as e: +# print(f'Couldn\'t load file {audio_filename}: {e}') +# return self[random.randrange(len(self))] + +def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + :param keys: function that splits the key into key and extension (base_plus_ext) + :param lcase: convert suffixes to lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if wds.tariterators.trace: + print( + prefix, + suffix, + current_sample.keys() if isinstance(current_sample, dict) else None, + ) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + if current_sample is None or prefix != current_sample["__key__"]: + if wds.tariterators.valid_sample(current_sample): + yield current_sample + current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) + if suffix in current_sample: + print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}") + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if wds.tariterators.valid_sample(current_sample): + yield current_sample + +wds.tariterators.group_by_keys = group_by_keys + +# S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py + +def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None): + """ + Returns a list of full S3 paths to files in a given S3 bucket and directory path. + """ + # Ensure dataset_path ends with a trailing slash + if dataset_path != '' and not dataset_path.endswith('/'): + dataset_path += '/' + # Use posixpath to construct the S3 URL path + bucket_path = posixpath.join(s3_url_prefix or '', dataset_path) + # Construct the `aws s3 ls` command + cmd = ['aws', 's3', 'ls', bucket_path] + + if profile is not None: + cmd.extend(['--profile', profile]) + + if recursive: + # Add the --recursive flag if requested + cmd.append('--recursive') + + # Run the `aws s3 ls` command and capture the output + run_ls = subprocess.run(cmd, capture_output=True, check=True) + # Split the output into lines and strip whitespace from each line + contents = run_ls.stdout.decode('utf-8').split('\n') + contents = [x.strip() for x in contents if x] + # Remove the timestamp from lines that begin with a timestamp + contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x) + if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents] + # Construct a full S3 path for each file in the contents list + contents = [posixpath.join(s3_url_prefix or '', x) + for x in contents if not x.endswith('/')] + # Apply the filter, if specified + if filter: + contents = [x for x in contents if filter in x] + # Remove redundant directory names in the S3 URL + if recursive: + # Get the main directory name from the S3 URL + main_dir = "/".join(bucket_path.split('/')[3:]) + # Remove the redundant directory names from each file path + contents = [x.replace(f'{main_dir}', '').replace( + '//', '/') for x in contents] + # Print debugging information, if requested + if debug: + print("contents = \n", contents) + # Return the list of S3 paths to files + return contents + + +def get_all_s3_urls( + names=[], # list of all valid [LAION AudioDataset] dataset names + # list of subsets you want from those datasets, e.g. ['train','valid'] + subsets=[''], + s3_url_prefix=None, # prefix for those dataset names + recursive=True, # recursively list all tar files in all subdirs + filter_str='tar', # only grab files with this substring + # print debugging info -- note: info displayed likely to change at dev's whims + debug=False, + profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'} +): + "get urls of shards (tar files) for multiple datasets in one s3 bucket" + urls = [] + for name in names: + # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list + if s3_url_prefix is None: + contents_str = name + else: + # Construct the S3 path using the s3_url_prefix and the current name value + contents_str = posixpath.join(s3_url_prefix, name) + if debug: + print(f"get_all_s3_urls: {contents_str}:") + for subset in subsets: + subset_str = posixpath.join(contents_str, subset) + if debug: + print(f"subset_str = {subset_str}") + # Get the list of tar files in the current subset directory + profile = profiles.get(name, None) + tar_list = get_s3_contents( + subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile) + for tar in tar_list: + # Escape spaces and parentheses in the tar filename for use in the shell command + tar = tar.replace(" ", "\ ").replace( + "(", "\(").replace(")", "\)") + # Construct the S3 path to the current tar file + s3_path = posixpath.join(name, subset, tar) + " -" + # Construct the AWS CLI command to download the current tar file + if s3_url_prefix is None: + request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}" + else: + request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}" + if profiles.get(name): + request_str += f" --profile {profiles.get(name)}" + if debug: + print("request_str = ", request_str) + # Add the constructed URL to the list of URLs + urls.append(request_str) + return urls + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + print(f"Handling webdataset error ({repr(exn)}). Ignoring.") + return True + + +def is_valid_sample(sample): + has_json = "json" in sample + has_audio = "audio" in sample + is_silent = is_silence(sample["audio"]) + is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"] + + return has_json and has_audio and not is_silent and not is_rejected + +class S3DatasetConfig: + def __init__( + self, + id: str, + s3_path: str, + custom_metadata_fn: Optional[Callable[[str], str]] = None, + profile: Optional[str] = None, + ): + self.id = id + self.path = s3_path + self.custom_metadata_fn = custom_metadata_fn + self.profile = profile + self.urls = [] + + def load_data_urls(self): + self.urls = get_all_s3_urls( + names=[self.path], + s3_url_prefix=None, + recursive=True, + profiles={self.path: self.profile} if self.profile else {}, + ) + + return self.urls + +class LocalWebDatasetConfig: + def __init__( + self, + id: str, + path: str, + custom_metadata_fn: Optional[Callable[[str], str]] = None, + profile: Optional[str] = None, + ): + self.id = id + self.path = path + self.custom_metadata_fn = custom_metadata_fn + self.urls = [] + + def load_data_urls(self): + + self.urls = fast_scandir(self.path, ["tar"])[1] + + return self.urls + +def audio_decoder(key, value): + # Get file extension from key + ext = key.split(".")[-1] + + if ext in AUDIO_KEYS: + return torchaudio.load(io.BytesIO(value)) + else: + return None + +def collation_fn(samples): + batched = list(zip(*samples)) + result = [] + for b in batched: + if isinstance(b[0], (int, float)): + b = np.array(b) + elif isinstance(b[0], torch.Tensor): + b = torch.stack(b) + elif isinstance(b[0], np.ndarray): + b = np.array(b) + else: + b = b + result.append(b) + return result + +class WebDatasetDataLoader(): + def __init__( + self, + datasets: List[S3DatasetConfig], + batch_size, + sample_size, + sample_rate=48000, + num_workers=8, + epoch_steps=1000, + random_crop=True, + force_channels="stereo", + augment_phase=True, + **data_loader_kwargs + ): + + self.datasets = datasets + + self.sample_size = sample_size + self.sample_rate = sample_rate + self.random_crop = random_crop + self.force_channels = force_channels + self.augment_phase = augment_phase + + urls = [dataset.load_data_urls() for dataset in datasets] + + # Flatten the list of lists of URLs + urls = [url for dataset_urls in urls for url in dataset_urls] + + # Shuffle the urls + random.shuffle(urls) + + self.dataset = wds.DataPipeline( + wds.ResampledShards(urls), + wds.tarfile_to_samples(handler=log_and_continue), + wds.decode(audio_decoder, handler=log_and_continue), + wds.map(self.wds_preprocess, handler=log_and_continue), + wds.select(is_valid_sample), + wds.to_tuple("audio", "json", handler=log_and_continue), + #wds.shuffle(bufsize=1000, initial=5000), + wds.batched(batch_size, partial=False, collation_fn=collation_fn), + ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps) + + self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs) + + def wds_preprocess(self, sample): + + found_key, rewrite_key = '', '' + for k, v in sample.items(): # print the all entries in dict + for akey in AUDIO_KEYS: + if k.endswith(akey): + # to rename long/weird key with its simpler counterpart + found_key, rewrite_key = k, akey + break + if '' != found_key: + break + if '' == found_key: # got no audio! + return None # try returning None to tell WebDataset to skip this one + + audio, in_sr = sample[found_key] + if in_sr != self.sample_rate: + resample_tf = T.Resample(in_sr, self.sample_rate) + audio = resample_tf(audio) + + if self.sample_size is not None: + # Pad/crop and get the relative timestamp + pad_crop = PadCrop_Normalized_T( + self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate) + audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop( + audio) + sample["json"]["seconds_start"] = seconds_start + sample["json"]["seconds_total"] = seconds_total + sample["json"]["padding_mask"] = padding_mask + else: + t_start, t_end = 0, 1 + + # Check if audio is length zero, initialize to a single zero if so + if audio.shape[-1] == 0: + audio = torch.zeros(1, 1) + + # Make the audio stereo and augment by randomly inverting phase + augs = torch.nn.Sequential( + Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), + Mono() if self.force_channels == "mono" else torch.nn.Identity(), + PhaseFlipper() if self.augment_phase else torch.nn.Identity() + ) + + audio = augs(audio) + + sample["json"]["timestamps"] = (t_start, t_end) + + if "text" in sample["json"]: + sample["json"]["prompt"] = sample["json"]["text"] + + # Check for custom metadata functions + for dataset in self.datasets: + if dataset.custom_metadata_fn is None: + continue + + if dataset.path in sample["__url__"]: + custom_metadata = dataset.custom_metadata_fn(sample["json"], audio) + sample["json"].update(custom_metadata) + + if found_key != rewrite_key: # rename long/weird key with its simpler counterpart + del sample[found_key] + + sample["audio"] = audio + + # Add audio to the metadata as well for conditioning + sample["json"]["audio"] = audio + + return sample + +def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4): + + dataset_type = dataset_config.get("dataset_type", None) + + assert dataset_type is not None, "Dataset type must be specified in dataset config" + + if audio_channels == 1: + force_channels = "mono" + elif audio_channels == 2: + force_channels = "stereo" + else: + force_channels = "foa" + + if dataset_type == "audio_dir": + + audio_dir_configs = dataset_config.get("datasets", None) + + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + + configs = [] + + for audio_dir_config in audio_dir_configs: + audio_dir_path = audio_dir_config.get("path", None) + split_path = audio_dir_config.get("split_path", None) + assert audio_dir_path is not None, "Path must be set for local audio directory configuration" + custom_metadata_fn = None + custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None) + + if custom_metadata_module_path is not None: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + + custom_metadata_fn = metadata_module.get_custom_metadata + + configs.append( + LocalDatasetConfig( + id=audio_dir_config["id"], + path=audio_dir_path, + split_path=split_path, + custom_metadata_fn=custom_metadata_fn + ) + ) + + train_set = SampleDataset( + configs, + sample_rate=sample_rate, + sample_size=sample_size, + random_crop=dataset_config.get("random_crop", True), + input_type=dataset_config.get("input_type", "video"), + fps=dataset_config.get("fps", 4), + force_channels=force_channels + ) + + return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, + num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) + + elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility + + wds_configs = [] + + for wds_config in dataset_config["datasets"]: + + custom_metadata_fn = None + custom_metadata_module_path = wds_config.get("custom_metadata_module", None) + + if custom_metadata_module_path is not None: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + + custom_metadata_fn = metadata_module.get_custom_metadata + + if "s3_path" in wds_config: + + wds_configs.append( + S3DatasetConfig( + id=wds_config["id"], + s3_path=wds_config["s3_path"], + custom_metadata_fn=custom_metadata_fn, + profile=wds_config.get("profile", None), + ) + ) + + elif "path" in wds_config: + + wds_configs.append( + LocalWebDatasetConfig( + id=wds_config["id"], + path=wds_config["path"], + custom_metadata_fn=custom_metadata_fn + ) + ) + + return WebDatasetDataLoader( + wds_configs, + sample_rate=sample_rate, + sample_size=sample_size, + batch_size=batch_size, + random_crop=dataset_config.get("random_crop", True), + num_workers=num_workers, + persistent_workers=True, + force_channels=force_channels, + epoch_steps=dataset_config.get("epoch_steps", 2000) + ).data_loader + + elif dataset_type == "latent_dir": + + audio_dir_configs = dataset_config.get("datasets", None) + + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + + configs = [] + + for audio_dir_config in audio_dir_configs: + audio_dir_path = audio_dir_config.get("path", None) + split_path = audio_dir_config.get("split_path", None) + assert audio_dir_path is not None, "Path must be set for local audio directory configuration" + + configs.append( + LocalDatasetConfig( + id=audio_dir_config["id"], + path=audio_dir_path, + split_path=split_path, + ) + ) + + train_set = LatentDataset( + configs, + sample_rate=sample_rate, + sample_size=sample_size, + random_crop=dataset_config.get("random_crop", True), + input_type=dataset_config.get("input_type", "video"), + fps=dataset_config.get("fps", 4), + force_channels=force_channels + ) + + return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, + num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) + elif dataset_type == 'multimodal_dir': + audio_dir_configs = dataset_config.get("datasets", None) + + assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" + + configs = [] + + for audio_dir_config in audio_dir_configs: + audio_dir_path = audio_dir_config.get("path", None) + split_path = audio_dir_config.get("split_path", None) + assert audio_dir_path is not None, "Path must be set for local audio directory configuration" + custom_metadata_fn = None + custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None) + + if custom_metadata_module_path is not None: + spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) + metadata_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(metadata_module) + + custom_metadata_fn = metadata_module.get_custom_metadata + + configs.append( + LocalDatasetConfig( + id=audio_dir_config["id"], + path=audio_dir_path, + split_path=split_path, + custom_metadata_fn=custom_metadata_fn + ) + ) + + train_set = MultiModalDataset( + configs, + sample_rate=sample_rate, + sample_size=sample_size, + random_crop=dataset_config.get("random_crop", True), + input_type=dataset_config.get("input_type", "video"), + fps=dataset_config.get("fps", 4), + force_channels=force_channels + ) + + return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, + num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) \ No newline at end of file diff --git a/think_sound/data/utils.py b/think_sound/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0a0f0a19a4f9a32ee3424cfd24efe7643ddde0fc --- /dev/null +++ b/think_sound/data/utils.py @@ -0,0 +1,378 @@ +import math +import random +import torch +import torch.nn.functional as F +from torch import nn +from typing import Tuple +import numpy as np + +class PadCrop(nn.Module): + def __init__(self, n_samples, randomize=True): + super().__init__() + self.n_samples = n_samples + self.randomize = randomize + + def __call__(self, signal): + n, s = signal.shape + start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() + end = start + self.n_samples + output = signal.new_zeros([n, self.n_samples]) + output[:, :min(s, self.n_samples)] = signal[:, start:end] + return output + +class PadCrop_Normalized_T(nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + + def __call__(self, source: torch.Tensor, randomize=True) -> Tuple[torch.Tensor, float, float, int, int]: + + n_channels, n_samples = source.shape + + # If the audio is shorter than the desired length, pad it + upper_bound = max(0, n_samples - self.n_samples) + + # If randomize is False, always start at the beginning of the audio + offset = 0 + if(randomize and n_samples > self.n_samples): + offset = random.randint(0, upper_bound) + + # Calculate the start and end times of the chunk + t_start = offset / (upper_bound + self.n_samples) + t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) + + # Create the chunk + chunk = source.new_zeros([n_channels, self.n_samples]) + + # Copy the audio into the chunk + chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] + + # Calculate the start and end times of the chunk in seconds + seconds_start = math.floor(offset / self.sample_rate) + seconds_total = math.ceil(n_samples / self.sample_rate) + + # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't + padding_mask = torch.zeros([self.n_samples]) + padding_mask[:min(n_samples, self.n_samples)] = 1 + + + return ( + chunk, + t_start, + t_end, + seconds_start, + seconds_total, + padding_mask + ) + +class PadCrop_Video_Normalized_T(nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + self.fps = fps + self.n_frames = int(self.fps * self.n_samples / self.sample_rate) + + def __call__(self, audio: torch.Tensor, video: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: + n_channels, n_samples = audio.shape + # print(video.shape) + n_frames, dim = video.shape + if not torch.is_tensor(video): + video = torch.from_numpy(video) + # If the audio is shorter than the desired length, pad it + audio_upper_bound = max(0, n_samples - self.n_samples) + video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps) + upper_bound = min(audio_upper_bound,video_upper_bound) + + # If randomize is False, always start at the beginning of the audio + offset = 0 + if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames): + offset = random.randint(0, upper_bound) + + # Calculate the start and end times of the chunk + t_start = offset / (upper_bound + self.n_samples) + t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) + frame_offset = int(self.fps * offset / self.sample_rate) + # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate) + # Create the chunk + chunk = audio.new_zeros([n_channels, self.n_samples]) + video_chunk = video.new_zeros([self.n_frames, video.shape[1]]) + # Copy the audio into the chunk + chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples] + video_chunk[:min(n_frames, self.n_frames)] = video[frame_offset:frame_offset + self.n_frames,:] + # Calculate the start and end times of the chunk in seconds + seconds_start = math.floor(offset / self.sample_rate) + seconds_total = math.ceil(n_samples / self.sample_rate) + + # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't + padding_mask = torch.zeros([self.n_samples]) + padding_mask[:min(n_samples, self.n_samples)] = 1 + + + return ( + chunk, + video_chunk, + t_start, + t_end, + seconds_start, + seconds_total, + padding_mask + ) + +class PadCrop_Video_Image_Normalized_T(nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + self.fps = fps + self.n_frames = int(self.fps * self.n_samples / self.sample_rate) + + def __call__(self, audio: torch.Tensor, video: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: + n_channels, n_samples = audio.shape + # import ipdb + # ipdb.set_trace() + n_frames, channel, width, height= video.shape + video = torch.from_numpy(video) + # If the audio is shorter than the desired length, pad it + audio_upper_bound = max(0, n_samples - self.n_samples) + video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps) + upper_bound = min(audio_upper_bound,video_upper_bound) + + # If randomize is False, always start at the beginning of the audio + offset = 0 + if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames): + offset = random.randint(0, upper_bound) + + # Calculate the start and end times of the chunk + t_start = offset / (upper_bound + self.n_samples) + t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) + frame_offset = int(self.fps * offset / self.sample_rate) + # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate) + # Create the chunk + chunk = audio.new_zeros([n_channels, self.n_samples]) + video_chunk = video.new_zeros([self.n_frames, channel, width, height]) + # Copy the audio into the chunk + chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples] + video_chunk[:min(n_frames, self.n_frames)] = video[frame_offset:frame_offset + self.n_frames] + # Calculate the start and end times of the chunk in seconds + seconds_start = math.floor(offset / self.sample_rate) + seconds_total = math.ceil(n_samples / self.sample_rate) + + # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't + padding_mask = torch.zeros([self.n_samples]) + padding_mask[:min(n_samples, self.n_samples)] = 1 + + + return ( + chunk, + video_chunk, + t_start, + t_end, + seconds_start, + seconds_total, + padding_mask + ) + +class PadCrop_Video_Hiera_Normalized_T(nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + self.fps = fps + self.n_frames = int(self.fps * self.n_samples / self.sample_rate) + + def __call__(self, audio: torch.Tensor, video: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: + + n_channels, n_samples = audio.shape + n_frames, heigh, width, channel = video.shape + video = torch.from_numpy(video) + # If the audio is shorter than the desired length, pad it + audio_upper_bound = max(0, n_samples - self.n_samples) + video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps) + upper_bound = min(audio_upper_bound,video_upper_bound) + + # If randomize is False, always start at the beginning of the audio + offset = 0 + if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames): + offset = random.randint(0, upper_bound) + + # Calculate the start and end times of the chunk + t_start = offset / (upper_bound + self.n_samples) + t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) + frame_offset = int(self.fps * offset / self.sample_rate) + # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate) + # Create the chunk + chunk = audio.new_zeros([n_channels, self.n_samples]) + video_chunk = video.new_zeros([self.n_frames, heigh, width, channel]) + # Copy the audio into the chunk + chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples] + video_chunk[:min(n_frames, self.n_frames)] = video[frame_offset:frame_offset + self.n_frames] + # video_chunk = video_chunk[None].permute(0, 4, 1, 2, 3).contiguous() + # print(video_chunk.shape) + # video_chunk = F.interpolate( + # video_chunk[0], + # size=(224, 224, 3), # 输出的空间尺寸 + # scale_factor=(target_frames / video_tensor.shape[1], 1, 1), # 时间轴的缩放因子 + # mode='trilinear', # 使用三线性插值 + # align_corners=False + # ) + + # video_chunk = F.interpolate(video_chunk, size=(64, 224, 224), mode="trilinear")[0] + # video_chunk = video_chunk.view(3,4,16,224,224).transpose(0,1) + # Calculate the start and end times of the chunk in seconds + seconds_start = math.floor(offset / self.sample_rate) + seconds_total = math.ceil(n_samples / self.sample_rate) + + # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't + padding_mask = torch.zeros([self.n_samples]) + padding_mask[:min(n_samples, self.n_samples)] = 1 + + + return ( + chunk, + video_chunk, + t_start, + t_end, + seconds_start, + seconds_total, + padding_mask + ) + +class PadCrop_DualVideo_Normalized_T(nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + self.fps = fps + self.n_frames = int(self.fps * self.n_samples / self.sample_rate) + + def __call__(self, audio: torch.Tensor, video_360: torch.Tensor, video_fov: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: + n_channels, n_samples = audio.shape + # print(video.shape) + n_frames, dim = video_360.shape + video_360 = torch.from_numpy(video_360) + video_fov = torch.from_numpy(video_fov) + # If the audio is shorter than the desired length, pad it + audio_upper_bound = max(0, n_samples - self.n_samples) + video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps) + upper_bound = min(audio_upper_bound,video_upper_bound) + + # If randomize is False, always start at the beginning of the audio + offset = 0 + if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames): + offset = random.randint(0, upper_bound) + + # Calculate the start and end times of the chunk + t_start = offset / (upper_bound + self.n_samples) + t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) + frame_offset = int(self.fps * offset / self.sample_rate) + # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate) + # Create the chunk + chunk = audio.new_zeros([n_channels, self.n_samples]) + video_360_chunk = video_360.new_zeros([self.n_frames, video_360.shape[1]]) + video_fov_chunk = video_fov.new_zeros([self.n_frames, video_fov.shape[1]]) + # Copy the audio into the chunk + chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples] + video_360_chunk[:min(n_frames, self.n_frames)] = video_360[frame_offset:frame_offset + self.n_frames,:] + video_fov_chunk[:min(n_frames, self.n_frames)] = video_fov[frame_offset:frame_offset + self.n_frames,:] + # Calculate the start and end times of the chunk in seconds + seconds_start = math.floor(offset / self.sample_rate) + seconds_total = math.ceil(n_samples / self.sample_rate) + + # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't + padding_mask = torch.zeros([self.n_samples]) + padding_mask[:min(n_samples, self.n_samples)] = 1 + + + return ( + chunk, + video_360_chunk, + video_fov_chunk, + t_start, + t_end, + seconds_start, + seconds_total, + padding_mask + ) + +class PhaseFlipper(nn.Module): + "Randomly invert the phase of a signal" + def __init__(self, p=0.5): + super().__init__() + self.p = p + def __call__(self, signal): + return -signal if (random.random() < self.p) else signal + +class Mono(nn.Module): + def __call__(self, signal): + return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal + +class Stereo(nn.Module): + def __call__(self, signal): + signal_shape = signal.shape + # Check if it's mono + if len(signal_shape) == 1: # s -> 2, s + signal = signal.unsqueeze(0).repeat(2, 1) + elif len(signal_shape) == 2: + if signal_shape[0] == 1: #1, s -> 2, s + signal = signal.repeat(2, 1) + elif signal_shape[0] > 2: #?, s -> 2,s + signal = signal[:2, :] + + return signal + +class FOA(nn.Module): + def __call__(self, signal): + signal_shape = signal.shape + # Check if it's mono + if len(signal_shape) == 1: # s -> (4, s) + foa = torch.zeros(4, signal_shape[0], device=signal.device) # 与输入信号一致的设备类型 + foa[0, :] = signal # W通道: 全方位声源 + foa[1, :] = 0 # X通道 + foa[2, :] = 0 # Y通道 + foa[3, :] = 0 # Z通道 + elif len(signal_shape) == 2: + foa = torch.zeros(4, signal_shape[1], device=signal.device) # 与输入信号一致的设备类型 + if signal_shape[0] == 1: # (1, s) -> (4, s) + foa[0, :] = signal[0] # W通道: 全方位声源 + foa[1, :] = 0 # X通道 + foa[2, :] = 0 # Y通道 + foa[3, :] = 0 # Z通道 + elif signal_shape[0] == 2: # (2, s) -> (4, s) + left = signal[0] + right = signal[1] + # 将立体声信号映射到FOA信号通道 + foa[0, :] = (left + right) / np.sqrt(2) # W通道: 全方位声源 + foa[1, :] = (left - right) / np.sqrt(2) # X通道: 前后方向 + foa[2, :] = 0 # Y通道: 左右方向,简单实现先置零 + foa[3, :] = 0 # Z通道: 垂直方向,这里置零 + else: + foa = signal + + else: + raise ValueError(f"Unsupported signal shape: {signal_shape}") + + assert foa.shape[0] == 4, f'inputs not FOA format' + + return foa \ No newline at end of file diff --git a/think_sound/inference/__init__.py b/think_sound/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/think_sound/inference/__pycache__/__init__.cpython-310.pyc b/think_sound/inference/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5023ae21e83129fdc7f23e439e637f624e7221a9 Binary files /dev/null and b/think_sound/inference/__pycache__/__init__.cpython-310.pyc differ diff --git a/think_sound/inference/__pycache__/__init__.cpython-38.pyc b/think_sound/inference/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..269dca61e7b70ba3213ad89ca48564e2e4b54d11 Binary files /dev/null and b/think_sound/inference/__pycache__/__init__.cpython-38.pyc differ diff --git a/think_sound/inference/__pycache__/__init__.cpython-39.pyc b/think_sound/inference/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad93a24373bcdf798ecd55b60e833a270e18fb7f Binary files /dev/null and b/think_sound/inference/__pycache__/__init__.cpython-39.pyc differ diff --git a/think_sound/inference/__pycache__/generation.cpython-310.pyc b/think_sound/inference/__pycache__/generation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..988d166a673c290635ee8dc5d50368e173a94500 Binary files /dev/null and b/think_sound/inference/__pycache__/generation.cpython-310.pyc differ diff --git a/think_sound/inference/__pycache__/generation.cpython-38.pyc b/think_sound/inference/__pycache__/generation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e8fa0fd739358e1117f01430fa5a77777b9869f Binary files /dev/null and b/think_sound/inference/__pycache__/generation.cpython-38.pyc differ diff --git a/think_sound/inference/__pycache__/generation.cpython-39.pyc b/think_sound/inference/__pycache__/generation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8aa87aa381c697f459aff50c5f5fdb4a358e579 Binary files /dev/null and b/think_sound/inference/__pycache__/generation.cpython-39.pyc differ diff --git a/think_sound/inference/__pycache__/sampling.cpython-310.pyc b/think_sound/inference/__pycache__/sampling.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3118560f043620b53fedaa9780b95c0a171f9924 Binary files /dev/null and b/think_sound/inference/__pycache__/sampling.cpython-310.pyc differ diff --git a/think_sound/inference/__pycache__/sampling.cpython-38.pyc b/think_sound/inference/__pycache__/sampling.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f97d01246f3cd254535df0889cc5baa10512190 Binary files /dev/null and b/think_sound/inference/__pycache__/sampling.cpython-38.pyc differ diff --git a/think_sound/inference/__pycache__/sampling.cpython-39.pyc b/think_sound/inference/__pycache__/sampling.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bf7c849f2e03b8d63fd51f085a4857040d28f6c Binary files /dev/null and b/think_sound/inference/__pycache__/sampling.cpython-39.pyc differ diff --git a/think_sound/inference/__pycache__/utils.cpython-310.pyc b/think_sound/inference/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9aee8d08753eb0ed4929cb1750a212e4d6a0f53c Binary files /dev/null and b/think_sound/inference/__pycache__/utils.cpython-310.pyc differ diff --git a/think_sound/inference/__pycache__/utils.cpython-38.pyc b/think_sound/inference/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7ae3b2b62d814ed92d69bc8e8269bf6a6ee56fe Binary files /dev/null and b/think_sound/inference/__pycache__/utils.cpython-38.pyc differ diff --git a/think_sound/inference/__pycache__/utils.cpython-39.pyc b/think_sound/inference/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e84bec0b8fb14faae453194b6c6572f0d45a6c5 Binary files /dev/null and b/think_sound/inference/__pycache__/utils.cpython-39.pyc differ diff --git a/think_sound/inference/generation.py b/think_sound/inference/generation.py new file mode 100644 index 0000000000000000000000000000000000000000..8d6873fe9d19deba0497e5fb4d939472552a261c --- /dev/null +++ b/think_sound/inference/generation.py @@ -0,0 +1,274 @@ +import numpy as np +import torch +import typing as tp +import math +from torchaudio import transforms as T + +from .utils import prepare_audio +from .sampling import sample, sample_k, sample_rf +from ..data.utils import PadCrop + +def generate_diffusion_uncond( + model, + steps: int = 250, + batch_size: int = 1, + sample_size: int = 2097152, + seed: int = -1, + device: str = "cuda", + init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, + init_noise_level: float = 1.0, + return_latents = False, + **sampler_kwargs + ) -> torch.Tensor: + + # The length of the output in audio samples + audio_sample_size = sample_size + + # If this is latent diffusion, change sample_size instead to the downsampled latent size + if model.pretransform is not None: + sample_size = sample_size // model.pretransform.downsampling_ratio + + # Seed + # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. + seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) + print(seed) + torch.manual_seed(seed) + # Define the initial noise immediately after setting the seed + noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) + + if init_audio is not None: + # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio. + in_sr, init_audio = init_audio + + io_channels = model.io_channels + + # For latent models, set the io_channels to the autoencoder's io_channels + if model.pretransform is not None: + io_channels = model.pretransform.io_channels + + # Prepare the initial audio for use by the model + init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) + + # For latent models, encode the initial audio into latents + if model.pretransform is not None: + init_audio = model.pretransform.encode(init_audio) + + init_audio = init_audio.repeat(batch_size, 1, 1) + else: + # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. + init_audio = None + init_noise_level = None + + # Inpainting mask + + if init_audio is not None: + # variations + sampler_kwargs["sigma_max"] = init_noise_level + mask = None + else: + mask = None + + # Now the generative AI part: + + diff_objective = model.diffusion_objective + + if diff_objective == "v": + # k-diffusion denoising process go! + sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device) + elif diff_objective == "rectified_flow": + sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device) + + # Denoising process done. + # If this is latent diffusion, decode latents back into audio + if model.pretransform is not None and not return_latents: + sampled = model.pretransform.decode(sampled) + + # Return audio + return sampled + + +def generate_diffusion_cond( + model, + steps: int = 250, + cfg_scale=6, + conditioning: dict = None, + conditioning_tensors: tp.Optional[dict] = None, + negative_conditioning: dict = None, + negative_conditioning_tensors: tp.Optional[dict] = None, + batch_size: int = 1, + sample_size: int = 2097152, + sample_rate: int = 48000, + seed: int = -1, + device: str = "cuda", + init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None, + init_noise_level: float = 1.0, + mask_args: dict = None, + return_latents = False, + **sampler_kwargs + ) -> torch.Tensor: + """ + Generate audio from a prompt using a diffusion model. + + Args: + model: The diffusion model to use for generation. + steps: The number of diffusion steps to use. + cfg_scale: Classifier-free guidance scale + conditioning: A dictionary of conditioning parameters to use for generation. + conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation. + batch_size: The batch size to use for generation. + sample_size: The length of the audio to generate, in samples. + sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly) + seed: The random seed to use for generation, or -1 to use a random seed. + device: The device to use for generation. + init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation. + init_noise_level: The noise level to use when generating from an initial audio sample. + return_latents: Whether to return the latents used for generation instead of the decoded audio. + **sampler_kwargs: Additional keyword arguments to pass to the sampler. + """ + + # The length of the output in audio samples + audio_sample_size = sample_size + + # If this is latent diffusion, change sample_size instead to the downsampled latent size + if model.pretransform is not None: + sample_size = sample_size // model.pretransform.downsampling_ratio + + # Seed + # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed. + seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32) + print(seed) + torch.manual_seed(seed) + # Define the initial noise immediately after setting the seed + noise = torch.randn([batch_size, model.io_channels, sample_size], device=device) + + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + torch.backends.cudnn.benchmark = False + import ipdb + # ipdb.set_trace() + # Conditioning + assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors" + if conditioning_tensors is None: + conditioning_tensors = model.conditioner(conditioning, device) + conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors) + + if negative_conditioning is not None or negative_conditioning_tensors is not None: + + if negative_conditioning_tensors is None: + negative_conditioning_tensors = model.conditioner(negative_conditioning, device) + + negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True) + else: + negative_conditioning_tensors = {} + + if init_audio is not None: + # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio. + in_sr, init_audio = init_audio + + io_channels = model.io_channels + + # For latent models, set the io_channels to the autoencoder's io_channels + if model.pretransform is not None: + io_channels = model.pretransform.io_channels + + # Prepare the initial audio for use by the model + init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device) + + # For latent models, encode the initial audio into latents + if model.pretransform is not None: + init_audio = model.pretransform.encode(init_audio) + + init_audio = init_audio.repeat(batch_size, 1, 1) + else: + # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch. + init_audio = None + init_noise_level = None + mask_args = None + + # Inpainting mask + if init_audio is not None and mask_args is not None: + # Cut and paste init_audio according to cropfrom, pastefrom, pasteto + # This is helpful for forward and reverse outpainting + cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size) + pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size) + pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size) + assert pastefrom < pasteto, "Paste From should be less than Paste To" + croplen = pasteto - pastefrom + if cropfrom + croplen > sample_size: + croplen = sample_size - cropfrom + cropto = cropfrom + croplen + pasteto = pastefrom + croplen + cutpaste = init_audio.new_zeros(init_audio.shape) + cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto] + #print(cropfrom, cropto, pastefrom, pasteto) + init_audio = cutpaste + # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args + mask = build_mask(sample_size, mask_args) + mask = mask.to(device) + elif init_audio is not None and mask_args is None: + # variations + sampler_kwargs["sigma_max"] = init_noise_level + mask = None + else: + mask = None + + model_dtype = next(model.model.parameters()).dtype + noise = noise.type(model_dtype) + conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()} + # Now the generative AI part: + # k-diffusion denoising process go! + diff_objective = model.diffusion_objective + if diff_objective == "v": + # k-diffusion denoising process go! + # sampled = sample(model.model, noise, steps, 0, **conditioning_inputs) + sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) + elif diff_objective == "rectified_flow": + + if "sigma_min" in sampler_kwargs: + del sampler_kwargs["sigma_min"] + + if "sampler_type" in sampler_kwargs: + del sampler_kwargs["sampler_type"] + + sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device) + + # v-diffusion: + #sampled = sample(model.model, noise, steps, 0, **conditioning_tensors, embedding_scale=cfg_scale) + del noise + del conditioning_tensors + del conditioning_inputs + torch.cuda.empty_cache() + # Denoising process done. + # If this is latent diffusion, decode latents back into audio + if model.pretransform is not None and not return_latents: + #cast sampled latents to pretransform dtype + sampled = sampled.to(next(model.pretransform.parameters()).dtype) + sampled = model.pretransform.decode(sampled) + + # Return audio + return sampled + +# builds a softmask given the parameters +# returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio, +# and anything between is a mixture of old/new +# ideally 0.5 is half/half mixture but i haven't figured this out yet +def build_mask(sample_size, mask_args): + maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size) + maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size) + softnessL = round(mask_args["softnessL"]/100.0 * sample_size) + softnessR = round(mask_args["softnessR"]/100.0 * sample_size) + marination = mask_args["marination"] + # use hann windows for softening the transition (i don't know if this is correct) + hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL] + hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:] + # build the mask. + mask = torch.zeros((sample_size)) + mask[maskstart:maskend] = 1 + mask[maskstart:maskstart+softnessL] = hannL + mask[maskend-softnessR:maskend] = hannR + # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds + if marination > 0: + mask = mask * (1-marination) + #print(mask) + return mask diff --git a/think_sound/inference/sampling.py b/think_sound/inference/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..2229e5089e3407a367df2d382ae039ca6364c489 --- /dev/null +++ b/think_sound/inference/sampling.py @@ -0,0 +1,232 @@ +import torch +import math +from tqdm import trange, tqdm + +import k_diffusion as K + +# Define the noise schedule and sampling loop +def get_alphas_sigmas(t): + """Returns the scaling factors for the clean image (alpha) and for the + noise (sigma), given a timestep.""" + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + +def alpha_sigma_to_t(alpha, sigma): + """Returns a timestep, given the scaling factors for the clean image and for + the noise.""" + return torch.atan2(sigma, alpha) / math.pi * 2 + +def t_to_alpha_sigma(t): + """Returns the scaling factors for the clean image and for the noise, given + a timestep.""" + return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) + + +@torch.no_grad() +def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args): + """Draws samples from a model given starting noise. Euler method""" + + # Make tensor of ones to broadcast the single t values + ts = x.new_ones([x.shape[0]]) + + # Create the noise schedule + t = torch.linspace(sigma_max, 0, steps + 1) + + #alphas, sigmas = 1-t, t + + for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])): + # Broadcast the current timestep to the correct shape + t_curr_tensor = t_curr * torch.ones( + (x.shape[0],), dtype=x.dtype, device=x.device + ) + dt = t_prev - t_curr # we solve backwards in our formulation + x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc) + + # If we are on the last timestep, output the denoised image + return x + +@torch.no_grad() +def sample(model, x, steps, eta, **extra_args): + """Draws samples from a model given starting noise. v-diffusion""" + ts = x.new_ones([x.shape[0]]) + + # Create the noise schedule + t = torch.linspace(1, 0, steps + 1)[:-1] + + alphas, sigmas = get_alphas_sigmas(t) + + # The sampling loop + for i in trange(steps): + + # Get the model output (v, the predicted velocity) + with torch.cuda.amp.autocast(): + v = model(x, ts * t[i], **extra_args).float() + + # Predict the noise and the denoised image + pred = x * alphas[i] - v * sigmas[i] + eps = x * sigmas[i] + v * alphas[i] + + # If we are not on the last timestep, compute the noisy image for the + # next timestep. + if i < steps - 1: + # If eta > 0, adjust the scaling factor for the predicted noise + # downward according to the amount of additional noise to add + ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \ + (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt() + adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt() + + # Recombine the predicted noise and predicted denoised image in the + # correct proportions for the next step + x = pred * alphas[i + 1] + eps * adjusted_sigma + + # Add the correct amount of fresh noise + if eta: + x += torch.randn_like(x) * ddim_sigma + + # If we are on the last timestep, output the denoised image + return pred + +# Soft mask inpainting is just shrinking hard (binary) mask inpainting +# Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step +def get_bmask(i, steps, mask): + strength = (i+1)/(steps) + # convert to binary mask + bmask = torch.where(mask<=strength,1,0) + return bmask + +def make_cond_model_fn(model, cond_fn): + def cond_model_fn(x, sigma, **kwargs): + with torch.enable_grad(): + x = x.detach().requires_grad_() + denoised = model(x, sigma, **kwargs) + cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach() + cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim) + return cond_denoised + return cond_model_fn + +# Uses k-diffusion from https://github.com/crowsonkb/k-diffusion +# init_data is init_audio as latents (if this is latent diffusion) +# For sampling, set both init_data and mask to None +# For variations, set init_data +# For inpainting, set both init_data & mask +def sample_k( + model_fn, + noise, + init_data=None, + mask=None, + steps=100, + sampler_type="dpmpp-2m-sde", + sigma_min=0.5, + sigma_max=50, + rho=1.0, device="cuda", + callback=None, + cond_fn=None, + **extra_args + ): + + denoiser = K.external.VDenoiser(model_fn) + + if cond_fn is not None: + denoiser = make_cond_model_fn(denoiser, cond_fn) + + # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has + sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device) + # Scale the initial noise by sigma + noise = noise * sigmas[0] + + wrapped_callback = callback + + if mask is None and init_data is not None: + # VARIATION (no inpainting) + # set the initial latent to the init_data, and noise it with initial sigma + x = init_data + noise + elif mask is not None and init_data is not None: + # INPAINTING + bmask = get_bmask(0, steps, mask) + # initial noising + input_noised = init_data + noise + # set the initial latent to a mix of init_data and noise, based on step 0's binary mask + x = input_noised * bmask + noise * (1-bmask) + # define the inpainting callback function (Note: side effects, it mutates x) + # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105 + # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)` + def inpainting_callback(args): + i = args["i"] + x = args["x"] + sigma = args["sigma"] + #denoised = args["denoised"] + # noise the init_data input with this step's appropriate amount of noise + input_noised = init_data + torch.randn_like(init_data) * sigma + # shrinking hard mask + bmask = get_bmask(i, steps, mask) + # mix input_noise with x, using binary mask + new_x = input_noised * bmask + x * (1-bmask) + # mutate x + x[:,:,:] = new_x[:,:,:] + # wrap together the inpainting callback and the user-submitted callback. + if callback is None: + wrapped_callback = inpainting_callback + else: + wrapped_callback = lambda args: (inpainting_callback(args), callback(args)) + else: + # SAMPLING + # set the initial latent to noise + x = noise + + + with torch.cuda.amp.autocast(): + if sampler_type == "k-heun": + return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-lms": + return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpmpp-2s-ancestral": + return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-2": + return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-fast": + return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "k-dpm-adaptive": + return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "dpmpp-2m-sde": + return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + elif sampler_type == "dpmpp-3m-sde": + return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args) + +# Uses discrete Euler sampling for rectified flow models +# init_data is init_audio as latents (if this is latent diffusion) +# For sampling, set both init_data and mask to None +# For variations, set init_data +# For inpainting, set both init_data & mask +def sample_rf( + model_fn, + noise, + init_data=None, + steps=100, + sigma_max=1, + device="cuda", + callback=None, + cond_fn=None, + **extra_args + ): + + if sigma_max > 1: + sigma_max = 1 + + if cond_fn is not None: + denoiser = make_cond_model_fn(denoiser, cond_fn) + + wrapped_callback = callback + + if init_data is not None: + # VARIATION (no inpainting) + # Interpolate the init data and the noise for init audio + x = init_data * (1 - sigma_max) + noise * sigma_max + else: + # SAMPLING + # set the initial latent to noise + x = noise + + with torch.cuda.amp.autocast(): + # TODO: Add callback support + #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args) + return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args) \ No newline at end of file diff --git a/think_sound/inference/utils.py b/think_sound/inference/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6c0a57609f68156ad244da9b5819666329772e --- /dev/null +++ b/think_sound/inference/utils.py @@ -0,0 +1,35 @@ +from ..data.utils import PadCrop + +from torchaudio import transforms as T + +def set_audio_channels(audio, target_channels): + if target_channels == 1: + # Convert to mono + audio = audio.mean(1, keepdim=True) + elif target_channels == 2: + # Convert to stereo + if audio.shape[1] == 1: + audio = audio.repeat(1, 2, 1) + elif audio.shape[1] > 2: + audio = audio[:, :2, :] + return audio + +def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): + + audio = audio.to(device) + + if in_sr != target_sr: + resample_tf = T.Resample(in_sr, target_sr).to(device) + audio = resample_tf(audio) + + audio = PadCrop(target_length, randomize=False)(audio) + + # Add batch dimension + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + elif audio.dim() == 2: + audio = audio.unsqueeze(0) + + audio = set_audio_channels(audio, target_channels) + + return audio \ No newline at end of file diff --git a/think_sound/interface/__init__.py b/think_sound/interface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/think_sound/interface/__pycache__/__init__.cpython-38.pyc b/think_sound/interface/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9585914d732a148c23ea9b8eed93b5b383238bbe Binary files /dev/null and b/think_sound/interface/__pycache__/__init__.cpython-38.pyc differ diff --git a/think_sound/interface/__pycache__/__init__.cpython-39.pyc b/think_sound/interface/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c791e8cb9797b74d6f44dadc2706615fc07e6d61 Binary files /dev/null and b/think_sound/interface/__pycache__/__init__.cpython-39.pyc differ diff --git a/think_sound/interface/__pycache__/gradio.cpython-38.pyc b/think_sound/interface/__pycache__/gradio.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eae1bc0b0e4fe71946d0042d73b0e4857089f3b1 Binary files /dev/null and b/think_sound/interface/__pycache__/gradio.cpython-38.pyc differ diff --git a/think_sound/interface/__pycache__/gradio.cpython-39.pyc b/think_sound/interface/__pycache__/gradio.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cb07f162f318c799c027a72e8be8a92f940da2a Binary files /dev/null and b/think_sound/interface/__pycache__/gradio.cpython-39.pyc differ diff --git a/think_sound/interface/gradio.py b/think_sound/interface/gradio.py new file mode 100644 index 0000000000000000000000000000000000000000..f38468bc34b88ec6bbe5451a8b11b998430888f8 --- /dev/null +++ b/think_sound/interface/gradio.py @@ -0,0 +1,700 @@ +import gc +import platform + +import numpy as np +import gradio as gr +import json +import torch +import torchaudio + +from aeiou.viz import audio_spectrogram_image +from einops import rearrange +from safetensors.torch import load_file +from torch.nn import functional as F +from torchaudio import transforms as T + +from ..inference.generation import generate_diffusion_cond, generate_diffusion_uncond +from ..models.factory import create_model_from_config +from ..models.pretrained import get_pretrained_model +from ..models.utils import load_ckpt_state_dict +from ..inference.utils import prepare_audio +from ..training.utils import copy_state_dict + +model = None +sample_rate = 32000 +sample_size = 1920000 + +def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False): + global model, sample_rate, sample_size + + if pretrained_name is not None: + print(f"Loading pretrained model {pretrained_name}") + model, model_config = get_pretrained_model(pretrained_name) + + elif model_config is not None and model_ckpt_path is not None: + print(f"Creating model from config") + model = create_model_from_config(model_config) + + print(f"Loading model checkpoint from {model_ckpt_path}") + # Load checkpoint + copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path)) + #model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) + + sample_rate = model_config["sample_rate"] + sample_size = model_config["sample_size"] + + if pretransform_ckpt_path is not None: + print(f"Loading pretransform checkpoint from {pretransform_ckpt_path}") + model.pretransform.load_state_dict(load_ckpt_state_dict(pretransform_ckpt_path), strict=False) + print(f"Done loading pretransform") + + model.to(device).eval().requires_grad_(False) + + if model_half: + model.to(torch.float16) + + print(f"Done loading model") + + return model, model_config + +def generate_cond( + prompt, + negative_prompt=None, + seconds_start=0, + seconds_total=30, + cfg_scale=6.0, + steps=250, + preview_every=None, + seed=-1, + sampler_type="dpmpp-3m-sde", + sigma_min=0.03, + sigma_max=1000, + cfg_rescale=0.0, + use_init=False, + init_audio=None, + init_noise_level=1.0, + mask_cropfrom=None, + mask_pastefrom=None, + mask_pasteto=None, + mask_maskstart=None, + mask_maskend=None, + mask_softnessL=None, + mask_softnessR=None, + mask_marination=None, + batch_size=1 + ): + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + print(f"Prompt: {prompt}") + + global preview_images + preview_images = [] + if preview_every == 0: + preview_every = None + + # Return fake stereo audio + conditioning = [{"prompt": prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size + + if negative_prompt: + negative_conditioning = [{"prompt": negative_prompt, "seconds_start": seconds_start, "seconds_total": seconds_total}] * batch_size + else: + negative_conditioning = None + + #Get the device from the model + device = next(model.parameters()).device + + seed = int(seed) + + if not use_init: + init_audio = None + + input_sample_size = sample_size + + if init_audio is not None: + in_sr, init_audio = init_audio + # Turn into torch tensor, converting from int16 to float32 + init_audio = torch.from_numpy(init_audio).float().div(32767) + + if init_audio.dim() == 1: + init_audio = init_audio.unsqueeze(0) # [1, n] + elif init_audio.dim() == 2: + init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] + + if in_sr != sample_rate: + resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) + init_audio = resample_tf(init_audio) + + audio_length = init_audio.shape[-1] + + if audio_length > sample_size: + + input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length + + init_audio = (sample_rate, init_audio) + + def progress_callback(callback_info): + global preview_images + denoised = callback_info["denoised"] + current_step = callback_info["i"] + sigma = callback_info["sigma"] + + if (current_step - 1) % preview_every == 0: + if model.pretransform is not None: + denoised = model.pretransform.decode(denoised) + denoised = rearrange(denoised, "b d n -> d (b n)") + denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() + audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) + preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) + + # If inpainting, send mask args + # This will definitely change in the future + if mask_cropfrom is not None: + mask_args = { + "cropfrom": mask_cropfrom, + "pastefrom": mask_pastefrom, + "pasteto": mask_pasteto, + "maskstart": mask_maskstart, + "maskend": mask_maskend, + "softnessL": mask_softnessL, + "softnessR": mask_softnessR, + "marination": mask_marination, + } + else: + mask_args = None + + # Do the audio generation + audio = generate_diffusion_cond( + model, + conditioning=conditioning, + negative_conditioning=negative_conditioning, + steps=steps, + cfg_scale=cfg_scale, + batch_size=batch_size, + sample_size=input_sample_size, + sample_rate=sample_rate, + seed=seed, + device=device, + sampler_type=sampler_type, + sigma_min=sigma_min, + sigma_max=sigma_max, + init_audio=init_audio, + init_noise_level=init_noise_level, + mask_args = mask_args, + callback = progress_callback if preview_every is not None else None, + scale_phi = cfg_rescale + ) + + # Convert to WAV file + audio = rearrange(audio, "b d n -> d (b n)") + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + torchaudio.save("output.wav", audio, sample_rate) + + # Let's look at a nice spectrogram too + audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + + return ("output.wav", [audio_spectrogram, *preview_images]) + +def generate_uncond( + steps=250, + seed=-1, + sampler_type="dpmpp-3m-sde", + sigma_min=0.03, + sigma_max=1000, + use_init=False, + init_audio=None, + init_noise_level=1.0, + batch_size=1, + preview_every=None + ): + + global preview_images + + preview_images = [] + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + seed = int(seed) + + if not use_init: + init_audio = None + + input_sample_size = sample_size + + if init_audio is not None: + in_sr, init_audio = init_audio + # Turn into torch tensor, converting from int16 to float32 + init_audio = torch.from_numpy(init_audio).float().div(32767) + + if init_audio.dim() == 1: + init_audio = init_audio.unsqueeze(0) # [1, n] + elif init_audio.dim() == 2: + init_audio = init_audio.transpose(0, 1) # [n, 2] -> [2, n] + + if in_sr != sample_rate: + resample_tf = T.Resample(in_sr, sample_rate).to(init_audio.device) + init_audio = resample_tf(init_audio) + + audio_length = init_audio.shape[-1] + + if audio_length > sample_size: + + input_sample_size = audio_length + (model.min_input_length - (audio_length % model.min_input_length)) % model.min_input_length + + init_audio = (sample_rate, init_audio) + + def progress_callback(callback_info): + global preview_images + denoised = callback_info["denoised"] + current_step = callback_info["i"] + sigma = callback_info["sigma"] + + if (current_step - 1) % preview_every == 0: + + if model.pretransform is not None: + denoised = model.pretransform.decode(denoised) + + denoised = rearrange(denoised, "b d n -> d (b n)") + + denoised = denoised.clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + audio_spectrogram = audio_spectrogram_image(denoised, sample_rate=sample_rate) + + preview_images.append((audio_spectrogram, f"Step {current_step} sigma={sigma:.3f})")) + + audio = generate_diffusion_uncond( + model, + steps=steps, + batch_size=batch_size, + sample_size=input_sample_size, + seed=seed, + device=device, + sampler_type=sampler_type, + sigma_min=sigma_min, + sigma_max=sigma_max, + init_audio=init_audio, + init_noise_level=init_noise_level, + callback = progress_callback if preview_every is not None else None + ) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + + return ("output.wav", [audio_spectrogram, *preview_images]) + +def generate_lm( + temperature=1.0, + top_p=0.95, + top_k=0, + batch_size=1, + ): + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + audio = model.generate_audio( + batch_size=batch_size, + max_gen_len = sample_size//model.pretransform.downsampling_ratio, + conditioning=None, + temp=temperature, + top_p=top_p, + top_k=top_k, + use_cache=True + ) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + audio_spectrogram = audio_spectrogram_image(audio, sample_rate=sample_rate) + + return ("output.wav", [audio_spectrogram]) + + +def create_uncond_sampling_ui(model_config): + generate_button = gr.Button("Generate", variant='primary', scale=1) + + with gr.Row(equal_height=False): + with gr.Column(): + with gr.Row(): + # Steps slider + steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") + + with gr.Accordion("Sampler params", open=False): + + # Seed + seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") + + # Sampler params + with gr.Row(): + sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde") + sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") + sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max") + + with gr.Accordion("Init audio", open=False): + init_audio_checkbox = gr.Checkbox(label="Use init audio") + init_audio_input = gr.Audio(label="Init audio") + init_noise_level_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.01, value=0.1, label="Init noise level") + + with gr.Column(): + audio_output = gr.Audio(label="Output audio", interactive=False) + audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) + send_to_init_button = gr.Button("Send to init audio", scale=1) + send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input]) + + generate_button.click(fn=generate_uncond, + inputs=[ + steps_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, + sigma_max_slider, + init_audio_checkbox, + init_audio_input, + init_noise_level_slider, + ], + outputs=[ + audio_output, + audio_spectrogram_output + ], + api_name="generate") + +def create_sampling_ui(model_config, inpainting=False): + with gr.Row(): + with gr.Column(scale=6): + prompt = gr.Textbox(show_label=False, placeholder="Prompt") + negative_prompt = gr.Textbox(show_label=False, placeholder="Negative prompt") + generate_button = gr.Button("Generate", variant='primary', scale=1) + + model_conditioning_config = model_config["model"].get("conditioning", None) + + has_seconds_start = False + has_seconds_total = False + + if model_conditioning_config is not None: + for conditioning_config in model_conditioning_config["configs"]: + if conditioning_config["id"] == "seconds_start": + has_seconds_start = True + if conditioning_config["id"] == "seconds_total": + has_seconds_total = True + + with gr.Row(equal_height=False): + with gr.Column(): + with gr.Row(visible = has_seconds_start or has_seconds_total): + # Timing controls + seconds_start_slider = gr.Slider(minimum=0, maximum=512, step=1, value=0, label="Seconds start", visible=has_seconds_start) + seconds_total_slider = gr.Slider(minimum=0, maximum=512, step=1, value=sample_size//sample_rate, label="Seconds total", visible=has_seconds_total) + + with gr.Row(): + # Steps slider + steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") + + # Preview Every slider + preview_every_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Preview Every") + + # CFG scale + cfg_scale_slider = gr.Slider(minimum=0.0, maximum=25.0, step=0.1, value=7.0, label="CFG scale") + + with gr.Accordion("Sampler params", open=False): + + # Seed + seed_textbox = gr.Textbox(label="Seed (set to -1 for random seed)", value="-1") + + # Sampler params + with gr.Row(): + sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde") + sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") + sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max") + cfg_rescale_slider = gr.Slider(minimum=0.0, maximum=1, step=0.01, value=0.0, label="CFG rescale amount") + + if inpainting: + # Inpainting Tab + with gr.Accordion("Inpainting", open=False): + sigma_max_slider.maximum=1000 + + init_audio_checkbox = gr.Checkbox(label="Do inpainting") + init_audio_input = gr.Audio(label="Init audio") + init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.1, value=80, label="Init audio noise level", visible=False) # hide this + + mask_cropfrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Crop From %") + mask_pastefrom_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Paste From %") + mask_pasteto_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Paste To %") + + mask_maskstart_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=50, label="Mask Start %") + mask_maskend_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=100, label="Mask End %") + mask_softnessL_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Left Crossfade Length %") + mask_softnessR_slider = gr.Slider(minimum=0.0, maximum=100.0, step=0.1, value=0, label="Softmask Right Crossfade Length %") + mask_marination_slider = gr.Slider(minimum=0.0, maximum=1, step=0.0001, value=0, label="Marination level", visible=False) # still working on the usefulness of this + + inputs = [prompt, + negative_prompt, + seconds_start_slider, + seconds_total_slider, + cfg_scale_slider, + steps_slider, + preview_every_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, + sigma_max_slider, + cfg_rescale_slider, + init_audio_checkbox, + init_audio_input, + init_noise_level_slider, + mask_cropfrom_slider, + mask_pastefrom_slider, + mask_pasteto_slider, + mask_maskstart_slider, + mask_maskend_slider, + mask_softnessL_slider, + mask_softnessR_slider, + mask_marination_slider + ] + else: + # Default generation tab + with gr.Accordion("Init audio", open=False): + init_audio_checkbox = gr.Checkbox(label="Use init audio") + init_audio_input = gr.Audio(label="Init audio") + init_noise_level_slider = gr.Slider(minimum=0.1, maximum=100.0, step=0.01, value=0.1, label="Init noise level") + + inputs = [prompt, + negative_prompt, + seconds_start_slider, + seconds_total_slider, + cfg_scale_slider, + steps_slider, + preview_every_slider, + seed_textbox, + sampler_type_dropdown, + sigma_min_slider, + sigma_max_slider, + cfg_rescale_slider, + init_audio_checkbox, + init_audio_input, + init_noise_level_slider + ] + + with gr.Column(): + audio_output = gr.Audio(label="Output audio", interactive=False) + audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) + send_to_init_button = gr.Button("Send to init audio", scale=1) + send_to_init_button.click(fn=lambda audio: audio, inputs=[audio_output], outputs=[init_audio_input]) + + generate_button.click(fn=generate_cond, + inputs=inputs, + outputs=[ + audio_output, + audio_spectrogram_output + ], + api_name="generate") + + +def create_txt2audio_ui(model_config): + with gr.Blocks() as ui: + with gr.Tab("Generation"): + create_sampling_ui(model_config) + with gr.Tab("Inpainting"): + create_sampling_ui(model_config, inpainting=True) + return ui + +def create_diffusion_uncond_ui(model_config): + with gr.Blocks() as ui: + create_uncond_sampling_ui(model_config) + + return ui + +def autoencoder_process(audio, latent_noise, n_quantizers): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + in_sr, audio = audio + + audio = torch.from_numpy(audio).float().div(32767).to(device) + + if audio.dim() == 1: + audio = audio.unsqueeze(0) + else: + audio = audio.transpose(0, 1) + + audio = model.preprocess_audio_for_encoder(audio, in_sr) + # Note: If you need to do chunked encoding, to reduce VRAM, + # then add these arguments to encode_audio and decode_audio: chunked=True, overlap=32, chunk_size=128 + # To turn it off, do chunked=False + # Optimal overlap and chunk_size values will depend on the model. + # See encode_audio & decode_audio in autoencoders.py for more info + # Get dtype of model + dtype = next(model.parameters()).dtype + + audio = audio.to(dtype) + + if n_quantizers > 0: + latents = model.encode_audio(audio, chunked=False, n_quantizers=n_quantizers) + else: + latents = model.encode_audio(audio, chunked=False) + + if latent_noise > 0: + latents = latents + torch.randn_like(latents) * latent_noise + + audio = model.decode_audio(latents, chunked=False) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + return "output.wav" + +def create_autoencoder_ui(model_config): + + is_dac_rvq = "model" in model_config and "bottleneck" in model_config["model"] and model_config["model"]["bottleneck"]["type"] in ["dac_rvq","dac_rvq_vae"] + + if is_dac_rvq: + n_quantizers = model_config["model"]["bottleneck"]["config"]["n_codebooks"] + else: + n_quantizers = 0 + + with gr.Blocks() as ui: + input_audio = gr.Audio(label="Input audio") + output_audio = gr.Audio(label="Output audio", interactive=False) + n_quantizers_slider = gr.Slider(minimum=1, maximum=n_quantizers, step=1, value=n_quantizers, label="# quantizers", visible=is_dac_rvq) + latent_noise_slider = gr.Slider(minimum=0.0, maximum=10.0, step=0.001, value=0.0, label="Add latent noise") + process_button = gr.Button("Process", variant='primary', scale=1) + process_button.click(fn=autoencoder_process, inputs=[input_audio, latent_noise_slider, n_quantizers_slider], outputs=output_audio, api_name="process") + + return ui + +def diffusion_prior_process(audio, steps, sampler_type, sigma_min, sigma_max): + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + #Get the device from the model + device = next(model.parameters()).device + + in_sr, audio = audio + + audio = torch.from_numpy(audio).float().div(32767).to(device) + + if audio.dim() == 1: + audio = audio.unsqueeze(0) # [1, n] + elif audio.dim() == 2: + audio = audio.transpose(0, 1) # [n, 2] -> [2, n] + + audio = audio.unsqueeze(0) + + audio = model.stereoize(audio, in_sr, steps, sampler_kwargs={"sampler_type": sampler_type, "sigma_min": sigma_min, "sigma_max": sigma_max}) + + audio = rearrange(audio, "b d n -> d (b n)") + + audio = audio.to(torch.float32).div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + + torchaudio.save("output.wav", audio, sample_rate) + + return "output.wav" + +def create_diffusion_prior_ui(model_config): + with gr.Blocks() as ui: + input_audio = gr.Audio(label="Input audio") + output_audio = gr.Audio(label="Output audio", interactive=False) + # Sampler params + with gr.Row(): + steps_slider = gr.Slider(minimum=1, maximum=500, step=1, value=100, label="Steps") + sampler_type_dropdown = gr.Dropdown(["dpmpp-2m-sde", "dpmpp-3m-sde", "k-heun", "k-lms", "k-dpmpp-2s-ancestral", "k-dpm-2", "k-dpm-fast"], label="Sampler type", value="dpmpp-3m-sde") + sigma_min_slider = gr.Slider(minimum=0.0, maximum=2.0, step=0.01, value=0.03, label="Sigma min") + sigma_max_slider = gr.Slider(minimum=0.0, maximum=1000.0, step=0.1, value=500, label="Sigma max") + process_button = gr.Button("Process", variant='primary', scale=1) + process_button.click(fn=diffusion_prior_process, inputs=[input_audio, steps_slider, sampler_type_dropdown, sigma_min_slider, sigma_max_slider], outputs=output_audio, api_name="process") + + return ui + +def create_lm_ui(model_config): + with gr.Blocks() as ui: + output_audio = gr.Audio(label="Output audio", interactive=False) + audio_spectrogram_output = gr.Gallery(label="Output spectrogram", show_label=False) + + # Sampling params + with gr.Row(): + temperature_slider = gr.Slider(minimum=0, maximum=5, step=0.01, value=1.0, label="Temperature") + top_p_slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.95, label="Top p") + top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=0, label="Top k") + + generate_button = gr.Button("Generate", variant='primary', scale=1) + generate_button.click( + fn=generate_lm, + inputs=[ + temperature_slider, + top_p_slider, + top_k_slider + ], + outputs=[output_audio, audio_spectrogram_output], + api_name="generate" + ) + + return ui + +def create_ui(model_config_path=None, ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, model_half=False): + + assert (pretrained_name is not None) ^ (model_config_path is not None and ckpt_path is not None), "Must specify either pretrained name or provide a model config and checkpoint, but not both" + + if model_config_path is not None: + # Load config from json file + with open(model_config_path) as f: + model_config = json.load(f) + else: + model_config = None + + try: + has_mps = platform.system() == "Darwin" and torch.backends.mps.is_available() + except Exception: + # In case this version of Torch doesn't even have `torch.backends.mps`... + has_mps = False + + if has_mps: + device = torch.device("mps") + elif torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + print("Using device:", device) + + _, model_config = load_model(model_config, ckpt_path, pretrained_name=pretrained_name, pretransform_ckpt_path=pretransform_ckpt_path, model_half=model_half, device=device) + + model_type = model_config["model_type"] + + if model_type == "diffusion_cond": + ui = create_txt2audio_ui(model_config) + elif model_type == "diffusion_uncond": + ui = create_diffusion_uncond_ui(model_config) + elif model_type == "autoencoder" or model_type == "diffusion_autoencoder": + ui = create_autoencoder_ui(model_config) + elif model_type == "diffusion_prior": + ui = create_diffusion_prior_ui(model_config) + elif model_type == "lm": + ui = create_lm_ui(model_config) + + return ui diff --git a/think_sound/models/__init__.py b/think_sound/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e27bbcb19a00a93e05ed6cf2a3a38895f26975d --- /dev/null +++ b/think_sound/models/__init__.py @@ -0,0 +1 @@ +from .factory import create_model_from_config, create_model_from_config_path \ No newline at end of file diff --git a/think_sound/models/__pycache__/__init__.cpython-310.pyc b/think_sound/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..616ae52c684375967146708cda52f0a6f0c85f02 Binary files /dev/null and b/think_sound/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/__init__.cpython-38.pyc b/think_sound/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..122d8b0c38a5a0e883935b14e41a460b3fca5ca8 Binary files /dev/null and b/think_sound/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/think_sound/models/__pycache__/__init__.cpython-39.pyc b/think_sound/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86ac628c21b54b80ad6388a7afc2c2df45622c72 Binary files /dev/null and b/think_sound/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/adp.cpython-310.pyc b/think_sound/models/__pycache__/adp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02fa73364c0c4c262771e0bf1a0d52e19e3c1907 Binary files /dev/null and b/think_sound/models/__pycache__/adp.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/adp.cpython-39.pyc b/think_sound/models/__pycache__/adp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47dc1f3adf73733fe498818acc8cd912331c5ce3 Binary files /dev/null and b/think_sound/models/__pycache__/adp.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/autoencoders.cpython-310.pyc b/think_sound/models/__pycache__/autoencoders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..219b7c304c4503e3c009a2f26ba19fdf579b7863 Binary files /dev/null and b/think_sound/models/__pycache__/autoencoders.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/autoencoders.cpython-39.pyc b/think_sound/models/__pycache__/autoencoders.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbffc7fd9bf57cd75207ce4a3ba11c5e7734680a Binary files /dev/null and b/think_sound/models/__pycache__/autoencoders.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/blocks.cpython-310.pyc b/think_sound/models/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d214e7a11f490047824643fb61f4e19db5e9f9c Binary files /dev/null and b/think_sound/models/__pycache__/blocks.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/blocks.cpython-39.pyc b/think_sound/models/__pycache__/blocks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89e01add4d2c8e00362289d661ebf222a6f23368 Binary files /dev/null and b/think_sound/models/__pycache__/blocks.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/bottleneck.cpython-310.pyc b/think_sound/models/__pycache__/bottleneck.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dc1f1046ede83141e434151263840cab1e6cdf7 Binary files /dev/null and b/think_sound/models/__pycache__/bottleneck.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/bottleneck.cpython-39.pyc b/think_sound/models/__pycache__/bottleneck.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef6de8baed483afe3835bcfafacb748b5f19d4d2 Binary files /dev/null and b/think_sound/models/__pycache__/bottleneck.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/conditioners.cpython-310.pyc b/think_sound/models/__pycache__/conditioners.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7de25095434277738e4a7258841f51a9b0d9a589 Binary files /dev/null and b/think_sound/models/__pycache__/conditioners.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/conditioners.cpython-39.pyc b/think_sound/models/__pycache__/conditioners.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9caf5becb3b8e2c68f9f16ee5f1c677bfa902539 Binary files /dev/null and b/think_sound/models/__pycache__/conditioners.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/diffusion.cpython-310.pyc b/think_sound/models/__pycache__/diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16ed10dfb4658fc2b07dd5f174ccdefa1ce3eb1f Binary files /dev/null and b/think_sound/models/__pycache__/diffusion.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/diffusion.cpython-39.pyc b/think_sound/models/__pycache__/diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98a49fdad3e599119289e4c31de61d3241649daf Binary files /dev/null and b/think_sound/models/__pycache__/diffusion.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/diffusion_prior.cpython-310.pyc b/think_sound/models/__pycache__/diffusion_prior.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..350be5114f0b6715caaa7d380f13e6613ca510b0 Binary files /dev/null and b/think_sound/models/__pycache__/diffusion_prior.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/diffusion_prior.cpython-39.pyc b/think_sound/models/__pycache__/diffusion_prior.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77de8bff884189b7ed28a05feab048e9a45a0105 Binary files /dev/null and b/think_sound/models/__pycache__/diffusion_prior.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/discriminators.cpython-310.pyc b/think_sound/models/__pycache__/discriminators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d922769aa15e38fc6f549952d0f7d1e2e42ff60 Binary files /dev/null and b/think_sound/models/__pycache__/discriminators.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/discriminators.cpython-39.pyc b/think_sound/models/__pycache__/discriminators.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40ffe9b8e621f52e0ffe87a3a4f374a7ccdc3080 Binary files /dev/null and b/think_sound/models/__pycache__/discriminators.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/dit.cpython-310.pyc b/think_sound/models/__pycache__/dit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36c7f1f80b7d678b8f188219b14ae5b3d39c2a38 Binary files /dev/null and b/think_sound/models/__pycache__/dit.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/dit.cpython-39.pyc b/think_sound/models/__pycache__/dit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6af671dd9cb9e4b70c16f37344078e9ef68248f1 Binary files /dev/null and b/think_sound/models/__pycache__/dit.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/factory.cpython-310.pyc b/think_sound/models/__pycache__/factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac5e8e82742679dcb4462a9cf5853c34ddbaaf21 Binary files /dev/null and b/think_sound/models/__pycache__/factory.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/factory.cpython-38.pyc b/think_sound/models/__pycache__/factory.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..772c0458f56f40df0255cc677b39a5d13a63c1ee Binary files /dev/null and b/think_sound/models/__pycache__/factory.cpython-38.pyc differ diff --git a/think_sound/models/__pycache__/factory.cpython-39.pyc b/think_sound/models/__pycache__/factory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ceb64b5a60a53e792ddb4c383e705755b8fb061 Binary files /dev/null and b/think_sound/models/__pycache__/factory.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/mmdit.cpython-310.pyc b/think_sound/models/__pycache__/mmdit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c15fe5dfcc45d568faa9c708cb0e6a35884013c Binary files /dev/null and b/think_sound/models/__pycache__/mmdit.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/mmdit.cpython-39.pyc b/think_sound/models/__pycache__/mmdit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab736975de43034ba2bdcfec33676d349758693e Binary files /dev/null and b/think_sound/models/__pycache__/mmdit.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/pretrained.cpython-310.pyc b/think_sound/models/__pycache__/pretrained.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..835bcd5ebdd0eabe77b77a2acd2224246d665a03 Binary files /dev/null and b/think_sound/models/__pycache__/pretrained.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/pretrained.cpython-38.pyc b/think_sound/models/__pycache__/pretrained.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e660e52b6bac33295abcb24fd59fbed479f6d182 Binary files /dev/null and b/think_sound/models/__pycache__/pretrained.cpython-38.pyc differ diff --git a/think_sound/models/__pycache__/pretrained.cpython-39.pyc b/think_sound/models/__pycache__/pretrained.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81ee0fd4e357f8c4c605eb2377912c37ddaa1055 Binary files /dev/null and b/think_sound/models/__pycache__/pretrained.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/pretransforms.cpython-310.pyc b/think_sound/models/__pycache__/pretransforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a46248ae84f97f324fe3e59c03816189e367621f Binary files /dev/null and b/think_sound/models/__pycache__/pretransforms.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/pretransforms.cpython-39.pyc b/think_sound/models/__pycache__/pretransforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13240d647ee8132e85353b99bb32c3cdc5886558 Binary files /dev/null and b/think_sound/models/__pycache__/pretransforms.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/transformer.cpython-310.pyc b/think_sound/models/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4661416a8bb59f98a72f0baf10ac14496c1986b Binary files /dev/null and b/think_sound/models/__pycache__/transformer.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/transformer.cpython-39.pyc b/think_sound/models/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8957557e758f2d304a0f1ebf04f76e5a2b547c1f Binary files /dev/null and b/think_sound/models/__pycache__/transformer.cpython-39.pyc differ diff --git a/think_sound/models/__pycache__/utils.cpython-310.pyc b/think_sound/models/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1eb7c4445470a6b94a542334f399b147e0a12d66 Binary files /dev/null and b/think_sound/models/__pycache__/utils.cpython-310.pyc differ diff --git a/think_sound/models/__pycache__/utils.cpython-38.pyc b/think_sound/models/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2785b42d22f3342618a5251cd7a86a7a2c81d0d Binary files /dev/null and b/think_sound/models/__pycache__/utils.cpython-38.pyc differ diff --git a/think_sound/models/__pycache__/utils.cpython-39.pyc b/think_sound/models/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bffbdbbc37c5cdf27e9e02467b7b2a4ce215a50 Binary files /dev/null and b/think_sound/models/__pycache__/utils.cpython-39.pyc differ diff --git a/think_sound/models/adp.py b/think_sound/models/adp.py new file mode 100644 index 0000000000000000000000000000000000000000..49eb526ab02d16eb4952d346401b1ad2b7e5cb7c --- /dev/null +++ b/think_sound/models/adp.py @@ -0,0 +1,1588 @@ +# Copied and modified from https://github.com/archinetai/audio-diffusion-pytorch/blob/v0.0.94/audio_diffusion_pytorch/modules.py under MIT License +# License can be found in LICENSES/LICENSE_ADP.txt + +import math +from inspect import isfunction +from math import ceil, floor, log, pi, log2 +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union +from packaging import version + +import torch +import torch.nn as nn +from einops import rearrange, reduce, repeat +from einops.layers.torch import Rearrange +from einops_exts import rearrange_many +from torch import Tensor, einsum +from torch.backends.cuda import sdp_kernel +from torch.nn import functional as F +from dac.nn.layers import Snake1d + +""" +Utils +""" + + +class ConditionedSequential(nn.Module): + def __init__(self, *modules): + super().__init__() + self.module_list = nn.ModuleList(*modules) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None): + for module in self.module_list: + x = module(x, mapping) + return x + +T = TypeVar("T") + +def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T: + if exists(val): + return val + return d() if isfunction(d) else d + +def exists(val: Optional[T]) -> T: + return val is not None + +def closest_power_2(x: float) -> int: + exponent = log2(x) + distance_fn = lambda z: abs(x - 2 ** z) # noqa + exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn) + return 2 ** int(exponent_closest) + +def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: + return_dicts: Tuple[Dict, Dict] = ({}, {}) + for key in d.keys(): + no_prefix = int(not key.startswith(prefix)) + return_dicts[no_prefix][key] = d[key] + return return_dicts + +def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]: + kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) + if keep_prefix: + return kwargs_with_prefix, kwargs + kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} + return kwargs_no_prefix, kwargs + +""" +Convolutional Blocks +""" +import typing as tp + +# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conv.py under MIT License +# License available in LICENSES/LICENSE_META.txt + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """See `pad_for_conv1d`.""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0): + """Pad for a convolution to make sure that the last window is full. + Extra padding is added at the end. This is required to ensure that we can rebuild + an output of the same length, as otherwise, even with padding, some time steps + might get removed. + For instance, with total padding = 4, kernel size = 4, stride = 2: + 0 0 1 2 3 4 5 0 0 # (0s are padding) + 1 2 3 # (output frames of a convolution, last 0 is never used) + 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding) + 1 2 3 4 # once you removed padding, we are missing one time step ! + """ + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + return F.pad(x, (0, extra_padding)) + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): + """Tiny wrapper around F.pad, just to allow for reflect padding on small input. + If this is the case, we insert extra 0 padding to the right before the reflection happen. + """ + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class Conv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, causal=False) -> Tensor: + kernel_size = self.kernel_size[0] + stride = self.stride[0] + dilation = self.dilation[0] + kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations + padding_total = kernel_size - stride + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + if causal: + # Left padding for causal + x = pad1d(x, (padding_total, extra_padding)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding)) + return super().forward(x) + +class ConvTranspose1d(nn.ConvTranspose1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x: Tensor, causal=False) -> Tensor: + kernel_size = self.kernel_size[0] + stride = self.stride[0] + padding_total = kernel_size - stride + + y = super().forward(x) + + # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be + # removed at the very end, when keeping only the right length for the output, + # as removing it here would require also passing the length at the matching layer + # in the encoder. + if causal: + padding_right = ceil(padding_total) + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + else: + # Asymmetric padding required for odd strides + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + y = unpad1d(y, (padding_left, padding_right)) + return y + + +def Downsample1d( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor + ) + + +def Upsample1d( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3 + ), + ) + else: + return ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor + ) + + +class ConvBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + num_groups: int = 8, + use_norm: bool = True, + use_snake: bool = False + ) -> None: + super().__init__() + + self.groupnorm = ( + nn.GroupNorm(num_groups=num_groups, num_channels=in_channels) + if use_norm + else nn.Identity() + ) + + if use_snake: + self.activation = Snake1d(in_channels) + else: + self.activation = nn.SiLU() + + self.project = Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + ) + + def forward( + self, x: Tensor, scale_shift: Optional[Tuple[Tensor, Tensor]] = None, causal=False + ) -> Tensor: + x = self.groupnorm(x) + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + x = self.activation(x) + return self.project(x, causal=causal) + + +class MappingToScaleShift(nn.Module): + def __init__( + self, + features: int, + channels: int, + ): + super().__init__() + + self.to_scale_shift = nn.Sequential( + nn.SiLU(), + nn.Linear(in_features=features, out_features=channels * 2), + ) + + def forward(self, mapping: Tensor) -> Tuple[Tensor, Tensor]: + scale_shift = self.to_scale_shift(mapping) + scale_shift = rearrange(scale_shift, "b c -> b c 1") + scale, shift = scale_shift.chunk(2, dim=1) + return scale, shift + + +class ResnetBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + kernel_size: int = 3, + stride: int = 1, + dilation: int = 1, + use_norm: bool = True, + use_snake: bool = False, + num_groups: int = 8, + context_mapping_features: Optional[int] = None, + ) -> None: + super().__init__() + + self.use_mapping = exists(context_mapping_features) + + self.block1 = ConvBlock1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + use_norm=use_norm, + num_groups=num_groups, + use_snake=use_snake + ) + + if self.use_mapping: + assert exists(context_mapping_features) + self.to_scale_shift = MappingToScaleShift( + features=context_mapping_features, channels=out_channels + ) + + self.block2 = ConvBlock1d( + in_channels=out_channels, + out_channels=out_channels, + use_norm=use_norm, + num_groups=num_groups, + use_snake=use_snake + ) + + self.to_out = ( + Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) + if in_channels != out_channels + else nn.Identity() + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + assert_message = "context mapping required if context_mapping_features > 0" + assert not (self.use_mapping ^ exists(mapping)), assert_message + + h = self.block1(x, causal=causal) + + scale_shift = None + if self.use_mapping: + scale_shift = self.to_scale_shift(mapping) + + h = self.block2(h, scale_shift=scale_shift, causal=causal) + + return h + self.to_out(x) + + +class Patcher(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + context_mapping_features: Optional[int] = None, + use_snake: bool = False, + ): + super().__init__() + assert_message = f"out_channels must be divisible by patch_size ({patch_size})" + assert out_channels % patch_size == 0, assert_message + self.patch_size = patch_size + + self.block = ResnetBlock1d( + in_channels=in_channels, + out_channels=out_channels // patch_size, + num_groups=1, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + x = self.block(x, mapping, causal=causal) + x = rearrange(x, "b c (l p) -> b (c p) l", p=self.patch_size) + return x + + +class Unpatcher(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + patch_size: int, + context_mapping_features: Optional[int] = None, + use_snake: bool = False + ): + super().__init__() + assert_message = f"in_channels must be divisible by patch_size ({patch_size})" + assert in_channels % patch_size == 0, assert_message + self.patch_size = patch_size + + self.block = ResnetBlock1d( + in_channels=in_channels // patch_size, + out_channels=out_channels, + num_groups=1, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward(self, x: Tensor, mapping: Optional[Tensor] = None, causal=False) -> Tensor: + x = rearrange(x, " b (c p) l -> b c (l p) ", p=self.patch_size) + x = self.block(x, mapping, causal=causal) + return x + + +""" +Attention Components +""" +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + +def add_mask(sim: Tensor, mask: Tensor) -> Tensor: + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + +def causal_mask(q: Tensor, k: Tensor) -> Tensor: + b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device + mask = ~torch.ones((i, j), dtype=torch.bool, device=device).triu(j - i + 1) + mask = repeat(mask, "n m -> b n m", b=b) + return mask + +class AttentionBase(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + ): + super().__init__() + self.scale = head_features**-0.5 + self.num_heads = num_heads + mid_features = head_features * num_heads + out_features = default(out_features, features) + + self.to_out = nn.Linear( + in_features=mid_features, out_features=out_features + ) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False + ) -> Tensor: + # Split heads + q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads) + + if not self.use_flash: + if is_causal and not mask: + # Mask out future tokens for causal attention + mask = causal_mask(q, k) + + # Compute similarity matrix and add eventual mask + sim = einsum("... n d, ... m d -> ... n m", q, k) * self.scale + sim = add_mask(sim, mask) if exists(mask) else sim + + # Get attention matrix with softmax + attn = sim.softmax(dim=-1, dtype=torch.float32) + + # Compute values + out = einsum("... n m, ... m d -> ... n d", attn, v) + else: + with sdp_kernel(*self.sdp_kernel_config): + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=is_causal) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + +class Attention(nn.Module): + def __init__( + self, + features: int, + *, + head_features: int, + num_heads: int, + out_features: Optional[int] = None, + context_features: Optional[int] = None, + causal: bool = False, + ): + super().__init__() + self.context_features = context_features + self.causal = causal + mid_features = head_features * num_heads + context_features = default(context_features, features) + + self.norm = nn.LayerNorm(features) + self.norm_context = nn.LayerNorm(context_features) + self.to_q = nn.Linear( + in_features=features, out_features=mid_features, bias=False + ) + self.to_kv = nn.Linear( + in_features=context_features, out_features=mid_features * 2, bias=False + ) + self.attention = AttentionBase( + features, + num_heads=num_heads, + head_features=head_features, + out_features=out_features, + ) + + def forward( + self, + x: Tensor, # [b, n, c] + context: Optional[Tensor] = None, # [b, m, d] + context_mask: Optional[Tensor] = None, # [b, m], false is masked, + causal: Optional[bool] = False, + ) -> Tensor: + assert_message = "You must provide a context when using context_features" + assert not self.context_features or exists(context), assert_message + # Use context if provided + context = default(context, x) + # Normalize then compute q from input and k,v from context + x, context = self.norm(x), self.norm_context(context) + + q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1)) + + if exists(context_mask): + # Mask out cross-attention for padding tokens + mask = repeat(context_mask, "b m -> b m d", d=v.shape[-1]) + k, v = k * mask, v * mask + + # Compute and return attention + return self.attention(q, k, v, is_causal=self.causal or causal) + + +def FeedForward(features: int, multiplier: int) -> nn.Module: + mid_features = features * multiplier + return nn.Sequential( + nn.Linear(in_features=features, out_features=mid_features), + nn.GELU(), + nn.Linear(in_features=mid_features, out_features=features), + ) + +""" +Transformer Blocks +""" + + +class TransformerBlock(nn.Module): + def __init__( + self, + features: int, + num_heads: int, + head_features: int, + multiplier: int, + context_features: Optional[int] = None, + ): + super().__init__() + + self.use_cross_attention = exists(context_features) and context_features > 0 + + self.attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features + ) + + if self.use_cross_attention: + self.cross_attention = Attention( + features=features, + num_heads=num_heads, + head_features=head_features, + context_features=context_features + ) + + self.feed_forward = FeedForward(features=features, multiplier=multiplier) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal: Optional[bool] = False) -> Tensor: + x = self.attention(x, causal=causal) + x + if self.use_cross_attention: + x = self.cross_attention(x, context=context, context_mask=context_mask) + x + x = self.feed_forward(x) + x + return x + + +""" +Transformers +""" + + +class Transformer1d(nn.Module): + def __init__( + self, + num_layers: int, + channels: int, + num_heads: int, + head_features: int, + multiplier: int, + context_features: Optional[int] = None, + ): + super().__init__() + + self.to_in = nn.Sequential( + nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True), + Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=1, + ), + Rearrange("b c t -> b t c"), + ) + + self.blocks = nn.ModuleList( + [ + TransformerBlock( + features=channels, + head_features=head_features, + num_heads=num_heads, + multiplier=multiplier, + context_features=context_features, + ) + for i in range(num_layers) + ] + ) + + self.to_out = nn.Sequential( + Rearrange("b t c -> b c t"), + Conv1d( + in_channels=channels, + out_channels=channels, + kernel_size=1, + ), + ) + + def forward(self, x: Tensor, *, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, causal=False) -> Tensor: + x = self.to_in(x) + for block in self.blocks: + x = block(x, context=context, context_mask=context_mask, causal=causal) + x = self.to_out(x) + return x + + +""" +Time Embeddings +""" + + +class SinusoidalEmbedding(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device, half_dim = x.device, self.dim // 2 + emb = torch.tensor(log(10000) / (half_dim - 1), device=device) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j") + return torch.cat((emb.sin(), emb.cos()), dim=-1) + + +class LearnedPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x: Tensor) -> Tensor: + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: + return nn.Sequential( + LearnedPositionalEmbedding(dim), + nn.Linear(in_features=dim + 1, out_features=out_features), + ) + + +""" +Encoder/Decoder Components +""" + + +class DownsampleBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + factor: int, + num_groups: int, + num_layers: int, + kernel_multiplier: int = 2, + use_pre_downsample: bool = True, + use_skip: bool = False, + use_snake: bool = False, + extract_channels: int = 0, + context_channels: int = 0, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + ): + super().__init__() + self.use_pre_downsample = use_pre_downsample + self.use_skip = use_skip + self.use_transformer = num_transformer_blocks > 0 + self.use_extract = extract_channels > 0 + self.use_context = context_channels > 0 + + channels = out_channels if use_pre_downsample else in_channels + + self.downsample = Downsample1d( + in_channels=in_channels, + out_channels=out_channels, + factor=factor, + kernel_multiplier=kernel_multiplier, + ) + + self.blocks = nn.ModuleList( + [ + ResnetBlock1d( + in_channels=channels + context_channels if i == 0 else channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + for i in range(num_layers) + ] + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features + ) + + if self.use_extract: + num_extract_groups = min(num_groups, extract_channels) + self.to_extracted = ResnetBlock1d( + in_channels=out_channels, + out_channels=extract_channels, + num_groups=num_extract_groups, + use_snake=use_snake + ) + + def forward( + self, + x: Tensor, + *, + mapping: Optional[Tensor] = None, + channels: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Union[Tuple[Tensor, List[Tensor]], Tensor]: + + if self.use_pre_downsample: + x = self.downsample(x) + + if self.use_context and exists(channels): + x = torch.cat([x, channels], dim=1) + + skips = [] + for block in self.blocks: + x = block(x, mapping=mapping, causal=causal) + skips += [x] if self.use_skip else [] + + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + skips += [x] if self.use_skip else [] + + if not self.use_pre_downsample: + x = self.downsample(x) + + if self.use_extract: + extracted = self.to_extracted(x) + return x, extracted + + return (x, skips) if self.use_skip else x + + +class UpsampleBlock1d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + *, + factor: int, + num_layers: int, + num_groups: int, + use_nearest: bool = False, + use_pre_upsample: bool = False, + use_skip: bool = False, + use_snake: bool = False, + skip_channels: int = 0, + use_skip_scale: bool = False, + extract_channels: int = 0, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + ): + super().__init__() + + self.use_extract = extract_channels > 0 + self.use_pre_upsample = use_pre_upsample + self.use_transformer = num_transformer_blocks > 0 + self.use_skip = use_skip + self.skip_scale = 2 ** -0.5 if use_skip_scale else 1.0 + + channels = out_channels if use_pre_upsample else in_channels + + self.blocks = nn.ModuleList( + [ + ResnetBlock1d( + in_channels=channels + skip_channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + for _ in range(num_layers) + ] + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features, + ) + + self.upsample = Upsample1d( + in_channels=in_channels, + out_channels=out_channels, + factor=factor, + use_nearest=use_nearest, + ) + + if self.use_extract: + num_extract_groups = min(num_groups, extract_channels) + self.to_extracted = ResnetBlock1d( + in_channels=out_channels, + out_channels=extract_channels, + num_groups=num_extract_groups, + use_snake=use_snake + ) + + def add_skip(self, x: Tensor, skip: Tensor) -> Tensor: + return torch.cat([x, skip * self.skip_scale], dim=1) + + def forward( + self, + x: Tensor, + *, + skips: Optional[List[Tensor]] = None, + mapping: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Union[Tuple[Tensor, Tensor], Tensor]: + + if self.use_pre_upsample: + x = self.upsample(x) + + for block in self.blocks: + x = self.add_skip(x, skip=skips.pop()) if exists(skips) else x + x = block(x, mapping=mapping, causal=causal) + + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + + if not self.use_pre_upsample: + x = self.upsample(x) + + if self.use_extract: + extracted = self.to_extracted(x) + return x, extracted + + return x + + +class BottleneckBlock1d(nn.Module): + def __init__( + self, + channels: int, + *, + num_groups: int, + num_transformer_blocks: int = 0, + attention_heads: Optional[int] = None, + attention_features: Optional[int] = None, + attention_multiplier: Optional[int] = None, + context_mapping_features: Optional[int] = None, + context_embedding_features: Optional[int] = None, + use_snake: bool = False, + ): + super().__init__() + self.use_transformer = num_transformer_blocks > 0 + + self.pre_block = ResnetBlock1d( + in_channels=channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + if self.use_transformer: + assert ( + (exists(attention_heads) or exists(attention_features)) + and exists(attention_multiplier) + ) + + if attention_features is None and attention_heads is not None: + attention_features = channels // attention_heads + + if attention_heads is None and attention_features is not None: + attention_heads = channels // attention_features + + self.transformer = Transformer1d( + num_layers=num_transformer_blocks, + channels=channels, + num_heads=attention_heads, + head_features=attention_features, + multiplier=attention_multiplier, + context_features=context_embedding_features, + ) + + self.post_block = ResnetBlock1d( + in_channels=channels, + out_channels=channels, + num_groups=num_groups, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def forward( + self, + x: Tensor, + *, + mapping: Optional[Tensor] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False + ) -> Tensor: + x = self.pre_block(x, mapping=mapping, causal=causal) + if self.use_transformer: + x = self.transformer(x, context=embedding, context_mask=embedding_mask, causal=causal) + x = self.post_block(x, mapping=mapping, causal=causal) + return x + + +""" +UNet +""" + + +class UNet1d(nn.Module): + def __init__( + self, + in_channels: int, + channels: int, + multipliers: Sequence[int], + factors: Sequence[int], + num_blocks: Sequence[int], + attentions: Sequence[int], + patch_size: int = 1, + resnet_groups: int = 8, + use_context_time: bool = True, + kernel_multiplier_downsample: int = 2, + use_nearest_upsample: bool = False, + use_skip_scale: bool = True, + use_snake: bool = False, + use_stft: bool = False, + use_stft_context: bool = False, + out_channels: Optional[int] = None, + context_features: Optional[int] = None, + context_features_multiplier: int = 4, + context_channels: Optional[Sequence[int]] = None, + context_embedding_features: Optional[int] = None, + **kwargs, + ): + super().__init__() + out_channels = default(out_channels, in_channels) + context_channels = list(default(context_channels, [])) + num_layers = len(multipliers) - 1 + use_context_features = exists(context_features) + use_context_channels = len(context_channels) > 0 + context_mapping_features = None + + attention_kwargs, kwargs = groupby("attention_", kwargs, keep_prefix=True) + + self.num_layers = num_layers + self.use_context_time = use_context_time + self.use_context_features = use_context_features + self.use_context_channels = use_context_channels + self.use_stft = use_stft + self.use_stft_context = use_stft_context + + self.context_features = context_features + context_channels_pad_length = num_layers + 1 - len(context_channels) + context_channels = context_channels + [0] * context_channels_pad_length + self.context_channels = context_channels + self.context_embedding_features = context_embedding_features + + if use_context_channels: + has_context = [c > 0 for c in context_channels] + self.has_context = has_context + self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] + + assert ( + len(factors) == num_layers + and len(attentions) >= num_layers + and len(num_blocks) == num_layers + ) + + if use_context_time or use_context_features: + context_mapping_features = channels * context_features_multiplier + + self.to_mapping = nn.Sequential( + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + nn.Linear(context_mapping_features, context_mapping_features), + nn.GELU(), + ) + + if use_context_time: + assert exists(context_mapping_features) + self.to_time = nn.Sequential( + TimePositionalEmbedding( + dim=channels, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_context_features: + assert exists(context_features) and exists(context_mapping_features) + self.to_features = nn.Sequential( + nn.Linear( + in_features=context_features, out_features=context_mapping_features + ), + nn.GELU(), + ) + + if use_stft: + stft_kwargs, kwargs = groupby("stft_", kwargs) + assert "num_fft" in stft_kwargs, "stft_num_fft required if use_stft=True" + stft_channels = (stft_kwargs["num_fft"] // 2 + 1) * 2 + in_channels *= stft_channels + out_channels *= stft_channels + context_channels[0] *= stft_channels if use_stft_context else 1 + assert exists(in_channels) and exists(out_channels) + self.stft = STFT(**stft_kwargs) + + assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}" + + self.to_in = Patcher( + in_channels=in_channels + context_channels[0], + out_channels=channels * multipliers[0], + patch_size=patch_size, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + self.downsamples = nn.ModuleList( + [ + DownsampleBlock1d( + in_channels=channels * multipliers[i], + out_channels=channels * multipliers[i + 1], + context_mapping_features=context_mapping_features, + context_channels=context_channels[i + 1], + context_embedding_features=context_embedding_features, + num_layers=num_blocks[i], + factor=factors[i], + kernel_multiplier=kernel_multiplier_downsample, + num_groups=resnet_groups, + use_pre_downsample=True, + use_skip=True, + use_snake=use_snake, + num_transformer_blocks=attentions[i], + **attention_kwargs, + ) + for i in range(num_layers) + ] + ) + + self.bottleneck = BottleneckBlock1d( + channels=channels * multipliers[-1], + context_mapping_features=context_mapping_features, + context_embedding_features=context_embedding_features, + num_groups=resnet_groups, + num_transformer_blocks=attentions[-1], + use_snake=use_snake, + **attention_kwargs, + ) + + self.upsamples = nn.ModuleList( + [ + UpsampleBlock1d( + in_channels=channels * multipliers[i + 1], + out_channels=channels * multipliers[i], + context_mapping_features=context_mapping_features, + context_embedding_features=context_embedding_features, + num_layers=num_blocks[i] + (1 if attentions[i] else 0), + factor=factors[i], + use_nearest=use_nearest_upsample, + num_groups=resnet_groups, + use_skip_scale=use_skip_scale, + use_pre_upsample=False, + use_skip=True, + use_snake=use_snake, + skip_channels=channels * multipliers[i + 1], + num_transformer_blocks=attentions[i], + **attention_kwargs, + ) + for i in reversed(range(num_layers)) + ] + ) + + self.to_out = Unpatcher( + in_channels=channels * multipliers[0], + out_channels=out_channels, + patch_size=patch_size, + context_mapping_features=context_mapping_features, + use_snake=use_snake + ) + + def get_channels( + self, channels_list: Optional[Sequence[Tensor]] = None, layer: int = 0 + ) -> Optional[Tensor]: + """Gets context channels at `layer` and checks that shape is correct""" + use_context_channels = self.use_context_channels and self.has_context[layer] + if not use_context_channels: + return None + assert exists(channels_list), "Missing context" + # Get channels index (skipping zero channel contexts) + channels_id = self.channels_ids[layer] + # Get channels + channels = channels_list[channels_id] + message = f"Missing context for layer {layer} at index {channels_id}" + assert exists(channels), message + # Check channels + num_channels = self.context_channels[layer] + message = f"Expected context with {num_channels} channels at idx {channels_id}" + assert channels.shape[1] == num_channels, message + # STFT channels if requested + channels = self.stft.encode1d(channels) if self.use_stft_context else channels # type: ignore # noqa + return channels + + def get_mapping( + self, time: Optional[Tensor] = None, features: Optional[Tensor] = None + ) -> Optional[Tensor]: + """Combines context time features and features into mapping""" + items, mapping = [], None + # Compute time features + if self.use_context_time: + assert_message = "use_context_time=True but no time features provided" + assert exists(time), assert_message + items += [self.to_time(time)] + # Compute features + if self.use_context_features: + assert_message = "context_features exists but no features provided" + assert exists(features), assert_message + items += [self.to_features(features)] + # Compute joint mapping + if self.use_context_time or self.use_context_features: + mapping = reduce(torch.stack(items), "n b m -> b m", "sum") + mapping = self.to_mapping(mapping) + return mapping + + def forward( + self, + x: Tensor, + time: Optional[Tensor] = None, + *, + features: Optional[Tensor] = None, + channels_list: Optional[Sequence[Tensor]] = None, + embedding: Optional[Tensor] = None, + embedding_mask: Optional[Tensor] = None, + causal: Optional[bool] = False, + ) -> Tensor: + channels = self.get_channels(channels_list, layer=0) + # Apply stft if required + x = self.stft.encode1d(x) if self.use_stft else x # type: ignore + # Concat context channels at layer 0 if provided + x = torch.cat([x, channels], dim=1) if exists(channels) else x + # Compute mapping from time and features + mapping = self.get_mapping(time, features) + x = self.to_in(x, mapping, causal=causal) + skips_list = [x] + + for i, downsample in enumerate(self.downsamples): + channels = self.get_channels(channels_list, layer=i + 1) + x, skips = downsample( + x, mapping=mapping, channels=channels, embedding=embedding, embedding_mask=embedding_mask, causal=causal + ) + skips_list += [skips] + + x = self.bottleneck(x, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + + for i, upsample in enumerate(self.upsamples): + skips = skips_list.pop() + x = upsample(x, skips=skips, mapping=mapping, embedding=embedding, embedding_mask=embedding_mask, causal=causal) + + x += skips_list.pop() + x = self.to_out(x, mapping, causal=causal) + x = self.stft.decode1d(x) if self.use_stft else x + + return x + + +""" Conditioning Modules """ + + +class FixedEmbedding(nn.Module): + def __init__(self, max_length: int, features: int): + super().__init__() + self.max_length = max_length + self.embedding = nn.Embedding(max_length, features) + + def forward(self, x: Tensor) -> Tensor: + batch_size, length, device = *x.shape[0:2], x.device + assert_message = "Input sequence length must be <= max_length" + assert length <= self.max_length, assert_message + position = torch.arange(length, device=device) + fixed_embedding = self.embedding(position) + fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size) + return fixed_embedding + + +def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor: + if proba == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif proba == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool) + + +class UNetCFG1d(UNet1d): + + """UNet1d with Classifier-Free Guidance""" + + def __init__( + self, + context_embedding_max_length: int, + context_embedding_features: int, + use_xattn_time: bool = False, + **kwargs, + ): + super().__init__( + context_embedding_features=context_embedding_features, **kwargs + ) + + self.use_xattn_time = use_xattn_time + + if use_xattn_time: + assert exists(context_embedding_features) + self.to_time_embedding = nn.Sequential( + TimePositionalEmbedding( + dim=kwargs["channels"], out_features=context_embedding_features + ), + nn.GELU(), + ) + + context_embedding_max_length += 1 # Add one for time embedding + + self.fixed_embedding = FixedEmbedding( + max_length=context_embedding_max_length, features=context_embedding_features + ) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + embedding: Tensor, + embedding_mask: Optional[Tensor] = None, + embedding_scale: float = 1.0, + embedding_mask_proba: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + scale_phi: float = 0.4, + negative_embedding: Optional[Tensor] = None, + negative_embedding_mask: Optional[Tensor] = None, + **kwargs, + ) -> Tensor: + b, device = embedding.shape[0], embedding.device + + if self.use_xattn_time: + embedding = torch.cat([embedding, self.to_time_embedding(time).unsqueeze(1)], dim=1) + + if embedding_mask is not None: + embedding_mask = torch.cat([embedding_mask, torch.ones((b, 1), device=device)], dim=1) + + fixed_embedding = self.fixed_embedding(embedding) + + if embedding_mask_proba > 0.0: + # Randomly mask embedding + batch_mask = rand_bool( + shape=(b, 1, 1), proba=embedding_mask_proba, device=device + ) + embedding = torch.where(batch_mask, fixed_embedding, embedding) + + if embedding_scale != 1.0: + if batch_cfg: + batch_x = torch.cat([x, x], dim=0) + batch_time = torch.cat([time, time], dim=0) + + if negative_embedding is not None: + if negative_embedding_mask is not None: + negative_embedding_mask = negative_embedding_mask.to(torch.bool).unsqueeze(2) + + negative_embedding = torch.where(negative_embedding_mask, negative_embedding, fixed_embedding) + + batch_embed = torch.cat([embedding, negative_embedding], dim=0) + + else: + batch_embed = torch.cat([embedding, fixed_embedding], dim=0) + + batch_mask = None + if embedding_mask is not None: + batch_mask = torch.cat([embedding_mask, embedding_mask], dim=0) + + batch_features = None + features = kwargs.pop("features", None) + if self.use_context_features: + batch_features = torch.cat([features, features], dim=0) + + batch_channels = None + channels_list = kwargs.pop("channels_list", None) + if self.use_context_channels: + batch_channels = [] + for channels in channels_list: + batch_channels += [torch.cat([channels, channels], dim=0)] + + # Compute both normal and fixed embedding outputs + batch_out = super().forward(batch_x, batch_time, embedding=batch_embed, embedding_mask=batch_mask, features=batch_features, channels_list=batch_channels, **kwargs) + out, out_masked = batch_out.chunk(2, dim=0) + + else: + # Compute both normal and fixed embedding outputs + out = super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) + out_masked = super().forward(x, time, embedding=fixed_embedding, embedding_mask=embedding_mask, **kwargs) + + out_cfg = out_masked + (out - out_masked) * embedding_scale + + if rescale_cfg: + + out_std = out.std(dim=1, keepdim=True) + out_cfg_std = out_cfg.std(dim=1, keepdim=True) + + return scale_phi * (out_cfg * (out_std/out_cfg_std)) + (1-scale_phi) * out_cfg + + else: + + return out_cfg + + else: + return super().forward(x, time, embedding=embedding, embedding_mask=embedding_mask, **kwargs) + + +class UNetNCCA1d(UNet1d): + + """UNet1d with Noise Channel Conditioning Augmentation""" + + def __init__(self, context_features: int, **kwargs): + super().__init__(context_features=context_features, **kwargs) + self.embedder = NumberEmbedder(features=context_features) + + def expand(self, x: Any, shape: Tuple[int, ...]) -> Tensor: + x = x if torch.is_tensor(x) else torch.tensor(x) + return x.expand(shape) + + def forward( # type: ignore + self, + x: Tensor, + time: Tensor, + *, + channels_list: Sequence[Tensor], + channels_augmentation: Union[ + bool, Sequence[bool], Sequence[Sequence[bool]], Tensor + ] = False, + channels_scale: Union[ + float, Sequence[float], Sequence[Sequence[float]], Tensor + ] = 0, + **kwargs, + ) -> Tensor: + b, n = x.shape[0], len(channels_list) + channels_augmentation = self.expand(channels_augmentation, shape=(b, n)).to(x) + channels_scale = self.expand(channels_scale, shape=(b, n)).to(x) + + # Augmentation (for each channel list item) + for i in range(n): + scale = channels_scale[:, i] * channels_augmentation[:, i] + scale = rearrange(scale, "b -> b 1 1") + item = channels_list[i] + channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa + + # Scale embedding (sum reduction if more than one channel list item) + channels_scale_emb = self.embedder(channels_scale) + channels_scale_emb = reduce(channels_scale_emb, "b n d -> b d", "sum") + + return super().forward( + x=x, + time=time, + channels_list=channels_list, + features=channels_scale_emb, + **kwargs, + ) + + +class UNetAll1d(UNetCFG1d, UNetNCCA1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, *args, **kwargs): # type: ignore + return UNetCFG1d.forward(self, *args, **kwargs) + + +def XUNet1d(type: str = "base", **kwargs) -> UNet1d: + if type == "base": + return UNet1d(**kwargs) + elif type == "all": + return UNetAll1d(**kwargs) + elif type == "cfg": + return UNetCFG1d(**kwargs) + elif type == "ncca": + return UNetNCCA1d(**kwargs) + else: + raise ValueError(f"Unknown XUNet1d type: {type}") + +class NumberEmbedder(nn.Module): + def __init__( + self, + features: int, + dim: int = 256, + ): + super().__init__() + self.features = features + self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) + + def forward(self, x: Union[List[float], Tensor]) -> Tensor: + if not torch.is_tensor(x): + device = next(self.embedding.parameters()).device + x = torch.tensor(x, device=device) + assert isinstance(x, Tensor) + shape = x.shape + x = rearrange(x, "... -> (...)") + embedding = self.embedding(x) + x = embedding.view(*shape, self.features) + return x # type: ignore + + +""" +Audio Transforms +""" + + +class STFT(nn.Module): + """Helper for torch stft and istft""" + + def __init__( + self, + num_fft: int = 1023, + hop_length: int = 256, + window_length: Optional[int] = None, + length: Optional[int] = None, + use_complex: bool = False, + ): + super().__init__() + self.num_fft = num_fft + self.hop_length = default(hop_length, floor(num_fft // 4)) + self.window_length = default(window_length, num_fft) + self.length = length + self.register_buffer("window", torch.hann_window(self.window_length)) + self.use_complex = use_complex + + def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]: + b = wave.shape[0] + wave = rearrange(wave, "b c t -> (b c) t") + + stft = torch.stft( + wave, + n_fft=self.num_fft, + hop_length=self.hop_length, + win_length=self.window_length, + window=self.window, # type: ignore + return_complex=True, + normalized=True, + ) + + if self.use_complex: + # Returns real and imaginary + stft_a, stft_b = stft.real, stft.imag + else: + # Returns magnitude and phase matrices + magnitude, phase = torch.abs(stft), torch.angle(stft) + stft_a, stft_b = magnitude, phase + + return rearrange_many((stft_a, stft_b), "(b c) f l -> b c f l", b=b) + + def decode(self, stft_a: Tensor, stft_b: Tensor) -> Tensor: + b, l = stft_a.shape[0], stft_a.shape[-1] # noqa + length = closest_power_2(l * self.hop_length) + + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> (b c) f l") + + if self.use_complex: + real, imag = stft_a, stft_b + else: + magnitude, phase = stft_a, stft_b + real, imag = magnitude * torch.cos(phase), magnitude * torch.sin(phase) + + stft = torch.stack([real, imag], dim=-1) + + wave = torch.istft( + stft, + n_fft=self.num_fft, + hop_length=self.hop_length, + win_length=self.window_length, + window=self.window, # type: ignore + length=default(self.length, length), + normalized=True, + ) + + return rearrange(wave, "(b c) t -> b c t", b=b) + + def encode1d( + self, wave: Tensor, stacked: bool = True + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + stft_a, stft_b = self.encode(wave) + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b c f l -> b (c f) l") + return torch.cat((stft_a, stft_b), dim=1) if stacked else (stft_a, stft_b) + + def decode1d(self, stft_pair: Tensor) -> Tensor: + f = self.num_fft // 2 + 1 + stft_a, stft_b = stft_pair.chunk(chunks=2, dim=1) + stft_a, stft_b = rearrange_many((stft_a, stft_b), "b (c f) l -> b c f l", f=f) + return self.decode(stft_a, stft_b) diff --git a/think_sound/models/autoencoders.py b/think_sound/models/autoencoders.py new file mode 100644 index 0000000000000000000000000000000000000000..853fdc9712c12284b2073c1cce8c384baa258337 --- /dev/null +++ b/think_sound/models/autoencoders.py @@ -0,0 +1,800 @@ +import torch +import math +import numpy as np + +from torch import nn +from torch.nn import functional as F +from torchaudio import transforms as T +from alias_free_torch import Activation1d +from dac.nn.layers import WNConv1d, WNConvTranspose1d +from typing import Literal, Dict, Any + +from ..inference.sampling import sample +from ..inference.utils import prepare_audio +from .blocks import SnakeBeta +from .bottleneck import Bottleneck, DiscreteBottleneck +from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper +from .factory import create_pretransform_from_config, create_bottleneck_from_config +from .pretransforms import Pretransform + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: + if activation == "elu": + act = nn.ELU() + elif activation == "snake": + act = SnakeBeta(channels) + elif activation == "none": + act = nn.Identity() + else: + raise ValueError(f"Unknown activation {activation}") + + if antialias: + act = Activation1d(act) + + return act + +class ResidualUnit(nn.Module): + def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): + super().__init__() + + self.dilation = dilation + + padding = (dilation * (7-1)) // 2 + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=7, dilation=dilation, padding=padding), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=out_channels, out_channels=out_channels, + kernel_size=1) + ) + + def forward(self, x): + res = x + + #x = checkpoint(self.layers, x) + x = self.layers(x) + + return x + res + +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): + super().__init__() + + self.layers = nn.Sequential( + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=9, use_snake=use_snake), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), + ) + + def forward(self, x): + return self.layers(x) + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): + super().__init__() + + if use_nearest_upsample: + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=stride, mode="nearest"), + WNConv1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, + stride=1, + bias=False, + padding='same') + ) + else: + upsample_layer = WNConvTranspose1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + upsample_layer, + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=9, use_snake=use_snake), + ) + + def forward(self, x): + return self.layers(x) + +class OobleckEncoder(nn.Module): + def __init__(self, + in_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False + ): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) + ] + + for i in range(self.depth-1): + layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), + WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class OobleckDecoder(nn.Module): + def __init__(self, + out_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False, + use_nearest_upsample=False, + final_tanh=True): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), + ] + + for i in range(self.depth-1, 0, -1): + layers += [DecoderBlock( + in_channels=c_mults[i]*channels, + out_channels=c_mults[i-1]*channels, + stride=strides[i-1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample + ) + ] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), + WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), + nn.Tanh() if final_tanh else nn.Identity() + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class DACEncoderWrapper(nn.Module): + def __init__(self, in_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Encoder as DACEncoder + + latent_dim = kwargs.pop("latent_dim", None) + + encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) + self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) + self.latent_dim = latent_dim + + # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility + self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() + + if in_channels != 1: + self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) + + def forward(self, x): + x = self.encoder(x) + x = self.proj_out(x) + return x + +class DACDecoderWrapper(nn.Module): + def __init__(self, latent_dim, out_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Decoder as DACDecoder + + self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) + + self.latent_dim = latent_dim + + def forward(self, x): + return self.decoder(x) + +class AudioAutoencoder(nn.Module): + def __init__( + self, + encoder, + decoder, + latent_dim, + downsampling_ratio, + sample_rate, + io_channels=2, + bottleneck: Bottleneck = None, + pretransform: Pretransform = None, + in_channels = None, + out_channels = None, + soft_clip = False + ): + super().__init__() + + self.downsampling_ratio = downsampling_ratio + self.sample_rate = sample_rate + + self.latent_dim = latent_dim + self.io_channels = io_channels + self.in_channels = io_channels + self.out_channels = io_channels + + self.min_length = self.downsampling_ratio + + if in_channels is not None: + self.in_channels = in_channels + + if out_channels is not None: + self.out_channels = out_channels + + self.bottleneck = bottleneck + + self.encoder = encoder + + self.decoder = decoder + + self.pretransform = pretransform + + self.soft_clip = soft_clip + + self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete + + def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): + + info = {} + # import ipdb + # ipdb.set_trace() + if self.pretransform is not None and not skip_pretransform: + if self.pretransform.enable_grad: + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + else: + with torch.no_grad(): + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + + if self.encoder is not None: + if iterate_batch: + latents = [] + for i in range(audio.shape[0]): + latents.append(self.encoder(audio[i:i+1])) + latents = torch.cat(latents, dim=0) + else: + latents = self.encoder(audio) + else: + latents = audio + + if self.bottleneck is not None: + # TODO: Add iterate batch logic, needs to merge the info dicts + latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) + + info.update(bottleneck_info) + + if return_info: + return latents, info + + return latents + + def decode(self, latents, iterate_batch=False, **kwargs): + + if self.bottleneck is not None: + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.bottleneck.decode(latents[i:i+1])) + latents = torch.cat(decoded, dim=0) + else: + latents = self.bottleneck.decode(latents) + + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.decoder(latents[i:i+1])) + decoded = torch.cat(decoded, dim=0) + else: + decoded = self.decoder(latents, **kwargs) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + if iterate_batch: + decodeds = [] + for i in range(decoded.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + if iterate_batch: + decodeds = [] + for i in range(latents.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + + if self.soft_clip: + decoded = torch.tanh(decoded) + + return decoded + + def decode_tokens(self, tokens, **kwargs): + ''' + Decode discrete tokens to audio + Only works with discrete autoencoders + ''' + + assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" + + latents = self.bottleneck.decode_tokens(tokens, **kwargs) + + return self.decode(latents, **kwargs) + + + def preprocess_audio_for_encoder(self, audio, in_sr): + ''' + Preprocess single audio tensor (Channels x Length) to be compatible with the encoder. + If the model is mono, stereo audio will be converted to mono. + Audio will be silence-padded to be a multiple of the model's downsampling ratio. + Audio will be resampled to the model's sample rate. + The output will have batch size 1 and be shape (1 x Channels x Length) + ''' + return self.preprocess_audio_list_for_encoder([audio], [in_sr]) + + def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): + ''' + Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. + The audio in that list can be of different lengths and channels. + in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio. + All audio will be resampled to the model's sample rate. + Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. + If the model is mono, all audio will be converted to mono. + The output will be a tensor of shape (Batch x Channels x Length) + ''' + batch_size = len(audio_list) + if isinstance(in_sr_list, int): + in_sr_list = [in_sr_list]*batch_size + assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" + new_audio = [] + max_length = 0 + # resample & find the max length + for i in range(batch_size): + audio = audio_list[i] + in_sr = in_sr_list[i] + if len(audio.shape) == 3 and audio.shape[0] == 1: + # batchsize 1 was given by accident. Just squeeze it. + audio = audio.squeeze(0) + elif len(audio.shape) == 1: + # Mono signal, channel dimension is missing, unsqueeze it in + audio = audio.unsqueeze(0) + assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" + # Resample audio + if in_sr != self.sample_rate: + resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) + audio = resample_tf(audio) + new_audio.append(audio) + if audio.shape[-1] > max_length: + max_length = audio.shape[-1] + # Pad every audio to the same length, multiple of model's downsampling ratio + padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length + for i in range(batch_size): + # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model + new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, + target_channels=self.in_channels, device=new_audio[i].device).squeeze(0) + # convert to tensor + return torch.stack(new_audio) + + def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. + If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. + Overlap and chunk_size params are both measured in number of latents (not audio samples) + # and therefore you likely could use the same values with decode_audio. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Encode the entire audio in parallel + return self.encode(audio, **kwargs) + else: + # CHUNKED ENCODING + # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) + # import ipdb + # ipdb.set_trace() + samples_per_latent = self.downsampling_ratio + total_size = audio.shape[2] # in samples + print(f'audio shape: {audio.shape}') + batch_size = audio.shape[0] + chunk_size *= samples_per_latent # converting metric in latents to samples + overlap *= samples_per_latent # converting metric in latents to samples + hop_size = chunk_size - overlap + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = audio[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = audio[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # Note: y_size might be a different value from the latent length used in diffusion training + # because we can encode audio of varying lengths + # However, the audio should've been padded to a multiple of samples_per_latent by now. + y_size = total_size // samples_per_latent + # Create an empty latent, we will populate it with chunks as we encode them + y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) + print(f'y_final shape: {y_final.shape}') + for i in range(num_chunks): + x_chunk = chunks[i,:] + # encode the chunk + y_chunk = self.encode(x_chunk) + print(f'y_chunk shape: {y_chunk.shape}') + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size // samples_per_latent + t_end = t_start + chunk_size // samples_per_latent + # remove the edges of the overlaps + ol = overlap//samples_per_latent//2 + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Decode latents to audio. + If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Decode the entire latent in parallel + return self.decode(latents, **kwargs) + else: + # chunked decoding + hop_size = chunk_size - overlap + total_size = latents.shape[2] + batch_size = latents.shape[0] + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = latents[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = latents[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # samples_per_latent is just the downsampling ratio + samples_per_latent = self.downsampling_ratio + # Create an empty waveform, we will populate it with chunks as decode them + y_size = total_size * samples_per_latent + y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # decode the chunk + y_chunk = self.decode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size * samples_per_latent + t_end = t_start + chunk_size * samples_per_latent + # remove the edges of the overlaps + ol = (overlap//2) * samples_per_latent + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + +class DiffusionAutoencoder(AudioAutoencoder): + def __init__( + self, + diffusion: ConditionedDiffusionModel, + diffusion_downsampling_ratio, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + + self.diffusion = diffusion + + self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio + + if self.encoder is not None: + # Shrink the initial encoder parameters to avoid saturated latents + with torch.no_grad(): + for param in self.encoder.parameters(): + param *= 0.5 + + def decode(self, latents, steps=100): + + upsampled_length = latents.shape[2] * self.downsampling_ratio + + if self.bottleneck is not None: + latents = self.bottleneck.decode(latents) + + if self.decoder is not None: + latents = self.decode(latents) + + # Upsample latents to match diffusion length + if latents.shape[2] != upsampled_length: + latents = F.interpolate(latents, size=upsampled_length, mode='nearest') + + noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device) + decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + decoded = self.pretransform.decode(decoded) + + return decoded + +# AE factories + +def create_encoder_from_config(encoder_config: Dict[str, Any]): + encoder_type = encoder_config.get("type", None) + assert encoder_type is not None, "Encoder type must be specified" + + if encoder_type == "oobleck": + encoder = OobleckEncoder( + **encoder_config["config"] + ) + + elif encoder_type == "seanet": + from encodec.modules import SEANetEncoder + seanet_encoder_config = encoder_config["config"] + + #SEANet encoder expects strides in reverse order + seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) + encoder = SEANetEncoder( + **seanet_encoder_config + ) + elif encoder_type == "dac": + dac_config = encoder_config["config"] + + encoder = DACEncoderWrapper(**dac_config) + elif encoder_type == "local_attn": + from .local_attention import TransformerEncoder1D + + local_attn_config = encoder_config["config"] + + encoder = TransformerEncoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown encoder type {encoder_type}") + + requires_grad = encoder_config.get("requires_grad", True) + if not requires_grad: + for param in encoder.parameters(): + param.requires_grad = False + + return encoder + +def create_decoder_from_config(decoder_config: Dict[str, Any]): + decoder_type = decoder_config.get("type", None) + assert decoder_type is not None, "Decoder type must be specified" + + if decoder_type == "oobleck": + decoder = OobleckDecoder( + **decoder_config["config"] + ) + elif decoder_type == "seanet": + from encodec.modules import SEANetDecoder + + decoder = SEANetDecoder( + **decoder_config["config"] + ) + elif decoder_type == "dac": + dac_config = decoder_config["config"] + + decoder = DACDecoderWrapper(**dac_config) + elif decoder_type == "local_attn": + from .local_attention import TransformerDecoder1D + + local_attn_config = decoder_config["config"] + + decoder = TransformerDecoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown decoder type {decoder_type}") + + requires_grad = decoder_config.get("requires_grad", True) + if not requires_grad: + for param in decoder.parameters(): + param.requires_grad = False + + return decoder + +def create_autoencoder_from_config(config: Dict[str, Any]): + + ae_config = config["model"] + + encoder = create_encoder_from_config(ae_config["encoder"]) + decoder = create_decoder_from_config(ae_config["decoder"]) + + bottleneck = ae_config.get("bottleneck", None) + + latent_dim = ae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = ae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = ae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + in_channels = ae_config.get("in_channels", None) + out_channels = ae_config.get("out_channels", None) + + pretransform = ae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = create_bottleneck_from_config(bottleneck) + + soft_clip = ae_config["decoder"].get("soft_clip", False) + + return AudioAutoencoder( + encoder, + decoder, + io_channels=io_channels, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + sample_rate=sample_rate, + bottleneck=bottleneck, + pretransform=pretransform, + in_channels=in_channels, + out_channels=out_channels, + soft_clip=soft_clip + ) + +def create_diffAE_from_config(config: Dict[str, Any]): + + diffae_config = config["model"] + + if "encoder" in diffae_config: + encoder = create_encoder_from_config(diffae_config["encoder"]) + else: + encoder = None + + if "decoder" in diffae_config: + decoder = create_decoder_from_config(diffae_config["decoder"]) + else: + decoder = None + + diffusion_model_type = diffae_config["diffusion"]["type"] + + if diffusion_model_type == "DAU1d": + diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"]) + elif diffusion_model_type == "adp_1d": + diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"]) + elif diffusion_model_type == "dit": + diffusion = DiTWrapper(**diffae_config["diffusion"]["config"]) + + latent_dim = diffae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = diffae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = diffae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + bottleneck = diffae_config.get("bottleneck", None) + + pretransform = diffae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = create_bottleneck_from_config(bottleneck) + + diffusion_downsampling_ratio = None, + + if diffusion_model_type == "DAU1d": + diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"]) + elif diffusion_model_type == "adp_1d": + diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"]) + elif diffusion_model_type == "dit": + diffusion_downsampling_ratio = 1 + + return DiffusionAutoencoder( + encoder=encoder, + decoder=decoder, + diffusion=diffusion, + io_channels=io_channels, + sample_rate=sample_rate, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + diffusion_downsampling_ratio=diffusion_downsampling_ratio, + bottleneck=bottleneck, + pretransform=pretransform + ) diff --git a/think_sound/models/blocks.py b/think_sound/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..3c827fd2441e643717d123847236d3d6c003ef4f --- /dev/null +++ b/think_sound/models/blocks.py @@ -0,0 +1,339 @@ +from functools import reduce +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from torch.backends.cuda import sdp_kernel +from packaging import version + +from dac.nn.layers import Snake1d + +class ResidualBlock(nn.Module): + def __init__(self, main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return self.main(input) + self.skip(input) + +class ResConvBlock(ResidualBlock): + def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): + skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) + super().__init__([ + nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_mid), + Snake1d(c_mid) if use_snake else nn.GELU(), + nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), + (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), + ], skip) + +class SelfAttention1d(nn.Module): + def __init__(self, c_in, n_head=1, dropout_rate=0.): + super().__init__() + assert c_in % n_head == 0 + self.norm = nn.GroupNorm(1, c_in) + self.n_head = n_head + self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) + self.out_proj = nn.Conv1d(c_in, c_in, 1) + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward(self, input): + n, c, s = input.shape + qkv = self.qkv_proj(self.norm(input)) + qkv = qkv.view( + [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) + q, k, v = qkv.chunk(3, dim=1) + scale = k.shape[3]**-0.25 + + if self.use_flash: + with sdp_kernel(*self.sdp_kernel_config): + y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) + else: + att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) + y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) + + + return input + self.dropout(self.out_proj(y)) + +class SkipBlock(nn.Module): + def __init__(self, *main): + super().__init__() + self.main = nn.Sequential(*main) + + def forward(self, input): + return torch.cat([self.main(input), input], dim=1) + +class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1.): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter(torch.randn( + [out_features // 2, in_features]) * std) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + +def expand_to_planes(input, shape): + return input[..., None].repeat([1, 1, shape[2]]) + +_kernels = { + 'linear': + [1 / 8, 3 / 8, 3 / 8, 1 / 8], + 'cubic': + [-0.01171875, -0.03515625, 0.11328125, 0.43359375, + 0.43359375, 0.11328125, -0.03515625, -0.01171875], + 'lanczos3': + [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, + -0.066637322306633, 0.13550527393817902, 0.44638532400131226, + 0.44638532400131226, 0.13550527393817902, -0.066637322306633, + -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] +} + +class Downsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, (self.pad,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv1d(x, weight, stride=2) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + + +class Upsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + +def Downsample1d_2( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor, + padding=factor * (kernel_multiplier // 2), + ) + + +def Upsample1d_2( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return nn.Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ), + ) + else: + return nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor, + padding=factor // 2 + factor % 2, + output_padding=factor % 2, + ) + +def zero_init(layer): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + return layer + +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +#rms_norm = torch.compile(rms_norm) + +class AdaRMSNorm(nn.Module): + def __init__(self, features, cond_features, eps=1e-6): + super().__init__() + self.eps = eps + self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) + + def extra_repr(self): + return f"eps={self.eps}," + + def forward(self, x, cond): + return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) + +def normalize(x, eps=1e-4): + dim = list(range(1, x.ndim)) + n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) + alpha = np.sqrt(n.numel() / x.numel()) + return x / torch.add(eps, n, alpha=alpha) + +class ForcedWNConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1): + super().__init__() + self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) + + def forward(self, x): + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(self.weight)) + + fan_in = self.weight[0].numel() + + w = normalize(self.weight) / math.sqrt(fan_in) + + return F.conv1d(x, w, padding='same') + +# Kernels + +use_compile = True + +def compile(function, *args, **kwargs): + if not use_compile: + return function + try: + return torch.compile(function, *args, **kwargs) + except RuntimeError: + return function + + +@compile +def linear_geglu(x, weight, bias=None): + x = x @ weight.mT + if bias is not None: + x = x + bias + x, gate = x.chunk(2, dim=-1) + return x * F.gelu(gate) + + +@compile +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +# Layers + +class LinearGEGLU(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features * 2, bias=bias) + self.out_features = out_features + + def forward(self, x): + return linear_geglu(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, shape, fix_scale = False, eps=1e-6): + super().__init__() + self.eps = eps + + if fix_scale: + self.register_buffer("scale", torch.ones(shape)) + else: + self.scale = nn.Parameter(torch.ones(shape)) + + def extra_repr(self): + return f"shape={tuple(self.scale.shape)}, eps={self.eps}" + + def forward(self, x): + return rms_norm(x, self.scale, self.eps) + +def snake_beta(x, alpha, beta): + return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) + +# try: +# snake_beta = torch.compile(snake_beta) +# except RuntimeError: +# pass + +# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license +# License available in LICENSES/LICENSE_NVIDIA.txt +class SnakeBeta(nn.Module): + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + self.no_div_by_zero = 0.000000001 + + def forward(self, x): + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = snake_beta(x, alpha, beta) + + return x \ No newline at end of file diff --git a/think_sound/models/bottleneck.py b/think_sound/models/bottleneck.py new file mode 100644 index 0000000000000000000000000000000000000000..5e81cab4bfb16b615ee21d5e9248e3b455f7eb5b --- /dev/null +++ b/think_sound/models/bottleneck.py @@ -0,0 +1,355 @@ +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from einops import rearrange +from vector_quantize_pytorch import ResidualVQ, FSQ +from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ + +class Bottleneck(nn.Module): + def __init__(self, is_discrete: bool = False): + super().__init__() + + self.is_discrete = is_discrete + + def encode(self, x, return_info=False, **kwargs): + raise NotImplementedError + + def decode(self, x): + raise NotImplementedError + +class DiscreteBottleneck(Bottleneck): + def __init__(self, num_quantizers, codebook_size, tokens_id): + super().__init__(is_discrete=True) + + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + self.tokens_id = tokens_id + + def decode_tokens(self, codes, **kwargs): + raise NotImplementedError + +class TanhBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + self.tanh = nn.Tanh() + + def encode(self, x, return_info=False): + info = {} + + x = torch.tanh(x) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + +def vae_sample(mean, scale): + stdev = nn.functional.softplus(scale) + 1e-4 + var = stdev * stdev + logvar = torch.log(var) + latents = torch.randn_like(mean) * stdev + mean + + kl = (mean * mean + var - logvar - 1).sum(1).mean() + + return latents, kl + +class VAEBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False, **kwargs): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["kl"] = kl + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + +def compute_mean_kernel(x, y): + kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] + return torch.exp(-kernel_input).mean() + +def compute_mmd(latents): + latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) + noise = torch.randn_like(latents_reshaped) + + latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) + noise_kernel = compute_mean_kernel(noise, noise) + latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) + + mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel + return mmd.mean() + +class WassersteinBottleneck(Bottleneck): + def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False): + super().__init__(is_discrete=False) + + self.noise_augment_dim = noise_augment_dim + self.bypass_mmd = bypass_mmd + + def encode(self, x, return_info=False): + info = {} + + if self.training and return_info: + if self.bypass_mmd: + mmd = torch.tensor(0.0) + else: + mmd = compute_mmd(x) + + info["mmd"] = mmd + + if return_info: + return x, info + + return x + + def decode(self, x): + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + +class L2Bottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False): + info = {} + + x = F.normalize(x, dim=1) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return F.normalize(x, dim=1) + +class RVQBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False, **kwargs): + info = {} + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + +class RVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False): + info = {} + + x, kl = vae_sample(*x.chunk(2, dim=1)) + + info["kl"] = kl + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + +class DACRVQBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + self.noise_augment_dim = noise_augment_dim + + def encode(self, x, return_info=False, **kwargs): + info = {} + + info["pre_quantizer"] = x + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + +class DACRVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + + def encode(self, x, return_info=False, n_quantizers: int = None): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["pre_quantizer"] = x + info["kl"] = kl + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + +class FSQBottleneck(DiscreteBottleneck): + def __init__(self, noise_augment_dim=0, **kwargs): + super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices") + + self.noise_augment_dim = noise_augment_dim + + self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64]) + + def encode(self, x, return_info=False): + info = {} + + orig_dtype = x.dtype + x = x.float() + + x = rearrange(x, "b c n -> b n c") + x, indices = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + x = x.to(orig_dtype) + + # Reorder indices to match the expected format + indices = rearrange(indices, "b n q -> b q n") + + info["quantizer_indices"] = indices + + if return_info: + return x, info + else: + return x + + def decode(self, x): + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + + def decode_tokens(self, tokens, **kwargs): + latents = self.quantizer.indices_to_codes(tokens) + + return self.decode(latents, **kwargs) \ No newline at end of file diff --git a/think_sound/models/codebook_patterns.py b/think_sound/models/codebook_patterns.py new file mode 100644 index 0000000000000000000000000000000000000000..f9bd2a9b837bd77cb40f3b500b02ea491dfb9da0 --- /dev/null +++ b/think_sound/models/codebook_patterns.py @@ -0,0 +1,545 @@ +# Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py under MIT License +# License available in LICENSES/LICENSE_META.txt + +from collections import namedtuple +from dataclasses import dataclass +from functools import lru_cache +import logging +import typing as tp + +from abc import ABC, abstractmethod +import torch + +LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index) +PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates +logger = logging.getLogger(__name__) + + +@dataclass +class Pattern: + """Base implementation of a pattern over a sequence with multiple codebooks. + + The codebook pattern consists in a layout, defining for each sequence step + the list of coordinates of each codebook timestep in the resulting interleaved sequence. + The first item of the pattern is always an empty list in order to properly insert a special token + to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern + and ``timesteps`` the number of timesteps corresponding to the original sequence. + + The pattern provides convenient methods to build and revert interleaved sequences from it: + ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T] + to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size, + K being the number of codebooks, T the number of original timesteps and S the number of sequence steps + for the output sequence. The unfilled positions are replaced with a special token and the built sequence + is returned along with a mask indicating valid tokens. + ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment + of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask + to fill and specify invalid positions if needed. + See the dedicated methods for more details. + """ + # Pattern layout, for each sequence step, we have a list of coordinates + # corresponding to the original codebook timestep and position. + # The first list is always an empty list in order to properly insert + # a special token to start with. + layout: PatternLayout + timesteps: int + n_q: int + + def __post_init__(self): + assert len(self.layout) > 0 + self._validate_layout() + self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes) + self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes) + logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout)) + + def _validate_layout(self): + """Runs checks on the layout to ensure a valid pattern is defined. + A pattern is considered invalid if: + - Multiple timesteps for a same codebook are defined in the same sequence step + - The timesteps for a given codebook are not in ascending order as we advance in the sequence + (this would mean that we have future timesteps before past timesteps). + """ + q_timesteps = {q: 0 for q in range(self.n_q)} + for s, seq_coords in enumerate(self.layout): + if len(seq_coords) > 0: + qs = set() + for coord in seq_coords: + qs.add(coord.q) + last_q_timestep = q_timesteps[coord.q] + assert coord.t >= last_q_timestep, \ + f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" + q_timesteps[coord.q] = coord.t + # each sequence step contains at max 1 coordinate per codebook + assert len(qs) == len(seq_coords), \ + f"Multiple entries for a same codebook are found at step {s}" + + @property + def num_sequence_steps(self): + return len(self.layout) - 1 + + @property + def max_delay(self): + max_t_in_seq_coords = 0 + for seq_coords in self.layout[1:]: + for coords in seq_coords: + max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1) + return max_t_in_seq_coords - self.timesteps + + @property + def valid_layout(self): + valid_step = len(self.layout) - self.max_delay + return self.layout[:valid_step] + + def starts_with_special_token(self): + return self.layout[0] == [] + + def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None): + """Get codebook coordinates in the layout that corresponds to the specified timestep t + and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step + and the actual codebook coordinates. + """ + assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" + if q is not None: + assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks" + coords = [] + for s, seq_codes in enumerate(self.layout): + for code in seq_codes: + if code.t == t and (q is None or code.q == q): + coords.append((s, code)) + return coords + + def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]: + return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)] + + def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]: + steps_with_timesteps = self.get_steps_with_timestep(t, q) + return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None + + def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool, + device: tp.Union[torch.device, str] = 'cpu'): + """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps. + + Args: + timesteps (int): Maximum number of timesteps steps to consider. + keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps. + device (torch.device or str): Device for created tensors. + Returns: + indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. + """ + assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" + assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern" + # use the proper layout based on whether we limit ourselves to valid steps only or not, + # note that using the valid_layout will result in a truncated sequence up to the valid steps + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy() + mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + # the last value is n_q * timesteps as we have flattened z and append special token as the last token + # which will correspond to the index: n_q * timesteps + indexes[:] = n_q * timesteps + # iterate over the pattern and fill scattered indexes and mask + for s, sequence_coords in enumerate(ref_layout): + for coords in sequence_coords: + if coords.t < timesteps: + indexes[coords.q, s] = coords.t + coords.q * timesteps + mask[coords.q, s] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Build sequence corresponding to the pattern from the input tensor z. + The sequence is built using up to sequence_steps if specified, and non-pattern + coordinates are filled with the special token. + + Args: + z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T]. + special_token (int): Special token used to fill non-pattern coordinates in the new sequence. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S + corresponding either to the sequence_steps if provided, otherwise to the length of the pattern. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S]. + """ + B, K, T = z.shape + indexes, mask = self._build_pattern_sequence_scatter_indexes( + T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device) + ) + z = z.view(B, -1) + # we append the special token as the last index of our flattened z tensor + z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1) + values = z[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int, + keep_only_valid_steps: bool = False, + is_model_output: bool = False, + device: tp.Union[torch.device, str] = 'cpu'): + """Builds scatter indexes required to retrieve the original multi-codebook sequence + from interleaving pattern. + + Args: + sequence_steps (int): Sequence steps. + n_q (int): Number of codebooks. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not. + device (torch.device or str): Device for created tensors. + Returns: + indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # TODO(jade): Do we want to further truncate to only valid timesteps here as well? + timesteps = self.timesteps + assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}" + assert sequence_steps <= len(ref_layout), \ + f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" + + # ensure we take the appropriate indexes to keep the model output from the first special token as well + if is_model_output and self.starts_with_special_token(): + ref_layout = ref_layout[1:] + + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy() + mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + indexes[:] = n_q * sequence_steps + for s, sequence_codes in enumerate(ref_layout): + if s < sequence_steps: + for code in sequence_codes: + if code.t < timesteps: + indexes[code.q, code.t] = s + code.q * sequence_steps + mask[code.q, code.t] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving. + The sequence is reverted using up to timesteps if specified, and non-pattern coordinates + are filled with the special token. + + Args: + s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S]. + special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T + corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + B, K, S = s.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device) + ) + s = s.view(B, -1) + # we append the special token as the last index of our flattened z tensor + s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1) + values = s[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False): + """Revert model logits obtained on a sequence built from the pattern + back to a tensor matching the original sequence. + + This method is similar to ``revert_pattern_sequence`` with the following specificities: + 1. It is designed to work with the extra cardinality dimension + 2. We return the logits for the first sequence item that matches the special_token and + which matching target in the original sequence is the first item of the sequence, + while we skip the last logits as there is no matching target + """ + B, card, K, S = logits.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=True, device=logits.device + ) + logits = logits.reshape(B, card, -1) + # we append the special token as the last index of our flattened z tensor + logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S] + values = logits[:, :, indexes.view(-1)] + values = values.view(B, card, K, indexes.shape[-1]) + return values, indexes, mask + + +class CodebooksPatternProvider(ABC): + """Abstraction around providing pattern for interleaving codebooks. + + The CodebooksPatternProvider abstraction allows to implement various strategies to + define interleaving pattern of sequences composed of multiple codebooks. For a given + number of codebooks `n_q`, the pattern provider can generate a specified pattern + corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern + can be used to construct a new sequence from the original codes respecting the specified + pattern. The pattern is defined as a list of list of code coordinates, code coordinate + being a tuple with the original timestep and codebook to build the new sequence. + Note that all patterns must start with an empty list that is then used to insert a first + sequence step of special tokens in the newly generated sequence. + + Args: + n_q (int): number of codebooks. + cached (bool): if True, patterns for a given length are cached. In general + that should be true for efficiency reason to avoid synchronization points. + """ + def __init__(self, n_q: int, cached: bool = True): + assert n_q > 0 + self.n_q = n_q + self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore + + @abstractmethod + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern with specific interleaving between codebooks. + + Args: + timesteps (int): Total number of timesteps. + """ + raise NotImplementedError() + + +class DelayedPatternProvider(CodebooksPatternProvider): + """Provider for delayed pattern across delayed codebooks. + Codebooks are delayed in the sequence and sequence steps will contain codebooks + from different timesteps. + + Example: + Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + The resulting sequence obtained from the returned pattern is: + [[S, 1, 2, 3, 4], + [S, S, 1, 2, 3], + [S, S, S, 1, 2]] + (with S being a special token) + + Args: + n_q (int): Number of codebooks. + delays (list of int, optional): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + flatten_first (int): Flatten the first N timesteps. + empty_initial (int): Prepend with N empty list of coordinates. + """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None, + flatten_first: int = 0, empty_initial: int = 0): + super().__init__(n_q) + if delays is None: + delays = list(range(n_q)) + self.delays = delays + self.flatten_first = flatten_first + self.empty_initial = empty_initial + assert len(self.delays) == self.n_q + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + omit_special_token = self.empty_initial < 0 + out: PatternLayout = [] if omit_special_token else [[]] + max_delay = max(self.delays) + if self.empty_initial: + out += [[] for _ in range(self.empty_initial)] + if self.flatten_first: + for t in range(min(timesteps, self.flatten_first)): + for q in range(self.n_q): + out.append([LayoutCoord(t, q)]) + for t in range(self.flatten_first, timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= self.flatten_first: + v.append(LayoutCoord(t_for_q, q)) + out.append(v) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class ParallelPatternProvider(DelayedPatternProvider): + """Provider for parallel pattern across codebooks. + This pattern provider is a special case of the delayed pattern with actually no delay, + hence delays=repeat(0, n_q). + + Args: + n_q (int): Number of codebooks. + empty_initial (int): Prepend with N empty list of coordinates. + """ + def __init__(self, n_q: int, empty_initial: int = 0): + super().__init__(n_q, [0] * n_q, empty_initial=empty_initial) + + +class UnrolledPatternProvider(CodebooksPatternProvider): + """Provider for unrolling codebooks pattern. + This pattern provider enables to represent the codebook flattened completely or only to some extend + while also specifying a given delay between the flattened codebooks representation, allowing to + unroll the codebooks in the sequence. + + Example: + 1. Flattening of the codebooks. + By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q), + taking n_q = 3 and timesteps = 4: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, S, 1, S, S, 2, S, S, 3, S, S, 4], + [S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [1, S, S, 2, S, S, 3, S, S, 4, S, S]] + 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step + for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example + taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [S, 1, S, S, 2, S, S, 3, S, S, 4, S], + [1, S, S, 2, S, S, 3, S, S, 4, S, S]] + 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks + allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the + same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1] + and delays = [0, 3, 3]: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + will result into: + [[S, S, S, 1, S, 2, S, 3, S, 4], + [S, S, S, 1, S, 2, S, 3, S, 4], + [1, 2, 3, S, 4, S, 5, S, 6, S]] + + Args: + n_q (int): Number of codebooks. + flattening (list of int, optional): Flattening schema over the codebooks. If not defined, + the codebooks will be flattened to 1 codebook per step, meaning that the sequence will + have n_q extra steps for each timestep. + delays (list of int, optional): Delay for each of the codebooks. If not defined, + no delay is added and therefore will default to [0] * ``n_q``. + Note that two codebooks that will be flattened to the same inner step + should have the same delay, otherwise the pattern is considered as invalid. + """ + FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay']) + + def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None, + delays: tp.Optional[tp.List[int]] = None): + super().__init__(n_q) + if flattening is None: + flattening = list(range(n_q)) + if delays is None: + delays = [0] * n_q + assert len(flattening) == n_q + assert len(delays) == n_q + assert sorted(flattening) == flattening + assert sorted(delays) == delays + self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening) + self.max_delay = max(delays) + + def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]): + """Build a flattened codebooks representation as a dictionary of inner step + and the actual codebook indices corresponding to the flattened codebook. For convenience, we + also store the delay associated to the flattened codebook to avoid maintaining an extra mapping. + """ + flattened_codebooks: dict = {} + for q, (inner_step, delay) in enumerate(zip(flattening, delays)): + if inner_step not in flattened_codebooks: + flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay) + else: + flat_codebook = flattened_codebooks[inner_step] + assert flat_codebook.delay == delay, ( + "Delay and flattening between codebooks is inconsistent: ", + "two codebooks flattened to the same position should have the same delay." + ) + flat_codebook.codebooks.append(q) + flattened_codebooks[inner_step] = flat_codebook + return flattened_codebooks + + @property + def _num_inner_steps(self): + """Number of inner steps to unroll between timesteps in order to flatten the codebooks. + """ + return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1 + + def num_virtual_steps(self, timesteps: int) -> int: + return timesteps * self._num_inner_steps + 1 + + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern for delay across codebooks. + + Args: + timesteps (int): Total number of timesteps. + """ + # the PatternLayout is built as a tuple of sequence position and list of coordinates + # so that it can be reordered properly given the required delay between codebooks of given timesteps + indexed_out: list = [(-1, [])] + max_timesteps = timesteps + self.max_delay + for t in range(max_timesteps): + # for each timestep, we unroll the flattened codebooks, + # emitting the sequence step with the corresponding delay + for step in range(self._num_inner_steps): + if step in self._flattened_codebooks: + # we have codebooks at this virtual step to emit + step_codebooks = self._flattened_codebooks[step] + t_for_q = t + step_codebooks.delay + coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks] + if t_for_q < max_timesteps and t < max_timesteps: + indexed_out.append((t_for_q, coords)) + else: + # there is no codebook in this virtual step so we emit an empty list + indexed_out.append((t, [])) + out = [coords for _, coords in sorted(indexed_out)] + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class CoarseFirstPattern(CodebooksPatternProvider): + """First generates all the codebooks #1 (e.g. coarser), then the remaining ones, + potentially with delays. + + ..Warning:: You must always generate the full training duration at test time, for instance, + 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected + location. This is due to the non causality of the remaining codebooks with respect to + the first ones. + + Args: + n_q (int): Number of codebooks. + delays (list of int, optional): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + """ + def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None): + super().__init__(n_q) + if delays is None: + delays = [0] * (n_q - 1) + self.delays = delays + assert len(self.delays) == self.n_q - 1 + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + for t in range(timesteps): + out.append([LayoutCoord(t, 0)]) + max_delay = max(self.delays) + for t in range(timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= 0: + v.append(LayoutCoord(t_for_q, q + 1)) + out.append(v) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) + + +class MusicLMPattern(CodebooksPatternProvider): + """Almost MusicLM style pattern. This is equivalent to full flattening + but in a different order. + + Args: + n_q (int): Number of codebooks. + group_by (int): Number of codebooks to group together. + """ + def __init__(self, n_q: int, group_by: int = 2): + super().__init__(n_q) + self.group_by = group_by + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + for offset in range(0, self.n_q, self.group_by): + for t in range(timesteps): + for q in range(offset, offset + self.group_by): + out.append([LayoutCoord(t, q)]) + return Pattern(out, n_q=self.n_q, timesteps=timesteps) \ No newline at end of file diff --git a/think_sound/models/conditioners.py b/think_sound/models/conditioners.py new file mode 100644 index 0000000000000000000000000000000000000000..915c6fd753d69ed16e0366b9b8051df231d7e5ef --- /dev/null +++ b/think_sound/models/conditioners.py @@ -0,0 +1,1006 @@ +#Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py + +import torch +import logging, warnings +import string +import typing as tp +import gc +from typing import Literal, Optional +import os +from .adp import NumberEmbedder +from ..inference.utils import set_audio_channels +from .factory import create_pretransform_from_config +from .pretransforms import Pretransform +from ..training.utils import copy_state_dict +from .utils import load_ckpt_state_dict +import numpy as np +from einops import rearrange +from transformers import AutoProcessor, AutoModel +from torch import nn + +class Conditioner(nn.Module): + def __init__( + self, + dim: int, + output_dim: int, + project_out: bool = False + ): + + super().__init__() + + self.dim = dim + self.output_dim = output_dim + self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity() + + def forward(self, x: tp.Any) -> tp.Any: + raise NotImplementedError() + +class VideoHieraConditioner(Conditioner): + def __init__(self, + output_dim: int, + hiera_ckpt_path, + project_out: bool = False, + finetune: bool = False): + super().__init__(768, output_dim, project_out=project_out) + + self.finetune = finetune + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + from hiera import Hiera + import hiera + # model = hiera.hiera_base_16x224(pretrained=True, checkpoint="useful_ckpts/hiera_base_224.mae_in1k_ft_in1k") + model = Hiera( + num_classes=400, # K400 has 400 classes + input_size=(64, 224, 224), + q_stride=[(1, 4, 4),(1,7,7),(1,2,2)], + mask_unit_size=(1, 8, 8), + patch_kernel=(3, 7, 7), + patch_stride=(2, 4, 4), + patch_padding=(1, 3, 3), + sep_pos_embed=True, + ) + state_dict = torch.load(hiera_ckpt_path)['model_state'] + state_dict.pop('pos_embed_temporal', None) # 如果不需要这个参数 + model.load_state_dict(state_dict,strict=False) + if self.finetune: + self.model = model + else: + self.__dict__["model"] = model + + state_dict = model.state_dict() + self.model.load_state_dict(state_dict, strict=False) + + if self.finetune: + self.model.requires_grad_(True) + self.model.train() + else: + self.model.requires_grad_(False) + self.model.train() + + finally: + logging.disable(previous_level) + + + gc.collect() + torch.cuda.empty_cache() + + def forward(self, x: tp.List[str], device: tp.Any = "cuda") -> tp.Any: + self.model.to(device) + import ipdb + ipdb.set_trace() + output, interm = model(x,return_intermediates=True) + + video_features = interm[-1] + return [self.proj_out(video_features), torch.ones(video_features.shape[0], 1).to(device)] + +class Video_Linear(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + video_feats.append(torch.load(path)['metaclip_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Video_Global(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim, global_dim=1536): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) + self.global_proj = nn.Sequential(nn.Linear(output_dim, global_dim)) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + data = torch.load(path) + video_feats.append(data['metaclip_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + global_x = self.global_proj(x.mean(dim=1)) + return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)] + +class Video_Sync(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + video_feats.append(torch.load(path)['sync_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + +class Text_Linear(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + self.embedder = nn.Sequential(nn.Linear(dim, output_dim)) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + video_feats.append(torch.load(path)['metaclip_text_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + + x = self.embedder(x) # B x 117 x C + return [x, torch.ones(x.shape[0], 1).to(device)] + + +class mm_unchang(Conditioner): + """ Transform the video feat encoder""" + + def __init__(self, dim, output_dim): + super().__init__(dim, output_dim) + + def forward(self, x, device: tp.Any = "cuda"): + # import ipdb + # ipdb.set_trace() + if not isinstance(x[0], torch.Tensor): + video_feats = [] + for path in x: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + elif '.pth' in path: + video_feats.append(torch.load(path)['metaclip_features'].to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device)) + x = torch.stack(video_feats, dim=0).to(device) + else: + # Revise the shape here: + x = torch.stack(x, dim=0).to(device) + return [x] + +class CLIPConditioner(Conditioner): + + CLIP_MODELS = ["metaclip-base", "metaclip-b16", "metaclip-large", "metaclip-huge"] + + CLIP_MODEL_DIMS = { + "metaclip-base": 512, + "metaclip-b16": 512, + "metaclip-large": 768, + "metaclip-huge": 1024, + } + + def __init__( + self, + dim: int, + output_dim: int, + clip_model_name: str = "metaclip-huge", + enable_grad: bool = False, + project_out: bool = False + ): + assert clip_model_name in self.CLIP_MODELS, f"Unknown CLIP model name: {clip_model_name}" + super().__init__(self.CLIP_MODEL_DIMS[clip_model_name], output_dim, project_out=project_out) + + self.enable_grad = enable_grad + model = AutoModel.from_pretrained(f"useful_ckpts/{clip_model_name}").train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + + + + if self.enable_grad: + self.model = model + else: + self.__dict__["model"] = model + + + def forward(self, images: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + # import ipdb + # ipdb.set_trace() + + self.model.eval() + if not isinstance(images[0], torch.Tensor): + video_feats = [] + for path in images: + if '.npy' in path: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + else: + video_feats.append(torch.from_numpy(np.load(path)).to(device)) + images = torch.stack(video_feats, dim=0).to(device) + else: + images = torch.stack(images, dim=0).to(device) + bsz, t, c, h, w = images.shape + # 使用 rearrange 进行维度合并 + images = rearrange(images, 'b t c h w -> (b t) c h w') + with torch.set_grad_enabled(self.enable_grad): + image_features = self.model.get_image_features(images) + image_features = rearrange(image_features, '(b t) d -> b t d', b=bsz, t=t) + image_features = self.proj_out(image_features) + + + return [image_features, torch.ones(image_features.shape[0], 1).to(device)] + +class IntConditioner(Conditioner): + def __init__(self, + output_dim: int, + min_val: int=0, + max_val: int=512 + ): + super().__init__(output_dim, output_dim) + + self.min_val = min_val + self.max_val = max_val + self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True) + + def forward(self, ints: tp.List[int], device=None) -> tp.Any: + + #self.int_embedder.to(device) + + ints = torch.tensor(ints).to(device) + ints = ints.clamp(self.min_val, self.max_val) + + int_embeds = self.int_embedder(ints).unsqueeze(1) + + return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)] + +class NumberConditioner(Conditioner): + ''' + Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings + ''' + def __init__(self, + output_dim: int, + min_val: float=0, + max_val: float=1 + ): + super().__init__(output_dim, output_dim) + + self.min_val = min_val + self.max_val = max_val + + self.embedder = NumberEmbedder(features=output_dim) + + def forward(self, floats: tp.List[float], device=None) -> tp.Any: + + # Cast the inputs to floats + floats = [float(x) for x in floats] + + floats = torch.tensor(floats).to(device) + + floats = floats.clamp(self.min_val, self.max_val) + + normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) + + # Cast floats to same type as embedder + embedder_dtype = next(self.embedder.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + float_embeds = self.embedder(normalized_floats).unsqueeze(1) + + return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)] + +class CLAPTextConditioner(Conditioner): + def __init__(self, + output_dim: int, + clap_ckpt_path, + use_text_features = False, + feature_layer_ix: int = -1, + audio_model_type="HTSAT-base", + enable_fusion=True, + project_out: bool = False, + finetune: bool = False): + super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out) + + self.use_text_features = use_text_features + self.feature_layer_ix = feature_layer_ix + self.finetune = finetune + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import laion_clap + from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + + if self.finetune: + self.model = model + else: + self.__dict__["model"] = model + + state_dict = clap_load_state_dict(clap_ckpt_path) + self.model.model.load_state_dict(state_dict, strict=False) + + if self.finetune: + self.model.model.text_branch.requires_grad_(True) + self.model.model.text_branch.train() + else: + self.model.model.text_branch.requires_grad_(False) + self.model.model.text_branch.eval() + + finally: + logging.disable(previous_level) + + del self.model.model.audio_branch + + gc.collect() + torch.cuda.empty_cache() + + def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"): + prompt_tokens = self.model.tokenizer(prompts) + attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True) + prompt_features = self.model.model.text_branch( + input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True), + attention_mask=attention_mask, + output_hidden_states=True + )["hidden_states"][layer_ix] + + return prompt_features, attention_mask + + def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any: + self.model.to(device) + + if self.use_text_features: + if len(texts) == 1: + text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device) + text_features = text_features[:1, ...] + text_attention_mask = text_attention_mask[:1, ...] + else: + text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device) + return [self.proj_out(text_features), text_attention_mask] + + # Fix for CLAP bug when only one text is passed + if len(texts) == 1: + text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...] + else: + text_embedding = self.model.get_text_embedding(texts, use_tensor=True) + + text_embedding = text_embedding.unsqueeze(1).to(device) + + return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)] + +class CLAPAudioConditioner(Conditioner): + def __init__(self, + output_dim: int, + clap_ckpt_path, + audio_model_type="HTSAT-base", + enable_fusion=True, + project_out: bool = False): + super().__init__(512, output_dim, project_out=project_out) + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + import laion_clap + from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict + + model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') + + if self.finetune: + self.model = model + else: + self.__dict__["model"] = model + + state_dict = clap_load_state_dict(clap_ckpt_path) + self.model.model.load_state_dict(state_dict, strict=False) + + if self.finetune: + self.model.model.audio_branch.requires_grad_(True) + self.model.model.audio_branch.train() + else: + self.model.model.audio_branch.requires_grad_(False) + self.model.model.audio_branch.eval() + + finally: + logging.disable(previous_level) + + del self.model.model.text_branch + + gc.collect() + torch.cuda.empty_cache() + + def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any: + + self.model.to(device) + + if isinstance(audios, list) or isinstance(audios, tuple): + audios = torch.cat(audios, dim=0) + + # Convert to mono + mono_audios = audios.mean(dim=1) + + with torch.cuda.amp.autocast(enabled=False): + audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True) + + audio_embedding = audio_embedding.unsqueeze(1).to(device) + + return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)] + +class T5Conditioner(Conditioner): + + T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", + "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", + "google/flan-t5-xl", "google/flan-t5-xxl", "t5-v1_1-xl", "google/t5-v1_1-xxl"] + + T5_MODEL_DIMS = { + "t5-small": 512, + "t5-base": 768, + "t5-large": 1024, + "t5-3b": 1024, + "t5-11b": 1024, + "t5-v1_1-xl": 2048, + "google/t5-v1_1-xxl": 4096, + "google/flan-t5-small": 512, + "google/flan-t5-base": 768, + "google/flan-t5-large": 1024, + "google/flan-t5-3b": 1024, + "google/flan-t5-11b": 1024, + "google/flan-t5-xl": 2048, + "google/flan-t5-xxl": 4096, + } + + def __init__( + self, + output_dim: int, + t5_model_name: str = "t5-base", + max_length: str = 77, + enable_grad: bool = False, + project_out: bool = False + ): + assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" + super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) + + from transformers import T5EncoderModel, AutoTokenizer + + self.max_length = max_length + self.enable_grad = enable_grad + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length) + # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad) + self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('useful_ckpts', t5_model_name)) + model = T5EncoderModel.from_pretrained(os.path.join('useful_ckpts', t5_model_name)).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + finally: + logging.disable(previous_level) + + if self.enable_grad: + self.model = model + else: + self.__dict__["model"] = model + + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + self.model.eval() + + with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): + embeddings = self.model( + input_ids=input_ids, attention_mask=attention_mask + )["last_hidden_state"] + + embeddings = self.proj_out(embeddings.float()) + + embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, attention_mask + +def patch_clip(clip_model): + # a hack to make it output last hidden states + # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 + def new_encode_text(self, text, normalize: bool = False): + cast_dtype = self.transformer.get_cast_dtype() + + x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.to(cast_dtype) + x = self.transformer(x, attn_mask=self.attn_mask) + x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] + return F.normalize(x, dim=-1) if normalize else x + + clip_model.encode_text = new_encode_text.__get__(clip_model) + return clip_model + +class CLIPTextConditioner(Conditioner): + def __init__( + self, + output_dim: int, + max_length: str = 77, + enable_grad: bool = False, + project_out: bool = False + ): + super().__init__(1024, output_dim, project_out=project_out) + + from transformers import T5EncoderModel, AutoTokenizer + import open_clip + from open_clip import create_model_from_pretrained + + self.max_length = max_length + self.enable_grad = enable_grad + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',cache_dir='useful_ckpts/DFN5B-CLIP-ViT-H-14-384', + return_transform=False).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) + model = patch_clip(model) + self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14' + finally: + logging.disable(previous_level) + + if self.enable_grad: + self.model = model + else: + self.__dict__["model"] = model + + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + + encoded = self.tokenizer( + texts + ).to(device) + + # input_ids = encoded["input_ids"].to(device) + # attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + self.model.eval() + + with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): + embeddings = self.model.encode_text( + encoded + ) + + embeddings = self.proj_out(embeddings.float()) + + # embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, torch.ones(embeddings.shape[0], 1).to(device) + +def patch_clip(clip_model): + # a hack to make it output last hidden states + # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 + def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + last_hidden_state = text_outputs[0] + # pooled_output = text_outputs[1] + # text_features = self.text_projection(pooled_output) + + return last_hidden_state + + clip_model.get_text_features = new_get_text_features.__get__(clip_model) + return clip_model + +class MetaCLIPTextConditioner(Conditioner): + def __init__( + self, + output_dim: int, + max_length: str = 77, + enable_grad: bool = False, + project_out: bool = False + ): + super().__init__(1024, output_dim, project_out=project_out) + + from transformers import AutoModel + from transformers import AutoProcessor + + self.max_length = max_length + self.enable_grad = enable_grad + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + self.model = AutoModel.from_pretrained("useful_ckpts/metaclip-huge") + self.model = patch_clip(self.model) + self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge") + finally: + logging.disable(previous_level) + + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.model.to(device) + self.proj_out.to(device) + encoded = self.clip_processor(text=texts, return_tensors="pt", padding=True).to(device) + + # input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + self.model.eval() + + with torch.set_grad_enabled(self.enable_grad): + embeddings = self.model.get_text_features( + **encoded + ) + + embeddings = self.proj_out(embeddings.float()) + + # embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, torch.ones(embeddings.shape[0],1).to(device) + +class PhonemeConditioner(Conditioner): + """ + A conditioner that turns text into phonemes and embeds them using a lookup table + Only works for English text + + Args: + output_dim: the dimension of the output embeddings + max_length: the maximum number of phonemes to embed + project_out: whether to add another linear projection to the output embeddings + """ + + def __init__( + self, + output_dim: int, + max_length: int = 1024, + project_out: bool = False, + ): + super().__init__(output_dim, output_dim, project_out=project_out) + + from g2p_en import G2p + + self.max_length = max_length + + self.g2p = G2p() + + # Reserving 0 for padding, 1 for ignored + self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim) + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.phoneme_embedder.to(device) + self.proj_out.to(device) + + batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length] + + phoneme_ignore = [" ", *string.punctuation] + + # Remove ignored phonemes and cut to max length + batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes] + + # Convert to ids + phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes] + + #Pad to match longest and make a mask tensor for the padding + longest = max([len(ids) for ids in phoneme_ids]) + phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids] + + phoneme_ids = torch.tensor(phoneme_ids).to(device) + + # Convert to embeddings + phoneme_embeds = self.phoneme_embedder(phoneme_ids) + + phoneme_embeds = self.proj_out(phoneme_embeds) + + return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device) + +class TokenizerLUTConditioner(Conditioner): + """ + A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary + + Args: + tokenizer_name: the name of the tokenizer from the Hugging Face transformers library + output_dim: the dimension of the output embeddings + max_length: the maximum length of the text to embed + project_out: whether to add another linear projection to the output embeddings + """ + + def __init__( + self, + tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library + output_dim: int, + max_length: int = 1024, + project_out: bool = False, + ): + super().__init__(output_dim, output_dim, project_out=project_out) + + from transformers import AutoTokenizer + + # Suppress logging from transformers + previous_level = logging.root.manager.disable + logging.disable(logging.ERROR) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + finally: + logging.disable(previous_level) + + self.max_length = max_length + + self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim) + + def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + self.proj_out.to(device) + + encoded = self.tokenizer( + texts, + truncation=True, + max_length=self.max_length, + padding="max_length", + return_tensors="pt", + ) + + input_ids = encoded["input_ids"].to(device) + attention_mask = encoded["attention_mask"].to(device).to(torch.bool) + + embeddings = self.token_embedder(input_ids) + + embeddings = self.proj_out(embeddings) + + embeddings = embeddings * attention_mask.unsqueeze(-1).float() + + return embeddings, attention_mask + +class PretransformConditioner(Conditioner): + """ + A conditioner that uses a pretransform's encoder for conditioning + + Args: + pretransform: an instantiated pretransform to use for conditioning + output_dim: the dimension of the output embeddings + """ + def __init__(self, pretransform: Pretransform, output_dim: int): + super().__init__(pretransform.encoded_channels, output_dim) + + self.pretransform = pretransform + + def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: + + self.pretransform.to(device) + self.proj_out.to(device) + + if isinstance(audio, list) or isinstance(audio, tuple): + audio = torch.cat(audio, dim=0) + + # Convert audio to pretransform input channels + audio = set_audio_channels(audio, self.pretransform.io_channels) + + latents = self.pretransform.encode(audio) + + latents = self.proj_out(latents) + + return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)] + +class MultiConditioner(nn.Module): + """ + A module that applies multiple conditioners to an input dictionary based on the keys + + Args: + conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt") + default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"}) + """ + def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}): + super().__init__() + + self.conditioners = nn.ModuleDict(conditioners) + self.default_keys = default_keys + + def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]: + output = {} + + for key, conditioner in self.conditioners.items(): + condition_key = key + + conditioner_inputs = [] + + for x in batch_metadata: + + if condition_key not in x: + if condition_key in self.default_keys: + condition_key = self.default_keys[condition_key] + else: + raise ValueError(f"Conditioner key {condition_key} not found in batch metadata") + + #Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list + if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1: + conditioner_input = x[condition_key][0] + + else: + conditioner_input = x[condition_key] + + conditioner_inputs.append(conditioner_input) + + cond_output = conditioner(conditioner_inputs, device) + if len(cond_output) == 1: + output[key] = cond_output[0] + elif len(cond_output) == 2: + output[key] = cond_output + elif len(cond_output) == 4: + output[key] = cond_output[:2] + output[f'{key}_g'] = cond_output[2:] + + return output + +def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner: + """ + Create a MultiConditioner from a conditioning config dictionary + + Args: + config: the conditioning config dictionary + device: the device to put the conditioners on + """ + conditioners = {} + cond_dim = config["cond_dim"] + + default_keys = config.get("default_keys", {}) + + for conditioner_info in config["configs"]: + id = conditioner_info["id"] + + conditioner_type = conditioner_info["type"] + + conditioner_config = {"output_dim": cond_dim} + + conditioner_config.update(conditioner_info["config"]) + if conditioner_type == "t5": + conditioners[id] = T5Conditioner(**conditioner_config) + elif conditioner_type == "clap_text": + conditioners[id] = CLAPTextConditioner(**conditioner_config) + elif conditioner_type == "clip_text": + conditioners[id] = CLIPTextConditioner(**conditioner_config) + elif conditioner_type == "metaclip_text": + conditioners[id] = MetaCLIPTextConditioner(**conditioner_config) + elif conditioner_type == "clap_audio": + conditioners[id] = CLAPAudioConditioner(**conditioner_config) + elif conditioner_type == "video_linear": + conditioners[id] = Video_Linear(**conditioner_config) + elif conditioner_type == "video_global": + conditioners[id] = Video_Global(**conditioner_config) + elif conditioner_type == "video_sync": + conditioners[id] = Video_Sync(**conditioner_config) + elif conditioner_type == "text_linear": + conditioners[id] = Text_Linear(**conditioner_config) + elif conditioner_type == "video_clip": + conditioners[id] = CLIPConditioner(**conditioner_config) + elif conditioner_type == "video_hiera": + conditioners[id] = VideoHieraConditioner(**conditioner_config) + elif conditioner_type == "int": + conditioners[id] = IntConditioner(**conditioner_config) + elif conditioner_type == "number": + conditioners[id] = NumberConditioner(**conditioner_config) + elif conditioner_type == "phoneme": + conditioners[id] = PhonemeConditioner(**conditioner_config) + elif conditioner_type == "lut": + conditioners[id] = TokenizerLUTConditioner(**conditioner_config) + elif conditioner_type == "pretransform": + sample_rate = conditioner_config.pop("sample_rate", None) + assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners" + + pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate) + + if conditioner_config.get("pretransform_ckpt_path", None) is not None: + pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path"))) + + conditioners[id] = PretransformConditioner(pretransform, **conditioner_config) + elif conditioner_type == "mm_unchang": + conditioners[id] = mm_unchang(**conditioner_config) + else: + raise ValueError(f"Unknown conditioner type: {conditioner_type}") + + return MultiConditioner(conditioners, default_keys=default_keys) \ No newline at end of file diff --git a/think_sound/models/diffusion.py b/think_sound/models/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..7364e87040ac0fe59a14486b442d194abf5e0f71 --- /dev/null +++ b/think_sound/models/diffusion.py @@ -0,0 +1,922 @@ +import torch +from torch import nn +from torch.nn import functional as F +from functools import partial +import numpy as np +import typing as tp + +from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes +from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config +from .dit import DiffusionTransformer +from .mmdit import MMAudio +from .factory import create_pretransform_from_config +from .pretransforms import Pretransform +from ..inference.generation import generate_diffusion_cond + +from .adp import UNetCFG1d, UNet1d + +from time import time + +class Profiler: + + def __init__(self): + self.ticks = [[time(), None]] + + def tick(self, msg): + self.ticks.append([time(), msg]) + + def __repr__(self): + rep = 80 * "=" + "\n" + for i in range(1, len(self.ticks)): + msg = self.ticks[i][1] + ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] + rep += msg + f": {ellapsed*1000:.2f}ms\n" + rep += 80 * "=" + "\n\n\n" + return rep + +class DiffusionModel(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x, t, **kwargs): + raise NotImplementedError() + +class DiffusionModelWrapper(nn.Module): + def __init__( + self, + model: DiffusionModel, + io_channels, + sample_size, + sample_rate, + min_input_length, + pretransform: tp.Optional[Pretransform] = None, + ): + super().__init__() + self.io_channels = io_channels + self.sample_size = sample_size + self.sample_rate = sample_rate + self.min_input_length = min_input_length + + self.model = model + + if pretransform is not None: + self.pretransform = pretransform + else: + self.pretransform = None + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +class ConditionedDiffusionModel(nn.Module): + def __init__(self, + *args, + supports_cross_attention: bool = False, + supports_input_concat: bool = False, + supports_global_cond: bool = False, + supports_prepend_cond: bool = False, + **kwargs): + super().__init__(*args, **kwargs) + self.supports_cross_attention = supports_cross_attention + self.supports_input_concat = supports_input_concat + self.supports_global_cond = supports_global_cond + self.supports_prepend_cond = supports_prepend_cond + + def forward(self, + x: torch.Tensor, + t: torch.Tensor, + cross_attn_cond: torch.Tensor = None, + cross_attn_mask: torch.Tensor = None, + input_concat_cond: torch.Tensor = None, + global_embed: torch.Tensor = None, + prepend_cond: torch.Tensor = None, + prepend_cond_mask: torch.Tensor = None, + cfg_scale: float = 1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + **kwargs): + raise NotImplementedError() + +class ConditionedDiffusionModelWrapper(nn.Module): + """ + A diffusion model that takes in conditioning + """ + def __init__( + self, + model: ConditionedDiffusionModel, + conditioner: MultiConditioner, + io_channels, + sample_rate, + min_input_length: int, + diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", + pretransform: tp.Optional[Pretransform] = None, + cross_attn_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [], + input_concat_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + add_cond_ids: tp.List[str] = [], + ): + super().__init__() + + self.model = model + self.conditioner = conditioner + self.io_channels = io_channels + self.sample_rate = sample_rate + self.diffusion_objective = diffusion_objective + self.pretransform = pretransform + self.cross_attn_cond_ids = cross_attn_cond_ids + self.global_cond_ids = global_cond_ids + self.input_concat_ids = input_concat_ids + self.prepend_cond_ids = prepend_cond_ids + self.add_cond_ids = add_cond_ids + self.min_input_length = min_input_length + + def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False): + cross_attention_input = None + cross_attention_masks = None + global_cond = None + input_concat_cond = None + prepend_cond = None + prepend_cond_mask = None + add_input = None + + if len(self.cross_attn_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + cross_attention_input = [] + cross_attention_masks = [] + + for key in self.cross_attn_cond_ids: + cross_attn_in, cross_attn_mask = conditioning_tensors[key] + + # Add sequence dimension if it's not there + if len(cross_attn_in.shape) == 2: + cross_attn_in = cross_attn_in.unsqueeze(1) + # cross_attn_mask = cross_attn_mask.unsqueeze(1) + + cross_attention_input.append(cross_attn_in) + cross_attention_masks.append(cross_attn_mask) + # import ipdb + # ipdb.set_trace() + cross_attention_input = torch.cat(cross_attention_input, dim=1) + cross_attention_masks = torch.cat(cross_attention_masks, dim=1) + + if len(self.add_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + add_input = [] + + for key in self.add_cond_ids: + add_in, _ = conditioning_tensors[key] + + # Add sequence dimension if it's not there + if len(add_in.shape) == 2: + add_in = add_in.unsqueeze(1) + + add_input.append(add_in) + + add_input = torch.cat(add_input, dim=1) + + if len(self.global_cond_ids) > 0: + # Concatenate all global conditioning inputs over the channel dimension + # Assumes that the global conditioning inputs are of shape (batch, channels) + global_conds = [] + # import ipdb + # ipdb.set_trace() + for key in self.global_cond_ids: + global_cond_input = conditioning_tensors[key][0] + + global_conds.append(global_cond_input) + + # Concatenate over the channel dimension + if global_conds[0].shape[-1] == 768: + global_cond = torch.cat(global_conds, dim=-1) + else: + global_cond = sum(global_conds) + + # global_cond = torch.cat(global_conds, dim=-1) + + if len(global_cond.shape) == 3: + global_cond = global_cond.squeeze(1) + + if len(self.input_concat_ids) > 0: + # Concatenate all input concat conditioning inputs over the channel dimension + # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq) + input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1) + + if len(self.prepend_cond_ids) > 0: + # Concatenate all prepend conditioning inputs over the sequence dimension + # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) + prepend_conds = [] + prepend_cond_masks = [] + + for key in self.prepend_cond_ids: + prepend_cond_input, prepend_cond_mask = conditioning_tensors[key] + prepend_conds.append(prepend_cond_input) + prepend_cond_masks.append(prepend_cond_mask) + + prepend_cond = torch.cat(prepend_conds, dim=1) + prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1) + + if negative: + return { + "negative_cross_attn_cond": cross_attention_input, + "negative_cross_attn_mask": cross_attention_masks, + "negative_global_cond": global_cond, + "negative_input_concat_cond": input_concat_cond + } + else: + return { + "cross_attn_cond": cross_attention_input, + "cross_attn_mask": cross_attention_masks, + "global_cond": global_cond, + "input_concat_cond": input_concat_cond, + "prepend_cond": prepend_cond, + "prepend_cond_mask": prepend_cond_mask, + "add_cond": add_input + } + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): + return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs) + + def generate(self, *args, **kwargs): + return generate_diffusion_cond(self, *args, **kwargs) + +class UNetCFG1DWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True) + + self.model = UNetCFG1d(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + cross_attn_cond=None, + cross_attn_mask=None, + input_concat_cond=None, + global_cond=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + **kwargs): + p = Profiler() + + p.tick("start") + + channels_list = None + if input_concat_cond is not None: + channels_list = [input_concat_cond] + + outputs = self.model( + x, + t, + embedding=cross_attn_cond, + embedding_mask=cross_attn_mask, + features=global_cond, + channels_list=channels_list, + embedding_scale=cfg_scale, + embedding_mask_proba=cfg_dropout_prob, + batch_cfg=batch_cfg, + rescale_cfg=rescale_cfg, + negative_embedding=negative_cross_attn_cond, + negative_embedding_mask=negative_cross_attn_mask, + **kwargs) + + p.tick("UNetCFG1D forward") + + #print(f"Profiler: {p}") + return outputs + +class UNet1DCondWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True) + + self.model = UNet1d(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + input_concat_cond=None, + global_cond=None, + cross_attn_cond=None, + cross_attn_mask=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + **kwargs): + + channels_list = None + if input_concat_cond is not None: + + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') + + channels_list = [input_concat_cond] + + outputs = self.model( + x, + t, + features=global_cond, + channels_list=channels_list, + **kwargs) + + return outputs + +class UNet1DUncondWrapper(DiffusionModel): + def __init__( + self, + in_channels, + *args, + **kwargs + ): + super().__init__() + + self.model = UNet1d(in_channels=in_channels, *args, **kwargs) + + self.io_channels = in_channels + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +class DAU1DCondWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True) + + self.model = DiffusionAttnUnet1D(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + input_concat_cond=None, + cross_attn_cond=None, + cross_attn_mask=None, + global_cond=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = False, + rescale_cfg: bool = False, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + negative_global_cond=None, + negative_input_concat_cond=None, + prepend_cond=None, + **kwargs): + + return self.model(x, t, cond = input_concat_cond) + +class DiffusionAttnUnet1D(nn.Module): + def __init__( + self, + io_channels = 2, + depth=14, + n_attn_layers = 6, + channels = [128, 128, 256, 256] + [512] * 10, + cond_dim = 0, + cond_noise_aug = False, + kernel_size = 5, + learned_resample = False, + strides = [2] * 13, + conv_bias = True, + use_snake = False + ): + super().__init__() + + self.cond_noise_aug = cond_noise_aug + + self.io_channels = io_channels + + if self.cond_noise_aug: + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.timestep_embed = FourierFeatures(1, 16) + + attn_layer = depth - n_attn_layers + + strides = [1] + strides + + block = nn.Identity() + + conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake) + + for i in range(depth, 0, -1): + c = channels[i - 1] + stride = strides[i-1] + if stride > 2 and not learned_resample: + raise ValueError("Must have stride 2 without learned resampling") + + if i > 1: + c_prev = channels[i - 2] + add_attn = i >= attn_layer and n_attn_layers > 0 + block = SkipBlock( + Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"), + conv_block(c_prev, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + block, + conv_block(c * 2 if i != depth else c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c), + SelfAttention1d( + c, c // 32) if add_attn else nn.Identity(), + conv_block(c, c, c_prev), + SelfAttention1d(c_prev, c_prev // + 32) if add_attn else nn.Identity(), + Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic") + ) + else: + cond_embed_dim = 16 if not self.cond_noise_aug else 32 + block = nn.Sequential( + conv_block((io_channels + cond_dim) + cond_embed_dim, c, c), + conv_block(c, c, c), + conv_block(c, c, c), + block, + conv_block(c * 2, c, c), + conv_block(c, c, c), + conv_block(c, c, io_channels, is_last=True), + ) + self.net = block + + with torch.no_grad(): + for param in self.net.parameters(): + param *= 0.5 + + def forward(self, x, t, cond=None, cond_aug_scale=None): + + timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape) + + inputs = [x, timestep_embed] + + if cond is not None: + if cond.shape[2] != x.shape[2]: + cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False) + + if self.cond_noise_aug: + # Get a random number between 0 and 1, uniformly sampled + if cond_aug_scale is None: + aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond) + else: + aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond) + + # Add noise to the conditioning signal + cond = cond + torch.randn_like(cond) * aug_level[:, None, None] + + # Get embedding for noise cond level, reusing timestamp_embed + aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape) + + inputs.append(aug_level_embed) + + inputs.append(cond) + + outputs = self.net(torch.cat(inputs, dim=1)) + + return outputs + +class DiTWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) + + self.model = DiffusionTransformer(*args, **kwargs) + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, + x, + t, + cross_attn_cond=None, + cross_attn_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + negative_input_concat_cond=None, + global_cond=None, + negative_global_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = True, + rescale_cfg: bool = False, + scale_phi: float = 0.0, + **kwargs): + + assert batch_cfg, "batch_cfg must be True for DiTWrapper" + #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" + + return self.model( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_mask, + negative_cross_attn_cond=negative_cross_attn_cond, + negative_cross_attn_mask=negative_cross_attn_mask, + input_concat_cond=input_concat_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + cfg_scale=cfg_scale, + cfg_dropout_prob=cfg_dropout_prob, + scale_phi=scale_phi, + global_embed=global_cond, + **kwargs) + +class MMDiTWrapper(ConditionedDiffusionModel): + def __init__( + self, + *args, + **kwargs + ): + super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False) + + self.model = MMAudio(*args, **kwargs) + + # with torch.no_grad(): + # for param in self.model.parameters(): + # param *= 0.5 + + def forward(self, + x, + t, + clip_f, + sync_f, + text_f, + inpaint_masked_input=None, + t5_features=None, + metaclip_global_text_features=None, + cfg_scale=1.0, + cfg_dropout_prob: float = 0.0, + batch_cfg: bool = True, + rescale_cfg: bool = False, + scale_phi: float = 0.0, + **kwargs): + + # breakpoint() + assert batch_cfg, "batch_cfg must be True for DiTWrapper" + #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper" + + return self.model( + latent=x, + t=t, + clip_f=clip_f, + sync_f=sync_f, + text_f=text_f, + inpaint_masked_input=inpaint_masked_input, + t5_features=t5_features, + metaclip_global_text_features=metaclip_global_text_features, + cfg_scale=cfg_scale, + cfg_dropout_prob=cfg_dropout_prob, + scale_phi=scale_phi, + **kwargs) + +class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel): + """ + A diffusion model that takes in conditioning + """ + def __init__( + self, + model: MMAudio, + conditioner: MultiConditioner, + io_channels, + sample_rate, + min_input_length: int, + diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", + pretransform: tp.Optional[Pretransform] = None, + cross_attn_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [], + input_concat_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + add_cond_ids: tp.List[str] = [], + mm_cond_ids: tp.List[str] = [], + ): + super().__init__() + + self.model = model + self.conditioner = conditioner + self.io_channels = io_channels + self.sample_rate = sample_rate + self.diffusion_objective = diffusion_objective + self.pretransform = pretransform + self.cross_attn_cond_ids = cross_attn_cond_ids + self.global_cond_ids = global_cond_ids + self.input_concat_ids = input_concat_ids + self.prepend_cond_ids = prepend_cond_ids + self.add_cond_ids = add_cond_ids + self.min_input_length = min_input_length + self.mm_cond_ids = mm_cond_ids + + assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper" + assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper" + assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper" + assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper" + assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper" + assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper" + assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper" + assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper" + assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper" + # assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper" + + def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False): + assert negative == False, "negative conditioning is not supported for MMDiTWrapper" + cross_attention_input = None + cross_attention_masks = None + global_cond = None + input_concat_cond = None + prepend_cond = None + prepend_cond_mask = None + add_input = None + inpaint_masked_input = None + t5_features = None + metaclip_global_text_features = None + clip_f = conditioning_tensors["metaclip_features"] + sync_f = conditioning_tensors["sync_features"] + text_f = conditioning_tensors["metaclip_text_features"] + if 'inpaint_masked_input' in conditioning_tensors.keys(): + inpaint_masked_input = conditioning_tensors["inpaint_masked_input"] + if 't5_features' in conditioning_tensors.keys(): + t5_features = conditioning_tensors["t5_features"] + if 'metaclip_global_text_features' in conditioning_tensors.keys(): + metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"] + return { + "clip_f": clip_f, + "sync_f": sync_f, + "text_f": text_f, + "inpaint_masked_input": inpaint_masked_input, + "t5_features": t5_features, + "metaclip_global_text_features": metaclip_global_text_features + } + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs): + # breakpoint() + # print(kwargs) + return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs) + + def generate(self, *args, **kwargs): + return generate_diffusion_cond(self, *args, **kwargs) + +class DiTUncondWrapper(DiffusionModel): + def __init__( + self, + io_channels, + *args, + **kwargs + ): + super().__init__() + + self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs) + + self.io_channels = io_channels + + with torch.no_grad(): + for param in self.model.parameters(): + param *= 0.5 + + def forward(self, x, t, **kwargs): + return self.model(x, t, **kwargs) + +def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]): + diffusion_uncond_config = config["model"] + + model_type = diffusion_uncond_config.get('type', None) + + diffusion_config = diffusion_uncond_config.get('config', {}) + + assert model_type is not None, "Must specify model type in config" + + pretransform = diffusion_uncond_config.get("pretransform", None) + + sample_size = config.get("sample_size", None) + assert sample_size is not None, "Must specify sample size in config" + + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "Must specify sample rate in config" + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if model_type == 'DAU1d': + + model = DiffusionAttnUnet1D( + **diffusion_config + ) + + elif model_type == "adp_uncond_1d": + + model = UNet1DUncondWrapper( + **diffusion_config + ) + + elif model_type == "dit": + model = DiTUncondWrapper( + **diffusion_config + ) + + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + + return DiffusionModelWrapper(model, + io_channels=model.io_channels, + sample_size=sample_size, + sample_rate=sample_rate, + pretransform=pretransform, + min_input_length=min_input_length) + +def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]): + diffusion_uncond_config = config["model"] + + + diffusion_config = diffusion_uncond_config.get('diffusion', {}) + model_type = diffusion_config.get('type', None) + model_config = diffusion_config.get("config",{}) + assert model_type is not None, "Must specify model type in config" + + pretransform = diffusion_uncond_config.get("pretransform", None) + + sample_size = config.get("sample_size", None) + assert sample_size is not None, "Must specify sample size in config" + + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "Must specify sample rate in config" + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if model_type == 'DAU1d': + + model = DiffusionAttnUnet1D( + **model_config + ) + + elif model_type == "adp_uncond_1d": + + model = UNet1DUncondWrapper( + io_channels = io_channels, + **model_config + ) + elif model_type == "dit": + model = DiTUncondWrapper( + **model_config + ) + + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + + return DiffusionModelWrapper(model, + io_channels=model.io_channels, + sample_size=sample_size, + sample_rate=sample_rate, + pretransform=pretransform, + min_input_length=min_input_length) + +def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]): + + model_config = config["model"] + + model_type = config["model_type"] + + diffusion_config = model_config.get('diffusion', None) + assert diffusion_config is not None, "Must specify diffusion config" + + diffusion_model_type = diffusion_config.get('type', None) + assert diffusion_model_type is not None, "Must specify diffusion model type" + + diffusion_model_config = diffusion_config.get('config', None) + assert diffusion_model_config is not None, "Must specify diffusion model config" + + if diffusion_model_type == 'adp_cfg_1d': + diffusion_model = UNetCFG1DWrapper(**diffusion_model_config) + elif diffusion_model_type == 'adp_1d': + diffusion_model = UNet1DCondWrapper(**diffusion_model_config) + elif diffusion_model_type == 'dit': + diffusion_model = DiTWrapper(**diffusion_model_config) + elif diffusion_model_type == 'mmdit': + diffusion_model = MMDiTWrapper(**diffusion_model_config) + + io_channels = model_config.get('io_channels', None) + assert io_channels is not None, "Must specify io_channels in model config" + + sample_rate = config.get('sample_rate', None) + assert sample_rate is not None, "Must specify sample_rate in config" + + diffusion_objective = diffusion_config.get('diffusion_objective', 'v') + + conditioning_config = model_config.get('conditioning', None) + + conditioner = None + if conditioning_config is not None: + conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) + + cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) + add_cond_ids = diffusion_config.get('add_cond_ids', []) + global_cond_ids = diffusion_config.get('global_cond_ids', []) + input_concat_ids = diffusion_config.get('input_concat_ids', []) + prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) + mm_cond_ids = diffusion_config.get('mm_cond_ids', []) + + pretransform = model_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d": + min_input_length *= np.prod(diffusion_model_config["factors"]) + elif diffusion_model_type == "dit": + min_input_length *= diffusion_model.model.patch_size + + # Get the proper wrapper class + + extra_kwargs = {} + + if model_type == "mm_diffusion_cond": + wrapper_fn = MMConditionedDiffusionModelWrapper + extra_kwargs["diffusion_objective"] = diffusion_objective + extra_kwargs["mm_cond_ids"] = mm_cond_ids + + if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill': + wrapper_fn = ConditionedDiffusionModelWrapper + extra_kwargs["diffusion_objective"] = diffusion_objective + + elif model_type == "diffusion_prior": + prior_type = model_config.get("prior_type", None) + assert prior_type is not None, "Must specify prior_type in diffusion prior model config" + + if prior_type == "mono_stereo": + from .diffusion_prior import MonoToStereoDiffusionPrior + wrapper_fn = MonoToStereoDiffusionPrior + + return wrapper_fn( + diffusion_model, + conditioner, + min_input_length=min_input_length, + sample_rate=sample_rate, + cross_attn_cond_ids=cross_attention_ids, + global_cond_ids=global_cond_ids, + input_concat_ids=input_concat_ids, + prepend_cond_ids=prepend_cond_ids, + add_cond_ids=add_cond_ids, + pretransform=pretransform, + io_channels=io_channels, + **extra_kwargs + ) \ No newline at end of file diff --git a/think_sound/models/diffusion_prior.py b/think_sound/models/diffusion_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb15258d7656fb85ee763910dc9b500331de603 --- /dev/null +++ b/think_sound/models/diffusion_prior.py @@ -0,0 +1,82 @@ +from enum import Enum +import typing as tp + +from .diffusion import ConditionedDiffusionModelWrapper +from ..inference.generation import generate_diffusion_cond +from ..inference.utils import prepare_audio + +import torch +from torch.nn import functional as F +from torchaudio import transforms as T + +# Define prior types enum +class PriorType(Enum): + MonoToStereo = 1 + +class DiffusionPrior(ConditionedDiffusionModelWrapper): + def __init__(self, *args, prior_type: PriorType=None, **kwargs): + super().__init__(*args, **kwargs) + self.prior_type = prior_type + +class MonoToStereoDiffusionPrior(DiffusionPrior): + def __init__(self, *args, **kwargs): + super().__init__(*args, prior_type=PriorType.MonoToStereo, **kwargs) + + def stereoize( + self, + audio: torch.Tensor, # (batch, channels, time) + video: torch.Tensor, + in_sr: int, + steps: int, + sampler_kwargs: dict = {}, + ): + """ + Generate stereo audio from mono audio using a pre-trained diffusion prior + + Args: + audio: The mono audio to convert to stereo + in_sr: The sample rate of the input audio + steps: The number of diffusion steps to run + sampler_kwargs: Keyword arguments to pass to the diffusion sampler + """ + + device = audio.device + + sample_rate = self.sample_rate + + # Resample input audio if necessary + if in_sr != sample_rate: + resample_tf = T.Resample(in_sr, sample_rate).to(audio.device) + audio = resample_tf(audio) + + audio_length = audio.shape[-1] + + # # Pad input audio to be compatible with the model + # min_length = self.min_input_length + # padded_input_length = audio_length + (min_length - (audio_length % min_length)) % min_length + + # # Pad input audio to be compatible with the model + # if padded_input_length > audio_length: + # audio = F.pad(audio, (0, padded_input_length - audio_length)) + + # Make audio mono, duplicate to stereo + dual_mono = audio.mean(1, keepdim=True).repeat(1, 2, 1) + + if self.pretransform is not None: + dual_mono = self.pretransform.encode(dual_mono) + + conditioning = self.conditioner([{'video':video}], device) + # Return fake stereo audio + conditioning["source"] = [dual_mono] + stereo_audio = generate_diffusion_cond( + self, + conditioning_tensors=conditioning, + steps=steps, + sample_size=audio_length, + sample_rate=sample_rate, + device=device, + cfg_scale=1, + **sampler_kwargs, + ) + + return stereo_audio \ No newline at end of file diff --git a/think_sound/models/discriminators.py b/think_sound/models/discriminators.py new file mode 100644 index 0000000000000000000000000000000000000000..b593168df965bb1f57881ea79edbc2f66478c6c2 --- /dev/null +++ b/think_sound/models/discriminators.py @@ -0,0 +1,546 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from functools import reduce +import typing as tp +from einops import rearrange +from audiotools import AudioSignal, STFTParams +from dac.model.discriminator import WNConv1d, WNConv2d + +def get_hinge_losses(score_real, score_fake): + gen_loss = -score_fake.mean() + dis_loss = torch.relu(1 - score_real).mean() + torch.relu(1 + score_fake).mean() + return dis_loss, gen_loss + +class EncodecDiscriminator(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + from encodec.msstftd import MultiScaleSTFTDiscriminator + + self.discriminators = MultiScaleSTFTDiscriminator(*args, **kwargs) + + def forward(self, x): + logits, features = self.discriminators(x) + return logits, features + + def loss(self, x, y): + feature_matching_distance = 0. + logits_true, feature_true = self.forward(x) + logits_fake, feature_fake = self.forward(y) + + dis_loss = torch.tensor(0.) + adv_loss = torch.tensor(0.) + + for i, (scale_true, scale_fake) in enumerate(zip(feature_true, feature_fake)): + + feature_matching_distance = feature_matching_distance + sum( + map( + lambda x, y: abs(x - y).mean(), + scale_true, + scale_fake, + )) / len(scale_true) + + _dis, _adv = get_hinge_losses( + logits_true[i], + logits_fake[i], + ) + + dis_loss = dis_loss + _dis + adv_loss = adv_loss + _adv + + return dis_loss, adv_loss, feature_matching_distance + +# Discriminators from oobleck + +IndividualDiscriminatorOut = tp.Tuple[torch.Tensor, tp.Sequence[torch.Tensor]] + +TensorDict = tp.Dict[str, torch.Tensor] + +class SharedDiscriminatorConvNet(nn.Module): + + def __init__( + self, + in_size: int, + convolution: tp.Union[nn.Conv1d, nn.Conv2d], + out_size: int = 1, + capacity: int = 32, + n_layers: int = 4, + kernel_size: int = 15, + stride: int = 4, + activation: tp.Callable[[], nn.Module] = lambda: nn.SiLU(), + normalization: tp.Callable[[nn.Module], nn.Module] = torch.nn.utils.weight_norm, + ) -> None: + super().__init__() + channels = [in_size] + channels += list(capacity * 2**np.arange(n_layers)) + + if isinstance(stride, int): + stride = n_layers * [stride] + + net = [] + for i in range(n_layers): + if isinstance(kernel_size, int): + pad = kernel_size // 2 + s = stride[i] + else: + pad = kernel_size[0] // 2 + s = (stride[i], 1) + + net.append( + normalization( + convolution( + channels[i], + channels[i + 1], + kernel_size, + stride=s, + padding=pad, + ))) + net.append(activation()) + + net.append(convolution(channels[-1], out_size, 1)) + + self.net = nn.ModuleList(net) + + def forward(self, x) -> IndividualDiscriminatorOut: + features = [] + for layer in self.net: + x = layer(x) + if isinstance(layer, nn.modules.conv._ConvNd): + features.append(x) + score = x.reshape(x.shape[0], -1).mean(-1) + return score, features + + +class MultiScaleDiscriminator(nn.Module): + + def __init__(self, + in_channels: int, + n_scales: int, + **conv_kwargs) -> None: + super().__init__() + layers = [] + for _ in range(n_scales): + layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv1d, **conv_kwargs)) + self.layers = nn.ModuleList(layers) + + def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: + score = 0 + features = [] + for layer in self.layers: + s, f = layer(x) + score = score + s + features.extend(f) + x = nn.functional.avg_pool1d(x, 2) + return score, features + +class MultiPeriodDiscriminator(nn.Module): + + def __init__(self, + in_channels: int, + periods: tp.Sequence[int], + **conv_kwargs) -> None: + super().__init__() + layers = [] + self.periods = periods + + for _ in periods: + layers.append(SharedDiscriminatorConvNet(in_channels, nn.Conv2d, **conv_kwargs)) + + self.layers = nn.ModuleList(layers) + + def forward(self, x: torch.Tensor) -> IndividualDiscriminatorOut: + score = 0 + features = [] + for layer, n in zip(self.layers, self.periods): + s, f = layer(self.fold(x, n)) + score = score + s + features.extend(f) + return score, features + + def fold(self, x: torch.Tensor, n: int) -> torch.Tensor: + pad = (n - (x.shape[-1] % n)) % n + x = nn.functional.pad(x, (0, pad)) + return x.reshape(*x.shape[:2], -1, n) + + +class MultiDiscriminator(nn.Module): + """ + Individual discriminators should take a single tensor as input (NxB C T) and + return a tuple composed of a score tensor (NxB) and a Sequence of Features + Sequence[NxB C' T']. + """ + + def __init__(self, discriminator_list: tp.Sequence[nn.Module], + keys: tp.Sequence[str]) -> None: + super().__init__() + self.discriminators = nn.ModuleList(discriminator_list) + self.keys = keys + + def unpack_tensor_to_dict(self, features: torch.Tensor) -> TensorDict: + features = features.chunk(len(self.keys), 0) + return {k: features[i] for i, k in enumerate(self.keys)} + + @staticmethod + def concat_dicts(dict_a, dict_b): + out_dict = {} + keys = set(list(dict_a.keys()) + list(dict_b.keys())) + for k in keys: + out_dict[k] = [] + if k in dict_a: + if isinstance(dict_a[k], list): + out_dict[k].extend(dict_a[k]) + else: + out_dict[k].append(dict_a[k]) + if k in dict_b: + if isinstance(dict_b[k], list): + out_dict[k].extend(dict_b[k]) + else: + out_dict[k].append(dict_b[k]) + return out_dict + + @staticmethod + def sum_dicts(dict_a, dict_b): + out_dict = {} + keys = set(list(dict_a.keys()) + list(dict_b.keys())) + for k in keys: + out_dict[k] = 0. + if k in dict_a: + out_dict[k] = out_dict[k] + dict_a[k] + if k in dict_b: + out_dict[k] = out_dict[k] + dict_b[k] + return out_dict + + def forward(self, inputs: TensorDict) -> TensorDict: + discriminator_input = torch.cat([inputs[k] for k in self.keys], 0) + all_scores = [] + all_features = [] + + for discriminator in self.discriminators: + score, features = discriminator(discriminator_input) + scores = self.unpack_tensor_to_dict(score) + scores = {f"score_{k}": scores[k] for k in scores.keys()} + all_scores.append(scores) + + features = map(self.unpack_tensor_to_dict, features) + features = reduce(self.concat_dicts, features) + features = {f"features_{k}": features[k] for k in features.keys()} + all_features.append(features) + + all_scores = reduce(self.sum_dicts, all_scores) + all_features = reduce(self.concat_dicts, all_features) + + inputs.update(all_scores) + inputs.update(all_features) + + return inputs + +class OobleckDiscriminator(nn.Module): + + def __init__( + self, + in_channels=1, + ): + super().__init__() + + multi_scale_discriminator = MultiScaleDiscriminator( + in_channels=in_channels, + n_scales=3, + ) + + multi_period_discriminator = MultiPeriodDiscriminator( + in_channels=in_channels, + periods=[2, 3, 5, 7, 11] + ) + + # multi_resolution_discriminator = MultiScaleSTFTDiscriminator( + # filters=32, + # in_channels = in_channels, + # out_channels = 1, + # n_ffts = [2048, 1024, 512, 256, 128], + # hop_lengths = [512, 256, 128, 64, 32], + # win_lengths = [2048, 1024, 512, 256, 128] + # ) + + self.multi_discriminator = MultiDiscriminator( + [multi_scale_discriminator, multi_period_discriminator], #, multi_resolution_discriminator], + ["reals", "fakes"] + ) + + def loss(self, reals, fakes): + inputs = { + "reals": reals, + "fakes": fakes, + } + + inputs = self.multi_discriminator(inputs) + + scores_real = inputs["score_reals"] + scores_fake = inputs["score_fakes"] + + features_real = inputs["features_reals"] + features_fake = inputs["features_fakes"] + + dis_loss, gen_loss = get_hinge_losses(scores_real, scores_fake) + + feature_matching_distance = torch.tensor(0.) + + for _, (scale_real, scale_fake) in enumerate(zip(features_real, features_fake)): + + feature_matching_distance = feature_matching_distance + sum( + map( + lambda real, fake: abs(real - fake).mean(), + scale_real, + scale_fake, + )) / len(scale_real) + + return dis_loss, gen_loss, feature_matching_distance + + +## Discriminators from Descript Audio Codec repo +## Copied and modified under MIT license, see LICENSES/LICENSE_DESCRIPT.txt +class MPD(nn.Module): + def __init__(self, period, channels=1): + super().__init__() + + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(channels, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 44100, channels=1): + super().__init__() + + self.convs = nn.ModuleList( + [ + WNConv1d(channels, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 44100, + bands: list = BANDS, + channels: int = 1 + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + self.channels = channels + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + x = rearrange(x, "b ch f t c -> (b ch) c t f", ch=self.channels) + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class DACDiscriminator(nn.Module): + def __init__( + self, + channels: int = 1, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p, channels=channels) for p in periods] + discs += [MSD(r, sample_rate=sample_rate, channels=channels) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands, channels=channels) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + +class DACGANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, **discriminator_kwargs): + super().__init__() + self.discriminator = DACDiscriminator(**discriminator_kwargs) + + def forward(self, fake, real): + d_fake = self.discriminator(fake) + d_real = self.discriminator(real) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature + + def loss(self, fake, real): + gen_loss, feature_distance = self.generator_loss(fake, real) + dis_loss = self.discriminator_loss(fake, real) + + return dis_loss, gen_loss, feature_distance \ No newline at end of file diff --git a/think_sound/models/dit.py b/think_sound/models/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..00ce7c6f2d4b59796e129d397f9d5067bf855c56 --- /dev/null +++ b/think_sound/models/dit.py @@ -0,0 +1,439 @@ +import typing as tp + +import torch +# from beartype.typing import Tuple +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from x_transformers import ContinuousTransformerWrapper, Encoder +from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP +from .blocks import FourierFeatures +from .transformer import ContinuousTransformer +from .utils import mask_from_frac_lengths, resample +class DiffusionTransformer(nn.Module): + def __init__(self, + io_channels=32, + patch_size=1, + embed_dim=768, + cond_token_dim=0, + project_cond_tokens=True, + global_cond_dim=0, + project_global_cond=True, + input_concat_dim=0, + prepend_cond_dim=0, + cond_ctx_dim=0, + depth=12, + num_heads=8, + transformer_type: tp.Literal["x-transformers", "continuous_transformer","mm_transformer"] = "x-transformers", + global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + frac_lengths_mask = (0.7, 1.), + ctx_drop: float = 0.1, + add_token_dim=0, + use_mlp=False, + **kwargs): + + super().__init__() + + self.cond_token_dim = cond_token_dim + + # Timestep embeddings + timestep_features_dim = 256 + + self.timestep_features = FourierFeatures(1, timestep_features_dim) + + self.to_timestep_embed = nn.Sequential( + nn.Linear(timestep_features_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True), + ) + self.use_mlp = use_mlp + if cond_token_dim > 0: + # Conditioning tokens + cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim + self.to_cond_embed = nn.Sequential( + nn.Linear(cond_token_dim, cond_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(cond_embed_dim, cond_embed_dim, bias=False) + ) + else: + cond_embed_dim = 0 + + if global_cond_dim > 0: + # Global conditioning + global_embed_dim = global_cond_dim if not project_global_cond else embed_dim + self.to_global_embed = nn.Sequential( + nn.Linear(global_cond_dim, global_embed_dim, bias=False), + nn.SiLU(), + nn.Linear(global_embed_dim, global_embed_dim, bias=False) + ) + + if add_token_dim > 0: + # Conditioning tokens + + add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim + self.to_add_embed = nn.Sequential( + nn.SiLU(), + ConvMLP(add_embed_dim, add_embed_dim * 4, kernel_size=3, padding=1), + ) + else: + add_embed_dim = 0 + + if cond_ctx_dim > 0: + self.ctx_linear = nn.Linear(cond_ctx_dim*2, cond_ctx_dim, bias=True) + self.frac_lengths_mask = frac_lengths_mask + self.ctx_drop = ctx_drop + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + self.input_concat_dim = input_concat_dim + + dim_in = io_channels + self.input_concat_dim + + self.patch_size = patch_size + + # Transformer + + self.transformer_type = transformer_type + + self.global_cond_type = global_cond_type + print("######################") + print(f'global type: {global_cond_type}') + print("######################") + if self.transformer_type == "x-transformers": + self.transformer = ContinuousTransformerWrapper( + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + max_seq_len=0, #Not relevant without absolute positional embeds + attn_layers = Encoder( + dim=embed_dim, + depth=depth, + heads=num_heads, + attn_flash = True, + cross_attend = cond_token_dim > 0, + dim_context=None if cond_embed_dim == 0 else cond_embed_dim, + zero_init_branch_output=True, + use_abs_pos_emb = False, + rotary_pos_emb=True, + ff_swish = True, + ff_glu = True, + **kwargs + ) + ) + + elif self.transformer_type == "continuous_transformer": + + global_dim = None + + if self.global_cond_type == "adaLN": + # The global conditioning is projected to the embed_dim already at this point + global_dim = embed_dim + + self.transformer = ContinuousTransformer( + dim=embed_dim, + depth=depth, + dim_heads=embed_dim // num_heads, + dim_in=dim_in * patch_size, + dim_out=io_channels * patch_size, + cross_attend = cond_token_dim > 0, + cond_token_dim = cond_embed_dim, + global_cond_dim=global_dim, + **kwargs + ) + + else: + raise ValueError(f"Unknown transformer type: {self.transformer_type}") + + self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False) + nn.init.zeros_(self.preprocess_conv.weight) + self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) + nn.init.zeros_(self.postprocess_conv.weight) + + def _forward( + self, + x, + t, + mask=None, + cross_attn_cond=None, + cross_attn_cond_mask=None, + input_concat_cond=None, + global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + add_cond=None, + add_masks=None, + # x_ctx=None, + return_info=False, + **kwargs): + + if cross_attn_cond is not None: + cross_attn_cond = self.to_cond_embed(cross_attn_cond) + if global_embed is not None: + # Project the global conditioning to the embedding dimension + global_embed = self.to_global_embed(global_embed) + if len(global_embed.shape) == 3: + global_embed = torch.max(global_embed, dim=1).values + + prepend_inputs = None + prepend_mask = None + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + + prepend_inputs = prepend_cond + if prepend_cond_mask is not None: + prepend_mask = prepend_cond_mask + + if input_concat_cond is not None: + + # Interpolate input_concat_cond to the same length as x + if input_concat_cond.shape[2] != x.shape[2]: + input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest-exact') + + x = torch.cat([x, input_concat_cond], dim=1) + + if add_cond is not None: + # Interpolate input_concat_cond to the same length as x + + if self.use_mlp: + add_cond = self.to_add_embed(add_cond) + if add_cond.shape[1] != x.shape[2]: + # add_cond = add_cond.transpose(1,2) + # add_cond = F.interpolate(add_cond, (x.shape[2], ), mode='nearest-exact') + # add_cond = add_cond.transpose(1,2) + add_cond = resample(add_cond, x) + + # Get the batch of timestep embeddings + timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim) + # import ipdb + # ipdb.set_trace() + # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists + if global_embed is not None: + global_embed = global_embed + timestep_embed + else: + global_embed = timestep_embed + + # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer + if self.global_cond_type == "prepend": + if prepend_inputs is None: + # Prepend inputs are just the global embed, and the mask is all ones + prepend_inputs = global_embed.unsqueeze(1) + prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) + else: + # Prepend inputs are the prepend conditioning + the global embed + prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) + prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) + + prepend_length = prepend_inputs.shape[1] + + x = self.preprocess_conv(x) + x + x = rearrange(x, "b c t -> b t c") + + + extra_args = {} + + if self.global_cond_type == "adaLN": + extra_args["global_cond"] = global_embed + + if self.patch_size > 1: + x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) + + if self.transformer_type == "x-transformers": + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, add_cond=add_cond, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs) + elif self.transformer_type == "continuous_transformer": + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) + + if return_info: + output, info = output + elif self.transformer_type == "mm_transformer": + output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs) + + output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] + + if self.patch_size > 1: + output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) + + output = self.postprocess_conv(output) + output + + if return_info: + return output, info + + return output + + def forward( + self, + x, + t, + cross_attn_cond=None, + cross_attn_cond_mask=None, + negative_cross_attn_cond=None, + negative_cross_attn_mask=None, + input_concat_cond=None, + global_embed=None, + negative_global_embed=None, + prepend_cond=None, + prepend_cond_mask=None, + add_cond=None, + cfg_scale=1.0, + cfg_dropout_prob=0.0, + causal=False, + scale_phi=0.0, + mask=None, + x_ctx=None, + ctx_mask=None, + return_info=False, + **kwargs): + + assert causal == False, "Causal mode is not supported for DiffusionTransformer" + bsz, a, b = x.shape + + if cross_attn_cond_mask is not None: + cross_attn_cond_mask = cross_attn_cond_mask.bool() + + cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.bool() + + # CFG dropout + if cfg_dropout_prob > 0.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if add_cond is not None: + null_embed = torch.zeros_like(add_cond, device=add_cond.device) + dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool) + add_cond = torch.where(dropout_mask, null_embed, add_cond) + + if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None): + # Classifier-free guidance + # Concatenate conditioned and unconditioned inputs on the batch dimension + batch_inputs = torch.cat([x, x], dim=0) + batch_timestep = torch.cat([t, t], dim=0) + + if global_embed is not None: + batch_global_cond = torch.cat([global_embed, global_embed], dim=0) + else: + batch_global_cond = None + + if input_concat_cond is not None: + batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) + else: + batch_input_concat_cond = None + + batch_cond = None + batch_cond_masks = None + + # Handle CFG for cross-attention conditioning + if cross_attn_cond is not None: + + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning + if negative_cross_attn_cond is not None: + + # If there's a negative cross-attention mask, set the masked tokens to the null embed + if negative_cross_attn_mask is not None: + negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) + + negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed) + + batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) + + else: + batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if cross_attn_cond_mask is not None: + batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) + + batch_prepend_cond = None + batch_prepend_cond_mask = None + + if prepend_cond is not None: + + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + batch_add_cond = None + + # Handle CFG for cross-attention conditioning + if add_cond is not None: + + null_embed = torch.zeros_like(add_cond, device=add_cond.device) + + + batch_add_cond = torch.cat([add_cond, null_embed], dim=0) + + + if mask is not None: + batch_masks = torch.cat([mask, mask], dim=0) + else: + batch_masks = None + + batch_output = self._forward( + batch_inputs, + batch_timestep, + cross_attn_cond=batch_cond, + cross_attn_cond_mask=batch_cond_masks, + mask = batch_masks, + # x_ctx=x_ctx, + input_concat_cond=batch_input_concat_cond, + global_embed = batch_global_cond, + prepend_cond = batch_prepend_cond, + prepend_cond_mask = batch_prepend_cond_mask, + add_cond = batch_add_cond, + return_info = return_info, + **kwargs) + + if return_info: + batch_output, info = batch_output + + cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) + cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale + + # CFG Rescale + if scale_phi != 0.0: + cond_out_std = cond_output.std(dim=1, keepdim=True) + out_cfg_std = cfg_output.std(dim=1, keepdim=True) + output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output + else: + output = cfg_output + + if return_info: + return output, info + + return output + + else: + return self._forward( + x, + t, + cross_attn_cond=cross_attn_cond, + cross_attn_cond_mask=cross_attn_cond_mask, + input_concat_cond=input_concat_cond, + global_embed=global_embed, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + add_cond=add_cond, + # x_ctx=x_ctx, + mask=mask, + return_info=return_info, + **kwargs + ) \ No newline at end of file diff --git a/think_sound/models/factory.py b/think_sound/models/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..20a44e84aa6f58e7fc128a5d82f97e41ccc809b6 --- /dev/null +++ b/think_sound/models/factory.py @@ -0,0 +1,156 @@ +import json + +def create_model_from_config(model_config): + model_type = model_config.get('model_type', None) + + assert model_type is not None, 'model_type must be specified in model config' + + if model_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + return create_autoencoder_from_config(model_config) + elif model_type == 'diffusion_uncond': + from .diffusion import create_diffusion_uncond_from_config + return create_diffusion_uncond_from_config(model_config) + # elif model_type == 'diffusion_infill': + # from .diffusion import create_diffusion_infill_from_config + # return create_diffusion_infill_from_config(model_config) + elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior" or model_type == "diffusion_infill" or model_type == "mm_diffusion_cond": + from .diffusion import create_diffusion_cond_from_config + return create_diffusion_cond_from_config(model_config) + elif model_type == 'diffusion_autoencoder': + from .autoencoders import create_diffAE_from_config + return create_diffAE_from_config(model_config) + elif model_type == 'lm': + from .lm import create_audio_lm_from_config + return create_audio_lm_from_config(model_config) + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + +def create_model_from_config_path(model_config_path): + with open(model_config_path) as f: + model_config = json.load(f) + + return create_model_from_config(model_config) + +def create_pretransform_from_config(pretransform_config, sample_rate): + pretransform_type = pretransform_config.get('type', None) + + assert pretransform_type is not None, 'type must be specified in pretransform config' + + if pretransform_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + from .pretransforms import AutoencoderPretransform + + # Create fake top-level config to pass sample rate to autoencoder constructor + # This is a bit of a hack but it keeps us from re-defining the sample rate in the config + autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} + autoencoder = create_autoencoder_from_config(autoencoder_config) + + scale = pretransform_config.get("scale", 1.0) + model_half = pretransform_config.get("model_half", False) + iterate_batch = pretransform_config.get("iterate_batch", False) + chunked = pretransform_config.get("chunked", False) + + pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) + elif pretransform_type == 'wavelet': + from .pretransforms import WaveletPretransform + + wavelet_config = pretransform_config["config"] + channels = wavelet_config["channels"] + levels = wavelet_config["levels"] + wavelet = wavelet_config["wavelet"] + + pretransform = WaveletPretransform(channels, levels, wavelet) + elif pretransform_type == 'pqmf': + from .pretransforms import PQMFPretransform + pqmf_config = pretransform_config["config"] + pretransform = PQMFPretransform(**pqmf_config) + elif pretransform_type == 'dac_pretrained': + from .pretransforms import PretrainedDACPretransform + pretrained_dac_config = pretransform_config["config"] + pretransform = PretrainedDACPretransform(**pretrained_dac_config) + elif pretransform_type == "audiocraft_pretrained": + from .pretransforms import AudiocraftCompressionPretransform + + audiocraft_config = pretransform_config["config"] + pretransform = AudiocraftCompressionPretransform(**audiocraft_config) + else: + raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') + + enable_grad = pretransform_config.get('enable_grad', False) + pretransform.enable_grad = enable_grad + + pretransform.eval().requires_grad_(pretransform.enable_grad) + + return pretransform + +def create_bottleneck_from_config(bottleneck_config): + bottleneck_type = bottleneck_config.get('type', None) + + assert bottleneck_type is not None, 'type must be specified in bottleneck config' + + if bottleneck_type == 'tanh': + from .bottleneck import TanhBottleneck + bottleneck = TanhBottleneck() + elif bottleneck_type == 'vae': + from .bottleneck import VAEBottleneck + bottleneck = VAEBottleneck() + elif bottleneck_type == 'rvq': + from .bottleneck import RVQBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQBottleneck(**quantizer_params) + elif bottleneck_type == "dac_rvq": + from .bottleneck import DACRVQBottleneck + + bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) + + elif bottleneck_type == 'rvq_vae': + from .bottleneck import RVQVAEBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQVAEBottleneck(**quantizer_params) + + elif bottleneck_type == 'dac_rvq_vae': + from .bottleneck import DACRVQVAEBottleneck + bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) + elif bottleneck_type == 'l2_norm': + from .bottleneck import L2Bottleneck + bottleneck = L2Bottleneck() + elif bottleneck_type == "wasserstein": + from .bottleneck import WassersteinBottleneck + bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) + elif bottleneck_type == "fsq": + from .bottleneck import FSQBottleneck + bottleneck = FSQBottleneck(**bottleneck_config["config"]) + else: + raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') + + requires_grad = bottleneck_config.get('requires_grad', True) + if not requires_grad: + for param in bottleneck.parameters(): + param.requires_grad = False + + return bottleneck diff --git a/think_sound/models/lm.py b/think_sound/models/lm.py new file mode 100644 index 0000000000000000000000000000000000000000..1897fa72ab716f69e0c6d71236e47cc50f78592e --- /dev/null +++ b/think_sound/models/lm.py @@ -0,0 +1,541 @@ +from dataclasses import dataclass +import torch +from tqdm.auto import trange +import typing as tp +from einops import rearrange +from torch import nn + +from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config +from .factory import create_pretransform_from_config +from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone +from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform +from .utils import multinomial, sample_top_k, sample_top_p + +from .codebook_patterns import ( + CodebooksPatternProvider, + DelayedPatternProvider, + MusicLMPattern, + ParallelPatternProvider, + UnrolledPatternProvider +) + +# Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license +# License can be found in LICENSES/LICENSE_META.txt + +@dataclass +class LMOutput: + # The logits are already re-aligned with the input codes + # hence no extra shift is required, e.g. when computing CE + logits: torch.Tensor # [B, K, T, card] + mask: torch.Tensor # [B, K, T] + +# Wrapper for a multi-codebook language model +# Handles patterns and quantizer heads +class AudioLanguageModel(nn.Module): + def __init__( + self, + pattern_provider: CodebooksPatternProvider, + backbone: AudioLMBackbone, + num_quantizers: int, + codebook_size: int + ): + super().__init__() + + self.pattern_provider = pattern_provider + self.backbone = backbone + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + + self.masked_token_id = codebook_size + + # Per-quantizer embedders + # Add one for the mask embed + self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)]) + + # Per-quantizer output heads + self.quantizer_heads = nn.ModuleList([ + nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers) + ]) + + def forward(self, + sequence: torch.Tensor, #[batch, seq_len, + prepend_cond=None, #[batch, seq, channels] + prepend_cond_mask=None, + cross_attn_cond=None, #[batch, seq, channels], + **kwargs + ): + + batch, num_quantizers, seq_len = sequence.shape + + assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model" + + backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) # [batch, seq_len, embed_dim] + + dtype = next(self.parameters()).dtype + + if cross_attn_cond is not None: + cross_attn_cond = cross_attn_cond.to(dtype) + + if prepend_cond is not None: + prepend_cond = prepend_cond.to(dtype) + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.to(dtype) + + backbone_input = backbone_input.to(dtype) + + output = self.backbone( + backbone_input, + cross_attn_cond=cross_attn_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + **kwargs + ) # [batch, seq_len, embed_dim] + + # Run output through quantizer heads + logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) # [batch, num_quantizers, seq_len, codebook_size] + + return logits + + def compute_logits( + self, + codes, #[batch, num_quantizers, seq_len] + **kwargs): + """ + Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning + Handles translation between input sequence and pattern-shifted sequence + Only used during training + """ + + batch, _, seq_len = codes.shape + + pattern = self.pattern_provider.get_pattern(seq_len) + + # Apply the token pattern to the codes, shifting the codes as needed and masking out invalid steps + shifted_codes, _, _ = pattern.build_pattern_sequence( + codes, + self.masked_token_id, + keep_only_valid_steps=True + ) + + # Run the model to get logits for each quantizer [batch, num_quantizers, seq_len, codebook_size] + logits = self(shifted_codes, **kwargs) + + # Rearrange logits to prepare to revert pattern + logits = rearrange(logits, "b n s c -> b c n s") + + # Revert sequence logits back to original sequence length, removing masked steps + logits, _, logits_mask = pattern.revert_pattern_logits( + logits, float('nan'), keep_only_valid_steps=True + ) + + logits = rearrange(logits, "b c n t -> b n t c") + + logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) # [batch, num_quantizers, seq_len] + + return LMOutput(logits=logits, mask=logits_mask) + +# Conditioning and generation wrapper for a multi-codebook language model +# Handles conditioning, CFG, generation, and encoding/decoding +class AudioLanguageModelWrapper(nn.Module): + def __init__( + self, + pretransform: Pretransform, + lm: AudioLanguageModel, + sample_rate: int, + min_input_length: int, + conditioner: MultiConditioner = None, + cross_attn_cond_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [] + ): + super().__init__() + + assert pretransform.is_discrete, "Pretransform must be discrete" + self.pretransform = pretransform + + self.pretransform.requires_grad_(False) + self.pretransform.eval() + + if isinstance(self.pretransform, AutoencoderPretransform): + self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers + self.codebook_size = self.pretransform.model.bottleneck.codebook_size + elif isinstance(self.pretransform, PretrainedDACPretransform): + self.num_quantizers = self.pretransform.model.num_quantizers + self.codebook_size = self.pretransform.model.codebook_size + elif isinstance(self.pretransform, AudiocraftCompressionPretransform): + self.num_quantizers = self.pretransform.num_quantizers + self.codebook_size = self.pretransform.codebook_size + else: + raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}") + + self.conditioner = conditioner + + self.lm = lm + + self.sample_rate = sample_rate + self.min_input_length = min_input_length + + self.cross_attn_cond_ids = cross_attn_cond_ids + self.prepend_cond_ids = prepend_cond_ids + self.global_cond_ids = global_cond_ids + + def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): + cross_attention_input = None + prepend_cond = None + prepend_cond_mask = None + global_cond = None + + if len(self.cross_attn_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1) + + if len(self.prepend_cond_ids) > 0: + # Concatenate all prepend conditioning inputs over the sequence dimension + # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) + prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1) + prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1) + + if len(self.global_cond_ids) > 0: + # Concatenate all global conditioning inputs over the channel dimension + # Assumes that the global conditioning inputs are of shape (batch, channels) + global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1) + if len(global_cond.shape) == 3: + global_cond = global_cond.squeeze(1) + + if negative: + return { + "negative_cross_attn_cond": cross_attention_input, + "negative_prepend_cond": prepend_cond, + "negative_prepend_cond_mask": prepend_cond_mask, + "negative_global_cond": global_cond + } + else: + return { + "cross_attn_cond": cross_attention_input, + "prepend_cond": prepend_cond, + "prepend_cond_mask": prepend_cond_mask, + "global_cond": global_cond + } + + def compute_logits( + self, + codes, + condition_tensors=None, + cfg_dropout_prob=0.0, + **kwargs + ): + """ + Compute logits for a batch of codes, and translates from conditioning inputs to model inputs + Handles CFG dropout + """ + + if condition_tensors is None: + condition_tensors = {} + + conditioning_inputs = self.get_conditioning_inputs(condition_tensors) + + cross_attn_cond = conditioning_inputs["cross_attn_cond"] + prepend_cond = conditioning_inputs["prepend_cond"] + prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] + global_cond = conditioning_inputs["global_cond"] + + if cfg_dropout_prob > 0.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if global_cond is not None: + null_embed = torch.zeros_like(global_cond, device=global_cond.device) + dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool) + global_cond = torch.where(dropout_mask, null_embed, global_cond) + + return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + + def _sample_next_token( + self, + sequence, #[batch, num_quantizers, seq_len] + conditioning_tensors=None, + cross_attn_use_cfg=True, + prepend_use_cfg=True, + global_use_cfg=True, + cfg_scale=1.0, + top_k=250, + top_p=0.0, + temp=1.0, + **kwargs + ): + """ + Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs + Handles CFG inference + """ + + if conditioning_tensors is None: + conditioning_tensors = {} + + conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors) + + cross_attn_cond = conditioning_inputs["cross_attn_cond"] + prepend_cond = conditioning_inputs["prepend_cond"] + prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] + global_cond = conditioning_inputs["global_cond"] + + if cfg_scale != 1.0: + + # Batch size is doubled to account for negative samples + sequence = torch.cat([sequence, sequence], dim=0) + + if cross_attn_cond is not None and cross_attn_use_cfg: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if prepend_cond is not None and prepend_use_cfg: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + if global_cond is not None and global_use_cfg: + null_embed = torch.zeros_like(global_cond, device=global_cond.device) + + global_cond = torch.cat([global_cond, null_embed], dim=0) + + logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + + if cfg_scale != 1.0: + cond_logits, uncond_logits = logits.chunk(2, dim=0) + + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale + + logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len] + + # Grab the logits for the last step + logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size] + + # Apply top-k or top-p sampling + + if temp > 0: + probs = torch.softmax(logits / temp, dim=-1) + + if top_p > 0.0: + next_token = sample_top_p(probs, p=top_p) + elif top_k > 0: + next_token = sample_top_k(probs, k=top_k) + else: + next_token = multinomial(probs, num_samples=1) + + else: + next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1] + + return next_token + + @torch.no_grad() + def generate( + self, + max_gen_len: int = 256, + batch_size: tp.Optional[int] = None, + init_data: tp.Optional[torch.Tensor] = None, + conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None, + callback: tp.Optional[tp.Callable[[int, int], None]] = None, + use_cache: bool = True, + cfg_scale: float = 1.0, + **kwargs + ): + device = next(self.parameters()).device + + if conditioning_tensors is None and conditioning is not None: + # Convert conditioning inputs to conditioning tensors + conditioning_tensors = self.conditioner(conditioning, device) + + # Check that batch size is consistent across inputs + possible_batch_sizes = [] + + if batch_size is not None: + possible_batch_sizes.append(batch_size) + elif init_data is not None: + possible_batch_sizes.append(init_data.shape[0]) + elif conditioning_tensors is not None: + # Assume that the first conditioning tensor has the batch dimension + possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0]) + else: + possible_batch_sizes.append(1) + + assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" + + batch_size = possible_batch_sizes[0] + + if init_data is None: + # Initialize with zeros + assert batch_size > 0 + init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long) + + batch_size, num_quantizers, seq_len = init_data.shape + + start_offset = seq_len + assert start_offset < max_gen_len, "init data longer than max gen length" + + pattern = self.lm.pattern_provider.get_pattern(max_gen_len) + + unknown_token = -1 + + # Initialize the generated codes with the init data, padded with unknown tokens + gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long) + gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len] + + gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len] + + start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) + assert start_offset_sequence is not None + + # Generation + prev_offset = 0 + gen_sequence_len = gen_sequence.shape[-1] + + # Reset generation cache + if use_cache and self.lm.backbone.use_generation_cache: + self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2) + + for offset in trange(start_offset_sequence, gen_sequence_len): + + # Get the full sequence up to the current offset + curr_sequence = gen_sequence[..., prev_offset:offset] + + next_token = self._sample_next_token( + curr_sequence, + conditioning_tensors=conditioning_tensors, + use_cache=use_cache, + cfg_scale=cfg_scale, + **kwargs + ) + + valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1) + next_token[~valid_mask] = self.lm.masked_token_id + + # Update the generated sequence with the next token + gen_sequence[..., offset:offset+1] = torch.where( + gen_sequence[..., offset:offset+1] == unknown_token, + next_token, + gen_sequence[..., offset:offset+1] + ) + + if use_cache and self.lm.backbone.use_generation_cache: + # Only update the offset if caching is being used + prev_offset = offset + + self.lm.backbone.update_generation_cache(offset) + + if callback is not None: + # Callback to report progress + # Pass in the offset relative to the start of the sequence, and the length of the current sequence + callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) + + assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence" + + out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) + + # sanity checks over the returned codes and corresponding masks + assert (out_codes[..., :max_gen_len] != unknown_token).all() + assert (out_mask[..., :max_gen_len] == 1).all() + + #out_codes = out_codes[..., 0:max_gen_len] + + return out_codes + + + def generate_audio( + self, + **kwargs + ): + """ + Generate audio from a batch of codes + """ + + codes = self.generate(**kwargs) + + audio = self.pretransform.decode_tokens(codes) + + return audio + + +def create_audio_lm_from_config(config): + model_config = config.get('model', None) + assert model_config is not None, 'model config must be specified in config' + + sample_rate = config.get('sample_rate', None) + assert sample_rate is not None, "Must specify sample_rate in config" + + lm_config = model_config.get('lm', None) + assert lm_config is not None, 'lm config must be specified in model config' + + codebook_pattern = lm_config.get("codebook_pattern", "delay") + + pattern_providers = { + 'parallel': ParallelPatternProvider, + 'delay': DelayedPatternProvider, + 'unroll': UnrolledPatternProvider, + 'musiclm': MusicLMPattern, + } + + pretransform_config = model_config.get("pretransform", None) + + pretransform = create_pretransform_from_config(pretransform_config, sample_rate) + + assert pretransform.is_discrete, "Pretransform must be discrete" + + min_input_length = pretransform.downsampling_ratio + + pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers) + + conditioning_config = model_config.get('conditioning', None) + + conditioner = None + if conditioning_config is not None: + conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) + + cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', []) + prepend_cond_ids = lm_config.get('prepend_cond_ids', []) + global_cond_ids = lm_config.get('global_cond_ids', []) + + lm_type = lm_config.get("type", None) + lm_model_config = lm_config.get("config", None) + + assert lm_type is not None, "Must specify lm type in lm config" + assert lm_model_config is not None, "Must specify lm model config in lm config" + + if lm_type == "x-transformers": + backbone = XTransformersAudioLMBackbone(**lm_model_config) + elif lm_type == "continuous_transformer": + backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config) + else: + raise NotImplementedError(f"Unrecognized lm type {lm_type}") + + lm = AudioLanguageModel( + pattern_provider=pattern_provider, + backbone=backbone, + num_quantizers=pretransform.num_quantizers, + codebook_size=pretransform.codebook_size + ) + + model = AudioLanguageModelWrapper( + pretransform=pretransform, + lm=lm, + conditioner=conditioner, + sample_rate=sample_rate, + min_input_length=min_input_length, + cross_attn_cond_ids=cross_attn_cond_ids, + prepend_cond_ids=prepend_cond_ids, + global_cond_ids=global_cond_ids + ) + + return model \ No newline at end of file diff --git a/think_sound/models/lm_backbone.py b/think_sound/models/lm_backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..c80cce60b06d9b367b114188444b0890a1990b61 --- /dev/null +++ b/think_sound/models/lm_backbone.py @@ -0,0 +1,159 @@ +from torch import nn +from x_transformers import ContinuousTransformerWrapper, Decoder + +from .transformer import ContinuousTransformer + +# Interface for backbone of a language model +# Handles conditioning and cross-attention +# Does not have to deal with patterns or quantizer heads +class AudioLMBackbone(nn.Module): + def __init__(self, embed_dim: int, use_generation_cache=False, **kwargs): + super().__init__() + + self.embed_dim = embed_dim + self.use_generation_cache = use_generation_cache + + def forward( + self, + x, + cross_attn_cond=None, + prepend_cond=None, + prepend_cond_mask=None, + global_cond=None, + use_cache=False, + **kwargs + ): + raise NotImplementedError + + def reset_generation_cache( + self, + max_seq_len, + batch_size, + dtype=None + ): + pass + + def update_generation_cache( + self, + seqlen_offset + ): + pass + +class XTransformersAudioLMBackbone(AudioLMBackbone): + def __init__(self, + embed_dim: int, + cross_attn_cond_dim: int = 0, + prepend_cond_dim: int = 0, + **kwargs): + super().__init__(embed_dim=embed_dim) + + # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer + self.model = ContinuousTransformerWrapper( + dim_in=embed_dim, + dim_out=embed_dim, + max_seq_len=0, #Not relevant without absolute positional embeds, + attn_layers=Decoder( + dim=embed_dim, + attn_flash = True, + cross_attend = cross_attn_cond_dim > 0, + zero_init_branch_output=True, + use_abs_pos_emb = False, + rotary_pos_emb=True, + ff_swish = True, + ff_glu = True, + **kwargs + ) + ) + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + if cross_attn_cond_dim > 0: + # Cross-attention conditioning + self.to_cross_attn_embed = nn.Sequential( + nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): + + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + prepend_length = prepend_cond.shape[1] + + if prepend_cond_mask is not None: + # Cast mask to bool + prepend_cond_mask = prepend_cond_mask.bool() + + if cross_attn_cond is not None: + # Project the cross-attention conditioning to the embedding dimension + cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) + + return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] + +class ContinuousTransformerAudioLMBackbone(AudioLMBackbone): + def __init__(self, + embed_dim: int, + cross_attn_cond_dim: int = 0, + prepend_cond_dim: int = 0, + project_cross_attn_cond: bool = False, + **kwargs): + super().__init__(embed_dim=embed_dim) + + # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer + self.model = ContinuousTransformer( + dim=embed_dim, + dim_in=embed_dim, + dim_out=embed_dim, + cross_attend = cross_attn_cond_dim > 0, + cond_token_dim = embed_dim if project_cross_attn_cond else cross_attn_cond_dim, + causal=True, + **kwargs + ) + + if prepend_cond_dim > 0: + # Prepend conditioning + self.to_prepend_embed = nn.Sequential( + nn.Linear(prepend_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + + if cross_attn_cond_dim > 0 and project_cross_attn_cond: + # Cross-attention conditioning + self.to_cross_attn_embed = nn.Sequential( + nn.Linear(cross_attn_cond_dim, embed_dim, bias=False), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=False) + ) + else: + self.to_cross_attn_embed = nn.Identity() + + def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross_attn_cond=None, global_cond=None, use_cache=False): + + prepend_length = 0 + if prepend_cond is not None: + # Project the prepend conditioning to the embedding dimension + prepend_cond = self.to_prepend_embed(prepend_cond) + prepend_length = prepend_cond.shape[1] + + if prepend_cond_mask is not None: + # Cast mask to bool + prepend_cond_mask = prepend_cond_mask.bool() + + if cross_attn_cond is not None: + # Cast cross_attn_cond to same dtype as self.to_cross_attn_embed + cross_attn_cond = cross_attn_cond.to(self.to_cross_attn_embed[0].weight.dtype) + + # Project the cross-attention conditioning to the embedding dimension + cross_attn_cond = self.to_cross_attn_embed(cross_attn_cond) + + return self.model(x, mask=mask, context=cross_attn_cond, prepend_embeds=prepend_cond, prepend_mask=prepend_cond_mask)[:, prepend_length:, :] \ No newline at end of file diff --git a/think_sound/models/lm_continuous.py b/think_sound/models/lm_continuous.py new file mode 100644 index 0000000000000000000000000000000000000000..469bb49f32492794345cf76dafbb377778eca81e --- /dev/null +++ b/think_sound/models/lm_continuous.py @@ -0,0 +1,525 @@ +from dataclasses import dataclass +import torch +from tqdm.auto import trange +import typing as tp +from einops import rearrange +from torch import nn + +from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config +from .factory import create_pretransform_from_config +from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone +from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform +from .utils import multinomial, sample_top_k, sample_top_p +from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper, create_diffusion_cond_from_config + +from .codebook_patterns import ( + CodebooksPatternProvider, + DelayedPatternProvider, + MusicLMPattern, + ParallelPatternProvider, + UnrolledPatternProvider +) + +# Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/models/lm.py under MIT license +# License can be found in LICENSES/LICENSE_META.txt + +@dataclass +class LMContinuousOutput: + # The logits are already re-aligned with the input codes + # hence no extra shift is required, e.g. when computing CE + logits: torch.Tensor # [B, K, T, card] + mask: torch.Tensor # [B, K, T] + +# Wrapper for a multi-codebook language model +# Handles patterns and quantizer heads +class AudioLMContinuousModel(nn.Module): + def __init__( + self, + backbone: AudioLMBackbone, + ): + super().__init__() + + self.backbone = backbone + + def sample_orders(self, bsz): + # generate a batch of random generation orders + orders = [] + for _ in range(bsz): + order = np.array(list(range(self.seq_len))) + np.random.shuffle(order) + orders.append(order) + orders = torch.Tensor(np.array(orders)).cuda().long() + return orders + + def random_masking(self, x, orders): + # generate token mask + bsz, seq_len, embed_dim = x.shape + mask_rate = self.mask_ratio_generator.rvs(1)[0] + num_masked_tokens = int(np.ceil(seq_len * mask_rate)) + mask = torch.zeros(bsz, seq_len, device=x.device) + mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens], + src=torch.ones(bsz, seq_len, device=x.device)) + return mask + + def forward(self, + sequence: torch.Tensor, #[batch, seq_len, + prepend_cond=None, #[batch, seq, channels] + prepend_cond_mask=None, + cross_attn_cond=None, #[batch, seq, channels], + **kwargs + ): + + + batch, seq_len, dim = sequence.shape + + dtype = next(self.parameters()).dtype + + if cross_attn_cond is not None: + cross_attn_cond = cross_attn_cond.to(dtype) + + if prepend_cond is not None: + prepend_cond = prepend_cond.to(dtype) + + if prepend_cond_mask is not None: + prepend_cond_mask = prepend_cond_mask.to(dtype) + + x = sequence.to(dtype) + orders = self.sample_orders(bsz=batch) + mask = self.random_masking(x, orders) + + output = self.backbone( + x, + mask = mask, + cross_attn_cond=cross_attn_cond, + prepend_cond=prepend_cond, + prepend_cond_mask=prepend_cond_mask, + **kwargs + ) # [batch, seq_len, embed_dim] + + + return output + +# Conditioning and generation wrapper for a multi-codebook language model +# Handles conditioning, CFG, generation, and encoding/decoding +class AudioLanguageModelWrapper(nn.Module): + def __init__( + self, + pretransform: Pretransform, + lm: AudioLanguageModel, + diff: ConditionedDiffusionModelWrapper, + sample_rate: int, + min_input_length: int, + conditioner: MultiConditioner = None, + diffusion_objective: tp.Literal["v", "rectified_flow"] = "v", + cross_attn_cond_ids: tp.List[str] = [], + prepend_cond_ids: tp.List[str] = [], + global_cond_ids: tp.List[str] = [] + ): + super().__init__() + + assert pretransform.is_discrete, "Pretransform must be discrete" + self.pretransform = pretransform + + self.pretransform.requires_grad_(False) + self.pretransform.eval() + self.diffusion_objective = diffusion_objective + print(f'Training in the {diffusion_objective} formulation') + if isinstance(self.pretransform, AutoencoderPretransform): + self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers + self.codebook_size = self.pretransform.model.bottleneck.codebook_size + elif isinstance(self.pretransform, PretrainedDACPretransform): + self.num_quantizers = self.pretransform.model.num_quantizers + self.codebook_size = self.pretransform.model.codebook_size + elif isinstance(self.pretransform, AudiocraftCompressionPretransform): + self.num_quantizers = self.pretransform.num_quantizers + self.codebook_size = self.pretransform.codebook_size + else: + raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}") + + self.conditioner = conditioner + + self.lm = lm + + self.sample_rate = sample_rate + self.min_input_length = min_input_length + + self.cross_attn_cond_ids = cross_attn_cond_ids + self.prepend_cond_ids = prepend_cond_ids + self.global_cond_ids = global_cond_ids + + def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): + cross_attention_input = None + prepend_cond = None + prepend_cond_mask = None + global_cond = None + + if len(self.cross_attn_cond_ids) > 0: + # Concatenate all cross-attention inputs over the sequence dimension + # Assumes that the cross-attention inputs are of shape (batch, seq, channels) + cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1) + + if len(self.prepend_cond_ids) > 0: + # Concatenate all prepend conditioning inputs over the sequence dimension + # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels) + prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1) + prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1) + + if len(self.global_cond_ids) > 0: + # Concatenate all global conditioning inputs over the channel dimension + # Assumes that the global conditioning inputs are of shape (batch, channels) + global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1) + if len(global_cond.shape) == 3: + global_cond = global_cond.squeeze(1) + + if negative: + return { + "negative_cross_attn_cond": cross_attention_input, + "negative_prepend_cond": prepend_cond, + "negative_prepend_cond_mask": prepend_cond_mask, + "negative_global_cond": global_cond + } + else: + return { + "cross_attn_cond": cross_attention_input, + "prepend_cond": prepend_cond, + "prepend_cond_mask": prepend_cond_mask, + "global_cond": global_cond + } + + def compute_logits( + self, + audios, + condition_tensors=None, + cfg_dropout_prob=0.0, + **kwargs + ): + """ + Compute logits for a batch of codes, and translates from conditioning inputs to model inputs + Handles CFG dropout + """ + + if condition_tensors is None: + condition_tensors = {} + + conditioning_inputs = self.get_conditioning_inputs(condition_tensors) + + cross_attn_cond = conditioning_inputs["cross_attn_cond"] + prepend_cond = conditioning_inputs["prepend_cond"] + prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] + global_cond = conditioning_inputs["global_cond"] + + if cfg_dropout_prob > 0.0: + if cross_attn_cond is not None: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) + cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) + + if prepend_cond is not None: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) + prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) + + if global_cond is not None: + null_embed = torch.zeros_like(global_cond, device=global_cond.device) + dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool) + global_cond = torch.where(dropout_mask, null_embed, global_cond) + + return self.lm.forward(audios, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + + def _sample_next_token( + self, + sequence, #[batch, num_quantizers, seq_len] + conditioning_tensors=None, + cross_attn_use_cfg=True, + prepend_use_cfg=True, + global_use_cfg=True, + cfg_scale=1.0, + top_k=250, + top_p=0.0, + temp=1.0, + **kwargs + ): + """ + Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs + Handles CFG inference + """ + + if conditioning_tensors is None: + conditioning_tensors = {} + + conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors) + + cross_attn_cond = conditioning_inputs["cross_attn_cond"] + prepend_cond = conditioning_inputs["prepend_cond"] + prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] + global_cond = conditioning_inputs["global_cond"] + + if cfg_scale != 1.0: + + # Batch size is doubled to account for negative samples + sequence = torch.cat([sequence, sequence], dim=0) + + if cross_attn_cond is not None and cross_attn_use_cfg: + null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) + + cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0) + + if prepend_cond is not None and prepend_use_cfg: + null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) + + prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) + + if prepend_cond_mask is not None: + prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) + + if global_cond is not None and global_use_cfg: + null_embed = torch.zeros_like(global_cond, device=global_cond.device) + + global_cond = torch.cat([global_cond, null_embed], dim=0) + + logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) + + if cfg_scale != 1.0: + cond_logits, uncond_logits = logits.chunk(2, dim=0) + + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale + + logits = rearrange(logits, "b n s c -> b n c s") # [batch, num_quantizers, codebook_size, seq_len] + + # Grab the logits for the last step + logits = logits[:, :, :, -1] # [batch, num_quantizers, codebook_size] + + # Apply top-k or top-p sampling + + if temp > 0: + probs = torch.softmax(logits / temp, dim=-1) + + if top_p > 0.0: + next_token = sample_top_p(probs, p=top_p) + elif top_k > 0: + next_token = sample_top_k(probs, k=top_k) + else: + next_token = multinomial(probs, num_samples=1) + + else: + next_token = torch.argmax(logits, dim=-1, keepdim=True) # [batch, num_quantizers, 1] + + return next_token + + @torch.no_grad() + def generate( + self, + max_gen_len: int = 256, + batch_size: tp.Optional[int] = None, + init_data: tp.Optional[torch.Tensor] = None, + conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None, + callback: tp.Optional[tp.Callable[[int, int], None]] = None, + use_cache: bool = True, + cfg_scale: float = 1.0, + **kwargs + ): + device = next(self.parameters()).device + + if conditioning_tensors is None and conditioning is not None: + # Convert conditioning inputs to conditioning tensors + conditioning_tensors = self.conditioner(conditioning, device) + + # Check that batch size is consistent across inputs + possible_batch_sizes = [] + + if batch_size is not None: + possible_batch_sizes.append(batch_size) + elif init_data is not None: + possible_batch_sizes.append(init_data.shape[0]) + elif conditioning_tensors is not None: + # Assume that the first conditioning tensor has the batch dimension + possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0]) + else: + possible_batch_sizes.append(1) + + assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" + + batch_size = possible_batch_sizes[0] + + if init_data is None: + # Initialize with zeros + assert batch_size > 0 + init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long) + + batch_size, num_quantizers, seq_len = init_data.shape + + start_offset = seq_len + assert start_offset < max_gen_len, "init data longer than max gen length" + + pattern = self.lm.pattern_provider.get_pattern(max_gen_len) + + unknown_token = -1 + + # Initialize the generated codes with the init data, padded with unknown tokens + gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long) + gen_codes[:, :, :start_offset] = init_data # [batch, num_quantizers, max_gen_len] + + gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) # [batch, num_quantizers, gen_sequence_len] + + start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) + assert start_offset_sequence is not None + + # Generation + prev_offset = 0 + gen_sequence_len = gen_sequence.shape[-1] + + # Reset generation cache + if use_cache and self.lm.backbone.use_generation_cache: + self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2) + + for offset in trange(start_offset_sequence, gen_sequence_len): + + # Get the full sequence up to the current offset + curr_sequence = gen_sequence[..., prev_offset:offset] + + next_token = self._sample_next_token( + curr_sequence, + conditioning_tensors=conditioning_tensors, + use_cache=use_cache, + cfg_scale=cfg_scale, + **kwargs + ) + + valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1) + next_token[~valid_mask] = self.lm.masked_token_id + + # Update the generated sequence with the next token + gen_sequence[..., offset:offset+1] = torch.where( + gen_sequence[..., offset:offset+1] == unknown_token, + next_token, + gen_sequence[..., offset:offset+1] + ) + + if use_cache and self.lm.backbone.use_generation_cache: + # Only update the offset if caching is being used + prev_offset = offset + + self.lm.backbone.update_generation_cache(offset) + + if callback is not None: + # Callback to report progress + # Pass in the offset relative to the start of the sequence, and the length of the current sequence + callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) + + assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence" + + out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) + + # sanity checks over the returned codes and corresponding masks + assert (out_codes[..., :max_gen_len] != unknown_token).all() + assert (out_mask[..., :max_gen_len] == 1).all() + + #out_codes = out_codes[..., 0:max_gen_len] + + return out_codes + + + def generate_audio( + self, + **kwargs + ): + """ + Generate audio from a batch of codes + """ + + codes = self.generate(**kwargs) + + audio = self.pretransform.decode_tokens(codes) + + return audio + + +def create_audio_lm_continuous_from_config(config): + model_config = config.get('model', None) + assert model_config is not None, 'model config must be specified in config' + + sample_rate = config.get('sample_rate', None) + assert sample_rate is not None, "Must specify sample_rate in config" + + lm_config = model_config.get('lm', None) + assert lm_config is not None, 'lm config must be specified in model config' + + + + pretransform_config = model_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + min_input_length = pretransform.downsampling_ratio + else: + min_input_length = 1 + + + conditioning_config = model_config.get('conditioning', None) + + conditioner = None + if conditioning_config is not None: + conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) + + cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', []) + prepend_cond_ids = lm_config.get('prepend_cond_ids', []) + global_cond_ids = lm_config.get('global_cond_ids', []) + + lm_type = lm_config.get("type", None) + lm_model_config = lm_config.get("config", None) + + assert lm_type is not None, "Must specify lm type in lm config" + assert lm_model_config is not None, "Must specify lm model config in lm config" + + if lm_type == "x-transformers": + backbone = XTransformersAudioLMBackbone(**lm_model_config) + elif lm_type == "continuous_transformer": + backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config) + else: + raise NotImplementedError(f"Unrecognized lm type {lm_type}") + + lm = AudioLanguageModel( + pattern_provider=pattern_provider, + backbone=backbone, + num_quantizers=pretransform.num_quantizers, + codebook_size=pretransform.codebook_size + ) + + diff_config = model_config.get("diffusion", None) + diffusion_model = DiTWrapper(**diff_config) + + cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', []) + add_cond_ids = diffusion_config.get('add_cond_ids', []) + global_cond_ids = diffusion_config.get('global_cond_ids', []) + input_concat_ids = diffusion_config.get('input_concat_ids', []) + prepend_cond_ids = diffusion_config.get('prepend_cond_ids', []) + + diff = ConditionedDiffusionModelWrapper( + diffusion_model, + conditioner=None, + min_input_length=min_input_length, + sample_rate=sample_rate, + cross_attn_cond_ids=cross_attention_ids, + global_cond_ids=global_cond_ids, + input_concat_ids=input_concat_ids, + prepend_cond_ids=prepend_cond_ids, + add_cond_ids=add_cond_ids, + pretransform=pretransform, + io_channels=2, + ) + + + model = AudioLanguageModelWrapper( + pretransform=pretransform, + lm=lm, + diff=diff, + conditioner=conditioner, + sample_rate=sample_rate, + min_input_length=min_input_length, + cross_attn_cond_ids=cross_attn_cond_ids, + prepend_cond_ids=prepend_cond_ids, + global_cond_ids=global_cond_ids + ) + + return model \ No newline at end of file diff --git a/think_sound/models/local_attention.py b/think_sound/models/local_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..893ce11fce1f263dd02ff2a2ebe8b5e67426f83f --- /dev/null +++ b/think_sound/models/local_attention.py @@ -0,0 +1,278 @@ +import torch + +from einops import rearrange +from torch import nn + +from .blocks import AdaRMSNorm +from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +# Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py +class ContinuousLocalTransformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + dim_in = None, + dim_out = None, + causal = False, + local_attn_window_size = 64, + heads = 8, + ff_mult = 2, + cond_dim = 0, + cross_attn_cond_dim = 0, + **kwargs + ): + super().__init__() + + dim_head = dim//heads + + self.layers = nn.ModuleList([]) + + self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity() + + self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity() + + self.local_attn_window_size = local_attn_window_size + + self.cond_dim = cond_dim + + self.cross_attn_cond_dim = cross_attn_cond_dim + + self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32)) + + for _ in range(depth): + + self.layers.append(nn.ModuleList([ + AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), + Attention( + dim=dim, + dim_heads=dim_head, + causal=causal, + zero_init_output=True, + natten_kernel_size=local_attn_window_size, + ), + Attention( + dim=dim, + dim_heads=dim_head, + dim_context = cross_attn_cond_dim, + zero_init_output=True + ) if self.cross_attn_cond_dim > 0 else nn.Identity(), + AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim), + FeedForward(dim = dim, mult = ff_mult, no_bias=True) + ])) + + def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None): + + x = checkpoint(self.project_in, x) + + if prepend_cond is not None: + x = torch.cat([prepend_cond, x], dim=1) + + pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + + for attn_norm, attn, xattn, ff_norm, ff in self.layers: + + residual = x + if cond is not None: + x = checkpoint(attn_norm, x, cond) + else: + x = checkpoint(attn_norm, x) + + x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual + + if cross_attn_cond is not None: + x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x + + residual = x + + if cond is not None: + x = checkpoint(ff_norm, x, cond) + else: + x = checkpoint(ff_norm, x) + + x = checkpoint(ff, x) + residual + + return checkpoint(self.project_out, x) + +class TransformerDownsampleBlock1D(nn.Module): + def __init__( + self, + in_channels, + embed_dim = 768, + depth = 3, + heads = 12, + downsample_ratio = 2, + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + self.downsample_ratio = downsample_ratio + + self.transformer = ContinuousLocalTransformer( + dim=embed_dim, + depth=depth, + heads=heads, + local_attn_window_size=local_attn_window_size, + **kwargs + ) + + self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() + + self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False) + + + def forward(self, x): + + x = checkpoint(self.project_in, x) + + # Compute + x = self.transformer(x) + + # Trade sequence length for channels + x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio) + + # Project back to embed dim + x = checkpoint(self.project_down, x) + + return x + +class TransformerUpsampleBlock1D(nn.Module): + def __init__( + self, + in_channels, + embed_dim, + depth = 3, + heads = 12, + upsample_ratio = 2, + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + self.upsample_ratio = upsample_ratio + + self.transformer = ContinuousLocalTransformer( + dim=embed_dim, + depth=depth, + heads=heads, + local_attn_window_size = local_attn_window_size, + **kwargs + ) + + self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity() + + self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False) + + def forward(self, x): + + # Project to embed dim + x = checkpoint(self.project_in, x) + + # Project to increase channel dim + x = checkpoint(self.project_up, x) + + # Trade channels for sequence length + x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio) + + # Compute + x = self.transformer(x) + + return x + + +class TransformerEncoder1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dims = [96, 192, 384, 768], + heads = [12, 12, 12, 12], + depths = [3, 3, 3, 3], + ratios = [2, 2, 2, 2], + local_attn_window_size = 64, + **kwargs + ): + super().__init__() + + layers = [] + + for layer in range(len(depths)): + prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] + + layers.append( + TransformerDownsampleBlock1D( + in_channels = prev_dim, + embed_dim = embed_dims[layer], + heads = heads[layer], + depth = depths[layer], + downsample_ratio = ratios[layer], + local_attn_window_size = local_attn_window_size, + **kwargs + ) + ) + + self.layers = nn.Sequential(*layers) + + self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) + self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) + + def forward(self, x): + x = rearrange(x, "b c n -> b n c") + x = checkpoint(self.project_in, x) + x = self.layers(x) + x = checkpoint(self.project_out, x) + x = rearrange(x, "b n c -> b c n") + + return x + + +class TransformerDecoder1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dims = [768, 384, 192, 96], + heads = [12, 12, 12, 12], + depths = [3, 3, 3, 3], + ratios = [2, 2, 2, 2], + local_attn_window_size = 64, + **kwargs + ): + + super().__init__() + + layers = [] + + for layer in range(len(depths)): + prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0] + + layers.append( + TransformerUpsampleBlock1D( + in_channels = prev_dim, + embed_dim = embed_dims[layer], + heads = heads[layer], + depth = depths[layer], + upsample_ratio = ratios[layer], + local_attn_window_size = local_attn_window_size, + **kwargs + ) + ) + + self.layers = nn.Sequential(*layers) + + self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False) + self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False) + + def forward(self, x): + x = rearrange(x, "b c n -> b n c") + x = checkpoint(self.project_in, x) + x = self.layers(x) + x = checkpoint(self.project_out, x) + x = rearrange(x, "b n c -> b c n") + return x \ No newline at end of file diff --git a/think_sound/models/mmdit.py b/think_sound/models/mmdit.py new file mode 100644 index 0000000000000000000000000000000000000000..9ec9ba3d7eba4b4af2c15daf593e3393a664034e --- /dev/null +++ b/think_sound/models/mmdit.py @@ -0,0 +1,531 @@ +import logging +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import sys +from .mmmodules.ext.rotary_embeddings import compute_rope_rotations +from .mmmodules.model.embeddings import TimestepEmbedder +from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP +from .mmmodules.model.transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock) +from .utils import resample + +log = logging.getLogger() + + +@dataclass +class PreprocessedConditions: + clip_f: torch.Tensor + sync_f: torch.Tensor + text_f: torch.Tensor + clip_f_c: torch.Tensor + text_f_c: torch.Tensor + + +# Partially from https://github.com/facebookresearch/DiT +class MMAudio(nn.Module): + + def __init__(self, + *, + latent_dim: int, + clip_dim: int, + sync_dim: int, + text_dim: int, + hidden_dim: int, + depth: int, + fused_depth: int, + num_heads: int, + mlp_ratio: float = 4.0, + latent_seq_len: int, + clip_seq_len: int, + sync_seq_len: int, + text_seq_len: int = 77, + latent_mean: Optional[torch.Tensor] = None, + latent_std: Optional[torch.Tensor] = None, + empty_string_feat: Optional[torch.Tensor] = None, + v2: bool = False, + kernel_size: int = 7, + sync_kernel: int = 7, + use_inpaint: bool = False, + use_mlp: bool = False, + cross_attend: bool = False, + add_video: bool = False, + triple_fusion: bool = False, + gated_video: bool = False) -> None: + super().__init__() + + self.v2 = v2 + self.latent_dim = latent_dim + self._latent_seq_len = latent_seq_len + self._clip_seq_len = clip_seq_len + self._sync_seq_len = sync_seq_len + self._text_seq_len = text_seq_len + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.cross_attend = cross_attend + self.add_video = add_video + self.gated_video = gated_video + self.triple_fusion = triple_fusion + self.use_inpaint = use_inpaint + if self.gated_video: + self.gated_mlp = nn.Sequential( + nn.LayerNorm(hidden_dim * 2), + nn.Linear(hidden_dim*2, hidden_dim * 4, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim * 4, hidden_dim, bias=False), + nn.Sigmoid() + ) + # 初始化最后一层权重为零,促进初始均匀融合 + nn.init.zeros_(self.gated_mlp[3].weight) + if self.triple_fusion: + self.gated_mlp_v = nn.Sequential( + nn.LayerNorm(hidden_dim * 3), + nn.Linear(hidden_dim*3, hidden_dim * 4, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim * 4, hidden_dim, bias=False), + nn.Sigmoid() + ) + self.gated_mlp_t = nn.Sequential( + nn.LayerNorm(hidden_dim * 3), + nn.Linear(hidden_dim*3, hidden_dim * 4, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim * 4, hidden_dim, bias=False), + nn.Sigmoid() + ) + # 初始化最后一层权重为零,促进初始均匀融合 + nn.init.zeros_(self.gated_mlp_v[3].weight) + nn.init.zeros_(self.gated_mlp_t[3].weight) + if v2: + padding_size = (kernel_size - 1) // 2 + if use_inpaint: + self.audio_input_proj = nn.Sequential( + ChannelLastConv1d(latent_dim*2, hidden_dim, kernel_size=kernel_size, padding=padding_size), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=kernel_size, padding=padding_size), + ) + else: + self.audio_input_proj = nn.Sequential( + ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=kernel_size, padding=padding_size), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=kernel_size, padding=padding_size), + ) + + self.clip_input_proj = nn.Sequential( + nn.Linear(clip_dim, hidden_dim), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + sync_pad = (sync_kernel - 1) // 2 + self.sync_input_proj = nn.Sequential( + ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=sync_kernel, padding=sync_pad), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.text_input_proj = nn.Sequential( + nn.Linear(text_dim, hidden_dim), + nn.SiLU(), + MLP(hidden_dim, hidden_dim * 4), + ) + else: + self.audio_input_proj = nn.Sequential( + ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3), + nn.SELU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3), + ) + + self.clip_input_proj = nn.Sequential( + nn.Linear(clip_dim, hidden_dim), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.sync_input_proj = nn.Sequential( + ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3), + nn.SELU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.text_input_proj = nn.Sequential( + nn.Linear(text_dim, hidden_dim), + MLP(hidden_dim, hidden_dim * 4), + ) + + self.clip_cond_proj = nn.Linear(hidden_dim, hidden_dim) + if use_mlp: + self.text_cond_proj = nn.Sequential( + nn.Linear(1024, hidden_dim), + MLP(hidden_dim, hidden_dim * 4), + ) + else: + self.text_cond_proj = nn.Linear(1024, hidden_dim) + self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4) + # each synchformer output segment has 8 feature frames + self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, sync_dim))) + + self.final_layer = FinalBlock(hidden_dim, latent_dim) + + if v2: + self.t_embed = TimestepEmbedder(hidden_dim, + frequency_embedding_size=hidden_dim, + max_period=1) + else: + self.t_embed = TimestepEmbedder(hidden_dim, + frequency_embedding_size=256, + max_period=10000) + self.joint_blocks = nn.ModuleList([ + JointBlock(hidden_dim, + num_heads, + mlp_ratio=mlp_ratio, + pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth) + ]) + + self.fused_blocks = nn.ModuleList([ + MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=kernel_size, padding=padding_size, cross_attend=cross_attend) + for i in range(fused_depth) + ]) + + if empty_string_feat is None: + empty_string_feat = torch.zeros((77, 1024)) + + empty_t5_feat = torch.zeros((77, 2048)) + + self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False) + self.empty_t5_feat = nn.Parameter(empty_t5_feat, requires_grad=False) + self.empty_clip_feat = nn.Parameter(torch.zeros(1, clip_dim), requires_grad=True) + self.empty_sync_feat = nn.Parameter(torch.zeros(1, sync_dim), requires_grad=True) + + self.initialize_weights() + self.initialize_rotations() + + def initialize_rotations(self): + base_freq = 1.0 + latent_rot = compute_rope_rotations(self._latent_seq_len, + self.hidden_dim // self.num_heads, + 10000, + freq_scaling=base_freq, + device=self.device) + clip_rot = compute_rope_rotations(self._clip_seq_len, + self.hidden_dim // self.num_heads, + 10000, + freq_scaling=base_freq * self._latent_seq_len / + self._clip_seq_len, + device=self.device) + + self.latent_rot = nn.Buffer(latent_rot, persistent=False) + self.clip_rot = nn.Buffer(clip_rot, persistent=False) + + def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None: + self._latent_seq_len = latent_seq_len + self._clip_seq_len = clip_seq_len + self._sync_seq_len = sync_seq_len + self.initialize_rotations() + + def initialize_weights(self): + + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.joint_blocks: + nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.clip_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.clip_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0) + for block in self.fused_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.conv.weight, 0) + nn.init.constant_(self.final_layer.conv.bias, 0) + + # empty string feat shall be initialized by a CLIP encoder + nn.init.constant_(self.sync_pos_emb, 0) + nn.init.constant_(self.empty_clip_feat, 0) + nn.init.constant_(self.empty_sync_feat, 0) + + def preprocess_conditions(self, clip_f: torch.Tensor, sync_f: torch.Tensor, + text_f: torch.Tensor, t5_features: torch.Tensor, metaclip_global_text_features: torch.Tensor) -> PreprocessedConditions: + """ + cache computations that do not depend on the latent/time step + i.e., the features are reused over steps during inference + """ + # breakpoint() + assert clip_f.shape[1] == self._clip_seq_len, f'{clip_f.shape=} {self._clip_seq_len=}' + assert sync_f.shape[1] == self._sync_seq_len, f'{sync_f.shape=} {self._sync_seq_len=}' + assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}' + + bs = clip_f.shape[0] + + # B * num_segments (24) * 8 * 768 + num_sync_segments = self._sync_seq_len // 8 + sync_f = sync_f.view(bs, num_sync_segments, 8, -1) + self.sync_pos_emb + sync_f = sync_f.flatten(1, 2) # (B, VN, D) + + # extend vf to match x + clip_f = self.clip_input_proj(clip_f) # (B, VN, D) + sync_f = self.sync_input_proj(sync_f) # (B, VN, D) + + if t5_features is not None: + + if metaclip_global_text_features is not None: + text_f_c = self.text_cond_proj(metaclip_global_text_features) # (B, D) + else: + text_f_c = self.text_cond_proj(text_f.mean(dim=1)) # (B, D) + # 计算填充长度 + padding_size = t5_features.size(2) - text_f.size(2) # 渴望填充的数量 + # 当确实需要填充的时候,确保填充是正数 + if padding_size > 0: + # 填充 text_f 的特征维度两侧 + text_f = F.pad(text_f, pad=(0, padding_size), mode='constant', value=0) # 在最后一个维度上进行填充 + else: + text_f = text_f # 如果填充长度不是正数,则不需要填充 + text_concat = torch.cat((text_f, t5_features), dim=1) + text_f = self.text_input_proj(text_concat) # (B, VN, D) + else: + text_f = self.text_input_proj(text_f) # (B, VN, D) + if metaclip_global_text_features is not None: + text_f_c = self.text_cond_proj(metaclip_global_text_features) # (B, D) + else: + text_f_c = self.text_cond_proj(text_f.mean(dim=1)) # (B, D) + + # upsample the sync features to match the audio + sync_f = sync_f.transpose(1, 2) # (B, D, VN) + # sync_f = resample(sync_f, self._latent_seq_len) + sync_f = F.interpolate(sync_f, size=self._latent_seq_len, mode='nearest-exact') + sync_f = sync_f.transpose(1, 2) # (B, N, D) + + # get conditional features from the clip side + clip_f_c = self.clip_cond_proj(clip_f.mean(dim=1)) # (B, D) + + return PreprocessedConditions(clip_f=clip_f, + sync_f=sync_f, + text_f=text_f, + clip_f_c=clip_f_c, + text_f_c=text_f_c) + + def predict_flow(self, latent: torch.Tensor, t: torch.Tensor, + conditions: PreprocessedConditions, inpaint_masked_input=None, cfg_scale:float=1.0,cfg_dropout_prob:float=0.0,scale_phi:float=0.0 + ) -> torch.Tensor: + """ + for non-cacheable computations + """ + # print(f'cfg_scale: {cfg_scale}, cfg_dropout_prob: {cfg_dropout_prob}, scale_phi: {scale_phi}') + assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}' + empty_conditions = None + if inpaint_masked_input is not None: + inpaint_masked_input = inpaint_masked_input.transpose(1,2) + clip_f = conditions.clip_f + sync_f = conditions.sync_f + text_f = conditions.text_f + clip_f_c = conditions.clip_f_c + text_f_c = conditions.text_f_c + + # breakpoint() + if inpaint_masked_input is not None: + latent = torch.cat([latent,inpaint_masked_input],dim=2) + latent = self.audio_input_proj(latent) # (B, N, D) + global_c = self.global_cond_mlp(clip_f_c + text_f_c) # (B, D) + # global_c = text_f_c + global_c = self.t_embed(t).unsqueeze(1) + global_c.unsqueeze(1) # (B, D) + extended_c = global_c + sync_f + + for block in self.joint_blocks: + latent, clip_f, text_f = block(latent, clip_f, text_f, global_c, extended_c, + self.latent_rot, self.clip_rot) # (B, N, D) + if self.add_video: + if clip_f.shape[1] != latent.shape[1]: + clip_f = resample(clip_f, latent) + + if self.triple_fusion: + text_f = torch.mean(text_f, dim=1, keepdim=True) # (bsz, 1, D) + text_f = text_f.expand(-1,latent.shape[1], -1) # (T_audio, D) + fusion = torch.concat((latent, clip_f, text_f),dim=-1) + gate_v = self.gated_mlp_v(fusion) + gate_t = self.gated_mlp_t(fusion) + # modulated_latent = gate * latent # 非对称设计 + latent = latent + gate_v * clip_f + gate_t * text_f + elif self.gated_video: + fusion = torch.concat((latent, clip_f),dim=-1) + gate = self.gated_mlp(fusion) + modulated_latent = gate * latent # 非对称设计 + latent = latent + modulated_latent + else: + latent = latent + clip_f + + for block in self.fused_blocks: + if self.cross_attend: + latent = block(latent, extended_c, self.latent_rot, context=text_f) + else: + latent = block(latent, extended_c, self.latent_rot) + + # should be extended_c; this is a minor implementation error #55 + flow = self.final_layer(latent, extended_c) # (B, N, out_dim), remove t + return flow + + def forward(self, latent: torch.Tensor, t: torch.Tensor, clip_f: torch.Tensor, sync_f: torch.Tensor, + text_f: torch.Tensor, inpaint_masked_input, t5_features, metaclip_global_text_features, cfg_scale:float,cfg_dropout_prob:float,scale_phi:float) -> torch.Tensor: + """ + latent: (B, N, C) + vf: (B, T, C_V) + t: (B,) + """ + # breakpoint() + # print(f'cfg_scale: {cfg_scale}, cfg_dropout_prob: {cfg_dropout_prob}, scale_phi: {scale_phi}') + if self.use_inpaint and inpaint_masked_input is None: + inpaint_masked_input = torch.zeros_like(latent, device=latent.device) + latent = latent.permute(0, 2, 1) + + if cfg_dropout_prob > 0.0: + if inpaint_masked_input is not None: + null_embed = torch.zeros_like(inpaint_masked_input,device=latent.device) + dropout_mask = torch.bernoulli(torch.full((inpaint_masked_input.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + inpaint_masked_input = torch.where(dropout_mask, null_embed, inpaint_masked_input) + + null_embed = torch.zeros_like(clip_f,device=latent.device) + dropout_mask = torch.bernoulli(torch.full((clip_f.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + # clip_f = torch.where(dropout_mask, null_embed, clip_f) + clip_f = torch.where(dropout_mask, self.empty_clip_feat, clip_f) + null_embed = torch.zeros_like(sync_f,device=latent.device) + dropout_mask = torch.bernoulli(torch.full((sync_f.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + # sync_f = torch.where(dropout_mask, null_embed, sync_f) + sync_f = torch.where(dropout_mask, self.empty_sync_feat, sync_f) + null_embed = torch.zeros_like(text_f,device=latent.device) + dropout_mask = torch.bernoulli(torch.full((text_f.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + # text_f = torch.where(dropout_mask, null_embed, text_f) + text_f = torch.where(dropout_mask, self.empty_string_feat, text_f) + if t5_features is not None: + null_embed = torch.zeros_like(t5_features,device=latent.device) + dropout_mask = torch.bernoulli(torch.full((t5_features.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + # t5_features = torch.where(dropout_mask, null_embed, t5_features) + t5_features = torch.where(dropout_mask, self.empty_t5_feat, t5_features) + if metaclip_global_text_features is not None: + null_embed = torch.zeros_like(metaclip_global_text_features,device=latent.device) + dropout_mask = torch.bernoulli(torch.full((metaclip_global_text_features.shape[0], 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + metaclip_global_text_features = torch.where(dropout_mask, null_embed, metaclip_global_text_features) + # null_embed = torch.zeros_like(clip_f_c,device=latent.device) + # dropout_mask = torch.bernoulli(torch.full((clip_f_c.shape[0], 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + # clip_f_c = torch.where(dropout_mask, null_embed, clip_f_c) + # null_embed = torch.zeros_like(text_f_c,device=latent.device) + # dropout_mask = torch.bernoulli(torch.full((text_f_c.shape[0], 1), cfg_dropout_prob, device=latent.device)).to(torch.bool) + # text_f_c = torch.where(dropout_mask, null_embed, text_f_c) + + if cfg_scale != 1.0: + # empty_conditions = self.get_empty_conditions(latent.shape[0]) + # breakpoint() + bsz = latent.shape[0] + latent = torch.cat([latent,latent], dim=0) + if inpaint_masked_input is not None: + empty_inpaint_masked_input = torch.zeros_like(inpaint_masked_input, device=latent.device) + inpaint_masked_input = torch.cat([inpaint_masked_input,empty_inpaint_masked_input], dim=0) + t = torch.cat([t, t], dim=0) + empty_clip_f = torch.zeros_like(clip_f, device=latent.device) + empty_sync_f = torch.zeros_like(sync_f, device=latent.device) + empty_text_f = torch.zeros_like(text_f, device=latent.device) + + # clip_f = torch.cat([clip_f,empty_clip_f], dim=0) + # sync_f = torch.cat([sync_f,empty_sync_f], dim=0) + # text_f = torch.cat([text_f,empty_text_f], dim=0) + clip_f = torch.cat([clip_f,self.get_empty_clip_sequence(bsz)], dim=0) + sync_f = torch.cat([sync_f,self.get_empty_sync_sequence(bsz)], dim=0) + text_f = torch.cat([text_f,self.get_empty_string_sequence(bsz)], dim=0) + if t5_features is not None: + empty_t5_features = torch.zeros_like(t5_features, device=latent.device) + # t5_features = torch.cat([t5_features,empty_t5_features], dim=0) + t5_features = torch.cat([t5_features,self.get_empty_t5_sequence(bsz)], dim=0) + if metaclip_global_text_features is not None: + empty_metaclip_global_text_features = torch.zeros_like(metaclip_global_text_features, device=latent.device) + metaclip_global_text_features = torch.cat([metaclip_global_text_features,empty_metaclip_global_text_features], dim=0) + # metaclip_global_text_features = torch.cat([metaclip_global_text_features,metaclip_global_text_features], dim=0) + # clip_f_c = torch.cat([clip_f_c,empty_clip_f_c], dim=0) + # text_f_c = torch.cat([text_f_c,empty_text_f_c], dim=0) + + conditions = self.preprocess_conditions(clip_f, sync_f, text_f, t5_features, metaclip_global_text_features) + flow = self.predict_flow(latent, t, conditions, inpaint_masked_input, cfg_scale,cfg_dropout_prob,scale_phi) + if cfg_scale != 1.0: + cond_output, uncond_output = torch.chunk(flow, 2, dim=0) + cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale + if scale_phi != 0.0: + cond_out_std = cond_output.std(dim=1, keepdim=True) + out_cfg_std = cfg_output.std(dim=1, keepdim=True) + flow = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output + else: + flow = cfg_output + flow = flow.permute(0, 2, 1) + return flow + + def get_empty_string_sequence(self, bs: int) -> torch.Tensor: + return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1) + + def get_empty_t5_sequence(self, bs: int) -> torch.Tensor: + return self.empty_t5_feat.unsqueeze(0).expand(bs, -1, -1) + + def get_empty_clip_sequence(self, bs: int) -> torch.Tensor: + return self.empty_clip_feat.unsqueeze(0).expand(bs, self._clip_seq_len, -1) + + def get_empty_sync_sequence(self, bs: int) -> torch.Tensor: + return self.empty_sync_feat.unsqueeze(0).expand(bs, self._sync_seq_len, -1) + + def get_empty_conditions( + self, + bs: int, + *, + negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions: + if negative_text_features is not None: + empty_text = negative_text_features + else: + empty_text = self.get_empty_string_sequence(1) + + empty_clip = self.get_empty_clip_sequence(1) + empty_sync = self.get_empty_sync_sequence(1) + conditions = self.preprocess_conditions(empty_clip, empty_sync, empty_text) + conditions.clip_f = conditions.clip_f.expand(bs, -1, -1) + conditions.sync_f = conditions.sync_f.expand(bs, -1, -1) + conditions.clip_f_c = conditions.clip_f_c.expand(bs, -1) + if negative_text_features is None: + conditions.text_f = conditions.text_f.expand(bs, -1, -1) + conditions.text_f_c = conditions.text_f_c.expand(bs, -1) + + return conditions + + def load_weights(self, src_dict) -> None: + if 't_embed.freqs' in src_dict: + del src_dict['t_embed.freqs'] + if 'latent_rot' in src_dict: + del src_dict['latent_rot'] + if 'clip_rot' in src_dict: + del src_dict['clip_rot'] + + self.load_state_dict(src_dict, strict=True) + + @property + def device(self) -> torch.device: + return self.empty_clip_feat.device + + @property + def latent_seq_len(self) -> int: + return self._latent_seq_len + + @property + def clip_seq_len(self) -> int: + return self._clip_seq_len + + @property + def sync_seq_len(self) -> int: + return self._sync_seq_len + diff --git a/think_sound/models/mmmodules/__init__.py b/think_sound/models/mmmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/think_sound/models/mmmodules/__pycache__/__init__.cpython-310.pyc b/think_sound/models/mmmodules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9c6dc76969b971214d33f622aef5e6c05c78e23 Binary files /dev/null and b/think_sound/models/mmmodules/__pycache__/__init__.cpython-310.pyc differ diff --git a/think_sound/models/mmmodules/__pycache__/__init__.cpython-39.pyc b/think_sound/models/mmmodules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4bde87364e1f4176f489cc3015cbd1cdcf7446c Binary files /dev/null and b/think_sound/models/mmmodules/__pycache__/__init__.cpython-39.pyc differ diff --git a/think_sound/models/mmmodules/ext/__init__.py b/think_sound/models/mmmodules/ext/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/think_sound/models/mmmodules/ext/__init__.py @@ -0,0 +1 @@ + diff --git a/think_sound/models/mmmodules/ext/__pycache__/__init__.cpython-310.pyc b/think_sound/models/mmmodules/ext/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b85fdf9a28490bf9b437551f7a1634092204900e Binary files /dev/null and b/think_sound/models/mmmodules/ext/__pycache__/__init__.cpython-310.pyc differ diff --git a/think_sound/models/mmmodules/ext/__pycache__/__init__.cpython-39.pyc b/think_sound/models/mmmodules/ext/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42fe0fad8c356a3962d193a463180a237f127c90 Binary files /dev/null and b/think_sound/models/mmmodules/ext/__pycache__/__init__.cpython-39.pyc differ diff --git a/think_sound/models/mmmodules/ext/__pycache__/rotary_embeddings.cpython-310.pyc b/think_sound/models/mmmodules/ext/__pycache__/rotary_embeddings.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae3f4750b54aa26378a8a542c4d11355c1024790 Binary files /dev/null and b/think_sound/models/mmmodules/ext/__pycache__/rotary_embeddings.cpython-310.pyc differ diff --git a/think_sound/models/mmmodules/ext/__pycache__/rotary_embeddings.cpython-39.pyc b/think_sound/models/mmmodules/ext/__pycache__/rotary_embeddings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a19ebe09f3893f90cb4f1ddfc74af285a8a1cb1 Binary files /dev/null and b/think_sound/models/mmmodules/ext/__pycache__/rotary_embeddings.cpython-39.pyc differ diff --git a/think_sound/models/mmmodules/ext/rotary_embeddings.py b/think_sound/models/mmmodules/ext/rotary_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea9d56278cb68b7577ed13148227c30ed98fd02 --- /dev/null +++ b/think_sound/models/mmmodules/ext/rotary_embeddings.py @@ -0,0 +1,35 @@ +from typing import Union + +import torch +from einops import rearrange +from torch import Tensor + +# Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py +# Ref: https://github.com/lucidrains/rotary-embedding-torch + + +def compute_rope_rotations(length: int, + dim: int, + theta: int, + *, + freq_scaling: float = 1.0, + device: Union[torch.device, str] = 'cpu') -> Tensor: + assert dim % 2 == 0 + + with torch.amp.autocast(device_type='cuda', enabled=False): + pos = torch.arange(length, dtype=torch.float32, device=device) + freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) + freqs *= freq_scaling + + rot = torch.einsum('..., f -> ... f', pos, freqs) + rot = torch.stack([torch.cos(rot), -torch.sin(rot), torch.sin(rot), torch.cos(rot)], dim=-1) + rot = rearrange(rot, 'n d (i j) -> 1 n d i j', i=2, j=2) + return rot + + +def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]: + with torch.amp.autocast(device_type='cuda', enabled=False): + _x = x.float() + _x = _x.view(*_x.shape[:-1], -1, 1, 2) + x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1] + return x_out.reshape(*x.shape).to(dtype=x.dtype) diff --git a/think_sound/models/mmmodules/ext/stft_converter.py b/think_sound/models/mmmodules/ext/stft_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..62922067ef3b1d3b8727ec39e7d664ccb304d9fe --- /dev/null +++ b/think_sound/models/mmmodules/ext/stft_converter.py @@ -0,0 +1,183 @@ +# Reference: # https://github.com/bytedance/Make-An-Audio-2 + +import torch +import torch.nn as nn +import torchaudio +from einops import rearrange +from librosa.filters import mel as librosa_mel_fn + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, norm_fn=torch.log10): + return norm_fn(torch.clamp(x, min=clip_val) * C) + + +def spectral_normalize_torch(magnitudes, norm_fn): + output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) + return output + + +class STFTConverter(nn.Module): + + def __init__( + self, + *, + sampling_rate: float = 16_000, + n_fft: int = 1024, + num_mels: int = 128, + hop_size: int = 256, + win_size: int = 1024, + fmin: float = 0, + fmax: float = 8_000, + norm_fn=torch.log, + ): + super().__init__() + self.sampling_rate = sampling_rate + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.norm_fn = norm_fn + + mel = librosa_mel_fn(sr=self.sampling_rate, + n_fft=self.n_fft, + n_mels=self.num_mels, + fmin=self.fmin, + fmax=self.fmax) + mel_basis = torch.from_numpy(mel).float() + hann_window = torch.hann_window(self.win_size) + + self.register_buffer('mel_basis', mel_basis) + self.register_buffer('hann_window', hann_window) + + @property + def device(self): + return self.hann_window.device + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + # input: batch_size * length + bs = waveform.shape[0] + waveform = waveform.clamp(min=-1., max=1.) + + spec = torch.stft(waveform, + self.n_fft, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=True, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True) + + spec = torch.view_as_real(spec) + # print('After stft', spec.shape, spec.min(), spec.max(), spec.mean()) + + power = spec.pow(2).sum(-1) + angle = torch.atan2(spec[..., 1], spec[..., 0]) + + print('power', power.shape, power.min(), power.max(), power.mean()) + print('angle', angle.shape, angle.min(), angle.max(), angle.mean()) + + # print('mel', self.mel_basis.shape, self.mel_basis.min(), self.mel_basis.max(), + # self.mel_basis.mean()) + + # spec = rearrange(spec, 'b f t c -> (b c) f t') + + # spec = self.mel_transform(spec) + + # spec = torch.matmul(self.mel_basis, spec) + + # print('After mel', spec.shape, spec.min(), spec.max(), spec.mean()) + + # spec = spectral_normalize_torch(spec, self.norm_fn) + + # print('After norm', spec.shape, spec.min(), spec.max(), spec.mean()) + + # compute magnitude + # magnitude = torch.sqrt((spec**2).sum(-1)) + # normalize by magnitude + # scaled_magnitude = torch.log10(magnitude.clamp(min=1e-5)) * 10 + # spec = spec / magnitude.unsqueeze(-1) * scaled_magnitude.unsqueeze(-1) + + # power = torch.log10(power.clamp(min=1e-5)) * 10 + power = torch.log10(power.clamp(min=1e-5)) + + print('After scaling', power.shape, power.min(), power.max(), power.mean()) + + spec = torch.stack([power, angle], dim=-1) + + # spec = rearrange(spec, '(b c) f t -> b c f t', b=bs) + spec = rearrange(spec, 'b f t c -> b c f t', b=bs) + + # spec[:, :, 400:] = 0 + + return spec + + def invert(self, spec: torch.Tensor, length: int) -> torch.Tensor: + bs = spec.shape[0] + + # spec = rearrange(spec, 'b c f t -> (b c) f t') + # print(spec.shape, self.mel_basis.shape) + # spec = torch.linalg.lstsq(self.mel_basis.unsqueeze(0), spec).solution + # spec = torch.linalg.pinv(self.mel_basis.unsqueeze(0)) @ spec + + # spec = self.invmel_transform(spec) + + spec = rearrange(spec, 'b c f t -> b f t c', b=bs).contiguous() + + # spec[..., 0] = 10**(spec[..., 0] / 10) + + power = spec[..., 0] + power = 10**power + + # print('After unscaling', spec[..., 0].shape, spec[..., 0].min(), spec[..., 0].max(), + # spec[..., 0].mean()) + + unit_vector = torch.stack([ + torch.cos(spec[..., 1]), + torch.sin(spec[..., 1]), + ], dim=-1) + + spec = torch.sqrt(power) * unit_vector + + # spec = rearrange(spec, '(b c) f t -> b f t c', b=bs).contiguous() + spec = torch.view_as_complex(spec) + + waveform = torch.istft( + spec, + self.n_fft, + length=length, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=True, + normalized=False, + onesided=True, + return_complex=False, + ) + + return waveform + + +if __name__ == '__main__': + + converter = STFTConverter(sampling_rate=16000) + + signal = torchaudio.load('./output/ZZ6GRocWW38_000090.wav')[0] + # resample signal at 44100 Hz + # signal = torchaudio.transforms.Resample(16_000, 44_100)(signal) + + L = signal.shape[1] + print('Input signal', signal.shape) + spec = converter(signal) + + print('Final spec', spec.shape) + + signal_recon = converter.invert(spec, length=L) + print('Output signal', signal_recon.shape, signal_recon.min(), signal_recon.max(), + signal_recon.mean()) + + print('MSE', torch.nn.functional.mse_loss(signal, signal_recon)) + torchaudio.save('./output/ZZ6GRocWW38_000090_recon.wav', signal_recon, 16000) diff --git a/think_sound/models/mmmodules/ext/stft_converter_mel.py b/think_sound/models/mmmodules/ext/stft_converter_mel.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b32d4cb9a23cd74f723e7d8307fd82fa1abba0 --- /dev/null +++ b/think_sound/models/mmmodules/ext/stft_converter_mel.py @@ -0,0 +1,234 @@ +# Reference: # https://github.com/bytedance/Make-An-Audio-2 + +import torch +import torch.nn as nn +import torchaudio +from einops import rearrange +from librosa.filters import mel as librosa_mel_fn + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, norm_fn=torch.log10): + return norm_fn(torch.clamp(x, min=clip_val) * C) + + +def spectral_normalize_torch(magnitudes, norm_fn): + output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn) + return output + + +class STFTConverter(nn.Module): + + def __init__( + self, + *, + sampling_rate: float = 16_000, + n_fft: int = 1024, + num_mels: int = 128, + hop_size: int = 256, + win_size: int = 1024, + fmin: float = 0, + fmax: float = 8_000, + norm_fn=torch.log, + ): + super().__init__() + self.sampling_rate = sampling_rate + self.n_fft = n_fft + self.num_mels = num_mels + self.hop_size = hop_size + self.win_size = win_size + self.fmin = fmin + self.fmax = fmax + self.norm_fn = norm_fn + + mel = librosa_mel_fn(sr=self.sampling_rate, + n_fft=self.n_fft, + n_mels=self.num_mels, + fmin=self.fmin, + fmax=self.fmax) + mel_basis = torch.from_numpy(mel).float() + hann_window = torch.hann_window(self.win_size) + + self.register_buffer('mel_basis', mel_basis) + self.register_buffer('hann_window', hann_window) + + @property + def device(self): + return self.hann_window.device + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + # input: batch_size * length + bs = waveform.shape[0] + waveform = waveform.clamp(min=-1., max=1.) + + spec = torch.stft(waveform, + self.n_fft, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=True, + pad_mode='reflect', + normalized=False, + onesided=True, + return_complex=True) + + spec = torch.view_as_real(spec) + # print('After stft', spec.shape, spec.min(), spec.max(), spec.mean()) + + power = (spec.pow(2).sum(-1))**(0.5) + angle = torch.atan2(spec[..., 1], spec[..., 0]) + + print('power 1', power.shape, power.min(), power.max(), power.mean()) + print('angle 1', angle.shape, angle.min(), angle.max(), angle.mean(), angle[:, :2, :2]) + + # print('mel', self.mel_basis.shape, self.mel_basis.min(), self.mel_basis.max(), + # self.mel_basis.mean()) + + # spec = self.mel_transform(spec) + + # power = torch.matmul(self.mel_basis, power) + + spec = rearrange(spec, 'b f t c -> (b c) f t') + spec = self.mel_basis.unsqueeze(0) @ spec + spec = rearrange(spec, '(b c) f t -> b f t c', b=bs) + + power = (spec.pow(2).sum(-1))**(0.5) + angle = torch.atan2(spec[..., 1], spec[..., 0]) + + print('power', power.shape, power.min(), power.max(), power.mean()) + print('angle', angle.shape, angle.min(), angle.max(), angle.mean(), angle[:, :2, :2]) + + # print('After mel', spec.shape, spec.min(), spec.max(), spec.mean()) + + # spec = spectral_normalize_torch(spec, self.norm_fn) + + # print('After norm', spec.shape, spec.min(), spec.max(), spec.mean()) + + # compute magnitude + # magnitude = torch.sqrt((spec**2).sum(-1)) + # normalize by magnitude + # scaled_magnitude = torch.log10(magnitude.clamp(min=1e-5)) * 10 + # spec = spec / magnitude.unsqueeze(-1) * scaled_magnitude.unsqueeze(-1) + + # power = torch.log10(power.clamp(min=1e-5)) * 10 + power = torch.log10(power.clamp(min=1e-8)) + + print('After scaling', power.shape, power.min(), power.max(), power.mean()) + + # spec = torch.stack([power, angle], dim=-1) + + # spec = rearrange(spec, '(b c) f t -> b c f t', b=bs) + # spec = rearrange(spec, 'b f t c -> b c f t', b=bs) + + # spec[:, :, 400:] = 0 + + return power, angle + # return spec[..., 0], spec[..., 1] + + def invert(self, spec: torch.Tensor, length: int) -> torch.Tensor: + + power, angle = spec + + bs = power.shape[0] + + # spec = rearrange(spec, 'b c f t -> (b c) f t') + # print(spec.shape, self.mel_basis.shape) + # spec = torch.linalg.lstsq(self.mel_basis.unsqueeze(0), spec).solution + # spec = torch.linalg.pinv(self.mel_basis.unsqueeze(0)) @ spec + + # spec = self.invmel_transform(spec) + + # spec = rearrange(spec, 'b c f t -> b f t c', b=bs).contiguous() + + # spec[..., 0] = 10**(spec[..., 0] / 10) + + # power = spec[..., 0] + power = 10**power + + # print('After unscaling', spec[..., 0].shape, spec[..., 0].min(), spec[..., 0].max(), + # spec[..., 0].mean()) + + unit_vector = torch.stack([ + torch.cos(angle), + torch.sin(angle), + ], dim=-1) + + spec = power.unsqueeze(-1) * unit_vector + + # power = torch.linalg.lstsq(self.mel_basis.unsqueeze(0), power).solution + spec = rearrange(spec, 'b f t c -> (b c) f t') + spec = torch.linalg.pinv(self.mel_basis.unsqueeze(0)) @ spec + # spec = torch.linalg.lstsq(self.mel_basis.unsqueeze(0), spec).solution + spec = rearrange(spec, '(b c) f t -> b f t c', b=bs).contiguous() + + power = (spec.pow(2).sum(-1))**(0.5) + angle = torch.atan2(spec[..., 1], spec[..., 0]) + + print('power 2', power.shape, power.min(), power.max(), power.mean()) + print('angle 2', angle.shape, angle.min(), angle.max(), angle.mean(), angle[:, :2, :2]) + + # spec = rearrange(spec, '(b c) f t -> b f t c', b=bs).contiguous() + spec = torch.view_as_complex(spec) + + waveform = torch.istft( + spec, + self.n_fft, + length=length, + hop_length=self.hop_size, + win_length=self.win_size, + window=self.hann_window, + center=True, + normalized=False, + onesided=True, + return_complex=False, + ) + + return waveform + + +if __name__ == '__main__': + + converter = STFTConverter(sampling_rate=16000) + + signal = torchaudio.load('./output/ZZ6GRocWW38_000090.wav')[0] + # resample signal at 44100 Hz + # signal = torchaudio.transforms.Resample(16_000, 44_100)(signal) + + L = signal.shape[1] + print('Input signal', signal.shape) + spec = converter(signal) + + power, angle = spec + + # print(power.shape, angle.shape) + # print(power, power.min(), power.max(), power.mean()) + # power = power.clamp(-1, 1) + # angle = angle.clamp(-1, 1) + + import matplotlib.pyplot as plt + + # Visualize power + plt.figure() + plt.imshow(power[0].detach().numpy(), aspect='auto', origin='lower') + plt.colorbar() + plt.title('Power') + plt.xlabel('Time') + plt.ylabel('Frequency') + plt.savefig('./output/power.png') + + # Visualize angle + plt.figure() + plt.imshow(angle[0].detach().numpy(), aspect='auto', origin='lower') + plt.colorbar() + plt.title('Angle') + plt.xlabel('Time') + plt.ylabel('Frequency') + plt.savefig('./output/angle.png') + + # print('Final spec', spec.shape) + + signal_recon = converter.invert(spec, length=L) + print('Output signal', signal_recon.shape, signal_recon.min(), signal_recon.max(), + signal_recon.mean()) + + print('MSE', torch.nn.functional.mse_loss(signal, signal_recon)) + torchaudio.save('./output/ZZ6GRocWW38_000090_recon.wav', signal_recon, 16000) diff --git a/think_sound/models/mmmodules/model/__init__.py b/think_sound/models/mmmodules/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/think_sound/models/mmmodules/model/__pycache__/__init__.cpython-310.pyc b/think_sound/models/mmmodules/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4401799caf1e7efffce5ba20eee3e4e6c4cf18ed Binary files /dev/null and b/think_sound/models/mmmodules/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/think_sound/models/mmmodules/model/__pycache__/__init__.cpython-39.pyc b/think_sound/models/mmmodules/model/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe6460de12d696b4d8a4f8f9f6af4713a12b8d2e Binary files /dev/null and b/think_sound/models/mmmodules/model/__pycache__/__init__.cpython-39.pyc differ diff --git a/think_sound/models/mmmodules/model/__pycache__/embeddings.cpython-310.pyc b/think_sound/models/mmmodules/model/__pycache__/embeddings.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8182934cbcf7d20c9145b17a9295f62bf0b03bf Binary files /dev/null and b/think_sound/models/mmmodules/model/__pycache__/embeddings.cpython-310.pyc differ diff --git a/think_sound/models/mmmodules/model/__pycache__/embeddings.cpython-39.pyc b/think_sound/models/mmmodules/model/__pycache__/embeddings.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9f111166db2dd4442d2227103547d9cf522f6d7 Binary files /dev/null and b/think_sound/models/mmmodules/model/__pycache__/embeddings.cpython-39.pyc differ diff --git a/think_sound/models/mmmodules/model/__pycache__/low_level.cpython-310.pyc b/think_sound/models/mmmodules/model/__pycache__/low_level.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca349e5522944bf0a9fe9a1d162fe9bd3cbb3eb5 Binary files /dev/null and b/think_sound/models/mmmodules/model/__pycache__/low_level.cpython-310.pyc differ diff --git a/think_sound/models/mmmodules/model/__pycache__/low_level.cpython-39.pyc b/think_sound/models/mmmodules/model/__pycache__/low_level.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb6982d89c721052d73cef49f074fbb58456d388 Binary files /dev/null and b/think_sound/models/mmmodules/model/__pycache__/low_level.cpython-39.pyc differ diff --git a/think_sound/models/mmmodules/model/__pycache__/transformer_layers.cpython-310.pyc b/think_sound/models/mmmodules/model/__pycache__/transformer_layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89842e8912466aa6e1659f3a4c77cc975da0f0f8 Binary files /dev/null and b/think_sound/models/mmmodules/model/__pycache__/transformer_layers.cpython-310.pyc differ diff --git a/think_sound/models/mmmodules/model/__pycache__/transformer_layers.cpython-39.pyc b/think_sound/models/mmmodules/model/__pycache__/transformer_layers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1aa3d468c975f0e977391704a453303d82305f7 Binary files /dev/null and b/think_sound/models/mmmodules/model/__pycache__/transformer_layers.cpython-39.pyc differ diff --git a/think_sound/models/mmmodules/model/embeddings.py b/think_sound/models/mmmodules/model/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..297feb4d2c79d306771f5436dbd4ada1a976b3bc --- /dev/null +++ b/think_sound/models/mmmodules/model/embeddings.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn + +# https://github.com/facebookresearch/DiT + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, dim, frequency_embedding_size, max_period): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, dim), + nn.SiLU(), + nn.Linear(dim, dim), + ) + self.dim = dim + self.max_period = max_period + assert dim % 2 == 0, 'dim must be even.' + + with torch.autocast('cuda', enabled=False): + self.freqs = nn.Buffer( + 1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) / + frequency_embedding_size)), + persistent=False) + freq_scale = 10000 / max_period + self.freqs = freq_scale * self.freqs + + def timestep_embedding(self, t): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + + args = t[:, None].float() * self.freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t).to(t.dtype) + t_emb = self.mlp(t_freq) + return t_emb diff --git a/think_sound/models/mmmodules/model/flow_matching.py b/think_sound/models/mmmodules/model/flow_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..e7c65dece6dec746db999092606f4384d084d119 --- /dev/null +++ b/think_sound/models/mmmodules/model/flow_matching.py @@ -0,0 +1,71 @@ +import logging +from typing import Callable, Optional + +import torch +from torchdiffeq import odeint + +log = logging.getLogger() + + +# Partially from https://github.com/gle-bellier/flow-matching +class FlowMatching: + + def __init__(self, min_sigma: float = 0.0, inference_mode='euler', num_steps: int = 25): + # inference_mode: 'euler' or 'adaptive' + # num_steps: number of steps in the euler inference mode + super().__init__() + self.min_sigma = min_sigma + self.inference_mode = inference_mode + self.num_steps = num_steps + + # self.fm = ExactOptimalTransportConditionalFlowMatcher(sigma=min_sigma) + + assert self.inference_mode in ['euler', 'adaptive'] + if self.inference_mode == 'adaptive' and num_steps > 0: + log.info('The number of steps is ignored in adaptive inference mode ') + + def get_conditional_flow(self, x0: torch.Tensor, x1: torch.Tensor, + t: torch.Tensor) -> torch.Tensor: + # which is psi_t(x), eq 22 in flow matching for generative models + t = t[:, None, None].expand_as(x0) + return (1 - (1 - self.min_sigma) * t) * x0 + t * x1 + + def loss(self, predicted_v: torch.Tensor, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: + # return the mean error without reducing the batch dimension + reduce_dim = list(range(1, len(predicted_v.shape))) + target_v = x1 - (1 - self.min_sigma) * x0 + return (predicted_v - target_v).pow(2).mean(dim=reduce_dim) + + def get_x0_xt_c( + self, + x1: torch.Tensor, + t: torch.Tensor, + Cs: list[torch.Tensor], + generator: Optional[torch.Generator] = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x0 = torch.empty_like(x1).normal_(generator=generator) + + xt = self.get_conditional_flow(x0, x1, t) + return x0, x1, xt, Cs + + def to_prior(self, fn: Callable, x1: torch.Tensor) -> torch.Tensor: + return self.run_t0_to_t1(fn, x1, 1, 0) + + def to_data(self, fn: Callable, x0: torch.Tensor) -> torch.Tensor: + return self.run_t0_to_t1(fn, x0, 0, 1) + + def run_t0_to_t1(self, fn: Callable, x0: torch.Tensor, t0: float, t1: float) -> torch.Tensor: + # fn: a function that takes (t, x) and returns the direction x0->x1 + + if self.inference_mode == 'adaptive': + return odeint(fn, x0, torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype)) + elif self.inference_mode == 'euler': + x = x0 + steps = torch.linspace(t0, t1 - self.min_sigma, self.num_steps + 1) + for ti, t in enumerate(steps[:-1]): + flow = fn(t, x) + next_t = steps[ti + 1] + dt = next_t - t + x = x + dt * flow + + return x diff --git a/think_sound/models/mmmodules/model/low_level.py b/think_sound/models/mmmodules/model/low_level.py new file mode 100644 index 0000000000000000000000000000000000000000..c8326a8bec99f1be08b92e76fda4b59e777b39d2 --- /dev/null +++ b/think_sound/models/mmmodules/model/low_level.py @@ -0,0 +1,95 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class ChannelLastConv1d(nn.Conv1d): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 2, 1) + x = super().forward(x) + x = x.permute(0, 2, 1) + return x + + +# https://github.com/Stability-AI/sd3-ref +class MLP(nn.Module): + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class ConvMLP(nn.Module): + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int = 256, + kernel_size: int = 3, + padding: int = 1, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ChannelLastConv1d(dim, + hidden_dim, + bias=False, + kernel_size=kernel_size, + padding=padding) + self.w2 = ChannelLastConv1d(hidden_dim, + dim, + bias=False, + kernel_size=kernel_size, + padding=padding) + self.w3 = ChannelLastConv1d(dim, + hidden_dim, + bias=False, + kernel_size=kernel_size, + padding=padding) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) diff --git a/think_sound/models/mmmodules/model/networks.py b/think_sound/models/mmmodules/model/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..8272a896b358f5db681d1462c4189d671b916d76 --- /dev/null +++ b/think_sound/models/mmmodules/model/networks.py @@ -0,0 +1,470 @@ +import logging +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmaudio.ext.rotary_embeddings import compute_rope_rotations +from mmaudio.model.embeddings import TimestepEmbedder +from mmaudio.model.low_level import MLP, ChannelLastConv1d, ConvMLP +from mmaudio.model.transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock) + +log = logging.getLogger() + + +@dataclass +class PreprocessedConditions: + clip_f: torch.Tensor + sync_f: torch.Tensor + text_f: torch.Tensor + clip_f_c: torch.Tensor + text_f_c: torch.Tensor + + +# Partially from https://github.com/facebookresearch/DiT +class MMAudio(nn.Module): + + def __init__(self, + *, + latent_dim: int, + clip_dim: int, + sync_dim: int, + text_dim: int, + hidden_dim: int, + depth: int, + fused_depth: int, + num_heads: int, + mlp_ratio: float = 4.0, + latent_seq_len: int, + clip_seq_len: int, + sync_seq_len: int, + text_seq_len: int = 77, + latent_mean: Optional[torch.Tensor] = None, + latent_std: Optional[torch.Tensor] = None, + empty_string_feat: Optional[torch.Tensor] = None, + v2: bool = False) -> None: + super().__init__() + + self.v2 = v2 + self.latent_dim = latent_dim + self._latent_seq_len = latent_seq_len + self._clip_seq_len = clip_seq_len + self._sync_seq_len = sync_seq_len + self._text_seq_len = text_seq_len + self.hidden_dim = hidden_dim + self.num_heads = num_heads + + if v2: + self.audio_input_proj = nn.Sequential( + ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3), + ) + + self.clip_input_proj = nn.Sequential( + nn.Linear(clip_dim, hidden_dim), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.sync_input_proj = nn.Sequential( + ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3), + nn.SiLU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.text_input_proj = nn.Sequential( + nn.Linear(text_dim, hidden_dim), + nn.SiLU(), + MLP(hidden_dim, hidden_dim * 4), + ) + else: + self.audio_input_proj = nn.Sequential( + ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3), + nn.SELU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3), + ) + + self.clip_input_proj = nn.Sequential( + nn.Linear(clip_dim, hidden_dim), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.sync_input_proj = nn.Sequential( + ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3), + nn.SELU(), + ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1), + ) + + self.text_input_proj = nn.Sequential( + nn.Linear(text_dim, hidden_dim), + MLP(hidden_dim, hidden_dim * 4), + ) + + self.clip_cond_proj = nn.Linear(hidden_dim, hidden_dim) + self.text_cond_proj = nn.Linear(hidden_dim, hidden_dim) + self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4) + # each synchformer output segment has 8 feature frames + self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, sync_dim))) + + self.final_layer = FinalBlock(hidden_dim, latent_dim) + + if v2: + self.t_embed = TimestepEmbedder(hidden_dim, + frequency_embedding_size=hidden_dim, + max_period=1) + else: + self.t_embed = TimestepEmbedder(hidden_dim, + frequency_embedding_size=256, + max_period=10000) + self.joint_blocks = nn.ModuleList([ + JointBlock(hidden_dim, + num_heads, + mlp_ratio=mlp_ratio, + pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth) + ]) + + self.fused_blocks = nn.ModuleList([ + MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=3, padding=1) + for i in range(fused_depth) + ]) + + if latent_mean is None: + # these values are not meant to be used + # if you don't provide mean/std here, we should load them later from a checkpoint + assert latent_std is None + latent_mean = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan')) + latent_std = torch.ones(latent_dim).view(1, 1, -1).fill_(float('nan')) + else: + assert latent_std is not None + assert latent_mean.numel() == latent_dim, f'{latent_mean.numel()=} != {latent_dim=}' + if empty_string_feat is None: + empty_string_feat = torch.zeros((text_seq_len, text_dim)) + self.latent_mean = nn.Parameter(latent_mean.view(1, 1, -1), requires_grad=False) + self.latent_std = nn.Parameter(latent_std.view(1, 1, -1), requires_grad=False) + + self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False) + self.empty_clip_feat = nn.Parameter(torch.zeros(1, clip_dim), requires_grad=True) + self.empty_sync_feat = nn.Parameter(torch.zeros(1, sync_dim), requires_grad=True) + + self.initialize_weights() + self.initialize_rotations() + + def initialize_rotations(self): + base_freq = 1.0 + latent_rot = compute_rope_rotations(self._latent_seq_len, + self.hidden_dim // self.num_heads, + 10000, + freq_scaling=base_freq, + device=self.device) + clip_rot = compute_rope_rotations(self._clip_seq_len, + self.hidden_dim // self.num_heads, + 10000, + freq_scaling=base_freq * self._latent_seq_len / + self._clip_seq_len, + device=self.device) + + self.latent_rot = nn.Buffer(latent_rot, persistent=False) + self.clip_rot = nn.Buffer(clip_rot, persistent=False) + + def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None: + self._latent_seq_len = latent_seq_len + self._clip_seq_len = clip_seq_len + self._sync_seq_len = sync_seq_len + self.initialize_rotations() + + def initialize_weights(self): + + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.joint_blocks: + nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.clip_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.clip_block.adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0) + for block in self.fused_blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.conv.weight, 0) + nn.init.constant_(self.final_layer.conv.bias, 0) + + # empty string feat shall be initialized by a CLIP encoder + nn.init.constant_(self.sync_pos_emb, 0) + nn.init.constant_(self.empty_clip_feat, 0) + nn.init.constant_(self.empty_sync_feat, 0) + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + # return (x - self.latent_mean) / self.latent_std + return x.sub_(self.latent_mean).div_(self.latent_std) + + def unnormalize(self, x: torch.Tensor) -> torch.Tensor: + # return x * self.latent_std + self.latent_mean + return x.mul_(self.latent_std).add_(self.latent_mean) + + def preprocess_conditions(self, clip_f: torch.Tensor, sync_f: torch.Tensor, + text_f: torch.Tensor) -> PreprocessedConditions: + """ + cache computations that do not depend on the latent/time step + i.e., the features are reused over steps during inference + """ + assert clip_f.shape[1] == self._clip_seq_len, f'{clip_f.shape=} {self._clip_seq_len=}' + assert sync_f.shape[1] == self._sync_seq_len, f'{sync_f.shape=} {self._sync_seq_len=}' + assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}' + + bs = clip_f.shape[0] + + # B * num_segments (24) * 8 * 768 + num_sync_segments = self._sync_seq_len // 8 + sync_f = sync_f.view(bs, num_sync_segments, 8, -1) + self.sync_pos_emb + sync_f = sync_f.flatten(1, 2) # (B, VN, D) + + # extend vf to match x + clip_f = self.clip_input_proj(clip_f) # (B, VN, D) + sync_f = self.sync_input_proj(sync_f) # (B, VN, D) + text_f = self.text_input_proj(text_f) # (B, VN, D) + + # upsample the sync features to match the audio + sync_f = sync_f.transpose(1, 2) # (B, D, VN) + sync_f = F.interpolate(sync_f, size=self._latent_seq_len, mode='nearest-exact') + sync_f = sync_f.transpose(1, 2) # (B, N, D) + + # get conditional features from the clip side + clip_f_c = self.clip_cond_proj(clip_f.mean(dim=1)) # (B, D) + text_f_c = self.text_cond_proj(text_f.mean(dim=1)) # (B, D) + + return PreprocessedConditions(clip_f=clip_f, + sync_f=sync_f, + text_f=text_f, + clip_f_c=clip_f_c, + text_f_c=text_f_c) + + def predict_flow(self, latent: torch.Tensor, t: torch.Tensor, + conditions: PreprocessedConditions) -> torch.Tensor: + """ + for non-cacheable computations + """ + assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}' + + clip_f = conditions.clip_f + sync_f = conditions.sync_f + text_f = conditions.text_f + clip_f_c = conditions.clip_f_c + text_f_c = conditions.text_f_c + + latent = self.audio_input_proj(latent) # (B, N, D) + global_c = self.global_cond_mlp(clip_f_c + text_f_c) # (B, D) + + global_c = self.t_embed(t).unsqueeze(1) + global_c.unsqueeze(1) # (B, D) + extended_c = global_c + sync_f + + for block in self.joint_blocks: + latent, clip_f, text_f = block(latent, clip_f, text_f, global_c, extended_c, + self.latent_rot, self.clip_rot) # (B, N, D) + + for block in self.fused_blocks: + latent = block(latent, extended_c, self.latent_rot) + + # should be extended_c; this is a minor implementation error #55 + flow = self.final_layer(latent, extended_c) # (B, N, out_dim), remove t + return flow + + def forward(self, latent: torch.Tensor, clip_f: torch.Tensor, sync_f: torch.Tensor, + text_f: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + latent: (B, N, C) + vf: (B, T, C_V) + t: (B,) + """ + conditions = self.preprocess_conditions(clip_f, sync_f, text_f) + flow = self.predict_flow(latent, t, conditions) + return flow + + def get_empty_string_sequence(self, bs: int) -> torch.Tensor: + return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1) + + def get_empty_clip_sequence(self, bs: int) -> torch.Tensor: + return self.empty_clip_feat.unsqueeze(0).expand(bs, self._clip_seq_len, -1) + + def get_empty_sync_sequence(self, bs: int) -> torch.Tensor: + return self.empty_sync_feat.unsqueeze(0).expand(bs, self._sync_seq_len, -1) + + def get_empty_conditions( + self, + bs: int, + *, + negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions: + if negative_text_features is not None: + empty_text = negative_text_features + else: + empty_text = self.get_empty_string_sequence(1) + + empty_clip = self.get_empty_clip_sequence(1) + empty_sync = self.get_empty_sync_sequence(1) + conditions = self.preprocess_conditions(empty_clip, empty_sync, empty_text) + conditions.clip_f = conditions.clip_f.expand(bs, -1, -1) + conditions.sync_f = conditions.sync_f.expand(bs, -1, -1) + conditions.clip_f_c = conditions.clip_f_c.expand(bs, -1) + if negative_text_features is None: + conditions.text_f = conditions.text_f.expand(bs, -1, -1) + conditions.text_f_c = conditions.text_f_c.expand(bs, -1) + + return conditions + + def ode_wrapper(self, t: torch.Tensor, latent: torch.Tensor, conditions: PreprocessedConditions, + empty_conditions: PreprocessedConditions, cfg_strength: float) -> torch.Tensor: + t = t * torch.ones(len(latent), device=latent.device, dtype=latent.dtype) + + if cfg_strength < 1.0: + return self.predict_flow(latent, t, conditions) + else: + return (cfg_strength * self.predict_flow(latent, t, conditions) + + (1 - cfg_strength) * self.predict_flow(latent, t, empty_conditions)) + + def load_weights(self, src_dict) -> None: + if 't_embed.freqs' in src_dict: + del src_dict['t_embed.freqs'] + if 'latent_rot' in src_dict: + del src_dict['latent_rot'] + if 'clip_rot' in src_dict: + del src_dict['clip_rot'] + + self.load_state_dict(src_dict, strict=True) + + @property + def device(self) -> torch.device: + return self.latent_mean.device + + @property + def latent_seq_len(self) -> int: + return self._latent_seq_len + + @property + def clip_seq_len(self) -> int: + return self._clip_seq_len + + @property + def sync_seq_len(self) -> int: + return self._sync_seq_len + + +def small_16k(**kwargs) -> MMAudio: + num_heads = 7 + return MMAudio(latent_dim=20, + clip_dim=1024, + sync_dim=768, + text_dim=1024, + hidden_dim=64 * num_heads, + depth=12, + fused_depth=8, + num_heads=num_heads, + latent_seq_len=250, + clip_seq_len=64, + sync_seq_len=192, + **kwargs) + + +def small_44k(**kwargs) -> MMAudio: + num_heads = 7 + return MMAudio(latent_dim=40, + clip_dim=1024, + sync_dim=768, + text_dim=1024, + hidden_dim=64 * num_heads, + depth=12, + fused_depth=8, + num_heads=num_heads, + latent_seq_len=345, + clip_seq_len=64, + sync_seq_len=192, + **kwargs) + + +def medium_44k(**kwargs) -> MMAudio: + num_heads = 14 + return MMAudio(latent_dim=40, + clip_dim=1024, + sync_dim=768, + text_dim=1024, + hidden_dim=64 * num_heads, + depth=12, + fused_depth=8, + num_heads=num_heads, + latent_seq_len=345, + clip_seq_len=64, + sync_seq_len=192, + **kwargs) + + +def large_44k(**kwargs) -> MMAudio: + num_heads = 14 + return MMAudio(latent_dim=40, + clip_dim=1024, + sync_dim=768, + text_dim=1024, + hidden_dim=64 * num_heads, + depth=21, + fused_depth=14, + num_heads=num_heads, + latent_seq_len=345, + clip_seq_len=64, + sync_seq_len=192, + **kwargs) + + +def large_44k_v2(**kwargs) -> MMAudio: + num_heads = 14 + return MMAudio(latent_dim=40, + clip_dim=1024, + sync_dim=768, + text_dim=1024, + hidden_dim=64 * num_heads, + depth=21, + fused_depth=14, + num_heads=num_heads, + latent_seq_len=345, + clip_seq_len=64, + sync_seq_len=192, + v2=True, + **kwargs) + + +def get_my_mmaudio(name: str, **kwargs) -> MMAudio: + if name == 'small_16k': + return small_16k(**kwargs) + if name == 'small_44k': + return small_44k(**kwargs) + if name == 'medium_44k': + return medium_44k(**kwargs) + if name == 'large_44k': + return large_44k(**kwargs) + if name == 'large_44k_v2': + return large_44k_v2(**kwargs) + + raise ValueError(f'Unknown model name: {name}') + + +if __name__ == '__main__': + network = get_my_mmaudio('small_16k') + + # print the number of parameters in terms of millions + num_params = sum(p.numel() for p in network.parameters()) / 1e6 + print(f'Number of parameters: {num_params:.2f}M') diff --git a/think_sound/models/mmmodules/model/sequence_config.py b/think_sound/models/mmmodules/model/sequence_config.py new file mode 100644 index 0000000000000000000000000000000000000000..14269014dc401b4751d172466813a935fddda6c1 --- /dev/null +++ b/think_sound/models/mmmodules/model/sequence_config.py @@ -0,0 +1,58 @@ +import dataclasses +import math + + +@dataclasses.dataclass +class SequenceConfig: + # general + duration: float + + # audio + sampling_rate: int + spectrogram_frame_rate: int + latent_downsample_rate: int = 2 + + # visual + clip_frame_rate: int = 8 + sync_frame_rate: int = 25 + sync_num_frames_per_segment: int = 16 + sync_step_size: int = 8 + sync_downsample_rate: int = 2 + + @property + def num_audio_frames(self) -> int: + # we need an integer number of latents + return self.latent_seq_len * self.spectrogram_frame_rate * self.latent_downsample_rate + + @property + def latent_seq_len(self) -> int: + return int( + math.ceil(self.duration * self.sampling_rate / self.spectrogram_frame_rate / + self.latent_downsample_rate)) + + @property + def clip_seq_len(self) -> int: + return int(self.duration * self.clip_frame_rate) + + @property + def sync_seq_len(self) -> int: + num_frames = self.duration * self.sync_frame_rate + num_segments = (num_frames - self.sync_num_frames_per_segment) // self.sync_step_size + 1 + return int(num_segments * self.sync_num_frames_per_segment / self.sync_downsample_rate) + + +CONFIG_16K = SequenceConfig(duration=8.0, sampling_rate=16000, spectrogram_frame_rate=256) +CONFIG_44K = SequenceConfig(duration=8.0, sampling_rate=44100, spectrogram_frame_rate=512) + +if __name__ == '__main__': + assert CONFIG_16K.latent_seq_len == 250 + assert CONFIG_16K.clip_seq_len == 64 + assert CONFIG_16K.sync_seq_len == 192 + assert CONFIG_16K.num_audio_frames == 128000 + + assert CONFIG_44K.latent_seq_len == 345 + assert CONFIG_44K.clip_seq_len == 64 + assert CONFIG_44K.sync_seq_len == 192 + assert CONFIG_44K.num_audio_frames == 353280 + + print('Passed') diff --git a/think_sound/models/mmmodules/model/transformer_layers.py b/think_sound/models/mmmodules/model/transformer_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..6b06bc9850543f87ca9eb4217899674609d1620b --- /dev/null +++ b/think_sound/models/mmmodules/model/transformer_layers.py @@ -0,0 +1,271 @@ +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from einops.layers.torch import Rearrange + +from ..ext.rotary_embeddings import apply_rope +from ..model.low_level import MLP, ChannelLastConv1d, ConvMLP +try: + from flash_attn import flash_attn_func, flash_attn_kvpacked_func + print('flash_attn installed, using Flash Attention') +except ImportError as e: + print(e) + print('flash_attn not installed, disabling Flash Attention') + flash_attn_kvpacked_func = None + flash_attn_func = None + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor): + return x * (1 + scale) + shift + + +def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + # training will crash without these contiguous calls and the CUDNN limitation + # I believe this is related to https://github.com/pytorch/pytorch/issues/133974 + # unresolved at the time of writing + fa_dtype_in = q.dtype + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = F.scaled_dot_product_attention(q, k, v) + out = rearrange(out, 'b h n d -> b n (h d)').contiguous() + return out + q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.bfloat16), (q, k, v)) + # print(f"q dtype: {q.dtype}") + # print(f"k dtype: {k.dtype}") + # print(f"v dtype: {v.dtype}") + # breakpoint() + out = flash_attn_func(q, k, v) + out = rearrange(out.to(fa_dtype_in), 'b n h d -> b n (h d)') + # out = rearrange(out.to(fa_dtype_in), 'b h n d -> b n (h d)').contiguous() + return out + + +class SelfAttention(nn.Module): + + def __init__(self, dim: int, nheads: int): + super().__init__() + self.dim = dim + self.nheads = nheads + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.q_norm = nn.RMSNorm(dim // nheads) + self.k_norm = nn.RMSNorm(dim // nheads) + + self.split_into_heads = Rearrange('b n (h d j) -> b h n d j', + h=nheads, + d=dim // nheads, + j=3) + + def pre_attention( + self, x: torch.Tensor, + rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # x: batch_size * n_tokens * n_channels + qkv = self.qkv(x) + q, k, v = self.split_into_heads(qkv).chunk(3, dim=-1) + q = q.squeeze(-1) + k = k.squeeze(-1) + v = v.squeeze(-1) + q = self.q_norm(q) + k = self.k_norm(k) + + if rot is not None: + q = apply_rope(q, rot) + k = apply_rope(k, rot) + + return q, k, v + + def forward( + self, + x: torch.Tensor, # batch_size * n_tokens * n_channels + ) -> torch.Tensor: + q, v, k = self.pre_attention(x) + out = attention(q, k, v) + return out + +class CrossAttention(nn.Module): + + def __init__(self, dim: int, nheads: int): + super().__init__() + self.dim = dim + self.nheads = nheads + + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim, dim * 2, bias=False) + self.q_norm = nn.RMSNorm(dim // nheads) + self.k_norm = nn.RMSNorm(dim // nheads) + + self.split_q_into_heads = Rearrange('b n (h d) -> b h n d', + h=nheads, + d=dim // nheads) + self.split_kv_into_heads = Rearrange('b n (h d j) -> b h n d j', + h=nheads, + d=dim // nheads, + j=2) + + def pre_attention( + self, x: torch.Tensor, + context: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # x: batch_size * n_tokens * n_channels + q = self.to_q(x) + kv = self.to_kv(context) + q = self.split_q_into_heads(q) + k, v = self.split_kv_into_heads(kv).chunk(2, dim=-1) + q = q.squeeze(-1) + k = k.squeeze(-1) + v = v.squeeze(-1) + q = self.q_norm(q) + k = self.k_norm(k) + + + return q, k, v + + def forward( + self, + x: torch.Tensor, context=None + ) -> torch.Tensor: + q, v, k = self.pre_attention(x, context=context) + out = attention(q, k, v) + return out + + +class MMDitSingleBlock(nn.Module): + + def __init__(self, + dim: int, + nhead: int, + mlp_ratio: float = 4.0, + pre_only: bool = False, + kernel_size: int = 7, + padding: int = 3, + cross_attend: bool = False): + super().__init__() + self.norm1 = nn.LayerNorm(dim, elementwise_affine=False) + self.attn = SelfAttention(dim, nhead) + if cross_attend: + self.cross_attn = CrossAttention(dim, nhead) + self.pre_only = pre_only + if pre_only: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True)) + else: + if kernel_size == 1: + self.linear1 = nn.Linear(dim, dim) + else: + self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding) + self.norm2 = nn.LayerNorm(dim, elementwise_affine=False) + + if kernel_size == 1: + self.ffn = MLP(dim, int(dim * mlp_ratio)) + else: + self.ffn = ConvMLP(dim, + int(dim * mlp_ratio), + kernel_size=kernel_size, + padding=padding) + + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True)) + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]): + # x: BS * N * D + # cond: BS * D + modulation = self.adaLN_modulation(c) + if self.pre_only: + (shift_msa, scale_msa) = modulation.chunk(2, dim=-1) + gate_msa = shift_mlp = scale_mlp = gate_mlp = None + else: + (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, + gate_mlp) = modulation.chunk(6, dim=-1) + + x = modulate(self.norm1(x), shift_msa, scale_msa) + q, k, v = self.attn.pre_attention(x, rot) + return (q, k, v), (gate_msa, shift_mlp, scale_mlp, gate_mlp) + + def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: tuple[torch.Tensor], context=None): + if self.pre_only: + return x + + (gate_msa, shift_mlp, scale_mlp, gate_mlp) = c + x = x + self.linear1(attn_out) * gate_msa + + if context is not None: + x = x + self.cross_attn(x, context=context) + + r = modulate(self.norm2(x), shift_mlp, scale_mlp) + x = x + self.ffn(r) * gate_mlp + + return x + + def forward(self, x: torch.Tensor, cond: torch.Tensor, + rot: Optional[torch.Tensor], context: torch.Tensor = None) -> torch.Tensor: + # x: BS * N * D + # cond: BS * D + x_qkv, x_conditions = self.pre_attention(x, cond, rot) + attn_out = attention(*x_qkv) + x = self.post_attention(x, attn_out, x_conditions, context = context) + + return x + + +class JointBlock(nn.Module): + + def __init__(self, dim: int, nhead: int, mlp_ratio: float = 4.0, pre_only: bool = False): + super().__init__() + self.pre_only = pre_only + self.latent_block = MMDitSingleBlock(dim, + nhead, + mlp_ratio, + pre_only=False, + kernel_size=3, + padding=1) + self.clip_block = MMDitSingleBlock(dim, + nhead, + mlp_ratio, + pre_only=pre_only, + kernel_size=3, + padding=1) + self.text_block = MMDitSingleBlock(dim, nhead, mlp_ratio, pre_only=pre_only, kernel_size=1) + + def forward(self, latent: torch.Tensor, clip_f: torch.Tensor, text_f: torch.Tensor, + global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: torch.Tensor, + clip_rot: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # latent: BS * N1 * D + # clip_f: BS * N2 * D + # c: BS * (1/N) * D + x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot) + c_qkv, c_mod = self.clip_block.pre_attention(clip_f, global_c, clip_rot) + t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None) + + latent_len = latent.shape[1] + clip_len = clip_f.shape[1] + text_len = text_f.shape[1] + + joint_qkv = [torch.cat([x_qkv[i], c_qkv[i], t_qkv[i]], dim=2) for i in range(3)] + + attn_out = attention(*joint_qkv) + x_attn_out = attn_out[:, :latent_len] + c_attn_out = attn_out[:, latent_len:latent_len + clip_len] + t_attn_out = attn_out[:, latent_len + clip_len:] + + latent = self.latent_block.post_attention(latent, x_attn_out, x_mod) + if not self.pre_only: + clip_f = self.clip_block.post_attention(clip_f, c_attn_out, c_mod) + text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod) + + return latent, clip_f, text_f + + +class FinalBlock(nn.Module): + + def __init__(self, dim, out_dim): + super().__init__() + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True)) + self.norm = nn.LayerNorm(dim, elementwise_affine=False) + self.conv = ChannelLastConv1d(dim, out_dim, kernel_size=7, padding=3) + + def forward(self, latent, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + latent = modulate(self.norm(latent), shift, scale) + latent = self.conv(latent) + return latent diff --git a/think_sound/models/mmmodules/runner.py b/think_sound/models/mmmodules/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..755ee76bea7de3f31a14a5512710c39743dc9239 --- /dev/null +++ b/think_sound/models/mmmodules/runner.py @@ -0,0 +1,609 @@ +""" +trainer.py - wrapper and utility functions for network training +Compute loss, back-prop, update parameters, logging, etc. +""" +import os +from pathlib import Path +from typing import Optional, Union + +import torch +import torch.distributed +import torch.optim as optim +from av_bench.evaluate import evaluate +from av_bench.extract import extract +from nitrous_ema import PostHocEMA +from omegaconf import DictConfig +from torch.nn.parallel import DistributedDataParallel as DDP + +from mmaudio.model.flow_matching import FlowMatching +from mmaudio.model.networks import get_my_mmaudio +from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K +from mmaudio.model.utils.features_utils import FeaturesUtils +from mmaudio.model.utils.parameter_groups import get_parameter_groups +from mmaudio.model.utils.sample_utils import log_normal_sample +from mmaudio.utils.dist_utils import (info_if_rank_zero, local_rank, string_if_rank_zero) +from mmaudio.utils.log_integrator import Integrator +from mmaudio.utils.logger import TensorboardLogger +from mmaudio.utils.time_estimator import PartialTimeEstimator, TimeEstimator +from mmaudio.utils.video_joiner import VideoJoiner + + +class Runner: + + def __init__(self, + cfg: DictConfig, + log: TensorboardLogger, + run_path: Union[str, Path], + for_training: bool = True, + latent_mean: Optional[torch.Tensor] = None, + latent_std: Optional[torch.Tensor] = None): + self.exp_id = cfg.exp_id + self.use_amp = cfg.amp + self.enable_grad_scaler = cfg.enable_grad_scaler + self.for_training = for_training + self.cfg = cfg + + if cfg.model.endswith('16k'): + self.seq_cfg = CONFIG_16K + mode = '16k' + elif cfg.model.endswith('44k'): + self.seq_cfg = CONFIG_44K + mode = '44k' + else: + raise ValueError(f'Unknown model: {cfg.model}') + + self.sample_rate = self.seq_cfg.sampling_rate + self.duration_sec = self.seq_cfg.duration + + # setting up the model + empty_string_feat = torch.load('./ext_weights/empty_string.pth', weights_only=True)[0] + self.network = DDP(get_my_mmaudio(cfg.model, + latent_mean=latent_mean, + latent_std=latent_std, + empty_string_feat=empty_string_feat).cuda(), + device_ids=[local_rank], + broadcast_buffers=False) + if cfg.compile: + # NOTE: though train_fn and val_fn are very similar + # (early on they are implemented as a single function) + # keeping them separate and compiling them separately are CRUCIAL for high performance + self.train_fn = torch.compile(self.train_fn) + self.val_fn = torch.compile(self.val_fn) + + self.fm = FlowMatching(cfg.sampling.min_sigma, + inference_mode=cfg.sampling.method, + num_steps=cfg.sampling.num_steps) + + # ema profile + if for_training and cfg.ema.enable and local_rank == 0: + self.ema = PostHocEMA(self.network.module, + sigma_rels=cfg.ema.sigma_rels, + update_every=cfg.ema.update_every, + checkpoint_every_num_steps=cfg.ema.checkpoint_every, + checkpoint_folder=cfg.ema.checkpoint_folder, + step_size_correction=True).cuda() + self.ema_start = cfg.ema.start + else: + self.ema = None + + self.rng = torch.Generator(device='cuda') + self.rng.manual_seed(cfg['seed'] + local_rank) + + # setting up feature extractors and VAEs + if mode == '16k': + self.features = FeaturesUtils( + tod_vae_ckpt=cfg['vae_16k_ckpt'], + bigvgan_vocoder_ckpt=cfg['bigvgan_vocoder_ckpt'], + synchformer_ckpt=cfg['synchformer_ckpt'], + enable_conditions=True, + mode=mode, + need_vae_encoder=False, + ) + elif mode == '44k': + self.features = FeaturesUtils( + tod_vae_ckpt=cfg['vae_44k_ckpt'], + synchformer_ckpt=cfg['synchformer_ckpt'], + enable_conditions=True, + mode=mode, + need_vae_encoder=False, + ) + self.features = self.features.cuda().eval() + + if cfg.compile: + self.features.compile() + + # hyperparameters + self.log_normal_sampling_mean = cfg.sampling.mean + self.log_normal_sampling_scale = cfg.sampling.scale + self.null_condition_probability = cfg.null_condition_probability + self.cfg_strength = cfg.cfg_strength + + # setting up logging + self.log = log + self.run_path = Path(run_path) + vgg_cfg = cfg.data.VGGSound + if for_training: + self.val_video_joiner = VideoJoiner(vgg_cfg.root, self.run_path / 'val-sampled-videos', + self.sample_rate, self.duration_sec) + else: + self.test_video_joiner = VideoJoiner(vgg_cfg.root, + self.run_path / 'test-sampled-videos', + self.sample_rate, self.duration_sec) + string_if_rank_zero(self.log, 'model_size', + f'{sum([param.nelement() for param in self.network.parameters()])}') + string_if_rank_zero( + self.log, 'number_of_parameters_that_require_gradient: ', + str( + sum([ + param.nelement() + for param in filter(lambda p: p.requires_grad, self.network.parameters()) + ]))) + info_if_rank_zero(self.log, 'torch version: ' + torch.__version__) + self.train_integrator = Integrator(self.log, distributed=True) + self.val_integrator = Integrator(self.log, distributed=True) + + # setting up optimizer and loss + if for_training: + self.enter_train() + parameter_groups = get_parameter_groups(self.network, cfg, print_log=(local_rank == 0)) + self.optimizer = optim.AdamW(parameter_groups, + lr=cfg['learning_rate'], + weight_decay=cfg['weight_decay'], + betas=[0.9, 0.95], + eps=1e-6 if self.use_amp else 1e-8, + fused=True) + if self.enable_grad_scaler: + self.scaler = torch.amp.GradScaler(init_scale=2048) + self.clip_grad_norm = cfg['clip_grad_norm'] + + # linearly warmup learning rate + linear_warmup_steps = cfg['linear_warmup_steps'] + + def warmup(currrent_step: int): + return (currrent_step + 1) / (linear_warmup_steps + 1) + + warmup_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup) + + # setting up learning rate scheduler + if cfg['lr_schedule'] == 'constant': + next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda _: 1) + elif cfg['lr_schedule'] == 'poly': + total_num_iter = cfg['iterations'] + next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, + lr_lambda=lambda x: + (1 - (x / total_num_iter))**0.9) + elif cfg['lr_schedule'] == 'step': + next_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, + cfg['lr_schedule_steps'], + cfg['lr_schedule_gamma']) + else: + raise NotImplementedError + + self.scheduler = optim.lr_scheduler.SequentialLR(self.optimizer, + [warmup_scheduler, next_scheduler], + [linear_warmup_steps]) + + # Logging info + self.log_text_interval = cfg['log_text_interval'] + self.log_extra_interval = cfg['log_extra_interval'] + self.save_weights_interval = cfg['save_weights_interval'] + self.save_checkpoint_interval = cfg['save_checkpoint_interval'] + self.save_copy_iterations = cfg['save_copy_iterations'] + self.num_iterations = cfg['num_iterations'] + if cfg['debug']: + self.log_text_interval = self.log_extra_interval = 1 + + # update() is called when we log metrics, within the logger + self.log.batch_timer = TimeEstimator(self.num_iterations, self.log_text_interval) + # update() is called every iteration, in this script + self.log.data_timer = PartialTimeEstimator(self.num_iterations, 1, ema_alpha=0.9) + else: + self.enter_val() + + def train_fn( + self, + clip_f: torch.Tensor, + sync_f: torch.Tensor, + text_f: torch.Tensor, + a_mean: torch.Tensor, + a_std: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # sample + a_randn = torch.empty_like(a_mean).normal_(generator=self.rng) + x1 = a_mean + a_std * a_randn + bs = x1.shape[0] # batch_size * seq_len * num_channels + + # normalize the latents + x1 = self.network.module.normalize(x1) + + t = log_normal_sample(x1, + generator=self.rng, + m=self.log_normal_sampling_mean, + s=self.log_normal_sampling_scale) + x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1, + t, + Cs=[clip_f, sync_f, text_f], + generator=self.rng) + + # classifier-free training + samples = torch.rand(bs, device=x1.device, generator=self.rng) + null_video = (samples < self.null_condition_probability) + clip_f[null_video] = self.network.module.empty_clip_feat + sync_f[null_video] = self.network.module.empty_sync_feat + + samples = torch.rand(bs, device=x1.device, generator=self.rng) + null_text = (samples < self.null_condition_probability) + text_f[null_text] = self.network.module.empty_string_feat + + pred_v = self.network(xt, clip_f, sync_f, text_f, t) + loss = self.fm.loss(pred_v, x0, x1) + mean_loss = loss.mean() + return x1, loss, mean_loss, t + + def val_fn( + self, + clip_f: torch.Tensor, + sync_f: torch.Tensor, + text_f: torch.Tensor, + x1: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + bs = x1.shape[0] # batch_size * seq_len * num_channels + # normalize the latents + x1 = self.network.module.normalize(x1) + t = log_normal_sample(x1, + generator=self.rng, + m=self.log_normal_sampling_mean, + s=self.log_normal_sampling_scale) + x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1, + t, + Cs=[clip_f, sync_f, text_f], + generator=self.rng) + + # classifier-free training + samples = torch.rand(bs, device=x1.device, generator=self.rng) + # null mask is for when a video is provided but we decided to ignore it + null_video = (samples < self.null_condition_probability) + # complete mask is for when a video is not provided or we decided to ignore it + clip_f[null_video] = self.network.module.empty_clip_feat + sync_f[null_video] = self.network.module.empty_sync_feat + + samples = torch.rand(bs, device=x1.device, generator=self.rng) + null_text = (samples < self.null_condition_probability) + text_f[null_text] = self.network.module.empty_string_feat + + pred_v = self.network(xt, clip_f, sync_f, text_f, t) + + loss = self.fm.loss(pred_v, x0, x1) + mean_loss = loss.mean() + return loss, mean_loss, t + + def train_pass(self, data, it: int = 0): + + if not self.for_training: + raise ValueError('train_pass() should not be called when not training.') + + self.enter_train() + with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16): + clip_f = data['clip_features'].cuda(non_blocking=True) + sync_f = data['sync_features'].cuda(non_blocking=True) + text_f = data['text_features'].cuda(non_blocking=True) + video_exist = data['video_exist'].cuda(non_blocking=True) + text_exist = data['text_exist'].cuda(non_blocking=True) + a_mean = data['a_mean'].cuda(non_blocking=True) + a_std = data['a_std'].cuda(non_blocking=True) + + # these masks are for non-existent data; masking for CFG training is in train_fn + clip_f[~video_exist] = self.network.module.empty_clip_feat + sync_f[~video_exist] = self.network.module.empty_sync_feat + text_f[~text_exist] = self.network.module.empty_string_feat + + self.log.data_timer.end() + if it % self.log_extra_interval == 0: + unmasked_clip_f = clip_f.clone() + unmasked_sync_f = sync_f.clone() + unmasked_text_f = text_f.clone() + x1, loss, mean_loss, t = self.train_fn(clip_f, sync_f, text_f, a_mean, a_std) + + self.train_integrator.add_dict({'loss': mean_loss}) + + if it % self.log_text_interval == 0 and it != 0: + self.train_integrator.add_scalar('lr', self.scheduler.get_last_lr()[0]) + self.train_integrator.add_binned_tensor('binned_loss', loss, t) + self.train_integrator.finalize('train', it) + self.train_integrator.reset_except_hooks() + + # Backward pass + self.optimizer.zero_grad(set_to_none=True) + if self.enable_grad_scaler: + self.scaler.scale(mean_loss).backward() + self.scaler.unscale_(self.optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(), + self.clip_grad_norm) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + mean_loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(), + self.clip_grad_norm) + self.optimizer.step() + + if self.ema is not None and it >= self.ema_start: + self.ema.update() + self.scheduler.step() + self.integrator.add_scalar('grad_norm', grad_norm) + + self.enter_val() + with torch.amp.autocast('cuda', enabled=self.use_amp, + dtype=torch.bfloat16), torch.inference_mode(): + try: + if it % self.log_extra_interval == 0: + # save GT audio + # unnormalize the latents + x1 = self.network.module.unnormalize(x1[0:1]) + mel = self.features.decode(x1) + audio = self.features.vocode(mel).cpu()[0] # 1 * num_samples + self.log.log_spectrogram('train', f'spec-gt-r{local_rank}', mel.cpu()[0], it) + self.log.log_audio('train', + f'audio-gt-r{local_rank}', + audio, + it, + sample_rate=self.sample_rate) + + # save audio from sampling + x0 = torch.empty_like(x1[0:1]).normal_(generator=self.rng) + clip_f = unmasked_clip_f[0:1] + sync_f = unmasked_sync_f[0:1] + text_f = unmasked_text_f[0:1] + conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f) + empty_conditions = self.network.module.get_empty_conditions(x0.shape[0]) + cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper( + t, x, conditions, empty_conditions, self.cfg_strength) + x1_hat = self.fm.to_data(cfg_ode_wrapper, x0) + x1_hat = self.network.module.unnormalize(x1_hat) + mel = self.features.decode(x1_hat) + audio = self.features.vocode(mel).cpu()[0] + self.log.log_spectrogram('train', f'spec-r{local_rank}', mel.cpu()[0], it) + self.log.log_audio('train', + f'audio-r{local_rank}', + audio, + it, + sample_rate=self.sample_rate) + except Exception as e: + self.log.warning(f'Error in extra logging: {e}') + if self.cfg.debug: + raise + + # Save network weights and checkpoint if needed + save_copy = it in self.save_copy_iterations + + if (it % self.save_weights_interval == 0 and it != 0) or save_copy: + self.save_weights(it) + + if it % self.save_checkpoint_interval == 0 and it != 0: + self.save_checkpoint(it, save_copy=save_copy) + + self.log.data_timer.start() + + @torch.inference_mode() + def validation_pass(self, data, it: int = 0): + self.enter_val() + with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16): + clip_f = data['clip_features'].cuda(non_blocking=True) + sync_f = data['sync_features'].cuda(non_blocking=True) + text_f = data['text_features'].cuda(non_blocking=True) + video_exist = data['video_exist'].cuda(non_blocking=True) + text_exist = data['text_exist'].cuda(non_blocking=True) + a_mean = data['a_mean'].cuda(non_blocking=True) + a_std = data['a_std'].cuda(non_blocking=True) + + clip_f[~video_exist] = self.network.module.empty_clip_feat + sync_f[~video_exist] = self.network.module.empty_sync_feat + text_f[~text_exist] = self.network.module.empty_string_feat + a_randn = torch.empty_like(a_mean).normal_(generator=self.rng) + x1 = a_mean + a_std * a_randn + + self.log.data_timer.end() + loss, mean_loss, t = self.val_fn(clip_f.clone(), sync_f.clone(), text_f.clone(), x1) + + self.val_integrator.add_binned_tensor('binned_loss', loss, t) + self.val_integrator.add_dict({'loss': mean_loss}) + + self.log.data_timer.start() + + @torch.inference_mode() + def inference_pass(self, + data, + it: int, + data_cfg: DictConfig, + *, + save_eval: bool = True) -> Path: + self.enter_val() + with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16): + clip_f = data['clip_features'].cuda(non_blocking=True) + sync_f = data['sync_features'].cuda(non_blocking=True) + text_f = data['text_features'].cuda(non_blocking=True) + video_exist = data['video_exist'].cuda(non_blocking=True) + text_exist = data['text_exist'].cuda(non_blocking=True) + a_mean = data['a_mean'].cuda(non_blocking=True) # for the shape only + + clip_f[~video_exist] = self.network.module.empty_clip_feat + sync_f[~video_exist] = self.network.module.empty_sync_feat + text_f[~text_exist] = self.network.module.empty_string_feat + + # sample + x0 = torch.empty_like(a_mean).normal_(generator=self.rng) + conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f) + empty_conditions = self.network.module.get_empty_conditions(x0.shape[0]) + cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper( + t, x, conditions, empty_conditions, self.cfg_strength) + x1_hat = self.fm.to_data(cfg_ode_wrapper, x0) + x1_hat = self.network.module.unnormalize(x1_hat) + mel = self.features.decode(x1_hat) + audio = self.features.vocode(mel).cpu() + for i in range(audio.shape[0]): + video_id = data['id'][i] + if (not self.for_training) and i == 0: + # save very few videos + self.test_video_joiner.join(video_id, f'{video_id}', audio[i].transpose(0, 1)) + + if data_cfg.output_subdir is not None: + # validation + if save_eval: + iter_naming = f'{it:09d}' + else: + iter_naming = 'val-cache' + audio_dir = self.log.log_audio(iter_naming, + f'{video_id}', + audio[i], + it=None, + sample_rate=self.sample_rate, + subdir=Path(data_cfg.output_subdir)) + if save_eval and i == 0: + self.val_video_joiner.join(video_id, f'{iter_naming}-{video_id}', + audio[i].transpose(0, 1)) + else: + # full test set, usually + audio_dir = self.log.log_audio(f'{data_cfg.tag}-sampled', + f'{video_id}', + audio[i], + it=None, + sample_rate=self.sample_rate) + + return Path(audio_dir) + + @torch.inference_mode() + def eval(self, audio_dir: Path, it: int, data_cfg: DictConfig) -> dict[str, float]: + with torch.amp.autocast('cuda', enabled=False): + if local_rank == 0: + extract(audio_path=audio_dir, + output_path=audio_dir / 'cache', + device='cuda', + batch_size=32, + audio_length=8) + output_metrics = evaluate(gt_audio_cache=Path(data_cfg.gt_cache), + pred_audio_cache=audio_dir / 'cache') + for k, v in output_metrics.items(): + # pad k to 10 characters + # pad v to 10 decimal places + self.log.log_scalar(f'{data_cfg.tag}/{k}', v, it) + self.log.info(f'{data_cfg.tag}/{k:<10}: {v:.10f}') + else: + output_metrics = None + + return output_metrics + + def save_weights(self, it, save_copy=False): + if local_rank != 0: + return + + os.makedirs(self.run_path, exist_ok=True) + if save_copy: + model_path = self.run_path / f'{self.exp_id}_{it}.pth' + torch.save(self.network.module.state_dict(), model_path) + self.log.info(f'Network weights saved to {model_path}.') + + # if last exists, move it to a shadow copy + model_path = self.run_path / f'{self.exp_id}_last.pth' + if model_path.exists(): + shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow')) + model_path.replace(shadow_path) + self.log.info(f'Network weights shadowed to {shadow_path}.') + + torch.save(self.network.module.state_dict(), model_path) + self.log.info(f'Network weights saved to {model_path}.') + + def save_checkpoint(self, it, save_copy=False): + if local_rank != 0: + return + + checkpoint = { + 'it': it, + 'weights': self.network.module.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'scheduler': self.scheduler.state_dict(), + 'ema': self.ema.state_dict() if self.ema is not None else None, + } + + os.makedirs(self.run_path, exist_ok=True) + if save_copy: + model_path = self.run_path / f'{self.exp_id}_ckpt_{it}.pth' + torch.save(checkpoint, model_path) + self.log.info(f'Checkpoint saved to {model_path}.') + + # if ckpt_last exists, move it to a shadow copy + model_path = self.run_path / f'{self.exp_id}_ckpt_last.pth' + if model_path.exists(): + shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow')) + model_path.replace(shadow_path) # moves the file + self.log.info(f'Checkpoint shadowed to {shadow_path}.') + + torch.save(checkpoint, model_path) + self.log.info(f'Checkpoint saved to {model_path}.') + + def get_latest_checkpoint_path(self): + ckpt_path = self.run_path / f'{self.exp_id}_ckpt_last.pth' + if not ckpt_path.exists(): + info_if_rank_zero(self.log, f'No checkpoint found at {ckpt_path}.') + return None + return ckpt_path + + def get_latest_weight_path(self): + weight_path = self.run_path / f'{self.exp_id}_last.pth' + if not weight_path.exists(): + self.log.info(f'No weight found at {weight_path}.') + return None + return weight_path + + def get_final_ema_weight_path(self): + weight_path = self.run_path / f'{self.exp_id}_ema_final.pth' + if not weight_path.exists(): + self.log.info(f'No weight found at {weight_path}.') + return None + return weight_path + + def load_checkpoint(self, path): + # This method loads everything and should be used to resume training + map_location = 'cuda:%d' % local_rank + checkpoint = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True) + + it = checkpoint['it'] + weights = checkpoint['weights'] + optimizer = checkpoint['optimizer'] + scheduler = checkpoint['scheduler'] + if self.ema is not None: + self.ema.load_state_dict(checkpoint['ema']) + self.log.info(f'EMA states loaded from step {self.ema.step}') + + map_location = 'cuda:%d' % local_rank + self.network.module.load_state_dict(weights) + self.optimizer.load_state_dict(optimizer) + self.scheduler.load_state_dict(scheduler) + + self.log.info(f'Global iteration {it} loaded.') + self.log.info('Network weights, optimizer states, and scheduler states loaded.') + + return it + + def load_weights_in_memory(self, src_dict): + self.network.module.load_weights(src_dict) + self.log.info('Network weights loaded from memory.') + + def load_weights(self, path): + # This method loads only the network weight and should be used to load a pretrained model + map_location = 'cuda:%d' % local_rank + src_dict = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True) + + self.log.info(f'Importing network weights from {path}...') + self.load_weights_in_memory(src_dict) + + def weights(self): + return self.network.module.state_dict() + + def enter_train(self): + self.integrator = self.train_integrator + self.network.train() + return self + + def enter_val(self): + self.network.eval() + return self diff --git a/think_sound/models/mmmodules/sample.py b/think_sound/models/mmmodules/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..72b83389d7dbb55bed02991f51731b0d1e346a2b --- /dev/null +++ b/think_sound/models/mmmodules/sample.py @@ -0,0 +1,90 @@ +import json +import logging +import os +import random + +import numpy as np +import torch +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, open_dict +from tqdm import tqdm + +from mmaudio.data.data_setup import setup_test_datasets +from mmaudio.runner import Runner +from mmaudio.utils.dist_utils import info_if_rank_zero +from mmaudio.utils.logger import TensorboardLogger + +local_rank = int(os.environ['LOCAL_RANK']) +world_size = int(os.environ['WORLD_SIZE']) + + +def sample(cfg: DictConfig): + # initial setup + num_gpus = world_size + run_dir = HydraConfig.get().run.dir + + # wrap python logger with a tensorboard logger + log = TensorboardLogger(cfg.exp_id, + run_dir, + logging.getLogger(), + is_rank0=(local_rank == 0), + enable_email=cfg.enable_email and not cfg.debug) + + info_if_rank_zero(log, f'All configuration: {cfg}') + info_if_rank_zero(log, f'Number of GPUs detected: {num_gpus}') + + # cuda setup + torch.cuda.set_device(local_rank) + torch.backends.cudnn.benchmark = cfg.cudnn_benchmark + + # number of dataloader workers + info_if_rank_zero(log, f'Number of dataloader workers (per GPU): {cfg.num_workers}') + + # Set seeds to ensure the same initialization + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + random.seed(cfg.seed) + + # setting up configurations + info_if_rank_zero(log, f'Configuration: {cfg}') + info_if_rank_zero(log, f'Batch size (per GPU): {cfg.batch_size}') + + # construct the trainer + runner = Runner(cfg, log=log, run_path=run_dir, for_training=False).enter_val() + + # load the last weights if needed + if cfg['weights'] is not None: + info_if_rank_zero(log, f'Loading weights from the disk: {cfg["weights"]}') + runner.load_weights(cfg['weights']) + cfg['weights'] = None + else: + weights = runner.get_final_ema_weight_path() + if weights is not None: + info_if_rank_zero(log, f'Automatically finding weight: {weights}') + runner.load_weights(weights) + + # setup datasets + dataset, sampler, loader = setup_test_datasets(cfg) + data_cfg = cfg.data.ExtractedVGG_test + with open_dict(data_cfg): + if cfg.output_name is not None: + # append to the tag + data_cfg.tag = f'{data_cfg.tag}-{cfg.output_name}' + + # loop + audio_path = None + for curr_iter, data in enumerate(tqdm(loader)): + new_audio_path = runner.inference_pass(data, curr_iter, data_cfg) + if audio_path is None: + audio_path = new_audio_path + else: + assert audio_path == new_audio_path, 'Different audio path detected' + + info_if_rank_zero(log, f'Inference completed. Audio path: {audio_path}') + output_metrics = runner.eval(audio_path, curr_iter, data_cfg) + + if local_rank == 0: + # write the output metrics to run_dir + output_metrics_path = os.path.join(run_dir, f'{data_cfg.tag}-output_metrics.json') + with open(output_metrics_path, 'w') as f: + json.dump(output_metrics, f, indent=4) diff --git a/think_sound/models/pqmf.py b/think_sound/models/pqmf.py new file mode 100644 index 0000000000000000000000000000000000000000..007fdb51ec797554c1cdd4d9363894d743d970bf --- /dev/null +++ b/think_sound/models/pqmf.py @@ -0,0 +1,393 @@ +import math +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from scipy.optimize import fmin +from scipy.signal import firwin, kaiser, kaiser_beta, kaiserord + +class PQMF(nn.Module): + """ + Pseudo Quadrature Mirror Filter (PQMF) for multiband signal decomposition and reconstruction. + Uses polyphase representation which is computationally more efficient for real-time. + + Parameters: + - attenuation (int): Desired attenuation of the rejected frequency bands, usually between 80 and 120 dB. + - num_bands (int): Number of desired frequency bands. It must be a power of 2. + """ + + def __init__(self, attenuation, num_bands): + super(PQMF, self).__init__() + + # Ensure num_bands is a power of 2 + is_power_of_2 = (math.log2(num_bands) == int(math.log2(num_bands))) + assert is_power_of_2, "'num_bands' must be a power of 2." + + # Create the prototype filter + prototype_filter = design_prototype_filter(attenuation, num_bands) + filter_bank = generate_modulated_filter_bank(prototype_filter, num_bands) + padded_filter_bank = pad_to_nearest_power_of_two(filter_bank) + + # Register filters and settings + self.register_buffer("filter_bank", padded_filter_bank) + self.register_buffer("prototype", prototype_filter) + self.num_bands = num_bands + + def forward(self, signal): + """Decompose the signal into multiple frequency bands.""" + # If signal is not a pytorch tensor of Batch x Channels x Length, convert it + signal = prepare_signal_dimensions(signal) + # The signal length must be a multiple of num_bands. Pad it with zeros. + signal = pad_signal(signal, self.num_bands) + # run it + signal = polyphase_analysis(signal, self.filter_bank) + return apply_alias_cancellation(signal) + + def inverse(self, bands): + """Reconstruct the original signal from the frequency bands.""" + bands = apply_alias_cancellation(bands) + return polyphase_synthesis(bands, self.filter_bank) + + +def prepare_signal_dimensions(signal): + """ + Rearrange signal into Batch x Channels x Length. + + Parameters + ---------- + signal : torch.Tensor or numpy.ndarray + The input signal. + + Returns + ------- + torch.Tensor + Preprocessed signal tensor. + """ + # Convert numpy to torch tensor + if isinstance(signal, np.ndarray): + signal = torch.from_numpy(signal) + + # Ensure tensor + if not isinstance(signal, torch.Tensor): + raise ValueError("Input should be either a numpy array or a PyTorch tensor.") + + # Modify dimension of signal to Batch x Channels x Length + if signal.dim() == 1: + # This is just a mono signal. Unsqueeze to 1 x 1 x Length + signal = signal.unsqueeze(0).unsqueeze(0) + elif signal.dim() == 2: + # This is a multi-channel signal (e.g. stereo) + # Rearrange so that larger dimension (Length) is last + if signal.shape[0] > signal.shape[1]: + signal = signal.T + # Unsqueeze to 1 x Channels x Length + signal = signal.unsqueeze(0) + return signal + +def pad_signal(signal, num_bands): + """ + Pads the signal to make its length divisible by the given number of bands. + + Parameters + ---------- + signal : torch.Tensor + The input signal tensor, where the last dimension represents the signal length. + + num_bands : int + The number of bands by which the signal length should be divisible. + + Returns + ------- + torch.Tensor + The padded signal tensor. If the original signal length was already divisible + by num_bands, returns the original signal unchanged. + """ + remainder = signal.shape[-1] % num_bands + if remainder > 0: + padding_size = num_bands - remainder + signal = nn.functional.pad(signal, (0, padding_size)) + return signal + +def generate_modulated_filter_bank(prototype_filter, num_bands): + """ + Generate a QMF bank of cosine modulated filters based on a given prototype filter. + + Parameters + ---------- + prototype_filter : torch.Tensor + The prototype filter used as the basis for modulation. + num_bands : int + The number of desired subbands or filters. + + Returns + ------- + torch.Tensor + A bank of cosine modulated filters. + """ + + # Initialize indices for modulation. + subband_indices = torch.arange(num_bands).reshape(-1, 1) + + # Calculate the length of the prototype filter. + filter_length = prototype_filter.shape[-1] + + # Generate symmetric time indices centered around zero. + time_indices = torch.arange(-(filter_length // 2), (filter_length // 2) + 1) + + # Calculate phase offsets to ensure orthogonality between subbands. + phase_offsets = (-1)**subband_indices * np.pi / 4 + + # Compute the cosine modulation function. + modulation = torch.cos( + (2 * subband_indices + 1) * np.pi / (2 * num_bands) * time_indices + phase_offsets + ) + + # Apply modulation to the prototype filter. + modulated_filters = 2 * prototype_filter * modulation + + return modulated_filters + + +def design_kaiser_lowpass(angular_cutoff, attenuation, filter_length=None): + """ + Design a lowpass filter using the Kaiser window. + + Parameters + ---------- + angular_cutoff : float + The angular frequency cutoff of the filter. + attenuation : float + The desired stopband attenuation in decibels (dB). + filter_length : int, optional + Desired length of the filter. If not provided, it's computed based on the given specs. + + Returns + ------- + ndarray + The designed lowpass filter coefficients. + """ + + estimated_length, beta = kaiserord(attenuation, angular_cutoff / np.pi) + + # Ensure the estimated length is odd. + estimated_length = 2 * (estimated_length // 2) + 1 + + if filter_length is None: + filter_length = estimated_length + + return firwin(filter_length, angular_cutoff, window=('kaiser', beta), scale=False, nyq=np.pi) + + +def evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length): + """ + Evaluate the filter's objective value based on the criteria from https://ieeexplore.ieee.org/document/681427 + + Parameters + ---------- + angular_cutoff : float + Angular frequency cutoff of the filter. + attenuation : float + Desired stopband attenuation in dB. + num_bands : int + Number of bands for the multiband filter system. + filter_length : int, optional + Desired length of the filter. + + Returns + ------- + float + The computed objective (loss) value for the given filter specs. + """ + + filter_coeffs = design_kaiser_lowpass(angular_cutoff, attenuation, filter_length) + convolved_filter = np.convolve(filter_coeffs, filter_coeffs[::-1], "full") + + return np.max(np.abs(convolved_filter[convolved_filter.shape[-1] // 2::2 * num_bands][1:])) + + +def design_prototype_filter(attenuation, num_bands, filter_length=None): + """ + Design the optimal prototype filter for a multiband system given the desired specs. + + Parameters + ---------- + attenuation : float + The desired stopband attenuation in dB. + num_bands : int + Number of bands for the multiband filter system. + filter_length : int, optional + Desired length of the filter. If not provided, it's computed based on the given specs. + + Returns + ------- + ndarray + The optimal prototype filter coefficients. + """ + + optimal_angular_cutoff = fmin(lambda angular_cutoff: evaluate_filter_objective(angular_cutoff, attenuation, num_bands, filter_length), + 1 / num_bands, disp=0)[0] + + prototype_filter = design_kaiser_lowpass(optimal_angular_cutoff, attenuation, filter_length) + return torch.tensor(prototype_filter, dtype=torch.float32) + +def pad_to_nearest_power_of_two(x): + """ + Pads the input tensor 'x' on both sides such that its last dimension + becomes the nearest larger power of two. + + Parameters: + ----------- + x : torch.Tensor + The input tensor to be padded. + + Returns: + -------- + torch.Tensor + The padded tensor. + """ + current_length = x.shape[-1] + target_length = 2**math.ceil(math.log2(current_length)) + + total_padding = target_length - current_length + left_padding = total_padding // 2 + right_padding = total_padding - left_padding + + return nn.functional.pad(x, (left_padding, right_padding)) + +def apply_alias_cancellation(x): + """ + Applies alias cancellation by inverting the sign of every + second element of every second row, starting from the second + row's first element in a tensor. + + This operation helps ensure that the aliasing introduced in + each band during the decomposition will be counteracted during + the reconstruction. + + Parameters: + ----------- + x : torch.Tensor + The input tensor. + + Returns: + -------- + torch.Tensor + Tensor with specific elements' sign inverted for alias cancellation. + """ + + # Create a mask of the same shape as 'x', initialized with all ones + mask = torch.ones_like(x) + + # Update specific elements in the mask to -1 to perform inversion + mask[..., 1::2, ::2] = -1 + + # Apply the mask to the input tensor 'x' + return x * mask + +def ensure_odd_length(tensor): + """ + Pads the last dimension of a tensor to ensure its size is odd. + + Parameters: + ----------- + tensor : torch.Tensor + Input tensor whose last dimension might need padding. + + Returns: + -------- + torch.Tensor + The original tensor if its last dimension was already odd, + or the padded tensor with an odd-sized last dimension. + """ + + last_dim_size = tensor.shape[-1] + + if last_dim_size % 2 == 0: + tensor = nn.functional.pad(tensor, (0, 1)) + + return tensor + +def polyphase_analysis(signal, filter_bank): + """ + Applies the polyphase method to efficiently analyze the signal using a filter bank. + + Parameters: + ----------- + signal : torch.Tensor + Input signal tensor with shape (Batch x Channels x Length). + + filter_bank : torch.Tensor + Filter bank tensor with shape (Bands x Length). + + Returns: + -------- + torch.Tensor + Signal split into sub-bands. (Batch x Channels x Bands x Length) + """ + + num_bands = filter_bank.shape[0] + num_channels = signal.shape[1] + + # Rearrange signal for polyphase processing. + # Also combine Batch x Channel into one dimension for now. + #signal = rearrange(signal, "b c (t n) -> b (c n) t", n=num_bands) + signal = rearrange(signal, "b c (t n) -> (b c) n t", n=num_bands) + + # Rearrange the filter bank for matching signal shape + filter_bank = rearrange(filter_bank, "c (t n) -> c n t", n=num_bands) + + # Apply convolution with appropriate padding to maintain spatial dimensions + padding = filter_bank.shape[-1] // 2 + filtered_signal = nn.functional.conv1d(signal, filter_bank, padding=padding) + + # Truncate the last dimension post-convolution to adjust the output shape + filtered_signal = filtered_signal[..., :-1] + # Rearrange the first dimension back into Batch x Channels + filtered_signal = rearrange(filtered_signal, "(b c) n t -> b c n t", c=num_channels) + + return filtered_signal + +def polyphase_synthesis(signal, filter_bank): + """ + Polyphase Inverse: Apply polyphase filter bank synthesis to reconstruct a signal. + + Parameters + ---------- + signal : torch.Tensor + Decomposed signal to be reconstructed (shape: Batch x Channels x Bands x Length). + + filter_bank : torch.Tensor + Analysis filter bank (shape: Bands x Length). + + should_rearrange : bool, optional + Flag to determine if the filters should be rearranged for polyphase synthesis. Default is True. + + Returns + ------- + torch.Tensor + Reconstructed signal (shape: Batch x Channels X Length) + """ + + num_bands = filter_bank.shape[0] + num_channels = signal.shape[1] + + # Rearrange the filter bank + filter_bank = filter_bank.flip(-1) + filter_bank = rearrange(filter_bank, "c (t n) -> n c t", n=num_bands) + + # Combine Batch x Channels into one dimension for now. + signal = rearrange(signal, "b c n t -> (b c) n t") + + # Apply convolution with appropriate padding + padding_amount = filter_bank.shape[-1] // 2 + 1 + reconstructed_signal = nn.functional.conv1d(signal, filter_bank, padding=int(padding_amount)) + + # Scale the result + reconstructed_signal = reconstructed_signal[..., :-1] * num_bands + + # Reorganize the output and truncate + reconstructed_signal = reconstructed_signal.flip(1) + reconstructed_signal = rearrange(reconstructed_signal, "(b c) n t -> b c (t n)", c=num_channels, n=num_bands) + reconstructed_signal = reconstructed_signal[..., 2 * filter_bank.shape[1]:] + + return reconstructed_signal \ No newline at end of file diff --git a/think_sound/models/pretrained.py b/think_sound/models/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..e83af343587da91af92218f309c969c5a975b5ed --- /dev/null +++ b/think_sound/models/pretrained.py @@ -0,0 +1,25 @@ +import json + +from .factory import create_model_from_config +from .utils import load_ckpt_state_dict + +from huggingface_hub import hf_hub_download + +def get_pretrained_model(name: str): + + model_config_path = hf_hub_download(name, filename="model_config.json", repo_type='model') + + with open(model_config_path) as f: + model_config = json.load(f) + + model = create_model_from_config(model_config) + + # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file + try: + model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model') + except Exception as e: + model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model') + + model.load_state_dict(load_ckpt_state_dict(model_ckpt_path)) + + return model, model_config \ No newline at end of file diff --git a/think_sound/models/pretransforms.py b/think_sound/models/pretransforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c9942db59908ce8135f44e45090e928f6c60393a --- /dev/null +++ b/think_sound/models/pretransforms.py @@ -0,0 +1,258 @@ +import torch +from einops import rearrange +from torch import nn + +class Pretransform(nn.Module): + def __init__(self, enable_grad, io_channels, is_discrete): + super().__init__() + + self.is_discrete = is_discrete + self.io_channels = io_channels + self.encoded_channels = None + self.downsampling_ratio = None + + self.enable_grad = enable_grad + + def encode(self, x): + raise NotImplementedError + + def decode(self, z): + raise NotImplementedError + + def tokenize(self, x): + raise NotImplementedError + + def decode_tokens(self, tokens): + raise NotImplementedError + +class AutoencoderPretransform(Pretransform): + def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): + super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) + self.model = model + self.model.requires_grad_(False).eval() + self.scale=scale + self.downsampling_ratio = model.downsampling_ratio + self.io_channels = model.io_channels + self.sample_rate = model.sample_rate + + self.model_half = model_half + self.iterate_batch = iterate_batch + + self.encoded_channels = model.latent_dim + + self.chunked = chunked + self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None + self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None + + if self.model_half: + self.model.half() + + def encode(self, x, **kwargs): + + if self.model_half: + x = x.half() + self.model.to(torch.float16) + + encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + encoded = encoded.float() + + return encoded / self.scale + + def decode(self, z, **kwargs): + z = z * self.scale + + if self.model_half: + z = z.half() + self.model.to(torch.float16) + + decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + decoded = decoded.float() + + return decoded + + def tokenize(self, x, **kwargs): + assert self.model.is_discrete, "Cannot tokenize with a continuous model" + + _, info = self.model.encode(x, return_info = True, **kwargs) + + return info[self.model.bottleneck.tokens_id] + + def decode_tokens(self, tokens, **kwargs): + assert self.model.is_discrete, "Cannot decode tokens with a continuous model" + + return self.model.decode_tokens(tokens, **kwargs) + + def load_state_dict(self, state_dict, strict=True): + self.model.load_state_dict(state_dict, strict=strict) + +class WaveletPretransform(Pretransform): + def __init__(self, channels, levels, wavelet): + super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) + + from .wavelets import WaveletEncode1d, WaveletDecode1d + + self.encoder = WaveletEncode1d(channels, levels, wavelet) + self.decoder = WaveletDecode1d(channels, levels, wavelet) + + self.downsampling_ratio = 2 ** levels + self.io_channels = channels + self.encoded_channels = channels * self.downsampling_ratio + + def encode(self, x): + return self.encoder(x) + + def decode(self, z): + return self.decoder(z) + +class PQMFPretransform(Pretransform): + def __init__(self, attenuation=100, num_bands=16): + # TODO: Fix PQMF to take in in-channels + super().__init__(enable_grad=False, io_channels=1, is_discrete=False) + from .pqmf import PQMF + self.pqmf = PQMF(attenuation, num_bands) + + + def encode(self, x): + # x is (Batch x Channels x Time) + x = self.pqmf.forward(x) + # pqmf.forward returns (Batch x Channels x Bands x Time) + # but Pretransform needs Batch x Channels x Time + # so concatenate channels and bands into one axis + return rearrange(x, "b c n t -> b (c n) t") + + def decode(self, x): + # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) + x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) + # returns (Batch x Channels x Time) + return self.pqmf.inverse(x) + +class PretrainedDACPretransform(Pretransform): + def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + import dac + + model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) + + self.model = dac.DAC.load(model_path) + + self.quantize_on_decode = quantize_on_decode + + if model_type == "44khz": + self.downsampling_ratio = 512 + else: + self.downsampling_ratio = 320 + + self.io_channels = 1 + + self.scale = scale + + self.chunked = chunked + + self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.n_codebooks + + self.codebook_size = self.model.codebook_size + + def encode(self, x): + + latents = self.model.encoder(x) + + if self.quantize_on_decode: + output = latents + else: + z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + output = z + + if self.scale != 1.0: + output = output / self.scale + + return output + + def decode(self, z): + + if self.scale != 1.0: + z = z * self.scale + + if self.quantize_on_decode: + z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + return self.model.decode(z) + + def tokenize(self, x): + return self.model.encode(x)[1] + + def decode_tokens(self, tokens): + latents = self.model.quantizer.from_codes(tokens) + return self.model.decode(latents) + +class AudiocraftCompressionPretransform(Pretransform): + def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + try: + from audiocraft.models import CompressionModel + except ImportError: + raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.") + + self.model = CompressionModel.get_pretrained(model_type) + + self.quantize_on_decode = quantize_on_decode + + self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) + + self.sample_rate = self.model.sample_rate + + self.io_channels = self.model.channels + + self.scale = scale + + #self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.num_codebooks + + self.codebook_size = self.model.cardinality + + self.model.to(torch.float16).eval().requires_grad_(False) + + def encode(self, x): + + assert False, "Audiocraft compression models do not support continuous encoding" + + # latents = self.model.encoder(x) + + # if self.quantize_on_decode: + # output = latents + # else: + # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + # output = z + + # if self.scale != 1.0: + # output = output / self.scale + + # return output + + def decode(self, z): + + assert False, "Audiocraft compression models do not support continuous decoding" + + # if self.scale != 1.0: + # z = z * self.scale + + # if self.quantize_on_decode: + # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + # return self.model.decode(z) + + def tokenize(self, x): + with torch.cuda.amp.autocast(enabled=False): + return self.model.encode(x.to(torch.float16))[0] + + def decode_tokens(self, tokens): + with torch.cuda.amp.autocast(enabled=False): + return self.model.decode(tokens) diff --git a/think_sound/models/transformer.py b/think_sound/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5602fbbab9a2c4a5e79aa14daaf0cd5a337db61e --- /dev/null +++ b/think_sound/models/transformer.py @@ -0,0 +1,821 @@ +from functools import reduce, partial +from packaging import version + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +import torch +import torch.nn.functional as F +from torch import nn, einsum +from torch.cuda.amp import autocast +from typing import Callable, Literal + +try: + from flash_attn import flash_attn_func, flash_attn_kvpacked_func +except ImportError as e: + print(e) + print('flash_attn not installed, disabling Flash Attention') + flash_attn_kvpacked_func = None + flash_attn_func = None + +try: + import natten +except ImportError: + natten = None + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + + +# Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License +# License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt + +def create_causal_mask(i, j, device): + return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) + +def or_reduce(masks): + head, *body = masks + for rest in body: + head = head | rest + return head + +# positional embeddings + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.scale = dim ** -0.5 + self.max_seq_len = max_seq_len + self.emb = nn.Embedding(max_seq_len, dim) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = (pos - seq_start_pos[..., None]).clamp(min = 0) + + pos_emb = self.emb(pos) + pos_emb = pos_emb * self.scale + return pos_emb + +class ScaledSinusoidalEmbedding(nn.Module): + def __init__(self, dim, theta = 10000): + super().__init__() + assert (dim % 2) == 0, 'dimension must be divisible by 2' + self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) + + half_dim = dim // 2 + freq_seq = torch.arange(half_dim).float() / half_dim + inv_freq = theta ** -freq_seq + self.register_buffer('inv_freq', inv_freq, persistent = False) + + def forward(self, x, pos = None, seq_start_pos = None): + seq_len, device = x.shape[1], x.device + + if pos is None: + pos = torch.arange(seq_len, device = device) + + if seq_start_pos is not None: + pos = pos - seq_start_pos[..., None] + + emb = einsum('i, j -> i j', pos, self.inv_freq) + emb = torch.cat((emb.sin(), emb.cos()), dim = -1) + return emb * self.scale + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + use_xpos = False, + scale_base = 512, + interpolation_factor = 1., + base = 10000, + base_rescale_factor = 1. + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + base *= base_rescale_factor ** (dim / (dim - 2)) + + inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + assert interpolation_factor >= 1. + self.interpolation_factor = interpolation_factor + + if not use_xpos: + self.register_buffer('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + + self.scale_base = scale_base + self.register_buffer('scale', scale) + + def forward_from_seq_len(self, seq_len): + device = self.inv_freq.device + + t = torch.arange(seq_len, device = device) + return self.forward(t) + + @autocast(enabled = False) + def forward(self, t): + device = self.inv_freq.device + + t = t.to(torch.float32) + + t = t / self.interpolation_factor + + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim = -1) + + if self.scale is None: + return freqs, 1. + + power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base + scale = self.scale ** rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim = -1) + + return freqs, scale + +def rotate_half(x): + x = rearrange(x, '... (j d) -> ... j d', j = 2) + x1, x2 = x.unbind(dim = -2) + return torch.cat((-x2, x1), dim = -1) + +@autocast(enabled = False) +def apply_rotary_pos_emb(t, freqs, scale = 1): + out_dtype = t.dtype + + # cast to float32 if necessary for numerical stability + dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32)) + rot_dim, seq_len = freqs.shape[-1], t.shape[-2] + freqs, t = freqs.to(dtype), t.to(dtype) + freqs = freqs[-seq_len:, :] + + if t.ndim == 4 and freqs.ndim == 3: + freqs = rearrange(freqs, 'b n d -> b 1 n d') + + # partial rotary embeddings, Wang et al. GPT-J + t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + + t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) + + return torch.cat((t, t_unrotated), dim = -1) + +# norms +class LayerNorm(nn.Module): + def __init__(self, dim, bias=False, fix_scale=False): + """ + bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less + """ + super().__init__() + + if fix_scale: + self.register_buffer("gamma", torch.ones(dim)) + else: + self.gamma = nn.Parameter(torch.ones(dim)) + + if bias: + self.beta = nn.Parameter(torch.zeros(dim)) + else: + self.register_buffer("beta", torch.zeros(dim)) + + + def forward(self, x): + return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta) + +# feedforward + +class GLU(nn.Module): + def __init__( + self, + dim_in, + dim_out, + activation: Callable, + use_conv = False, + conv_kernel_size = 3, + ): + super().__init__() + self.act = activation + self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2)) + self.use_conv = use_conv + + def forward(self, x): + if self.use_conv: + x = rearrange(x, 'b n d -> b d n') + x = self.proj(x) + x = rearrange(x, 'b d n -> b n d') + else: + x = self.proj(x) + + x, gate = x.chunk(2, dim = -1) + return x * self.act(gate) + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out = None, + mult = 4, + no_bias = False, + glu = True, + use_conv = False, + conv_kernel_size = 3, + zero_init_output = True, + ): + super().__init__() + inner_dim = int(dim * mult) + + # Default to SwiGLU + + activation = nn.SiLU() + + dim_out = dim if dim_out is None else dim_out + + if glu: + linear_in = GLU(dim, inner_dim, activation) + else: + linear_in = nn.Sequential( + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias), + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + activation + ) + + linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias) + + # init last linear layer to 0 + if zero_init_output: + nn.init.zeros_(linear_out.weight) + if not no_bias: + nn.init.zeros_(linear_out.bias) + + + self.ff = nn.Sequential( + linear_in, + Rearrange('b d n -> b n d') if use_conv else nn.Identity(), + linear_out, + Rearrange('b n d -> b d n') if use_conv else nn.Identity(), + ) + + def forward(self, x): + return self.ff(x) + +class Attention(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + dim_context = None, + causal = False, + zero_init_output=True, + qk_norm: Literal['l2', 'ln', 'none'] = 'none', + natten_kernel_size = None + ): + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.causal = causal + + dim_kv = dim_context if dim_context is not None else dim + + self.num_heads = dim // dim_heads + self.kv_heads = dim_kv // dim_heads + + if dim_context is not None: + self.to_q = nn.Linear(dim, dim, bias=False) + self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False) + else: + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.to_out = nn.Linear(dim, dim, bias=False) + + if zero_init_output: + nn.init.zeros_(self.to_out.weight) + + self.qk_norm = qk_norm + + if self.qk_norm == "ln": + self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6) + elif self.qk_norm == 'rns': + self.q_norm = nn.RMSNorm(dim_heads) + self.k_norm = nn.RMSNorm(dim_heads) + + # Using 1d neighborhood attention + self.natten_kernel_size = natten_kernel_size + if natten_kernel_size is not None: + return + + self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None + + self.sdp_kwargs = dict( + enable_flash = True, + enable_math = True, + enable_mem_efficient = True + ) + + def flash_attn( + self, + q, + k, + v, + mask = None, + causal = None + ): + batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device + kv_heads = k.shape[1] + # Recommended for multi-query single-key-value attention by Tri Dao + # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) + + if heads != kv_heads: + # Repeat interleave kv_heads to match q_heads + heads_per_kv_head = heads // kv_heads + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + + if k.ndim == 3: + k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) + + if v.ndim == 3: + v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) + + causal = self.causal if causal is None else causal + + if q_len == 1 and causal: + causal = False + + if mask is not None: + assert mask.ndim == 4 + mask = mask.expand(batch, heads, q_len, k_len) + + # handle kv cache - this should be bypassable in updated flash attention 2 + + if k_len > q_len and causal: + causal_mask = self.create_causal_mask(q_len, k_len, device = device) + if mask is None: + mask = ~causal_mask + else: + mask = mask & ~causal_mask + causal = False + + # manually handle causal mask, if another mask was given + + row_is_entirely_masked = None + + if mask is not None and causal: + causal_mask = self.create_causal_mask(q_len, k_len, device = device) + mask = mask & ~causal_mask + + # protect against an entire row being masked out + + row_is_entirely_masked = ~mask.any(dim = -1) + mask[..., 0] = mask[..., 0] | row_is_entirely_masked + + causal = False + + with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs): + out = F.scaled_dot_product_attention( + q, k, v, + attn_mask = mask, + is_causal = causal + ) + + # for a row that is entirely masked out, should zero out the output of that row token + + if row_is_entirely_masked is not None: + out = out.masked_fill(row_is_entirely_masked[..., None], 0.) + + return out + + def forward( + self, + x, + context = None, + mask = None, + context_mask = None, + rotary_pos_emb = None, + causal = None + ): + h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None + kv_input = context if has_context else x + + if hasattr(self, 'to_q'): + # Use separate linear projections for q and k/v + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + else: + # Use fused linear projection + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + + # Normalize q and k for cosine sim attention + if self.qk_norm == "l2": + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + elif self.qk_norm == "ln": + q = self.q_norm(q) + k = self.k_norm(k) + elif self.qk_norm == "rns": + q = self.q_norm(q) + k = self.k_norm(k) + + if rotary_pos_emb is not None and not has_context: + freqs, _ = rotary_pos_emb + + q_dtype = q.dtype + k_dtype = k.dtype + + q = q.to(torch.float32) + k = k.to(torch.float32) + freqs = freqs.to(torch.float32) + + q = apply_rotary_pos_emb(q, freqs) + k = apply_rotary_pos_emb(k, freqs) + + q = q.to(q_dtype) + k = k.to(k_dtype) + + input_mask = context_mask + + if input_mask is None and not has_context: + input_mask = mask + + # determine masking + masks = [] + final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account + + if input_mask is not None: + input_mask = rearrange(input_mask, 'b j -> b 1 1 j') + masks.append(~input_mask) + + # Other masks will be added here later + + if len(masks) > 0: + final_attn_mask = ~or_reduce(masks) + + n, device = q.shape[-2], q.device + + causal = self.causal if causal is None else causal + + if n == 1 and causal: + causal = False + + if self.natten_kernel_size is not None: + if natten is None: + raise ImportError('natten not installed, please install natten to use neighborhood attention') + + dtype_in = q.dtype + q, k, v = map(lambda t: t.to(torch.float32), (q, k, v)) + + attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1) + + if final_attn_mask is not None: + attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max) + + attn = F.softmax(attn, dim=-1, dtype=torch.float32) + + out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in) + + # Prioritize Flash Attention 2 + elif self.use_fa_flash: + assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2' + # Flash Attention 2 requires FP16 inputs + fa_dtype_in = q.dtype + q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v)) + + out = flash_attn_func(q, k, v, causal = causal) + + out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d') + + # Fall back to PyTorch implementation + elif self.use_pt_flash: + out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask) + + else: + # Fall back to custom implementation + + if h != kv_h: + # Repeat interleave kv_heads to match q_heads + heads_per_kv_head = h // kv_h + k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + + scale = 1. / (q.shape[-1] ** 0.5) + + kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' + + dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale + + i, j, dtype = *dots.shape[-2:], dots.dtype + + mask_value = -torch.finfo(dots.dtype).max + + if final_attn_mask is not None: + dots = dots.masked_fill(~final_attn_mask, mask_value) + + if causal: + causal_mask = self.create_causal_mask(i, j, device = device) + dots = dots.masked_fill(causal_mask, mask_value) + + attn = F.softmax(dots, dim=-1, dtype=torch.float32) + attn = attn.type(dtype) + + out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) + + # merge heads + out = rearrange(out, ' b h n d -> b n (h d)') + + # Communicate between heads + + # with autocast(enabled = False): + # out_dtype = out.dtype + # out = out.to(torch.float32) + # out = self.to_out(out).to(out_dtype) + out = self.to_out(out) + + if mask is not None: + mask = rearrange(mask, 'b n -> b n 1') + out = out.masked_fill(~mask, 0.) + + return out + +class ConformerModule(nn.Module): + def __init__( + self, + dim, + norm_kwargs = {}, + ): + + super().__init__() + + self.dim = dim + + self.in_norm = LayerNorm(dim, **norm_kwargs) + self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + self.glu = GLU(dim, dim, nn.SiLU()) + self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) + self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm + self.swish = nn.SiLU() + self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) + + def forward(self, x): + x = self.in_norm(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.glu(x) + x = rearrange(x, 'b n d -> b d n') + x = self.depthwise_conv(x) + x = rearrange(x, 'b d n -> b n d') + x = self.mid_norm(x) + x = self.swish(x) + x = rearrange(x, 'b n d -> b d n') + x = self.pointwise_conv_2(x) + x = rearrange(x, 'b d n -> b n d') + + return x + +class TransformerBlock(nn.Module): + def __init__( + self, + dim, + dim_heads = 64, + cross_attend = False, + dim_context = None, + global_cond_dim = None, + causal = False, + zero_init_branch_outputs = True, + conformer = False, + layer_ix = -1, + remove_norms = False, + attn_kwargs = {}, + ff_kwargs = {}, + norm_kwargs = {} + ): + + super().__init__() + self.dim = dim + self.dim_heads = dim_heads + self.cross_attend = cross_attend + self.dim_context = dim_context + self.causal = causal + + self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + + self.self_attn = Attention( + dim, + dim_heads = dim_heads, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + + if cross_attend: + self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + self.cross_attn = Attention( + dim, + dim_heads = dim_heads, + dim_context=dim_context, + causal = causal, + zero_init_output=zero_init_branch_outputs, + **attn_kwargs + ) + + self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity() + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs) + + self.layer_ix = layer_ix + + self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None + + self.global_cond_dim = global_cond_dim + + if global_cond_dim is not None: + self.to_scale_shift_gate = nn.Sequential( + nn.SiLU(), + nn.Linear(global_cond_dim, dim * 6, bias=False) + ) + + nn.init.zeros_(self.to_scale_shift_gate[1].weight) + #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias) + + def forward( + self, + x, + context = None, + global_cond=None, + mask = None, + context_mask = None, + rotary_pos_emb = None + ): + if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: + + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1) + + # self-attention with adaLN + residual = x + x = self.pre_norm(x) + x = x * (1 + scale_self) + shift_self + x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) + x = x * torch.sigmoid(1 - gate_self) + x = x + residual + + if context is not None: + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + + if self.conformer is not None: + x = x + self.conformer(x) + + # feedforward with adaLN + residual = x + x = self.ff_norm(x) + x = x * (1 + scale_ff) + shift_ff + x = self.ff(x) + x = x * torch.sigmoid(1 - gate_ff) + x = x + residual + + else: + x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) + + if context is not None: + x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) + + if self.conformer is not None: + x = x + self.conformer(x) + + x = x + self.ff(self.ff_norm(x)) + + return x + +class ContinuousTransformer(nn.Module): + def __init__( + self, + dim, + depth, + *, + dim_in = None, + dim_out = None, + dim_heads = 64, + cross_attend=False, + cond_token_dim=None, + global_cond_dim=None, + causal=False, + rotary_pos_emb=True, + zero_init_branch_outputs=True, + conformer=False, + use_sinusoidal_emb=False, + use_abs_pos_emb=False, + abs_pos_emb_max_length=10000, + **kwargs + ): + + super().__init__() + + self.dim = dim + self.depth = depth + self.causal = causal + self.layers = nn.ModuleList([]) + + self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity() + self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity() + + if rotary_pos_emb: + self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32)) + else: + self.rotary_pos_emb = None + + self.use_sinusoidal_emb = use_sinusoidal_emb + if use_sinusoidal_emb: + self.pos_emb = ScaledSinusoidalEmbedding(dim) + + self.use_abs_pos_emb = use_abs_pos_emb + if use_abs_pos_emb: + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length) + + for i in range(depth): + self.layers.append( + TransformerBlock( + dim, + dim_heads = dim_heads, + cross_attend = cross_attend, + dim_context = cond_token_dim, + global_cond_dim = global_cond_dim, + causal = causal, + zero_init_branch_outputs = zero_init_branch_outputs, + conformer=conformer, + layer_ix=i, + **kwargs + ) + ) + + def forward( + self, + x, + mask = None, + prepend_embeds = None, + prepend_mask = None, + add_cond = None, + global_cond = None, + return_info = False, + **kwargs + ): + batch, seq, device = *x.shape[:2], x.device + + info = { + "hidden_states": [], + } + + x = self.project_in(x) + if add_cond is not None: + x = x + add_cond + + if prepend_embeds is not None: + prepend_length, prepend_dim = prepend_embeds.shape[1:] + + assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' + + x = torch.cat((prepend_embeds, x), dim = -2) + + if prepend_mask is not None or mask is not None: + mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool) + prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool) + + mask = torch.cat((prepend_mask, mask), dim = -1) + + + # Attention layers + + if self.rotary_pos_emb is not None: + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1]) + else: + rotary_pos_emb = None + + if self.use_sinusoidal_emb or self.use_abs_pos_emb: + x = x + self.pos_emb(x) + + # Iterate over the transformer layers + for layer in self.layers: + #x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + + if return_info: + info["hidden_states"].append(x) + + x = self.project_out(x) + + if return_info: + return x, info + + return x diff --git a/think_sound/models/utils.py b/think_sound/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..384a4ca718f025096f51cf8420a48a73480353f4 --- /dev/null +++ b/think_sound/models/utils.py @@ -0,0 +1,164 @@ +import torch +from safetensors.torch import load_file +from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor +from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline +from torch.nn.utils import remove_weight_norm + +def load_ckpt_state_dict(ckpt_path, prefix=None): + if ckpt_path.endswith(".safetensors"): + state_dict = load_file(ckpt_path) + else: + state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] + + # 过滤特定前缀的state_dict + filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict + + return filtered_state_dict + +def remove_weight_norm_from_model(model): + for module in model.modules(): + if hasattr(module, "weight"): + print(f"Removing weight norm from {module}") + remove_weight_norm(module) + + return model + +# Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license +# License can be found in LICENSES/LICENSE_META.txt + +def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): + """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. + + Args: + input (torch.Tensor): The input tensor containing probabilities. + num_samples (int): Number of samples to draw. + replacement (bool): Whether to draw with replacement or not. + Keywords args: + generator (torch.Generator): A pseudorandom number generator for sampling. + Returns: + torch.Tensor: Last dimension contains num_samples indices + sampled from the multinomial probability distribution + located in the last dimension of tensor input. + """ + + if num_samples == 1: + q = torch.empty_like(input).exponential_(1, generator=generator) + return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64) + + input_ = input.reshape(-1, input.shape[-1]) + output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) + output = output_.reshape(*list(input.shape[:-1]), -1) + return output + + +def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: + """Sample next token from top K values along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + k (int): The k in “top-k”. + Returns: + torch.Tensor: Sampled tokens. + """ + top_k_value, _ = torch.topk(probs, k, dim=-1) + min_value_top_k = top_k_value[..., [-1]] + probs *= (probs >= min_value_top_k).float() + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs, num_samples=1) + return next_token + + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + """Sample next token from top P probabilities along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + p (int): The p in “top-p”. + Returns: + torch.Tensor: Sampled tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort *= (~mask).float() + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + +def next_power_of_two(n): + return 2 ** (n - 1).bit_length() + +def next_multiple_of_64(n): + return ((n + 63) // 64) * 64 + + +# mask construction helpers + +def mask_from_start_end_indices( + seq_len: int, + start: Tensor, + end: Tensor +): + assert start.shape == end.shape + device = start.device + + seq = torch.arange(seq_len, device = device, dtype = torch.long) + seq = seq.reshape(*((-1,) * start.ndim), seq_len) + seq = seq.expand(*start.shape, seq_len) + + mask = seq >= start[..., None].long() + mask &= seq < end[..., None].long() + return mask + +def mask_from_frac_lengths( + seq_len: int, + frac_lengths: Tensor +): + device = frac_lengths.device + + lengths = (frac_lengths * seq_len).long() + max_start = seq_len - lengths + + rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1) + start = (max_start * rand).clamp(min = 0) + end = start + lengths + + return mask_from_start_end_indices(seq_len, start, end) + +def _build_spline(video_feat, video_t, target_t): + # 三次样条插值核心实现 + coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1)) + spline = NaturalCubicSpline(coeffs) + return spline.evaluate(target_t).permute(0,2,1) + +def resample(video_feat, audio_latent): + """ + 9s + video_feat: [B, 72, D] + audio_latent: [B, D', 194] or int + """ + B, Tv, D = video_feat.shape + + if isinstance(audio_latent, torch.Tensor): + # audio_latent is a tensor + if audio_latent.shape[1] != D: + Ta = audio_latent.shape[1] + else: + Ta = audio_latent.shape[2] + elif isinstance(audio_latent, int): + # audio_latent is an int + Ta = audio_latent + else: + raise TypeError("audio_latent must be either a tensor or an int") + + # 构建时间戳 (关键改进点) + video_time = torch.linspace(0, 9, Tv, device=video_feat.device) + audio_time = torch.linspace(0, 9, Ta, device=video_feat.device) + + # 三维化处理 (Batch, Feature, Time) + video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv] + + # 三次样条插值 + aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta] + return aligned_video.permute(0, 2, 1) # [B, Ta, D] diff --git a/think_sound/models/wavelets.py b/think_sound/models/wavelets.py new file mode 100644 index 0000000000000000000000000000000000000000..a359e39110c168aab960d3f79262b464a660e55e --- /dev/null +++ b/think_sound/models/wavelets.py @@ -0,0 +1,82 @@ +"""The 1D discrete wavelet transform for PyTorch.""" + +from einops import rearrange +import pywt +import torch +from torch import nn +from torch.nn import functional as F +from typing import Literal + + +def get_filter_bank(wavelet): + filt = torch.tensor(pywt.Wavelet(wavelet).filter_bank) + if wavelet.startswith("bior") and torch.all(filt[:, 0] == 0): + filt = filt[:, 1:] + return filt + +class WaveletEncode1d(nn.Module): + def __init__(self, + channels, + levels, + wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): + super().__init__() + self.wavelet = wavelet + self.channels = channels + self.levels = levels + filt = get_filter_bank(wavelet) + assert filt.shape[-1] % 2 == 1 + kernel = filt[:2, None] + kernel = torch.flip(kernel, dims=(-1,)) + index_i = torch.repeat_interleave(torch.arange(2), channels) + index_j = torch.tile(torch.arange(channels), (2,)) + kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) + kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] + self.register_buffer("kernel", kernel_final) + + def forward(self, x): + for i in range(self.levels): + low, rest = x[:, : self.channels], x[:, self.channels :] + pad = self.kernel.shape[-1] // 2 + low = F.pad(low, (pad, pad), "reflect") + low = F.conv1d(low, self.kernel, stride=2) + rest = rearrange( + rest, "n (c c2) (l l2) -> n (c l2 c2) l", l2=2, c2=self.channels + ) + x = torch.cat([low, rest], dim=1) + return x + + +class WaveletDecode1d(nn.Module): + def __init__(self, + channels, + levels, + wavelet: Literal["bior2.2", "bior2.4", "bior2.6", "bior2.8", "bior4.4", "bior6.8"] = "bior4.4"): + super().__init__() + self.wavelet = wavelet + self.channels = channels + self.levels = levels + filt = get_filter_bank(wavelet) + assert filt.shape[-1] % 2 == 1 + kernel = filt[2:, None] + index_i = torch.repeat_interleave(torch.arange(2), channels) + index_j = torch.tile(torch.arange(channels), (2,)) + kernel_final = torch.zeros(channels * 2, channels, filt.shape[-1]) + kernel_final[index_i * channels + index_j, index_j] = kernel[index_i, 0] + self.register_buffer("kernel", kernel_final) + + def forward(self, x): + for i in range(self.levels): + low, rest = x[:, : self.channels * 2], x[:, self.channels * 2 :] + pad = self.kernel.shape[-1] // 2 + 2 + low = rearrange(low, "n (l2 c) l -> n c (l l2)", l2=2) + low = F.pad(low, (pad, pad), "reflect") + low = rearrange(low, "n c (l l2) -> n (l2 c) l", l2=2) + low = F.conv_transpose1d( + low, self.kernel, stride=2, padding=self.kernel.shape[-1] // 2 + ) + low = low[..., pad - 1 : -pad] + rest = rearrange( + rest, "n (c l2 c2) l -> n (c c2) (l l2)", l2=2, c2=self.channels + ) + x = torch.cat([low, rest], dim=1) + return x \ No newline at end of file diff --git a/think_sound/training/__init__.py b/think_sound/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f77486b07a478bc88359bf2ece8b9c860df1b054 --- /dev/null +++ b/think_sound/training/__init__.py @@ -0,0 +1 @@ +from .factory import create_training_wrapper_from_config, create_demo_callback_from_config diff --git a/think_sound/training/__pycache__/__init__.cpython-310.pyc b/think_sound/training/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5341ed35294363d4664819b34658372d1d88e9f5 Binary files /dev/null and b/think_sound/training/__pycache__/__init__.cpython-310.pyc differ diff --git a/think_sound/training/__pycache__/__init__.cpython-39.pyc b/think_sound/training/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eac873918dd6c84cad0a8ac58daa7733992010ec Binary files /dev/null and b/think_sound/training/__pycache__/__init__.cpython-39.pyc differ diff --git a/think_sound/training/__pycache__/autoencoders.cpython-310.pyc b/think_sound/training/__pycache__/autoencoders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2138ff51a7251c02db411132dd507928a12e5329 Binary files /dev/null and b/think_sound/training/__pycache__/autoencoders.cpython-310.pyc differ diff --git a/think_sound/training/__pycache__/autoencoders.cpython-39.pyc b/think_sound/training/__pycache__/autoencoders.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b2a6bedcd0cc3066e42776776c849564b079d1b Binary files /dev/null and b/think_sound/training/__pycache__/autoencoders.cpython-39.pyc differ diff --git a/think_sound/training/__pycache__/diffusion.cpython-310.pyc b/think_sound/training/__pycache__/diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f65fd143aa6d3379f48252d9084461fef88ce66c Binary files /dev/null and b/think_sound/training/__pycache__/diffusion.cpython-310.pyc differ diff --git a/think_sound/training/__pycache__/diffusion.cpython-39.pyc b/think_sound/training/__pycache__/diffusion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11b337395b997f0d788cd045371b466af68e9a41 Binary files /dev/null and b/think_sound/training/__pycache__/diffusion.cpython-39.pyc differ diff --git a/think_sound/training/__pycache__/factory.cpython-310.pyc b/think_sound/training/__pycache__/factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7978655e702af1a76fe29e1775d5fbf89ca7d853 Binary files /dev/null and b/think_sound/training/__pycache__/factory.cpython-310.pyc differ diff --git a/think_sound/training/__pycache__/factory.cpython-39.pyc b/think_sound/training/__pycache__/factory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f2a6d0238e0fc918e9f36b81647ebfcf2fdd419 Binary files /dev/null and b/think_sound/training/__pycache__/factory.cpython-39.pyc differ diff --git a/think_sound/training/__pycache__/utils.cpython-310.pyc b/think_sound/training/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..439cab9e3454c17accbacf85f947a62cac1bfd90 Binary files /dev/null and b/think_sound/training/__pycache__/utils.cpython-310.pyc differ diff --git a/think_sound/training/__pycache__/utils.cpython-39.pyc b/think_sound/training/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c815222a7df75bafc58254b3fcac4294baf607fc Binary files /dev/null and b/think_sound/training/__pycache__/utils.cpython-39.pyc differ diff --git a/think_sound/training/autoencoders.py b/think_sound/training/autoencoders.py new file mode 100644 index 0000000000000000000000000000000000000000..f215393a12d446a64a8ebb84b0fabe49ce52258b --- /dev/null +++ b/think_sound/training/autoencoders.py @@ -0,0 +1,505 @@ +import torch +import torchaudio +import wandb +from einops import rearrange +from safetensors.torch import save_file, save_model +from ema_pytorch import EMA +from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss, SpatialSTFTLoss +# import pytorch_lightning as pl +import lightning as L +from lightning.pytorch.callbacks import Callback +from ..models.autoencoders import AudioAutoencoder +from ..models.discriminators import EncodecDiscriminator, OobleckDiscriminator, DACGANLoss +from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck +from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss +from .utils import create_optimizer_from_config, create_scheduler_from_config + + +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image + +class AutoencoderTrainingWrapper(L.LightningModule): + def __init__( + self, + autoencoder: AudioAutoencoder, + lr: float = 1e-4, + warmup_steps: int = 0, + encoder_freeze_on_warmup: bool = False, + sample_rate=48000, + loss_config: dict = None, + optimizer_configs: dict = None, + use_ema: bool = True, + ema_copy = None, + force_input_mono = False, + latent_mask_ratio = 0.0, + teacher_model: AudioAutoencoder = None + ): + super().__init__() + + self.automatic_optimization = False + + self.autoencoder = autoencoder + + self.warmed_up = False + self.warmup_steps = warmup_steps + self.encoder_freeze_on_warmup = encoder_freeze_on_warmup + self.lr = lr + + self.force_input_mono = force_input_mono + + self.teacher_model = teacher_model + + if optimizer_configs is None: + optimizer_configs ={ + "autoencoder": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (.8, .99) + } + } + }, + "discriminator": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (.8, .99) + } + } + } + + } + + self.optimizer_configs = optimizer_configs + + if loss_config is None: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + loss_config = { + "discriminator": { + "type": "encodec", + "config": { + "n_ffts": scales, + "hop_lengths": hop_sizes, + "win_lengths": win_lengths, + "filters": 32 + }, + "weights": { + "adversarial": 0.1, + "feature_matching": 5.0, + } + }, + "spectral": { + "type": "mrstft", + "config": { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + }, + "weights": { + "mrstft": 1.0, + } + }, + "time": { + "type": "l1", + "config": {}, + "weights": { + "l1": 0.0, + } + } + } + + self.loss_config = loss_config + + # Spectral reconstruction loss + + stft_loss_args = loss_config['spectral']['config'] + + if self.autoencoder.out_channels == 2: + self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + self.lrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + elif self.autoencoder.out_channels == 4: + # self.sdstft = SpatialSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + else: + self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + # Discriminator + + if loss_config['discriminator']['type'] == 'oobleck': + self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'encodec': + self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config']) + elif loss_config['discriminator']['type'] == 'dac': + self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config']) + + self.gen_loss_modules = [] + + # Adversarial and feature matching losses + self.gen_loss_modules += [ + ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'), + ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'), + ] + + if self.teacher_model is not None: + # Distillation losses + + stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25 + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss + AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder + AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder + AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder + ] + + else: + + # Reconstruction loss + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), + ] + + if self.autoencoder.out_channels == 2: + + # Add left and right channel reconstruction losses in addition to the sum and difference + self.gen_loss_modules += [ + AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2), + AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2), + ] + elif self.autoencoder.out_channels == 4: + # self.gen_loss_modules += [ + # AuralossLoss(self.lrstft, 'reals', 'decoded', name='stft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), + # ] + # Add left and right channel reconstruction losses in addition to the sum and difference + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals_w', 'decoded_w', name='stft_loss_w', weight=self.loss_config['spectral']['weights']['mrstft']/4), + AuralossLoss(self.sdstft, 'reals_x', 'decoded_x', name='stft_loss_x', weight=self.loss_config['spectral']['weights']['mrstft']/4), + AuralossLoss(self.sdstft, 'reals_y', 'decoded_y', name='stft_loss_y', weight=self.loss_config['spectral']['weights']['mrstft']/4), + AuralossLoss(self.sdstft, 'reals_z', 'decoded_z', name='stft_loss_z', weight=self.loss_config['spectral']['weights']['mrstft']/4), + ] + + self.gen_loss_modules += [ + AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']), + ] + + if self.loss_config['time']['weights']['l1'] > 0.0: + self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss')) + + if self.autoencoder.bottleneck is not None: + self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config) + + self.losses_gen = MultiLoss(self.gen_loss_modules) + + self.disc_loss_modules = [ + ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'), + ] + + self.losses_disc = MultiLoss(self.disc_loss_modules) + + # Set up EMA for model weights + self.autoencoder_ema = None + + self.use_ema = use_ema + + if self.use_ema: + self.autoencoder_ema = EMA( + self.autoencoder, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1 + ) + + self.latent_mask_ratio = latent_mask_ratio + + def configure_optimizers(self): + + opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters()) + opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters()) + + if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']: + sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen) + sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc) + return [opt_gen, opt_disc], [sched_gen, sched_disc] + + return [opt_gen, opt_disc] + + def training_step(self, batch, batch_idx): + reals, _ = batch + + # Remove extra dimension added by WebDataset + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + if self.global_step >= self.warmup_steps: + self.warmed_up = True + + loss_info = {} + + loss_info["reals"] = reals + + encoder_input = reals + + if self.force_input_mono and encoder_input.shape[1] > 1: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + loss_info["encoder_input"] = encoder_input + + data_std = encoder_input.std() + + if self.warmed_up and self.encoder_freeze_on_warmup: + with torch.no_grad(): + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + else: + latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True) + + loss_info["latents"] = latents + + loss_info.update(encoder_info) + + # Encode with teacher model for distillation + if self.teacher_model is not None: + with torch.no_grad(): + teacher_latents = self.teacher_model.encode(encoder_input, return_info=False) + loss_info['teacher_latents'] = teacher_latents + + # Optionally mask out some latents for noise resistance + if self.latent_mask_ratio > 0.0: + mask = torch.rand_like(latents) < self.latent_mask_ratio + latents = torch.where(mask, torch.zeros_like(latents), latents) + decoded = self.autoencoder.decode(latents) + + loss_info["decoded"] = decoded + + if self.autoencoder.out_channels == 2: + loss_info["decoded_left"] = decoded[:, 0:1, :] + loss_info["decoded_right"] = decoded[:, 1:2, :] + loss_info["reals_left"] = reals[:, 0:1, :] + loss_info["reals_right"] = reals[:, 1:2, :] + elif self.autoencoder.out_channels == 4: + loss_info["decoded_w"] = decoded[:, 0:1, :] + loss_info["decoded_x"] = decoded[:, 1:2, :] + loss_info["decoded_y"] = decoded[:, 2:3, :] + loss_info["decoded_z"] = decoded[:, 3:4, :] + loss_info["reals_w"] = reals[:, 0:1, :] + loss_info["reals_x"] = reals[:, 1:2, :] + loss_info["reals_y"] = reals[:, 2:3, :] + loss_info["reals_z"] = reals[:, 3:4, :] + + # Distillation + if self.teacher_model is not None: + with torch.no_grad(): + teacher_decoded = self.teacher_model.decode(teacher_latents) + own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher + teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model + + loss_info['teacher_decoded'] = teacher_decoded + loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded + loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded + + + if self.warmed_up: + loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded) + else: + loss_dis = torch.tensor(0.).to(reals) + loss_adv = torch.tensor(0.).to(reals) + feature_matching_distance = torch.tensor(0.).to(reals) + + loss_info["loss_dis"] = loss_dis + loss_info["loss_adv"] = loss_adv + loss_info["feature_matching_distance"] = feature_matching_distance + + opt_gen, opt_disc = self.optimizers() + + lr_schedulers = self.lr_schedulers() + + sched_gen = None + sched_disc = None + + if lr_schedulers is not None: + sched_gen, sched_disc = lr_schedulers + + # Train the discriminator + if self.global_step % 2 and self.warmed_up: + loss, losses = self.losses_disc(loss_info) + + log_dict = { + 'train/disc_lr': opt_disc.param_groups[0]['lr'] + } + + opt_disc.zero_grad() + self.manual_backward(loss) + + + opt_disc.step() + + if sched_disc is not None: + # sched step every step + sched_disc.step() + + # Train the generator + else: + + # import ipdb + # ipdb.set_trace() + loss, losses = self.losses_gen(loss_info) + + if self.use_ema: + self.autoencoder_ema.update() + + opt_gen.zero_grad() + self.manual_backward(loss) + opt_gen.step() + + if sched_gen is not None: + # scheduler step every step + sched_gen.step() + + log_dict = { + 'train/loss': loss.detach(), + 'train/latent_std': latents.std().detach(), + 'train/data_std': data_std.detach(), + 'train/gen_lr': opt_gen.param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f'train/{loss_name}'] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + + return loss + + def export_model(self, path, use_safetensors=False): + if self.autoencoder_ema is not None: + model = self.autoencoder_ema.ema_model + else: + model = self.autoencoder + + if use_safetensors: + save_model(model, path) + else: + torch.save({"state_dict": model.state_dict()}, path) + + +class AutoencoderDemoCallback(Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + sample_size=65536, + sample_rate=48000 + ): + super().__init__() + self.demo_every = demo_every + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + module.eval() + + try: + demo_reals, _ = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + encoder_input = demo_reals + + encoder_input = encoder_input.to(module.device) + + if module.force_input_mono: + encoder_input = encoder_input.mean(dim=1, keepdim=True) + + demo_reals = demo_reals.to(module.device) + + with torch.no_grad(): + if module.use_ema: + + latents = module.autoencoder_ema.ema_model.encode(encoder_input) + + fakes = module.autoencoder_ema.ema_model.decode(latents) + else: + latents = module.autoencoder.encode(encoder_input) + + fakes = module.autoencoder.decode(latents) + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demos/recon_{trainer.global_step:08}.wav' + reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_dict[f'recon'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) + log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) + + log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + + trainer.logger.experiment.log(log_dict) + except Exception as e: + print(f'{type(e).__name__}: {e}') + raise e + finally: + module.train() + +def create_loss_modules_from_bottleneck(bottleneck, loss_config): + losses = [] + + if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): + try: + kl_weight = loss_config['bottleneck']['weights']['kl'] + except: + kl_weight = 1e-6 + + kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss') + losses.append(kl_loss) + + if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck): + quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss') + losses.append(quantizer_loss) + + if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck): + codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss') + commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss') + losses.append(codebook_loss) + losses.append(commitment_loss) + + if isinstance(bottleneck, WassersteinBottleneck): + try: + mmd_weight = loss_config['bottleneck']['weights']['mmd'] + except: + mmd_weight = 100 + + mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss') + losses.append(mmd_loss) + + return losses \ No newline at end of file diff --git a/think_sound/training/diffusion.py b/think_sound/training/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..067a1630591fdeff36188abedfd42c6cac13e6ee --- /dev/null +++ b/think_sound/training/diffusion.py @@ -0,0 +1,2023 @@ +# import pytorch_lightning as pl +import lightning as L +from lightning.pytorch.callbacks import Callback +import sys, gc +import random +import torch +import torchaudio +import typing as tp +import wandb +# from beartype.typing import Tuple +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image +import auraloss +from ema_pytorch import EMA +from einops import rearrange +from safetensors.torch import save_file +from torch import optim +from torch.nn import functional as F +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler +from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper +from ..models.autoencoders import DiffusionAutoencoder +from ..models.diffusion_prior import PriorType +from .autoencoders import create_loss_modules_from_bottleneck +from .losses import AuralossLoss, MSELoss, MultiLoss +from .utils import create_optimizer_from_config, create_scheduler_from_config, mask_from_frac_lengths, generate_mask, generate_channel_mask +import os +from pathlib import Path +from time import time +import numpy as np + +class Profiler: + + def __init__(self): + self.ticks = [[time(), None]] + + def tick(self, msg): + self.ticks.append([time(), msg]) + + def __repr__(self): + rep = 80 * "=" + "\n" + for i in range(1, len(self.ticks)): + msg = self.ticks[i][1] + ellapsed = self.ticks[i][0] - self.ticks[i - 1][0] + rep += msg + f": {ellapsed*1000:.2f}ms\n" + rep += 80 * "=" + "\n\n\n" + return rep + +class DiffusionUncondTrainingWrapper(L.LightningModule): + ''' + Wrapper for training an unconditional audio diffusion model (like Dance Diffusion). + ''' + def __init__( + self, + model: DiffusionModelWrapper, + lr: float = 1e-4, + pre_encoded: bool = False + ): + super().__init__() + + self.diffusion = model + + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1 + ) + + self.lr = lr + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.losses = MultiLoss(loss_modules) + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + return optim.Adam([*self.diffusion.parameters()], lr=self.lr) + + def training_step(self, batch, batch_idx): + reals = batch[0] + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + diffusion_input = reals + + loss_info = {} + + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + if self.diffusion.pretransform is not None: + if not self.pre_encoded: + with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + loss_info["reals"] = diffusion_input + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + targets = noise * alphas - diffusion_input * sigmas + + with torch.amp.autocast('cuda'): + v = self.diffusion(noised_inputs, t) + + loss_info.update({ + "v": v, + "targets": targets + }) + + loss, losses = self.losses(loss_info) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + +class DiffusionUncondDemoCallback(Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + demo_steps=250, + sample_rate=48000 + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_steps = demo_steps + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + demo_samples = module.diffusion.sample_size + + if module.diffusion.pretransform is not None: + demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio + + noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device) + + try: + with torch.amp.autocast('cuda'): + fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_{trainer.global_step:08}.wav' + fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + del fakes + + except Exception as e: + print(f'{type(e).__name__}: {e}') + finally: + gc.collect() + torch.cuda.empty_cache() + +class DiffusionInfillTrainingWrapper(L.LightningModule): + ''' + Wrapper for training an unconditional audio diffusion model (like Dance Diffusion). + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = 1e-4, + optimizer_configs: dict = None, + pre_encoded: bool = False, + frac_lengths_mask = (0.7, 1.), + min_span_len = 10, + timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", + diffusion_objective = 'rectified_flow', + ctx_drop: float = 0.1, + r_drop: float = 0.0, + ): + super().__init__() + + self.diffusion = model + + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1 + ) + + if optimizer_configs is None: + optimizer_configs = { + "diffusion": { + "optimizer": { + "type": "Adam", + "config": { + "lr": lr + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + self.frac_lengths_mask = frac_lengths_mask + self.min_span_len = min_span_len + self.timestep_sampler = timestep_sampler + self.ctx_drop = ctx_drop + self.r_drop = r_drop + self.diffusion_objective = diffusion_objective + print(f'Training in the {diffusion_objective} formulation') + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss", + mask_key="mask" + ) + ] + + self.losses = MultiLoss(loss_modules) + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + diffusion_opt_config = self.optimizer_configs['diffusion'] + opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) + + if "scheduler" in diffusion_opt_config: + sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) + sched_diff_config = { + "scheduler": sched_diff, + "interval": "step" + } + return [opt_diff], [sched_diff_config] + + return [opt_diff] + + def training_step(self, batch, batch_idx): + reals, metadata = batch + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + # import ipdb + # ipdb.set_trace() + p_drop = torch.rand(1).item() + # r_drop = torch.rand(1).item() + # if p_drop >= self.ctx_drop and self.r_drop > 0.0 and r_drop < self.r_drop: + # generate_channel_mask(reals) + + diffusion_input = reals + assert torch.all(torch.isfinite(diffusion_input)), "Non-finite values detected in diffusion_input" + p = Profiler() + loss_info = {} + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + p.tick("setup") + + conditioning = {} + + p.tick("conditioning") + if self.diffusion.pretransform is not None: + if not self.pre_encoded: + with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + loss_info["reals"] = diffusion_input + + if self.timestep_sampler == "uniform": + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + elif self.timestep_sampler == "logit_normal": + t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) + + # # Calculate the noise schedule parameters for those timesteps + # alphas, sigmas = get_alphas_sigmas(t) + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + # x_ctx = diffusion_input.detach().clone().transpose(1,2) + bsz, dim, seq_len = diffusion_input.shape + + + if p_drop < self.ctx_drop: + ctx_mask = torch.ones((bsz, seq_len), device = diffusion_input.device, dtype = torch.bool) + # elif self.r_drop > 0.0 and r_drop < self.r_drop: + # ctx_mask = torch.zeros((bsz, seq_len), device=diffusion_input.device, dtype=torch.bool) + else: + # 计算 frac_lengths 提前使用 + frac_lengths = torch.zeros((bsz,), device=diffusion_input.device).uniform_(*self.frac_lengths_mask) + # if self.r_drop > 0.0 and r_drop < self.r_drop: + # import ipdb + # ipdb.set_trace() + + # ctx_mask = torch.zeros((bsz, seq_len), device=diffusion_input.device, dtype=torch.bool) + # else: + ctx_mask = generate_mask(bsz, seq_len, frac_lengths, self.min_span_len) + + if ctx_mask.dim() == 2: + ctx_mask = ctx_mask.unsqueeze(1) + masked_sequence = diffusion_input * ~ctx_mask + conditioning['x_ctx'] = [masked_sequence] + if self.diffusion_objective == "v": + targets = noise * alphas - diffusion_input * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - diffusion_input + with torch.amp.autocast('cuda'): + p.tick("amp") + v = self.diffusion(noised_inputs, t, cond=conditioning) + p.tick("diffusion") + loss_info.update({ + "v": v, + "targets": targets, + "mask": ctx_mask.squeeze(-1) + }) + # import ipdb + # ipdb.set_trace() + loss, losses = self.losses(loss_info) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + p.tick("log") + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + +class DiffusionInfillDemoCallback(Callback): + def __init__(self, + demo_dl, + demo_every=2000, + num_demos=8, + demo_steps=250, + sample_rate=48000 + ): + super().__init__() + self.demo_dl = iter(demo_dl) + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_steps = demo_steps + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + + try: + demo_reals, _ = next(self.demo_dl) + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + demo_reals = demo_reals.to(module.device) + reals = demo_reals + log_dict = {} + + if not module.pre_encoded: + # Log the real audio + log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu())) + # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals") + + if module.diffusion.pretransform is not None: + module.diffusion.pretransform.to(module.device) + with torch.amp.autocast('cuda'): + demo_reals = module.diffusion.pretransform.encode(demo_reals) + + demo_samples = demo_reals.shape[2] + + # Get conditioning + conditioning = {} + + noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device) + frac_lengths = torch.zeros((demo_reals.shape[0],), device = module.device).uniform_(*(0.3,0.5)) + ctx_mask = generate_mask(demo_reals.shape[0],demo_reals.shape[2], frac_lengths, module.min_span_len) + # x_ctx = (demo_reals * ~ctx_mask.unsqueeze(1)).transpose(1,2) + x_ctx = demo_reals * ~ctx_mask.unsqueeze(1) + + conditioning['x_ctx'] = [x_ctx] + # x_ctx_mask = x_ctx * ~ctx_mask.unsqueeze(-1) + if module.diffusion.pretransform is not None: + log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(x_ctx.cpu())) + else: + log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(x_ctx, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu())) + cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) + with torch.amp.autocast('cuda'): + if module.diffusion_objective == "v": + fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0) + elif module.diffusion_objective == "rectified_flow": + fakes = sample_discrete_euler(module.diffusion_ema, noise, self.demo_steps, **cond_inputs) + # fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + # #Interleave reals and fakes + # reals_fakes = rearrange([reals, fakes], 'i b d n -> (b i) d n') + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + + filename = f'results/audio_ssl/demo_ssl_{trainer.global_step:08}.wav' + os.makedirs(Path(filename).parent,exist_ok=True) + fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + del fakes + + except Exception as e: + print(f'{type(e).__name__}: {e}') + finally: + gc.collect() + torch.cuda.empty_cache() + +class DiffusionCondTrainingWrapper(L.LightningModule): + ''' + Wrapper for training a conditional audio diffusion model. + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = None, + mask_padding: bool = False, + mask_padding_dropout: float = 0.0, + use_ema: bool = True, + log_loss_info: bool = False, + optimizer_configs: dict = None, + diffusion_objective: tp.Literal["rectified_flow", "v"] = "v", + pre_encoded: bool = False, + cfg_dropout_prob = 0.1, + timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", + max_mask_segments = 0, + ): + super().__init__() + + self.diffusion = model + + if use_ema: + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + else: + self.diffusion_ema = None + + self.mask_padding = mask_padding + self.mask_padding_dropout = mask_padding_dropout + + self.cfg_dropout_prob = cfg_dropout_prob + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.timestep_sampler = timestep_sampler + + self.diffusion_objective = model.diffusion_objective + print(f'Training in the {self.diffusion_objective} formulation with timestep sampler: {timestep_sampler}') + + self.max_mask_segments = max_mask_segments + + self.loss_modules = [ + MSELoss("output", + "targets", + weight=1.0, + mask_key="padding_mask" if self.mask_padding else None, + name="mse_loss" + ) + ] + + self.losses = MultiLoss(self.loss_modules) + + self.log_loss_info = log_loss_info + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "diffusion": { + "optimizer": { + "type": "Adam", + "config": { + "lr": lr + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + diffusion_opt_config = self.optimizer_configs['diffusion'] + opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) + + if "scheduler" in diffusion_opt_config: + sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) + sched_diff_config = { + "scheduler": sched_diff, + "interval": "step" + } + return [opt_diff], [sched_diff_config] + + return [opt_diff] + + def training_step(self, batch, batch_idx): + reals, metadata = batch + # import ipdb + # ipdb.set_trace() + p = Profiler() + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + diffusion_input = reals + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + p.tick("setup") + + with torch.amp.autocast('cuda'): + + conditioning = self.diffusion.conditioner(metadata, self.device) + + + video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) + conditioning['metaclip_features'][~video_exist] = self.diffusion.model.model.empty_clip_feat + conditioning['sync_features'][~video_exist] = self.diffusion.model.model.empty_sync_feat + # If mask_padding is on, randomly drop the padding masks to allow for learning silence padding + use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout + + # Create batch tensor of attention masks from the "mask" field of the metadata array + if use_padding_mask: + padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) # Shape (batch_size, sequence_length) + + p.tick("conditioning") + + if self.diffusion.pretransform is not None: + self.diffusion.pretransform.to(self.device) + + if not self.pre_encoded: + with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) + + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + p.tick("pretransform") + + # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input + if use_padding_mask: + padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + if self.max_mask_segments > 0: + # Max mask size is the full sequence length + max_mask_length = diffusion_input.shape[2] + + # Create a mask of random length for a random slice of the input + masked_input, mask = self.random_mask(diffusion_input, max_mask_length) + + conditioning['inpaint_mask'] = [mask] + conditioning['inpaint_masked_input'] = masked_input + + if self.timestep_sampler == "uniform": + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + elif self.timestep_sampler == "logit_normal": + t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) + # import ipdb + # ipdb.set_trace() + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + + if self.diffusion_objective == "v": + targets = noise * alphas - diffusion_input * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - diffusion_input + + p.tick("noise") + + extra_args = {} + + if use_padding_mask: + extra_args["mask"] = padding_masks + + with torch.amp.autocast('cuda'): + p.tick("amp") + output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) + p.tick("diffusion") + + loss_info.update({ + "output": output, + "targets": targets, + "padding_mask": padding_masks if use_padding_mask else None, + }) + + loss, losses = self.losses(loss_info) + + p.tick("loss") + + if self.log_loss_info: + # Loss debugging logs + num_loss_buckets = 10 + bucket_size = 1 / num_loss_buckets + loss_all = F.mse_loss(output, targets, reduction="none") + + sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() + + # gather loss_all across all GPUs + loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") + + # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size + loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + + # Log bucketed losses with corresponding sigma bucket values, if it's not NaN + debug_log_dict = { + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + } + + self.log_dict(debug_log_dict) + + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + p.tick("log") + #print(f"Profiler: {p}") + return loss + + def validation_step(self, batch, batch_idx): + reals, metadata = batch + # breakpoint() + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + diffusion_input = reals + + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + + with torch.amp.autocast('cuda'): + + conditioning = self.diffusion.conditioner(metadata, self.device) + + video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) + conditioning['metaclip_features'][~video_exist] = self.diffusion.model.model.empty_clip_feat + conditioning['sync_features'][~video_exist] = self.diffusion.model.model.empty_sync_feat + + if self.diffusion.pretransform is not None: + + if not self.pre_encoded: + self.diffusion.pretransform.to(self.device) + with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad) + + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + if self.max_mask_segments > 0: + # Max mask size is the full sequence length + max_mask_length = diffusion_input.shape[2] + + # Create a mask of random length for a random slice of the input + masked_input, mask = self.random_mask(diffusion_input, max_mask_length) + + conditioning['inpaint_mask'] = [mask] + conditioning['inpaint_masked_input'] = masked_input + if self.timestep_sampler == "uniform": + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + elif self.timestep_sampler == "logit_normal": + t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) + + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + + if self.diffusion_objective == "v": + targets = noise * alphas - diffusion_input * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - diffusion_input + + + with torch.amp.autocast('cuda'): + output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = 0.0) + + loss_info.update({ + "output": output, + "targets": targets, + }) + + loss, losses = self.losses(loss_info) + + + log_dict = { + 'val_loss': loss.detach(), + } + + self.log_dict(log_dict, prog_bar=True, batch_size=diffusion_input.size(0)) + + def predict_step(self, batch, batch_idx): + reals, metadata = batch + # import ipdb + # ipdb.set_trace() + ids = [item['id'] for item in metadata] + batch_size, length = reals.shape[0], reals.shape[2] + with torch.amp.autocast('cuda'): + conditioning = self.diffusion.conditioner(metadata, self.device) + + video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0) + conditioning['metaclip_features'][~video_exist] = self.diffusion.model.model.empty_clip_feat + conditioning['sync_features'][~video_exist] = self.diffusion.model.model.empty_sync_feat + + cond_inputs = self.diffusion.get_conditioning_inputs(conditioning) + if batch_size > 1: + noise_list = [] + for _ in range(batch_size): + noise_1 = torch.randn([1, self.diffusion.io_channels, length]).to(self.device) # 每次生成推进RNG状态 + noise_list.append(noise_1) + noise = torch.cat(noise_list, dim=0) + else: + noise = torch.randn([batch_size, self.diffusion.io_channels, length]).to(self.device) + with torch.amp.autocast('cuda'): + + model = self.diffusion.model + if self.diffusion_objective == "v": + fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True) + elif self.diffusion_objective == "rectified_flow": + import time + start_time = time.time() + fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True) + end_time = time.time() + execution_time = end_time - start_time + print(f"执行时间: {execution_time:.2f} 秒") + breakpoint() + if self.diffusion.pretransform is not None: + fakes = self.diffusion.pretransform.decode(fakes) + + audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() + return audios + # # Put the demos together + # fakes = rearrange(fakes, 'b d n -> d (b n)') + + def random_mask(self, sequence, max_mask_length): + b, _, sequence_length = sequence.size() + + # Create a mask tensor for each batch element + masks = [] + + for i in range(b): + mask_type = random.randint(0, 2) + + if mask_type == 0: # Random mask with multiple segments + num_segments = random.randint(1, self.max_mask_segments) + max_segment_length = max_mask_length // num_segments + + segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments) + + mask = torch.ones((1, 1, sequence_length)) + for length in segment_lengths: + mask_start = random.randint(0, sequence_length - length) + mask[:, :, mask_start:mask_start + length] = 0 + + elif mask_type == 1: # Full mask + mask = torch.zeros((1, 1, sequence_length)) + + elif mask_type == 2: # Causal mask + mask = torch.ones((1, 1, sequence_length)) + mask_length = random.randint(1, max_mask_length) + mask[:, :, -mask_length:] = 0 + + mask = mask.to(sequence.device) + masks.append(mask) + + # Concatenate the mask tensors into a single tensor + mask = torch.cat(masks, dim=0).to(sequence.device) + + # Apply the mask to the sequence tensor for each batch element + masked_sequence = sequence * mask + + return masked_sequence, mask + + def on_before_zero_grad(self, *args, **kwargs): + if self.diffusion_ema is not None: + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + if self.diffusion_ema is not None: + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + +class DiffusionCondDemoCallback(Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + sample_size=65536, + demo_steps=250, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = {}, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + demo_cond_from_batch: bool = False, + display_audio_cond: bool = False + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_samples = sample_size + self.demo_steps = demo_steps + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.demo_conditioning = demo_conditioning + self.demo_cfg_scales = demo_cfg_scales + + # If true, the callback will use the metadata from the batch to generate the demo conditioning + self.demo_cond_from_batch = demo_cond_from_batch + + # If true, the callback will display the audio conditioning + self.display_audio_cond = display_audio_cond + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + module.eval() + + print(f"Generating demo") + self.last_demo_step = trainer.global_step + + demo_samples = self.demo_samples + + demo_cond = self.demo_conditioning + + if self.demo_cond_from_batch: + # Get metadata from the batch + demo_cond = batch[1][:self.num_demos] + + if '.pth' in demo_cond[0]: + demo_cond_data = [] + for path in demo_cond: + # info = {} + data = torch.load(path, weights_only=True) + if 'caption_t5' not in data.keys(): + data['caption_t5'] = data['caption'] + data['seconds_start'] = 0 + data['seconds_total'] = 10 + demo_cond_data.append(data) + demo_cond = demo_cond_data + elif '.npz' in demo_cond[0]: + demo_cond_data = [] + for path in demo_cond: + # info = {} + npz_data = np.load(path,allow_pickle=True) + data = {key: npz_data[key] for key in npz_data.files} + for key in data.keys(): + # print(key) + if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number): + data[key] = torch.from_numpy(data[key]) + + demo_cond_data.append(data) + demo_cond = demo_cond_data + if module.diffusion.pretransform is not None: + demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio + + noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device) + + try: + print("Getting conditioning") + with torch.amp.autocast('cuda'): + conditioning = module.diffusion.conditioner(demo_cond, module.device) + + cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) + + log_dict = {} + + if self.display_audio_cond: + audio_inputs = torch.cat([cond["audio"] for cond in demo_cond], dim=0) + audio_inputs = rearrange(audio_inputs, 'b d n -> d (b n)') + + filename = f'demo_audio_cond_{trainer.global_step:08}.wav' + audio_inputs = audio_inputs.to(torch.float32).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, audio_inputs, self.sample_rate) + log_dict[f'demo_audio_cond'] = wandb.Audio(filename, sample_rate=self.sample_rate, caption="Audio conditioning") + log_dict[f"demo_audio_cond_melspec_left"] = wandb.Image(audio_spectrogram_image(audio_inputs)) + trainer.logger.experiment.log(log_dict) + + for cfg_scale in self.demo_cfg_scales: + + print(f"Generating demo for cfg scale {cfg_scale}") + + with torch.amp.autocast('cuda'): + # model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model + model = module.diffusion.model + + if module.diffusion_objective == "v": + fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + elif module.diffusion_objective == "rectified_flow": + fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demos/demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes.div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + trainer.logger.experiment.log(log_dict) + + del fakes + + except Exception as e: + raise e + finally: + gc.collect() + torch.cuda.empty_cache() + module.train() + +class DiffusionCondInpaintTrainingWrapper(L.LightningModule): + ''' + Wrapper for training a conditional audio diffusion model. + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = 1e-4, + max_mask_segments = 10, + log_loss_info: bool = False, + optimizer_configs: dict = None, + use_ema: bool = True, + pre_encoded: bool = False, + cfg_dropout_prob = 0.1, + timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", + ): + super().__init__() + + self.diffusion = model + + self.use_ema = use_ema + + if self.use_ema: + self.diffusion_ema = EMA( + self.diffusion.model, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + else: + self.diffusion_ema = None + + self.cfg_dropout_prob = cfg_dropout_prob + + self.lr = lr + self.max_mask_segments = max_mask_segments + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.timestep_sampler = timestep_sampler + + self.diffusion_objective = model.diffusion_objective + + self.loss_modules = [ + MSELoss("output", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.losses = MultiLoss(self.loss_modules) + + self.log_loss_info = log_loss_info + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "diffusion": { + "optimizer": { + "type": "Adam", + "config": { + "lr": lr + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + diffusion_opt_config = self.optimizer_configs['diffusion'] + opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters()) + + if "scheduler" in diffusion_opt_config: + sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff) + sched_diff_config = { + "scheduler": sched_diff, + "interval": "step" + } + return [opt_diff], [sched_diff_config] + + return [opt_diff] + + def random_mask(self, sequence, max_mask_length): + b, _, sequence_length = sequence.size() + + # Create a mask tensor for each batch element + masks = [] + + for i in range(b): + mask_type = random.randint(0, 2) + + if mask_type == 0: # Random mask with multiple segments + num_segments = random.randint(1, self.max_mask_segments) + max_segment_length = max_mask_length // num_segments + + segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments) + + mask = torch.ones((1, 1, sequence_length)) + for length in segment_lengths: + mask_start = random.randint(0, sequence_length - length) + mask[:, :, mask_start:mask_start + length] = 0 + + elif mask_type == 1: # Full mask + mask = torch.zeros((1, 1, sequence_length)) + + elif mask_type == 2: # Causal mask + mask = torch.ones((1, 1, sequence_length)) + mask_length = random.randint(1, max_mask_length) + mask[:, :, -mask_length:] = 0 + + mask = mask.to(sequence.device) + masks.append(mask) + + # Concatenate the mask tensors into a single tensor + mask = torch.cat(masks, dim=0).to(sequence.device) + + # Apply the mask to the sequence tensor for each batch element + masked_sequence = sequence * mask + + return masked_sequence, mask + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + p = Profiler() + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + diffusion_input = reals + + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + p.tick("setup") + + with torch.amp.autocast('cuda'): + conditioning = self.diffusion.conditioner(metadata, self.device) + + p.tick("conditioning") + + if self.diffusion.pretransform is not None: + self.diffusion.pretransform.to(self.device) + + if not self.pre_encoded: + with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + p.tick("pretransform") + + # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input + # if use_padding_mask: + # padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool() + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + # Max mask size is the full sequence length + max_mask_length = diffusion_input.shape[2] + + # Create a mask of random length for a random slice of the input + masked_input, mask = self.random_mask(diffusion_input, max_mask_length) + + # conditioning['inpaint_mask'] = [mask] + conditioning['inpaint_masked_input'] = [masked_input] + + if self.timestep_sampler == "uniform": + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + elif self.timestep_sampler == "logit_normal": + t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device)) + + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(diffusion_input) + noised_inputs = diffusion_input * alphas + noise * sigmas + + if self.diffusion_objective == "v": + targets = noise * alphas - diffusion_input * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - diffusion_input + + p.tick("noise") + + extra_args = {} + + with torch.amp.autocast('cuda'): + p.tick("amp") + output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args) + p.tick("diffusion") + + loss_info.update({ + "output": output, + "targets": targets, + }) + + loss, losses = self.losses(loss_info) + + if self.log_loss_info: + # Loss debugging logs + num_loss_buckets = 10 + bucket_size = 1 / num_loss_buckets + loss_all = F.mse_loss(output, targets, reduction="none") + + sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() + + # gather loss_all across all GPUs + loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") + + # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size + loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + + # Log bucketed losses with corresponding sigma bucket values, if it's not NaN + debug_log_dict = { + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + } + + self.log_dict(debug_log_dict) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + p.tick("log") + #print(f"Profiler: {p}") + return loss + + def on_before_zero_grad(self, *args, **kwargs): + if self.diffusion_ema is not None: + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + if self.diffusion_ema is not None: + self.diffusion.model = self.diffusion_ema.ema_model + + if use_safetensors: + save_file(self.diffusion.state_dict(), path) + else: + torch.save({"state_dict": self.diffusion.state_dict()}, path) + +class DiffusionCondInpaintDemoCallback(Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + demo_steps=250, + sample_size=65536, + sample_rate=48000, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7] + ): + super().__init__() + self.demo_every = demo_every + self.demo_steps = demo_steps + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.demo_cfg_scales = demo_cfg_scales + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + try: + log_dict = {} + + demo_reals, metadata = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + demo_reals = demo_reals.to(module.device) + + if not module.pre_encoded: + # Log the real audio + log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu())) + # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals") + + if module.diffusion.pretransform is not None: + module.diffusion.pretransform.to(module.device) + with torch.amp.autocast('cuda'): + demo_reals = module.diffusion.pretransform.encode(demo_reals) + + demo_samples = demo_reals.shape[2] + + # Get conditioning + conditioning = module.diffusion.conditioner(metadata, module.device) + + masked_input, mask = module.random_mask(demo_reals, demo_reals.shape[2]) + + conditioning['inpaint_mask'] = [mask] + conditioning['inpaint_masked_input'] = [masked_input] + + if module.diffusion.pretransform is not None: + log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(masked_input.cpu())) + else: + log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(masked_input, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu())) + + cond_inputs = module.diffusion.get_conditioning_inputs(conditioning) + + noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device) + + trainer.logger.experiment.log(log_dict) + + for cfg_scale in self.demo_cfg_scales: + model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model + print(f"Generating demo for cfg scale {cfg_scale}") + + if module.diffusion_objective == "v": + fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + elif module.diffusion_objective == "rectified_flow": + fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True) + + if module.diffusion.pretransform is not None: + with torch.amp.autocast('cuda'): + fakes = module.diffusion.pretransform.decode(fakes) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + except Exception as e: + print(f'{type(e).__name__}: {e}') + raise e + +class DiffusionAutoencoderTrainingWrapper(L.LightningModule): + ''' + Wrapper for training a diffusion autoencoder + ''' + def __init__( + self, + model: DiffusionAutoencoder, + lr: float = 1e-4, + ema_copy = None, + use_reconstruction_loss: bool = False + ): + super().__init__() + + self.diffae = model + + self.diffae_ema = EMA( + self.diffae, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + + self.lr = lr + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + if model.bottleneck is not None: + # TODO: Use loss config for configurable bottleneck weights and reconstruction losses + loss_modules += create_loss_modules_from_bottleneck(model.bottleneck, {}) + + self.use_reconstruction_loss = use_reconstruction_loss + + if use_reconstruction_loss: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + sample_rate = model.sample_rate + + stft_loss_args = { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + } + + out_channels = model.out_channels + + if model.pretransform is not None: + out_channels = model.pretransform.io_channels + + if out_channels == 2: + self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + else: + self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + loss_modules.append( + AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss + ) + + self.losses = MultiLoss(loss_modules) + + def configure_optimizers(self): + return optim.Adam([*self.diffae.parameters()], lr=self.lr) + + def training_step(self, batch, batch_idx): + reals = batch[0] + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + loss_info["audio_reals"] = reals + + if self.diffae.pretransform is not None: + with torch.no_grad(): + reals = self.diffae.pretransform.encode(reals) + + loss_info["reals"] = reals + + #Encode reals, skipping the pretransform since it was already applied + latents, encoder_info = self.diffae.encode(reals, return_info=True, skip_pretransform=True) + + loss_info["latents"] = latents + loss_info.update(encoder_info) + + if self.diffae.decoder is not None: + latents = self.diffae.decoder(latents) + + # Upsample latents to match diffusion length + if latents.shape[2] != reals.shape[2]: + latents = F.interpolate(latents, size=reals.shape[2], mode='nearest') + + loss_info["latents_upsampled"] = latents + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(reals) + noised_reals = reals * alphas + noise * sigmas + targets = noise * alphas - reals * sigmas + + with torch.amp.autocast('cuda'): + v = self.diffae.diffusion(noised_reals, t, input_concat_cond=latents) + + loss_info.update({ + "v": v, + "targets": targets + }) + + if self.use_reconstruction_loss: + pred = noised_reals * alphas - v * sigmas + + loss_info["pred"] = pred + + if self.diffae.pretransform is not None: + pred = self.diffae.pretransform.decode(pred) + loss_info["audio_pred"] = pred + + loss, losses = self.losses(loss_info) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': reals.std(), + 'train/latent_std': latents.std(), + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffae_ema.update() + + def export_model(self, path, use_safetensors=False): + + model = self.diffae_ema.ema_model + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + +class DiffusionAutoencoderDemoCallback(Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + demo_steps=250, + sample_size=65536, + sample_rate=48000 + ): + super().__init__() + self.demo_every = demo_every + self.demo_steps = demo_steps + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + demo_reals, _ = next(self.demo_dl) + + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + encoder_input = demo_reals + + encoder_input = encoder_input.to(module.device) + + demo_reals = demo_reals.to(module.device) + + with torch.no_grad() and torch.amp.autocast('cuda'): + latents = module.diffae_ema.ema_model.encode(encoder_input).float() + fakes = module.diffae_ema.ema_model.decode(latents, steps=self.demo_steps) + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'recon_{trainer.global_step:08}.wav' + reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_dict[f'recon'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents) + log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents)) + + log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + + if module.diffae_ema.ema_model.pretransform is not None: + with torch.no_grad() and torch.amp.autocast('cuda'): + initial_latents = module.diffae_ema.ema_model.pretransform.encode(encoder_input) + first_stage_fakes = module.diffae_ema.ema_model.pretransform.decode(initial_latents) + first_stage_fakes = rearrange(first_stage_fakes, 'b d n -> d (b n)') + first_stage_fakes = first_stage_fakes.to(torch.float32).mul(32767).to(torch.int16).cpu() + first_stage_filename = f'first_stage_{trainer.global_step:08}.wav' + torchaudio.save(first_stage_filename, first_stage_fakes, self.sample_rate) + + log_dict[f'first_stage_latents'] = wandb.Image(tokens_spectrogram_image(initial_latents)) + + log_dict[f'first_stage'] = wandb.Audio(first_stage_filename, + sample_rate=self.sample_rate, + caption=f'First Stage Reconstructed') + + log_dict[f'first_stage_melspec_left'] = wandb.Image(audio_spectrogram_image(first_stage_fakes)) + + + trainer.logger.experiment.log(log_dict) + +def create_source_mixture(reals, num_sources=2): + # Create a fake mixture source by mixing elements from the training batch together with random offsets + source = torch.zeros_like(reals) + for i in range(reals.shape[0]): + sources_added = 0 + + js = list(range(reals.shape[0])) + random.shuffle(js) + for j in js: + if i == j or (i != j and sources_added < num_sources): + # Randomly offset the mixed element between 0 and the length of the source + seq_len = reals.shape[2] + offset = random.randint(0, seq_len-1) + source[i, :, offset:] += reals[j, :, :-offset] + if i == j: + # If this is the real one, shift the reals as well to ensure alignment + new_reals = torch.zeros_like(reals[i]) + new_reals[:, offset:] = reals[i, :, :-offset] + reals[i] = new_reals + sources_added += 1 + + return source + +class DiffusionPriorTrainingWrapper(L.LightningModule): + ''' + Wrapper for training a diffusion prior for inverse problems + Prior types: + mono_stereo: The prior is conditioned on a mono version of the audio to generate a stereo version + ''' + def __init__( + self, + model: ConditionedDiffusionModelWrapper, + lr: float = 1e-4, + ema_copy = None, + prior_type: PriorType = PriorType.MonoToStereo, + use_reconstruction_loss: bool = False, + log_loss_info: bool = False, + ): + super().__init__() + + self.diffusion = model + + self.diffusion_ema = EMA( + self.diffusion, + ema_model=ema_copy, + beta=0.9999, + power=3/4, + update_every=1, + update_after_step=1, + include_online_model=False + ) + + self.lr = lr + + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.log_loss_info = log_loss_info + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.use_reconstruction_loss = use_reconstruction_loss + + if use_reconstruction_loss: + scales = [2048, 1024, 512, 256, 128, 64, 32] + hop_sizes = [] + win_lengths = [] + overlap = 0.75 + for s in scales: + hop_sizes.append(int(s * (1 - overlap))) + win_lengths.append(s) + + sample_rate = model.sample_rate + + stft_loss_args = { + "fft_sizes": scales, + "hop_sizes": hop_sizes, + "win_lengths": win_lengths, + "perceptual_weighting": True + } + + out_channels = model.io_channels + + + if model.pretransform is not None: + out_channels = model.pretransform.io_channels + self.audio_out_channels = out_channels + + if self.audio_out_channels == 2: + self.sdstft = auraloss.freq.SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + self.lrstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + # Add left and right channel reconstruction losses in addition to the sum and difference + loss_modules += [ + AuralossLoss(self.lrstft, 'audio_reals_left', 'pred_left', name='stft_loss_left', weight=0.05), + AuralossLoss(self.lrstft, 'audio_reals_right', 'pred_right', name='stft_loss_right', weight=0.05), + ] + + else: + self.sdstft = auraloss.freq.MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args) + + loss_modules.append( + AuralossLoss(self.sdstft, 'audio_reals', 'audio_pred', name='mrstft_loss', weight=0.1), # Reconstruction loss + ) + + self.losses = MultiLoss(loss_modules) + + self.prior_type = prior_type + + def configure_optimizers(self): + return optim.Adam([*self.diffusion.parameters()], lr=self.lr) + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + loss_info = {} + + loss_info["audio_reals"] = reals + + if self.prior_type == PriorType.MonoToStereo: + source = reals.mean(dim=1, keepdim=True).repeat(1, reals.shape[1], 1).to(self.device) + loss_info["audio_reals_mono"] = source + else: + raise ValueError(f"Unknown prior type {self.prior_type}") + + if self.diffusion.pretransform is not None: + with torch.no_grad(): + reals = self.diffusion.pretransform.encode(reals) + + if self.prior_type in [PriorType.MonoToStereo]: + source = self.diffusion.pretransform.encode(source) + + if self.diffusion.conditioner is not None: + with torch.amp.autocast('cuda'): + conditioning = self.diffusion.conditioner(metadata, self.device) + else: + conditioning = {} + + loss_info["reals"] = reals + + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(reals.shape[0])[:, 0].to(self.device) + + # Calculate the noise schedule parameters for those timesteps + alphas, sigmas = get_alphas_sigmas(t) + + # Combine the ground truth data and the noise + alphas = alphas[:, None, None] + sigmas = sigmas[:, None, None] + noise = torch.randn_like(reals) + noised_reals = reals * alphas + noise * sigmas + targets = noise * alphas - reals * sigmas + + with torch.amp.autocast('cuda'): + + conditioning['source'] = [source] + + v = self.diffusion(noised_reals, t, cond=conditioning, cfg_dropout_prob = 0.1) + + loss_info.update({ + "v": v, + "targets": targets + }) + + if self.use_reconstruction_loss: + pred = noised_reals * alphas - v * sigmas + + loss_info["pred"] = pred + + if self.diffusion.pretransform is not None: + pred = self.diffusion.pretransform.decode(pred) + loss_info["audio_pred"] = pred + + if self.audio_out_channels == 2: + loss_info["pred_left"] = pred[:, 0:1, :] + loss_info["pred_right"] = pred[:, 1:2, :] + loss_info["audio_reals_left"] = loss_info["audio_reals"][:, 0:1, :] + loss_info["audio_reals_right"] = loss_info["audio_reals"][:, 1:2, :] + + loss, losses = self.losses(loss_info) + + if self.log_loss_info: + # Loss debugging logs + num_loss_buckets = 10 + bucket_size = 1 / num_loss_buckets + loss_all = F.mse_loss(v, targets, reduction="none") + + sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze() + + # gather loss_all across all GPUs + loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n") + + # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size + loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)]) + + # Log bucketed losses with corresponding sigma bucket values, if it's not NaN + debug_log_dict = { + f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i]) + } + + self.log_dict(debug_log_dict) + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': reals.std() + } + + for loss_name, loss_value in losses.items(): + log_dict[f"train/{loss_name}"] = loss_value.detach() + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + self.diffusion_ema.update() + + def export_model(self, path, use_safetensors=False): + + #model = self.diffusion_ema.ema_model + model = self.diffusion + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + +class DiffusionPriorDemoCallback(Callback): + def __init__( + self, + demo_dl, + demo_every=2000, + demo_steps=250, + sample_size=65536, + sample_rate=48000 + ): + super().__init__() + + self.demo_every = demo_every + self.demo_steps = demo_steps + self.demo_samples = sample_size + self.demo_dl = iter(demo_dl) + self.sample_rate = sample_rate + self.last_demo_step = -1 + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: DiffusionAutoencoderTrainingWrapper, outputs, batch, batch_idx): + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + self.last_demo_step = trainer.global_step + + demo_reals, metadata = next(self.demo_dl) + # import ipdb + # ipdb.set_trace() + # Remove extra dimension added by WebDataset + if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + demo_reals = demo_reals[0] + + demo_reals = demo_reals.to(module.device) + + encoder_input = demo_reals + + if module.diffusion.conditioner is not None: + with torch.amp.autocast('cuda'): + conditioning_tensors = module.diffusion.conditioner(metadata, module.device) + + else: + conditioning_tensors = {} + + + with torch.no_grad() and torch.amp.autocast('cuda'): + if module.prior_type == PriorType.MonoToStereo and encoder_input.shape[1] > 1: + source = encoder_input.mean(dim=1, keepdim=True).repeat(1, encoder_input.shape[1], 1).to(module.device) + + if module.diffusion.pretransform is not None: + encoder_input = module.diffusion.pretransform.encode(encoder_input) + source_input = module.diffusion.pretransform.encode(source) + else: + source_input = source + + conditioning_tensors['source'] = [source_input] + + fakes = sample(module.diffusion_ema.model, torch.randn_like(encoder_input), self.demo_steps, 0, cond=conditioning_tensors) + + if module.diffusion.pretransform is not None: + fakes = module.diffusion.pretransform.decode(fakes) + + #Interleave reals and fakes + reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n') + + # Put the demos together + reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'recon_mono_{trainer.global_step:08}.wav' + reals_fakes = reals_fakes.to(torch.float32).div(torch.max(torch.abs(reals_fakes))).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, reals_fakes, self.sample_rate) + + log_dict[f'recon'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes)) + + #Log the source + filename = f'source_{trainer.global_step:08}.wav' + source = rearrange(source, 'b d n -> d (b n)') + source = source.to(torch.float32).mul(32767).to(torch.int16).cpu() + torchaudio.save(filename, source, self.sample_rate) + + log_dict[f'source'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Source') + + log_dict[f'source_melspec_left'] = wandb.Image(audio_spectrogram_image(source)) + + trainer.logger.experiment.log(log_dict) \ No newline at end of file diff --git a/think_sound/training/factory.py b/think_sound/training/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..d2558b161aa4772257de5f5b88751abd4f869932 --- /dev/null +++ b/think_sound/training/factory.py @@ -0,0 +1,262 @@ +import torch +from torch.nn import Parameter +from ..models.factory import create_model_from_config + +def create_training_wrapper_from_config(model_config, model): + model_type = model_config.get('model_type', None) + assert model_type is not None, 'model_type must be specified in model config' + + training_config = model_config.get('training', None) + assert training_config is not None, 'training config must be specified in model config' + if model_type == 'autoencoder': + from .autoencoders import AutoencoderTrainingWrapper + + ema_copy = None + + if training_config.get("use_ema", False): + ema_copy = create_model_from_config(model_config) + ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + use_ema = training_config.get("use_ema", False) + + latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0) + + teacher_model = training_config.get("teacher_model", None) + if teacher_model is not None: + teacher_model = create_model_from_config(teacher_model) + teacher_model = teacher_model.eval().requires_grad_(False) + + teacher_model_ckpt = training_config.get("teacher_model_ckpt", None) + if teacher_model_ckpt is not None: + teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"]) + else: + raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified") + + return AutoencoderTrainingWrapper( + model, + lr=training_config["learning_rate"], + warmup_steps=training_config.get("warmup_steps", 0), + encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False), + sample_rate=model_config["sample_rate"], + loss_config=training_config.get("loss_configs", None), + optimizer_configs=training_config.get("optimizer_configs", None), + use_ema=use_ema, + ema_copy=ema_copy if use_ema else None, + force_input_mono=training_config.get("force_input_mono", False), + latent_mask_ratio=latent_mask_ratio, + teacher_model=teacher_model + ) + elif model_type == 'diffusion_uncond': + from .diffusion import DiffusionUncondTrainingWrapper + return DiffusionUncondTrainingWrapper( + model, + lr=training_config["learning_rate"], + pre_encoded=training_config.get("pre_encoded", False), + ) + elif model_type == 'diffusion_infill': + from .diffusion import DiffusionInfillTrainingWrapper + return DiffusionInfillTrainingWrapper( + model, + lr=training_config.get("learning_rate", None), + optimizer_configs=training_config.get("optimizer_configs", None), + pre_encoded=training_config.get("pre_encoded", False), + frac_lengths_mask=training_config.get("frac_lengths_mask", (0.7, 1.)), + min_span_len=training_config.get("min_span_len", 10), + timestep_sampler = training_config.get("timestep_sampler", "uniform"), + ctx_drop = training_config.get("ctx_drop", 0.1), + r_drop = training_config.get("r_drop", 0.0) + ) + elif model_type == 'diffusion_cond' or model_type == 'mm_diffusion_cond': + from .diffusion import DiffusionCondTrainingWrapper + return DiffusionCondTrainingWrapper( + model, + lr=training_config.get("learning_rate", None), + mask_padding=training_config.get("mask_padding", False), + mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0), + use_ema = training_config.get("use_ema", True), + log_loss_info=training_config.get("log_loss_info", False), + optimizer_configs=training_config.get("optimizer_configs", None), + pre_encoded=training_config.get("pre_encoded", False), + diffusion_objective=training_config.get("diffusion_objective","v"), + cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), + timestep_sampler = training_config.get("timestep_sampler", "uniform"), + max_mask_segments = training_config.get("max_mask_segments", 0) + ) + elif model_type == 'diffusion_prior': + from .diffusion import DiffusionPriorTrainingWrapper + from ..models.diffusion_prior import PriorType + + ema_copy = create_model_from_config(model_config) + + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + prior_type = training_config.get("prior_type", "mono_stereo") + + if prior_type == "mono_stereo": + prior_type_enum = PriorType.MonoToStereo + else: + raise ValueError(f"Unknown prior type: {prior_type}") + + return DiffusionPriorTrainingWrapper( + model, + lr=training_config["learning_rate"], + ema_copy=ema_copy, + prior_type=prior_type_enum, + log_loss_info=training_config.get("log_loss_info", False), + use_reconstruction_loss=training_config.get("use_reconstruction_loss", False), + ) + elif model_type == 'diffusion_cond_inpaint': + from .diffusion import DiffusionCondInpaintTrainingWrapper + return DiffusionCondInpaintTrainingWrapper( + model, + lr=training_config.get("learning_rate", None), + max_mask_segments = training_config.get("max_mask_segments", 10), + log_loss_info=training_config.get("log_loss_info", False), + optimizer_configs=training_config.get("optimizer_configs", None), + use_ema=training_config.get("use_ema", True), + pre_encoded=training_config.get("pre_encoded", False), + cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1), + timestep_sampler = training_config.get("timestep_sampler", "uniform") + ) + elif model_type == 'diffusion_autoencoder' : + from .diffusion import DiffusionAutoencoderTrainingWrapper + + ema_copy = create_model_from_config(model_config) + + # Copy each weight to the ema copy + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + return DiffusionAutoencoderTrainingWrapper( + model, + ema_copy=ema_copy, + lr=training_config["learning_rate"], + use_reconstruction_loss=training_config.get("use_reconstruction_loss", False) + ) + elif model_type == 'lm': + from .lm import AudioLanguageModelTrainingWrapper + + ema_copy = create_model_from_config(model_config) + + for name, param in model.state_dict().items(): + if isinstance(param, Parameter): + # backwards compatibility for serialized parameters + param = param.data + ema_copy.state_dict()[name].copy_(param) + + return AudioLanguageModelTrainingWrapper( + model, + ema_copy=ema_copy, + lr=training_config.get("learning_rate", None), + use_ema=training_config.get("use_ema", False), + optimizer_configs=training_config.get("optimizer_configs", None), + pre_encoded=training_config.get("pre_encoded", False), + ) + + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + +def create_demo_callback_from_config(model_config, **kwargs): + model_type = model_config.get('model_type', None) + assert model_type is not None, 'model_type must be specified in model config' + + training_config = model_config.get('training', None) + assert training_config is not None, 'training config must be specified in model config' + + demo_config = training_config.get("demo", {}) + + if model_type == 'autoencoder': + from .autoencoders import AutoencoderDemoCallback + return AutoencoderDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == 'diffusion_uncond': + from .diffusion import DiffusionUncondDemoCallback + return DiffusionUncondDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_rate=model_config["sample_rate"] + ) + elif model_type == 'diffusion_infill': + from .diffusion import DiffusionInfillDemoCallback + return DiffusionInfillDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == "diffusion_autoencoder": + from .diffusion import DiffusionAutoencoderDemoCallback + return DiffusionAutoencoderDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == "diffusion_prior": + from .diffusion import DiffusionPriorDemoCallback + return DiffusionPriorDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + demo_steps=demo_config.get("demo_steps", 250), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + **kwargs + ) + elif model_type == "diffusion_cond" or model_type == 'mm_diffusion_cond': + from .diffusion import DiffusionCondDemoCallback + + return DiffusionCondDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_steps=demo_config.get("demo_steps", 250), + num_demos=demo_config["num_demos"], + demo_cfg_scales=demo_config["demo_cfg_scales"], + demo_conditioning=demo_config.get("demo_cond", {}), + demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False), + display_audio_cond=demo_config.get("display_audio_cond", False), + ) + elif model_type == "diffusion_cond_inpaint": + from .diffusion import DiffusionCondInpaintDemoCallback + + return DiffusionCondInpaintDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_steps=demo_config.get("demo_steps", 250), + demo_cfg_scales=demo_config["demo_cfg_scales"], + **kwargs + ) + + elif model_type == "lm": + from .lm import AudioLanguageModelDemoCallback + + return AudioLanguageModelDemoCallback( + demo_every=demo_config.get("demo_every", 2000), + sample_size=model_config["sample_size"], + sample_rate=model_config["sample_rate"], + demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]), + demo_conditioning=demo_config.get("demo_cond", None), + num_demos=demo_config.get("num_demos", 8), + **kwargs + ) + else: + raise NotImplementedError(f'Unknown model type: {model_type}') \ No newline at end of file diff --git a/think_sound/training/lm.py b/think_sound/training/lm.py new file mode 100644 index 0000000000000000000000000000000000000000..e1fa9f71c805f8d4083919d5c46422c5b7eeb4a8 --- /dev/null +++ b/think_sound/training/lm.py @@ -0,0 +1,267 @@ +import pytorch_lightning as pl +import sys, gc +import random +import torch +import torchaudio +import typing as tp +import wandb + +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image +from ema_pytorch import EMA +from einops import rearrange +from safetensors.torch import save_file +from torch import optim +from torch.nn import functional as F +from pytorch_lightning.utilities.rank_zero import rank_zero_only + +from ..models.lm import AudioLanguageModelWrapper +from .utils import create_optimizer_from_config, create_scheduler_from_config + +class AudioLanguageModelTrainingWrapper(pl.LightningModule): + def __init__( + self, + model: AudioLanguageModelWrapper, + lr = 1e-4, + use_ema=False, + ema_copy=None, + optimizer_configs: dict = None, + pre_encoded=False + ): + super().__init__() + + self.model = model + + self.model.pretransform.requires_grad_(False) + + self.model_ema = None + if use_ema: + self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "lm": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (0.9, 0.95), + "weight_decay": 0.1 + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + self.optimizer_configs = optimizer_configs + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + lm_opt_config = self.optimizer_configs['lm'] + opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) + + if "scheduler" in lm_opt_config: + sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) + sched_lm_config = { + "scheduler": sched_lm, + "interval": "step" + } + return [opt_lm], [sched_lm_config] + + return [opt_lm] + + # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license + # License can be found in LICENSES/LICENSE_META.txt + + def _compute_cross_entropy( + self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: + """Compute cross entropy between multi-codebook targets and model's logits. + The cross entropy is computed per codebook to provide codebook-level cross entropy. + Valid timesteps for each of the codebook are pulled from the mask, where invalid + timesteps are set to 0. + + Args: + logits (torch.Tensor): Model's logits of shape [B, K, T, card]. + targets (torch.Tensor): Target codes, of shape [B, K, T]. + mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. + Returns: + ce (torch.Tensor): Cross entropy averaged over the codebooks + ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). + """ + B, K, T = targets.shape + assert logits.shape[:-1] == targets.shape + assert mask.shape == targets.shape + ce = torch.zeros([], device=targets.device) + ce_per_codebook: tp.List[torch.Tensor] = [] + for k in range(K): + logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] + targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] + mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] + ce_targets = targets_k[mask_k] + ce_logits = logits_k[mask_k] + q_ce = F.cross_entropy(ce_logits, ce_targets) + ce += q_ce + ce_per_codebook.append(q_ce.detach()) + # average cross entropy across codebooks + ce = ce / K + return ce, ce_per_codebook + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + if not self.pre_encoded: + codes = self.model.pretransform.tokenize(reals) + else: + codes = reals + + padding_masks = [] + for md in metadata: + if md["padding_mask"].ndim == 1: + padding_masks.append(md["padding_mask"]) + else: + padding_masks.append(md["padding_mask"][0]) + + padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) + + # Interpolate padding masks to the same length as the codes + padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=codes.shape[2], mode='nearest').bool() + + condition_tensors = None + + # If the model is conditioned, get the conditioning tensors + if self.model.conditioner is not None: + condition_tensors = self.model.conditioner(metadata, self.device) + + lm_output = self.model.compute_logits(codes, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) + + logits = lm_output.logits # [b, k, t, c] + logits_mask = lm_output.mask # [b, k, t] + + logits_mask = logits_mask & padding_masks + + cross_entropy, cross_entropy_per_codebook = self._compute_cross_entropy(logits, codes, logits_mask) + + loss = cross_entropy + + log_dict = { + 'train/loss': loss.detach(), + 'train/cross_entropy': cross_entropy.detach(), + 'train/perplexity': torch.exp(cross_entropy).detach(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + for k, ce_q in enumerate(cross_entropy_per_codebook): + log_dict[f'cross_entropy_q{k + 1}'] = ce_q + log_dict[f'perplexity_q{k + 1}'] = torch.exp(ce_q) + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + if self.model_ema is not None: + self.model_ema.update() + + def export_model(self, path, use_safetensors=False): + + model = self.model_ema.ema_model if self.model_ema is not None else self.model + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + + +class AudioLanguageModelDemoCallback(pl.Callback): + def __init__(self, + demo_every=2000, + num_demos=8, + sample_size=65536, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + **kwargs + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_samples = sample_size + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.demo_conditioning = demo_conditioning + self.demo_cfg_scales = demo_cfg_scales + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + module.eval() + + print(f"Generating demo") + self.last_demo_step = trainer.global_step + + demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio + + #demo_reals = batch[0][:self.num_demos] + + # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + # demo_reals = demo_reals[0] + + #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) + + ##Limit to first 50 tokens + #demo_reals_tokens = demo_reals_tokens[:, :, :50] + + try: + print("Getting conditioning") + + for cfg_scale in self.demo_cfg_scales: + + model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model + + print(f"Generating demo for cfg scale {cfg_scale}") + fakes = model.generate_audio( + batch_size=self.num_demos, + max_gen_len=demo_length_tokens, + conditioning=self.demo_conditioning, + #init_data = demo_reals_tokens, + cfg_scale=cfg_scale, + temp=1.0, + top_p=0.95 + ) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes / fakes.abs().max() + fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + except Exception as e: + raise e + finally: + gc.collect() + torch.cuda.empty_cache() + module.train() \ No newline at end of file diff --git a/think_sound/training/lm_continuous.py b/think_sound/training/lm_continuous.py new file mode 100644 index 0000000000000000000000000000000000000000..0ecc1a92336a0623f3f9b1c455a1f8198e4cacb8 --- /dev/null +++ b/think_sound/training/lm_continuous.py @@ -0,0 +1,294 @@ +import pytorch_lightning as pl +import sys, gc +import random +import torch +import torchaudio +import typing as tp +import wandb + +from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image +from ema_pytorch import EMA +from einops import rearrange +from safetensors.torch import save_file +from torch import optim +from torch.nn import functional as F +from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper + +from ..models.lm import AudioLMContinuousModelWrapper +from .utils import create_optimizer_from_config, create_scheduler_from_config + +class AudioLMContinuousModelTrainingWrapper(pl.LightningModule): + def __init__( + self, + model: AudioLanguageModelWrapper, + lr = 1e-4, + diffusion_objective: tp.Literal["rectified_flow", "v"] = "v", + timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform", + use_ema=False, + ema_copy=None, + optimizer_configs: dict = None, + diffusion_batch_mul=4, + pre_encoded=False + ): + super().__init__() + + self.model = model + self.diffusion = diffusion + self.rng = torch.quasirandom.SobolEngine(1, scramble=True) + + self.model.pretransform.requires_grad_(False) + + self.timestep_sampler = timestep_sampler + + self.diffusion_objective = model.diffusion_objective + + loss_modules = [ + MSELoss("v", + "targets", + weight=1.0, + name="mse_loss" + ) + ] + + self.losses = MultiLoss(loss_modules) + + self.model_ema = None + if use_ema: + self.model_ema = EMA(self.model, ema_model=ema_copy, beta=0.99, update_every=10) + + assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config" + + if optimizer_configs is None: + optimizer_configs = { + "lm": { + "optimizer": { + "type": "AdamW", + "config": { + "lr": lr, + "betas": (0.9, 0.95), + "weight_decay": 0.1 + } + } + } + } + else: + if lr is not None: + print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.") + + + self.optimizer_configs = optimizer_configs + + self.diffusion_batch_mul = diffusion_batch_mul + + self.pre_encoded = pre_encoded + + def configure_optimizers(self): + lm_opt_config = self.optimizer_configs['lm'] + opt_lm = create_optimizer_from_config(lm_opt_config['optimizer'], self.model.parameters()) + + if "scheduler" in lm_opt_config: + sched_lm = create_scheduler_from_config(lm_opt_config['scheduler'], opt_lm) + sched_lm_config = { + "scheduler": sched_lm, + "interval": "step" + } + return [opt_lm], [sched_lm_config] + + return [opt_lm] + + # Copied and modified from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/solvers/musicgen.py under MIT license + # License can be found in LICENSES/LICENSE_META.txt + + + def training_step(self, batch, batch_idx): + reals, metadata = batch + + if reals.ndim == 4 and reals.shape[0] == 1: + reals = reals[0] + + diffusion_input = reals + + loss_info = {} + + if not self.pre_encoded: + loss_info["audio_reals"] = diffusion_input + + if self.diffusion.pretransform is not None: + if not self.pre_encoded: + with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad): + diffusion_input = self.diffusion.pretransform.encode(diffusion_input) + else: + # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run + if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0: + diffusion_input = diffusion_input / self.diffusion.pretransform.scale + + loss_info["reals"] = diffusion_input + + padding_masks = [] + for md in metadata: + if md["padding_mask"].ndim == 1: + padding_masks.append(md["padding_mask"]) + else: + padding_masks.append(md["padding_mask"][0]) + + padding_masks = torch.stack(padding_masks, dim=0).to(self.device) # Shape (batch_size, sequence_length) + + condition_tensors = None + + # If the model is conditioned, get the conditioning tensors + if self.model.conditioner is not None: + with torch.cuda.amp.autocast(): + condition_tensors = self.model.conditioner(metadata, self.device) + + z = self.model.compute_logits(diffusion_input, condition_tensors=condition_tensors, cfg_dropout_prob=0.1) + bsz, seq_len, _ = z.shape + gt_inputs = diffusion_input.clone().detach() + gt_inputs = gt_inputs.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) + z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1) + mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul) + if self.timestep_sampler == "uniform": + # Draw uniformly distributed continuous timesteps + t = self.rng.draw(z.shape[0])[:, 0].to(self.device) + elif self.timestep_sampler == "logit_normal": + t = torch.sigmoid(torch.randn(z.shape[0], device=self.device)) + + # Calculate the noise schedule parameters for those timesteps + if self.diffusion_objective == "v": + alphas, sigmas = get_alphas_sigmas(t) + elif self.diffusion_objective == "rectified_flow": + alphas, sigmas = 1-t, t + + # Combine the ground truth data and the noise + alphas = alphas[:, None] + sigmas = sigmas[:, None] + + noise = torch.randn_like(gt_inputs) + noised_inputs = gt_inputs * alphas + noise * sigmas + if self.diffusion_objective == "v": + targets = noise * alphas - gt_inputs * sigmas + elif self.diffusion_objective == "rectified_flow": + targets = noise - gt_inputs + cond = {} + cond['z'] = z + with torch.cuda.amp.autocast(): + v = self.diffusion(noised_inputs, t, cond=cond) + + loss_info.update({ + "v": v, + "targets": targets + }) + + loss, losses = self.losses() + + log_dict = { + 'train/loss': loss.detach(), + 'train/std_data': diffusion_input.std(), + 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr'] + } + + + self.log_dict(log_dict, prog_bar=True, on_step=True) + return loss + + def on_before_zero_grad(self, *args, **kwargs): + if self.model_ema is not None: + self.model_ema.update() + + def export_model(self, path, use_safetensors=False): + + model = self.model_ema.ema_model if self.model_ema is not None else self.model + + if use_safetensors: + save_file(model.state_dict(), path) + else: + torch.save({"state_dict": model.state_dict()}, path) + + +class AudioLanguageModelDemoCallback(pl.Callback):loss_info + def __init__(self, + demo_every=2000, + num_demos=8, + sample_size=65536, + sample_rate=48000, + demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, + demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7], + **kwargs + ): + super().__init__() + + self.demo_every = demo_every + self.num_demos = num_demos + self.demo_samples = sample_size + self.sample_rate = sample_rate + self.last_demo_step = -1 + self.demo_conditioning = demo_conditioning + self.demo_cfg_scales = demo_cfg_scales + + @rank_zero_only + @torch.no_grad() + def on_train_batch_end(self, trainer, module: AudioLanguageModelTrainingWrapper, outputs, batch, batch_idx): + + if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step: + return + + module.eval() + + print(f"Generating demo") + self.last_demo_step = trainer.global_step + + demo_length_tokens = self.demo_samples // module.model.pretransform.downsampling_ratio + + #demo_reals = batch[0][:self.num_demos] + + # if demo_reals.ndim == 4 and demo_reals.shape[0] == 1: + # demo_reals = demo_reals[0] + + #demo_reals_tokens = module.model.pretransform.tokenize(demo_reals) + + ##Limit to first 50 tokens + #demo_reals_tokens = demo_reals_tokens[:, :, :50] + + try: + print("Getting conditioning") + + for cfg_scale in self.demo_cfg_scales: + + model = module.model # module.model_ema.ema_model if module.model_ema is not None else module.model + + print(f"Generating demo for cfg scale {cfg_scale}") + fakes = model.generate_audio( + batch_size=self.num_demos, + max_gen_len=demo_length_tokens, + conditioning=self.demo_conditioning, + #init_data = demo_reals_tokens, + cfg_scale=cfg_scale, + temp=1.0, + top_p=0.95 + ) + + # Put the demos together + fakes = rearrange(fakes, 'b d n -> d (b n)') + + log_dict = {} + + filename = f'demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav' + fakes = fakes / fakes.abs().max() + fakes = fakes.type(torch.float32).clamp(-1, 1).mul(32767).type(torch.int16).cpu() + torchaudio.save(filename, fakes, self.sample_rate) + + log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename, + sample_rate=self.sample_rate, + caption=f'Reconstructed') + + log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes)) + + trainer.logger.experiment.log(log_dict) + + except Exception as e: + raise e + finally: + gc.collect() + torch.cuda.empty_cache() + module.train() \ No newline at end of file diff --git a/think_sound/training/losses/__init__.py b/think_sound/training/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37fdea0eb6c3190e7001567cfe17dc296bf811e8 --- /dev/null +++ b/think_sound/training/losses/__init__.py @@ -0,0 +1 @@ +from .losses import * \ No newline at end of file diff --git a/think_sound/training/losses/__pycache__/__init__.cpython-310.pyc b/think_sound/training/losses/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1a2eac94cf919a5e241ffc55f7749b6d32fce3a Binary files /dev/null and b/think_sound/training/losses/__pycache__/__init__.cpython-310.pyc differ diff --git a/think_sound/training/losses/__pycache__/__init__.cpython-39.pyc b/think_sound/training/losses/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9d511b8392c6e8906ef62a166e5948fb28f3dfc Binary files /dev/null and b/think_sound/training/losses/__pycache__/__init__.cpython-39.pyc differ diff --git a/think_sound/training/losses/__pycache__/auraloss.cpython-310.pyc b/think_sound/training/losses/__pycache__/auraloss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f736432cc8ac2062b5fdb0d6e6ed21620d06231a Binary files /dev/null and b/think_sound/training/losses/__pycache__/auraloss.cpython-310.pyc differ diff --git a/think_sound/training/losses/__pycache__/auraloss.cpython-39.pyc b/think_sound/training/losses/__pycache__/auraloss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4ef7ab4b2d3d8c80493dd9208a8469b6000a307 Binary files /dev/null and b/think_sound/training/losses/__pycache__/auraloss.cpython-39.pyc differ diff --git a/think_sound/training/losses/__pycache__/losses.cpython-310.pyc b/think_sound/training/losses/__pycache__/losses.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..330d5bcce476194ab113cc90782747dbb2779346 Binary files /dev/null and b/think_sound/training/losses/__pycache__/losses.cpython-310.pyc differ diff --git a/think_sound/training/losses/__pycache__/losses.cpython-39.pyc b/think_sound/training/losses/__pycache__/losses.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6af0513ae4e07e8877b2253e4f1df0153a0ba20e Binary files /dev/null and b/think_sound/training/losses/__pycache__/losses.cpython-39.pyc differ diff --git a/think_sound/training/losses/auraloss.py b/think_sound/training/losses/auraloss.py new file mode 100644 index 0000000000000000000000000000000000000000..933ec0ab66cb6f8fbb26e249f96b09fb1982a132 --- /dev/null +++ b/think_sound/training/losses/auraloss.py @@ -0,0 +1,691 @@ +# Copied and modified from https://github.com/csteinmetz1/auraloss/blob/main/auraloss/freq.py under Apache License 2.0 +# You can find the license at LICENSES/LICENSE_AURALOSS.txt + +import torch +import numpy as np +from typing import List, Any +import scipy.signal + +def apply_reduction(losses, reduction="none"): + """Apply reduction to collection of losses.""" + if reduction == "mean": + losses = losses.mean() + elif reduction == "sum": + losses = losses.sum() + return losses + +def compute_direction(w, x, y, z): + # 计算各个声道的权重 + phi = torch.atan2(y, x) + theta = torch.atan2(torch.sqrt(x**2 + y**2), z) + return phi.unsqueeze(1), theta.unsqueeze(1) + +def get_window(win_type: str, win_length: int): + """Return a window function. + + Args: + win_type (str): Window type. Can either be one of the window function provided in PyTorch + ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] + or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html). + win_length (int): Window length + + Returns: + win: The window as a 1D torch tensor + """ + + try: + win = getattr(torch, win_type)(win_length) + except: + win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length)) + + return win + +class SumAndDifference(torch.nn.Module): + """Sum and difference signal extraction module.""" + + def __init__(self): + """Initialize sum and difference extraction module.""" + super(SumAndDifference, self).__init__() + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, #channels, #samples). + Returns: + Tensor: Sum signal. + Tensor: Difference signal. + """ + if not (x.size(1) == 2): # inputs must be stereo + raise ValueError(f"Input must be stereo: {x.size(1)} channel(s).") + + sum_sig = self.sum(x).unsqueeze(1) + diff_sig = self.diff(x).unsqueeze(1) + + return sum_sig, diff_sig + + @staticmethod + def sum(x): + return x[:, 0, :] + x[:, 1, :] + + @staticmethod + def diff(x): + return x[:, 0, :] - x[:, 1, :] + + +class FIRFilter(torch.nn.Module): + """FIR pre-emphasis filtering module. + + Args: + filter_type (str): Shape of the desired FIR filter ("hp", "fd", "aw"). Default: "hp" + coef (float): Coefficient value for the filter tap (only applicable for "hp" and "fd"). Default: 0.85 + ntaps (int): Number of FIR filter taps for constructing A-weighting filters. Default: 101 + plot (bool): Plot the magnitude respond of the filter. Default: False + + Based upon the perceptual loss pre-empahsis filters proposed by + [Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922). + + A-weighting filter - "aw" + First-order highpass - "hp" + Folded differentiator - "fd" + + Note that the default coefficeint value of 0.85 is optimized for + a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates. + """ + + def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False): + """Initilize FIR pre-emphasis filtering module.""" + super(FIRFilter, self).__init__() + self.filter_type = filter_type + self.coef = coef + self.fs = fs + self.ntaps = ntaps + self.plot = plot + + import scipy.signal + + if ntaps % 2 == 0: + raise ValueError(f"ntaps must be odd (ntaps={ntaps}).") + + if filter_type == "hp": + self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) + self.fir.weight.requires_grad = False + self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1) + elif filter_type == "fd": + self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1) + self.fir.weight.requires_grad = False + self.fir.weight.data = torch.tensor([1, 0, -coef]).view(1, 1, -1) + elif filter_type == "aw": + # Definition of analog A-weighting filter according to IEC/CD 1672. + f1 = 20.598997 + f2 = 107.65265 + f3 = 737.86223 + f4 = 12194.217 + A1000 = 1.9997 + + NUMs = [(2 * np.pi * f4) ** 2 * (10 ** (A1000 / 20)), 0, 0, 0, 0] + DENs = np.polymul( + [1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2], + [1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2], + ) + DENs = np.polymul( + np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2] + ) + + # convert analog filter to digital filter + b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs) + + # compute the digital filter frequency response + w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs) + + # then we fit to 101 tap FIR filter with least squares + taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs) + + # now implement this digital FIR filter as a Conv1d layer + self.fir = torch.nn.Conv1d( + 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2 + ) + self.fir.weight.requires_grad = False + self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1) + + if plot: + from .plotting import compare_filters + compare_filters(b, a, taps, fs=fs) + + def forward(self, input, target): + """Calculate forward propagation. + Args: + input (Tensor): Predicted signal (B, #channels, #samples). + target (Tensor): Groundtruth signal (B, #channels, #samples). + Returns: + Tensor: Filtered signal. + """ + input = torch.nn.functional.conv1d( + input, self.fir.weight.data, padding=self.ntaps // 2 + ) + target = torch.nn.functional.conv1d( + target, self.fir.weight.data, padding=self.ntaps // 2 + ) + return input, target + +class SpectralConvergenceLoss(torch.nn.Module): + """Spectral convergence loss module. + + See [Arik et al., 2018](https://arxiv.org/abs/1808.06719). + """ + + def __init__(self): + super(SpectralConvergenceLoss, self).__init__() + + def forward(self, x_mag, y_mag): + return (torch.norm(y_mag - x_mag, p="fro", dim=[-1, -2]) / torch.norm(y_mag, p="fro", dim=[-1, -2])).mean() + +class STFTMagnitudeLoss(torch.nn.Module): + """STFT magnitude loss module. + + See [Arik et al., 2018](https://arxiv.org/abs/1808.06719) + and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1) + + Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the + compression strength (larger value results in more compression), and `log_eps` can be used + to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive + output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression. + + Args: + log (bool, optional): Log-scale the STFT magnitudes, + or use linear scale. Default: True + log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm. + Default: 0.0 + log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm. + Default: 1.0 + distance (str, optional): Distance function ["L1", "L2"]. Default: "L1" + reduction (str, optional): Reduction of the loss elements. Default: "mean" + """ + + def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"): + super(STFTMagnitudeLoss, self).__init__() + + self.log = log + self.log_eps = log_eps + self.log_fac = log_fac + + if distance == "L1": + self.distance = torch.nn.L1Loss(reduction=reduction) + elif distance == "L2": + self.distance = torch.nn.MSELoss(reduction=reduction) + else: + raise ValueError(f"Invalid distance: '{distance}'.") + + def forward(self, x_mag, y_mag): + if self.log: + x_mag = torch.log(self.log_fac * x_mag + self.log_eps) + y_mag = torch.log(self.log_fac * y_mag + self.log_eps) + return self.distance(x_mag, y_mag) + + +class STFTLoss(torch.nn.Module): + """STFT loss module. + + See [Yamamoto et al. 2019](https://arxiv.org/abs/1904.04472). + + Args: + fft_size (int, optional): FFT size in samples. Default: 1024 + hop_size (int, optional): Hop size of the FFT in samples. Default: 256 + win_length (int, optional): Length of the FFT analysis window. Default: 1024 + window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch + ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] + or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html). + Default: 'hann_window' + w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0 + w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0 + w_lin_mag_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0 + w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0 + sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None + scale (str, optional): Optional frequency scaling method, options include: + ['mel', 'chroma'] + Default: None + n_bins (int, optional): Number of scaling frequency bins. Default: None. + perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False + scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False + eps (float, optional): Small epsilon value for stablity. Default: 1e-8 + output (str, optional): Format of the loss returned. + 'loss' : Return only the raw, aggregate loss term. + 'full' : Return the raw loss, plus intermediate loss terms. + Default: 'loss' + reduction (str, optional): Specifies the reduction to apply to the output: + 'none': no reduction will be applied, + 'mean': the sum of the output will be divided by the number of elements in the output, + 'sum': the output will be summed. + Default: 'mean' + mag_distance (str, optional): Distance function ["L1", "L2"] for the magnitude loss terms. + device (str, optional): Place the filterbanks on specified device. Default: None + + Returns: + loss: + Aggreate loss term. Only returned if output='loss'. By default. + loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss: + Aggregate and intermediate loss terms. Only returned if output='full'. + """ + + def __init__( + self, + fft_size: int = 1024, + hop_size: int = 256, + win_length: int = 1024, + window: str = "hann_window", + w_sc: float = 1.0, + w_log_mag: float = 1.0, + w_lin_mag: float = 0.0, + w_phs: float = 0.0, + sample_rate: float = None, + scale: str = None, + n_bins: int = None, + perceptual_weighting: bool = False, + scale_invariance: bool = False, + eps: float = 1e-8, + output: str = "loss", + reduction: str = "mean", + mag_distance: str = "L1", + device: Any = None, + **kwargs + ): + super().__init__() + self.fft_size = fft_size + self.hop_size = hop_size + self.win_length = win_length + self.window = get_window(window, win_length) + self.w_sc = w_sc + self.w_log_mag = w_log_mag + self.w_lin_mag = w_lin_mag + self.w_phs = w_phs + self.sample_rate = sample_rate + self.scale = scale + self.n_bins = n_bins + self.perceptual_weighting = perceptual_weighting + self.scale_invariance = scale_invariance + self.eps = eps + self.output = output + self.reduction = reduction + self.mag_distance = mag_distance + self.device = device + + self.phs_used = bool(self.w_phs) + + self.spectralconv = SpectralConvergenceLoss() + self.logstft = STFTMagnitudeLoss( + log=True, + reduction=reduction, + distance=mag_distance, + **kwargs + ) + self.linstft = STFTMagnitudeLoss( + log=False, + reduction=reduction, + distance=mag_distance, + **kwargs + ) + + # setup mel filterbank + if scale is not None: + try: + import librosa.filters + except Exception as e: + print(e) + print("Try `pip install auraloss[all]`.") + + if self.scale == "mel": + assert sample_rate != None # Must set sample rate to use mel scale + assert n_bins <= fft_size # Must be more FFT bins than Mel bins + fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins) + fb = torch.tensor(fb).unsqueeze(0) + + elif self.scale == "chroma": + assert sample_rate != None # Must set sample rate to use chroma scale + assert n_bins <= fft_size # Must be more FFT bins than chroma bins + fb = librosa.filters.chroma( + sr=sample_rate, n_fft=fft_size, n_chroma=n_bins + ) + + else: + raise ValueError( + f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'." + ) + + self.register_buffer("fb", fb) + + if scale is not None and device is not None: + self.fb = self.fb.to(self.device) # move filterbank to device + + if self.perceptual_weighting: + if sample_rate is None: + raise ValueError( + f"`sample_rate` must be supplied when `perceptual_weighting = True`." + ) + self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate) + + def stft(self, x): + """Perform STFT. + Args: + x (Tensor): Input signal tensor (B, T). + + Returns: + Tensor: x_mag, x_phs + Magnitude and phase spectra (B, fft_size // 2 + 1, frames). + """ + x_stft = torch.stft( + x, + self.fft_size, + self.hop_size, + self.win_length, + self.window, + return_complex=True, + ) + x_mag = torch.sqrt( + torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps) + ) + + # torch.angle is expensive, so it is only evaluated if the values are used in the loss + if self.phs_used: + x_phs = torch.angle(x_stft) + else: + x_phs = None + + return x_mag, x_phs + + def forward(self, input: torch.Tensor, target: torch.Tensor): + bs, chs, seq_len = input.size() + + if self.perceptual_weighting: # apply optional A-weighting via FIR filter + # since FIRFilter only support mono audio we will move channels to batch dim + input = input.view(bs * chs, 1, -1) + target = target.view(bs * chs, 1, -1) + + # now apply the filter to both + self.prefilter.to(input.device) + input, target = self.prefilter(input, target) + + # now move the channels back + input = input.view(bs, chs, -1) + target = target.view(bs, chs, -1) + + # compute the magnitude and phase spectra of input and target + self.window = self.window.to(input.device) + + x_mag, x_phs = self.stft(input.view(-1, input.size(-1))) + y_mag, y_phs = self.stft(target.view(-1, target.size(-1))) + + # apply relevant transforms + if self.scale is not None: + self.fb = self.fb.to(input.device) + x_mag = torch.matmul(self.fb, x_mag) + y_mag = torch.matmul(self.fb, y_mag) + + # normalize scales + if self.scale_invariance: + alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag**2).sum([-2, -1])) + y_mag = y_mag * alpha.unsqueeze(-1) + + # compute loss terms + sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0 + log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0 + lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0 + phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0 + + # combine loss terms + loss = ( + (self.w_sc * sc_mag_loss) + + (self.w_log_mag * log_mag_loss) + + (self.w_lin_mag * lin_mag_loss) + + (self.w_phs * phs_loss) + ) + + loss = apply_reduction(loss, reduction=self.reduction) + + if self.output == "loss": + return loss + elif self.output == "full": + return loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss + +class MultiResolutionSTFTLoss(torch.nn.Module): + """Multi resolution STFT loss module. + + See [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480) + + Args: + fft_sizes (list): List of FFT sizes. + hop_sizes (list): List of hop sizes. + win_lengths (list): List of window lengths. + window (str, optional): Window to apply before FFT, options include: + 'hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window'] + Default: 'hann_window' + w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0 + w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0 + w_lin_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0 + w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0 + sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None + scale (str, optional): Optional frequency scaling method, options include: + ['mel', 'chroma'] + Default: None + n_bins (int, optional): Number of mel frequency bins. Required when scale = 'mel'. Default: None. + scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False + """ + + def __init__( + self, + fft_sizes: List[int] = [1024, 2048, 512], + hop_sizes: List[int] = [120, 240, 50], + win_lengths: List[int] = [600, 1200, 240], + window: str = "hann_window", + w_sc: float = 1.0, + w_log_mag: float = 1.0, + w_lin_mag: float = 0.0, + w_phs: float = 0.0, + sample_rate: float = None, + scale: str = None, + n_bins: int = None, + perceptual_weighting: bool = False, + scale_invariance: bool = False, + **kwargs, + ): + super().__init__() + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all + self.fft_sizes = fft_sizes + self.hop_sizes = hop_sizes + self.win_lengths = win_lengths + + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses += [ + STFTLoss( + fs, + ss, + wl, + window, + w_sc, + w_log_mag, + w_lin_mag, + w_phs, + sample_rate, + scale, + n_bins, + perceptual_weighting, + scale_invariance, + **kwargs, + ) + ] + + def forward(self, x, y): + mrstft_loss = 0.0 + sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss = [], [], [], [] + # import ipdb + # ipdb.set_trace() + for f in self.stft_losses: + if f.output == "full": # extract just first term + tmp_loss = f(x, y) + mrstft_loss += tmp_loss[0] + sc_mag_loss.append(tmp_loss[1]) + log_mag_loss.append(tmp_loss[2]) + lin_mag_loss.append(tmp_loss[3]) + phs_loss.append(tmp_loss[4]) + else: + mrstft_loss += f(x, y) + + mrstft_loss /= len(self.stft_losses) + + if f.output == "loss": + return mrstft_loss + else: + return mrstft_loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss + + +class SumAndDifferenceSTFTLoss(torch.nn.Module): + """Sum and difference sttereo STFT loss module. + + See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291) + + Args: + fft_sizes (List[int]): List of FFT sizes. + hop_sizes (List[int]): List of hop sizes. + win_lengths (List[int]): List of window lengths. + window (str, optional): Window function type. + w_sum (float, optional): Weight of the sum loss component. Default: 1.0 + w_diff (float, optional): Weight of the difference loss component. Default: 1.0 + perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False + mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False + n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128 + sample_rate (float, optional): Audio sample rate. Default: None + output (str, optional): Format of the loss returned. + 'loss' : Return only the raw, aggregate loss term. + 'full' : Return the raw loss, plus intermediate loss terms. + Default: 'loss' + """ + + def __init__( + self, + fft_sizes: List[int], + hop_sizes: List[int], + win_lengths: List[int], + window: str = "hann_window", + w_sum: float = 1.0, + w_diff: float = 1.0, + output: str = "loss", + **kwargs, + ): + super().__init__() + self.sd = SumAndDifference() + self.w_sum = w_sum + self.w_diff = w_diff + self.output = output + self.mrstft = MultiResolutionSTFTLoss( + fft_sizes, + hop_sizes, + win_lengths, + window, + **kwargs, + ) + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """This loss function assumes batched input of stereo audio in the time domain. + + Args: + input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len). + target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len). + + Returns: + loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'. + loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor): + Aggregate and intermediate loss terms. Only returned if output='full'. + """ + assert input.shape == target.shape # must have same shape + bs, chs, seq_len = input.size() + + # compute sum and difference signals for both + input_sum, input_diff = self.sd(input) + target_sum, target_diff = self.sd(target) + + # compute error in STFT domain + sum_loss = self.mrstft(input_sum, target_sum) + diff_loss = self.mrstft(input_diff, target_diff) + loss = ((self.w_sum * sum_loss) + (self.w_diff * diff_loss)) / 2 + + if self.output == "loss": + return loss + elif self.output == "full": + return loss, sum_loss, diff_loss + +class SpatialSTFTLoss(torch.nn.Module): + """Sum and difference sttereo STFT loss module. + + See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291) + + Args: + fft_sizes (List[int]): List of FFT sizes. + hop_sizes (List[int]): List of hop sizes. + win_lengths (List[int]): List of window lengths. + window (str, optional): Window function type. + w_sum (float, optional): Weight of the sum loss component. Default: 1.0 + w_diff (float, optional): Weight of the difference loss component. Default: 1.0 + perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False + mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False + n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128 + sample_rate (float, optional): Audio sample rate. Default: None + output (str, optional): Format of the loss returned. + 'loss' : Return only the raw, aggregate loss term. + 'full' : Return the raw loss, plus intermediate loss terms. + Default: 'loss' + """ + + def __init__( + self, + fft_sizes: List[int], + hop_sizes: List[int], + win_lengths: List[int], + window: str = "hann_window", + w_phi: float = 1.0, + w_theta: float = 1.0, + output: str = "loss", + **kwargs, + ): + super().__init__() + self.w_phi = w_phi + self.w_theta = w_theta + self.output = output + self.mrstft = MultiResolutionSTFTLoss( + fft_sizes, + hop_sizes, + win_lengths, + window, + **kwargs, + ) + + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """This loss function assumes batched input of stereo audio in the time domain. + + Args: + input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len). + target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len). + + Returns: + loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'. + loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor): + Aggregate and intermediate loss terms. Only returned if output='full'. + """ + assert input.shape == target.shape # must have same shape + bs, chs, seq_len = input.size() + + w_o, x_o, y_o, z_o = input[:, 0], input[:, 1], input[:, 2], input[:, 3] + w_r, x_r, y_r, z_r = target[:, 0], target[:, 1], target[:, 2], target[:, 3] + + phi_o, theta_o = compute_direction(w_o, x_o, y_o, z_o) + phi_r, theta_r = compute_direction(w_r, x_r, y_r, z_r) + + # compute error in STFT domain + phi_loss = self.mrstft(phi_o, phi_r) + theta_loss = self.mrstft(theta_o, theta_r) + loss = ((self.w_phi * phi_loss) + (self.w_theta * theta_loss)) / 2 + + if self.output == "loss": + return loss + elif self.output == "full": + return loss, sum_loss, diff_loss \ No newline at end of file diff --git a/think_sound/training/losses/losses.py b/think_sound/training/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..7285850c3ff873e0dda6a83265536dcb0bcb5b4f --- /dev/null +++ b/think_sound/training/losses/losses.py @@ -0,0 +1,100 @@ +import typing as tp + +from torch.nn import functional as F +from torch import nn + +class LossModule(nn.Module): + def __init__(self, name: str, weight: float = 1.0): + super().__init__() + + self.name = name + self.weight = weight + + def forward(self, info, *args, **kwargs): + raise NotImplementedError + +class ValueLoss(LossModule): + def __init__(self, key: str, name, weight: float = 1.0): + super().__init__(name=name, weight=weight) + + self.key = key + + def forward(self, info): + return self.weight * info[self.key] + +class L1Loss(LossModule): + def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'): + super().__init__(name=name, weight=weight) + + self.key_a = key_a + self.key_b = key_b + + self.mask_key = mask_key + + def forward(self, info): + mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none') + + if self.mask_key is not None and self.mask_key in info: + mse_loss = mse_loss[info[self.mask_key]] + + mse_loss = mse_loss.mean() + + return self.weight * mse_loss + +class MSELoss(LossModule): + def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'): + super().__init__(name=name, weight=weight) + + self.key_a = key_a + self.key_b = key_b + + self.mask_key = mask_key + + def forward(self, info): + mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none') + if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None: + mask = info[self.mask_key] + + if mask.ndim == 2 and mse_loss.ndim == 3: + mask = mask.unsqueeze(1) + + if mask.shape[1] != mse_loss.shape[1]: + mask = mask.repeat(1, mse_loss.shape[1], 1) + + mse_loss = mse_loss[mask] + + mse_loss = mse_loss.mean() + + return self.weight * mse_loss + +class AuralossLoss(LossModule): + def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1): + super().__init__(name, weight) + + self.auraloss_module = auraloss_module + + self.input_key = input_key + self.target_key = target_key + + def forward(self, info): + loss = self.auraloss_module(info[self.input_key], info[self.target_key]) + + return self.weight * loss + +class MultiLoss(nn.Module): + def __init__(self, losses: tp.List[LossModule]): + super().__init__() + + self.losses = nn.ModuleList(losses) + + def forward(self, info): + total_loss = 0 + + losses = {} + + for loss_module in self.losses: + module_loss = loss_module(info) + total_loss += module_loss + losses[loss_module.name] = module_loss + + return total_loss, losses \ No newline at end of file diff --git a/think_sound/training/utils.py b/think_sound/training/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4a848098c14931c9791d032642cbd6eda93addd7 --- /dev/null +++ b/think_sound/training/utils.py @@ -0,0 +1,200 @@ +import torch +import os +from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor +import random + + + +def get_rank(): + """Get rank of current process.""" + + print(os.environ.keys()) + + if "SLURM_PROCID" in os.environ: + return int(os.environ["SLURM_PROCID"]) + + if not torch.distributed.is_available() or not torch.distributed.is_initialized(): + return 0 + + return torch.distributed.get_rank() + +class InverseLR(torch.optim.lr_scheduler._LRScheduler): + """Implements an inverse decay learning rate schedule with an optional exponential + warmup. When last_epoch=-1, sets initial lr as lr. + inv_gamma is the number of steps/epochs required for the learning rate to decay to + (1 / 2)**power of its original value. + Args: + optimizer (Optimizer): Wrapped optimizer. + inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1. + power (float): Exponential factor of learning rate decay. Default: 1. + warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable) + Default: 0. + final_lr (float): The final learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0., + last_epoch=-1, verbose=False): + self.inv_gamma = inv_gamma + self.power = power + if not 0. <= warmup < 1: + raise ValueError('Invalid value for warmup') + self.warmup = warmup + self.final_lr = final_lr + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + import warnings + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + + return self._get_closed_form_lr() + + def _get_closed_form_lr(self): + warmup = 1 - self.warmup ** (self.last_epoch + 1) + lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power + return [warmup * max(self.final_lr, base_lr * lr_mult) + for base_lr in self.base_lrs] + +def copy_state_dict(model, state_dict): + """Load state_dict to model, but only for keys that match exactly. + + Args: + model (nn.Module): model to load state_dict. + state_dict (OrderedDict): state_dict to load. + """ + model_state_dict = model.state_dict() + + # 创建一个列表存储不匹配的参数 + missing_keys = [] + unexpected_keys = [] + # 手动加载并检查不匹配的参数 + for key in state_dict: + if key not in model_state_dict: + unexpected_keys.append(key) + elif state_dict[key].shape != model_state_dict[key].shape: + unexpected_keys.append(key) + + for key in model_state_dict: + if key not in state_dict: + missing_keys.append(key) + + # 打印不匹配的参数 + print("Missing keys in state_dict:", missing_keys) + print("Unexpected keys in state_dict:", unexpected_keys) + for key in state_dict: + if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape: + if isinstance(state_dict[key], torch.nn.Parameter): + # backwards compatibility for serialized parameters + state_dict[key] = state_dict[key].data + model_state_dict[key] = state_dict[key] + + model.load_state_dict(model_state_dict, strict=False) + +def create_optimizer_from_config(optimizer_config, parameters): + """Create optimizer from config. + + Args: + parameters (iterable): parameters to optimize. + optimizer_config (dict): optimizer config. + + Returns: + torch.optim.Optimizer: optimizer. + """ + + optimizer_type = optimizer_config["type"] + + if optimizer_type == "FusedAdam": + from deepspeed.ops.adam import FusedAdam + optimizer = FusedAdam(parameters, **optimizer_config["config"]) + else: + optimizer_fn = getattr(torch.optim, optimizer_type) + optimizer = optimizer_fn(parameters, **optimizer_config["config"]) + return optimizer + +def create_scheduler_from_config(scheduler_config, optimizer): + """Create scheduler from config. + + Args: + scheduler_config (dict): scheduler config. + optimizer (torch.optim.Optimizer): optimizer. + + Returns: + torch.optim.lr_scheduler._LRScheduler: scheduler. + """ + if scheduler_config["type"] == "InverseLR": + scheduler_fn = InverseLR + else: + scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"]) + scheduler = scheduler_fn(optimizer, **scheduler_config["config"]) + return scheduler + +# mask construction helpers + +def mask_from_start_end_indices( + seq_len: int, + start: Tensor, + end: Tensor +): + assert start.shape == end.shape + device = start.device + + seq = torch.arange(seq_len, device = device, dtype = torch.long) + seq = seq.reshape(*((-1,) * start.ndim), seq_len) + seq = seq.expand(*start.shape, seq_len) + + mask = seq >= start[..., None].long() + mask &= seq < end[..., None].long() + return mask + +def mask_from_frac_lengths( + seq_len: int, + frac_lengths: Tensor +): + device = frac_lengths.device + + lengths = (frac_lengths * seq_len).long() + max_start = seq_len - lengths + + rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1) + start = (max_start * rand).clamp(min = 0) + end = start + lengths + + return mask_from_start_end_indices(seq_len, start, end) + +def generate_mask(batch_size, seq_len, frac_lengths, min_span_len): + # 计算需要掩盖的起始数量 + n_mask = (frac_lengths * seq_len // min_span_len).long() # 每个 span 为 10 + # 初始化掩码张量,初始为全 0(未掩盖) + mask_tensor = torch.zeros((batch_size, seq_len), device=frac_lengths.device, dtype=torch.bool) + + for b in range(batch_size): + # 随机挑选起始帧 + start_frames = random.sample(range(0, seq_len - min_span_len + 1), n_mask[b]) # 0 到 seq_len-10 的范围 + + for start in start_frames: + # 将 span 为 10 的区域标记为 1(掩盖) + mask_tensor[b, start:start + 10] = 1.0 + + return mask_tensor + +def generate_channel_mask(diffusion_input): + + # 如果 r_drop 小于 threshold,则对每个样本选择一个随机声道进行完全 mask + batchsize, num_channels, dim = diffusion_input.shape + for i in range(batchsize): + channel_means = torch.mean(torch.abs(diffusion_input[i]), dim=1) # Mean of the absolute values for each channel + # Determine if any channel is 'small enough' + if torch.all(channel_means > 0.01): + # If all channels are not 'small enough', apply the mask + channel = torch.randint(num_channels, (1,)).item() + diffusion_input[i, channel, :] = 1e-8 # Mask the channel by setting its values + else: + # Optionally log that at least one channel is 'small enough' and no mask is applied + print(f"Sample {i}: At least one channel is 'small enough', skipping masking.") + + return diffusion_input +