diff --git a/README.md b/README.md index 4fd9618dde607663fcd69b94b821a8603a243b8b..20cf8a7bc5b236d5e1064470df8d11d56bb8c752 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ --- title: DeepSound-V1 -emoji: ๐Ÿ”Š colorFrom: blue colorTo: indigo sdk: gradio @@ -9,155 +8,160 @@ pinned: false --- -# [Taming Multimodal Joint Training for High-Quality Video-to-Audio Synthesis](https://hkchengrex.github.io/MMAudio) + -[Ho Kei Cheng](https://hkchengrex.github.io/), [Masato Ishii](https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ), [Akio Hayakawa](https://scholar.google.com/citations?user=sXAjHFIAAAAJ), [Takashi Shibuya](https://scholar.google.com/citations?user=XCRO260AAAAJ), [Alexander Schwing](https://www.alexander-schwing.de/), [Yuki Mitsufuji](https://www.yukimitsufuji.com/) -University of Illinois Urbana-Champaign, Sony AI, and Sony Group Corporation +
+

+

DeepSound-V1

+ + Paper | Webpage | Huggingface Demo +

+
+## [DeepSound-V1: Start to Think Step-by-Step in the Audio Generation from Videos](https://github.com/lym0302/DeepSound-V1) -[[Paper (being prepared)]](https://hkchengrex.github.io/MMAudio) [[Project Page]](https://hkchengrex.github.io/MMAudio) + + -**Note: This repository is still under construction. Single-example inference should work as expected. The training code will be added. Code is subject to non-backward-compatible changes.** + ## Highlight -MMAudio generates synchronized audio given video and/or text inputs. -Our key innovation is multimodal joint training which allows training on a wide range of audio-visual and audio-text datasets. -Moreover, a synchronization module aligns the generated audio with the video frames. +DeepSound-V1 is a framework enabling audio generation from videos towards initial step-by-step thinking without extra annotations based on the internal chain-of-thought (CoT) of Multi-modal large language model(MLLM). - -## Results + + ## Installation +```bash +conda create -n deepsound-v1 python=3.10.16 -y +conda activate deepsound-v1 +pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu120 +pip install flash-attn==2.5.8 --no-build-isolation +pip install -e . +pip install -r reqirments.txt +``` + -We have only tested this on Ubuntu. + -**Clone our repository:** + -```bash -cd MMAudio -pip install -e . + -(If you encounter the File "setup.py" not found error, upgrade your pip with pip install --upgrade pip) - -**Pretrained models:** - -The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py` - -| Model | Download link | File size | -| -------- | ------- | ------- | -| Flow prediction network, small 16kHz | mmaudio_small_16k.pth | 601M | -| Flow prediction network, small 44.1kHz | mmaudio_small_44k.pth | 601M | -| Flow prediction network, medium 44.1kHz | mmaudio_medium_44k.pth | 2.4G | -| Flow prediction network, large 44.1kHz **(recommended)** | mmaudio_large_44k.pth | 3.9G | -| 16kHz VAE | v1-16.pth | 655M | -| 16kHz BigVGAN vocoder |best_netG.pt | 429M | -| 44.1kHz VAE |v1-44.pth | 1.2G | -| Synchformer visual encoder |synchformer_state_dict.pth | 907M | - -The 44.1kHz vocoder will be downloaded automatically. - -The expected directory structure (full): + + + + + ## Demo -By default, these scripts use the `large_44k` model. -In our experiments, inference only takes around 6GB of GPU memory (in 16-bit mode) which should fit in most modern GPUs. +### Pretrained models +See [MODELS.md](docs/MODELS.md). ### Command-line interface With `demo.py` + ```bash -python demo.py --duration=8 --video= --prompt "your prompt" +python demo.py -i ``` -The output (audio in `.flac` format, and video in `.mp4` format) will be saved in `./output`. + +All training parameters are [here](). + + - -### Gradio interface + -### Known limitations -1. The model sometimes generates undesired unintelligible human speech-like sounds -2. The model sometimes generates undesired background music -3. The model struggles with unfamiliar concepts, e.g., it can generate "gunfires" but not "RPG firing". -We believe all of these three limitations can be addressed with more high-quality training data. +## Evaluation +Refer [av-benchmark](https://github.com/hkchengrex/av-benchmark) for benchmarking results. +See [EVAL.md](docs/EVAL.md). -## Training -Work in progress. -## Evaluation -Work in progress. +## Citation + + + +## Relevant Repositories + +- [av-benchmark](https://github.com/hkchengrex/av-benchmark) for benchmarking results. + ## Acknowledgement -Many thanks to: -- [Make-An-Audio 2](https://github.com/bytedance/Make-An-Audio-2) for the 16kHz BigVGAN pretrained model -- [BigVGAN](https://github.com/NVIDIA/BigVGAN) -- [Synchformer](https://github.com/v-iashin/Synchformer) +Many thanks to: +- [VideoLLaMA2](https://github.com/DAMO-NLP-SG/VideoLLaMA2) +- [MMAudio](https://github.com/hkchengrex/MMAudio) +- [FoleyCrafter](https://github.com/open-mmlab/FoleyCrafter) +- [BS-RoFormer](https://github.com/ZFTurbo/Music-Source-Separation-Training) diff --git a/app.py b/app.py index f0b95a36fcbd063edd1e513eda76754713d72acd..dffa1c92bf98355261be7e900e24074655658506 100644 --- a/app.py +++ b/app.py @@ -1,275 +1,162 @@ -import spaces -import logging -from datetime import datetime -from pathlib import Path - -import gradio as gr -import torch -import torchaudio import os +import sys +import time +import gradio as gr +import subprocess +from pathlib import Path +import requests +from moviepy.editor import AudioFileClip, VideoFileClip + +project_root = os.path.dirname(os.path.abspath(__file__)) +mmaudio_path = os.path.join(project_root, 'third_party', 'MMAudio') +sys.path.append(mmaudio_path) + +from pipeline.pipeline import Pipeline +from third_party.MMAudio.mmaudio.eval_utils import setup_eval_logging + +# # download model +# os.makedirs("pretrained/mllm", exist_ok=True) +# from huggingface_hub import snapshot_download +# repo_local_path = snapshot_download(repo_id="lym0302/VideoLLaMA2.1-7B-AV-CoT", cache_dir='pretrained/mllm') + +# remove_vo_model_dir = "pretrained/remove_vo/checkpoints" +# os.makedirs(remove_vo_model_dir, exist_ok=True) +# urls = ["https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/model_bs_roformer_ep_317_sdr_12.9755.ckpt", +# "https://raw.githubusercontent.com/ZFTurbo/Music-Source-Separation-Training/main/configs/viperx/model_bs_roformer_ep_317_sdr_12.9755.yaml"] +# for url in urls: +# file_name = url.split("/")[-1] # Extract file name from URL +# file_path = os.path.join(remove_vo_model_dir, file_name) +# response = requests.get(url, stream=True) +# if response.status_code == 200: +# with open(file_path, "wb") as f: +# for chunk in response.iter_content(chunk_size=8192): # Use a chunk size of 8 KB +# f.write(chunk) +# print(f"File downloaded successfully and saved to {file_path}") +# else: +# print(f"Failed to download the file. Status code: {response.status_code}") + +# os.makedirs("pretrained/v2a/mmaudio", exist_ok=True) -try: - import mmaudio -except ImportError: - os.system("pip install -e .") - import mmaudio - -from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, - setup_eval_logging) -from mmaudio.model.flow_matching import FlowMatching -from mmaudio.model.networks import MMAudio, get_my_mmaudio -from mmaudio.model.sequence_config import SequenceConfig -from mmaudio.model.utils.features_utils import FeaturesUtils -import tempfile - -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True - -log = logging.getLogger() - -device = 'cpu' -dtype = torch.bfloat16 - -model: ModelConfig = all_model_cfg['large_44k_v2'] -model.download_if_needed() -output_dir = Path('./output/gradio') setup_eval_logging() +pipeline = Pipeline( + step0_model_dir='pretrained/mllm/models--lym0302--VideoLLaMA2.1-7B-AV-CoT', + step1_mode='mmaudio_medium_44k', + step2_model_dir='pretrained/mllm/models--lym0302--VideoLLaMA2.1-7B-AV-CoT', + step2_mode='cot', + step3_mode='bs_roformer', +) - -def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: - seq_cfg = model.seq_cfg - - net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval() - net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) - log.info(f'Loaded weights from {model.model_path}') - - feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path, - synchformer_ckpt=model.synchformer_ckpt, - enable_conditions=True, - mode=model.mode, - bigvgan_vocoder_ckpt=model.bigvgan_16k_path, - need_vae_encoder=False) - feature_utils = feature_utils.to(device, dtype).eval() - - return net, feature_utils, seq_cfg - - -net, feature_utils, seq_cfg = get_model() - - -@spaces.GPU(duration=120) -@torch.inference_mode() -def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int, - cfg_strength: float, duration: float): - - rng = torch.Generator(device=device) - if seed >= 0: - rng.manual_seed(seed) - else: - rng.seed() - fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) - - video_info = load_video(video, duration) - clip_frames = video_info.clip_frames - sync_frames = video_info.sync_frames - duration = video_info.duration_sec - clip_frames = clip_frames.unsqueeze(0) - sync_frames = sync_frames.unsqueeze(0) - seq_cfg.duration = duration - net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) - - audios = generate(clip_frames, - sync_frames, [prompt], - negative_text=[negative_prompt], - feature_utils=feature_utils, - net=net, - fm=fm, - rng=rng, - cfg_strength=cfg_strength) - audio = audios.float().cpu()[0] - - # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S') - video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name - # output_dir.mkdir(exist_ok=True, parents=True) - # video_save_path = output_dir / f'{current_time_string}.mp4' - make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate) - log.info(f'Saved video to {video_save_path}') - return video_save_path - - -@spaces.GPU(duration=120) -@torch.inference_mode() -def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float, - duration: float): - - rng = torch.Generator(device=device) - if seed >= 0: - rng.manual_seed(seed) - else: - rng.seed() - fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) - - clip_frames = sync_frames = None - seq_cfg.duration = duration - net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) - - audios = generate(clip_frames, - sync_frames, [prompt], - negative_text=[negative_prompt], - feature_utils=feature_utils, - net=net, - fm=fm, - rng=rng, - cfg_strength=cfg_strength) - audio = audios.float().cpu()[0] - - audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name - torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate) - log.info(f'Saved audio to {audio_save_path}') - return audio_save_path +output_dir = "output_gradio" +os.makedirs(output_dir, exist_ok=True) +skip_final_video = False +def video_to_audio( + video_input: gr.Video, + prompt: str='', + negative_prompt: str='', + mode: str='s4', + postp_mode: str='neg', + duration: float=10, + seed: int=42,): + + log_messages = [] # ็”จไบŽๅญ˜ๅ‚จๆ—ฅๅฟ— + def log_info(msg): + log_messages.append(msg) + return "\n".join(log_messages) # ๆฏๆฌก่ฟ”ๅ›žๅฎŒๆ•ด็š„ๆ—ฅๅฟ—ๅŽ†ๅฒ + + if not video_input: + yield None, log_info("Error: No video input provided.") + return + + yield None, log_info("Generate high-quality audio from video step-by-step...") # ๅˆๅง‹ๅŒ–ๆ—ฅๅฟ— + + st_infer = time.time() + video_input = str(video_input) + + for step_results in pipeline.run_for_gradio( + video_input=video_input, + output_dir=output_dir, + mode=mode, + postp_mode=postp_mode, + prompt=prompt, + negative_prompt=negative_prompt, + duration=duration, + seed=seed + ): + if step_results['log'] == 'Finish step-by-step v2a.': + break + else: + yield None, log_info(step_results['log']) + + + temp_final_audio_path = step_results["temp_final_audio_path"] + temp_final_video_path = step_results["temp_final_video_path"] + + video_name_stem = Path(video_input).stem + final_audio_path = str(Path(output_dir) / f'{video_name_stem}.wav') + final_video_path = str(Path(output_dir) / f'{video_name_stem}.mp4') + + if temp_final_audio_path is not None: + subprocess.run(['cp', str(temp_final_audio_path), final_audio_path], check=True) + step_results["final_audio_path"] = final_audio_path + + if skip_final_video: + step_results["final_video_path"] = None + else: + if temp_final_video_path is not None: + subprocess.run(['cp', str(temp_final_video_path), final_video_path], check=True) + else: + audio = AudioFileClip(final_audio_path) + video = VideoFileClip(video_input) + duration = min(audio.duration, video.duration) + audio = audio.subclip(0, duration) + video.audio = audio + video = video.subclip(0, duration) + video.write_videofile(final_video_path) + step_results["final_video_path"] = final_video_path + + et_infer = time.time() + print(f"Inference time: {et_infer - st_infer:.2f} s.") + print("step_results: ", step_results) + + yield (final_video_path if os.path.exists(final_video_path) else None), log_info(step_results['log']) video_to_audio_tab = gr.Interface( fn=video_to_audio, + # Project page: https://hkchengrex.com/MMAudio/
description=""" - Project page: https://hkchengrex.com/MMAudio/
- Code: https://github.com/hkchengrex/MMAudio
+ Code: https://github.com/lym0302/DeepSound-V1
NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side). Doing so does not improve results. - The model has been trained on 8-second videos. Using much longer or shorter videos will degrade performance. Around 5s~12s should be fine. + This is a step-by-step v2a process and may take a long time. + If Post Processing is set to 'rm', the generated video may be None. """, inputs=[ gr.Video(), gr.Text(label='Prompt'), - gr.Text(label='Negative prompt', value='music'), - gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1), - gr.Number(label='Num steps', value=25, precision=0, minimum=1), - gr.Number(label='Guidance Strength', value=4.5, minimum=1), - gr.Number(label='Duration (sec)', value=8, minimum=1), - ], - outputs='playable_video', - cache_examples=False, - title='MMAudio โ€” Video-to-Audio Synthesis', - examples=[ - [ - 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4', - 'waves, seagulls', - '', - 0, - 25, - 4.5, - 10, - ], - [ - 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4', - '', - 'music', - 0, - 25, - 4.5, - 10, - ], - [ - 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4', - 'bubbles', - '', - 0, - 25, - 4.5, - 10, - ], - [ - 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4', - 'Indian holy music', - '', - 0, - 25, - 4.5, - 10, - ], - [ - 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4', - 'galloping', - '', - 0, - 25, - 4.5, - 10, - ], - [ - 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4', - 'waves, storm', - '', - 0, - 25, - 4.5, - 10, - ], - [ - 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4', - '', - '', - 0, - 25, - 4.5, - 10, - ], - [ - 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4', - 'storm', - '', - 0, - 25, - 4.5, - 10, - ], - [ - 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4', - '', - '', - 0, - 25, - 4.5, - 10, - ], - [ - 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4', - 'typing', - '', - 0, - 25, - 4.5, - 10, - ], - [ - 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4', - '', - '', - 0, - 25, - 4.5, - 10, - ], - ]) + gr.Text(label='Negative prompt', value=''), + gr.Radio(["s3", "s4"], label="Mode", value="s4"), + gr.Radio(["rm", "rep", "neg"], label="Post Processing", value="neg"), + gr.Number(label='Duration (sec)', value=10, minimum=1), + gr.Number(label='Seed (42: random)', value=42, precision=0, minimum=-1), -text_to_audio_tab = gr.Interface( - fn=text_to_audio, - inputs=[ - gr.Text(label='Prompt'), - gr.Text(label='Negative prompt'), - gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1), - gr.Number(label='Num steps', value=25, precision=0, minimum=1), - gr.Number(label='Guidance Strength', value=4.5, minimum=1), - gr.Number(label='Duration (sec)', value=8, minimum=1), ], - outputs='audio', + outputs=[gr.Video(label="Generated Video"), gr.Text(label="Logs"),], cache_examples=False, - title='MMAudio โ€” Text-to-Audio Synthesis', + title='DeepSound-V1 โ€” Video-to-Audio Synthesis', ) + if __name__ == "__main__": - gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab], - ['Video-to-Audio', 'Text-to-Audio']).launch(allowed_paths=[output_dir]) + gr.TabbedInterface([video_to_audio_tab], + ['Video-to-Audio']).launch(allowed_paths=[output_dir]) + + +# if __name__ == "__main__": +# port = 8000 +# gr.TabbedInterface([video_to_audio_tab, ], +# ['Video-to-Audio', ]).launch( +# server_port=port, allowed_paths=[output_dir]) diff --git a/demo.py b/demo.py deleted file mode 100644 index ab66f5bd3599b5960f2b7386d600173c6c541369..0000000000000000000000000000000000000000 --- a/demo.py +++ /dev/null @@ -1,135 +0,0 @@ -import logging -from argparse import ArgumentParser -from pathlib import Path - -import torch -import torchaudio - -from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video, - setup_eval_logging) -from mmaudio.model.flow_matching import FlowMatching -from mmaudio.model.networks import MMAudio, get_my_mmaudio -from mmaudio.model.utils.features_utils import FeaturesUtils - -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True - -log = logging.getLogger() - - -@torch.inference_mode() -def main(): - setup_eval_logging() - - parser = ArgumentParser() - parser.add_argument('--variant', - type=str, - default='large_44k_v2', - help='small_16k, small_44k, medium_44k, large_44k, large_44k_v2') - parser.add_argument('--video', type=Path, help='Path to the video file') - parser.add_argument('--prompt', type=str, help='Input prompt', default='') - parser.add_argument('--negative_prompt', type=str, help='Negative prompt', default='') - parser.add_argument('--duration', type=float, default=8.0) - parser.add_argument('--cfg_strength', type=float, default=4.5) - parser.add_argument('--num_steps', type=int, default=25) - - parser.add_argument('--mask_away_clip', action='store_true') - - parser.add_argument('--output', type=Path, help='Output directory', default='./output') - parser.add_argument('--seed', type=int, help='Random seed', default=42) - parser.add_argument('--skip_video_composite', action='store_true') - parser.add_argument('--full_precision', action='store_true') - - args = parser.parse_args() - - if args.variant not in all_model_cfg: - raise ValueError(f'Unknown model variant: {args.variant}') - model: ModelConfig = all_model_cfg[args.variant] - model.download_if_needed() - seq_cfg = model.seq_cfg - - if args.video: - video_path: Path = Path(args.video).expanduser() - else: - video_path = None - prompt: str = args.prompt - negative_prompt: str = args.negative_prompt - output_dir: str = args.output.expanduser() - seed: int = args.seed - num_steps: int = args.num_steps - duration: float = args.duration - cfg_strength: float = args.cfg_strength - skip_video_composite: bool = args.skip_video_composite - mask_away_clip: bool = args.mask_away_clip - - device = 'cuda' - dtype = torch.float32 if args.full_precision else torch.bfloat16 - - output_dir.mkdir(parents=True, exist_ok=True) - - # load a pretrained model - net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval() - net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) - log.info(f'Loaded weights from {model.model_path}') - - # misc setup - rng = torch.Generator(device=device) - rng.manual_seed(seed) - fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) - - feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path, - synchformer_ckpt=model.synchformer_ckpt, - enable_conditions=True, - mode=model.mode, - bigvgan_vocoder_ckpt=model.bigvgan_16k_path, - need_vae_encoder=False) - feature_utils = feature_utils.to(device, dtype).eval() - - if video_path is not None: - log.info(f'Using video {video_path}') - video_info = load_video(video_path, duration) - clip_frames = video_info.clip_frames - sync_frames = video_info.sync_frames - duration = video_info.duration_sec - if mask_away_clip: - clip_frames = None - else: - clip_frames = clip_frames.unsqueeze(0) - sync_frames = sync_frames.unsqueeze(0) - else: - log.info('No video provided -- text-to-audio mode') - clip_frames = sync_frames = None - - seq_cfg.duration = duration - net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) - - log.info(f'Prompt: {prompt}') - log.info(f'Negative prompt: {negative_prompt}') - - audios = generate(clip_frames, - sync_frames, [prompt], - negative_text=[negative_prompt], - feature_utils=feature_utils, - net=net, - fm=fm, - rng=rng, - cfg_strength=cfg_strength) - audio = audios.float().cpu()[0] - if video_path is not None: - save_path = output_dir / f'{video_path.stem}.flac' - else: - safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '') - save_path = output_dir / f'{safe_filename}.flac' - torchaudio.save(save_path, audio, seq_cfg.sampling_rate) - - log.info(f'Audio saved to {save_path}') - if video_path is not None and not skip_video_composite: - video_save_path = output_dir / f'{video_path.stem}.mp4' - make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate) - log.info(f'Video saved to {output_dir / video_save_path}') - - log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30)) - - -if __name__ == '__main__': - main() diff --git a/docs/images/icon.png b/docs/images/icon.png deleted file mode 100644 index c337eee9868e61173e61f583ef098668681555f5..0000000000000000000000000000000000000000 Binary files a/docs/images/icon.png and /dev/null differ diff --git a/docs/index.html b/docs/index.html deleted file mode 100644 index c2792c3baef3baad9e0220853d1232eebf3c266d..0000000000000000000000000000000000000000 --- a/docs/index.html +++ /dev/null @@ -1,147 +0,0 @@ - - - - - - - - - - - - MMAudio - - - - - - - - - - - - - -



-
-
-
- Taming Multimodal Joint Training for High-Quality
Video-to-Audio Synthesis -
-
- -
-
-
- arXiv 2024 -
-
-
- - - -
-
- 1University of Illinois Urbana-Champaign -
-
- 2Sony AI -
-
- 3Sony Group Corporation -
-
- -
- -
- -
- - -
- [Code] -
- -
- -
- -
- -
-
- TL;DR -
-
-
-
-
-

- MMAudio generates synchronized audio given video and/or text inputs. -

-
-
- -
-
-
- -
-
- Demo -
-
-
-
- -
-
-
- -
- -
- -

-

- -
- - - \ No newline at end of file diff --git a/docs/style.css b/docs/style.css deleted file mode 100644 index 4946ef1f17b794d2122351bf24e4eb08f19b9637..0000000000000000000000000000000000000000 --- a/docs/style.css +++ /dev/null @@ -1,78 +0,0 @@ -body { - font-family: 'Source Sans 3', sans-serif; - font-size: 18px; - margin-left: auto; - margin-right: auto; - font-weight: 400; - height: 100%; - max-width: 1000px; -} - -table { - width: 100%; - border-collapse: collapse; -} -th, td { - border: 1px solid #ddd; - padding: 8px; - text-align: center; -} -th { - background-color: #f2f2f2; -} -video { - width: 100%; - height: auto; -} -p { - font-size: 28px; -} -h2 { - font-size: 36px; -} - -.strong { - font-weight: 700; -} - -.light { - font-weight: 100; -} - -.heavy { - font-weight: 900; -} - -.column { - float: left; -} - -a:link, -a:visited { - color: #05538f; - text-decoration: none; -} - -a:hover { - color: #63cbdd; -} - -hr { - border: 0; - height: 1px; - background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0)); -} - -.video-container { - position: relative; - padding-bottom: 56.25%; /* 16:9 */ - height: 0; - } - -.video-container iframe { - position: absolute; - top: 0; - left: 0; - width: 100%; - height: 100%; -} \ No newline at end of file diff --git a/docs/style_videos.css b/docs/style_videos.css deleted file mode 100644 index 9d641122166e3c3fdd8f3e104628686ed5dc9258..0000000000000000000000000000000000000000 --- a/docs/style_videos.css +++ /dev/null @@ -1,52 +0,0 @@ -body { - font-family: 'Source Sans 3', sans-serif; - font-size: 1.5vh; - font-weight: 400; -} - -table { - width: 100%; - border-collapse: collapse; -} -th, td { - border: 1px solid #ddd; - padding: 8px; - text-align: center; -} -th { - background-color: #f2f2f2; -} -video { - width: 100%; - height: auto; -} -p { - font-size: 1.5vh; - font-weight: bold; -} -h2 { - font-size: 2vh; - font-weight: bold; -} - -.video-container { - position: relative; - padding-bottom: 56.25%; /* 16:9 */ - height: 0; - } - -.video-container iframe { - position: absolute; - top: 0; - left: 0; - width: 100%; - height: 100%; -} - -.video-header { - background-color: #f2f2f2; - text-align: center; - font-size: 1.5vh; - font-weight: bold; - padding: 8px; -} \ No newline at end of file diff --git a/docs/video_gen.html b/docs/video_gen.html deleted file mode 100644 index da1a9d95393153b1264cc4c58ee78bda23379a2a..0000000000000000000000000000000000000000 --- a/docs/video_gen.html +++ /dev/null @@ -1,254 +0,0 @@ - - - - - - - - - - MMAudio - - - - - - - - - - - - -
-

Comparisons with Movie Gen Audio on Videos Generated by MovieGen

-

- Example 1: Ice cracking with sharp snapping sound, and metal tool scraping against the ice surface. - Back to index -

- -
-
-
Movie Gen Audio
-
- -
-
-
-
Ours
-
- -
-
-
-
- - - -

- Example 2: Rhythmic splashing and lapping of water. - Back to index -

-
-
-
Movie Gen Audio
-
- -
-
-
-
Ours
-
- -
-
-
-
- -

- Example 3: Shovel scrapes against dry earth. - Back to index -

-
-
-
Movie Gen Audio
-
- -
-
-
-
Ours
-
- -
-
-
-
- - -

- (Failure case) Example 4: Creamy sound of mashed potatoes being scooped. - Back to index -

-
-
-
Movie Gen Audio
-
- -
-
-
-
Ours
-
- -
-
-
-
- -
- -
- -

Results on Videos Generated by Hunyuan

-

- Back to index -

-
-
-
Typing
-
- -
-
-
-
Water is rushing down a stream and pouring
-
- -
-
-
-
-
-
Waves on beach
-
- -
-
-
-
Water droplet
-
- -
-
-
-
- -

Results on Videos Generated by Sora

-

- Back to index -

-
-
-
Ships riding waves
-
- -
-
-
-
Train (no text prompt given)
-
- -
-
-
-
-
-
Seashore (no text prompt given)
-
- -
-
-
-
Surfing (failure: unprompted music)
-
- -
-
-
-
- -
-

Results on Videos Generated by Mochi 1

-

- Back to index -

-
-
-
Magical fire and lightning (no text prompt given)
-
- -
-
-
-
Storm (no text prompt given)
-
- -
-
-
-
- -

Results on Videos Generated by LTX-Video

-

- Back to index -

-
-
-
Firewood burning and cracking
-
- -
-
-
-
Waterfall, water splashing
-
- -
-
-
-
- -
- - - \ No newline at end of file diff --git a/docs/video_main.html b/docs/video_main.html deleted file mode 100644 index 36c3d996cb5bc0e9050fd217b2b1a056b085a88e..0000000000000000000000000000000000000000 --- a/docs/video_main.html +++ /dev/null @@ -1,98 +0,0 @@ - - - - - - - - - - MMAudio - - - - - - - - - - - - - -

Index

-

(Click on the links to load the corresponding videos) Back to project page

- -
    -
  1. - Comparisons with Movie Gen Audio on Videos Generated by MovieGen -
  2. -
  3. - Results on Videos Generated by Hunyuan and Sora -
  4. -
  5. - Results on Videos Generated by Mochi 1 and LTX-Video -
  6. -
  7. - On VGGSound -
      -
    1. Example 1: Wolf howling
    2. -
    3. Example 2: Striking a golf ball
    4. -
    5. Example 3: Hitting a drum
    6. -
    7. Example 4: Dog barking
    8. -
    9. Example 5: Playing a string instrument
    10. -
    11. Example 6: A group of people playing tambourines
    12. -
    13. Extra results & failure cases
    14. -
    -
  8. -
- -
- -
-
-
- - - \ No newline at end of file diff --git a/docs/video_vgg.html b/docs/video_vgg.html deleted file mode 100644 index 945b33660ed46c3f7acad3157c4181219b248533..0000000000000000000000000000000000000000 --- a/docs/video_vgg.html +++ /dev/null @@ -1,452 +0,0 @@ - - - - - - - - - - MMAudio - - - - - - - - - - -
-

Comparisons with state-of-the-art methods in VGGSound

-

- Example 1: Wolf howling. - Back to index -

-
-
-
Ground-truth
-
- -
-
-
-
Ours
-
- -
-
-
-
V2A-Mapper
-
- -
-
-
-
FoleyCrafter
-
- -
-
-
-
-
-
Frieren
-
- -
-
-
-
VATT
-
- -
-
-
-
V-AURA
-
- -
-
-
-
Seeing and Hearing
-
- -
-
-
-
- -
-

Comparisons with state-of-the-art methods in VGGSound

-

- Example 2: Striking a golf ball. - Back to index -

- -
-
-
Ground-truth
-
- -
-
-
-
Ours
-
- -
-
-
-
V2A-Mapper
-
- -
-
-
-
FoleyCrafter
-
- -
-
-
-
-
-
Frieren
-
- -
-
-
-
VATT
-
- -
-
-
-
V-AURA
-
- -
-
-
-
Seeing and Hearing
-
- -
-
-
-
- -
-

Comparisons with state-of-the-art methods in VGGSound

-

- Example 3: Hitting a drum. - Back to index -

- -
-
-
Ground-truth
-
- -
-
-
-
Ours
-
- -
-
-
-
V2A-Mapper
-
- -
-
-
-
FoleyCrafter
-
- -
-
-
-
-
-
Frieren
-
- -
-
-
-
VATT
-
- -
-
-
-
V-AURA
-
- -
-
-
-
Seeing and Hearing
-
- -
-
-
-
-
- -
-

Comparisons with state-of-the-art methods in VGGSound

-

- Example 4: Dog barking. - Back to index -

- -
-
-
Ground-truth
-
- -
-
-
-
Ours
-
- -
-
-
-
V2A-Mapper
-
- -
-
-
-
FoleyCrafter
-
- -
-
-
-
-
-
Frieren
-
- -
-
-
-
VATT
-
- -
-
-
-
V-AURA
-
- -
-
-
-
Seeing and Hearing
-
- -
-
-
-
- -
-

Comparisons with state-of-the-art methods in VGGSound

-

- Example 5: Playing a string instrument. - Back to index -

- -
-
-
Ground-truth
-
- -
-
-
-
Ours
-
- -
-
-
-
V2A-Mapper
-
- -
-
-
-
FoleyCrafter
-
- -
-
-
-
-
-
Frieren
-
- -
-
-
-
VATT
-
- -
-
-
-
V-AURA
-
- -
-
-
-
Seeing and Hearing
-
- -
-
-
-
- -
-

Comparisons with state-of-the-art methods in VGGSound

-

- Example 6: A group of people playing tambourines. - Back to index -

- -
-
-
Ground-truth
-
- -
-
-
-
Ours
-
- -
-
-
-
V2A-Mapper
-
- -
-
-
-
FoleyCrafter
-
- -
-
-
-
-
-
Frieren
-
- -
-
-
-
VATT
-
- -
-
-
-
V-AURA
-
- -
-
-
-
Seeing and Hearing
-
- -
-
-
-
- -
-

Comparisons with state-of-the-art methods in VGGSound

-

- Back to index -

- -
-
-
Moving train
-
- -
-
-
-
Water splashing
-
- -
-
-
-
Skateboarding
-
- -
-
-
-
Synchronized clapping
-
- -
-
-
- -

- -
-

Failure cases

-

- Back to index -

- -
-
-
Human speech
-
- -
-
-
-
Unfamiliar vision input
-
- -
-
-
-
-
- - - \ No newline at end of file diff --git a/mmaudio/__init__.py b/pipeline/__init__.py similarity index 100% rename from mmaudio/__init__.py rename to pipeline/__init__.py diff --git a/pipeline/__pycache__/__init__.cpython-310.pyc b/pipeline/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27c645b9e977c6466bd93033dc0b3238218ee0a8 Binary files /dev/null and b/pipeline/__pycache__/__init__.cpython-310.pyc differ diff --git a/pipeline/__pycache__/__init__.cpython-38.pyc b/pipeline/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8a471e1eabb5f1693534a146f8221d4d1f99ae6 Binary files /dev/null and b/pipeline/__pycache__/__init__.cpython-38.pyc differ diff --git a/pipeline/__pycache__/pipeline.cpython-310.pyc b/pipeline/__pycache__/pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6883284b060817ae41c215b0f6cc3f06fce9794b Binary files /dev/null and b/pipeline/__pycache__/pipeline.cpython-310.pyc differ diff --git a/pipeline/__pycache__/pipeline.cpython-38.pyc b/pipeline/__pycache__/pipeline.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5089252ee24dcbe73680ff3e1946c7263f737e78 Binary files /dev/null and b/pipeline/__pycache__/pipeline.cpython-38.pyc differ diff --git a/pipeline/__pycache__/step0.cpython-310.pyc b/pipeline/__pycache__/step0.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd6e53e8f0e44f211852b2013dd0c169a80553eb Binary files /dev/null and b/pipeline/__pycache__/step0.cpython-310.pyc differ diff --git a/pipeline/__pycache__/step0.cpython-38.pyc b/pipeline/__pycache__/step0.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1e778f54e7c344fb5dc054814af66923e11bbe5 Binary files /dev/null and b/pipeline/__pycache__/step0.cpython-38.pyc differ diff --git a/pipeline/__pycache__/step1.cpython-310.pyc b/pipeline/__pycache__/step1.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02f91be3fbdf746fef1772f83b51fbc847a5e2cf Binary files /dev/null and b/pipeline/__pycache__/step1.cpython-310.pyc differ diff --git a/pipeline/__pycache__/step1.cpython-38.pyc b/pipeline/__pycache__/step1.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a190ed47a0e4453a63a8aa145674b2f4982b4ee Binary files /dev/null and b/pipeline/__pycache__/step1.cpython-38.pyc differ diff --git a/pipeline/__pycache__/step2.cpython-310.pyc b/pipeline/__pycache__/step2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2cc54c6c3b5538af0b2ba2eaaa0acc0e369a27a Binary files /dev/null and b/pipeline/__pycache__/step2.cpython-310.pyc differ diff --git a/pipeline/__pycache__/step2.cpython-38.pyc b/pipeline/__pycache__/step2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79cae7189ede70e20c51e42596a6a5805981658a Binary files /dev/null and b/pipeline/__pycache__/step2.cpython-38.pyc differ diff --git a/pipeline/__pycache__/step3.cpython-310.pyc b/pipeline/__pycache__/step3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b262d52dbd05c03954a51fd21a842ec82ee6e133 Binary files /dev/null and b/pipeline/__pycache__/step3.cpython-310.pyc differ diff --git a/pipeline/__pycache__/step3.cpython-38.pyc b/pipeline/__pycache__/step3.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04937460c4d2b393f09bd1dc3d2ca93a8fe06cca Binary files /dev/null and b/pipeline/__pycache__/step3.cpython-38.pyc differ diff --git a/pipeline/__pycache__/step4.cpython-310.pyc b/pipeline/__pycache__/step4.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74e6b7f15b48fbeb441d84a6311a6dba9896562c Binary files /dev/null and b/pipeline/__pycache__/step4.cpython-310.pyc differ diff --git a/pipeline/__pycache__/step4.cpython-38.pyc b/pipeline/__pycache__/step4.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6332e0bcd38b0c4eb270418fae097299f8190193 Binary files /dev/null and b/pipeline/__pycache__/step4.cpython-38.pyc differ diff --git a/pipeline/pipeline.py b/pipeline/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..1f89add9f42b30cb03f19f5e69095002a13b3345 --- /dev/null +++ b/pipeline/pipeline.py @@ -0,0 +1,175 @@ +# coding=utf-8 + +from .step0 import Step0 +from .step1 import Step1 +from .step2 import Step2 +from .step3 import Step3 +from .step4 import Step4 +import logging +import re +import os + +class Pipeline: + def __init__(self, step0_model_dir, step1_mode, step2_model_dir, step2_mode, step3_mode): + self.step0 = Step0(step0_model_dir) + self.step1 = Step1(step1_mode) + self.step2 = Step2(step2_model_dir, step2_mode) + self.step3 = Step3(model_type=step3_mode) + self.step4 = Step4() + self.step_processors = [self.step1, self.step2, self.step3, self.step4] + self.log = logging.getLogger(self.__class__.__name__) + self.log.setLevel(logging.INFO) + + + def run(self, video_input, output_dir, mode='s4', postp_mode='rep', prompt='', negative_prompt='', duration=10, seed=42): + step0_resp = self.step0.run(video_input) + step0_resp_list = re.findall(r'(Step\d:.*?)(?=Step\d:|$)', step0_resp, re.DOTALL) + step_infos = [step_info.strip().split("\n")[0] for step_info in step0_resp_list] + step3_temp_dir = os.path.join(output_dir, "remove_vo") + + step_results = {"temp_final_audio_path": None, "temp_final_video_path": None} + for step_info in step_infos: + self.log.info(f"Start to {step_info}") + if step_info == 'Step1: Generate audio from video.': + step1_audio_path, step1_video_path = self.step1.run(video_input, output_dir, prompt, negative_prompt, duration=duration, seed=seed) + step_results["step1_audio_path"] = step1_audio_path + step_results["step1_video_path"] = step1_video_path + + elif step_info == 'Step2: Given a video and its generated audio, determine whether the audio contains voice-over.': + is_vo = self.step2.run(str(step_results["step1_video_path"])) + step_results["is_vo"] = is_vo + if not step_results["is_vo"]: # not voice-over + step_results["temp_final_audio_path"] = step_results["step1_audio_path"] + step_results["temp_final_video_path"] = step_results["step1_video_path"] + return step_results + + elif step_info == 'Step3: Remove voice-over from audio.': + step3_audio_path = self.step3.run(input_audio_path=step_results["step1_audio_path"], + temp_store_dir=step3_temp_dir, + output_dir=output_dir) + step_results["step3_audio_path"] = step3_audio_path + if mode == 's3': + step_results["temp_final_audio_path"] = step_results["step3_audio_path"] + return step_results + + elif step_info == 'Step4: Determine whether the audio is silent.': + is_silent = self.step4.run(step_results["step3_audio_path"]) + step_results["is_silent"] = is_silent + + else: + self.log.error(f"Step-by-Step Error !!!!!!!!!") + return step_results + + if not step_results["is_silent"]: # not silent + step_results["temp_final_audio_path"] = step_results["step3_audio_path"] + else: + self.log.info(f"Start to post process, use mode: {postp_mode}") + if postp_mode == "rm": + step_results["temp_final_audio_path"] = None + elif postp_mode == "rep": + step_results["temp_final_audio_path"] = step_results["step1_audio_path"] + step_results["temp_final_video_path"] = step_results["step1_video_path"] + elif postp_mode == "neg": + neg_audio_path, neg_video_path = self.step1.run(video_input, output_dir, prompt, negative_prompt='human voice', duration=duration, seed=seed, is_postp=True) + step_results["temp_final_audio_path"] = neg_audio_path + step_results["temp_final_video_path"] = neg_video_path + else: + self.log.error(f"Error postp_mode: {postp_mode}") + + self.log.info(f"After post-processing, audio is {step_results['temp_final_audio_path']} and video is {step_results['temp_final_video_path']}") + self.log.info(f"Finish Post-Process successfully.\n") + + return step_results + + + + def run_for_gradio(self, video_input, output_dir, mode='s4', postp_mode='rep', prompt='', negative_prompt='', duration=10, seed=42): + step_results = {"temp_final_audio_path": None, + "temp_final_video_path": None, + 'log': ''} + + step0_resp = self.step0.run(video_input) + step0_resp_list = re.findall(r'(Step\d:.*?)(?=Step\d:|$)', step0_resp, re.DOTALL) + step_infos = [step_info.strip().split("\n")[0] for step_info in step0_resp_list] + step3_temp_dir = os.path.join(output_dir, "remove_vo") + + + for step_info in step_infos: + self.log.info(f"Start to {step_info}") + step_results['log'] = f"Start to {step_info}" + yield step_results + + if step_info == 'Step1: Generate audio from video.': + step1_audio_path, step1_video_path = self.step1.run(video_input, output_dir, prompt, negative_prompt, duration=duration, seed=seed) + step_results["step1_audio_path"] = step1_audio_path + step_results["step1_video_path"] = step1_video_path + step_results['log'] = "Step1 completed." + yield step_results + + elif step_info == 'Step2: Given a video and its generated audio, determine whether the audio contains voice-over.': + is_vo = self.step2.run(str(step_results["step1_video_path"])) + step_results["is_vo"] = is_vo + step_results['log'] = f"Step2 completed. Contain voice-over? {'Yes' if is_vo else 'No'}" + yield step_results + if not step_results["is_vo"]: # not voice-over + step_results["temp_final_audio_path"] = step_results["step1_audio_path"] + step_results["temp_final_video_path"] = step_results["step1_video_path"] + step_results['log'] = "Finish step-by-step v2a." + yield step_results + + elif step_info == 'Step3: Remove voice-over from audio.': + step3_audio_path = self.step3.run(input_audio_path=step_results["step1_audio_path"], + temp_store_dir=step3_temp_dir, + output_dir=output_dir) + step_results["step3_audio_path"] = step3_audio_path + step_results['log'] = f"Step3 completed." + yield step_results + if mode == 's3': + step_results["temp_final_audio_path"] = step_results["step3_audio_path"] + step_results['log'] = "Finish step-by-step v2a." + yield step_results + + elif step_info == 'Step4: Determine whether the audio is silent.': + is_silent = self.step4.run(step_results["step3_audio_path"]) + step_results["is_silent"] = is_silent + step_results['log'] = f"Step4 completed. Silent? {'Yes' if is_silent else 'No'}" + yield step_results + + else: + self.log.error(f"Step-by-Step Error !!!!!!!!!") + step_results['log'] = f"Step-by-Step Error !!!!!!!!!" + yield step_results + step_results['log'] = "Finish step-by-step v2a." + yield step_results + + if not step_results["is_silent"]: # not silent + step_results["temp_final_audio_path"] = step_results["step3_audio_path"] + step_results['log'] = "Finish step-by-step v2a." + yield step_results + + else: + step_results['log'] = f"Post-processing with mode: {postp_mode}" + yield step_results + self.log.info(f"Start to post process, use mode: {postp_mode}") + + if postp_mode == "rm": + step_results["temp_final_audio_path"] = None + elif postp_mode == "rep": + step_results["temp_final_audio_path"] = step_results["step1_audio_path"] + step_results["temp_final_video_path"] = step_results["step1_video_path"] + elif postp_mode == "neg": + neg_audio_path, neg_video_path = self.step1.run(video_input, output_dir, prompt, negative_prompt='human voice', duration=duration, seed=seed, is_postp=True) + step_results["temp_final_audio_path"] = neg_audio_path + step_results["temp_final_video_path"] = neg_video_path + else: + self.log.error(f"Error postp_mode: {postp_mode}") + + self.log.info(f"After post-processing, audio is {step_results['temp_final_audio_path']} and video is {step_results['temp_final_video_path']}") + self.log.info(f"Finish Post-Process successfully.\n") + step_results['log'] = f"Post-processing completed." + yield step_results + + + step_results['log'] = "Finish step-by-step v2a." + yield step_results + diff --git a/pipeline/step0.py b/pipeline/step0.py new file mode 100644 index 0000000000000000000000000000000000000000..0448cc6242d161716ec9e6b447cd17efa6437b9e --- /dev/null +++ b/pipeline/step0.py @@ -0,0 +1,39 @@ +# coding=utf-8 +# CoT generate step-by-step + +from third_party.VideoLLaMA2.videollama2 import model_init, mm_infer +import logging + +class Step0: + def __init__(self, model_path, modal_type='v'): + self.log = logging.getLogger(self.__class__.__name__) + self.log.setLevel(logging.INFO) + + self.model, self.processor, self.tokenizer = model_init(model_path) + self.modal_type=modal_type + if modal_type == "a": + self.model.model.vision_tower = None + elif modal_type == "v": + self.model.model.audio_tower = None + elif modal_type == "av": + pass + else: + raise NotImplementedError + self.modal = 'audio' if modal_type == "a" else "video" + self.question = f"Generate high-quality audio from video step-by-step." + self.preprocess = self.processor[self.modal] + + def run(self, video_path): + self.log.info("######################################################################################################") + self.log.info("Generate high-quality audio from video step-by-step...") + audio_video_tensor = self.preprocess(video_path, va=False) + output = mm_infer( + audio_video_tensor, + self.question, + model=self.model, + tokenizer=self.tokenizer, + modal=self.modal, + do_sample=False, + ) + + return output diff --git a/pipeline/step1.py b/pipeline/step1.py new file mode 100644 index 0000000000000000000000000000000000000000..a926d305cf8e6f83a8b687ff64999eccc3fc4a6f --- /dev/null +++ b/pipeline/step1.py @@ -0,0 +1,36 @@ +# coding=utf-8 +# V2A +import logging + + +class Step1: + def __init__(self, step1_mode): + self.log = logging.getLogger(self.__class__.__name__) + self.log.setLevel(logging.INFO) + + if step1_mode.startswith('mmaudio'): + from v2a_models.v2a_mmaudio import V2A_MMAudio + variant = step1_mode.replace("mmaudio_", "") + self.v2a_model = V2A_MMAudio(variant) + elif step1_mode == "foleycrafter": + from v2a_models.v2a_foleycrafter import V2A_FoleyCrafter + self.v2a_model = V2A_FoleyCrafter() + else: + self.log.error(f"Error step1_mode: {step1_mode}") + + + + def run(self, video_path, output_dir, prompt='', negative_prompt='', duration=10, seed=42, is_postp=False,): + # self.log.info("Step1: Generate audio from video.") + step1_audio_path, step1_video_path = self.v2a_model.generate_audio( + video_path=video_path, + output_dir=output_dir, + prompt=prompt, + negative_prompt=negative_prompt, + duration=duration, + seed=seed, + is_postp=is_postp) + + self.log.info(f"The audio generated by Step1 is in {step1_audio_path}, and the video is in {step1_video_path}") + self.log.info("Finish Step1 successfully.\n") + return step1_audio_path, step1_video_path diff --git a/pipeline/step2.py b/pipeline/step2.py new file mode 100644 index 0000000000000000000000000000000000000000..416a7f8e99d2ff8332c9d866df37345d90e301b9 --- /dev/null +++ b/pipeline/step2.py @@ -0,0 +1,52 @@ +# coding=utf-8 +# judge voice-over + +from third_party.VideoLLaMA2.videollama2 import model_init, mm_infer +import logging + +class Step2: + def __init__(self, model_path, step2_mode, modal_type="av"): + self.log = logging.getLogger(self.__class__.__name__) + self.log.setLevel(logging.INFO) + + self.model, self.processor, self.tokenizer = model_init(model_path) + self.modal_type=modal_type + if modal_type == "a": + self.model.model.vision_tower = None + elif modal_type == "v": + self.model.model.audio_tower = None + elif modal_type == "av": + pass + else: + raise NotImplementedError + self.modal = 'audio' if modal_type == "a" else "video" + + self.question = f"Given a video and its corresponding audio, determine whether the audio contains voice-over? Options: A. Yes, B. No. Choose A or B." + self.preprocess = self.processor[self.modal] + + self.step2_mode = step2_mode + + def run(self, video_audio_path): + # self.log.info("Step2: Given a video and its generated audio, determine whether the audio contains voice-over.") + audio_video_tensor = self.preprocess(video_audio_path, va=True) + output = mm_infer( + audio_video_tensor, + self.question, + model=self.model, + tokenizer=self.tokenizer, + modal=self.modal, + do_sample=False, + ) + # print("oooooooooooooooooooooo: ", output) + + if self.step2_mode == "cot": + output = output.split("")[-1][1] + print("1111111111111111111111111: ", output) + output = (output == "A") + + if output: + self.log.info(f"The video generated by Step1 ({video_audio_path}) contains voice-over.") + else: + self.log.info(f"The video generated by Step1 ({video_audio_path}) does not contain voice-over.") + self.log.info("Finish Step2 successfully.\n") + return output diff --git a/pipeline/step3.py b/pipeline/step3.py new file mode 100644 index 0000000000000000000000000000000000000000..f31fda44de1e66a3bb76e730059ff762e48467ef --- /dev/null +++ b/pipeline/step3.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Remove voice-over +import logging +import argparse +import subprocess +import librosa +import os +import torch +import soundfile as sf +import numpy as np + + +# Using the embedded version of Python can also correctly import the utils module. +# current_dir = os.path.dirname(os.path.abspath(__file__)) +# sys.path.append(current_dir) + +from third_party.MusicSourceSeparationTraining.utils import demix, load_config, normalize_audio, denormalize_audio, draw_spectrogram +from third_party.MusicSourceSeparationTraining.utils import prefer_target_instrument, apply_tta, load_start_checkpoint +from third_party.MusicSourceSeparationTraining.models.bs_roformer import BSRoformer +import warnings + +warnings.filterwarnings("ignore") + +model_base_dir = "pretrained/remove_vo/checkpoints" +MODEL_PATHS = {"bs_roformer": [f"{model_base_dir}/model_bs_roformer_ep_317_sdr_12.9755.ckpt", f"{model_base_dir}/model_bs_roformer_ep_317_sdr_12.9755.yaml"]} + + +class Step3: + def __init__(self, model_type="bs_roformer"): + model_path, config_path = MODEL_PATHS[model_type] + + self.log = logging.getLogger(self.__class__.__name__) + self.log.setLevel(logging.INFO) + self.device = 'cpu' + if torch.cuda.is_available(): + self.device = 'cuda' + elif torch.backends.mps.is_available(): + self.device = 'mps' + else: + self.log.warning('CUDA/MPS are not available, running on CPU') + + self.model_type = model_type + + # self.model, self.config = get_model_from_config(model_type, config_path) + self.config = load_config(model_type, config_path) + self.model = BSRoformer(**dict(self.config.model)) + args = argparse.Namespace() + args.start_check_point = model_path + args.model_type = model_type + args.lora_checkpoint = '' + load_start_checkpoint(args, self.model, type_='inference') + self.model = self.model.to(self.device) + self.sample_rate = getattr(self.config.audio, 'sample_rate', 44100) + + + def run(self, + input_audio_path, + temp_store_dir, # for remove result dir + output_dir, # for final dir + disable_detailed_pbar: bool=False, + use_tta: bool= False, + extract_instrumental: bool=True, + codec="wav", + subtype="FLOAT", + draw_spectro=0, + ): + + # self.log.info("Step3: Remove voice-over from audio.") + + os.makedirs(output_dir, exist_ok=True) + + if disable_detailed_pbar: + detailed_pbar = False + else: + detailed_pbar = True + + instruments = prefer_target_instrument(self.config)[:] + + mix, sr = librosa.load(input_audio_path, sr=self.sample_rate, mono=False) + # If mono audio we must adjust it depending on model + if len(mix.shape) == 1: + mix = np.expand_dims(mix, axis=0) + if 'num_channels' in self.config.audio: + if self.config.audio['num_channels'] == 2: + print(f'Convert mono track to stereo...') + mix = np.concatenate([mix, mix], axis=0) + + mix_orig = mix.copy() + if 'normalize' in self.config.inference: + if self.config.inference['normalize'] is True: + mix, norm_params = normalize_audio(mix) + + waveforms_orig = demix(self.config, self.model, mix, self.device, model_type=self.model_type, pbar=detailed_pbar) + if use_tta: + waveforms_orig = apply_tta(self.config, self.model, mix, waveforms_orig, self.device, self.model_type) + + if extract_instrumental: + instr = 'vocals' if 'vocals' in instruments else instruments[0] + waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr] + if 'instrumental' not in instruments: + instruments.append('instrumental') + + file_name = os.path.splitext(os.path.basename(input_audio_path))[0].replace(".step1", "") + temp_output_dir = os.path.join(temp_store_dir, file_name) + os.makedirs(temp_output_dir, exist_ok=True) + + for instr in instruments: + estimates = waveforms_orig[instr] + if 'normalize' in self.config.inference: + if self.config.inference['normalize'] is True: + estimates = denormalize_audio(estimates, norm_params) + + output_path = os.path.join(temp_output_dir, f"{instr}.{codec}") + sf.write(output_path, estimates.T, sr, subtype=subtype) + if draw_spectro > 0: + output_img_path = os.path.join(temp_output_dir, f"{instr}.jpg") + draw_spectrogram(estimates.T, sr, draw_spectro, output_img_path) + + + instrumental_file = os.path.join(temp_output_dir, 'instrumental.wav') + step3_audio_path = f"{output_dir}/{file_name}.step3.wav" + subprocess.run(['cp', instrumental_file, step3_audio_path]) + + self.log.info(f"The voice-over has been removed, and the audio is saved in {step3_audio_path}") + self.log.info("Finish Step3 successfully.\n") + return step3_audio_path + + + diff --git a/pipeline/step4.py b/pipeline/step4.py new file mode 100644 index 0000000000000000000000000000000000000000..3fc5c0d848b21e7c45ae6cd99cc08863a83bab13 --- /dev/null +++ b/pipeline/step4.py @@ -0,0 +1,31 @@ +# coding=utf-8 +# Silence detection +import logging +import librosa +import numpy as np + + +class Step4: + def __init__(self): + self.log = logging.getLogger(self.__class__.__name__) + self.log.setLevel(logging.INFO) + + + def run(self, + audio_path, + silence_thresh=-50, + duration_thresh=0.9): + # self.log.info("Step4: Determine whether the audio is silent.") + y, sr = librosa.load(audio_path, sr=None) + energy = librosa.feature.rms(y=y)[0] + energy_db = librosa.amplitude_to_db(energy) + silent_ratio = np.sum(energy_db < silence_thresh) / len(energy_db) + is_silent = silent_ratio > duration_thresh + + if is_silent: + self.log.info(f"The audio after removing the voiceover ({audio_path}) is silent.") + else: + self.log.info(f"The audio after removing the voiceover ({audio_path}) is not silent.") + self.log.info("Finish Step4 successfully.\n") + + return is_silent diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 160d9d00777a11dafb4b56f553f76c1be06213a6..0000000000000000000000000000000000000000 --- a/pyproject.toml +++ /dev/null @@ -1,52 +0,0 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.metadata] -allow-direct-references = true - -[tool.yapf] -based_on_style = "pep8" -indent_width = 4 -column_limit = 100 - -[project] -name = "mmaudio" -version = "1.0.0" -authors = [{ name = "Rex Cheng", email = "hkchengrex@gmail.com" }] -description = "" -readme = "README.md" -requires-python = ">=3.9" -classifiers = [ - "Programming Language :: Python :: 3", - "Operating System :: OS Independent", -] -dependencies = [ - 'torch >= 2.5.1', - 'python-dotenv', - 'cython', - 'gitpython >= 3.1', - 'tensorboard >= 2.11', - 'numpy >= 1.21, <2.1', - 'Pillow >= 9.5', - 'opencv-python >= 4.8', - 'scipy >= 1.7', - 'tqdm >= 4.66.1', - 'gradio >= 3.34', - 'einops >= 0.6', - 'hydra-core >= 1.3.2', - 'requests', - 'torchdiffeq', - 'librosa >= 0.8.1', - 'nitrous-ema', - 'safetensors', - 'auraloss', - 'hydra_colorlog', - 'tensordict', - 'colorlog', - 'open_clip_torch', - 'soundfile', -] - -[tool.hatch.build.targets.wheel] -packages = ["mmaudio"] diff --git a/requirements.txt.bak b/requirements.txt.bak deleted file mode 100644 index 9e461d6d33dbcfd8c06c060bae9752beda85d428..0000000000000000000000000000000000000000 --- a/requirements.txt.bak +++ /dev/null @@ -1,27 +0,0 @@ -torch == 2.4.0 -torchvision -torchaudio -python-dotenv -cython -gitpython >= 3.1 -tensorboard >= 2.11 -numpy >= 1.21, <2.1 -Pillow >= 9.5 -opencv-python >= 4.8 -scipy >= 1.7 -tqdm >= 4.66.1 -gradio >= 3.34 -einops >= 0.6 -hydra-core >= 1.3.2 -requests -torchdiffeq -librosa >= 0.8.1 -nitrous-ema -safetensors -auraloss -hydra_colorlog -tensordict -colorlog -open_clip_torch -soundfile -av \ No newline at end of file diff --git a/third_party/MMAudio/.gitignore b/third_party/MMAudio/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f732c933876160a511e008af864962edd1ae5620 --- /dev/null +++ b/third_party/MMAudio/.gitignore @@ -0,0 +1,146 @@ +run_*.sh +log/ +saves +saves/ +weights/ +weights +output/ +output +pretrained/ +workspace +workspace/ +ext_weights/ +ext_weights +.checkpoints/ +.vscode/ +training/example_output/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/third_party/MMAudio/LICENSE b/third_party/MMAudio/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..0ea89b1e3b1b756f25d9a9995a9b5a137647ebf4 --- /dev/null +++ b/third_party/MMAudio/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Sony Research Inc. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/mmaudio/data/__init__.py b/third_party/MMAudio/mmaudio/__init__.py similarity index 100% rename from mmaudio/data/__init__.py rename to third_party/MMAudio/mmaudio/__init__.py diff --git a/mmaudio/ext/bigvgan_v2/__init__.py b/third_party/MMAudio/mmaudio/data/__init__.py similarity index 100% rename from mmaudio/ext/bigvgan_v2/__init__.py rename to third_party/MMAudio/mmaudio/data/__init__.py diff --git a/mmaudio/data/av_utils.py b/third_party/MMAudio/mmaudio/data/av_utils.py similarity index 81% rename from mmaudio/data/av_utils.py rename to third_party/MMAudio/mmaudio/data/av_utils.py index 7d4945b9658b8208f039e72d78e1dac45ae5e12d..39e23349b12dcd90f3c2b530b4c95c88e94d122a 100644 --- a/mmaudio/data/av_utils.py +++ b/third_party/MMAudio/mmaudio/data/av_utils.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from fractions import Fraction from pathlib import Path -from typing import Optional +from typing import Optional, List, Tuple import av import numpy as np @@ -15,7 +15,7 @@ class VideoInfo: fps: Fraction clip_frames: torch.Tensor sync_frames: torch.Tensor - all_frames: Optional[list[np.ndarray]] + all_frames: Optional[List[np.ndarray]] @property def height(self): @@ -25,9 +25,35 @@ class VideoInfo: def width(self): return self.all_frames[0].shape[1] + @classmethod + def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float, + fps: Fraction) -> 'VideoInfo': + num_frames = int(duration_sec * fps) + all_frames = [image_info.original_frame] * num_frames + return cls(duration_sec=duration_sec, + fps=fps, + clip_frames=image_info.clip_frames, + sync_frames=image_info.sync_frames, + all_frames=all_frames) -def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float, - need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]: + +@dataclass +class ImageInfo: + clip_frames: torch.Tensor + sync_frames: torch.Tensor + original_frame: Optional[np.ndarray] + + @property + def height(self): + return self.original_frame.shape[0] + + @property + def width(self): + return self.original_frame.shape[1] + + +def read_frames(video_path: Path, list_of_fps: List[float], start_sec: float, end_sec: float, + need_all_frames: bool) -> Tuple[List[np.ndarray], List[np.ndarray], Fraction]: output_frames = [[] for _ in list_of_fps] next_frame_time_for_each_fps = [0.0 for _ in list_of_fps] time_delta_for_each_fps = [1 / fps for fps in list_of_fps] diff --git a/third_party/MMAudio/mmaudio/data/data_setup.py b/third_party/MMAudio/mmaudio/data/data_setup.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ebcea712012d811aa39ae28be628fd94c8bd13 --- /dev/null +++ b/third_party/MMAudio/mmaudio/data/data_setup.py @@ -0,0 +1,174 @@ +import logging +import random + +import numpy as np +import torch +from omegaconf import DictConfig +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.dataloader import default_collate +from torch.utils.data.distributed import DistributedSampler + +from mmaudio.data.eval.audiocaps import AudioCapsData +from mmaudio.data.eval.video_dataset import MovieGen, VGGSound +from mmaudio.data.extracted_audio import ExtractedAudio +from mmaudio.data.extracted_vgg import ExtractedVGG +from mmaudio.data.mm_dataset import MultiModalDataset +from mmaudio.utils.dist_utils import local_rank + +log = logging.getLogger() + + +# Re-seed randomness every time we start a worker +def worker_init_fn(worker_id: int): + worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000 + np.random.seed(worker_seed) + random.seed(worker_seed) + log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}') + + +def load_vgg_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset: + dataset = ExtractedVGG(tsv_path=data_cfg.tsv, + data_dim=cfg.data_dim, + premade_mmap_dir=data_cfg.memmap_dir) + + return dataset + + +def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset: + dataset = ExtractedAudio(tsv_path=data_cfg.tsv, + data_dim=cfg.data_dim, + premade_mmap_dir=data_cfg.memmap_dir) + + return dataset + + +def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]: + if cfg.mini_train: + vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG_val) + audiocaps = load_audio_data(cfg, cfg.data.AudioCaps) + dataset = MultiModalDataset([vgg], [audiocaps]) + if cfg.example_train: + video = load_vgg_data(cfg, cfg.data.Example_video) + audio = load_audio_data(cfg, cfg.data.Example_audio) + dataset = MultiModalDataset([video], [audio]) + else: + # load the largest one first + freesound = load_audio_data(cfg, cfg.data.FreeSound) + vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG) + audiocaps = load_audio_data(cfg, cfg.data.AudioCaps) + audioset_sl = load_audio_data(cfg, cfg.data.AudioSetSL) + bbcsound = load_audio_data(cfg, cfg.data.BBCSound) + clotho = load_audio_data(cfg, cfg.data.Clotho) + dataset = MultiModalDataset([vgg] * cfg.vgg_oversample_rate, + [audiocaps, audioset_sl, bbcsound, freesound, clotho]) + + batch_size = cfg.batch_size + num_workers = cfg.num_workers + pin_memory = cfg.pin_memory + sampler, loader = construct_loader(dataset, + batch_size, + num_workers, + shuffle=True, + drop_last=True, + pin_memory=pin_memory) + + return dataset, sampler, loader + + +def setup_test_datasets(cfg): + dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_test) + + batch_size = cfg.batch_size + num_workers = cfg.num_workers + pin_memory = cfg.pin_memory + sampler, loader = construct_loader(dataset, + batch_size, + num_workers, + shuffle=False, + drop_last=False, + pin_memory=pin_memory) + + return dataset, sampler, loader + + +def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]: + if cfg.example_train: + dataset = load_vgg_data(cfg, cfg.data.Example_video) + else: + dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_val) + + val_batch_size = cfg.batch_size + val_eval_batch_size = cfg.eval_batch_size + num_workers = cfg.num_workers + pin_memory = cfg.pin_memory + _, val_loader = construct_loader(dataset, + val_batch_size, + num_workers, + shuffle=False, + drop_last=False, + pin_memory=pin_memory) + _, eval_loader = construct_loader(dataset, + val_eval_batch_size, + num_workers, + shuffle=False, + drop_last=False, + pin_memory=pin_memory) + + return dataset, val_loader, eval_loader + + +def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]: + if dataset_name.startswith('audiocaps_full'): + dataset = AudioCapsData(cfg.eval_data.AudioCaps_full.audio_path, + cfg.eval_data.AudioCaps_full.csv_path) + elif dataset_name.startswith('audiocaps'): + dataset = AudioCapsData(cfg.eval_data.AudioCaps.audio_path, + cfg.eval_data.AudioCaps.csv_path) + elif dataset_name.startswith('moviegen'): + dataset = MovieGen(cfg.eval_data.MovieGen.video_path, + cfg.eval_data.MovieGen.jsonl_path, + duration_sec=cfg.duration_s) + elif dataset_name.startswith('vggsound'): + dataset = VGGSound(cfg.eval_data.VGGSound.video_path, + cfg.eval_data.VGGSound.csv_path, + duration_sec=cfg.duration_s) + else: + raise ValueError(f'Invalid dataset name: {dataset_name}') + + batch_size = cfg.batch_size + num_workers = cfg.num_workers + pin_memory = cfg.pin_memory + _, loader = construct_loader(dataset, + batch_size, + num_workers, + shuffle=False, + drop_last=False, + pin_memory=pin_memory, + error_avoidance=True) + return dataset, loader + + +def error_avoidance_collate(batch): + batch = list(filter(lambda x: x is not None, batch)) + return default_collate(batch) + + +def construct_loader(dataset: Dataset, + batch_size: int, + num_workers: int, + *, + shuffle: bool = True, + drop_last: bool = True, + pin_memory: bool = False, + error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]: + train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle) + train_loader = DataLoader(dataset, + batch_size, + sampler=train_sampler, + num_workers=num_workers, + worker_init_fn=worker_init_fn, + drop_last=drop_last, + persistent_workers=num_workers > 0, + pin_memory=pin_memory, + collate_fn=error_avoidance_collate if error_avoidance else None) + return train_sampler, train_loader diff --git a/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/__init__.py b/third_party/MMAudio/mmaudio/data/eval/__init__.py similarity index 100% rename from mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/__init__.py rename to third_party/MMAudio/mmaudio/data/eval/__init__.py diff --git a/third_party/MMAudio/mmaudio/data/eval/audiocaps.py b/third_party/MMAudio/mmaudio/data/eval/audiocaps.py new file mode 100644 index 0000000000000000000000000000000000000000..35f4fd9e1e300503b0100825e698f82edfd735d1 --- /dev/null +++ b/third_party/MMAudio/mmaudio/data/eval/audiocaps.py @@ -0,0 +1,39 @@ +import logging +import os +from collections import defaultdict +from pathlib import Path +from typing import Union + +import pandas as pd +import torch +from torch.utils.data.dataset import Dataset + +log = logging.getLogger() + + +class AudioCapsData(Dataset): + + def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]): + df = pd.read_csv(csv_path).to_dict(orient='records') + + audio_files = sorted(os.listdir(audio_path)) + audio_files = set( + [Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')]) + + self.data = [] + for row in df: + self.data.append({ + 'name': row['name'], + 'caption': row['caption'], + }) + + self.audio_path = Path(audio_path) + self.csv_path = Path(csv_path) + + log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}') + + def __getitem__(self, idx: int) -> torch.Tensor: + return self.data[idx] + + def __len__(self): + return len(self.data) diff --git a/third_party/MMAudio/mmaudio/data/eval/moviegen.py b/third_party/MMAudio/mmaudio/data/eval/moviegen.py new file mode 100644 index 0000000000000000000000000000000000000000..97969d68385f70eb49e8eb25fc6c3733a0cedda8 --- /dev/null +++ b/third_party/MMAudio/mmaudio/data/eval/moviegen.py @@ -0,0 +1,131 @@ +import json +import logging +import os +from pathlib import Path +from typing import Union + +import torch +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder + +from mmaudio.utils.dist_utils import local_rank + +log = logging.getLogger() + +_CLIP_SIZE = 384 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + + +class MovieGenData(Dataset): + + def __init__( + self, + video_root: Union[str, Path], + sync_root: Union[str, Path], + jsonl_root: Union[str, Path], + *, + duration_sec: float = 10.0, + read_clip: bool = True, + ): + self.video_root = Path(video_root) + self.sync_root = Path(sync_root) + self.jsonl_root = Path(jsonl_root) + self.read_clip = read_clip + + videos = sorted(os.listdir(self.video_root)) + videos = [v[:-4] for v in videos] # remove extensions + self.captions = {} + + for v in videos: + with open(self.jsonl_root / (v + '.jsonl')) as f: + data = json.load(f) + self.captions[v] = data['audio_prompt'] + + if local_rank == 0: + log.info(f'{len(videos)} videos found in {video_root}') + + self.duration_sec = duration_sec + + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_augment = v2.Compose([ + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + + self.sync_augment = v2.Compose([ + v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.videos = videos + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + caption = self.captions[video_id] + + reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + if clip_chunk.shape[0] < self.clip_expected_length: + raise RuntimeError(f'CLIP video too short {video_id}') + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + if sync_chunk.shape[0] < self.sync_expected_length: + raise RuntimeError(f'Sync video too short {video_id}') + + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + if clip_chunk.shape[0] != self.clip_expected_length: + raise RuntimeError(f'CLIP video wrong length {video_id}, ' + f'expected {self.clip_expected_length}, ' + f'got {clip_chunk.shape[0]}') + clip_chunk = self.clip_augment(clip_chunk) + + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + raise RuntimeError(f'Sync video wrong length {video_id}, ' + f'expected {self.sync_expected_length}, ' + f'got {sync_chunk.shape[0]}') + sync_chunk = self.sync_augment(sync_chunk) + + data = { + 'name': video_id, + 'caption': caption, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + return self.sample(idx) + + def __len__(self): + return len(self.captions) diff --git a/third_party/MMAudio/mmaudio/data/eval/video_dataset.py b/third_party/MMAudio/mmaudio/data/eval/video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..0b84a963e6da0c31984a3105dc87a6e9a1918c62 --- /dev/null +++ b/third_party/MMAudio/mmaudio/data/eval/video_dataset.py @@ -0,0 +1,197 @@ +import json +import logging +import os +from pathlib import Path +from typing import Union + +import pandas as pd +import torch +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder + +from mmaudio.utils.dist_utils import local_rank + +log = logging.getLogger() + +_CLIP_SIZE = 384 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + + +class VideoDataset(Dataset): + + def __init__( + self, + video_root: Union[str, Path], + *, + duration_sec: float = 8.0, + ): + self.video_root = Path(video_root) + + self.duration_sec = duration_sec + + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + # to be implemented by subclasses + self.captions = {} + self.videos = sorted(list(self.captions.keys())) + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + caption = self.captions[video_id] + + reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + if clip_chunk.shape[0] < self.clip_expected_length: + raise RuntimeError( + f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}' + ) + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + if sync_chunk.shape[0] < self.sync_expected_length: + raise RuntimeError( + f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}' + ) + + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + if clip_chunk.shape[0] != self.clip_expected_length: + raise RuntimeError(f'CLIP video wrong length {video_id}, ' + f'expected {self.clip_expected_length}, ' + f'got {clip_chunk.shape[0]}') + clip_chunk = self.clip_transform(clip_chunk) + + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + raise RuntimeError(f'Sync video wrong length {video_id}, ' + f'expected {self.sync_expected_length}, ' + f'got {sync_chunk.shape[0]}') + sync_chunk = self.sync_transform(sync_chunk) + + data = { + 'name': video_id, + 'caption': caption, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.captions) + + +class VGGSound(VideoDataset): + + def __init__( + self, + video_root: Union[str, Path], + csv_path: Union[str, Path], + *, + duration_sec: float = 8.0, + ): + super().__init__(video_root, duration_sec=duration_sec) + self.video_root = Path(video_root) + self.csv_path = Path(csv_path) + + videos = sorted(os.listdir(self.video_root)) + if local_rank == 0: + log.info(f'{len(videos)} videos found in {video_root}') + self.captions = {} + + df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption', + 'split']).to_dict(orient='records') + + videos_no_found = [] + for row in df: + if row['split'] == 'test': + start_sec = int(row['sec']) + video_id = str(row['id']) + # this is how our videos are named + video_name = f'{video_id}_{start_sec:06d}' + if video_name + '.mp4' not in videos: + videos_no_found.append(video_name) + continue + + self.captions[video_name] = row['caption'] + + if local_rank == 0: + log.info(f'{len(videos)} videos found in {video_root}') + log.info(f'{len(self.captions)} useable videos found') + if videos_no_found: + log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}') + log.info( + 'A small amount is expected, as not all videos are still available on YouTube') + + self.videos = sorted(list(self.captions.keys())) + + +class MovieGen(VideoDataset): + + def __init__( + self, + video_root: Union[str, Path], + jsonl_root: Union[str, Path], + *, + duration_sec: float = 10.0, + ): + super().__init__(video_root, duration_sec=duration_sec) + self.video_root = Path(video_root) + self.jsonl_root = Path(jsonl_root) + + videos = sorted(os.listdir(self.video_root)) + videos = [v[:-4] for v in videos] # remove extensions + self.captions = {} + + for v in videos: + with open(self.jsonl_root / (v + '.jsonl')) as f: + data = json.load(f) + self.captions[v] = data['audio_prompt'] + + if local_rank == 0: + log.info(f'{len(videos)} videos found in {video_root}') + + self.videos = videos diff --git a/third_party/MMAudio/mmaudio/data/extracted_audio.py b/third_party/MMAudio/mmaudio/data/extracted_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..d23fd6890fedb8d24c167fadbee338a417b0f6a3 --- /dev/null +++ b/third_party/MMAudio/mmaudio/data/extracted_audio.py @@ -0,0 +1,88 @@ +import logging +from pathlib import Path +from typing import Union + +import pandas as pd +import torch +from tensordict import TensorDict +from torch.utils.data.dataset import Dataset + +from mmaudio.utils.dist_utils import local_rank + +log = logging.getLogger() + + +class ExtractedAudio(Dataset): + + def __init__( + self, + tsv_path: Union[str, Path], + *, + premade_mmap_dir: Union[str, Path], + data_dim: dict[str, int], + ): + super().__init__() + + self.data_dim = data_dim + self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records') + self.ids = [str(d['id']) for d in self.df_list] + + log.info(f'Loading precomputed mmap from {premade_mmap_dir}') + # load precomputed memory mapped tensors + premade_mmap_dir = Path(premade_mmap_dir) + td = TensorDict.load_memmap(premade_mmap_dir) + log.info(f'Loaded precomputed mmap from {premade_mmap_dir}') + self.mean = td['mean'] + self.std = td['std'] + self.text_features = td['text_features'] + + log.info(f'Loaded {len(self)} samples from {premade_mmap_dir}.') + log.info(f'Loaded mean: {self.mean.shape}.') + log.info(f'Loaded std: {self.std.shape}.') + log.info(f'Loaded text features: {self.text_features.shape}.') + + assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \ + f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}' + assert self.std.shape[1] == self.data_dim['latent_seq_len'], \ + f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}' + + assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \ + f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}' + assert self.text_features.shape[-1] == self.data_dim['text_dim'], \ + f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}' + + self.fake_clip_features = torch.zeros(self.data_dim['clip_seq_len'], + self.data_dim['clip_dim']) + self.fake_sync_features = torch.zeros(self.data_dim['sync_seq_len'], + self.data_dim['sync_dim']) + self.video_exist = torch.tensor(0, dtype=torch.bool) + self.text_exist = torch.tensor(1, dtype=torch.bool) + + def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: + latents = self.mean + return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1)) + + def get_memory_mapped_tensor(self) -> TensorDict: + td = TensorDict({ + 'mean': self.mean, + 'std': self.std, + 'text_features': self.text_features, + }) + return td + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + data = { + 'id': str(self.df_list[idx]['id']), + 'a_mean': self.mean[idx], + 'a_std': self.std[idx], + 'clip_features': self.fake_clip_features, + 'sync_features': self.fake_sync_features, + 'text_features': self.text_features[idx], + 'caption': self.df_list[idx]['caption'], + 'video_exist': self.video_exist, + 'text_exist': self.text_exist, + } + return data + + def __len__(self): + return len(self.ids) diff --git a/third_party/MMAudio/mmaudio/data/extracted_vgg.py b/third_party/MMAudio/mmaudio/data/extracted_vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..39c8e4b7e72e0bae5dd8d0ba802abb47738ebc4f --- /dev/null +++ b/third_party/MMAudio/mmaudio/data/extracted_vgg.py @@ -0,0 +1,101 @@ +import logging +from pathlib import Path +from typing import Union + +import pandas as pd +import torch +from tensordict import TensorDict +from torch.utils.data.dataset import Dataset + +from mmaudio.utils.dist_utils import local_rank + +log = logging.getLogger() + + +class ExtractedVGG(Dataset): + + def __init__( + self, + tsv_path: Union[str, Path], + *, + premade_mmap_dir: Union[str, Path], + data_dim: dict[str, int], + ): + super().__init__() + + self.data_dim = data_dim + self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records') + self.ids = [d['id'] for d in self.df_list] + + log.info(f'Loading precomputed mmap from {premade_mmap_dir}') + # load precomputed memory mapped tensors + premade_mmap_dir = Path(premade_mmap_dir) + td = TensorDict.load_memmap(premade_mmap_dir) + log.info(f'Loaded precomputed mmap from {premade_mmap_dir}') + self.mean = td['mean'] + self.std = td['std'] + self.clip_features = td['clip_features'] + self.sync_features = td['sync_features'] + self.text_features = td['text_features'] + + if local_rank == 0: + log.info(f'Loaded {len(self)} samples.') + log.info(f'Loaded mean: {self.mean.shape}.') + log.info(f'Loaded std: {self.std.shape}.') + log.info(f'Loaded clip_features: {self.clip_features.shape}.') + log.info(f'Loaded sync_features: {self.sync_features.shape}.') + log.info(f'Loaded text_features: {self.text_features.shape}.') + + assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \ + f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}' + assert self.std.shape[1] == self.data_dim['latent_seq_len'], \ + f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}' + + assert self.clip_features.shape[1] == self.data_dim['clip_seq_len'], \ + f'{self.clip_features.shape[1]} != {self.data_dim["clip_seq_len"]}' + assert self.sync_features.shape[1] == self.data_dim['sync_seq_len'], \ + f'{self.sync_features.shape[1]} != {self.data_dim["sync_seq_len"]}' + assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \ + f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}' + + assert self.clip_features.shape[-1] == self.data_dim['clip_dim'], \ + f'{self.clip_features.shape[-1]} != {self.data_dim["clip_dim"]}' + assert self.sync_features.shape[-1] == self.data_dim['sync_dim'], \ + f'{self.sync_features.shape[-1]} != {self.data_dim["sync_dim"]}' + assert self.text_features.shape[-1] == self.data_dim['text_dim'], \ + f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}' + + self.video_exist = torch.tensor(1, dtype=torch.bool) + self.text_exist = torch.tensor(1, dtype=torch.bool) + + def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: + latents = self.mean + return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1)) + + def get_memory_mapped_tensor(self) -> TensorDict: + td = TensorDict({ + 'mean': self.mean, + 'std': self.std, + 'clip_features': self.clip_features, + 'sync_features': self.sync_features, + 'text_features': self.text_features, + }) + return td + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + data = { + 'id': self.df_list[idx]['id'], + 'a_mean': self.mean[idx], + 'a_std': self.std[idx], + 'clip_features': self.clip_features[idx], + 'sync_features': self.sync_features[idx], + 'text_features': self.text_features[idx], + 'caption': self.df_list[idx]['label'], + 'video_exist': self.video_exist, + 'text_exist': self.text_exist, + } + + return data + + def __len__(self): + return len(self.ids) diff --git a/mmaudio/model/__init__.py b/third_party/MMAudio/mmaudio/data/extraction/__init__.py similarity index 100% rename from mmaudio/model/__init__.py rename to third_party/MMAudio/mmaudio/data/extraction/__init__.py diff --git a/third_party/MMAudio/mmaudio/data/extraction/vgg_sound.py b/third_party/MMAudio/mmaudio/data/extraction/vgg_sound.py new file mode 100644 index 0000000000000000000000000000000000000000..116710d1fac2518807611564e1a1dc32dbd0bf07 --- /dev/null +++ b/third_party/MMAudio/mmaudio/data/extraction/vgg_sound.py @@ -0,0 +1,193 @@ +import logging +import os +from pathlib import Path +from typing import Optional, Union + +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset +from torchvision.transforms import v2 +from torio.io import StreamingMediaDecoder + +from mmaudio.utils.dist_utils import local_rank + +log = logging.getLogger() + +_CLIP_SIZE = 384 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 + + +class VGGSound(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv', + sample_rate: int = 16_000, + duration_sec: float = 8.0, + audio_samples: Optional[int] = None, + normalize_audio: bool = False, + ): + self.root = Path(root) + self.normalize_audio = normalize_audio + if audio_samples is None: + self.audio_samples = int(sample_rate * duration_sec) + else: + self.audio_samples = audio_samples + effective_duration = audio_samples / sample_rate + # make sure the duration is close enough, within 15ms + assert abs(effective_duration - duration_sec) < 0.015, \ + f'audio_samples {audio_samples} does not match duration_sec {duration_sec}' + + videos = sorted(os.listdir(self.root)) + videos = set([Path(v).stem for v in videos]) # remove extensions + self.labels = {} + self.videos = [] + missing_videos = [] + + # read the tsv for subset information + df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records') + for record in df_list: + id = record['id'] + label = record['label'] + if id in videos: + self.labels[id] = label + self.videos.append(id) + else: + missing_videos.append(id) + + if local_rank == 0: + log.info(f'{len(videos)} videos found in {root}') + log.info(f'{len(self.videos)} videos found in {tsv_path}') + log.info(f'{len(missing_videos)} videos missing in {root}') + + self.sample_rate = sample_rate + self.duration_sec = duration_sec + + self.expected_audio_length = audio_samples + self.clip_expected_length = int(_CLIP_FPS * self.duration_sec) + self.sync_expected_length = int(_SYNC_FPS * self.duration_sec) + + self.clip_transform = v2.Compose([ + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + + self.sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.resampler = {} + + def sample(self, idx: int) -> dict[str, torch.Tensor]: + video_id = self.videos[idx] + label = self.labels[video_id] + + reader = StreamingMediaDecoder(self.root / (video_id + '.mp4')) + reader.add_basic_video_stream( + frames_per_chunk=int(_CLIP_FPS * self.duration_sec), + frame_rate=_CLIP_FPS, + format='rgb24', + ) + reader.add_basic_video_stream( + frames_per_chunk=int(_SYNC_FPS * self.duration_sec), + frame_rate=_SYNC_FPS, + format='rgb24', + ) + reader.add_basic_audio_stream(frames_per_chunk=2**30, ) + + reader.fill_buffer() + data_chunk = reader.pop_chunks() + + clip_chunk = data_chunk[0] + sync_chunk = data_chunk[1] + audio_chunk = data_chunk[2] + + if clip_chunk is None: + raise RuntimeError(f'CLIP video returned None {video_id}') + if clip_chunk.shape[0] < self.clip_expected_length: + raise RuntimeError( + f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}' + ) + + if sync_chunk is None: + raise RuntimeError(f'Sync video returned None {video_id}') + if sync_chunk.shape[0] < self.sync_expected_length: + raise RuntimeError( + f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}' + ) + + # process audio + sample_rate = int(reader.get_out_stream_info(2).sample_rate) + audio_chunk = audio_chunk.transpose(0, 1) + audio_chunk = audio_chunk.mean(dim=0) # mono + if self.normalize_audio: + abs_max = audio_chunk.abs().max() + audio_chunk = audio_chunk / abs_max * 0.95 + if abs_max <= 1e-6: + raise RuntimeError(f'Audio is silent {video_id}') + + # resample + if sample_rate == self.sample_rate: + audio_chunk = audio_chunk + else: + if sample_rate not in self.resampler: + # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best + self.resampler[sample_rate] = torchaudio.transforms.Resample( + sample_rate, + self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + audio_chunk = self.resampler[sample_rate](audio_chunk) + + if audio_chunk.shape[0] < self.expected_audio_length: + raise RuntimeError(f'Audio too short {video_id}') + audio_chunk = audio_chunk[:self.expected_audio_length] + + # truncate the video + clip_chunk = clip_chunk[:self.clip_expected_length] + if clip_chunk.shape[0] != self.clip_expected_length: + raise RuntimeError(f'CLIP video wrong length {video_id}, ' + f'expected {self.clip_expected_length}, ' + f'got {clip_chunk.shape[0]}') + clip_chunk = self.clip_transform(clip_chunk) + + sync_chunk = sync_chunk[:self.sync_expected_length] + if sync_chunk.shape[0] != self.sync_expected_length: + raise RuntimeError(f'Sync video wrong length {video_id}, ' + f'expected {self.sync_expected_length}, ' + f'got {sync_chunk.shape[0]}') + sync_chunk = self.sync_transform(sync_chunk) + + data = { + 'id': video_id, + 'caption': label, + 'audio': audio_chunk, + 'clip_video': clip_chunk, + 'sync_video': sync_chunk, + } + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + try: + return self.sample(idx) + except Exception as e: + log.error(f'Error loading video {self.videos[idx]}: {e}') + return None + + def __len__(self): + return len(self.labels) diff --git a/third_party/MMAudio/mmaudio/data/extraction/wav_dataset.py b/third_party/MMAudio/mmaudio/data/extraction/wav_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..95bfbb3d7dea50ad9c8822e4626dda9582d7cd55 --- /dev/null +++ b/third_party/MMAudio/mmaudio/data/extraction/wav_dataset.py @@ -0,0 +1,132 @@ +import logging +import os +from pathlib import Path +from typing import Union + +import open_clip +import pandas as pd +import torch +import torchaudio +from torch.utils.data.dataset import Dataset + +log = logging.getLogger() + + +class WavTextClipsDataset(Dataset): + + def __init__( + self, + root: Union[str, Path], + *, + captions_tsv: Union[str, Path], + clips_tsv: Union[str, Path], + sample_rate: int, + num_samples: int, + normalize_audio: bool = False, + reject_silent: bool = False, + tokenizer_id: str = 'ViT-H-14-378-quickgelu', + ): + self.root = Path(root) + self.sample_rate = sample_rate + self.num_samples = num_samples + self.normalize_audio = normalize_audio + self.reject_silent = reject_silent + self.tokenizer = open_clip.get_tokenizer(tokenizer_id) + + audios = sorted(os.listdir(self.root)) + audios = set([ + Path(audio).stem for audio in audios + if audio.endswith('.wav') or audio.endswith('.flac') + ]) + self.captions = {} + + # read the caption tsv + df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records') + for record in df_list: + id = record['id'] + caption = record['caption'] + self.captions[id] = caption + + # read the clip tsv + df_list = pd.read_csv(clips_tsv, sep='\t', dtype={ + 'id': str, + 'name': str + }).to_dict('records') + self.clips = [] + for record in df_list: + record['id'] = record['id'] + record['name'] = record['name'] + id = record['id'] + name = record['name'] + if name not in self.captions: + log.warning(f'Audio {name} not found in {captions_tsv}') + continue + record['caption'] = self.captions[name] + self.clips.append(record) + + log.info(f'Found {len(self.clips)} audio files in {self.root}') + + self.resampler = {} + + def __getitem__(self, idx: int) -> torch.Tensor: + try: + clip = self.clips[idx] + audio_name = clip['name'] + audio_id = clip['id'] + caption = clip['caption'] + start_sample = clip['start_sample'] + end_sample = clip['end_sample'] + + audio_path = self.root / f'{audio_name}.flac' + if not audio_path.exists(): + audio_path = self.root / f'{audio_name}.wav' + assert audio_path.exists() + + audio_chunk, sample_rate = torchaudio.load(audio_path) + audio_chunk = audio_chunk.mean(dim=0) # mono + abs_max = audio_chunk.abs().max() + if self.normalize_audio: + audio_chunk = audio_chunk / abs_max * 0.95 + + if self.reject_silent and abs_max < 1e-6: + log.warning(f'Rejecting silent audio') + return None + + audio_chunk = audio_chunk[start_sample:end_sample] + + # resample + if sample_rate == self.sample_rate: + audio_chunk = audio_chunk + else: + if sample_rate not in self.resampler: + # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best + self.resampler[sample_rate] = torchaudio.transforms.Resample( + sample_rate, + self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + audio_chunk = self.resampler[sample_rate](audio_chunk) + + if audio_chunk.shape[0] < self.num_samples: + raise ValueError('Audio is too short') + audio_chunk = audio_chunk[:self.num_samples] + + tokens = self.tokenizer([caption])[0] + + output = { + 'waveform': audio_chunk, + 'id': audio_id, + 'caption': caption, + 'tokens': tokens, + } + + return output + except Exception as e: + log.error(f'Error reading {audio_path}: {e}') + return None + + def __len__(self): + return len(self.clips) diff --git a/third_party/MMAudio/mmaudio/data/mm_dataset.py b/third_party/MMAudio/mmaudio/data/mm_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a9c7d3d02fc0534592e7c990d19be2d6b378b56c --- /dev/null +++ b/third_party/MMAudio/mmaudio/data/mm_dataset.py @@ -0,0 +1,45 @@ +import bisect + +import torch +from torch.utils.data.dataset import Dataset + + +# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset +class MultiModalDataset(Dataset): + datasets: list[Dataset] + cumulative_sizes: list[int] + + @staticmethod + def cumsum(sequence): + r, s = [], 0 + for e in sequence: + l = len(e) + r.append(l + s) + s += l + return r + + def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]): + super().__init__() + self.video_datasets = list(video_datasets) + self.audio_datasets = list(audio_datasets) + self.datasets = self.video_datasets + self.audio_datasets + + self.cumulative_sizes = self.cumsum(self.datasets) + + def __len__(self): + return self.cumulative_sizes[-1] + + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx] + + def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]: + return self.video_datasets[0].compute_latent_stats() diff --git a/third_party/MMAudio/mmaudio/data/utils.py b/third_party/MMAudio/mmaudio/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f782ceaf4a506c6f81886981cd55492fd0a5cccf --- /dev/null +++ b/third_party/MMAudio/mmaudio/data/utils.py @@ -0,0 +1,148 @@ +import logging +import os +import random +import tempfile +from pathlib import Path +from typing import Any, Optional, Union + +import torch +import torch.distributed as dist +from tensordict import MemoryMappedTensor +from torch.utils.data import DataLoader +from torch.utils.data.dataset import Dataset +from tqdm import tqdm + +from mmaudio.utils.dist_utils import local_rank, world_size + +scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm') +shm_path = Path('/dev/shm') + +log = logging.getLogger() + + +def reseed(seed): + random.seed(seed) + torch.manual_seed(seed) + + +def local_scatter_torch(obj: Optional[Any]): + if world_size == 1: + # Just one worker. Do nothing. + return obj + + array = [obj] * world_size + target_array = [None] + if local_rank == 0: + dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0) + else: + dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0) + return target_array[0] + + +class ShardDataset(Dataset): + + def __init__(self, root): + self.root = root + self.shards = sorted(os.listdir(root)) + + def __len__(self): + return len(self.shards) + + def __getitem__(self, idx): + return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True) + + +def get_tmp_dir(in_memory: bool) -> Path: + return shm_path if in_memory else scratch_path + + +def load_shards_and_share(data_path: Union[str, Path], ids: list[int], + in_memory: bool) -> MemoryMappedTensor: + if local_rank == 0: + with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f: + log.info(f'Loading shards from {data_path} into {f.name}...') + data = load_shards(data_path, ids=ids, tmp_file_path=f.name) + data = share_tensor_to_all(data) + torch.distributed.barrier() + f.close() # why does the context manager not close the file for me? + else: + log.info('Waiting for the data to be shared with me...') + data = share_tensor_to_all(None) + torch.distributed.barrier() + + return data + + +def load_shards( + data_path: Union[str, Path], + ids: list[int], + *, + tmp_file_path: str, +) -> Union[torch.Tensor, dict[str, torch.Tensor]]: + + id_set = set(ids) + shards = sorted(os.listdir(data_path)) + log.info(f'Found {len(shards)} shards in {data_path}.') + first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True) + + log.info(f'Rank {local_rank} created file {tmp_file_path}') + first_item = next(iter(first_shard.values())) + log.info(f'First item shape: {first_item.shape}') + mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape), + dtype=torch.float32, + filename=tmp_file_path, + existsok=True) + total_count = 0 + used_index = set() + id_indexing = {i: idx for idx, i in enumerate(ids)} + # faster with no workers; otherwise we need to set_sharing_strategy('file_system') + loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0) + for data in tqdm(loader, desc='Loading shards'): + for i, v in data.items(): + if i not in id_set: + continue + + # tensor_index = ids.index(i) + tensor_index = id_indexing[i] + if tensor_index in used_index: + raise ValueError(f'Duplicate id {i} found in {data_path}.') + used_index.add(tensor_index) + mm_tensor[tensor_index] = v + total_count += 1 + + assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.' + log.info(f'Loaded {total_count} tensors from {data_path}.') + + return mm_tensor + + +def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor: + """ + x: the tensor to be shared; None if local_rank != 0 + return: the shared tensor + """ + + # there is no need to share your stuff with anyone if you are alone; must be in memory + if world_size == 1: + return x + + if local_rank == 0: + assert x is not None, 'x must not be None if local_rank == 0' + else: + assert x is None, 'x must be None if local_rank != 0' + + if local_rank == 0: + filename = x.filename + meta_information = (filename, x.shape, x.dtype) + else: + meta_information = None + + filename, data_shape, data_type = local_scatter_torch(meta_information) + if local_rank == 0: + data = x + else: + data = MemoryMappedTensor.from_filename(filename=filename, + dtype=data_type, + shape=data_shape) + + return data diff --git a/mmaudio/eval_utils.py b/third_party/MMAudio/mmaudio/eval_utils.py similarity index 70% rename from mmaudio/eval_utils.py rename to third_party/MMAudio/mmaudio/eval_utils.py index a5c9291f2687855b10b63b3f6e67e299c86cbbbe..74357cbd70aced61563798a88c38d3d486b206c5 100644 --- a/mmaudio/eval_utils.py +++ b/third_party/MMAudio/mmaudio/eval_utils.py @@ -1,16 +1,18 @@ import dataclasses import logging from pathlib import Path -from typing import Optional +from typing import Optional, Tuple, List, Dict +import numpy as np import torch from colorlog import ColoredFormatter +from PIL import Image from torchvision.transforms import v2 -from mmaudio.data.av_utils import VideoInfo, read_frames, reencode_with_audio +from mmaudio.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio from mmaudio.model.flow_matching import FlowMatching from mmaudio.model.networks import MMAudio -from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig) +from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig from mmaudio.model.utils.features_utils import FeaturesUtils from mmaudio.utils.download_utils import download_model_if_needed @@ -24,7 +26,7 @@ class ModelConfig: vae_path: Path bigvgan_16k_path: Optional[Path] mode: str - synchformer_ckpt: Path = Path('./ext_weights/synchformer_state_dict.pth') + synchformer_ckpt: Path = Path('./pretrained/v2a/mmaudio/ext_weights/synchformer_state_dict.pth') @property def seq_cfg(self) -> SequenceConfig: @@ -42,31 +44,31 @@ class ModelConfig: small_16k = ModelConfig(model_name='small_16k', - model_path=Path('./weights/mmaudio_small_16k.pth'), - vae_path=Path('./ext_weights/v1-16.pth'), - bigvgan_16k_path=Path('./ext_weights/best_netG.pt'), + model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_small_16k.pth'), + vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-16.pth'), + bigvgan_16k_path=Path('./pretrained/v2a/mmaudio/ext_weights/best_netG.pt'), mode='16k') small_44k = ModelConfig(model_name='small_44k', - model_path=Path('./weights/mmaudio_small_44k.pth'), - vae_path=Path('./ext_weights/v1-44.pth'), + model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_small_44k.pth'), + vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'), bigvgan_16k_path=None, mode='44k') medium_44k = ModelConfig(model_name='medium_44k', - model_path=Path('./weights/mmaudio_medium_44k.pth'), - vae_path=Path('./ext_weights/v1-44.pth'), + model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_medium_44k.pth'), + vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'), bigvgan_16k_path=None, mode='44k') large_44k = ModelConfig(model_name='large_44k', - model_path=Path('./weights/mmaudio_large_44k.pth'), - vae_path=Path('./ext_weights/v1-44.pth'), + model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_large_44k.pth'), + vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'), bigvgan_16k_path=None, mode='44k') large_44k_v2 = ModelConfig(model_name='large_44k_v2', - model_path=Path('./weights/mmaudio_large_44k_v2.pth'), - vae_path=Path('./ext_weights/v1-44.pth'), + model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_large_44k_v2.pth'), + vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'), bigvgan_16k_path=None, mode='44k') -all_model_cfg: dict[str, ModelConfig] = { +all_model_cfg: Dict[str, ModelConfig] = { 'small_16k': small_16k, 'small_44k': small_44k, 'medium_44k': medium_44k, @@ -78,9 +80,9 @@ all_model_cfg: dict[str, ModelConfig] = { def generate( clip_video: Optional[torch.Tensor], sync_video: Optional[torch.Tensor], - text: Optional[list[str]], + text: Optional[List[str]], *, - negative_text: Optional[list[str]] = None, + negative_text: Optional[List[str]] = None, feature_utils: FeaturesUtils, net: MMAudio, fm: FlowMatching, @@ -88,6 +90,7 @@ def generate( cfg_strength: float, clip_batch_size_multiplier: int = 40, sync_batch_size_multiplier: int = 40, + image_input: bool = False, ) -> torch.Tensor: device = feature_utils.device dtype = feature_utils.dtype @@ -98,10 +101,12 @@ def generate( clip_features = feature_utils.encode_video_with_clip(clip_video, batch_size=bs * clip_batch_size_multiplier) + if image_input: + clip_features = clip_features.expand(-1, net.clip_seq_len, -1) else: clip_features = net.get_empty_clip_sequence(bs) - if sync_video is not None: + if sync_video is not None and not image_input: sync_video = sync_video.to(device, dtype, non_blocking=True) sync_features = feature_utils.encode_video_with_sync(sync_video, batch_size=bs * @@ -139,7 +144,7 @@ def generate( return audio -LOGFORMAT = " %(log_color)s%(levelname)-8s%(reset)s | %(log_color)s%(message)s%(reset)s" +LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s" def setup_eval_logging(log_level: int = logging.INFO): @@ -153,12 +158,14 @@ def setup_eval_logging(log_level: int = logging.INFO): log.addHandler(stream) -def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo: - _CLIP_SIZE = 384 - _CLIP_FPS = 8.0 +_CLIP_SIZE = 384 +_CLIP_FPS = 8.0 + +_SYNC_SIZE = 224 +_SYNC_FPS = 25.0 - _SYNC_SIZE = 224 - _SYNC_FPS = 25.0 + +def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo: clip_transform = v2.Compose([ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), @@ -213,5 +220,36 @@ def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = Tr return video_info +def load_image(image_path: Path) -> VideoInfo: + clip_transform = v2.Compose([ + v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + ]) + + sync_transform = v2.Compose([ + v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC), + v2.CenterCrop(_SYNC_SIZE), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + frame = np.array(Image.open(image_path)) + + clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2) + sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2) + + clip_frames = clip_transform(clip_chunk) + sync_frames = sync_transform(sync_chunk) + + video_info = ImageInfo( + clip_frames=clip_frames, + sync_frames=sync_frames, + original_frame=frame, + ) + return video_info + + def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int): reencode_with_audio(video_info, output_path, audio, sampling_rate) diff --git a/mmaudio/ext/__init__.py b/third_party/MMAudio/mmaudio/ext/__init__.py similarity index 100% rename from mmaudio/ext/__init__.py rename to third_party/MMAudio/mmaudio/ext/__init__.py diff --git a/mmaudio/ext/autoencoder/__init__.py b/third_party/MMAudio/mmaudio/ext/autoencoder/__init__.py similarity index 100% rename from mmaudio/ext/autoencoder/__init__.py rename to third_party/MMAudio/mmaudio/ext/autoencoder/__init__.py diff --git a/mmaudio/ext/autoencoder/autoencoder.py b/third_party/MMAudio/mmaudio/ext/autoencoder/autoencoder.py similarity index 96% rename from mmaudio/ext/autoencoder/autoencoder.py rename to third_party/MMAudio/mmaudio/ext/autoencoder/autoencoder.py index 5b444656112f9c4e5d9493c8fce40c118a2e31d5..b77db4ebc0b68a68dfe23a28b2ff83537fd0e619 100644 --- a/mmaudio/ext/autoencoder/autoencoder.py +++ b/third_party/MMAudio/mmaudio/ext/autoencoder/autoencoder.py @@ -20,7 +20,7 @@ class AutoEncoderModule(nn.Module): super().__init__() self.vae: VAE = get_my_vae(mode).eval() vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu') - self.vae.load_state_dict(vae_state_dict, strict=False) + self.vae.load_state_dict(vae_state_dict) self.vae.remove_weight_norm() if mode == '16k': diff --git a/mmaudio/ext/autoencoder/edm2_utils.py b/third_party/MMAudio/mmaudio/ext/autoencoder/edm2_utils.py similarity index 100% rename from mmaudio/ext/autoencoder/edm2_utils.py rename to third_party/MMAudio/mmaudio/ext/autoencoder/edm2_utils.py diff --git a/mmaudio/ext/autoencoder/vae.py b/third_party/MMAudio/mmaudio/ext/autoencoder/vae.py similarity index 92% rename from mmaudio/ext/autoencoder/vae.py rename to third_party/MMAudio/mmaudio/ext/autoencoder/vae.py index 204c2e01cf9fc89eb718f8aa266a1c6a7e443312..b69ed59d2fcffb3d69ed29c6306b9f103dbadc93 100644 --- a/mmaudio/ext/autoencoder/vae.py +++ b/third_party/MMAudio/mmaudio/ext/autoencoder/vae.py @@ -1,5 +1,5 @@ import logging -from typing import Optional +from typing import Optional, Tuple, List import torch import torch.nn as nn @@ -75,15 +75,11 @@ class VAE(nn.Module): super().__init__() if data_dim == 80: - # self.data_mean = torch.tensor(DATA_MEAN_80D, dtype=torch.float32).cuda() - # self.data_std = torch.tensor(DATA_STD_80D, dtype=torch.float32).cuda() - self.register_buffer('data_mean', torch.tensor(DATA_MEAN_80D, dtype=torch.float32)) - self.register_buffer('data_std', torch.tensor(DATA_STD_80D, dtype=torch.float32)) + self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32)) + self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32)) elif data_dim == 128: - # torch.tensor(DATA_MEAN_128D, dtype=torch.float32).cuda() - # self.data_std = torch.tensor(DATA_STD_128D, dtype=torch.float32).cuda() - self.register_buffer('data_mean', torch.tensor(DATA_MEAN_128D, dtype=torch.float32)) - self.register_buffer('data_std', torch.tensor(DATA_STD_128D, dtype=torch.float32)) + self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32)) + self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32)) self.data_mean = self.data_mean.view(1, -1, 1) self.data_std = self.data_std.view(1, -1, 1) @@ -143,7 +139,7 @@ class VAE(nn.Module): rng: Optional[torch.Generator] = None, normalize: bool = True, unnormalize: bool = True, - ) -> tuple[torch.Tensor, DiagonalGaussianDistribution]: + ) -> Tuple[torch.Tensor, DiagonalGaussianDistribution]: posterior = self.encode(x, normalize=normalize) if sample_posterior: @@ -176,10 +172,10 @@ class Encoder1D(nn.Module): def __init__(self, *, dim: int, - ch_mult: tuple[int] = (1, 2, 4, 8), + ch_mult: Tuple[int] = (1, 2, 4, 8), num_res_blocks: int, - attn_layers: list[int] = [], - down_layers: list[int] = [], + attn_layers: List[int] = [], + down_layers: List[int] = [], resamp_with_conv: bool = True, in_dim: int, embed_dim: int, @@ -273,10 +269,10 @@ class Decoder1D(nn.Module): *, dim: int, out_dim: int, - ch_mult: tuple[int] = (1, 2, 4, 8), + ch_mult: Tuple[int] = (1, 2, 4, 8), num_res_blocks: int, - attn_layers: list[int] = [], - down_layers: list[int] = [], + attn_layers: List[int] = [], + down_layers: List[int] = [], kernel_size: int = 3, resamp_with_conv: bool = True, in_dim: int, diff --git a/mmaudio/ext/autoencoder/vae_modules.py b/third_party/MMAudio/mmaudio/ext/autoencoder/vae_modules.py similarity index 100% rename from mmaudio/ext/autoencoder/vae_modules.py rename to third_party/MMAudio/mmaudio/ext/autoencoder/vae_modules.py diff --git a/mmaudio/ext/bigvgan/LICENSE b/third_party/MMAudio/mmaudio/ext/bigvgan/LICENSE similarity index 100% rename from mmaudio/ext/bigvgan/LICENSE rename to third_party/MMAudio/mmaudio/ext/bigvgan/LICENSE diff --git a/mmaudio/ext/bigvgan/__init__.py b/third_party/MMAudio/mmaudio/ext/bigvgan/__init__.py similarity index 100% rename from mmaudio/ext/bigvgan/__init__.py rename to third_party/MMAudio/mmaudio/ext/bigvgan/__init__.py diff --git a/mmaudio/ext/bigvgan/activations.py b/third_party/MMAudio/mmaudio/ext/bigvgan/activations.py similarity index 100% rename from mmaudio/ext/bigvgan/activations.py rename to third_party/MMAudio/mmaudio/ext/bigvgan/activations.py diff --git a/mmaudio/ext/bigvgan/alias_free_torch/__init__.py b/third_party/MMAudio/mmaudio/ext/bigvgan/alias_free_torch/__init__.py similarity index 100% rename from mmaudio/ext/bigvgan/alias_free_torch/__init__.py rename to third_party/MMAudio/mmaudio/ext/bigvgan/alias_free_torch/__init__.py diff --git a/mmaudio/ext/bigvgan/alias_free_torch/act.py b/third_party/MMAudio/mmaudio/ext/bigvgan/alias_free_torch/act.py similarity index 100% rename from mmaudio/ext/bigvgan/alias_free_torch/act.py rename to third_party/MMAudio/mmaudio/ext/bigvgan/alias_free_torch/act.py diff --git a/mmaudio/ext/bigvgan/alias_free_torch/filter.py b/third_party/MMAudio/mmaudio/ext/bigvgan/alias_free_torch/filter.py similarity index 100% rename from mmaudio/ext/bigvgan/alias_free_torch/filter.py rename to third_party/MMAudio/mmaudio/ext/bigvgan/alias_free_torch/filter.py diff --git a/mmaudio/ext/bigvgan/alias_free_torch/resample.py b/third_party/MMAudio/mmaudio/ext/bigvgan/alias_free_torch/resample.py similarity index 100% rename from mmaudio/ext/bigvgan/alias_free_torch/resample.py rename to third_party/MMAudio/mmaudio/ext/bigvgan/alias_free_torch/resample.py diff --git a/mmaudio/ext/bigvgan/bigvgan.py b/third_party/MMAudio/mmaudio/ext/bigvgan/bigvgan.py similarity index 100% rename from mmaudio/ext/bigvgan/bigvgan.py rename to third_party/MMAudio/mmaudio/ext/bigvgan/bigvgan.py diff --git a/mmaudio/ext/bigvgan/bigvgan_vocoder.yml b/third_party/MMAudio/mmaudio/ext/bigvgan/bigvgan_vocoder.yml similarity index 100% rename from mmaudio/ext/bigvgan/bigvgan_vocoder.yml rename to third_party/MMAudio/mmaudio/ext/bigvgan/bigvgan_vocoder.yml diff --git a/mmaudio/ext/bigvgan/env.py b/third_party/MMAudio/mmaudio/ext/bigvgan/env.py similarity index 100% rename from mmaudio/ext/bigvgan/env.py rename to third_party/MMAudio/mmaudio/ext/bigvgan/env.py diff --git a/mmaudio/ext/bigvgan/incl_licenses/LICENSE_1 b/third_party/MMAudio/mmaudio/ext/bigvgan/incl_licenses/LICENSE_1 similarity index 100% rename from mmaudio/ext/bigvgan/incl_licenses/LICENSE_1 rename to third_party/MMAudio/mmaudio/ext/bigvgan/incl_licenses/LICENSE_1 diff --git a/mmaudio/ext/bigvgan/incl_licenses/LICENSE_2 b/third_party/MMAudio/mmaudio/ext/bigvgan/incl_licenses/LICENSE_2 similarity index 100% rename from mmaudio/ext/bigvgan/incl_licenses/LICENSE_2 rename to third_party/MMAudio/mmaudio/ext/bigvgan/incl_licenses/LICENSE_2 diff --git a/mmaudio/ext/bigvgan/incl_licenses/LICENSE_3 b/third_party/MMAudio/mmaudio/ext/bigvgan/incl_licenses/LICENSE_3 similarity index 100% rename from mmaudio/ext/bigvgan/incl_licenses/LICENSE_3 rename to third_party/MMAudio/mmaudio/ext/bigvgan/incl_licenses/LICENSE_3 diff --git a/mmaudio/ext/bigvgan/incl_licenses/LICENSE_4 b/third_party/MMAudio/mmaudio/ext/bigvgan/incl_licenses/LICENSE_4 similarity index 100% rename from mmaudio/ext/bigvgan/incl_licenses/LICENSE_4 rename to third_party/MMAudio/mmaudio/ext/bigvgan/incl_licenses/LICENSE_4 diff --git a/mmaudio/ext/bigvgan/incl_licenses/LICENSE_5 b/third_party/MMAudio/mmaudio/ext/bigvgan/incl_licenses/LICENSE_5 similarity index 100% rename from mmaudio/ext/bigvgan/incl_licenses/LICENSE_5 rename to third_party/MMAudio/mmaudio/ext/bigvgan/incl_licenses/LICENSE_5 diff --git a/mmaudio/ext/bigvgan/models.py b/third_party/MMAudio/mmaudio/ext/bigvgan/models.py similarity index 100% rename from mmaudio/ext/bigvgan/models.py rename to third_party/MMAudio/mmaudio/ext/bigvgan/models.py diff --git a/mmaudio/ext/bigvgan/utils.py b/third_party/MMAudio/mmaudio/ext/bigvgan/utils.py similarity index 100% rename from mmaudio/ext/bigvgan/utils.py rename to third_party/MMAudio/mmaudio/ext/bigvgan/utils.py diff --git a/mmaudio/ext/bigvgan_v2/LICENSE b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/LICENSE similarity index 100% rename from mmaudio/ext/bigvgan_v2/LICENSE rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/LICENSE diff --git a/mmaudio/model/utils/__init__.py b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/__init__.py similarity index 100% rename from mmaudio/model/utils/__init__.py rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/__init__.py diff --git a/mmaudio/ext/bigvgan_v2/activations.py b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/activations.py similarity index 100% rename from mmaudio/ext/bigvgan_v2/activations.py rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/activations.py diff --git a/mmaudio/utils/__init__.py b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/__init__.py similarity index 100% rename from mmaudio/utils/__init__.py rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/__init__.py diff --git a/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/activation1d.py b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/activation1d.py similarity index 100% rename from mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/activation1d.py rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/activation1d.py diff --git a/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/anti_alias_activation.cpp b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/anti_alias_activation.cpp similarity index 100% rename from mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/anti_alias_activation.cpp rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/anti_alias_activation.cpp diff --git a/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/anti_alias_activation_cuda.cu b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/anti_alias_activation_cuda.cu similarity index 100% rename from mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/anti_alias_activation_cuda.cu rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/anti_alias_activation_cuda.cu diff --git a/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/compat.h b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/compat.h similarity index 100% rename from mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/compat.h rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/compat.h diff --git a/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/load.py b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/load.py similarity index 100% rename from mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/load.py rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/load.py diff --git a/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/type_shim.h b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/type_shim.h similarity index 100% rename from mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/type_shim.h rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/cuda/type_shim.h diff --git a/mmaudio/ext/bigvgan_v2/alias_free_activation/torch/__init__.py b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/torch/__init__.py similarity index 100% rename from mmaudio/ext/bigvgan_v2/alias_free_activation/torch/__init__.py rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/torch/__init__.py diff --git a/mmaudio/ext/bigvgan_v2/alias_free_activation/torch/act.py b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/torch/act.py similarity index 100% rename from mmaudio/ext/bigvgan_v2/alias_free_activation/torch/act.py rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/torch/act.py diff --git a/mmaudio/ext/bigvgan_v2/alias_free_activation/torch/filter.py b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/torch/filter.py similarity index 100% rename from mmaudio/ext/bigvgan_v2/alias_free_activation/torch/filter.py rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/torch/filter.py diff --git a/mmaudio/ext/bigvgan_v2/alias_free_activation/torch/resample.py b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/torch/resample.py similarity index 100% rename from mmaudio/ext/bigvgan_v2/alias_free_activation/torch/resample.py rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/alias_free_activation/torch/resample.py diff --git a/mmaudio/ext/bigvgan_v2/bigvgan.py b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/bigvgan.py similarity index 100% rename from mmaudio/ext/bigvgan_v2/bigvgan.py rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/bigvgan.py diff --git a/mmaudio/ext/bigvgan_v2/env.py b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/env.py similarity index 100% rename from mmaudio/ext/bigvgan_v2/env.py rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/env.py diff --git a/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_1 b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_1 similarity index 100% rename from mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_1 rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_1 diff --git a/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_2 b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_2 similarity index 100% rename from mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_2 rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_2 diff --git a/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_3 b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_3 similarity index 100% rename from mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_3 rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_3 diff --git a/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_4 b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_4 similarity index 100% rename from mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_4 rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_4 diff --git a/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_5 b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_5 similarity index 100% rename from mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_5 rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_5 diff --git a/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_6 b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_6 similarity index 100% rename from mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_6 rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_6 diff --git a/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_7 b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_7 similarity index 100% rename from mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_7 rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_7 diff --git a/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_8 b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_8 similarity index 100% rename from mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_8 rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/incl_licenses/LICENSE_8 diff --git a/mmaudio/ext/bigvgan_v2/utils.py b/third_party/MMAudio/mmaudio/ext/bigvgan_v2/utils.py similarity index 100% rename from mmaudio/ext/bigvgan_v2/utils.py rename to third_party/MMAudio/mmaudio/ext/bigvgan_v2/utils.py diff --git a/mmaudio/ext/mel_converter.py b/third_party/MMAudio/mmaudio/ext/mel_converter.py similarity index 67% rename from mmaudio/ext/mel_converter.py rename to third_party/MMAudio/mmaudio/ext/mel_converter.py index 6fc589c9468e077fc580965db250fd502e229672..15266d22fb95176229643597a5fea8304888007d 100644 --- a/mmaudio/ext/mel_converter.py +++ b/third_party/MMAudio/mmaudio/ext/mel_converter.py @@ -1,11 +1,12 @@ # Reference: # https://github.com/bytedance/Make-An-Audio-2 +from typing import Literal import torch import torch.nn as nn from librosa.filters import mel as librosa_mel_fn -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, norm_fn=torch.log10): +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn): return norm_fn(torch.clamp(x, min=clip_val) * C) @@ -19,14 +20,14 @@ class MelConverter(nn.Module): def __init__( self, *, - sampling_rate: float = 16_000, - n_fft: int = 1024, - num_mels: int = 80, - hop_size: int = 256, - win_size: int = 1024, - fmin: float = 0, - fmax: float = 8_000, - norm_fn=torch.log10, + sampling_rate: float, + n_fft: int, + num_mels: int, + hop_size: int, + win_size: int, + fmin: float, + fmax: float, + norm_fn, ): super().__init__() self.sampling_rate = sampling_rate @@ -80,3 +81,26 @@ class MelConverter(nn.Module): spec = spectral_normalize_torch(spec, self.norm_fn) return spec + + +def get_mel_converter(mode: Literal['16k', '44k']) -> MelConverter: + if mode == '16k': + return MelConverter(sampling_rate=16_000, + n_fft=1024, + num_mels=80, + hop_size=256, + win_size=1024, + fmin=0, + fmax=8_000, + norm_fn=torch.log10) + elif mode == '44k': + return MelConverter(sampling_rate=44_100, + n_fft=2048, + num_mels=128, + hop_size=512, + win_size=2048, + fmin=0, + fmax=44100 / 2, + norm_fn=torch.log) + else: + raise ValueError(f'Unknown mode: {mode}') diff --git a/mmaudio/ext/rotary_embeddings.py b/third_party/MMAudio/mmaudio/ext/rotary_embeddings.py similarity index 74% rename from mmaudio/ext/rotary_embeddings.py rename to third_party/MMAudio/mmaudio/ext/rotary_embeddings.py index 16a9cf813a9cf24e35019986bcd1c38b25564c4e..41a7555424720ff971dad5ef0cff69d471f3828d 100644 --- a/mmaudio/ext/rotary_embeddings.py +++ b/third_party/MMAudio/mmaudio/ext/rotary_embeddings.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Tuple import torch from einops import rearrange @@ -7,7 +7,7 @@ from torch import Tensor # Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py # Ref: https://github.com/lucidrains/rotary-embedding-torch -DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' + def compute_rope_rotations(length: int, dim: int, theta: int, @@ -16,8 +16,7 @@ def compute_rope_rotations(length: int, device: Union[torch.device, str] = 'cpu') -> Tensor: assert dim % 2 == 0 - # with torch.amp.autocast(device_type='cuda', enabled=False): - with torch.amp.autocast(device_type=DEVICE, enabled=False): + with torch.amp.autocast(device_type='cuda', enabled=False): pos = torch.arange(length, dtype=torch.float32, device=device) freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) freqs *= freq_scaling @@ -28,9 +27,8 @@ def compute_rope_rotations(length: int, return rot -def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]: - # with torch.amp.autocast(device_type='cuda', enabled=False): - with torch.amp.autocast(device_type=DEVICE, enabled=False): +def apply_rope(x: Tensor, rot: Tensor) -> Tuple[Tensor, Tensor]: + with torch.amp.autocast(device_type='cuda', enabled=False): _x = x.float() _x = _x.view(*_x.shape[:-1], -1, 1, 2) x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1] diff --git a/mmaudio/ext/stft_converter.py b/third_party/MMAudio/mmaudio/ext/stft_converter.py similarity index 100% rename from mmaudio/ext/stft_converter.py rename to third_party/MMAudio/mmaudio/ext/stft_converter.py diff --git a/mmaudio/ext/stft_converter_mel.py b/third_party/MMAudio/mmaudio/ext/stft_converter_mel.py similarity index 100% rename from mmaudio/ext/stft_converter_mel.py rename to third_party/MMAudio/mmaudio/ext/stft_converter_mel.py diff --git a/mmaudio/ext/synchformer/LICENSE b/third_party/MMAudio/mmaudio/ext/synchformer/LICENSE similarity index 100% rename from mmaudio/ext/synchformer/LICENSE rename to third_party/MMAudio/mmaudio/ext/synchformer/LICENSE diff --git a/mmaudio/ext/synchformer/__init__.py b/third_party/MMAudio/mmaudio/ext/synchformer/__init__.py similarity index 100% rename from mmaudio/ext/synchformer/__init__.py rename to third_party/MMAudio/mmaudio/ext/synchformer/__init__.py diff --git a/mmaudio/ext/synchformer/divided_224_16x4.yaml b/third_party/MMAudio/mmaudio/ext/synchformer/divided_224_16x4.yaml similarity index 100% rename from mmaudio/ext/synchformer/divided_224_16x4.yaml rename to third_party/MMAudio/mmaudio/ext/synchformer/divided_224_16x4.yaml diff --git a/mmaudio/ext/synchformer/motionformer.py b/third_party/MMAudio/mmaudio/ext/synchformer/motionformer.py similarity index 100% rename from mmaudio/ext/synchformer/motionformer.py rename to third_party/MMAudio/mmaudio/ext/synchformer/motionformer.py diff --git a/mmaudio/ext/synchformer/synchformer.py b/third_party/MMAudio/mmaudio/ext/synchformer/synchformer.py similarity index 84% rename from mmaudio/ext/synchformer/synchformer.py rename to third_party/MMAudio/mmaudio/ext/synchformer/synchformer.py index dcaeda1a5a529a2b465a4be464384bdbfded2f75..80871f004d6f4c57f48594d90195f84f89d7cb0a 100644 --- a/mmaudio/ext/synchformer/synchformer.py +++ b/third_party/MMAudio/mmaudio/ext/synchformer/synchformer.py @@ -41,14 +41,14 @@ class Synchformer(nn.Module): return super().load_state_dict(sd, strict) -# if __name__ == "__main__": -# model = Synchformer().cuda().eval() -# sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True) -# model.load_state_dict(sd) - -# vid = torch.randn(2, 7, 16, 3, 224, 224).cuda() -# features = model.extract_vfeats(vid, for_loop=False).detach().cpu() -# print(features.shape) +if __name__ == "__main__": + model = Synchformer().cuda().eval() + sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True) + model.load_state_dict(sd) + + vid = torch.randn(2, 7, 16, 3, 224, 224).cuda() + features = model.extract_vfeats(vid, for_loop=False).detach().cpu() + print(features.shape) # extract and save the state dict only # sd = torch.load('./ext_weights/sync_model_audioset.pt')['model'] diff --git a/mmaudio/ext/synchformer/utils.py b/third_party/MMAudio/mmaudio/ext/synchformer/utils.py similarity index 100% rename from mmaudio/ext/synchformer/utils.py rename to third_party/MMAudio/mmaudio/ext/synchformer/utils.py diff --git a/mmaudio/ext/synchformer/video_model_builder.py b/third_party/MMAudio/mmaudio/ext/synchformer/video_model_builder.py similarity index 100% rename from mmaudio/ext/synchformer/video_model_builder.py rename to third_party/MMAudio/mmaudio/ext/synchformer/video_model_builder.py diff --git a/mmaudio/ext/synchformer/vit_helper.py b/third_party/MMAudio/mmaudio/ext/synchformer/vit_helper.py similarity index 100% rename from mmaudio/ext/synchformer/vit_helper.py rename to third_party/MMAudio/mmaudio/ext/synchformer/vit_helper.py diff --git a/third_party/MMAudio/mmaudio/model/__init__.py b/third_party/MMAudio/mmaudio/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mmaudio/model/embeddings.py b/third_party/MMAudio/mmaudio/model/embeddings.py similarity index 89% rename from mmaudio/model/embeddings.py rename to third_party/MMAudio/mmaudio/model/embeddings.py index d447a98f941f1231d1b1dac716db3047a6a8eb88..297feb4d2c79d306771f5436dbd4ada1a976b3bc 100644 --- a/mmaudio/model/embeddings.py +++ b/third_party/MMAudio/mmaudio/model/embeddings.py @@ -21,11 +21,12 @@ class TimestepEmbedder(nn.Module): assert dim % 2 == 0, 'dim must be even.' with torch.autocast('cuda', enabled=False): - self.freqs = ( + self.freqs = nn.Buffer( 1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) / - frequency_embedding_size))) + frequency_embedding_size)), + persistent=False) freq_scale = 10000 / max_period - self.freqs = nn.Parameter(freq_scale * self.freqs) + self.freqs = freq_scale * self.freqs def timestep_embedding(self, t): """ diff --git a/mmaudio/model/flow_matching.py b/third_party/MMAudio/mmaudio/model/flow_matching.py similarity index 72% rename from mmaudio/model/flow_matching.py rename to third_party/MMAudio/mmaudio/model/flow_matching.py index a04510ab888c0c3c3398360f97b8b7e3c55998ad..3aec539041fd285ade004d17390e45296950a57b 100644 --- a/mmaudio/model/flow_matching.py +++ b/third_party/MMAudio/mmaudio/model/flow_matching.py @@ -1,11 +1,9 @@ import logging -from typing import Callable, Iterable, Optional +from typing import Callable, Optional, List, Tuple import torch from torchdiffeq import odeint -# from torchcfm.conditional_flow_matching import ExactOptimalTransportConditionalFlowMatcher - log = logging.getLogger() @@ -42,15 +40,11 @@ class FlowMatching: self, x1: torch.Tensor, t: torch.Tensor, - Cs: list[torch.Tensor], + Cs: List[torch.Tensor], generator: Optional[torch.Generator] = None - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # x0 = torch.randn_like(x1, generator=generator) + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: x0 = torch.empty_like(x1).normal_(generator=generator) - # find mini-batch optimal transport - # x0, x1, _, Cs = self.fm.ot_sampler.sample_plan_with_labels(x0, x1, None, Cs, replace=True) - xt = self.get_conditional_flow(x0, x1, t) return x0, x1, xt, Cs @@ -74,15 +68,4 @@ class FlowMatching: dt = next_t - t x = x + dt * flow - # return odeint(fn, - # x0, - # torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype), - # method='rk4', - # options=dict(step_size=(t1 - t0) / self.num_steps))[-1] - # return odeint(fn, - # x0, - # torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype), - # method='euler', - # options=dict(step_size=(t1 - t0) / self.num_steps))[-1] - return x diff --git a/mmaudio/model/low_level.py b/third_party/MMAudio/mmaudio/model/low_level.py similarity index 100% rename from mmaudio/model/low_level.py rename to third_party/MMAudio/mmaudio/model/low_level.py diff --git a/mmaudio/model/networks.py b/third_party/MMAudio/mmaudio/model/networks.py similarity index 98% rename from mmaudio/model/networks.py rename to third_party/MMAudio/mmaudio/model/networks.py index e60e309c89d92cec70e7e673a4e842cc6716fae9..d38d0623c95e77653dd8ad9a4284738b8faa4d0d 100644 --- a/mmaudio/model/networks.py +++ b/third_party/MMAudio/mmaudio/model/networks.py @@ -166,10 +166,8 @@ class MMAudio(nn.Module): self._clip_seq_len, device=self.device) - # self.latent_rot = latent_rot.to(self.device) - # self.clip_rot = clip_rot.to(self.device) - self.register_buffer('latent_rot', latent_rot) - self.register_buffer('clip_rot', clip_rot) + self.latent_rot = nn.Buffer(latent_rot, persistent=False) + self.clip_rot = nn.Buffer(clip_rot, persistent=False) def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None: self._latent_seq_len = latent_seq_len @@ -285,6 +283,7 @@ class MMAudio(nn.Module): for block in self.fused_blocks: latent = block(latent, extended_c, self.latent_rot) + # should be extended_c; this is a minor implementation error #55 flow = self.final_layer(latent, global_c) # (B, N, out_dim), remove t return flow @@ -348,7 +347,7 @@ class MMAudio(nn.Module): if 'clip_rot' in src_dict: del src_dict['clip_rot'] - self.load_state_dict(src_dict, strict=False) + self.load_state_dict(src_dict, strict=True) @property def device(self) -> torch.device: diff --git a/mmaudio/model/sequence_config.py b/third_party/MMAudio/mmaudio/model/sequence_config.py similarity index 100% rename from mmaudio/model/sequence_config.py rename to third_party/MMAudio/mmaudio/model/sequence_config.py diff --git a/mmaudio/model/transformer_layers.py b/third_party/MMAudio/mmaudio/model/transformer_layers.py similarity index 96% rename from mmaudio/model/transformer_layers.py rename to third_party/MMAudio/mmaudio/model/transformer_layers.py index 3ca02ec3b6c00b9c39624d97d55a211cdd2e427d..6264d1debe1b3e29925ee4ad1811597ad147d2f0 100644 --- a/mmaudio/model/transformer_layers.py +++ b/third_party/MMAudio/mmaudio/model/transformer_layers.py @@ -1,11 +1,10 @@ -from typing import Optional +from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange -from torch.nn.attention import SDPBackend, sdpa_kernel from mmaudio.ext.rotary_embeddings import apply_rope from mmaudio.model.low_level import MLP, ChannelLastConv1d, ConvMLP @@ -45,7 +44,7 @@ class SelfAttention(nn.Module): def pre_attention( self, x: torch.Tensor, - rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + rot: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # x: batch_size * n_tokens * n_channels qkv = self.qkv(x) q, k, v = self.split_into_heads(qkv).chunk(3, dim=-1) @@ -118,7 +117,7 @@ class MMDitSingleBlock(nn.Module): q, k, v = self.attn.pre_attention(x, rot) return (q, k, v), (gate_msa, shift_mlp, scale_mlp, gate_mlp) - def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: tuple[torch.Tensor]): + def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: Tuple[torch.Tensor]): if self.pre_only: return x @@ -161,7 +160,7 @@ class JointBlock(nn.Module): def forward(self, latent: torch.Tensor, clip_f: torch.Tensor, text_f: torch.Tensor, global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: torch.Tensor, - clip_rot: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + clip_rot: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # latent: BS * N1 * D # clip_f: BS * N2 * D # c: BS * (1/N) * D diff --git a/third_party/MMAudio/mmaudio/model/utils/__init__.py b/third_party/MMAudio/mmaudio/model/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mmaudio/model/utils/distributions.py b/third_party/MMAudio/mmaudio/model/utils/distributions.py similarity index 100% rename from mmaudio/model/utils/distributions.py rename to third_party/MMAudio/mmaudio/model/utils/distributions.py diff --git a/mmaudio/model/utils/features_utils.py b/third_party/MMAudio/mmaudio/model/utils/features_utils.py similarity index 96% rename from mmaudio/model/utils/features_utils.py rename to third_party/MMAudio/mmaudio/model/utils/features_utils.py index 8b5ebcf685d98d9f024ce29df239e93312418bae..5a059e135d3b4ea697f875c12feef5e1882ae656 100644 --- a/mmaudio/model/utils/features_utils.py +++ b/third_party/MMAudio/mmaudio/model/utils/features_utils.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional +from typing import Literal, Optional, Tuple, List import open_clip import torch @@ -9,7 +9,7 @@ from open_clip import create_model_from_pretrained from torchvision.transforms import Normalize from mmaudio.ext.autoencoder import AutoEncoderModule -from mmaudio.ext.mel_converter import MelConverter +from mmaudio.ext.mel_converter import get_mel_converter from mmaudio.ext.synchformer import Synchformer from mmaudio.model.utils.distributions import DiagonalGaussianDistribution @@ -63,13 +63,13 @@ class FeaturesUtils(nn.Module): self.tokenizer = None if tod_vae_ckpt is not None: + self.mel_converter = get_mel_converter(mode) self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt, vocoder_ckpt_path=bigvgan_vocoder_ckpt, mode=mode, need_vae_encoder=need_vae_encoder) else: self.tod = None - self.mel_converter = MelConverter() def compile(self): if self.clip_model is not None: @@ -129,7 +129,7 @@ class FeaturesUtils(nn.Module): return x @torch.inference_mode() - def encode_text(self, text: list[str]) -> torch.Tensor: + def encode_text(self, text: List[str]) -> torch.Tensor: assert self.clip_model is not None, 'CLIP is not loaded' assert self.tokenizer is not None, 'Tokenizer is not loaded' # x: (B, L) diff --git a/mmaudio/model/utils/parameter_groups.py b/third_party/MMAudio/mmaudio/model/utils/parameter_groups.py similarity index 100% rename from mmaudio/model/utils/parameter_groups.py rename to third_party/MMAudio/mmaudio/model/utils/parameter_groups.py diff --git a/mmaudio/model/utils/sample_utils.py b/third_party/MMAudio/mmaudio/model/utils/sample_utils.py similarity index 100% rename from mmaudio/model/utils/sample_utils.py rename to third_party/MMAudio/mmaudio/model/utils/sample_utils.py diff --git a/third_party/MMAudio/mmaudio/runner.py b/third_party/MMAudio/mmaudio/runner.py new file mode 100644 index 0000000000000000000000000000000000000000..755ee76bea7de3f31a14a5512710c39743dc9239 --- /dev/null +++ b/third_party/MMAudio/mmaudio/runner.py @@ -0,0 +1,609 @@ +""" +trainer.py - wrapper and utility functions for network training +Compute loss, back-prop, update parameters, logging, etc. +""" +import os +from pathlib import Path +from typing import Optional, Union + +import torch +import torch.distributed +import torch.optim as optim +from av_bench.evaluate import evaluate +from av_bench.extract import extract +from nitrous_ema import PostHocEMA +from omegaconf import DictConfig +from torch.nn.parallel import DistributedDataParallel as DDP + +from mmaudio.model.flow_matching import FlowMatching +from mmaudio.model.networks import get_my_mmaudio +from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K +from mmaudio.model.utils.features_utils import FeaturesUtils +from mmaudio.model.utils.parameter_groups import get_parameter_groups +from mmaudio.model.utils.sample_utils import log_normal_sample +from mmaudio.utils.dist_utils import (info_if_rank_zero, local_rank, string_if_rank_zero) +from mmaudio.utils.log_integrator import Integrator +from mmaudio.utils.logger import TensorboardLogger +from mmaudio.utils.time_estimator import PartialTimeEstimator, TimeEstimator +from mmaudio.utils.video_joiner import VideoJoiner + + +class Runner: + + def __init__(self, + cfg: DictConfig, + log: TensorboardLogger, + run_path: Union[str, Path], + for_training: bool = True, + latent_mean: Optional[torch.Tensor] = None, + latent_std: Optional[torch.Tensor] = None): + self.exp_id = cfg.exp_id + self.use_amp = cfg.amp + self.enable_grad_scaler = cfg.enable_grad_scaler + self.for_training = for_training + self.cfg = cfg + + if cfg.model.endswith('16k'): + self.seq_cfg = CONFIG_16K + mode = '16k' + elif cfg.model.endswith('44k'): + self.seq_cfg = CONFIG_44K + mode = '44k' + else: + raise ValueError(f'Unknown model: {cfg.model}') + + self.sample_rate = self.seq_cfg.sampling_rate + self.duration_sec = self.seq_cfg.duration + + # setting up the model + empty_string_feat = torch.load('./ext_weights/empty_string.pth', weights_only=True)[0] + self.network = DDP(get_my_mmaudio(cfg.model, + latent_mean=latent_mean, + latent_std=latent_std, + empty_string_feat=empty_string_feat).cuda(), + device_ids=[local_rank], + broadcast_buffers=False) + if cfg.compile: + # NOTE: though train_fn and val_fn are very similar + # (early on they are implemented as a single function) + # keeping them separate and compiling them separately are CRUCIAL for high performance + self.train_fn = torch.compile(self.train_fn) + self.val_fn = torch.compile(self.val_fn) + + self.fm = FlowMatching(cfg.sampling.min_sigma, + inference_mode=cfg.sampling.method, + num_steps=cfg.sampling.num_steps) + + # ema profile + if for_training and cfg.ema.enable and local_rank == 0: + self.ema = PostHocEMA(self.network.module, + sigma_rels=cfg.ema.sigma_rels, + update_every=cfg.ema.update_every, + checkpoint_every_num_steps=cfg.ema.checkpoint_every, + checkpoint_folder=cfg.ema.checkpoint_folder, + step_size_correction=True).cuda() + self.ema_start = cfg.ema.start + else: + self.ema = None + + self.rng = torch.Generator(device='cuda') + self.rng.manual_seed(cfg['seed'] + local_rank) + + # setting up feature extractors and VAEs + if mode == '16k': + self.features = FeaturesUtils( + tod_vae_ckpt=cfg['vae_16k_ckpt'], + bigvgan_vocoder_ckpt=cfg['bigvgan_vocoder_ckpt'], + synchformer_ckpt=cfg['synchformer_ckpt'], + enable_conditions=True, + mode=mode, + need_vae_encoder=False, + ) + elif mode == '44k': + self.features = FeaturesUtils( + tod_vae_ckpt=cfg['vae_44k_ckpt'], + synchformer_ckpt=cfg['synchformer_ckpt'], + enable_conditions=True, + mode=mode, + need_vae_encoder=False, + ) + self.features = self.features.cuda().eval() + + if cfg.compile: + self.features.compile() + + # hyperparameters + self.log_normal_sampling_mean = cfg.sampling.mean + self.log_normal_sampling_scale = cfg.sampling.scale + self.null_condition_probability = cfg.null_condition_probability + self.cfg_strength = cfg.cfg_strength + + # setting up logging + self.log = log + self.run_path = Path(run_path) + vgg_cfg = cfg.data.VGGSound + if for_training: + self.val_video_joiner = VideoJoiner(vgg_cfg.root, self.run_path / 'val-sampled-videos', + self.sample_rate, self.duration_sec) + else: + self.test_video_joiner = VideoJoiner(vgg_cfg.root, + self.run_path / 'test-sampled-videos', + self.sample_rate, self.duration_sec) + string_if_rank_zero(self.log, 'model_size', + f'{sum([param.nelement() for param in self.network.parameters()])}') + string_if_rank_zero( + self.log, 'number_of_parameters_that_require_gradient: ', + str( + sum([ + param.nelement() + for param in filter(lambda p: p.requires_grad, self.network.parameters()) + ]))) + info_if_rank_zero(self.log, 'torch version: ' + torch.__version__) + self.train_integrator = Integrator(self.log, distributed=True) + self.val_integrator = Integrator(self.log, distributed=True) + + # setting up optimizer and loss + if for_training: + self.enter_train() + parameter_groups = get_parameter_groups(self.network, cfg, print_log=(local_rank == 0)) + self.optimizer = optim.AdamW(parameter_groups, + lr=cfg['learning_rate'], + weight_decay=cfg['weight_decay'], + betas=[0.9, 0.95], + eps=1e-6 if self.use_amp else 1e-8, + fused=True) + if self.enable_grad_scaler: + self.scaler = torch.amp.GradScaler(init_scale=2048) + self.clip_grad_norm = cfg['clip_grad_norm'] + + # linearly warmup learning rate + linear_warmup_steps = cfg['linear_warmup_steps'] + + def warmup(currrent_step: int): + return (currrent_step + 1) / (linear_warmup_steps + 1) + + warmup_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup) + + # setting up learning rate scheduler + if cfg['lr_schedule'] == 'constant': + next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda _: 1) + elif cfg['lr_schedule'] == 'poly': + total_num_iter = cfg['iterations'] + next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, + lr_lambda=lambda x: + (1 - (x / total_num_iter))**0.9) + elif cfg['lr_schedule'] == 'step': + next_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, + cfg['lr_schedule_steps'], + cfg['lr_schedule_gamma']) + else: + raise NotImplementedError + + self.scheduler = optim.lr_scheduler.SequentialLR(self.optimizer, + [warmup_scheduler, next_scheduler], + [linear_warmup_steps]) + + # Logging info + self.log_text_interval = cfg['log_text_interval'] + self.log_extra_interval = cfg['log_extra_interval'] + self.save_weights_interval = cfg['save_weights_interval'] + self.save_checkpoint_interval = cfg['save_checkpoint_interval'] + self.save_copy_iterations = cfg['save_copy_iterations'] + self.num_iterations = cfg['num_iterations'] + if cfg['debug']: + self.log_text_interval = self.log_extra_interval = 1 + + # update() is called when we log metrics, within the logger + self.log.batch_timer = TimeEstimator(self.num_iterations, self.log_text_interval) + # update() is called every iteration, in this script + self.log.data_timer = PartialTimeEstimator(self.num_iterations, 1, ema_alpha=0.9) + else: + self.enter_val() + + def train_fn( + self, + clip_f: torch.Tensor, + sync_f: torch.Tensor, + text_f: torch.Tensor, + a_mean: torch.Tensor, + a_std: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # sample + a_randn = torch.empty_like(a_mean).normal_(generator=self.rng) + x1 = a_mean + a_std * a_randn + bs = x1.shape[0] # batch_size * seq_len * num_channels + + # normalize the latents + x1 = self.network.module.normalize(x1) + + t = log_normal_sample(x1, + generator=self.rng, + m=self.log_normal_sampling_mean, + s=self.log_normal_sampling_scale) + x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1, + t, + Cs=[clip_f, sync_f, text_f], + generator=self.rng) + + # classifier-free training + samples = torch.rand(bs, device=x1.device, generator=self.rng) + null_video = (samples < self.null_condition_probability) + clip_f[null_video] = self.network.module.empty_clip_feat + sync_f[null_video] = self.network.module.empty_sync_feat + + samples = torch.rand(bs, device=x1.device, generator=self.rng) + null_text = (samples < self.null_condition_probability) + text_f[null_text] = self.network.module.empty_string_feat + + pred_v = self.network(xt, clip_f, sync_f, text_f, t) + loss = self.fm.loss(pred_v, x0, x1) + mean_loss = loss.mean() + return x1, loss, mean_loss, t + + def val_fn( + self, + clip_f: torch.Tensor, + sync_f: torch.Tensor, + text_f: torch.Tensor, + x1: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + bs = x1.shape[0] # batch_size * seq_len * num_channels + # normalize the latents + x1 = self.network.module.normalize(x1) + t = log_normal_sample(x1, + generator=self.rng, + m=self.log_normal_sampling_mean, + s=self.log_normal_sampling_scale) + x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1, + t, + Cs=[clip_f, sync_f, text_f], + generator=self.rng) + + # classifier-free training + samples = torch.rand(bs, device=x1.device, generator=self.rng) + # null mask is for when a video is provided but we decided to ignore it + null_video = (samples < self.null_condition_probability) + # complete mask is for when a video is not provided or we decided to ignore it + clip_f[null_video] = self.network.module.empty_clip_feat + sync_f[null_video] = self.network.module.empty_sync_feat + + samples = torch.rand(bs, device=x1.device, generator=self.rng) + null_text = (samples < self.null_condition_probability) + text_f[null_text] = self.network.module.empty_string_feat + + pred_v = self.network(xt, clip_f, sync_f, text_f, t) + + loss = self.fm.loss(pred_v, x0, x1) + mean_loss = loss.mean() + return loss, mean_loss, t + + def train_pass(self, data, it: int = 0): + + if not self.for_training: + raise ValueError('train_pass() should not be called when not training.') + + self.enter_train() + with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16): + clip_f = data['clip_features'].cuda(non_blocking=True) + sync_f = data['sync_features'].cuda(non_blocking=True) + text_f = data['text_features'].cuda(non_blocking=True) + video_exist = data['video_exist'].cuda(non_blocking=True) + text_exist = data['text_exist'].cuda(non_blocking=True) + a_mean = data['a_mean'].cuda(non_blocking=True) + a_std = data['a_std'].cuda(non_blocking=True) + + # these masks are for non-existent data; masking for CFG training is in train_fn + clip_f[~video_exist] = self.network.module.empty_clip_feat + sync_f[~video_exist] = self.network.module.empty_sync_feat + text_f[~text_exist] = self.network.module.empty_string_feat + + self.log.data_timer.end() + if it % self.log_extra_interval == 0: + unmasked_clip_f = clip_f.clone() + unmasked_sync_f = sync_f.clone() + unmasked_text_f = text_f.clone() + x1, loss, mean_loss, t = self.train_fn(clip_f, sync_f, text_f, a_mean, a_std) + + self.train_integrator.add_dict({'loss': mean_loss}) + + if it % self.log_text_interval == 0 and it != 0: + self.train_integrator.add_scalar('lr', self.scheduler.get_last_lr()[0]) + self.train_integrator.add_binned_tensor('binned_loss', loss, t) + self.train_integrator.finalize('train', it) + self.train_integrator.reset_except_hooks() + + # Backward pass + self.optimizer.zero_grad(set_to_none=True) + if self.enable_grad_scaler: + self.scaler.scale(mean_loss).backward() + self.scaler.unscale_(self.optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(), + self.clip_grad_norm) + self.scaler.step(self.optimizer) + self.scaler.update() + else: + mean_loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(), + self.clip_grad_norm) + self.optimizer.step() + + if self.ema is not None and it >= self.ema_start: + self.ema.update() + self.scheduler.step() + self.integrator.add_scalar('grad_norm', grad_norm) + + self.enter_val() + with torch.amp.autocast('cuda', enabled=self.use_amp, + dtype=torch.bfloat16), torch.inference_mode(): + try: + if it % self.log_extra_interval == 0: + # save GT audio + # unnormalize the latents + x1 = self.network.module.unnormalize(x1[0:1]) + mel = self.features.decode(x1) + audio = self.features.vocode(mel).cpu()[0] # 1 * num_samples + self.log.log_spectrogram('train', f'spec-gt-r{local_rank}', mel.cpu()[0], it) + self.log.log_audio('train', + f'audio-gt-r{local_rank}', + audio, + it, + sample_rate=self.sample_rate) + + # save audio from sampling + x0 = torch.empty_like(x1[0:1]).normal_(generator=self.rng) + clip_f = unmasked_clip_f[0:1] + sync_f = unmasked_sync_f[0:1] + text_f = unmasked_text_f[0:1] + conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f) + empty_conditions = self.network.module.get_empty_conditions(x0.shape[0]) + cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper( + t, x, conditions, empty_conditions, self.cfg_strength) + x1_hat = self.fm.to_data(cfg_ode_wrapper, x0) + x1_hat = self.network.module.unnormalize(x1_hat) + mel = self.features.decode(x1_hat) + audio = self.features.vocode(mel).cpu()[0] + self.log.log_spectrogram('train', f'spec-r{local_rank}', mel.cpu()[0], it) + self.log.log_audio('train', + f'audio-r{local_rank}', + audio, + it, + sample_rate=self.sample_rate) + except Exception as e: + self.log.warning(f'Error in extra logging: {e}') + if self.cfg.debug: + raise + + # Save network weights and checkpoint if needed + save_copy = it in self.save_copy_iterations + + if (it % self.save_weights_interval == 0 and it != 0) or save_copy: + self.save_weights(it) + + if it % self.save_checkpoint_interval == 0 and it != 0: + self.save_checkpoint(it, save_copy=save_copy) + + self.log.data_timer.start() + + @torch.inference_mode() + def validation_pass(self, data, it: int = 0): + self.enter_val() + with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16): + clip_f = data['clip_features'].cuda(non_blocking=True) + sync_f = data['sync_features'].cuda(non_blocking=True) + text_f = data['text_features'].cuda(non_blocking=True) + video_exist = data['video_exist'].cuda(non_blocking=True) + text_exist = data['text_exist'].cuda(non_blocking=True) + a_mean = data['a_mean'].cuda(non_blocking=True) + a_std = data['a_std'].cuda(non_blocking=True) + + clip_f[~video_exist] = self.network.module.empty_clip_feat + sync_f[~video_exist] = self.network.module.empty_sync_feat + text_f[~text_exist] = self.network.module.empty_string_feat + a_randn = torch.empty_like(a_mean).normal_(generator=self.rng) + x1 = a_mean + a_std * a_randn + + self.log.data_timer.end() + loss, mean_loss, t = self.val_fn(clip_f.clone(), sync_f.clone(), text_f.clone(), x1) + + self.val_integrator.add_binned_tensor('binned_loss', loss, t) + self.val_integrator.add_dict({'loss': mean_loss}) + + self.log.data_timer.start() + + @torch.inference_mode() + def inference_pass(self, + data, + it: int, + data_cfg: DictConfig, + *, + save_eval: bool = True) -> Path: + self.enter_val() + with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16): + clip_f = data['clip_features'].cuda(non_blocking=True) + sync_f = data['sync_features'].cuda(non_blocking=True) + text_f = data['text_features'].cuda(non_blocking=True) + video_exist = data['video_exist'].cuda(non_blocking=True) + text_exist = data['text_exist'].cuda(non_blocking=True) + a_mean = data['a_mean'].cuda(non_blocking=True) # for the shape only + + clip_f[~video_exist] = self.network.module.empty_clip_feat + sync_f[~video_exist] = self.network.module.empty_sync_feat + text_f[~text_exist] = self.network.module.empty_string_feat + + # sample + x0 = torch.empty_like(a_mean).normal_(generator=self.rng) + conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f) + empty_conditions = self.network.module.get_empty_conditions(x0.shape[0]) + cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper( + t, x, conditions, empty_conditions, self.cfg_strength) + x1_hat = self.fm.to_data(cfg_ode_wrapper, x0) + x1_hat = self.network.module.unnormalize(x1_hat) + mel = self.features.decode(x1_hat) + audio = self.features.vocode(mel).cpu() + for i in range(audio.shape[0]): + video_id = data['id'][i] + if (not self.for_training) and i == 0: + # save very few videos + self.test_video_joiner.join(video_id, f'{video_id}', audio[i].transpose(0, 1)) + + if data_cfg.output_subdir is not None: + # validation + if save_eval: + iter_naming = f'{it:09d}' + else: + iter_naming = 'val-cache' + audio_dir = self.log.log_audio(iter_naming, + f'{video_id}', + audio[i], + it=None, + sample_rate=self.sample_rate, + subdir=Path(data_cfg.output_subdir)) + if save_eval and i == 0: + self.val_video_joiner.join(video_id, f'{iter_naming}-{video_id}', + audio[i].transpose(0, 1)) + else: + # full test set, usually + audio_dir = self.log.log_audio(f'{data_cfg.tag}-sampled', + f'{video_id}', + audio[i], + it=None, + sample_rate=self.sample_rate) + + return Path(audio_dir) + + @torch.inference_mode() + def eval(self, audio_dir: Path, it: int, data_cfg: DictConfig) -> dict[str, float]: + with torch.amp.autocast('cuda', enabled=False): + if local_rank == 0: + extract(audio_path=audio_dir, + output_path=audio_dir / 'cache', + device='cuda', + batch_size=32, + audio_length=8) + output_metrics = evaluate(gt_audio_cache=Path(data_cfg.gt_cache), + pred_audio_cache=audio_dir / 'cache') + for k, v in output_metrics.items(): + # pad k to 10 characters + # pad v to 10 decimal places + self.log.log_scalar(f'{data_cfg.tag}/{k}', v, it) + self.log.info(f'{data_cfg.tag}/{k:<10}: {v:.10f}') + else: + output_metrics = None + + return output_metrics + + def save_weights(self, it, save_copy=False): + if local_rank != 0: + return + + os.makedirs(self.run_path, exist_ok=True) + if save_copy: + model_path = self.run_path / f'{self.exp_id}_{it}.pth' + torch.save(self.network.module.state_dict(), model_path) + self.log.info(f'Network weights saved to {model_path}.') + + # if last exists, move it to a shadow copy + model_path = self.run_path / f'{self.exp_id}_last.pth' + if model_path.exists(): + shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow')) + model_path.replace(shadow_path) + self.log.info(f'Network weights shadowed to {shadow_path}.') + + torch.save(self.network.module.state_dict(), model_path) + self.log.info(f'Network weights saved to {model_path}.') + + def save_checkpoint(self, it, save_copy=False): + if local_rank != 0: + return + + checkpoint = { + 'it': it, + 'weights': self.network.module.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'scheduler': self.scheduler.state_dict(), + 'ema': self.ema.state_dict() if self.ema is not None else None, + } + + os.makedirs(self.run_path, exist_ok=True) + if save_copy: + model_path = self.run_path / f'{self.exp_id}_ckpt_{it}.pth' + torch.save(checkpoint, model_path) + self.log.info(f'Checkpoint saved to {model_path}.') + + # if ckpt_last exists, move it to a shadow copy + model_path = self.run_path / f'{self.exp_id}_ckpt_last.pth' + if model_path.exists(): + shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow')) + model_path.replace(shadow_path) # moves the file + self.log.info(f'Checkpoint shadowed to {shadow_path}.') + + torch.save(checkpoint, model_path) + self.log.info(f'Checkpoint saved to {model_path}.') + + def get_latest_checkpoint_path(self): + ckpt_path = self.run_path / f'{self.exp_id}_ckpt_last.pth' + if not ckpt_path.exists(): + info_if_rank_zero(self.log, f'No checkpoint found at {ckpt_path}.') + return None + return ckpt_path + + def get_latest_weight_path(self): + weight_path = self.run_path / f'{self.exp_id}_last.pth' + if not weight_path.exists(): + self.log.info(f'No weight found at {weight_path}.') + return None + return weight_path + + def get_final_ema_weight_path(self): + weight_path = self.run_path / f'{self.exp_id}_ema_final.pth' + if not weight_path.exists(): + self.log.info(f'No weight found at {weight_path}.') + return None + return weight_path + + def load_checkpoint(self, path): + # This method loads everything and should be used to resume training + map_location = 'cuda:%d' % local_rank + checkpoint = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True) + + it = checkpoint['it'] + weights = checkpoint['weights'] + optimizer = checkpoint['optimizer'] + scheduler = checkpoint['scheduler'] + if self.ema is not None: + self.ema.load_state_dict(checkpoint['ema']) + self.log.info(f'EMA states loaded from step {self.ema.step}') + + map_location = 'cuda:%d' % local_rank + self.network.module.load_state_dict(weights) + self.optimizer.load_state_dict(optimizer) + self.scheduler.load_state_dict(scheduler) + + self.log.info(f'Global iteration {it} loaded.') + self.log.info('Network weights, optimizer states, and scheduler states loaded.') + + return it + + def load_weights_in_memory(self, src_dict): + self.network.module.load_weights(src_dict) + self.log.info('Network weights loaded from memory.') + + def load_weights(self, path): + # This method loads only the network weight and should be used to load a pretrained model + map_location = 'cuda:%d' % local_rank + src_dict = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True) + + self.log.info(f'Importing network weights from {path}...') + self.load_weights_in_memory(src_dict) + + def weights(self): + return self.network.module.state_dict() + + def enter_train(self): + self.integrator = self.train_integrator + self.network.train() + return self + + def enter_val(self): + self.network.eval() + return self diff --git a/third_party/MMAudio/mmaudio/sample.py b/third_party/MMAudio/mmaudio/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..72b83389d7dbb55bed02991f51731b0d1e346a2b --- /dev/null +++ b/third_party/MMAudio/mmaudio/sample.py @@ -0,0 +1,90 @@ +import json +import logging +import os +import random + +import numpy as np +import torch +from hydra.core.hydra_config import HydraConfig +from omegaconf import DictConfig, open_dict +from tqdm import tqdm + +from mmaudio.data.data_setup import setup_test_datasets +from mmaudio.runner import Runner +from mmaudio.utils.dist_utils import info_if_rank_zero +from mmaudio.utils.logger import TensorboardLogger + +local_rank = int(os.environ['LOCAL_RANK']) +world_size = int(os.environ['WORLD_SIZE']) + + +def sample(cfg: DictConfig): + # initial setup + num_gpus = world_size + run_dir = HydraConfig.get().run.dir + + # wrap python logger with a tensorboard logger + log = TensorboardLogger(cfg.exp_id, + run_dir, + logging.getLogger(), + is_rank0=(local_rank == 0), + enable_email=cfg.enable_email and not cfg.debug) + + info_if_rank_zero(log, f'All configuration: {cfg}') + info_if_rank_zero(log, f'Number of GPUs detected: {num_gpus}') + + # cuda setup + torch.cuda.set_device(local_rank) + torch.backends.cudnn.benchmark = cfg.cudnn_benchmark + + # number of dataloader workers + info_if_rank_zero(log, f'Number of dataloader workers (per GPU): {cfg.num_workers}') + + # Set seeds to ensure the same initialization + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + random.seed(cfg.seed) + + # setting up configurations + info_if_rank_zero(log, f'Configuration: {cfg}') + info_if_rank_zero(log, f'Batch size (per GPU): {cfg.batch_size}') + + # construct the trainer + runner = Runner(cfg, log=log, run_path=run_dir, for_training=False).enter_val() + + # load the last weights if needed + if cfg['weights'] is not None: + info_if_rank_zero(log, f'Loading weights from the disk: {cfg["weights"]}') + runner.load_weights(cfg['weights']) + cfg['weights'] = None + else: + weights = runner.get_final_ema_weight_path() + if weights is not None: + info_if_rank_zero(log, f'Automatically finding weight: {weights}') + runner.load_weights(weights) + + # setup datasets + dataset, sampler, loader = setup_test_datasets(cfg) + data_cfg = cfg.data.ExtractedVGG_test + with open_dict(data_cfg): + if cfg.output_name is not None: + # append to the tag + data_cfg.tag = f'{data_cfg.tag}-{cfg.output_name}' + + # loop + audio_path = None + for curr_iter, data in enumerate(tqdm(loader)): + new_audio_path = runner.inference_pass(data, curr_iter, data_cfg) + if audio_path is None: + audio_path = new_audio_path + else: + assert audio_path == new_audio_path, 'Different audio path detected' + + info_if_rank_zero(log, f'Inference completed. Audio path: {audio_path}') + output_metrics = runner.eval(audio_path, curr_iter, data_cfg) + + if local_rank == 0: + # write the output metrics to run_dir + output_metrics_path = os.path.join(run_dir, f'{data_cfg.tag}-output_metrics.json') + with open(output_metrics_path, 'w') as f: + json.dump(output_metrics, f, indent=4) diff --git a/third_party/MMAudio/mmaudio/utils/__init__.py b/third_party/MMAudio/mmaudio/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mmaudio/utils/dist_utils.py b/third_party/MMAudio/mmaudio/utils/dist_utils.py similarity index 100% rename from mmaudio/utils/dist_utils.py rename to third_party/MMAudio/mmaudio/utils/dist_utils.py diff --git a/mmaudio/utils/download_utils.py b/third_party/MMAudio/mmaudio/utils/download_utils.py similarity index 100% rename from mmaudio/utils/download_utils.py rename to third_party/MMAudio/mmaudio/utils/download_utils.py diff --git a/third_party/MMAudio/mmaudio/utils/email_utils.py b/third_party/MMAudio/mmaudio/utils/email_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..873b77c2b3b8cc0abe4df61c7fb981159c8bccfc --- /dev/null +++ b/third_party/MMAudio/mmaudio/utils/email_utils.py @@ -0,0 +1,50 @@ +import logging +import os +from datetime import datetime + +import requests +from dotenv import load_dotenv +from pytz import timezone + +from mmaudio.utils.timezone import my_timezone + +_source = 'USE YOURS' +_target = 'USE YOURS' + +log = logging.getLogger() + +_fmt = "%Y-%m-%d %H:%M:%S %Z%z" + + +class EmailSender: + + def __init__(self, exp_id: str, enable: bool): + self.exp_id = exp_id + self.enable = enable + if enable: + load_dotenv() + self.MAILGUN_API_KEY = os.getenv('MAILGUN_API_KEY') + if self.MAILGUN_API_KEY is None: + log.warning('MAILGUN_API_KEY is not set') + self.enable = False + + def send(self, subject, content): + if self.enable: + subject = str(subject) + content = str(content) + try: + return requests.post(f'https://api.mailgun.net/v3/{_source}/messages', + auth=('api', self.MAILGUN_API_KEY), + data={ + 'from': + f'๐Ÿค– ', + 'to': [f'{_target}'], + 'subject': + f'[{self.exp_id}] {subject}', + 'text': + ('\n\n' + content + '\n\n\n' + + datetime.now(timezone(my_timezone)).strftime(_fmt)), + }, + timeout=20) + except Exception as e: + log.error(f'Failed to send email: {e}') diff --git a/third_party/MMAudio/mmaudio/utils/log_integrator.py b/third_party/MMAudio/mmaudio/utils/log_integrator.py new file mode 100644 index 0000000000000000000000000000000000000000..af042fa05bd73d3411f3464de3b0ed61dad61ab7 --- /dev/null +++ b/third_party/MMAudio/mmaudio/utils/log_integrator.py @@ -0,0 +1,112 @@ +""" +Integrate numerical values for some iterations +Typically used for loss computation / logging to tensorboard +Call finalize and create a new Integrator when you want to display/log +""" +from typing import Callable, Union + +import torch + +from mmaudio.utils.logger import TensorboardLogger +from mmaudio.utils.tensor_utils import distribute_into_histogram + + +class Integrator: + + def __init__(self, logger: TensorboardLogger, distributed: bool = True): + self.values = {} + self.counts = {} + self.hooks = [] # List is used here to maintain insertion order + + # for binned tensors + self.binned_tensors = {} + self.binned_tensor_indices = {} + + self.logger = logger + + self.distributed = distributed + self.local_rank = torch.distributed.get_rank() + self.world_size = torch.distributed.get_world_size() + + def add_scalar(self, key: str, x: Union[torch.Tensor, int, float]): + if isinstance(x, torch.Tensor): + x = x.detach() + if x.dtype in [torch.long, torch.int, torch.bool]: + x = x.float() + + if key not in self.values: + self.counts[key] = 1 + self.values[key] = x + else: + self.counts[key] += 1 + self.values[key] += x + + def add_dict(self, tensor_dict: dict[str, torch.Tensor]): + for k, v in tensor_dict.items(): + self.add_scalar(k, v) + + def add_binned_tensor(self, key: str, x: torch.Tensor, indices: torch.Tensor): + if key not in self.binned_tensors: + self.binned_tensors[key] = [x.detach().flatten()] + self.binned_tensor_indices[key] = [indices.detach().flatten()] + else: + self.binned_tensors[key].append(x.detach().flatten()) + self.binned_tensor_indices[key].append(indices.detach().flatten()) + + def add_hook(self, hook: Callable[[torch.Tensor], tuple[str, torch.Tensor]]): + """ + Adds a custom hook, i.e. compute new metrics using values in the dict + The hook takes the dict as argument, and returns a (k, v) tuple + e.g. for computing IoU + """ + self.hooks.append(hook) + + def reset_except_hooks(self): + self.values = {} + self.counts = {} + + # Average and output the metrics + def finalize(self, prefix: str, it: int, ignore_timer: bool = False) -> None: + + for hook in self.hooks: + k, v = hook(self.values) + self.add_scalar(k, v) + + # for the metrics + outputs = {} + for k, v in self.values.items(): + avg = v / self.counts[k] + if self.distributed: + # Inplace operation + if isinstance(avg, torch.Tensor): + avg = avg.cuda() + else: + avg = torch.tensor(avg).cuda() + torch.distributed.reduce(avg, dst=0) + + if self.local_rank == 0: + avg = (avg / self.world_size).cpu().item() + outputs[k] = avg + else: + # Simple does it + outputs[k] = avg + + if (not self.distributed) or (self.local_rank == 0): + self.logger.log_metrics(prefix, outputs, it, ignore_timer=ignore_timer) + + # for the binned tensors + for k, v in self.binned_tensors.items(): + x = torch.cat(v, dim=0) + indices = torch.cat(self.binned_tensor_indices[k], dim=0) + hist, count = distribute_into_histogram(x, indices) + + if self.distributed: + torch.distributed.reduce(hist, dst=0) + torch.distributed.reduce(count, dst=0) + if self.local_rank == 0: + hist = hist / count + else: + hist = hist / count + + if (not self.distributed) or (self.local_rank == 0): + self.logger.log_histogram(f'{prefix}/{k}', hist, it) diff --git a/third_party/MMAudio/mmaudio/utils/logger.py b/third_party/MMAudio/mmaudio/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..3878232ec3e0097f4063fcaa9111564e84ea98b9 --- /dev/null +++ b/third_party/MMAudio/mmaudio/utils/logger.py @@ -0,0 +1,231 @@ +""" +Dumps things to tensorboard and console +""" + +import datetime +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Optional, Union + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torchaudio +from PIL import Image +from pytz import timezone +from torch.utils.tensorboard import SummaryWriter + +from mmaudio.utils.email_utils import EmailSender +from mmaudio.utils.time_estimator import PartialTimeEstimator, TimeEstimator +from mmaudio.utils.timezone import my_timezone + + +def tensor_to_numpy(image: torch.Tensor): + image_np = (image.numpy() * 255).astype('uint8') + return image_np + + +def detach_to_cpu(x: torch.Tensor): + return x.detach().cpu() + + +def fix_width_trunc(x: float): + return ('{:.9s}'.format('{:0.9f}'.format(x))) + + +def plot_spectrogram(spectrogram: np.ndarray, title=None, ylabel="freq_bin", ax=None): + if ax is None: + _, ax = plt.subplots(1, 1) + if title is not None: + ax.set_title(title) + ax.set_ylabel(ylabel) + ax.imshow(spectrogram, origin="lower", aspect="auto", interpolation="nearest") + + +class TensorboardLogger: + + def __init__(self, + exp_id: str, + run_dir: Union[Path, str], + py_logger: logging.Logger, + *, + is_rank0: bool = False, + enable_email: bool = False): + self.exp_id = exp_id + self.run_dir = Path(run_dir) + self.py_log = py_logger + self.email_sender = EmailSender(exp_id, enable=(is_rank0 and enable_email)) + if is_rank0: + self.tb_log = SummaryWriter(run_dir) + else: + self.tb_log = None + + # Get current git info for logging + try: + import git + repo = git.Repo(".") + git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha) + except (ImportError, RuntimeError, TypeError): + print('Failed to fetch git info. Defaulting to None') + git_info = 'None' + + self.log_string('git', git_info) + + # log the SLURM job id if available + job_id = os.environ.get('SLURM_JOB_ID', None) + if job_id is not None: + self.log_string('slurm_job_id', job_id) + self.email_sender.send(f'Job {job_id} started', f'Job started {run_dir}') + + # used when logging metrics + self.batch_timer: TimeEstimator = None + self.data_timer: PartialTimeEstimator = None + + self.nan_count = defaultdict(int) + + def log_scalar(self, tag: str, x: float, it: int): + if self.tb_log is None: + return + if math.isnan(x) and 'grad_norm' not in tag: + self.nan_count[tag] += 1 + if self.nan_count[tag] == 10: + self.email_sender.send( + f'Nan detected in {tag} @ {self.run_dir}', + f'Nan detected in {tag} at iteration {it}; run_dir: {self.run_dir}') + else: + self.nan_count[tag] = 0 + self.tb_log.add_scalar(tag, x, it) + + def log_metrics(self, + prefix: str, + metrics: dict[str, float], + it: int, + ignore_timer: bool = False): + msg = f'{self.exp_id}-{prefix} - it {it:6d}: ' + metrics_msg = '' + for k, v in sorted(metrics.items()): + self.log_scalar(f'{prefix}/{k}', v, it) + metrics_msg += f'{k: >10}:{v:.7f},\t' + + if self.batch_timer is not None and not ignore_timer: + self.batch_timer.update() + avg_time = self.batch_timer.get_and_reset_avg_time() + data_time = self.data_timer.get_and_reset_avg_time() + + # add time to tensorboard + self.log_scalar(f'{prefix}/avg_time', avg_time, it) + self.log_scalar(f'{prefix}/data_time', data_time, it) + + est = self.batch_timer.get_est_remaining(it) + est = datetime.timedelta(seconds=est) + if est.days > 0: + remaining_str = f'{est.days}d {est.seconds // 3600}h' + else: + remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m' + eta = datetime.datetime.now(timezone(my_timezone)) + est + eta_str = eta.strftime('%Y-%m-%d %H:%M:%S %Z%z') + time_msg = f'avg_time:{avg_time:.3f},data:{data_time:.3f},remaining:{remaining_str},eta:{eta_str},\t' + msg = f'{msg} {time_msg}' + + msg = f'{msg} {metrics_msg}' + self.py_log.info(msg) + + def log_histogram(self, tag: str, hist: torch.Tensor, it: int): + if self.tb_log is None: + return + # hist should be a 1D tensor + hist = hist.cpu().numpy() + fig, ax = plt.subplots() + x_range = np.linspace(0, 1, len(hist)) + ax.bar(x_range, hist, width=1 / (len(hist) - 1)) + ax.set_xticks(x_range) + ax.set_xticklabels(x_range) + plt.tight_layout() + self.tb_log.add_figure(tag, fig, it) + plt.close() + + def log_image(self, prefix: str, tag: str, image: np.ndarray, it: int): + image_dir = self.run_dir / f'{prefix}_images' + image_dir.mkdir(exist_ok=True, parents=True) + + image = Image.fromarray(image) + image.save(image_dir / f'{it:09d}_{tag}.png') + + def log_audio(self, + prefix: str, + tag: str, + waveform: torch.Tensor, + it: Optional[int] = None, + *, + subdir: Optional[Path] = None, + sample_rate: int = 16000) -> Path: + if subdir is None: + audio_dir = self.run_dir / prefix + else: + audio_dir = self.run_dir / subdir / prefix + audio_dir.mkdir(exist_ok=True, parents=True) + + if it is None: + name = f'{tag}.flac' + else: + name = f'{it:09d}_{tag}.flac' + + torchaudio.save(audio_dir / name, + waveform.cpu().float(), + sample_rate=sample_rate, + channels_first=True) + return Path(audio_dir) + + def log_spectrogram( + self, + prefix: str, + tag: str, + spec: torch.Tensor, + it: Optional[int], + *, + subdir: Optional[Path] = None, + ): + if subdir is None: + spec_dir = self.run_dir / prefix + else: + spec_dir = self.run_dir / subdir / prefix + spec_dir.mkdir(exist_ok=True, parents=True) + + if it is None: + name = f'{tag}.png' + else: + name = f'{it:09d}_{tag}.png' + + plot_spectrogram(spec.cpu().float()) + plt.tight_layout() + plt.savefig(spec_dir / name) + plt.close() + + def log_string(self, tag: str, x: str): + self.py_log.info(f'{tag} - {x}') + if self.tb_log is None: + return + self.tb_log.add_text(tag, x) + + def debug(self, x): + self.py_log.debug(x) + + def info(self, x): + self.py_log.info(x) + + def warning(self, x): + self.py_log.warning(x) + + def error(self, x): + self.py_log.error(x) + + def critical(self, x): + self.py_log.critical(x) + + self.email_sender.send(f'Error occurred in {self.run_dir}', x) + + def complete(self): + self.email_sender.send(f'Job completed in {self.run_dir}', 'Job completed') diff --git a/third_party/MMAudio/mmaudio/utils/synthesize_ema.py b/third_party/MMAudio/mmaudio/utils/synthesize_ema.py new file mode 100644 index 0000000000000000000000000000000000000000..d71348010f5776360460152d9d910e0bec62c1f5 --- /dev/null +++ b/third_party/MMAudio/mmaudio/utils/synthesize_ema.py @@ -0,0 +1,19 @@ +from typing import Optional + +from nitrous_ema import PostHocEMA +from omegaconf import DictConfig + +from mmaudio.model.networks import get_my_mmaudio + + +def synthesize_ema(cfg: DictConfig, sigma: float, step: Optional[int]): + vae = get_my_mmaudio(cfg.model) + emas = PostHocEMA(vae, + sigma_rels=cfg.ema.sigma_rels, + update_every=cfg.ema.update_every, + checkpoint_every_num_steps=cfg.ema.checkpoint_every, + checkpoint_folder=cfg.ema.checkpoint_folder) + + synthesized_ema = emas.synthesize_ema_model(sigma_rel=sigma, step=step, device='cpu') + state_dict = synthesized_ema.ema_model.state_dict() + return state_dict diff --git a/third_party/MMAudio/mmaudio/utils/tensor_utils.py b/third_party/MMAudio/mmaudio/utils/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b650955b04ce097d0a03bbafb6424f9528c631c2 --- /dev/null +++ b/third_party/MMAudio/mmaudio/utils/tensor_utils.py @@ -0,0 +1,14 @@ +import torch + + +def distribute_into_histogram(loss: torch.Tensor, + t: torch.Tensor, + num_bins: int = 25) -> tuple[torch.Tensor, torch.Tensor]: + loss = loss.detach().flatten() + t = t.detach().flatten() + t = (t * num_bins).long() + hist = torch.zeros(num_bins, device=loss.device) + count = torch.zeros(num_bins, device=loss.device) + hist.scatter_add_(0, t, loss) + count.scatter_add_(0, t, torch.ones_like(loss)) + return hist, count diff --git a/third_party/MMAudio/mmaudio/utils/time_estimator.py b/third_party/MMAudio/mmaudio/utils/time_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..62ff3ca189cda8f9524c11196fdc292eedb1d354 --- /dev/null +++ b/third_party/MMAudio/mmaudio/utils/time_estimator.py @@ -0,0 +1,72 @@ +import time + + +class TimeEstimator: + + def __init__(self, total_iter: int, step_size: int, ema_alpha: float = 0.7): + self.avg_time_window = [] # window-based average + self.exp_avg_time = None # exponential moving average + self.alpha = ema_alpha # for exponential moving average + + self.last_time = time.time() # would not be accurate for the first iteration but well + self.total_iter = total_iter + self.step_size = step_size + + self._buffering_exp = True + + # call this at a fixed interval + # does not have to be every step + def update(self): + curr_time = time.time() + time_per_iter = curr_time - self.last_time + self.last_time = curr_time + + self.avg_time_window.append(time_per_iter) + + if self._buffering_exp: + if self.exp_avg_time is not None: + # discard the first iteration call to not pollute the ema + self._buffering_exp = False + self.exp_avg_time = time_per_iter + else: + self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter + + def get_est_remaining(self, it: int): + if self.exp_avg_time is None: + return 0 + + remaining_iter = self.total_iter - it + return remaining_iter * self.exp_avg_time / self.step_size + + def get_and_reset_avg_time(self): + avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size + self.avg_time_window = [] + return avg + + +class PartialTimeEstimator(TimeEstimator): + """ + Used where the start_time and the end_time do not align + """ + + def update(self): + raise RuntimeError('Please use start() and end() for PartialTimeEstimator') + + def start(self): + self.last_time = time.time() + + def end(self): + assert self.last_time is not None, 'Please call start() before calling end()' + curr_time = time.time() + time_per_iter = curr_time - self.last_time + self.last_time = None + + self.avg_time_window.append(time_per_iter) + + if self._buffering_exp: + if self.exp_avg_time is not None: + # discard the first iteration call to not pollute the ema + self._buffering_exp = False + self.exp_avg_time = time_per_iter + else: + self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter diff --git a/third_party/MMAudio/mmaudio/utils/timezone.py b/third_party/MMAudio/mmaudio/utils/timezone.py new file mode 100644 index 0000000000000000000000000000000000000000..4c7f0e6e753816a421f8e5d829ac131c95192a03 --- /dev/null +++ b/third_party/MMAudio/mmaudio/utils/timezone.py @@ -0,0 +1 @@ +my_timezone = 'US/Central' diff --git a/third_party/MMAudio/mmaudio/utils/video_joiner.py b/third_party/MMAudio/mmaudio/utils/video_joiner.py new file mode 100644 index 0000000000000000000000000000000000000000..1a05ae84a079e03f9af96bb2dc0bf38f004732ca --- /dev/null +++ b/third_party/MMAudio/mmaudio/utils/video_joiner.py @@ -0,0 +1,66 @@ +from pathlib import Path +from typing import Union + +import torch +from torio.io import StreamingMediaDecoder, StreamingMediaEncoder + + +class VideoJoiner: + + def __init__(self, src_root: Union[str, Path], output_root: Union[str, Path], sample_rate: int, + duration_seconds: float): + self.src_root = Path(src_root) + self.output_root = Path(output_root) + self.sample_rate = sample_rate + self.duration_seconds = duration_seconds + + self.output_root.mkdir(parents=True, exist_ok=True) + + def join(self, video_id: str, output_name: str, audio: torch.Tensor): + video_path = self.src_root / f'{video_id}.mp4' + output_path = self.output_root / f'{output_name}.mp4' + merge_audio_into_video(video_path, output_path, audio, self.sample_rate, + self.duration_seconds) + + +def merge_audio_into_video(video_path: Union[str, Path], output_path: Union[str, Path], + audio: torch.Tensor, sample_rate: int, duration_seconds: float): + # audio: (num_samples, num_channels=1/2) + + frame_rate = 24 + # read the video + reader = StreamingMediaDecoder(video_path) + reader.add_basic_video_stream( + frames_per_chunk=int(frame_rate * duration_seconds), + # buffer_chunk_size=1, # does not work with this -- extracted audio would be too short + format="rgb24", + frame_rate=frame_rate, + ) + + reader.fill_buffer() + video_chunk = reader.pop_chunks()[0] + t, _, h, w = video_chunk.shape + + writer = StreamingMediaEncoder(output_path) + writer.add_audio_stream( + sample_rate=sample_rate, + num_channels=audio.shape[-1], + encoder="libmp3lame", + ) + writer.add_video_stream(frame_rate=frame_rate, + width=w, + height=h, + format="rgb24", + encoder="libx264", + encoder_format="yuv420p") + + with writer.open(): + writer.write_audio_chunk(0, audio.float()) + writer.write_video_chunk(1, video_chunk) + + +if __name__ == '__main__': + # Usage example + import sys + audio = torch.randn(16000 * 4, 1) + merge_audio_into_video(sys.argv[1], sys.argv[2], audio, 16000, 4) diff --git a/third_party/MusicSourceSeparationTraining/LICENSE b/third_party/MusicSourceSeparationTraining/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..9d7186e88bca9975edd65956cd499fa60bd04251 --- /dev/null +++ b/third_party/MusicSourceSeparationTraining/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Roman Solovyev (ZFTurbo) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/MusicSourceSeparationTraining/__pycache__/utils.cpython-310.pyc b/third_party/MusicSourceSeparationTraining/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40c9f273cc431f158ee006c171f16a2720979721 Binary files /dev/null and b/third_party/MusicSourceSeparationTraining/__pycache__/utils.cpython-310.pyc differ diff --git a/third_party/MusicSourceSeparationTraining/models/bs_roformer/__init__.py b/third_party/MusicSourceSeparationTraining/models/bs_roformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..980e0afa5b7b4fd66168bce6905a94e7c91c380e --- /dev/null +++ b/third_party/MusicSourceSeparationTraining/models/bs_roformer/__init__.py @@ -0,0 +1,2 @@ +from models.bs_roformer.bs_roformer import BSRoformer +from models.bs_roformer.mel_band_roformer import MelBandRoformer diff --git a/third_party/MusicSourceSeparationTraining/models/bs_roformer/__pycache__/__init__.cpython-310.pyc b/third_party/MusicSourceSeparationTraining/models/bs_roformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..691ea4d427783a86db00e4685692484064f6de45 Binary files /dev/null and b/third_party/MusicSourceSeparationTraining/models/bs_roformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/MusicSourceSeparationTraining/models/bs_roformer/__pycache__/attend.cpython-310.pyc b/third_party/MusicSourceSeparationTraining/models/bs_roformer/__pycache__/attend.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4240dee88df09225f5aa31963589c16581b4b472 Binary files /dev/null and b/third_party/MusicSourceSeparationTraining/models/bs_roformer/__pycache__/attend.cpython-310.pyc differ diff --git a/third_party/MusicSourceSeparationTraining/models/bs_roformer/__pycache__/bs_roformer.cpython-310.pyc b/third_party/MusicSourceSeparationTraining/models/bs_roformer/__pycache__/bs_roformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d9b39e8e48c00e0593d315dfaded68a357e0a04 Binary files /dev/null and b/third_party/MusicSourceSeparationTraining/models/bs_roformer/__pycache__/bs_roformer.cpython-310.pyc differ diff --git a/third_party/MusicSourceSeparationTraining/models/bs_roformer/__pycache__/mel_band_roformer.cpython-310.pyc b/third_party/MusicSourceSeparationTraining/models/bs_roformer/__pycache__/mel_band_roformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26bfed7129e411219e6ff38bf58347ff83aaf354 Binary files /dev/null and b/third_party/MusicSourceSeparationTraining/models/bs_roformer/__pycache__/mel_band_roformer.cpython-310.pyc differ diff --git a/third_party/MusicSourceSeparationTraining/models/bs_roformer/attend.py b/third_party/MusicSourceSeparationTraining/models/bs_roformer/attend.py new file mode 100644 index 0000000000000000000000000000000000000000..d6dc4b3079cff5b3c8c90cea8df2301afd18918b --- /dev/null +++ b/third_party/MusicSourceSeparationTraining/models/bs_roformer/attend.py @@ -0,0 +1,126 @@ +from functools import wraps +from packaging import version +from collections import namedtuple + +import os +import torch +from torch import nn, einsum +import torch.nn.functional as F + +from einops import rearrange, reduce + +# constants + +FlashAttentionConfig = namedtuple('FlashAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) + +# helpers + +def exists(val): + return val is not None + +def default(v, d): + return v if exists(v) else d + +def once(fn): + called = False + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + return inner + +print_once = once(print) + +# main class + +class Attend(nn.Module): + def __init__( + self, + dropout = 0., + flash = False, + scale = None + ): + super().__init__() + self.scale = scale + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.flash = flash + assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' + + # determine efficient attention configs for cuda and cpu + + self.cpu_config = FlashAttentionConfig(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + device_version = version.parse(f'{device_properties.major}.{device_properties.minor}') + + if device_version >= version.parse('8.0'): + if os.name == 'nt': + print_once('Windows OS detected, using math or mem efficient attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(False, True, True) + else: + print_once('GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(True, False, False) + else: + print_once('GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda') + self.cuda_config = FlashAttentionConfig(False, True, True) + + def flash_attn(self, q, k, v): + _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device + + if exists(self.scale): + default_scale = q.shape[-1] ** -0.5 + q = q * (self.scale / default_scale) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, k, v, + dropout_p = self.dropout if self.training else 0. + ) + + return out + + def forward(self, q, k, v): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + q_len, k_len, device = q.shape[-2], k.shape[-2], q.device + + scale = default(self.scale, q.shape[-1] ** -0.5) + + if self.flash: + return self.flash_attn(q, k, v) + + # similarity + + sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, b h j d -> b h i d", attn, v) + + return out diff --git a/third_party/MusicSourceSeparationTraining/models/bs_roformer/bs_roformer.py b/third_party/MusicSourceSeparationTraining/models/bs_roformer/bs_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..195593ed4794f808034ca422993bc63d68ae9643 --- /dev/null +++ b/third_party/MusicSourceSeparationTraining/models/bs_roformer/bs_roformer.py @@ -0,0 +1,622 @@ +from functools import partial + +import torch +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from models.bs_roformer.attend import Attend +from torch.utils.checkpoint import checkpoint + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + +from rotary_embedding_torch import RotaryEmbedding + +from einops import rearrange, pack, unpack +from einops.layers.torch import Rearrange + +# helper functions + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# norm + +def l2norm(t): + return F.normalize(t, dim = -1, p = 2) + + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +# attention + +class FeedForward(Module): + def __init__( + self, + dim, + mult=4, + dropout=0. + ): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0., + rotary_embed=None, + flash=True + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + self.attend = Attend(flash=flash, dropout=dropout) + + self.norm = RMSNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=False), + nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) + + if exists(self.rotary_embed): + q = self.rotary_embed.rotate_queries_or_keys(q) + k = self.rotary_embed.rotate_queries_or_keys(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + @beartype + def __init__( + self, + *, + dim, + dim_head=32, + heads=8, + scale=8, + flash=False, + dropout=0. + ): + super().__init__() + dim_inner = dim_head * heads + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads) + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.attend = Attend( + scale=scale, + dropout=dropout, + flash=flash + ) + + self.to_out = nn.Sequential( + Rearrange('b h d n -> b n (h d)'), + nn.Linear(dim_inner, dim, bias=False) + ) + + def forward( + self, + x + ): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + if linear_attn: + attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn) + else: + attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, + rotary_embed=rotary_embed, flash=flash_attn) + + self.layers.append(ModuleList([ + attn, + FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) + ])) + + self.norm = RMSNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +# bandsplit module + +class BandSplit(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...] + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential( + RMSNorm(dim_in), + nn.Linear(dim_in, dim) + ) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP( + dim_in, + dim_out, + dim_hidden=None, + depth=1, + activation=nn.Tanh +): + dim_hidden = default(dim_hidden, dim_in) + + net = [] + dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...], + depth, + mlp_expansion_factor=4 + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + net = [] + + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), + nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# main class + +DEFAULT_FREQS_PER_BANDS = ( + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 12, 12, 12, 12, 12, 12, 12, 12, + 24, 24, 24, 24, 24, 24, 24, 24, + 48, 48, 48, 48, 48, 48, 48, 48, + 128, 129, +) + + +class BSRoformer(Module): + + @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS, + # in the paper, they divide into ~60 bands, test with 1 for starters + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + flash_attn=True, + dim_freqs_in=1025, + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=2, + multi_stft_resolution_loss_weight=1., + multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window, + mlp_expansion_factor=4, + use_torch_checkpoint=False, + skip_connection=False, + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + self.use_torch_checkpoint = use_torch_checkpoint + self.skip_connection = skip_connection + + self.layers = ModuleList([]) + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn, + norm_output=False + ) + + time_rotary_embed = RotaryEmbedding(dim=dim_head) + freq_rotary_embed = RotaryEmbedding(dim=dim_head) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs)) + tran_modules.append( + Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs) + ) + tran_modules.append( + Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.final_norm = RMSNorm(dim) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized + ) + + self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) + + freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1] + + assert len(freqs_per_bands) > 1 + assert sum( + freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}' + + freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands) + + self.band_split = BandSplit( + dim=dim, + dim_inputs=freqs_per_bands_with_complex + ) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth, + mlp_expansion_factor=mlp_expansion_factor, + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, + normalized=multi_stft_normalized + ) + + def forward( + self, + raw_audio, + target=None, + return_loss_breakdown=False + ): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + # defining whether model is loaded on MPS (MacOS GPU accelerator) + x_is_mps = True if device.type == "mps" else False + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, 'b t -> b 1 t') + + channels = raw_audio.shape[1] + assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') + + stft_window = self.stft_window_fn(device=device) + + # RuntimeError: FFT operations are only supported on MacOS 14+ + # Since it's tedious to define whether we're on correct MacOS version - simple try-catch is used + try: + stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) + except: + stft_repr = torch.stft(raw_audio.cpu() if x_is_mps else raw_audio, **self.stft_kwargs, + window=stft_window.cpu() if x_is_mps else stft_window, return_complex=True).to( + device) + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') + + # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c') + + x = rearrange(stft_repr, 'b f t c -> b t (f c)') + + if self.use_torch_checkpoint: + x = checkpoint(self.band_split, x, use_reentrant=False) + else: + x = self.band_split(x) + + # axial / hierarchical attention + + store = [None] * len(self.layers) + for i, transformer_block in enumerate(self.layers): + + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = transformer_block + + x, ft_ps = pack([x], 'b * d') + if self.use_torch_checkpoint: + x = checkpoint(linear_transformer, x, use_reentrant=False) + else: + x = linear_transformer(x) + x, = unpack(x, ft_ps, 'b * d') + else: + time_transformer, freq_transformer = transformer_block + + if self.skip_connection: + # Sum all previous + for j in range(i): + x = x + store[j] + + x = rearrange(x, 'b t f d -> b f t d') + x, ps = pack([x], '* t d') + + if self.use_torch_checkpoint: + x = checkpoint(time_transformer, x, use_reentrant=False) + else: + x = time_transformer(x) + + x, = unpack(x, ps, '* t d') + x = rearrange(x, 'b f t d -> b t f d') + x, ps = pack([x], '* f d') + + if self.use_torch_checkpoint: + x = checkpoint(freq_transformer, x, use_reentrant=False) + else: + x = freq_transformer(x) + + x, = unpack(x, ps, '* f d') + + if self.skip_connection: + store[i] = x + + x = self.final_norm(x) + + num_stems = len(self.mask_estimators) + + if self.use_torch_checkpoint: + mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1) + else: + mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + mask = torch.view_as_complex(mask) + + stft_repr = stft_repr * mask + + # istft + + stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) + + # same as torch.stft() fix for MacOS MPS above + try: + recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1]) + except: + recon_audio = torch.istft(stft_repr.cpu() if x_is_mps else stft_repr, **self.stft_kwargs, window=stft_window.cpu() if x_is_mps else stft_window, return_complex=False, length=raw_audio.shape[-1]).to(device) + + recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', s=self.audio_channels, n=num_stems) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') + + # if a target is passed in, calculate loss for learning + + if not exists(target): + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, '... t -> ... 1 t') + + target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0. + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) + target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) + + weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) \ No newline at end of file diff --git a/third_party/MusicSourceSeparationTraining/models/bs_roformer/mel_band_roformer.py b/third_party/MusicSourceSeparationTraining/models/bs_roformer/mel_band_roformer.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d2c40f2e00eb0b99521e6506cf0c0027561541 --- /dev/null +++ b/third_party/MusicSourceSeparationTraining/models/bs_roformer/mel_band_roformer.py @@ -0,0 +1,668 @@ +from functools import partial + +import torch +from torch import nn, einsum, Tensor +from torch.nn import Module, ModuleList +import torch.nn.functional as F + +from models.bs_roformer.attend import Attend +from torch.utils.checkpoint import checkpoint + +from beartype.typing import Tuple, Optional, List, Callable +from beartype import beartype + +from rotary_embedding_torch import RotaryEmbedding + +from einops import rearrange, pack, unpack, reduce, repeat +from einops.layers.torch import Rearrange + +from librosa import filters + + +# helper functions + +def exists(val): + return val is not None + + +def default(v, d): + return v if exists(v) else d + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def pad_at_dim(t, pad, dim=-1, value=0.): + dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = ((0, 0) * dims_from_right) + return F.pad(t, (*zeros, *pad), value=value) + + +def l2norm(t): + return F.normalize(t, dim=-1, p=2) + + +# norm + +class RMSNorm(Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim=-1) * self.scale * self.gamma + + +# attention + +class FeedForward(Module): + def __init__( + self, + dim, + mult=4, + dropout=0. + ): + super().__init__() + dim_inner = int(dim * mult) + self.net = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim_inner), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim_inner, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.net(x) + + +class Attention(Module): + def __init__( + self, + dim, + heads=8, + dim_head=64, + dropout=0., + rotary_embed=None, + flash=True + ): + super().__init__() + self.heads = heads + self.scale = dim_head ** -0.5 + dim_inner = heads * dim_head + + self.rotary_embed = rotary_embed + + self.attend = Attend(flash=flash, dropout=dropout) + + self.norm = RMSNorm(dim) + self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) + + self.to_gates = nn.Linear(dim, heads) + + self.to_out = nn.Sequential( + nn.Linear(dim_inner, dim, bias=False), + nn.Dropout(dropout) + ) + + def forward(self, x): + x = self.norm(x) + + q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads) + + if exists(self.rotary_embed): + q = self.rotary_embed.rotate_queries_or_keys(q) + k = self.rotary_embed.rotate_queries_or_keys(k) + + out = self.attend(q, k, v) + + gates = self.to_gates(x) + out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + + +class LinearAttention(Module): + """ + this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al. + """ + + @beartype + def __init__( + self, + *, + dim, + dim_head=32, + heads=8, + scale=8, + flash=False, + dropout=0. + ): + super().__init__() + dim_inner = dim_head * heads + self.norm = RMSNorm(dim) + + self.to_qkv = nn.Sequential( + nn.Linear(dim, dim_inner * 3, bias=False), + Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads) + ) + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.attend = Attend( + scale=scale, + dropout=dropout, + flash=flash + ) + + self.to_out = nn.Sequential( + Rearrange('b h d n -> b n (h d)'), + nn.Linear(dim_inner, dim, bias=False) + ) + + def forward( + self, + x + ): + x = self.norm(x) + + q, k, v = self.to_qkv(x) + + q, k = map(l2norm, (q, k)) + q = q * self.temperature.exp() + + out = self.attend(q, k, v) + + return self.to_out(out) + + +class Transformer(Module): + def __init__( + self, + *, + dim, + depth, + dim_head=64, + heads=8, + attn_dropout=0., + ff_dropout=0., + ff_mult=4, + norm_output=True, + rotary_embed=None, + flash_attn=True, + linear_attn=False + ): + super().__init__() + self.layers = ModuleList([]) + + for _ in range(depth): + if linear_attn: + attn = LinearAttention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, flash=flash_attn) + else: + attn = Attention(dim=dim, dim_head=dim_head, heads=heads, dropout=attn_dropout, + rotary_embed=rotary_embed, flash=flash_attn) + + self.layers.append(ModuleList([ + attn, + FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) + ])) + + self.norm = RMSNorm(dim) if norm_output else nn.Identity() + + def forward(self, x): + + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) + + +# bandsplit module + +class BandSplit(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...] + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_features = ModuleList([]) + + for dim_in in dim_inputs: + net = nn.Sequential( + RMSNorm(dim_in), + nn.Linear(dim_in, dim) + ) + + self.to_features.append(net) + + def forward(self, x): + x = x.split(self.dim_inputs, dim=-1) + + outs = [] + for split_input, to_feature in zip(x, self.to_features): + split_output = to_feature(split_input) + outs.append(split_output) + + return torch.stack(outs, dim=-2) + + +def MLP( + dim_in, + dim_out, + dim_hidden=None, + depth=1, + activation=nn.Tanh +): + dim_hidden = default(dim_hidden, dim_in) + + net = [] + dims = (dim_in, *((dim_hidden,) * depth), dim_out) + + for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): + is_last = ind == (len(dims) - 2) + + net.append(nn.Linear(layer_dim_in, layer_dim_out)) + + if is_last: + continue + + net.append(activation()) + + return nn.Sequential(*net) + + +class MaskEstimator(Module): + @beartype + def __init__( + self, + dim, + dim_inputs: Tuple[int, ...], + depth, + mlp_expansion_factor=4 + ): + super().__init__() + self.dim_inputs = dim_inputs + self.to_freqs = ModuleList([]) + dim_hidden = dim * mlp_expansion_factor + + for dim_in in dim_inputs: + net = [] + + mlp = nn.Sequential( + MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), + nn.GLU(dim=-1) + ) + + self.to_freqs.append(mlp) + + def forward(self, x): + x = x.unbind(dim=-2) + + outs = [] + + for band_features, mlp in zip(x, self.to_freqs): + freq_out = mlp(band_features) + outs.append(freq_out) + + return torch.cat(outs, dim=-1) + + +# main class + +class MelBandRoformer(Module): + + @beartype + def __init__( + self, + dim, + *, + depth, + stereo=False, + num_stems=1, + time_transformer_depth=2, + freq_transformer_depth=2, + linear_transformer_depth=0, + num_bands=60, + dim_head=64, + heads=8, + attn_dropout=0.1, + ff_dropout=0.1, + flash_attn=True, + dim_freqs_in=1025, + sample_rate=44100, # needed for mel filter bank from librosa + stft_n_fft=2048, + stft_hop_length=512, + # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction + stft_win_length=2048, + stft_normalized=False, + stft_window_fn: Optional[Callable] = None, + mask_estimator_depth=1, + multi_stft_resolution_loss_weight=1., + multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256), + multi_stft_hop_size=147, + multi_stft_normalized=False, + multi_stft_window_fn: Callable = torch.hann_window, + match_input_audio_length=False, # if True, pad output tensor to match length of input tensor + mlp_expansion_factor=4, + use_torch_checkpoint=False, + skip_connection=False, + ): + super().__init__() + + self.stereo = stereo + self.audio_channels = 2 if stereo else 1 + self.num_stems = num_stems + self.use_torch_checkpoint = use_torch_checkpoint + self.skip_connection = skip_connection + + self.layers = ModuleList([]) + + transformer_kwargs = dict( + dim=dim, + heads=heads, + dim_head=dim_head, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + flash_attn=flash_attn + ) + + time_rotary_embed = RotaryEmbedding(dim=dim_head) + freq_rotary_embed = RotaryEmbedding(dim=dim_head) + + for _ in range(depth): + tran_modules = [] + if linear_transformer_depth > 0: + tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs)) + tran_modules.append( + Transformer(depth=time_transformer_depth, rotary_embed=time_rotary_embed, **transformer_kwargs) + ) + tran_modules.append( + Transformer(depth=freq_transformer_depth, rotary_embed=freq_rotary_embed, **transformer_kwargs) + ) + self.layers.append(nn.ModuleList(tran_modules)) + + self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length) + + self.stft_kwargs = dict( + n_fft=stft_n_fft, + hop_length=stft_hop_length, + win_length=stft_win_length, + normalized=stft_normalized + ) + + freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1] + + # create mel filter bank + # with librosa.filters.mel as in section 2 of paper + + mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands) + + mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy) + + # for some reason, it doesn't include the first freq? just force a value for now + + mel_filter_bank[0][0] = 1. + + # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position, + # so let's force a positive value + + mel_filter_bank[-1, -1] = 1. + + # binary as in paper (then estimated masks are averaged for overlapping regions) + + freqs_per_band = mel_filter_bank > 0 + assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now' + + repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands) + freq_indices = repeated_freq_indices[freqs_per_band] + + if stereo: + freq_indices = repeat(freq_indices, 'f -> f s', s=2) + freq_indices = freq_indices * 2 + torch.arange(2) + freq_indices = rearrange(freq_indices, 'f s -> (f s)') + + self.register_buffer('freq_indices', freq_indices, persistent=False) + self.register_buffer('freqs_per_band', freqs_per_band, persistent=False) + + num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum') + num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum') + + self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False) + self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False) + + # band split and mask estimator + + freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist()) + + self.band_split = BandSplit( + dim=dim, + dim_inputs=freqs_per_bands_with_complex + ) + + self.mask_estimators = nn.ModuleList([]) + + for _ in range(num_stems): + mask_estimator = MaskEstimator( + dim=dim, + dim_inputs=freqs_per_bands_with_complex, + depth=mask_estimator_depth, + mlp_expansion_factor=mlp_expansion_factor, + ) + + self.mask_estimators.append(mask_estimator) + + # for the multi-resolution stft loss + + self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight + self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes + self.multi_stft_n_fft = stft_n_fft + self.multi_stft_window_fn = multi_stft_window_fn + + self.multi_stft_kwargs = dict( + hop_length=multi_stft_hop_size, + normalized=multi_stft_normalized + ) + + self.match_input_audio_length = match_input_audio_length + + def forward( + self, + raw_audio, + target=None, + return_loss_breakdown=False + ): + """ + einops + + b - batch + f - freq + t - time + s - audio channel (1 for mono, 2 for stereo) + n - number of 'stems' + c - complex (2) + d - feature dimension + """ + + device = raw_audio.device + + if raw_audio.ndim == 2: + raw_audio = rearrange(raw_audio, 'b t -> b 1 t') + + batch, channels, raw_audio_length = raw_audio.shape + + istft_length = raw_audio_length if self.match_input_audio_length else None + + assert (not self.stereo and channels == 1) or ( + self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)' + + # to stft + + raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') + + stft_window = self.stft_window_fn(device=device) + + stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True) + stft_repr = torch.view_as_real(stft_repr) + + stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') + + # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting + stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c') + + # index out all frequencies for all frequency ranges across bands ascending in one go + + batch_arange = torch.arange(batch, device=device)[..., None] + + # account for stereo + + x = stft_repr[batch_arange, self.freq_indices] + + # fold the complex (real and imag) into the frequencies dimension + + x = rearrange(x, 'b f t c -> b t (f c)') + + if self.use_torch_checkpoint: + x = checkpoint(self.band_split, x, use_reentrant=False) + else: + x = self.band_split(x) + + # axial / hierarchical attention + + store = [None] * len(self.layers) + for i, transformer_block in enumerate(self.layers): + + if len(transformer_block) == 3: + linear_transformer, time_transformer, freq_transformer = transformer_block + + x, ft_ps = pack([x], 'b * d') + if self.use_torch_checkpoint: + x = checkpoint(linear_transformer, x, use_reentrant=False) + else: + x = linear_transformer(x) + x, = unpack(x, ft_ps, 'b * d') + else: + time_transformer, freq_transformer = transformer_block + + if self.skip_connection: + # Sum all previous + for j in range(i): + x = x + store[j] + + x = rearrange(x, 'b t f d -> b f t d') + x, ps = pack([x], '* t d') + + if self.use_torch_checkpoint: + x = checkpoint(time_transformer, x, use_reentrant=False) + else: + x = time_transformer(x) + + x, = unpack(x, ps, '* t d') + x = rearrange(x, 'b f t d -> b t f d') + x, ps = pack([x], '* f d') + + if self.use_torch_checkpoint: + x = checkpoint(freq_transformer, x, use_reentrant=False) + else: + x = freq_transformer(x) + + x, = unpack(x, ps, '* f d') + + if self.skip_connection: + store[i] = x + + num_stems = len(self.mask_estimators) + if self.use_torch_checkpoint: + masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators], dim=1) + else: + masks = torch.stack([fn(x) for fn in self.mask_estimators], dim=1) + masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2) + + # modulate frequency representation + + stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') + + # complex number multiplication + + stft_repr = torch.view_as_complex(stft_repr) + masks = torch.view_as_complex(masks) + + masks = masks.type(stft_repr.dtype) + + # need to average the estimated mask for the overlapped frequencies + + scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1]) + + stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems) + masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks) + + denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels) + + masks_averaged = masks_summed / denom.clamp(min=1e-8) + + # modulate stft repr with estimated mask + + stft_repr = stft_repr * masks_averaged + + # istft + + stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels) + + recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, + length=istft_length) + + recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems) + + if num_stems == 1: + recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t') + + # if a target is passed in, calculate loss for learning + + if not exists(target): + return recon_audio + + if self.num_stems > 1: + assert target.ndim == 4 and target.shape[1] == self.num_stems + + if target.ndim == 2: + target = rearrange(target, '... t -> ... 1 t') + + target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft + + loss = F.l1_loss(recon_audio, target) + + multi_stft_resolution_loss = 0. + + for window_size in self.multi_stft_resolutions_window_sizes: + res_stft_kwargs = dict( + n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft + win_length=window_size, + return_complex=True, + window=self.multi_stft_window_fn(window_size, device=device), + **self.multi_stft_kwargs, + ) + + recon_Y = torch.stft(rearrange(recon_audio, '... s t -> (... s) t'), **res_stft_kwargs) + target_Y = torch.stft(rearrange(target, '... s t -> (... s) t'), **res_stft_kwargs) + + multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y) + + weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight + + total_loss = loss + weighted_multi_resolution_loss + + if not return_loss_breakdown: + return total_loss + + return total_loss, (loss, multi_stft_resolution_loss) diff --git a/third_party/MusicSourceSeparationTraining/utils.py b/third_party/MusicSourceSeparationTraining/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c277fc66889576987120d969bf6344349a457ef --- /dev/null +++ b/third_party/MusicSourceSeparationTraining/utils.py @@ -0,0 +1,665 @@ +# coding: utf-8 +__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/' +import sys +import os +sys.path.append(os.path.dirname(__file__)) + + +import argparse +import numpy as np +import torch +import torch.nn as nn +import yaml +import os +import soundfile as sf +import matplotlib.pyplot as plt +from ml_collections import ConfigDict +from omegaconf import OmegaConf +from tqdm.auto import tqdm +from typing import Dict, List, Tuple, Any, Union +import loralib as lora + + +def load_config(model_type: str, config_path: str) -> Union[ConfigDict, OmegaConf]: + """ + Load the configuration from the specified path based on the model type. + + Parameters: + ---------- + model_type : str + The type of model to load (e.g., 'htdemucs', 'mdx23c', etc.). + config_path : str + The path to the YAML or OmegaConf configuration file. + + Returns: + ------- + config : Any + The loaded configuration, which can be in different formats (e.g., OmegaConf or ConfigDict). + + Raises: + ------ + FileNotFoundError: + If the configuration file at `config_path` is not found. + ValueError: + If there is an error loading the configuration file. + """ + try: + with open(config_path, 'r') as f: + if model_type == 'htdemucs': + config = OmegaConf.load(config_path) + else: + config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader)) + return config + except FileNotFoundError: + raise FileNotFoundError(f"Configuration file not found at {config_path}") + except Exception as e: + raise ValueError(f"Error loading configuration: {e}") + +''' +def get_model_from_config(model_type: str, config_path: str) -> Tuple: + """ + Load the model specified by the model type and configuration file. + + Parameters: + ---------- + model_type : str + The type of model to load (e.g., 'mdx23c', 'htdemucs', 'scnet', etc.). + config_path : str + The path to the configuration file (YAML or OmegaConf format). + + Returns: + ------- + model : nn.Module or None + The initialized model based on the `model_type`, or None if the model type is not recognized. + config : Any + The configuration used to initialize the model. This could be in different formats + depending on the model type (e.g., OmegaConf, ConfigDict). + + Raises: + ------ + ValueError: + If the `model_type` is unknown or an error occurs during model initialization. + """ + + config = load_config(model_type, config_path) + + if model_type == 'mdx23c': + from models.mdx23c_tfc_tdf_v3 import TFC_TDF_net + model = TFC_TDF_net(config) + elif model_type == 'htdemucs': + from models.demucs4ht import get_model + model = get_model(config) + elif model_type == 'segm_models': + from models.segm_models import Segm_Models_Net + model = Segm_Models_Net(config) + elif model_type == 'torchseg': + from models.torchseg_models import Torchseg_Net + model = Torchseg_Net(config) + elif model_type == 'mel_band_roformer': + from models.bs_roformer import MelBandRoformer + model = MelBandRoformer(**dict(config.model)) + elif model_type == 'bs_roformer': + from models.bs_roformer import BSRoformer + model = BSRoformer(**dict(config.model)) + elif model_type == 'swin_upernet': + from models.upernet_swin_transformers import Swin_UperNet_Model + model = Swin_UperNet_Model(config) + elif model_type == 'bandit': + from models.bandit.core.model import MultiMaskMultiSourceBandSplitRNNSimple + model = MultiMaskMultiSourceBandSplitRNNSimple(**config.model) + elif model_type == 'bandit_v2': + from models.bandit_v2.bandit import Bandit + model = Bandit(**config.kwargs) + elif model_type == 'scnet_unofficial': + from models.scnet_unofficial import SCNet + model = SCNet(**config.model) + elif model_type == 'scnet': + from models.scnet import SCNet + model = SCNet(**config.model) + elif model_type == 'apollo': + from models.look2hear.models import BaseModel + model = BaseModel.apollo(**config.model) + elif model_type == 'bs_mamba2': + from models.ts_bs_mamba2 import Separator + model = Separator(**config.model) + else: + raise ValueError(f"Unknown model type: {model_type}") + + return model, config +''' + +def read_audio_transposed(path: str, instr: str = None, skip_err: bool = False) -> Tuple[np.ndarray, int]: + """ + Reads an audio file, ensuring mono audio is converted to two-dimensional format, + and transposes the data to have channels as the first dimension. + Parameters + ---------- + path : str + Path to the audio file. + skip_err: bool + If true, not raise errors + instr: + name of instument + Returns + ------- + Tuple[np.ndarray, int] + A tuple containing: + - Transposed audio data as a NumPy array with shape (channels, length). + For mono audio, the shape will be (1, length). + - Sampling rate (int), e.g., 44100. + """ + + try: + mix, sr = sf.read(path) + except Exception as e: + if skip_err: + print(f"No stem {instr}: skip!") + return None, None + else: + raise RuntimeError(f"Error reading the file at {path}: {e}") + else: + if len(mix.shape) == 1: # For mono audio + mix = np.expand_dims(mix, axis=-1) + return mix.T, sr + + +def normalize_audio(audio: np.ndarray) -> tuple[np.ndarray, Dict[str, float]]: + """ + Normalize an audio signal by subtracting the mean and dividing by the standard deviation. + + Parameters: + ---------- + audio : np.ndarray + Input audio array with shape (channels, time) or (time,). + + Returns: + ------- + tuple[np.ndarray, dict[str, float]] + - Normalized audio array with the same shape as the input. + - Dictionary containing the mean and standard deviation of the original audio. + """ + + mono = audio.mean(0) + mean, std = mono.mean(), mono.std() + return (audio - mean) / std, {"mean": mean, "std": std} + + +def denormalize_audio(audio: np.ndarray, norm_params: Dict[str, float]) -> np.ndarray: + """ + Denormalize an audio signal by reversing the normalization process (multiplying by the standard deviation + and adding the mean). + + Parameters: + ---------- + audio : np.ndarray + Normalized audio array to be denormalized. + norm_params : dict[str, float] + Dictionary containing the 'mean' and 'std' values used for normalization. + + Returns: + ------- + np.ndarray + Denormalized audio array with the same shape as the input. + """ + + return audio * norm_params["std"] + norm_params["mean"] + + +def apply_tta( + config, + model: torch.nn.Module, + mix: torch.Tensor, + waveforms_orig: Dict[str, torch.Tensor], + device: torch.device, + model_type: str +) -> Dict[str, torch.Tensor]: + """ + Apply Test-Time Augmentation (TTA) for source separation. + + This function processes the input mixture with test-time augmentations, including + channel inversion and polarity inversion, to enhance the separation results. The + results from all augmentations are averaged to produce the final output. + + Parameters: + ---------- + config : Any + Configuration object containing model and processing parameters. + model : torch.nn.Module + The trained model used for source separation. + mix : torch.Tensor + The mixed audio tensor with shape (channels, time). + waveforms_orig : Dict[str, torch.Tensor] + Dictionary of original separated waveforms (before TTA) for each instrument. + device : torch.device + Device (CPU or CUDA) on which the model will be executed. + model_type : str + Type of the model being used (e.g., "demucs", "custom_model"). + + Returns: + ------- + Dict[str, torch.Tensor] + Updated dictionary of separated waveforms after applying TTA. + """ + # Create augmentations: channel inversion and polarity inversion + track_proc_list = [mix[::-1].copy(), -1.0 * mix.copy()] + + # Process each augmented mixture + for i, augmented_mix in enumerate(track_proc_list): + waveforms = demix(config, model, augmented_mix, device, model_type=model_type) + for el in waveforms: + if i == 0: + waveforms_orig[el] += waveforms[el][::-1].copy() + else: + waveforms_orig[el] -= waveforms[el] + + # Average the results across augmentations + for el in waveforms_orig: + waveforms_orig[el] /= len(track_proc_list) + 1 + + return waveforms_orig + + +def _getWindowingArray(window_size: int, fade_size: int) -> torch.Tensor: + """ + Generate a windowing array with a linear fade-in at the beginning and a fade-out at the end. + + This function creates a window of size `window_size` where the first `fade_size` elements + linearly increase from 0 to 1 (fade-in) and the last `fade_size` elements linearly decrease + from 1 to 0 (fade-out). The middle part of the window is filled with ones. + + Parameters: + ---------- + window_size : int + The total size of the window. + fade_size : int + The size of the fade-in and fade-out regions. + + Returns: + ------- + torch.Tensor + A tensor of shape (window_size,) containing the generated windowing array. + + Example: + ------- + If `window_size=10` and `fade_size=3`, the output will be: + tensor([0.0000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.5000, 0.0000]) + """ + + fadein = torch.linspace(0, 1, fade_size) + fadeout = torch.linspace(1, 0, fade_size) + + window = torch.ones(window_size) + window[-fade_size:] = fadeout + window[:fade_size] = fadein + return window + + +def demix( + config: ConfigDict, + model: torch.nn.Module, + mix: torch.Tensor, + device: torch.device, + model_type: str, + pbar: bool = False +) -> Tuple[List[Dict[str, np.ndarray]], np.ndarray]: + """ + Unified function for audio source separation with support for multiple processing modes. + + This function separates audio into its constituent sources using either a generic custom logic + or a Demucs-specific logic. It supports batch processing and overlapping window-based chunking + for efficient and artifact-free separation. + + Parameters: + ---------- + config : ConfigDict + Configuration object containing audio and inference settings. + model : torch.nn.Module + The trained model used for audio source separation. + mix : torch.Tensor + Input audio tensor with shape (channels, time). + device : torch.device + The computation device (CPU or CUDA). + model_type : str, optional + Processing mode: + - "demucs" for logic specific to the Demucs model. + Default is "generic". + pbar : bool, optional + If True, displays a progress bar during chunk processing. Default is False. + + Returns: + ------- + Union[Dict[str, np.ndarray], np.ndarray] + - A dictionary mapping target instruments to separated audio sources if multiple instruments are present. + - A numpy array of the separated source if only one instrument is present. + """ + + mix = torch.tensor(mix, dtype=torch.float32) + + if model_type == 'htdemucs': + mode = 'demucs' + else: + mode = 'generic' + # Define processing parameters based on the mode + if mode == 'demucs': + chunk_size = config.training.samplerate * config.training.segment + num_instruments = len(config.training.instruments) + num_overlap = config.inference.num_overlap + step = chunk_size // num_overlap + else: + chunk_size = config.audio.chunk_size + num_instruments = len(prefer_target_instrument(config)) + num_overlap = config.inference.num_overlap + + fade_size = chunk_size // 10 + step = chunk_size // num_overlap + border = chunk_size - step + length_init = mix.shape[-1] + windowing_array = _getWindowingArray(chunk_size, fade_size) + # Add padding for generic mode to handle edge artifacts + if length_init > 2 * border and border > 0: + mix = nn.functional.pad(mix, (border, border), mode="reflect") + + batch_size = config.inference.batch_size + + use_amp = getattr(config.training, 'use_amp', True) + + with torch.cuda.amp.autocast(enabled=use_amp): + with torch.inference_mode(): + # Initialize result and counter tensors + req_shape = (num_instruments,) + mix.shape + result = torch.zeros(req_shape, dtype=torch.float32) + counter = torch.zeros(req_shape, dtype=torch.float32) + + i = 0 + batch_data = [] + batch_locations = [] + progress_bar = tqdm( + total=mix.shape[1], desc="Processing audio chunks", leave=False + ) if pbar else None + + while i < mix.shape[1]: + # Extract chunk and apply padding if necessary + part = mix[:, i:i + chunk_size].to(device) + chunk_len = part.shape[-1] + if mode == "generic" and chunk_len > chunk_size // 2: + pad_mode = "reflect" + else: + pad_mode = "constant" + part = nn.functional.pad(part, (0, chunk_size - chunk_len), mode=pad_mode, value=0) + + batch_data.append(part) + batch_locations.append((i, chunk_len)) + i += step + + # Process batch if it's full or the end is reached + if len(batch_data) >= batch_size or i >= mix.shape[1]: + arr = torch.stack(batch_data, dim=0) + x = model(arr) + + if mode == "generic": + window = windowing_array.clone() # using clone() fixes the clicks at chunk edges when using batch_size=1 + if i - step == 0: # First audio chunk, no fadein + window[:fade_size] = 1 + elif i >= mix.shape[1]: # Last audio chunk, no fadeout + window[-fade_size:] = 1 + + for j, (start, seg_len) in enumerate(batch_locations): + if mode == "generic": + result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() * window[..., :seg_len] + counter[..., start:start + seg_len] += window[..., :seg_len] + else: + result[..., start:start + seg_len] += x[j, ..., :seg_len].cpu() + counter[..., start:start + seg_len] += 1.0 + + batch_data.clear() + batch_locations.clear() + + if progress_bar: + progress_bar.update(step) + + if progress_bar: + progress_bar.close() + + # Compute final estimated sources + estimated_sources = result / counter + estimated_sources = estimated_sources.cpu().numpy() + np.nan_to_num(estimated_sources, copy=False, nan=0.0) + + # Remove padding for generic mode + if mode == "generic": + if length_init > 2 * border and border > 0: + estimated_sources = estimated_sources[..., border:-border] + + # Return the result as a dictionary or a single array + if mode == "demucs": + instruments = config.training.instruments + else: + instruments = prefer_target_instrument(config) + + ret_data = {k: v for k, v in zip(instruments, estimated_sources)} + + if mode == "demucs" and num_instruments <= 1: + return estimated_sources + else: + return ret_data + + +def prefer_target_instrument(config: ConfigDict) -> List[str]: + """ + Return the list of target instruments based on the configuration. + If a specific target instrument is specified in the configuration, + it returns a list with that instrument. Otherwise, it returns the list of instruments. + + Parameters: + ---------- + config : ConfigDict + Configuration object containing the list of instruments or the target instrument. + + Returns: + ------- + List[str] + A list of target instruments. + """ + if getattr(config.training, 'target_instrument', None): + return [config.training.target_instrument] + else: + return config.training.instruments + + +def load_not_compatible_weights(model: torch.nn.Module, weights: str, verbose: bool = False) -> None: + """ + Load weights into a model, handling mismatched shapes and dimensions. + + Args: + model: PyTorch model into which the weights will be loaded. + weights: Path to the weights file. + verbose: If True, prints detailed information about matching and mismatched layers. + """ + + new_model = model.state_dict() + old_model = torch.load(weights) + if 'state' in old_model: + # Fix for htdemucs weights loading + old_model = old_model['state'] + if 'state_dict' in old_model: + # Fix for apollo weights loading + old_model = old_model['state_dict'] + + for el in new_model: + if el in old_model: + if verbose: + print(f'Match found for {el}!') + if new_model[el].shape == old_model[el].shape: + if verbose: + print('Action: Just copy weights!') + new_model[el] = old_model[el] + else: + if len(new_model[el].shape) != len(old_model[el].shape): + if verbose: + print('Action: Different dimension! Too lazy to write the code... Skip it') + else: + if verbose: + print(f'Shape is different: {tuple(new_model[el].shape)} != {tuple(old_model[el].shape)}') + ln = len(new_model[el].shape) + max_shape = [] + slices_old = [] + slices_new = [] + for i in range(ln): + max_shape.append(max(new_model[el].shape[i], old_model[el].shape[i])) + slices_old.append(slice(0, old_model[el].shape[i])) + slices_new.append(slice(0, new_model[el].shape[i])) + # print(max_shape) + # print(slices_old, slices_new) + slices_old = tuple(slices_old) + slices_new = tuple(slices_new) + max_matrix = np.zeros(max_shape, dtype=np.float32) + for i in range(ln): + max_matrix[slices_old] = old_model[el].cpu().numpy() + max_matrix = torch.from_numpy(max_matrix) + new_model[el] = max_matrix[slices_new] + else: + if verbose: + print(f'Match not found for {el}!') + model.load_state_dict( + new_model + ) + + +def load_lora_weights(model: torch.nn.Module, lora_path: str, device: str = 'cpu') -> None: + """ + Load LoRA weights into a model. + This function updates the given model with LoRA-specific weights from the specified checkpoint file. + It does not require the checkpoint to match the model's full state dictionary, as only LoRA layers are updated. + + Parameters: + ---------- + model : Module + The PyTorch model into which the LoRA weights will be loaded. + lora_path : str + Path to the LoRA checkpoint file. + device : str, optional + The device to load the weights onto, by default 'cpu'. Common values are 'cpu' or 'cuda'. + + Returns: + ------- + None + The model is updated in place. + """ + lora_state_dict = torch.load(lora_path, map_location=device) + model.load_state_dict(lora_state_dict, strict=False) + + +def load_start_checkpoint(args: argparse.Namespace, model: torch.nn.Module, type_='train') -> None: + """ + Load the starting checkpoint for a model. + + Args: + args: Parsed command-line arguments containing the checkpoint path. + model: PyTorch model to load the checkpoint into. + type_: how to load weights - for train we can load not fully compatible weights + """ + + print(f'Start from checkpoint: {args.start_check_point}') + if type_ in ['train']: + if 1: + load_not_compatible_weights(model, args.start_check_point, verbose=False) + else: + model.load_state_dict(torch.load(args.start_check_point)) + else: + device='cpu' + if args.model_type in ['htdemucs', 'apollo']: + state_dict = torch.load(args.start_check_point, map_location=device, weights_only=False) + # Fix for htdemucs pretrained models + if 'state' in state_dict: + state_dict = state_dict['state'] + # Fix for apollo pretrained models + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + else: + state_dict = torch.load(args.start_check_point, map_location=device, weights_only=True) + model.load_state_dict(state_dict) + + if args.lora_checkpoint: + print(f"Loading LoRA weights from: {args.lora_checkpoint}") + load_lora_weights(model, args.lora_checkpoint) + + +def bind_lora_to_model(config: Dict[str, Any], model: nn.Module) -> nn.Module: + """ + Replaces specific layers in the model with LoRA-extended versions. + + Parameters: + ---------- + config : Dict[str, Any] + Configuration containing parameters for LoRA. It should include a 'lora' key with parameters for `MergedLinear`. + model : nn.Module + The original model in which the layers will be replaced. + + Returns: + ------- + nn.Module + The modified model with the replaced layers. + """ + + if 'lora' not in config: + raise ValueError("Configuration must contain the 'lora' key with parameters for LoRA.") + + replaced_layers = 0 # Counter for replaced layers + + for name, module in model.named_modules(): + hierarchy = name.split('.') + layer_name = hierarchy[-1] + + # Check if this is the target layer to replace (and layer_name == 'to_qkv') + if isinstance(module, nn.Linear): + try: + # Get the parent module + parent_module = model + for submodule_name in hierarchy[:-1]: + parent_module = getattr(parent_module, submodule_name) + + # Replace the module with LoRA-enabled layer + setattr( + parent_module, + layer_name, + lora.MergedLinear( + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + **config['lora'] + ) + ) + replaced_layers += 1 # Increment the counter + + except Exception as e: + print(f"Error replacing layer {name}: {e}") + + if replaced_layers == 0: + print("Warning: No layers were replaced. Check the model structure and configuration.") + else: + print(f"Number of layers replaced with LoRA: {replaced_layers}") + + return model + + +def draw_spectrogram(waveform, sample_rate, length, output_file): + import librosa.display + + # Cut only required part of spectorgram + x = waveform[:int(length * sample_rate), :] + X = librosa.stft(x.mean(axis=-1)) # perform short-term fourier transform on mono signal + Xdb = librosa.amplitude_to_db(np.abs(X), ref=np.max) # convert an amplitude spectrogram to dB-scaled spectrogram. + fig, ax = plt.subplots() + # plt.figure(figsize=(30, 10)) # initialize the fig size + img = librosa.display.specshow( + Xdb, + cmap='plasma', + sr=sample_rate, + x_axis='time', + y_axis='linear', + ax=ax + ) + ax.set(title='File: ' + os.path.basename(output_file)) + fig.colorbar(img, ax=ax, format="%+2.f dB") + if output_file is not None: + plt.savefig(output_file) diff --git a/third_party/VideoLLaMA2/.gitignore b/third_party/VideoLLaMA2/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..5d2b4c1ab07337c2106b40a88610bf991e0707aa --- /dev/null +++ b/third_party/VideoLLaMA2/.gitignore @@ -0,0 +1,58 @@ +# Python +__pycache__ +*.pyc +*.egg-info +dist + +# Log +*.log +*.log.* +*.json +*.jsonl +log_dir*/ +temp*/ + +# Data +!**/alpaca-data-conversation.json + +# Editor +.idea +*.swp + +# Other +.DS_Store +3rd_parties + +# jupyter +.ipynb_checkpoints +*.ipynb + +# DevContainer +!.devcontainer/* + +# Demo +serve_images/ +temp/ + +# data folder +data/ +dataset/ +datasets/ + +# training folder +wandb +ckpts* +output +output/ +checkpoints +checkpoints/ +work_dirs*/ + +# evaluation folder +/eval +/eval* + +# pretrained weights +pretrained/ +publish_models/ +public_models/ \ No newline at end of file diff --git a/third_party/VideoLLaMA2/LICENSE b/third_party/VideoLLaMA2/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/third_party/VideoLLaMA2/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/third_party/VideoLLaMA2/README.md b/third_party/VideoLLaMA2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..99f6dea9ef4a61c1cb7551cf1042c05a56ee6dfa --- /dev/null +++ b/third_party/VideoLLaMA2/README.md @@ -0,0 +1,365 @@ +

+ +

+ +

+VideoLLaMA 2: Advancing Spatial-Temporal Modeling and Audio Understanding in Video-LLMs

+
If our project helps you, please give us a star โญ on GitHub to support us. ๐Ÿ™๐Ÿ™
+ +
+ +[![hf_space](https://img.shields.io/badge/๐Ÿค—-Demo-9C276A.svg)](https://huggingface.co/spaces/lixin4ever/VideoLLaMA2) +[![hf_checkpoint](https://img.shields.io/badge/๐Ÿค—-Checkpoints-9C276A.svg)](https://huggingface.co/collections/DAMO-NLP-SG/videollama-2-6669b6b6f0493188305c87ed) +[![hf_data](https://img.shields.io/badge/๐Ÿค—-MSVC-9C276A.svg)](https://huggingface.co/datasets/DAMO-NLP-SG/Multi-Source-Video-Captioning) +[![arXiv](https://img.shields.io/badge/Arxiv-2406.07476-AD1C18.svg?logo=arXiv)](https://arxiv.org/abs/2406.07476)
+[![License](https://img.shields.io/badge/License-Apache%202.0-yellow)](https://github.com/DAMO-NLP-SG/VideoLLaMA2/blob/main/LICENSE) +[![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FDAMO-NLP-SG%2FVideoLLaMA2&count_bg=%2379C83D&title_bg=%23555555&icon=&icon_color=%23E7E7E7&title=Visitor&edge_flat=false)](https://hits.seeyoufarm.com) +[![GitHub issues](https://img.shields.io/github/issues/DAMO-NLP-SG/VideoLLaMA2?color=critical&label=Issues)](https://github.com/DAMO-NLP-SG/VideoLLaMA2/issues?q=is%3Aopen+is%3Aissue) +[![GitHub closed issues](https://img.shields.io/github/issues-closed/DAMO-NLP-SG/VideoLLaMA2?color=success&label=Issues)](https://github.com/DAMO-NLP-SG/VideoLLaMA2/issues?q=is%3Aissue+is%3Aclosed)
+ +
+ +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/videollama-2-advancing-spatial-temporal/zero-shot-video-question-answer-on-egoschema-1)](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-egoschema-1?p=videollama-2-advancing-spatial-temporal)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/videollama-2-advancing-spatial-temporal/video-question-answering-on-perception-test)](https://paperswithcode.com/sota/video-question-answering-on-perception-test?p=videollama-2-advancing-spatial-temporal)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/videollama-2-advancing-spatial-temporal/video-question-answering-on-mvbench)](https://paperswithcode.com/sota/video-question-answering-on-mvbench?p=videollama-2-advancing-spatial-temporal)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/videollama-2-advancing-spatial-temporal/zero-shot-video-question-answer-on-video-mme-1)](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-video-mme-1?p=videollama-2-advancing-spatial-temporal)
+[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/videollama-2-advancing-spatial-temporal/zero-shot-video-question-answer-on-video-mme)](https://paperswithcode.com/sota/zero-shot-video-question-answer-on-video-mme?p=videollama-2-advancing-spatial-temporal)
+ +
๐Ÿ’ก Some other multimodal-LLM projects from our team may interest you โœจ.

+ + +> [**Video-LLaMA: An Instruction-tuned Audio-Visual Language Model for Video Understanding**](https://github.com/DAMO-NLP-SG/Video-LLaMA)
+> Hang Zhang, Xin Li, Lidong Bing
+[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/DAMO-NLP-SG/Video-LLaMA) [![github](https://img.shields.io/github/stars/DAMO-NLP-SG/Video-LLaMA.svg?style=social)](https://github.com/DAMO-NLP-SG/Video-LLaMA) [![arXiv](https://img.shields.io/badge/Arxiv-2306.02858-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2306.02858)
+ +> [**VCD: Mitigating Object Hallucinations in Large Vision-Language Models through Visual Contrastive Decoding**](https://arxiv.org/abs/2311.16922)
+> Sicong Leng, Hang Zhang, Guanzheng Chen, Xin Li, Shijian Lu, Chunyan Miao, Lidong Bing
+[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/DAMO-NLP-SG/VCD) [![github](https://img.shields.io/github/stars/DAMO-NLP-SG/VCD.svg?style=social)](https://github.com/DAMO-NLP-SG/VCD) [![arXiv](https://img.shields.io/badge/Arxiv-2311.16922-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2311.16922)
+ +> [**The Curse of Multi-Modalities: Evaluating Hallucinations of Large Multimodal Models across Language, Visual, and Audio**](https://arxiv.org/abs/2410.12787)
+> Sicong Leng, Yun Xing, Zesen Cheng, Yang Zhou, Hang Zhang, Xin Li, Deli Zhao, Shijian Lu, Chunyan Miao, Lidong Bing
+[![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/DAMO-NLP-SG/CMM) [![github](https://img.shields.io/github/stars/DAMO-NLP-SG/CMM.svg?style=social)](https://github.com/DAMO-NLP-SG/CMM) [![arXiv](https://img.shields.io/badge/Arxiv-2410.12787-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2410.12787)
+ +

+ +
+ + +## ๐Ÿ“ฐ News +* **[2024.10.22]** Release checkpoints of [VideoLLaMA2.1-7B-AV](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2.1-7B-AV). +* **[2024.10.15]** Release checkpoints of [VideoLLaMA2.1-7B-16F-Base](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2.1-7B-16F-Base) and [VideoLLaMA2.1-7B-16F](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2.1-7B-16F). +* **[2024.08.14]** Release checkpoints of [VideoLLaMA2-72B-Base](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-72B-Base) and [VideoLLaMA2-72B](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-72B). +* **[2024.07.30]** Release checkpoints of [VideoLLaMA2-8x7B-Base](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-8x7B-Base) and [VideoLLaMA2-8x7B](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-8x7B). +* **[2024.06.25]** ๐Ÿ”ฅ๐Ÿ”ฅ As of Jun 25, our [VideoLLaMA2-7B-16F](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-7B-16F) is the **Top-1** ~7B-sized VideoLLM on the [MLVU Leaderboard](https://github.com/JUNJIE99/MLVU?tab=readme-ov-file#trophy-mini-leaderboard). +* **[2024.06.18]** ๐Ÿ”ฅ๐Ÿ”ฅ As of Jun 18, our [VideoLLaMA2-7B-16F](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-7B-16F) is the **Top-1** ~7B-sized VideoLLM on the [VideoMME Leaderboard](https://video-mme.github.io/home_page.html#leaderboard). +* **[2024.06.17]** ๐Ÿ‘‹๐Ÿ‘‹ Update technical report with the latest results and the missing references. If you have works closely related to VideoLLaMA 2 but not mentioned in the paper, feel free to let us know. +* **[2024.06.14]** ๐Ÿ”ฅ๐Ÿ”ฅ [Online Demo](https://huggingface.co/spaces/lixin4ever/VideoLLaMA2) is available. +* **[2024.06.03]** Release training, evaluation, and serving codes of VideoLLaMA 2. + + + + +## ๐Ÿ› ๏ธ Requirements and Installation +Basic Dependencies: +* Python >= 3.8 +* Pytorch >= 2.2.0 +* CUDA Version >= 11.8 +* transformers == 4.40.0 (for reproducing paper results) +* tokenizers == 0.19.1 + +**[Online Mode]** Install required packages (better for development): +```bash +git clone https://github.com/DAMO-NLP-SG/VideoLLaMA2 +cd VideoLLaMA2 +git checkout audio_visual +pip install -r requirements.txt +pip install flash-attn==2.5.8 --no-build-isolation +pip install opencv-python==4.5.5.64 +apt-get update && apt-get install ffmpeg libsm6 libxext6 -y +``` + +**[Offline Mode]** Install VideoLLaMA2 as a Python package (better for direct use): +```bash +git clone https://github.com/DAMO-NLP-SG/VideoLLaMA2 +cd VideoLLaMA2 +git checkout audio_visual +pip install --upgrade pip # enable PEP 660 support +pip install -e . +pip install flash-attn==2.5.8 --no-build-isolation +pip install opencv-python==4.5.5.64 +apt-get update && apt-get install ffmpeg libsm6 libxext6 -y +``` + +## ๐Ÿš€ Main Results + +### Multi-Choice Video QA & Video Captioning +

+ +### Open-Ended Video QA +

+ +### Audio QA +

+ +### Audio-Visual QA +

+ + +## :earth_americas: Model Zoo +### Vision-only Checkpoints +| Model Name | Model Type | Visual Encoder | Language Decoder | # Training Frames | +|:----------------|:------------:|:----------------|:------------------|:----------------:| +| [VideoLLaMA2-7B-Base](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-7B-Base) | Base | [clip-vit-large-patch14-336](https://huggingface.co/openai/clip-vit-large-patch14-336) | [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) | 8 | +| [VideoLLaMA2-7B](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-7B) | Chat | [clip-vit-large-patch14-336](https://huggingface.co/openai/clip-vit-large-patch14-336) | [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) | 8 | +| [VideoLLaMA2-7B-16F-Base](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-7B-16F-Base) | Base | [clip-vit-large-patch14-336](https://huggingface.co/openai/clip-vit-large-patch14-336) | [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) | 16 | +| [VideoLLaMA2-7B-16F](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-7B-16F) | Chat | [clip-vit-large-patch14-336](https://huggingface.co/openai/clip-vit-large-patch14-336) | [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) | 16 | +| [VideoLLaMA2-8x7B-Base](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-8x7B-Base) | Base | [clip-vit-large-patch14-336](https://huggingface.co/openai/clip-vit-large-patch14-336) | [Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) | 8 | +| [VideoLLaMA2-8x7B](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-8x7B) | Chat | [clip-vit-large-patch14-336](https://huggingface.co/openai/clip-vit-large-patch14-336) | [Mixtral-8x7B-Instruct-v0.1](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) | 8 | +| [VideoLLaMA2-72B-Base](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-72B-Base) | Base | [clip-vit-large-patch14-336](https://huggingface.co/openai/clip-vit-large-patch14-336) | [Qwen2-72B-Instruct](https://huggingface.co/Qwen/Qwen2-72B-Instruct) | 8 | +| [VideoLLaMA2-72B](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2-72B) | Chat | [clip-vit-large-patch14-336](https://huggingface.co/openai/clip-vit-large-patch14-336) | [Qwen2-72B-Instruct](https://huggingface.co/Qwen/Qwen2-72B-Instruct) | 8 | +| [VideoLLaMA2.1-7B-16F-Base](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2.1-7B-16F-Base) | Base | [siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) | [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) | 16 | +| [VideoLLaMA2.1-7B-16F](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2.1-7B-16F) | Chat | [siglip-so400m-patch14-384](https://huggingface.co/google/siglip-so400m-patch14-384) | [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) | 16 | + +### Audio-Visual Checkpoints +| Model Name | Type | Audio Encoder | Language Decoder | +|:-------------------|:----------------|:----------------|:------------------| +| [VideoLLaMA2.1-7B-AV](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2.1-7B-AV) | Chat | [Fine-tuned BEATs_iter3+(AS2M)(cpt2)](https://1drv.ms/u/s!AqeByhGUtINrgcpj8ujXH1YUtxooEg?e=E9Ncea) | [VideoLLaMA2.1-7B-16F](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2.1-7B-16F) | + + +## [๐Ÿค— Demo](https://huggingface.co/spaces/lixin4ever/VideoLLaMA2-AV) + +It is highly recommended to try our [online demo](https://huggingface.co/spaces/lixin4ever/VideoLLaMA2-AV) first. + +To run a video-based LLM (Large Language Model) web demonstration on your device, you will first need to ensure that you have the necessary model checkpoints prepared, followed by adhering to the steps outlined to successfully launch the demo. + +### Single-model Version + +* Launch a gradio app directly ([VideoLLaMA2.1-7B-AV](https://huggingface.co/DAMO-NLP-SG/VideoLLaMA2.1-7B-AV) is adopted by default): +```bash +python videollama2/serve/gradio_web_server_adhoc_av.py +``` + +## ๐Ÿ—๏ธ Training & Evaluation + +### Quick Start + +To facilitate further development on top of our codebase, we provide a quick-start guide on how to train a customized [VideoLLaMA2](https://github.com/DAMO-NLP-SG/VideoLLaMA2) with [VideoLLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA) dataset and evaluate the trained model on the mainstream video-llm benchmarks. + +1. Training Data Structure: +Follow the main branch(https://github.com/DAMO-NLP-SG/VideoLLaMA2/tree/main) of this VideoLLaMA2 codebase. +2. Command: +```bash +# VideoLLaMA2.1-audio pretraining +bash scripts/custom/pretrain_audio.sh +# VideoLLaMA2.1-audio finetuning +bash scripts/custom/finetune_audio.sh +# VideoLLaMA2.1-audio_visual finetuning +bash scripts/custom/va_joint.sh +``` +3. Evaluation Data Structure: +Follow the main branch(https://github.com/DAMO-NLP-SG/VideoLLaMA2/tree/main) of this VideoLLaMA2 codebase. + +4. Command: +```bash +# ClothoAQA.sh evaluation +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/eval/eval_audio_clothoAQA.sh +# TUT2017 evaluation +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/eval/eval_audio_TUT2017.sh +# VocalSound evaluation +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/eval/eval_audio_vocalsound.sh +# AVQA_music evaluation +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/eval/eval_audio_video_AVQA.sh +# AVSD evaluation (need to set azure openai key/endpoint/deployname) +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/eval/eval_audio_video_AVSD.sh +# AVSSD evaluation (need to set azure openai key/endpoint/deployname) +CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 bash scripts/eval/eval_audio_video_AVSSD.sh +``` + +### Data Format + +If you want to train a video-llm on your data, you need to follow the procedures below to prepare the audio/video/image sft data: + +1. Suppose your data structure is like: +```bash +VideoLLaMA2 +โ”œโ”€โ”€ datasets +โ”‚ โ”œโ”€โ”€ custom_sft +โ”‚ | โ”œโ”€โ”€ audio +โ”‚ | โ”œโ”€โ”€ video +โ”‚ | โ”œโ”€โ”€ image +| | โ””โ”€โ”€ custom.json +``` +2. Then you should re-organize the annotated audio/video/image sft data according to the following format: +```json +[ + { + "id": 0, + "audio": "audio/xxx.wav", + "conversations": [ + { + "from": "human", + "value": "