lym0302 commited on
Commit
eedfa8e
·
1 Parent(s): bafca5a
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +21 -0
  2. README.md +185 -14
  3. app.py +343 -0
  4. batch_eval.py +110 -0
  5. config/__init__.py +0 -0
  6. config/base_config.yaml +62 -0
  7. config/data/base.yaml +70 -0
  8. config/eval_config.yaml +17 -0
  9. config/eval_data/base.yaml +22 -0
  10. config/hydra/job_logging/custom-eval.yaml +32 -0
  11. config/hydra/job_logging/custom-no-rank.yaml +32 -0
  12. config/hydra/job_logging/custom-simplest.yaml +26 -0
  13. config/hydra/job_logging/custom.yaml +33 -0
  14. config/train_config.yaml +41 -0
  15. demo.py +141 -0
  16. docs/EVAL.md +22 -0
  17. docs/MODELS.md +50 -0
  18. docs/TRAINING.md +184 -0
  19. docs/images/icon.png +0 -0
  20. docs/index.html +149 -0
  21. docs/style.css +78 -0
  22. docs/style_videos.css +52 -0
  23. docs/video_gen.html +254 -0
  24. docs/video_main.html +98 -0
  25. docs/video_vgg.html +452 -0
  26. gradio_demo.py +343 -0
  27. mmaudio/__init__.py +0 -0
  28. mmaudio/__pycache__/__init__.cpython-310.pyc +0 -0
  29. mmaudio/__pycache__/__init__.cpython-38.pyc +0 -0
  30. mmaudio/__pycache__/eval_utils.cpython-310.pyc +0 -0
  31. mmaudio/__pycache__/eval_utils.cpython-38.pyc +0 -0
  32. mmaudio/data/__init__.py +0 -0
  33. mmaudio/data/__pycache__/__init__.cpython-310.pyc +0 -0
  34. mmaudio/data/__pycache__/__init__.cpython-38.pyc +0 -0
  35. mmaudio/data/__pycache__/av_utils.cpython-310.pyc +0 -0
  36. mmaudio/data/__pycache__/av_utils.cpython-38.pyc +0 -0
  37. mmaudio/data/av_utils.py +162 -0
  38. mmaudio/data/data_setup.py +174 -0
  39. mmaudio/data/eval/__init__.py +0 -0
  40. mmaudio/data/eval/audiocaps.py +39 -0
  41. mmaudio/data/eval/moviegen.py +131 -0
  42. mmaudio/data/eval/video_dataset.py +197 -0
  43. mmaudio/data/extracted_audio.py +88 -0
  44. mmaudio/data/extracted_vgg.py +101 -0
  45. mmaudio/data/extraction/__init__.py +0 -0
  46. mmaudio/data/extraction/vgg_sound.py +193 -0
  47. mmaudio/data/extraction/wav_dataset.py +132 -0
  48. mmaudio/data/mm_dataset.py +45 -0
  49. mmaudio/data/utils.py +148 -0
  50. mmaudio/eval_utils.py +255 -0
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Sony Research Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,14 +1,185 @@
1
- ---
2
- title: DeepSound V1
3
- emoji: 📚
4
- colorFrom: red
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.22.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: DeepSound-V1 demo
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <p align="center">
3
+ <h2>MMAudio</h2>
4
+ <a href="https://arxiv.org/abs/2412.15322">Paper</a> | <a href="https://hkchengrex.github.io/MMAudio">Webpage</a> | <a href="https://huggingface.co/hkchengrex/MMAudio/tree/main">Models</a> | <a href="https://huggingface.co/spaces/hkchengrex/MMAudio"> Huggingface Demo</a> | <a href="https://colab.research.google.com/drive/1TAaXCY2-kPk4xE4PwKB3EqFbSnkUuzZ8?usp=sharing">Colab Demo</a> | <a href="https://replicate.com/zsxkib/mmaudio">Replicate Demo</a>
5
+ </p>
6
+ </div>
7
+
8
+ ## [Taming Multimodal Joint Training for High-Quality Video-to-Audio Synthesis](https://hkchengrex.github.io/MMAudio)
9
+
10
+ [Ho Kei Cheng](https://hkchengrex.github.io/), [Masato Ishii](https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ), [Akio Hayakawa](https://scholar.google.com/citations?user=sXAjHFIAAAAJ), [Takashi Shibuya](https://scholar.google.com/citations?user=XCRO260AAAAJ), [Alexander Schwing](https://www.alexander-schwing.de/), [Yuki Mitsufuji](https://www.yukimitsufuji.com/)
11
+
12
+ University of Illinois Urbana-Champaign, Sony AI, and Sony Group Corporation
13
+
14
+ CVPR 2025
15
+
16
+ ## Highlight
17
+
18
+ MMAudio generates synchronized audio given video and/or text inputs.
19
+ Our key innovation is multimodal joint training which allows training on a wide range of audio-visual and audio-text datasets.
20
+ Moreover, a synchronization module aligns the generated audio with the video frames.
21
+
22
+ ## Results
23
+
24
+ (All audio from our algorithm MMAudio)
25
+
26
+ Videos from Sora:
27
+
28
+ https://github.com/user-attachments/assets/82afd192-0cee-48a1-86ca-bd39b8c8f330
29
+
30
+ Videos from Veo 2:
31
+
32
+ https://github.com/user-attachments/assets/8a11419e-fee2-46e0-9e67-dfb03c48d00e
33
+
34
+ Videos from MovieGen/Hunyuan Video/VGGSound:
35
+
36
+ https://github.com/user-attachments/assets/29230d4e-21c1-4cf8-a221-c28f2af6d0ca
37
+
38
+ For more results, visit https://hkchengrex.com/MMAudio/video_main.html.
39
+
40
+
41
+ ## Installation
42
+
43
+ We have only tested this on Ubuntu.
44
+
45
+ ### Prerequisites
46
+
47
+ We recommend using a [miniforge](https://github.com/conda-forge/miniforge) environment.
48
+
49
+ - Python 3.9+
50
+ - PyTorch **2.5.1+** and corresponding torchvision/torchaudio (pick your CUDA version https://pytorch.org/, pip install recommended)
51
+ <!-- - ffmpeg<7 ([this is required by torchaudio](https://pytorch.org/audio/master/installation.html#optional-dependencies), you can install it in a miniforge environment with `conda install -c conda-forge 'ffmpeg<7'`) -->
52
+
53
+ **1. Install prerequisite if not yet met:**
54
+
55
+ ```bash
56
+ pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 --upgrade
57
+ ```
58
+
59
+ (Or any other CUDA versions that your GPUs/driver support)
60
+
61
+ <!-- ```
62
+ conda install -c conda-forge 'ffmpeg<7
63
+ ```
64
+ (Optional, if you use miniforge and don't already have the appropriate ffmpeg) -->
65
+
66
+ **2. Clone our repository:**
67
+
68
+ ```bash
69
+ git clone https://github.com/hkchengrex/MMAudio.git
70
+ ```
71
+
72
+ **3. Install with pip (install pytorch first before attempting this!):**
73
+
74
+ ```bash
75
+ cd MMAudio
76
+ pip install -e .
77
+ ```
78
+
79
+ (If you encounter the File "setup.py" not found error, upgrade your pip with pip install --upgrade pip)
80
+
81
+
82
+ **Pretrained models:**
83
+
84
+ The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`.
85
+ The models are also available at https://huggingface.co/hkchengrex/MMAudio/tree/main
86
+ See [MODELS.md](docs/MODELS.md) for more details.
87
+
88
+ ## Demo
89
+
90
+ By default, these scripts use the `large_44k_v2` model.
91
+ In our experiments, inference only takes around 6GB of GPU memory (in 16-bit mode) which should fit in most modern GPUs.
92
+
93
+ ### Command-line interface
94
+
95
+ With `demo.py`
96
+
97
+ ```bash
98
+ python demo.py --duration=8 --video=<path to video> --prompt "your prompt"
99
+ ```
100
+
101
+ The output (audio in `.flac` format, and video in `.mp4` format) will be saved in `./output`.
102
+ See the file for more options.
103
+ Simply omit the `--video` option for text-to-audio synthesis.
104
+ The default output (and training) duration is 8 seconds. Longer/shorter durations could also work, but a large deviation from the training duration may result in a lower quality.
105
+
106
+ ### Gradio interface
107
+
108
+ Supports video-to-audio and text-to-audio synthesis.
109
+ You can also try experimental image-to-audio synthesis which duplicates the input image to a video for processing. This might be interesting to some but it is not something MMAudio has been trained for.
110
+ Use [port forwarding](https://unix.stackexchange.com/questions/115897/whats-ssh-port-forwarding-and-whats-the-difference-between-ssh-local-and-remot) (e.g., `ssh -L 7860:localhost:7860 server`) if necessary. The default port is `7860` which you can specify with `--port`.
111
+
112
+ ```bash
113
+ python gradio_demo.py
114
+ ```
115
+
116
+ ### FAQ
117
+
118
+ 1. Video processing
119
+ - Processing higher-resolution videos takes longer due to encoding and decoding (which can take >95% of the processing time!), but it does not improve the quality of results.
120
+ - The CLIP encoder resizes input frames to 384×384 pixels.
121
+ - Synchformer resizes the shorter edge to 224 pixels and applies a center crop, focusing only on the central square of each frame.
122
+ 2. Frame rates
123
+ - The CLIP model operates at 8 FPS, while Synchformer works at 25 FPS.
124
+ - Frame rate conversion happens on-the-fly via the video reader.
125
+ - For input videos with a frame rate below 25 FPS, frames will be duplicated to match the required rate.
126
+ 3. Failure cases
127
+ As with most models of this type, failures can occur, and the reasons are not always clear. Below are some known failure modes. If you notice a failure mode or believe there’s a bug, feel free to open an issue in the repository.
128
+ 4. Performance variations
129
+ We notice that there can be subtle performance variations in different hardware and software environments. Some of the reasons include using/not using `torch.compile`, video reader library/backend, inference precision, batch sizes, random seeds, etc. We (will) provide pre-computed results on standard benchmark for reference. Results obtained from this codebase should be similar but might not be exactly the same.
130
+
131
+ ### Known limitations
132
+
133
+ 1. The model sometimes generates unintelligible human speech-like sounds
134
+ 2. The model sometimes generates background music (without explicit training, it would not be high quality)
135
+ 3. The model struggles with unfamiliar concepts, e.g., it can generate "gunfires" but not "RPG firing".
136
+
137
+ We believe all of these three limitations can be addressed with more high-quality training data.
138
+
139
+ ## Training
140
+
141
+ See [TRAINING.md](docs/TRAINING.md).
142
+
143
+ ## Evaluation
144
+
145
+ See [EVAL.md](docs/EVAL.md).
146
+
147
+ ## Training Datasets
148
+
149
+ MMAudio was trained on several datasets, including [AudioSet](https://research.google.com/audioset/), [Freesound](https://github.com/LAION-AI/audio-dataset/blob/main/laion-audio-630k/README.md), [VGGSound](https://www.robots.ox.ac.uk/~vgg/data/vggsound/), [AudioCaps](https://audiocaps.github.io/), and [WavCaps](https://github.com/XinhaoMei/WavCaps). These datasets are subject to specific licenses, which can be accessed on their respective websites. We do not guarantee that the pre-trained models are suitable for commercial use. Please use them at your own risk.
150
+
151
+ ## Update Logs
152
+
153
+ - 2025-03-09: Uploaded the corrected tsv files. See [TRAINING.md](docs/TRAINING.md).
154
+ - 2025-02-27: Disabled the GradScaler by default to improve training stability. See #49.
155
+ - 2024-12-23: Added training and batch evaluation scripts.
156
+ - 2024-12-14: Removed the `ffmpeg<7` requirement for the demos by replacing `torio.io.StreamingMediaDecoder` with `pyav` for reading frames. The read frames are also cached, so we are not reading the same frames again during reconstruction. This should speed things up and make installation less of a hassle.
157
+ - 2024-12-13: Improved for-loop processing in CLIP/Sync feature extraction by introducing a batch size multiplier. We can approximately use 40x batch size for CLIP/Sync without using more memory, thereby speeding up processing. Removed VAE encoder during inference -- we don't need it.
158
+ - 2024-12-11: Replaced `torio.io.StreamingMediaDecoder` with `pyav` for reading framerate when reconstructing the input video. `torio.io.StreamingMediaDecoder` does not work reliably in huggingface ZeroGPU's environment, and I suspect that it might not work in some other environments as well.
159
+
160
+ ## Citation
161
+
162
+ ```bibtex
163
+ @inproceedings{cheng2025taming,
164
+ title={Taming Multimodal Joint Training for High-Quality Video-to-Audio Synthesis},
165
+ author={Cheng, Ho Kei and Ishii, Masato and Hayakawa, Akio and Shibuya, Takashi and Schwing, Alexander and Mitsufuji, Yuki},
166
+ booktitle={CVPR},
167
+ year={2025}
168
+ }
169
+ ```
170
+
171
+ ## Relevant Repositories
172
+
173
+ - [av-benchmark](https://github.com/hkchengrex/av-benchmark) for benchmarking results.
174
+
175
+ ## Disclaimer
176
+
177
+ We have no affiliation with and have no knowledge of the party behind the domain "mmaudio.net".
178
+
179
+ ## Acknowledgement
180
+
181
+ Many thanks to:
182
+ - [Make-An-Audio 2](https://github.com/bytedance/Make-An-Audio-2) for the 16kHz BigVGAN pretrained model and the VAE architecture
183
+ - [BigVGAN](https://github.com/NVIDIA/BigVGAN)
184
+ - [Synchformer](https://github.com/v-iashin/Synchformer)
185
+ - [EDM2](https://github.com/NVlabs/edm2) for the magnitude-preserving VAE network architecture
app.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ from argparse import ArgumentParser
4
+ from datetime import datetime
5
+ from fractions import Fraction
6
+ from pathlib import Path
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import torchaudio
11
+
12
+ from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
13
+ load_video, make_video, setup_eval_logging)
14
+ from mmaudio.model.flow_matching import FlowMatching
15
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
16
+ from mmaudio.model.sequence_config import SequenceConfig
17
+ from mmaudio.model.utils.features_utils import FeaturesUtils
18
+
19
+ torch.backends.cuda.matmul.allow_tf32 = True
20
+ torch.backends.cudnn.allow_tf32 = True
21
+
22
+ log = logging.getLogger()
23
+
24
+ device = 'cpu'
25
+ if torch.cuda.is_available():
26
+ device = 'cuda'
27
+ elif torch.backends.mps.is_available():
28
+ device = 'mps'
29
+ else:
30
+ log.warning('CUDA/MPS are not available, running on CPU')
31
+ dtype = torch.bfloat16
32
+
33
+ model: ModelConfig = all_model_cfg['large_44k_v2']
34
+ model.download_if_needed()
35
+ output_dir = Path('./output/gradio')
36
+
37
+ setup_eval_logging()
38
+
39
+
40
+ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
41
+ seq_cfg = model.seq_cfg
42
+
43
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
44
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
45
+ log.info(f'Loaded weights from {model.model_path}')
46
+
47
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
48
+ synchformer_ckpt=model.synchformer_ckpt,
49
+ enable_conditions=True,
50
+ mode=model.mode,
51
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
52
+ need_vae_encoder=False)
53
+ feature_utils = feature_utils.to(device, dtype).eval()
54
+
55
+ return net, feature_utils, seq_cfg
56
+
57
+
58
+ net, feature_utils, seq_cfg = get_model()
59
+
60
+
61
+ @torch.inference_mode()
62
+ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
63
+ cfg_strength: float, duration: float):
64
+
65
+ rng = torch.Generator(device=device)
66
+ if seed >= 0:
67
+ rng.manual_seed(seed)
68
+ else:
69
+ rng.seed()
70
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
71
+
72
+ video_info = load_video(video, duration)
73
+ clip_frames = video_info.clip_frames
74
+ sync_frames = video_info.sync_frames
75
+ duration = video_info.duration_sec
76
+ clip_frames = clip_frames.unsqueeze(0)
77
+ sync_frames = sync_frames.unsqueeze(0)
78
+ seq_cfg.duration = duration
79
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
80
+
81
+ audios = generate(clip_frames,
82
+ sync_frames, [prompt],
83
+ negative_text=[negative_prompt],
84
+ feature_utils=feature_utils,
85
+ net=net,
86
+ fm=fm,
87
+ rng=rng,
88
+ cfg_strength=cfg_strength)
89
+ audio = audios.float().cpu()[0]
90
+
91
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
92
+ output_dir.mkdir(exist_ok=True, parents=True)
93
+ video_save_path = output_dir / f'{current_time_string}.mp4'
94
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
95
+ gc.collect()
96
+ return video_save_path
97
+
98
+
99
+ @torch.inference_mode()
100
+ def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int, num_steps: int,
101
+ cfg_strength: float, duration: float):
102
+
103
+ rng = torch.Generator(device=device)
104
+ if seed >= 0:
105
+ rng.manual_seed(seed)
106
+ else:
107
+ rng.seed()
108
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
109
+
110
+ image_info = load_image(image)
111
+ clip_frames = image_info.clip_frames
112
+ sync_frames = image_info.sync_frames
113
+ clip_frames = clip_frames.unsqueeze(0)
114
+ sync_frames = sync_frames.unsqueeze(0)
115
+ seq_cfg.duration = duration
116
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
117
+
118
+ audios = generate(clip_frames,
119
+ sync_frames, [prompt],
120
+ negative_text=[negative_prompt],
121
+ feature_utils=feature_utils,
122
+ net=net,
123
+ fm=fm,
124
+ rng=rng,
125
+ cfg_strength=cfg_strength,
126
+ image_input=True)
127
+ audio = audios.float().cpu()[0]
128
+
129
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
130
+ output_dir.mkdir(exist_ok=True, parents=True)
131
+ video_save_path = output_dir / f'{current_time_string}.mp4'
132
+ video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1))
133
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
134
+ gc.collect()
135
+ return video_save_path
136
+
137
+
138
+ @torch.inference_mode()
139
+ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
140
+ duration: float):
141
+
142
+ rng = torch.Generator(device=device)
143
+ if seed >= 0:
144
+ rng.manual_seed(seed)
145
+ else:
146
+ rng.seed()
147
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
148
+
149
+ clip_frames = sync_frames = None
150
+ seq_cfg.duration = duration
151
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
152
+
153
+ audios = generate(clip_frames,
154
+ sync_frames, [prompt],
155
+ negative_text=[negative_prompt],
156
+ feature_utils=feature_utils,
157
+ net=net,
158
+ fm=fm,
159
+ rng=rng,
160
+ cfg_strength=cfg_strength)
161
+ audio = audios.float().cpu()[0]
162
+
163
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
164
+ output_dir.mkdir(exist_ok=True, parents=True)
165
+ audio_save_path = output_dir / f'{current_time_string}.flac'
166
+ torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
167
+ gc.collect()
168
+ return audio_save_path
169
+
170
+
171
+ video_to_audio_tab = gr.Interface(
172
+ fn=video_to_audio,
173
+ description="""
174
+ Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
175
+ Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
176
+
177
+ NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
178
+ Doing so does not improve results.
179
+ """,
180
+ inputs=[
181
+ gr.Video(),
182
+ gr.Text(label='Prompt'),
183
+ gr.Text(label='Negative prompt', value='music'),
184
+ gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
185
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
186
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
187
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
188
+ ],
189
+ outputs='playable_video',
190
+ cache_examples=False,
191
+ title='MMAudio — Video-to-Audio Synthesis',
192
+ examples=[
193
+ [
194
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
195
+ 'waves, seagulls',
196
+ '',
197
+ 0,
198
+ 25,
199
+ 4.5,
200
+ 10,
201
+ ],
202
+ [
203
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4',
204
+ '',
205
+ 'music',
206
+ 0,
207
+ 25,
208
+ 4.5,
209
+ 10,
210
+ ],
211
+ [
212
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4',
213
+ 'bubbles',
214
+ '',
215
+ 0,
216
+ 25,
217
+ 4.5,
218
+ 10,
219
+ ],
220
+ [
221
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4',
222
+ 'Indian holy music',
223
+ '',
224
+ 0,
225
+ 25,
226
+ 4.5,
227
+ 10,
228
+ ],
229
+ [
230
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4',
231
+ 'galloping',
232
+ '',
233
+ 0,
234
+ 25,
235
+ 4.5,
236
+ 10,
237
+ ],
238
+ [
239
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
240
+ 'waves, storm',
241
+ '',
242
+ 0,
243
+ 25,
244
+ 4.5,
245
+ 10,
246
+ ],
247
+ [
248
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
249
+ 'storm',
250
+ '',
251
+ 0,
252
+ 25,
253
+ 4.5,
254
+ 10,
255
+ ],
256
+ [
257
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
258
+ '',
259
+ '',
260
+ 0,
261
+ 25,
262
+ 4.5,
263
+ 10,
264
+ ],
265
+ [
266
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
267
+ 'typing',
268
+ '',
269
+ 0,
270
+ 25,
271
+ 4.5,
272
+ 10,
273
+ ],
274
+ [
275
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
276
+ '',
277
+ '',
278
+ 0,
279
+ 25,
280
+ 4.5,
281
+ 10,
282
+ ],
283
+ [
284
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
285
+ '',
286
+ '',
287
+ 0,
288
+ 25,
289
+ 4.5,
290
+ 10,
291
+ ],
292
+ ])
293
+
294
+ text_to_audio_tab = gr.Interface(
295
+ fn=text_to_audio,
296
+ description="""
297
+ Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
298
+ Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
299
+ """,
300
+ inputs=[
301
+ gr.Text(label='Prompt'),
302
+ gr.Text(label='Negative prompt'),
303
+ gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
304
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
305
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
306
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
307
+ ],
308
+ outputs='audio',
309
+ cache_examples=False,
310
+ title='MMAudio — Text-to-Audio Synthesis',
311
+ )
312
+
313
+ image_to_audio_tab = gr.Interface(
314
+ fn=image_to_audio,
315
+ description="""
316
+ Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
317
+ Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
318
+
319
+ NOTE: It takes longer to process high-resolution images (>384 px on the shorter side).
320
+ Doing so does not improve results.
321
+ """,
322
+ inputs=[
323
+ gr.Image(type='filepath'),
324
+ gr.Text(label='Prompt'),
325
+ gr.Text(label='Negative prompt'),
326
+ gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
327
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
328
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
329
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
330
+ ],
331
+ outputs='playable_video',
332
+ cache_examples=False,
333
+ title='MMAudio — Image-to-Audio Synthesis (experimental)',
334
+ )
335
+
336
+ if __name__ == "__main__":
337
+ parser = ArgumentParser()
338
+ parser.add_argument('--port', type=int, default=7860)
339
+ args = parser.parse_args()
340
+
341
+ gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab],
342
+ ['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']).launch(
343
+ server_port=args.port, allowed_paths=[output_dir])
batch_eval.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import hydra
6
+ import torch
7
+ import torch.distributed as distributed
8
+ import torchaudio
9
+ from hydra.core.hydra_config import HydraConfig
10
+ from omegaconf import DictConfig
11
+ from tqdm import tqdm
12
+
13
+ from mmaudio.data.data_setup import setup_eval_dataset
14
+ from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate
15
+ from mmaudio.model.flow_matching import FlowMatching
16
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
17
+ from mmaudio.model.utils.features_utils import FeaturesUtils
18
+
19
+ torch.backends.cuda.matmul.allow_tf32 = True
20
+ torch.backends.cudnn.allow_tf32 = True
21
+
22
+ local_rank = int(os.environ['LOCAL_RANK'])
23
+ world_size = int(os.environ['WORLD_SIZE'])
24
+ log = logging.getLogger()
25
+
26
+
27
+ @torch.inference_mode()
28
+ @hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')
29
+ def main(cfg: DictConfig):
30
+ device = 'cuda'
31
+ torch.cuda.set_device(local_rank)
32
+
33
+ if cfg.model not in all_model_cfg:
34
+ raise ValueError(f'Unknown model variant: {cfg.model}')
35
+ model: ModelConfig = all_model_cfg[cfg.model]
36
+ model.download_if_needed()
37
+ seq_cfg = model.seq_cfg
38
+
39
+ run_dir = Path(HydraConfig.get().run.dir)
40
+ if cfg.output_name is None:
41
+ output_dir = run_dir / cfg.dataset
42
+ else:
43
+ output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}'
44
+ output_dir.mkdir(parents=True, exist_ok=True)
45
+
46
+ # load a pretrained model
47
+ seq_cfg.duration = cfg.duration_s
48
+ net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval()
49
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
50
+ log.info(f'Loaded weights from {model.model_path}')
51
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
52
+ log.info(f'Latent seq len: {seq_cfg.latent_seq_len}')
53
+ log.info(f'Clip seq len: {seq_cfg.clip_seq_len}')
54
+ log.info(f'Sync seq len: {seq_cfg.sync_seq_len}')
55
+
56
+ # misc setup
57
+ rng = torch.Generator(device=device)
58
+ rng.manual_seed(cfg.seed)
59
+ fm = FlowMatching(cfg.sampling.min_sigma,
60
+ inference_mode=cfg.sampling.method,
61
+ num_steps=cfg.sampling.num_steps)
62
+
63
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
64
+ synchformer_ckpt=model.synchformer_ckpt,
65
+ enable_conditions=True,
66
+ mode=model.mode,
67
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
68
+ need_vae_encoder=False)
69
+ feature_utils = feature_utils.to(device).eval()
70
+
71
+ if cfg.compile:
72
+ net.preprocess_conditions = torch.compile(net.preprocess_conditions)
73
+ net.predict_flow = torch.compile(net.predict_flow)
74
+ feature_utils.compile()
75
+
76
+ dataset, loader = setup_eval_dataset(cfg.dataset, cfg)
77
+
78
+ with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):
79
+ for batch in tqdm(loader):
80
+ audios = generate(batch.get('clip_video', None),
81
+ batch.get('sync_video', None),
82
+ batch.get('caption', None),
83
+ feature_utils=feature_utils,
84
+ net=net,
85
+ fm=fm,
86
+ rng=rng,
87
+ cfg_strength=cfg.cfg_strength,
88
+ clip_batch_size_multiplier=64,
89
+ sync_batch_size_multiplier=64)
90
+ audios = audios.float().cpu()
91
+ names = batch['name']
92
+ for audio, name in zip(audios, names):
93
+ torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)
94
+
95
+
96
+ def distributed_setup():
97
+ distributed.init_process_group(backend="nccl")
98
+ local_rank = distributed.get_rank()
99
+ world_size = distributed.get_world_size()
100
+ log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
101
+ return local_rank, world_size
102
+
103
+
104
+ if __name__ == '__main__':
105
+ distributed_setup()
106
+
107
+ main()
108
+
109
+ # clean-up
110
+ distributed.destroy_process_group()
config/__init__.py ADDED
File without changes
config/base_config.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - data: base
3
+ - eval_data: base
4
+ - override hydra/job_logging: custom-simplest
5
+ - _self_
6
+
7
+ hydra:
8
+ run:
9
+ dir: ./output/${exp_id}
10
+ output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
11
+
12
+ enable_email: False
13
+
14
+ model: small_16k
15
+
16
+ exp_id: default
17
+ debug: False
18
+ cudnn_benchmark: True
19
+ compile: True
20
+ amp: True
21
+ weights: null
22
+ checkpoint: null
23
+ seed: 14159265
24
+ num_workers: 10 # per-GPU
25
+ pin_memory: False # set to True if your system can handle it, i.e., have enough memory
26
+
27
+ # NOTE: This DOSE NOT affect the model during inference in any way
28
+ # they are just for the dataloader to fill in the missing data in multi-modal loading
29
+ # to change the sequence length for the model, see networks.py
30
+ data_dim:
31
+ text_seq_len: 77
32
+ clip_dim: 1024
33
+ sync_dim: 768
34
+ text_dim: 1024
35
+
36
+ # ema configuration
37
+ ema:
38
+ enable: True
39
+ sigma_rels: [0.05, 0.1]
40
+ update_every: 1
41
+ checkpoint_every: 5_000
42
+ checkpoint_folder: ${hydra:run.dir}/ema_ckpts
43
+ default_output_sigma: 0.05
44
+
45
+
46
+ # sampling
47
+ sampling:
48
+ mean: 0.0
49
+ scale: 1.0
50
+ min_sigma: 0.0
51
+ method: euler
52
+ num_steps: 25
53
+
54
+ # classifier-free guidance
55
+ null_condition_probability: 0.1
56
+ cfg_strength: 4.5
57
+
58
+ # checkpoint paths to external modules
59
+ vae_16k_ckpt: ./ext_weights/v1-16.pth
60
+ vae_44k_ckpt: ./ext_weights/v1-44.pth
61
+ bigvgan_vocoder_ckpt: ./ext_weights/best_netG.pt
62
+ synchformer_ckpt: ./ext_weights/synchformer_state_dict.pth
config/data/base.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VGGSound:
2
+ root: ../data/video
3
+ subset_name: sets/vgg3-train.tsv
4
+ fps: 8
5
+ height: 384
6
+ width: 384
7
+ sample_duration_sec: 8.0
8
+
9
+ VGGSound_test:
10
+ root: ../data/video
11
+ subset_name: sets/vgg3-test.tsv
12
+ fps: 8
13
+ height: 384
14
+ width: 384
15
+ sample_duration_sec: 8.0
16
+
17
+ VGGSound_val:
18
+ root: ../data/video
19
+ subset_name: sets/vgg3-val.tsv
20
+ fps: 8
21
+ height: 384
22
+ width: 384
23
+ sample_duration_sec: 8.0
24
+
25
+ ExtractedVGG:
26
+ tsv: ../data/v1-16-memmap/vgg-train.tsv
27
+ memmap_dir: ../data/v1-16-memmap/vgg-train
28
+
29
+ ExtractedVGG_test:
30
+ tag: test
31
+ gt_cache: ../data/eval-cache/vggsound-test
32
+ output_subdir: null
33
+ tsv: ../data/v1-16-memmap/vgg-test.tsv
34
+ memmap_dir: ../data/v1-16-memmap/vgg-test
35
+
36
+ ExtractedVGG_val:
37
+ tag: val
38
+ gt_cache: ../data/eval-cache/vggsound-val
39
+ output_subdir: val
40
+ tsv: ../data/v1-16-memmap/vgg-val.tsv
41
+ memmap_dir: ../data/v1-16-memmap/vgg-val
42
+
43
+ AudioCaps:
44
+ tsv: ../data/v1-16-memmap/audiocaps.tsv
45
+ memmap_dir: ../data/v1-16-memmap/audiocaps
46
+
47
+ AudioSetSL:
48
+ tsv: ../data/v1-16-memmap/audioset_sl.tsv
49
+ memmap_dir: ../data/v1-16-memmap/audioset_sl
50
+
51
+ BBCSound:
52
+ tsv: ../data/v1-16-memmap/bbcsound.tsv
53
+ memmap_dir: ../data/v1-16-memmap/bbcsound
54
+
55
+ FreeSound:
56
+ tsv: ../data/v1-16-memmap/freesound.tsv
57
+ memmap_dir: ../data/v1-16-memmap/freesound
58
+
59
+ Clotho:
60
+ tsv: ../data/v1-16-memmap/clotho.tsv
61
+ memmap_dir: ../data/v1-16-memmap/clotho
62
+
63
+ Example_video:
64
+ tsv: ./training/example_output/memmap/vgg-example.tsv
65
+ memmap_dir: ./training/example_output/memmap/vgg-example
66
+
67
+ Example_audio:
68
+ tsv: ./training/example_output/memmap/audio-example.tsv
69
+ memmap_dir: ./training/example_output/memmap/audio-example
70
+
config/eval_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_config
3
+ - override hydra/job_logging: custom-simplest
4
+ - _self_
5
+
6
+ hydra:
7
+ run:
8
+ dir: ./output/${exp_id}
9
+ output_subdir: eval-${now:%Y-%m-%d_%H-%M-%S}-hydra
10
+
11
+ exp_id: ${model}
12
+ dataset: audiocaps
13
+ duration_s: 8.0
14
+
15
+ # for inference, this is the per-GPU batch size
16
+ batch_size: 16
17
+ output_name: null
config/eval_data/base.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AudioCaps:
2
+ audio_path: ../data/AudioCaps-test-audioldm-ver
3
+ # a csv file, with a header row of 'name' and 'caption'
4
+ # name should match the audio file name without extension
5
+ # Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_audioldm_data.csv
6
+ csv_path: ../data/AudioCaps-test-audioldm-ver/data.csv
7
+
8
+ AudioCaps_full:
9
+ audio_path: ../data/AudioCaps-test-full-ver
10
+ # a csv file, with a header row of 'name' and 'caption'
11
+ # name should match the audio file name without extension
12
+ # Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_full_data.csv
13
+ csv_path: ../data/AudioCaps-test-full-ver/data.csv
14
+
15
+ MovieGen:
16
+ video_path: ../data/MovieGen/MovieGenAudioBenchSfx/video_with_audio
17
+ jsonl_path: ../data/MovieGen/MovieGenAudioBenchSfx/metadata
18
+
19
+ VGGSound:
20
+ video_path: ../data/test-videos
21
+ # from the officially released csv file
22
+ csv_path: ../data/vggsound.csv
config/hydra/job_logging/custom-eval.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ colorlog:
8
+ '()': 'colorlog.ColoredFormatter'
9
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
+ datefmt: '%Y-%m-%d %H:%M:%S'
11
+ log_colors:
12
+ DEBUG: purple
13
+ INFO: green
14
+ WARNING: yellow
15
+ ERROR: red
16
+ CRITICAL: red
17
+ handlers:
18
+ console:
19
+ class: logging.StreamHandler
20
+ formatter: colorlog
21
+ stream: ext://sys.stdout
22
+ file:
23
+ class: logging.FileHandler
24
+ formatter: simple
25
+ # absolute file path
26
+ filename: ${hydra.runtime.output_dir}/eval-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
27
+ mode: w
28
+ root:
29
+ level: INFO
30
+ handlers: [console, file]
31
+
32
+ disable_existing_loggers: false
config/hydra/job_logging/custom-no-rank.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ colorlog:
8
+ '()': 'colorlog.ColoredFormatter'
9
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
+ datefmt: '%Y-%m-%d %H:%M:%S'
11
+ log_colors:
12
+ DEBUG: purple
13
+ INFO: green
14
+ WARNING: yellow
15
+ ERROR: red
16
+ CRITICAL: red
17
+ handlers:
18
+ console:
19
+ class: logging.StreamHandler
20
+ formatter: colorlog
21
+ stream: ext://sys.stdout
22
+ file:
23
+ class: logging.FileHandler
24
+ formatter: simple
25
+ # absolute file path
26
+ filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
27
+ mode: w
28
+ root:
29
+ level: INFO
30
+ handlers: [console, file]
31
+
32
+ disable_existing_loggers: false
config/hydra/job_logging/custom-simplest.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ colorlog:
8
+ '()': 'colorlog.ColoredFormatter'
9
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
+ datefmt: '%Y-%m-%d %H:%M:%S'
11
+ log_colors:
12
+ DEBUG: purple
13
+ INFO: green
14
+ WARNING: yellow
15
+ ERROR: red
16
+ CRITICAL: red
17
+ handlers:
18
+ console:
19
+ class: logging.StreamHandler
20
+ formatter: colorlog
21
+ stream: ext://sys.stdout
22
+ root:
23
+ level: INFO
24
+ handlers: [console]
25
+
26
+ disable_existing_loggers: false
config/hydra/job_logging/custom.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package hydra.job_logging
2
+ # python logging configuration for tasks
3
+ version: 1
4
+ formatters:
5
+ simple:
6
+ format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
7
+ datefmt: '%Y-%m-%d %H:%M:%S'
8
+ colorlog:
9
+ '()': 'colorlog.ColoredFormatter'
10
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)sr${oc.env:LOCAL_RANK}%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
11
+ datefmt: '%Y-%m-%d %H:%M:%S'
12
+ log_colors:
13
+ DEBUG: purple
14
+ INFO: green
15
+ WARNING: yellow
16
+ ERROR: red
17
+ CRITICAL: red
18
+ handlers:
19
+ console:
20
+ class: logging.StreamHandler
21
+ formatter: colorlog
22
+ stream: ext://sys.stdout
23
+ file:
24
+ class: logging.FileHandler
25
+ formatter: simple
26
+ # absolute file path
27
+ filename: ${hydra.runtime.output_dir}/train-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
28
+ mode: w
29
+ root:
30
+ level: INFO
31
+ handlers: [console, file]
32
+
33
+ disable_existing_loggers: false
config/train_config.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_config
3
+ - override data: base
4
+ - override hydra/job_logging: custom
5
+ - _self_
6
+
7
+ hydra:
8
+ run:
9
+ dir: ./output/${exp_id}
10
+ output_subdir: train-${now:%Y-%m-%d_%H-%M-%S}-hydra
11
+
12
+ ema:
13
+ start: 0
14
+
15
+ mini_train: False
16
+ example_train: False
17
+ enable_grad_scaler: False
18
+ vgg_oversample_rate: 5
19
+
20
+ log_text_interval: 200
21
+ log_extra_interval: 20_000
22
+ val_interval: 5_000
23
+ eval_interval: 20_000
24
+ save_eval_interval: 40_000
25
+ save_weights_interval: 10_000
26
+ save_checkpoint_interval: 10_000
27
+ save_copy_iterations: []
28
+
29
+ batch_size: 512
30
+ eval_batch_size: 256 # per-GPU
31
+
32
+ num_iterations: 300_000
33
+ learning_rate: 1.0e-4
34
+ linear_warmup_steps: 1_000
35
+
36
+ lr_schedule: step
37
+ lr_schedule_steps: [240_000, 270_000]
38
+ lr_schedule_gamma: 0.1
39
+
40
+ clip_grad_norm: 1.0
41
+ weight_decay: 1.0e-6
demo.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torchaudio
7
+
8
+ from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
9
+ setup_eval_logging)
10
+ from mmaudio.model.flow_matching import FlowMatching
11
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
12
+ from mmaudio.model.utils.features_utils import FeaturesUtils
13
+
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.backends.cudnn.allow_tf32 = True
16
+
17
+ log = logging.getLogger()
18
+
19
+
20
+ @torch.inference_mode()
21
+ def main():
22
+ setup_eval_logging()
23
+
24
+ parser = ArgumentParser()
25
+ parser.add_argument('--variant',
26
+ type=str,
27
+ default='large_44k_v2',
28
+ help='small_16k, small_44k, medium_44k, large_44k, large_44k_v2')
29
+ parser.add_argument('--video', type=Path, help='Path to the video file')
30
+ parser.add_argument('--prompt', type=str, help='Input prompt', default='')
31
+ parser.add_argument('--negative_prompt', type=str, help='Negative prompt', default='')
32
+ parser.add_argument('--duration', type=float, default=8.0)
33
+ parser.add_argument('--cfg_strength', type=float, default=4.5)
34
+ parser.add_argument('--num_steps', type=int, default=25)
35
+
36
+ parser.add_argument('--mask_away_clip', action='store_true')
37
+
38
+ parser.add_argument('--output', type=Path, help='Output directory', default='./output')
39
+ parser.add_argument('--seed', type=int, help='Random seed', default=42)
40
+ parser.add_argument('--skip_video_composite', action='store_true')
41
+ parser.add_argument('--full_precision', action='store_true')
42
+
43
+ args = parser.parse_args()
44
+
45
+ if args.variant not in all_model_cfg:
46
+ raise ValueError(f'Unknown model variant: {args.variant}')
47
+ model: ModelConfig = all_model_cfg[args.variant]
48
+ model.download_if_needed()
49
+ seq_cfg = model.seq_cfg
50
+
51
+ if args.video:
52
+ video_path: Path = Path(args.video).expanduser()
53
+ else:
54
+ video_path = None
55
+ prompt: str = args.prompt
56
+ negative_prompt: str = args.negative_prompt
57
+ output_dir: str = args.output.expanduser()
58
+ seed: int = args.seed
59
+ num_steps: int = args.num_steps
60
+ duration: float = args.duration
61
+ cfg_strength: float = args.cfg_strength
62
+ skip_video_composite: bool = args.skip_video_composite
63
+ mask_away_clip: bool = args.mask_away_clip
64
+
65
+ device = 'cpu'
66
+ if torch.cuda.is_available():
67
+ device = 'cuda'
68
+ elif torch.backends.mps.is_available():
69
+ device = 'mps'
70
+ else:
71
+ log.warning('CUDA/MPS are not available, running on CPU')
72
+ dtype = torch.float32 if args.full_precision else torch.bfloat16
73
+
74
+ output_dir.mkdir(parents=True, exist_ok=True)
75
+
76
+ # load a pretrained model
77
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
78
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
79
+ log.info(f'Loaded weights from {model.model_path}')
80
+
81
+ # misc setup
82
+ rng = torch.Generator(device=device)
83
+ rng.manual_seed(seed)
84
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
85
+
86
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
87
+ synchformer_ckpt=model.synchformer_ckpt,
88
+ enable_conditions=True,
89
+ mode=model.mode,
90
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
91
+ need_vae_encoder=False)
92
+ feature_utils = feature_utils.to(device, dtype).eval()
93
+
94
+ if video_path is not None:
95
+ log.info(f'Using video {video_path}')
96
+ video_info = load_video(video_path, duration)
97
+ clip_frames = video_info.clip_frames
98
+ sync_frames = video_info.sync_frames
99
+ duration = video_info.duration_sec
100
+ if mask_away_clip:
101
+ clip_frames = None
102
+ else:
103
+ clip_frames = clip_frames.unsqueeze(0)
104
+ sync_frames = sync_frames.unsqueeze(0)
105
+ else:
106
+ log.info('No video provided -- text-to-audio mode')
107
+ clip_frames = sync_frames = None
108
+
109
+ seq_cfg.duration = duration
110
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
111
+
112
+ log.info(f'Prompt: {prompt}')
113
+ log.info(f'Negative prompt: {negative_prompt}')
114
+
115
+ audios = generate(clip_frames,
116
+ sync_frames, [prompt],
117
+ negative_text=[negative_prompt],
118
+ feature_utils=feature_utils,
119
+ net=net,
120
+ fm=fm,
121
+ rng=rng,
122
+ cfg_strength=cfg_strength)
123
+ audio = audios.float().cpu()[0]
124
+ if video_path is not None:
125
+ save_path = output_dir / f'{video_path.stem}.flac'
126
+ else:
127
+ safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')
128
+ save_path = output_dir / f'{safe_filename}.flac'
129
+ torchaudio.save(save_path, audio, seq_cfg.sampling_rate)
130
+
131
+ log.info(f'Audio saved to {save_path}')
132
+ if video_path is not None and not skip_video_composite:
133
+ video_save_path = output_dir / f'{video_path.stem}.mp4'
134
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
135
+ log.info(f'Video saved to {output_dir / video_save_path}')
136
+
137
+ log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
138
+
139
+
140
+ if __name__ == '__main__':
141
+ main()
docs/EVAL.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation
2
+
3
+ ## Batch Evaluation
4
+
5
+ To evaluate the model on a dataset, use the `batch_eval.py` script. It is significantly more efficient in large-scale evaluation compared to `demo.py`, supporting batched inference, multi-GPU inference, torch compilation, and skipping video compositions.
6
+
7
+ An example of running this script with four GPUs is as follows:
8
+
9
+ ```bash
10
+ OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=4 batch_eval.py duration_s=8 dataset=vggsound model=small_16k num_workers=8
11
+ ```
12
+
13
+ You may need to update the data paths in `config/eval_data/base.yaml`.
14
+ More configuration options can be found in `config/base_config.yaml` and `config/eval_config.yaml`.
15
+
16
+ ## Precomputed Results
17
+
18
+ Precomputed results for VGGSound, AudioCaps, and MovieGen are available here: https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results
19
+
20
+ ## Obtaining Quantitative Metrics
21
+
22
+ Our evaluation code is available here: https://github.com/hkchengrex/av-benchmark
docs/MODELS.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pretrained models
2
+
3
+ The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`.
4
+ The models are also available at https://huggingface.co/hkchengrex/MMAudio/tree/main
5
+
6
+ | Model | Download link | File size |
7
+ | -------- | ------- | ------- |
8
+ | Flow prediction network, small 16kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_16k.pth" download="mmaudio_small_16k.pth">mmaudio_small_16k.pth</a> | 601M |
9
+ | Flow prediction network, small 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_44k.pth" download="mmaudio_small_44k.pth">mmaudio_small_44k.pth</a> | 601M |
10
+ | Flow prediction network, medium 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_medium_44k.pth" download="mmaudio_medium_44k.pth">mmaudio_medium_44k.pth</a> | 2.4G |
11
+ | Flow prediction network, large 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k.pth" download="mmaudio_large_44k.pth">mmaudio_large_44k.pth</a> | 3.9G |
12
+ | Flow prediction network, large 44.1kHz, v2 **(recommended)** | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth" download="mmaudio_large_44k_v2.pth">mmaudio_large_44k_v2.pth</a> | 3.9G |
13
+ | 16kHz VAE | <a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-16.pth">v1-16.pth</a> | 655M |
14
+ | 16kHz BigVGAN vocoder (from Make-An-Audio 2) |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/best_netG.pt">best_netG.pt</a> | 429M |
15
+ | 44.1kHz VAE |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth">v1-44.pth</a> | 1.2G |
16
+ | Synchformer visual encoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth">synchformer_state_dict.pth</a> | 907M |
17
+
18
+ To run the model, you need four components: a flow prediction network, visual feature extractors (Synchformer and CLIP, CLIP will be downloaded automatically), a VAE, and a vocoder. VAEs and vocoders are specific to the sampling rate (16kHz or 44.1kHz) and not model sizes.
19
+ The 44.1kHz vocoder will be downloaded automatically.
20
+ The `_v2` model performs worse in benchmarking (e.g., in Fréchet distance), but, in my experience, generalizes better to new data.
21
+
22
+ The expected directory structure (full):
23
+
24
+ ```bash
25
+ MMAudio
26
+ ├── ext_weights
27
+ │ ├── best_netG.pt
28
+ │ ├── synchformer_state_dict.pth
29
+ │ ├── v1-16.pth
30
+ │ └── v1-44.pth
31
+ ├── weights
32
+ │ ├── mmaudio_small_16k.pth
33
+ │ ├── mmaudio_small_44k.pth
34
+ │ ├── mmaudio_medium_44k.pth
35
+ │ ├── mmaudio_large_44k.pth
36
+ │ └── mmaudio_large_44k_v2.pth
37
+ └── ...
38
+ ```
39
+
40
+ The expected directory structure (minimal, for the recommended model only):
41
+
42
+ ```bash
43
+ MMAudio
44
+ ├── ext_weights
45
+ │ ├── synchformer_state_dict.pth
46
+ │ └── v1-44.pth
47
+ ├── weights
48
+ │ └── mmaudio_large_44k_v2.pth
49
+ └── ...
50
+ ```
docs/TRAINING.md ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training
2
+
3
+ ## Overview
4
+
5
+ We have put a large emphasis on making training as fast as possible.
6
+ Consequently, some pre-processing steps are required.
7
+
8
+ Namely, before starting any training, we
9
+
10
+ 1. Obtain training data as videos, audios, and captions.
11
+ 2. Encode training audios into spectrograms and then with VAE into mean/std
12
+ 3. Extract CLIP and synchronization features from videos
13
+ 4. Extract CLIP features from text (captions)
14
+ 5. Encode all extracted features into [MemoryMappedTensors](https://pytorch.org/tensordict/main/reference/generated/tensordict.MemoryMappedTensor.html) with [TensorDict](https://pytorch.org/tensordict/main/reference/tensordict.html)
15
+
16
+ **NOTE:** for maximum training speed (e.g., when training the base model with 2*H100s), you would need around 3~5 GB/s of random read speed. Spinning disks would not be able to catch up and most consumer-grade SSDs would struggle. In my experience, the best bet is to have a large enough system memory such that the OS can cache the data. This way, the data is read from RAM instead of disk.
17
+
18
+ The current training script does not support `_v2` training.
19
+
20
+ ## Recommended Hardware Configuration
21
+
22
+ These are what I recommend for a smooth and efficient training experience. These are not minimum requirements.
23
+
24
+ - Single-node machine. We did not implement multi-node training
25
+ - GPUs: for the small model, two 80G-H100s or above; for the large model, eight 80G-H100s or above
26
+ - System memory: for 16kHz training, 600GB+; for 44kHz training, 700GB+
27
+ - Storage: >2TB of fast NVMe storage. If you have enough system memory, OS caching will help and the storage does not need to be as fast.
28
+
29
+ ## Prerequisites
30
+
31
+ 1. Install [av-benchmark](https://github.com/hkchengrex/av-benchmark). We use this library to automatically evaluate on the validation set during training, and on the test set after training.
32
+ 2. Extract features for evaluation using [av-benchmark](https://github.com/hkchengrex/av-benchmark) for the validation and test set as a [validation cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L38) and a [test cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L31). You can also download the precomputed evaluation cache [here](https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results/tree/main).
33
+
34
+ 3. You will need ffmpeg to extract frames from videos. Note that `torchaudio` imposes a maximum version limit (`ffmpeg<7`). You can install it as follows:
35
+
36
+ ```bash
37
+ conda install -c conda-forge 'ffmpeg<7'
38
+ ```
39
+
40
+ 4. Download the training datasets. We used [VGGSound](https://arxiv.org/abs/2004.14368), [AudioCaps](https://audiocaps.github.io/), [WavCaps](https://arxiv.org/abs/2303.17395), and [Clotho](https://arxiv.org/abs/1910.09387) (paper to be updated). Note that the audio files in the huggingface release of WavCaps have been downsampled to 32kHz. To the best of our ability, we located the original (high-sampling rate) audio files and used them instead to prevent artifacts during 44.1kHz training. We did not use the "SoundBible" portion of WavCaps, since it is a small set with many short audio unsuitable for our training.
41
+
42
+ 5. Download the corresponding VAE (`v1-16.pth` for 16kHz training, and `v1-44.pth` for 44.1kHz training), vocoder models (`best_netG.pt` for 16kHz training; the vocoder for 44.1kHz training will be downloaded automatically), the [empty string encoding](https://github.com/hkchengrex/MMAudio/releases/download/v0.1/empty_string.pth), and Synchformer weights from [MODELS.md](https://github.com/hkchengrex/MMAudio/blob/main/docs/MODELS.md) place them in `ext_weights/`.
43
+
44
+ ### Helpful links for downloading the datasets
45
+
46
+ We cannot redistribute the datasets for copyright reasons, but we do find some links helpful and they might be helpful to you as well.
47
+
48
+ - https://huggingface.co/datasets/Meranti/CLAP_freesound
49
+ - https://huggingface.co/datasets/agkphysics/AudioSet
50
+ - https://sound-effects.bbcrewind.co.uk/
51
+
52
+ For certain sources of VGGSound, you might notice desychronization between the audio and the video. This happens the video keyframes do not always align with the start of the audio and what happens during playbacks is player-dependent. We used PyTorch's decoder which can correctly handle these cases.
53
+
54
+ ## Preparing Audio-Video-Text Features
55
+
56
+ We have prepared some example data in `training/example_videos`.
57
+ `training/extract_video_training_latents.py` extracts audio, video, and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
58
+
59
+ To run this script, use the `torchrun` utility:
60
+
61
+ ```bash
62
+ torchrun --standalone training/extract_video_training_latents.py
63
+ ```
64
+
65
+ You can run this script with multiple GPUs (with `--nproc_per_node=<n>` after `--standalone` and before the script name) to speed up extraction.
66
+ Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
67
+ Change the data path definitions in `data_cfg` if necessary.
68
+
69
+ Arguments:
70
+
71
+ - `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
72
+ - `output_dir` -- where TensorDict and the metadata file are saved.
73
+
74
+ Outputs produced in `output_dir`:
75
+
76
+ 1. A directory named `vgg-{split}` (i.e., in the TensorDict format), containing
77
+ a. `mean.memmap` mean values predicted by the VAE encoder (number of videos X sequence length X channel size)
78
+ b. `std.memmap` standard deviation values predicted by the VAE encoder (number of videos X sequence length X channel size)
79
+ c. `text_features.memmap` text features extracted from CLIP (number of videos X 77 (sequence length) X 1024)
80
+ d. `clip_features.memmap` clip features extracted from CLIP (number of videos X 64 (8 fps) X 1024)
81
+ e. `sync_features.memmap` synchronization features extracted from Synchformer (number of videos X 192 (24 fps) X 768)
82
+ f. `meta.json` that contains the metadata for the above memory mappings
83
+ 2. A tab-separated values file named `vgg-{split}.tsv` that contains two columns: `id` containing video file names without extension, and `label` containing corresponding text labels (i.e., captions)
84
+
85
+ ## Preparing Audio-Text Features
86
+
87
+ We have prepared some example data in `training/example_audios`.
88
+
89
+ 1. Run `training/partition_clips` to partition each audio file into clips (by finding start and end points; we do not save the partitioned audio onto the disk to save disk space)
90
+ 2. Run `training/extract_audio_training_latents.py` to extract each clip's audio and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
91
+
92
+ ### Partitioning the audio files
93
+
94
+ Run
95
+
96
+ ```bash
97
+ python training/partition_clips.py
98
+ ```
99
+
100
+ Arguments:
101
+
102
+ - `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`)
103
+ - `output_dir` -- path to the output `.csv` file
104
+ - `start` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the beginning of the chunk to be processed
105
+ - `end` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the end of the chunk to be processed
106
+
107
+ ### Extracting audio and text features
108
+
109
+ Run
110
+
111
+ ```bash
112
+ torchrun --standalone training/extract_audio_training_latents.py
113
+ ```
114
+
115
+ You can run this with multiple GPUs (with `--nproc_per_node=<n>`) to speed up extraction.
116
+ Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
117
+
118
+ Arguments:
119
+
120
+ - `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`), same as the previous step
121
+ - `captions_tsv` -- path to the captions file, a tab-separated values (tsv) file at least with columns `id` and `caption`
122
+ - `clips_tsv` -- path to the clips file, generated in the last step
123
+ - `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
124
+ - `output_dir` -- where TensorDict and the metadata file are saved.
125
+
126
+ Outputs produced in `output_dir`:
127
+
128
+ 1. A directory named `{basename(output_dir)}` (i.e., in the TensorDict format), containing
129
+ a. `mean.memmap` mean values predicted by the VAE encoder (number of audios X sequence length X channel size)
130
+ b. `std.memmap` standard deviation values predicted by the VAE encoder (number of audios X sequence length X channel size)
131
+ c. `text_features.memmap` text features extracted from CLIP (number of audios X 77 (sequence length) X 1024)
132
+ f. `meta.json` that contains the metadata for the above memory mappings
133
+ 2. A tab-separated values file named `{basename(output_dir)}.tsv` that contains two columns: `id` containing audio file names without extension, and `label` containing corresponding text labels (i.e., captions)
134
+
135
+ ### Reference tsv files (with overlaps removed as mentioned in the paper)
136
+
137
+ The reference tsv files can be found [here](https://github.com/hkchengrex/MMAudio/releases/tag/v0.1).
138
+
139
+ Note that these reference tsv files are the **outputs** of `extract_audio_training_latents.py`, which means the `id` column might contain duplicate entries (one per clip). You can still use it as the `captions_tsv` input though -- the script will handle duplicates gracefully.
140
+ Among these reference tsv files, `audioset_sl.tsv`, `bbcsound.tsv`, and `freesound.tsv` are subsets that are parts of WavCaps. These subsets might be smaller than the original datasets.
141
+ The Clotho data contains both the development set and the validation set.
142
+
143
+ **Update (Mar 9, 2025)**:
144
+ We have updated a corrected set of reference tsv files. The previous tsv files contained some (<1%) corrupted captions (ie, mismatch between audio and caption, see https://github.com/hkchengrex/MMAudio/issues/56). The tsv files for VGGSound are unaffected. This reason for this error is unknown, but I cannot reproduce this error in the latest version of the code. Our pre-trained models are trained with **uncorrected** tsv files. For future training, I recommend using the corrected tsv files.
145
+
146
+ The error statistics are as follows:
147
+
148
+ - AudioCaps (170/43824), 0.39%
149
+ - Freesound: (1670/180636), 0.92%
150
+ - AudioSet: (290/100776), 0.29%
151
+ - BBCSound: (3/29975), 0.01%
152
+ - Clotho: (8/24332), 0.03%
153
+
154
+ ## Training on Extracted Features
155
+
156
+ We use Distributed Data Parallel (DDP) for training.
157
+ First, specify the data path in `config/data/base.yaml`. If you used the default parameters in the scripts above to extract features for the example data, the `Example_video` and `Example_audio` items should already be correct.
158
+
159
+ To run training on the example data, use the following command:
160
+
161
+ ```bash
162
+ OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=1 train.py exp_id=debug compile=False debug=True example_train=True batch_size=1
163
+ ```
164
+
165
+ This will not train a useful model, but it will check if everything is set up correctly.
166
+
167
+ For full training on the base model with two GPUs, use the following command:
168
+
169
+ ```bash
170
+ OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=2 train.py exp_id=exp_1 model=small_16k
171
+ ```
172
+
173
+ Any outputs from training will be stored in `output/<exp_id>`.
174
+
175
+ More configuration options can be found in `config/base_config.yaml` and `config/train_config.yaml`.
176
+ For the medium and large models, specify `vgg_oversample_rate` to be `3` to reduce overfitting.
177
+
178
+ ## Checkpoints
179
+
180
+ Model checkpoints, including optimizer states and the latest EMA weights, are available here: https://huggingface.co/hkchengrex/MMAudio
181
+
182
+ ---
183
+
184
+ Godspeed!
docs/images/icon.png ADDED
docs/index.html ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <!-- Google tag (gtag.js) -->
5
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
6
+ <script>
7
+ window.dataLayer = window.dataLayer || [];
8
+ function gtag(){dataLayer.push(arguments);}
9
+ gtag('js', new Date());
10
+ gtag('config', 'G-0JKBJ3WRJZ');
11
+ </script>
12
+
13
+ <link rel="preconnect" href="https://fonts.googleapis.com">
14
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
15
+ <link href="https://fonts.googleapis.com/css2?family=Source+Sans+3&display=swap" rel="stylesheet">
16
+ <meta charset="UTF-8">
17
+ <title>MMAudio</title>
18
+
19
+ <link rel="icon" type="image/png" href="images/icon.png">
20
+
21
+ <meta name="viewport" content="width=device-width, initial-scale=1">
22
+ <!-- CSS only -->
23
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
24
+ integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
25
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
26
+
27
+ <link rel="stylesheet" href="style.css">
28
+ </head>
29
+ <body>
30
+
31
+ <body>
32
+ <br><br><br><br>
33
+ <div class="container">
34
+ <div class="row text-center" style="font-size:38px">
35
+ <div class="col strong">
36
+ Taming Multimodal Joint Training for High-Quality <br>Video-to-Audio Synthesis
37
+ </div>
38
+ </div>
39
+
40
+ <br>
41
+ <div class="row text-center" style="font-size:28px">
42
+ <div class="col">
43
+ CVPR 2025
44
+ </div>
45
+ </div>
46
+ <br>
47
+
48
+ <div class="h-100 row text-center heavy justify-content-md-center" style="font-size:22px;">
49
+ <div class="col-sm-auto px-lg-2">
50
+ <a href="https://hkchengrex.github.io/">Ho Kei Cheng<sup>1</sup></a>
51
+ </div>
52
+ <div class="col-sm-auto px-lg-2">
53
+ <nobr><a href="https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ">Masato Ishii<sup>2</sup></a></nobr>
54
+ </div>
55
+ <div class="col-sm-auto px-lg-2">
56
+ <nobr><a href="https://scholar.google.com/citations?user=sXAjHFIAAAAJ">Akio Hayakawa<sup>2</sup></a></nobr>
57
+ </div>
58
+ <div class="col-sm-auto px-lg-2">
59
+ <nobr><a href="https://scholar.google.com/citations?user=XCRO260AAAAJ">Takashi Shibuya<sup>2</sup></a></nobr>
60
+ </div>
61
+ <div class="col-sm-auto px-lg-2">
62
+ <nobr><a href="https://www.alexander-schwing.de/">Alexander Schwing<sup>1</sup></a></nobr>
63
+ </div>
64
+ <div class="col-sm-auto px-lg-2" >
65
+ <nobr><a href="https://www.yukimitsufuji.com/">Yuki Mitsufuji<sup>2,3</sup></a></nobr>
66
+ </div>
67
+ </div>
68
+
69
+ <div class="h-100 row text-center heavy justify-content-md-center" style="font-size:22px;">
70
+ <div class="col-sm-auto px-lg-2">
71
+ <sup>1</sup>University of Illinois Urbana-Champaign
72
+ </div>
73
+ <div class="col-sm-auto px-lg-2">
74
+ <sup>2</sup>Sony AI
75
+ </div>
76
+ <div class="col-sm-auto px-lg-2">
77
+ <sup>3</sup>Sony Group Corporation
78
+ </div>
79
+ </div>
80
+
81
+ <br>
82
+
83
+ <br>
84
+
85
+ <div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
86
+ <div class="col-sm-2">
87
+ <a href="https://arxiv.org/abs/2412.15322">[Paper]</a>
88
+ </div>
89
+ <div class="col-sm-2">
90
+ <a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
91
+ </div>
92
+ <div class="col-sm-3">
93
+ <a href="https://huggingface.co/spaces/hkchengrex/MMAudio">[Huggingface Demo]</a>
94
+ </div>
95
+ <div class="col-sm-2">
96
+ <a href="https://colab.research.google.com/drive/1TAaXCY2-kPk4xE4PwKB3EqFbSnkUuzZ8?usp=sharing">[Colab Demo]</a>
97
+ </div>
98
+ <div class="col-sm-3">
99
+ <a href="https://replicate.com/zsxkib/mmaudio">[Replicate Demo]</a>
100
+ </div>
101
+ </div>
102
+
103
+ <br>
104
+
105
+ <hr>
106
+
107
+ <div class="row" style="font-size:32px">
108
+ <div class="col strong">
109
+ TL;DR
110
+ </div>
111
+ </div>
112
+ <br>
113
+ <div class="row">
114
+ <div class="col">
115
+ <p class="light" style="text-align: left;">
116
+ MMAudio generates synchronized audio given video and/or text inputs.
117
+ </p>
118
+ </div>
119
+ </div>
120
+
121
+ <br>
122
+ <hr>
123
+ <br>
124
+
125
+ <div class="row" style="font-size:32px">
126
+ <div class="col strong">
127
+ Demo
128
+ </div>
129
+ </div>
130
+ <br>
131
+ <div class="row" style="font-size:48px">
132
+ <div class="col strong text-center">
133
+ <a href="video_main.html" style="text-decoration: underline;">&lt;More results&gt;</a>
134
+ </div>
135
+ </div>
136
+ <br>
137
+ <div class="video-container" style="text-align: center;">
138
+ <iframe src="https://youtube.com/embed/YElewUT2M4M"></iframe>
139
+ </div>
140
+
141
+ <br>
142
+
143
+ <br><br>
144
+ <br><br>
145
+
146
+ </div>
147
+
148
+ </body>
149
+ </html>
docs/style.css ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: 'Source Sans 3', sans-serif;
3
+ font-size: 18px;
4
+ margin-left: auto;
5
+ margin-right: auto;
6
+ font-weight: 400;
7
+ height: 100%;
8
+ max-width: 1000px;
9
+ }
10
+
11
+ table {
12
+ width: 100%;
13
+ border-collapse: collapse;
14
+ }
15
+ th, td {
16
+ border: 1px solid #ddd;
17
+ padding: 8px;
18
+ text-align: center;
19
+ }
20
+ th {
21
+ background-color: #f2f2f2;
22
+ }
23
+ video {
24
+ width: 100%;
25
+ height: auto;
26
+ }
27
+ p {
28
+ font-size: 28px;
29
+ }
30
+ h2 {
31
+ font-size: 36px;
32
+ }
33
+
34
+ .strong {
35
+ font-weight: 700;
36
+ }
37
+
38
+ .light {
39
+ font-weight: 100;
40
+ }
41
+
42
+ .heavy {
43
+ font-weight: 900;
44
+ }
45
+
46
+ .column {
47
+ float: left;
48
+ }
49
+
50
+ a:link,
51
+ a:visited {
52
+ color: #05538f;
53
+ text-decoration: none;
54
+ }
55
+
56
+ a:hover {
57
+ color: #63cbdd;
58
+ }
59
+
60
+ hr {
61
+ border: 0;
62
+ height: 1px;
63
+ background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
64
+ }
65
+
66
+ .video-container {
67
+ position: relative;
68
+ padding-bottom: 56.25%; /* 16:9 */
69
+ height: 0;
70
+ }
71
+
72
+ .video-container iframe {
73
+ position: absolute;
74
+ top: 0;
75
+ left: 0;
76
+ width: 100%;
77
+ height: 100%;
78
+ }
docs/style_videos.css ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: 'Source Sans 3', sans-serif;
3
+ font-size: 1.5vh;
4
+ font-weight: 400;
5
+ }
6
+
7
+ table {
8
+ width: 100%;
9
+ border-collapse: collapse;
10
+ }
11
+ th, td {
12
+ border: 1px solid #ddd;
13
+ padding: 8px;
14
+ text-align: center;
15
+ }
16
+ th {
17
+ background-color: #f2f2f2;
18
+ }
19
+ video {
20
+ width: 100%;
21
+ height: auto;
22
+ }
23
+ p {
24
+ font-size: 1.5vh;
25
+ font-weight: bold;
26
+ }
27
+ h2 {
28
+ font-size: 2vh;
29
+ font-weight: bold;
30
+ }
31
+
32
+ .video-container {
33
+ position: relative;
34
+ padding-bottom: 56.25%; /* 16:9 */
35
+ height: 0;
36
+ }
37
+
38
+ .video-container iframe {
39
+ position: absolute;
40
+ top: 0;
41
+ left: 0;
42
+ width: 100%;
43
+ height: 100%;
44
+ }
45
+
46
+ .video-header {
47
+ background-color: #f2f2f2;
48
+ text-align: center;
49
+ font-size: 1.5vh;
50
+ font-weight: bold;
51
+ padding: 8px;
52
+ }
docs/video_gen.html ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <!-- Google tag (gtag.js) -->
5
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
6
+ <script>
7
+ window.dataLayer = window.dataLayer || [];
8
+ function gtag(){dataLayer.push(arguments);}
9
+ gtag('js', new Date());
10
+ gtag('config', 'G-0JKBJ3WRJZ');
11
+ </script>
12
+
13
+ <link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
14
+ <meta charset="UTF-8">
15
+ <title>MMAudio</title>
16
+
17
+ <link rel="icon" type="image/png" href="images/icon.png">
18
+
19
+ <meta name="viewport" content="width=device-width, initial-scale=1">
20
+ <!-- CSS only -->
21
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
22
+ integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
23
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.7.1/jquery.min.js"></script>
24
+
25
+ <link rel="stylesheet" href="style_videos.css">
26
+ </head>
27
+ <body>
28
+
29
+ <div id="moviegen_all">
30
+ <h2 id="moviegen" style="text-align: center;">Comparisons with Movie Gen Audio on Videos Generated by MovieGen</h2>
31
+ <p id="moviegen1" style="overflow: hidden;">
32
+ Example 1: Ice cracking with sharp snapping sound, and metal tool scraping against the ice surface.
33
+ <span style="float: right;"><a href="#index">Back to index</a></span>
34
+ </p>
35
+
36
+ <div class="row g-1">
37
+ <div class="col-sm-6">
38
+ <div class="video-header">Movie Gen Audio</div>
39
+ <div class="video-container">
40
+ <iframe src="https://youtube.com/embed/d7Lb0ihtGcE"></iframe>
41
+ </div>
42
+ </div>
43
+ <div class="col-sm-6">
44
+ <div class="video-header">Ours</div>
45
+ <div class="video-container">
46
+ <iframe src="https://youtube.com/embed/F4JoJ2r2m8U"></iframe>
47
+ </div>
48
+ </div>
49
+ </div>
50
+ <br>
51
+
52
+ <!-- <p id="moviegen2">Example 2: Rhythmic splashing and lapping of water. <span style="float:right;"><a href="#index">Back to index</a></span> </p>
53
+
54
+ <table>
55
+ <thead>
56
+ <tr>
57
+ <th>Movie Gen Audio</th>
58
+ <th>Ours</th>
59
+ </tr>
60
+ </thead>
61
+ <tbody>
62
+ <tr>
63
+ <td width="50%">
64
+ <div class="video-container">
65
+ <iframe src="https://youtube.com/embed/5gQNPK99CIk"></iframe>
66
+ </div>
67
+ </td>
68
+ <td width="50%">
69
+ <div class="video-container">
70
+ <iframe src="https://youtube.com/embed/AbwnTzG-BpA"></iframe>
71
+ </div>
72
+ </td>
73
+ </tr>
74
+ </tbody>
75
+ </table> -->
76
+
77
+ <p id="moviegen2" style="overflow: hidden;">
78
+ Example 2: Rhythmic splashing and lapping of water.
79
+ <span style="float:right;"><a href="#index">Back to index</a></span>
80
+ </p>
81
+ <div class="row g-1">
82
+ <div class="col-sm-6">
83
+ <div class="video-header">Movie Gen Audio</div>
84
+ <div class="video-container">
85
+ <iframe src="https://youtube.com/embed/5gQNPK99CIk"></iframe>
86
+ </div>
87
+ </div>
88
+ <div class="col-sm-6">
89
+ <div class="video-header">Ours</div>
90
+ <div class="video-container">
91
+ <iframe src="https://youtube.com/embed/AbwnTzG-BpA"></iframe>
92
+ </div>
93
+ </div>
94
+ </div>
95
+ <br>
96
+
97
+ <p id="moviegen3" style="overflow: hidden;">
98
+ Example 3: Shovel scrapes against dry earth.
99
+ <span style="float:right;"><a href="#index">Back to index</a></span>
100
+ </p>
101
+ <div class="row g-1">
102
+ <div class="col-sm-6">
103
+ <div class="video-header">Movie Gen Audio</div>
104
+ <div class="video-container">
105
+ <iframe src="https://youtube.com/embed/PUKGyEve7XQ"></iframe>
106
+ </div>
107
+ </div>
108
+ <div class="col-sm-6">
109
+ <div class="video-header">Ours</div>
110
+ <div class="video-container">
111
+ <iframe src="https://youtube.com/embed/CNn7i8VNkdc"></iframe>
112
+ </div>
113
+ </div>
114
+ </div>
115
+ <br>
116
+
117
+
118
+ <p id="moviegen4" style="overflow: hidden;">
119
+ (Failure case) Example 4: Creamy sound of mashed potatoes being scooped.
120
+ <span style="float:right;"><a href="#index">Back to index</a></span>
121
+ </p>
122
+ <div class="row g-1">
123
+ <div class="col-sm-6">
124
+ <div class="video-header">Movie Gen Audio</div>
125
+ <div class="video-container">
126
+ <iframe src="https://youtube.com/embed/PJv1zxR9JjQ"></iframe>
127
+ </div>
128
+ </div>
129
+ <div class="col-sm-6">
130
+ <div class="video-header">Ours</div>
131
+ <div class="video-container">
132
+ <iframe src="https://youtube.com/embed/c3-LJ1lNsPQ"></iframe>
133
+ </div>
134
+ </div>
135
+ </div>
136
+ <br>
137
+
138
+ </div>
139
+
140
+ <div id="hunyuan_sora_all">
141
+
142
+ <h2 id="hunyuan" style="text-align: center;">Results on Videos Generated by Hunyuan</h2>
143
+ <p style="overflow: hidden;">
144
+ <span style="float:right;"><a href="#index">Back to index</a></span>
145
+ </p>
146
+ <div class="row g-1">
147
+ <div class="col-sm-6">
148
+ <div class="video-header">Typing</div>
149
+ <div class="video-container">
150
+ <iframe src="https://youtube.com/embed/8ln_9hhH_nk"></iframe>
151
+ </div>
152
+ </div>
153
+ <div class="col-sm-6">
154
+ <div class="video-header">Water is rushing down a stream and pouring</div>
155
+ <div class="video-container">
156
+ <iframe src="https://youtube.com/embed/5df1FZFQj30"></iframe>
157
+ </div>
158
+ </div>
159
+ </div>
160
+ <div class="row g-1">
161
+ <div class="col-sm-6">
162
+ <div class="video-header">Waves on beach</div>
163
+ <div class="video-container">
164
+ <iframe src="https://youtube.com/embed/7wQ9D5WgpFc"></iframe>
165
+ </div>
166
+ </div>
167
+ <div class="col-sm-6">
168
+ <div class="video-header">Water droplet</div>
169
+ <div class="video-container">
170
+ <iframe src="https://youtube.com/embed/q7M2nsalGjM"></iframe>
171
+ </div>
172
+ </div>
173
+ </div>
174
+ <br>
175
+
176
+ <h2 id="sora" style="text-align: center;">Results on Videos Generated by Sora</h2>
177
+ <p style="overflow: hidden;">
178
+ <span style="float:right;"><a href="#index">Back to index</a></span>
179
+ </p>
180
+ <div class="row g-1">
181
+ <div class="col-sm-6">
182
+ <div class="video-header">Ships riding waves</div>
183
+ <div class="video-container">
184
+ <iframe src="https://youtube.com/embed/JbgQzHHytk8"></iframe>
185
+ </div>
186
+ </div>
187
+ <div class="col-sm-6">
188
+ <div class="video-header">Train (no text prompt given)</div>
189
+ <div class="video-container">
190
+ <iframe src="https://youtube.com/embed/xOW7zrjpWC8"></iframe>
191
+ </div>
192
+ </div>
193
+ </div>
194
+ <div class="row g-1">
195
+ <div class="col-sm-6">
196
+ <div class="video-header">Seashore (no text prompt given)</div>
197
+ <div class="video-container">
198
+ <iframe src="https://youtube.com/embed/fIuw5Y8ZZ9E"></iframe>
199
+ </div>
200
+ </div>
201
+ <div class="col-sm-6">
202
+ <div class="video-header">Surfing (failure: unprompted music)</div>
203
+ <div class="video-container">
204
+ <iframe src="https://youtube.com/embed/UcSTk-v0M_s"></iframe>
205
+ </div>
206
+ </div>
207
+ </div>
208
+ <br>
209
+
210
+ <div id="mochi_ltx_all">
211
+ <h2 id="mochi" style="text-align: center;">Results on Videos Generated by Mochi 1</h2>
212
+ <p style="overflow: hidden;">
213
+ <span style="float:right;"><a href="#index">Back to index</a></span>
214
+ </p>
215
+ <div class="row g-1">
216
+ <div class="col-sm-6">
217
+ <div class="video-header">Magical fire and lightning (no text prompt given)</div>
218
+ <div class="video-container">
219
+ <iframe src="https://youtube.com/embed/tTlRZaSMNwY"></iframe>
220
+ </div>
221
+ </div>
222
+ <div class="col-sm-6">
223
+ <div class="video-header">Storm (no text prompt given)</div>
224
+ <div class="video-container">
225
+ <iframe src="https://youtube.com/embed/4hrZTMJUy3w"></iframe>
226
+ </div>
227
+ </div>
228
+ </div>
229
+ <br>
230
+
231
+ <h2 id="ltx" style="text-align: center;">Results on Videos Generated by LTX-Video</h2>
232
+ <p style="overflow: hidden;">
233
+ <span style="float:right;"><a href="#index">Back to index</a></span>
234
+ </p>
235
+ <div class="row g-1">
236
+ <div class="col-sm-6">
237
+ <div class="video-header">Firewood burning and cracking</div>
238
+ <div class="video-container">
239
+ <iframe src="https://youtube.com/embed/P7_DDpgev0g"></iframe>
240
+ </div>
241
+ </div>
242
+ <div class="col-sm-6">
243
+ <div class="video-header">Waterfall, water splashing</div>
244
+ <div class="video-container">
245
+ <iframe src="https://youtube.com/embed/4MvjceYnIO0"></iframe>
246
+ </div>
247
+ </div>
248
+ </div>
249
+ <br>
250
+
251
+ </div>
252
+
253
+ </body>
254
+ </html>
docs/video_main.html ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <!-- Google tag (gtag.js) -->
5
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
6
+ <script>
7
+ window.dataLayer = window.dataLayer || [];
8
+ function gtag(){dataLayer.push(arguments);}
9
+ gtag('js', new Date());
10
+ gtag('config', 'G-0JKBJ3WRJZ');
11
+ </script>
12
+
13
+ <link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
14
+ <meta charset="UTF-8">
15
+ <title>MMAudio</title>
16
+
17
+ <link rel="icon" type="image/png" href="images/icon.png">
18
+
19
+ <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
20
+ <!-- CSS only -->
21
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
22
+ integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
23
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.7.1/jquery.min.js"></script>
24
+
25
+ <link rel="stylesheet" href="style_videos.css">
26
+
27
+ <script type="text/javascript">
28
+ $(document).ready(function(){
29
+ $("#content").load("video_gen.html #moviegen_all");
30
+ $("#load_moveigen").click(function(){
31
+ $("#content").load("video_gen.html #moviegen_all");
32
+ });
33
+ $("#load_hunyuan_sora").click(function(){
34
+ $("#content").load("video_gen.html #hunyuan_sora_all");
35
+ });
36
+ $("#load_mochi_ltx").click(function(){
37
+ $("#content").load("video_gen.html #mochi_ltx_all");
38
+ });
39
+ $("#load_vgg1").click(function(){
40
+ $("#content").load("video_vgg.html #vgg1");
41
+ });
42
+ $("#load_vgg2").click(function(){
43
+ $("#content").load("video_vgg.html #vgg2");
44
+ });
45
+ $("#load_vgg3").click(function(){
46
+ $("#content").load("video_vgg.html #vgg3");
47
+ });
48
+ $("#load_vgg4").click(function(){
49
+ $("#content").load("video_vgg.html #vgg4");
50
+ });
51
+ $("#load_vgg5").click(function(){
52
+ $("#content").load("video_vgg.html #vgg5");
53
+ });
54
+ $("#load_vgg6").click(function(){
55
+ $("#content").load("video_vgg.html #vgg6");
56
+ });
57
+ $("#load_vgg_extra").click(function(){
58
+ $("#content").load("video_vgg.html #vgg_extra");
59
+ });
60
+ });
61
+ </script>
62
+ </head>
63
+ <body>
64
+ <h1 id="index" style="text-align: center;">Index</h1>
65
+ <p><b>(Click on the links to load the corresponding videos)</b> <span style="float:right;"><a href="index.html">Back to project page</a></span></p>
66
+
67
+ <ol>
68
+ <li>
69
+ <a href="#" id="load_moveigen">Comparisons with Movie Gen Audio on Videos Generated by MovieGen</a>
70
+ </li>
71
+ <li>
72
+ <a href="#" id="load_hunyuan_sora">Results on Videos Generated by Hunyuan and Sora</a>
73
+ </li>
74
+ <li>
75
+ <a href="#" id="load_mochi_ltx">Results on Videos Generated by Mochi 1 and LTX-Video</a>
76
+ </li>
77
+ <li>
78
+ On VGGSound
79
+ <ol>
80
+ <li><a id='load_vgg1' href="#">Example 1: Wolf howling</a></li>
81
+ <li><a id='load_vgg2' href="#">Example 2: Striking a golf ball</a></li>
82
+ <li><a id='load_vgg3' href="#">Example 3: Hitting a drum</a></li>
83
+ <li><a id='load_vgg4' href="#">Example 4: Dog barking</a></li>
84
+ <li><a id='load_vgg5' href="#">Example 5: Playing a string instrument</a></li>
85
+ <li><a id='load_vgg6' href="#">Example 6: A group of people playing tambourines</a></li>
86
+ <li><a id='load_vgg_extra' href="#">Extra results & failure cases</a></li>
87
+ </ol>
88
+ </li>
89
+ </ol>
90
+
91
+ <div id="content" class="container-fluid">
92
+
93
+ </div>
94
+ <br>
95
+ <br>
96
+
97
+ </body>
98
+ </html>
docs/video_vgg.html ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <!-- Google tag (gtag.js) -->
5
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
6
+ <script>
7
+ window.dataLayer = window.dataLayer || [];
8
+ function gtag(){dataLayer.push(arguments);}
9
+ gtag('js', new Date());
10
+ gtag('config', 'G-0JKBJ3WRJZ');
11
+ </script>
12
+
13
+ <link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
14
+ <meta charset="UTF-8">
15
+ <title>MMAudio</title>
16
+
17
+ <meta name="viewport" content="width=device-width, initial-scale=1">
18
+ <!-- CSS only -->
19
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
20
+ integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
21
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
22
+
23
+ <link rel="stylesheet" href="style_videos.css">
24
+ </head>
25
+ <body>
26
+
27
+ <div id="vgg1">
28
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
29
+ <p style="overflow: hidden;">
30
+ Example 1: Wolf howling.
31
+ <span style="float:right;"><a href="#index">Back to index</a></span>
32
+ </p>
33
+ <div class="row g-1">
34
+ <div class="col-sm-3">
35
+ <div class="video-header">Ground-truth</div>
36
+ <div class="video-container">
37
+ <iframe src="https://youtube.com/embed/9J_V74gqMUA"></iframe>
38
+ </div>
39
+ </div>
40
+ <div class="col-sm-3">
41
+ <div class="video-header">Ours</div>
42
+ <div class="video-container">
43
+ <iframe src="https://youtube.com/embed/P6O8IpjErPc"></iframe>
44
+ </div>
45
+ </div>
46
+ <div class="col-sm-3">
47
+ <div class="video-header">V2A-Mapper</div>
48
+ <div class="video-container">
49
+ <iframe src="https://youtube.com/embed/w-5eyqepvTk"></iframe>
50
+ </div>
51
+ </div>
52
+ <div class="col-sm-3">
53
+ <div class="video-header">FoleyCrafter</div>
54
+ <div class="video-container">
55
+ <iframe src="https://youtube.com/embed/VOLfoZlRkzo"></iframe>
56
+ </div>
57
+ </div>
58
+ </div>
59
+ <div class="row g-1">
60
+ <div class="col-sm-3">
61
+ <div class="video-header">Frieren</div>
62
+ <div class="video-container">
63
+ <iframe src="https://youtube.com/embed/49owKyA5Pa8"></iframe>
64
+ </div>
65
+ </div>
66
+ <div class="col-sm-3">
67
+ <div class="video-header">VATT</div>
68
+ <div class="video-container">
69
+ <iframe src="https://youtube.com/embed/QVtrFgbeGDM"></iframe>
70
+ </div>
71
+ </div>
72
+ <div class="col-sm-3">
73
+ <div class="video-header">V-AURA</div>
74
+ <div class="video-container">
75
+ <iframe src="https://youtube.com/embed/8r0uEfSNjvI"></iframe>
76
+ </div>
77
+ </div>
78
+ <div class="col-sm-3">
79
+ <div class="video-header">Seeing and Hearing</div>
80
+ <div class="video-container">
81
+ <iframe src="https://youtube.com/embed/bn-sLg2qulk"></iframe>
82
+ </div>
83
+ </div>
84
+ </div>
85
+ </div>
86
+
87
+ <div id="vgg2">
88
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
89
+ <p style="overflow: hidden;">
90
+ Example 2: Striking a golf ball.
91
+ <span style="float:right;"><a href="#index">Back to index</a></span>
92
+ </p>
93
+
94
+ <div class="row g-1">
95
+ <div class="col-sm-3">
96
+ <div class="video-header">Ground-truth</div>
97
+ <div class="video-container">
98
+ <iframe src="https://youtube.com/embed/1hwSu42kkho"></iframe>
99
+ </div>
100
+ </div>
101
+ <div class="col-sm-3">
102
+ <div class="video-header">Ours</div>
103
+ <div class="video-container">
104
+ <iframe src="https://youtube.com/embed/kZibDoDCNxI"></iframe>
105
+ </div>
106
+ </div>
107
+ <div class="col-sm-3">
108
+ <div class="video-header">V2A-Mapper</div>
109
+ <div class="video-container">
110
+ <iframe src="https://youtube.com/embed/jgKfLBLhh7Y"></iframe>
111
+ </div>
112
+ </div>
113
+ <div class="col-sm-3">
114
+ <div class="video-header">FoleyCrafter</div>
115
+ <div class="video-container">
116
+ <iframe src="https://youtube.com/embed/Lfsx8mOPcJo"></iframe>
117
+ </div>
118
+ </div>
119
+ </div>
120
+ <div class="row g-1">
121
+ <div class="col-sm-3">
122
+ <div class="video-header">Frieren</div>
123
+ <div class="video-container">
124
+ <iframe src="https://youtube.com/embed/tz-LpbB0MBc"></iframe>
125
+ </div>
126
+ </div>
127
+ <div class="col-sm-3">
128
+ <div class="video-header">VATT</div>
129
+ <div class="video-container">
130
+ <iframe src="https://youtube.com/embed/RTDUHMi08n4"></iframe>
131
+ </div>
132
+ </div>
133
+ <div class="col-sm-3">
134
+ <div class="video-header">V-AURA</div>
135
+ <div class="video-container">
136
+ <iframe src="https://youtube.com/embed/N-3TDOsPnZQ"></iframe>
137
+ </div>
138
+ </div>
139
+ <div class="col-sm-3">
140
+ <div class="video-header">Seeing and Hearing</div>
141
+ <div class="video-container">
142
+ <iframe src="https://youtube.com/embed/QnsHnLn4gB0"></iframe>
143
+ </div>
144
+ </div>
145
+ </div>
146
+ </div>
147
+
148
+ <div id="vgg3">
149
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
150
+ <p style="overflow: hidden;">
151
+ Example 3: Hitting a drum.
152
+ <span style="float:right;"><a href="#index">Back to index</a></span>
153
+ </p>
154
+
155
+ <div class="row g-1">
156
+ <div class="col-sm-3">
157
+ <div class="video-header">Ground-truth</div>
158
+ <div class="video-container">
159
+ <iframe src="https://youtube.com/embed/0oeIwq77w0Q"></iframe>
160
+ </div>
161
+ </div>
162
+ <div class="col-sm-3">
163
+ <div class="video-header">Ours</div>
164
+ <div class="video-container">
165
+ <iframe src="https://youtube.com/embed/-UtPV9ohuIM"></iframe>
166
+ </div>
167
+ </div>
168
+ <div class="col-sm-3">
169
+ <div class="video-header">V2A-Mapper</div>
170
+ <div class="video-container">
171
+ <iframe src="https://youtube.com/embed/9yivkgN-zwc"></iframe>
172
+ </div>
173
+ </div>
174
+ <div class="col-sm-3">
175
+ <div class="video-header">FoleyCrafter</div>
176
+ <div class="video-container">
177
+ <iframe src="https://youtube.com/embed/kkCsXPOlBvY"></iframe>
178
+ </div>
179
+ </div>
180
+ </div>
181
+ <div class="row g-1">
182
+ <div class="col-sm-3">
183
+ <div class="video-header">Frieren</div>
184
+ <div class="video-container">
185
+ <iframe src="https://youtube.com/embed/MbNKsVsuvig"></iframe>
186
+ </div>
187
+ </div>
188
+ <div class="col-sm-3">
189
+ <div class="video-header">VATT</div>
190
+ <div class="video-container">
191
+ <iframe src="https://youtube.com/embed/2yYviBjrpBw"></iframe>
192
+ </div>
193
+ </div>
194
+ <div class="col-sm-3">
195
+ <div class="video-header">V-AURA</div>
196
+ <div class="video-container">
197
+ <iframe src="https://youtube.com/embed/9yivkgN-zwc"></iframe>
198
+ </div>
199
+ </div>
200
+ <div class="col-sm-3">
201
+ <div class="video-header">Seeing and Hearing</div>
202
+ <div class="video-container">
203
+ <iframe src="https://youtube.com/embed/6dnyQt4Fuhs"></iframe>
204
+ </div>
205
+ </div>
206
+ </div>
207
+ </div>
208
+ </div>
209
+
210
+ <div id="vgg4">
211
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
212
+ <p style="overflow: hidden;">
213
+ Example 4: Dog barking.
214
+ <span style="float:right;"><a href="#index">Back to index</a></span>
215
+ </p>
216
+
217
+ <div class="row g-1">
218
+ <div class="col-sm-3">
219
+ <div class="video-header">Ground-truth</div>
220
+ <div class="video-container">
221
+ <iframe src="https://youtube.com/embed/ckaqvTyMYAw"></iframe>
222
+ </div>
223
+ </div>
224
+ <div class="col-sm-3">
225
+ <div class="video-header">Ours</div>
226
+ <div class="video-container">
227
+ <iframe src="https://youtube.com/embed/_aRndFZzZ-I"></iframe>
228
+ </div>
229
+ </div>
230
+ <div class="col-sm-3">
231
+ <div class="video-header">V2A-Mapper</div>
232
+ <div class="video-container">
233
+ <iframe src="https://youtube.com/embed/mNCISP3LBl0"></iframe>
234
+ </div>
235
+ </div>
236
+ <div class="col-sm-3">
237
+ <div class="video-header">FoleyCrafter</div>
238
+ <div class="video-container">
239
+ <iframe src="https://youtube.com/embed/phZBQ3L7foE"></iframe>
240
+ </div>
241
+ </div>
242
+ </div>
243
+ <div class="row g-1">
244
+ <div class="col-sm-3">
245
+ <div class="video-header">Frieren</div>
246
+ <div class="video-container">
247
+ <iframe src="https://youtube.com/embed/Sb5Mg1-ORao"></iframe>
248
+ </div>
249
+ </div>
250
+ <div class="col-sm-3">
251
+ <div class="video-header">VATT</div>
252
+ <div class="video-container">
253
+ <iframe src="https://youtube.com/embed/eHmAGOmtDDg"></iframe>
254
+ </div>
255
+ </div>
256
+ <div class="col-sm-3">
257
+ <div class="video-header">V-AURA</div>
258
+ <div class="video-container">
259
+ <iframe src="https://youtube.com/embed/NEGa3krBrm0"></iframe>
260
+ </div>
261
+ </div>
262
+ <div class="col-sm-3">
263
+ <div class="video-header">Seeing and Hearing</div>
264
+ <div class="video-container">
265
+ <iframe src="https://youtube.com/embed/aO0EAXlwE7A"></iframe>
266
+ </div>
267
+ </div>
268
+ </div>
269
+ </div>
270
+
271
+ <div id="vgg5">
272
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
273
+ <p style="overflow: hidden;">
274
+ Example 5: Playing a string instrument.
275
+ <span style="float:right;"><a href="#index">Back to index</a></span>
276
+ </p>
277
+
278
+ <div class="row g-1">
279
+ <div class="col-sm-3">
280
+ <div class="video-header">Ground-truth</div>
281
+ <div class="video-container">
282
+ <iframe src="https://youtube.com/embed/KP1QhWauIOc"></iframe>
283
+ </div>
284
+ </div>
285
+ <div class="col-sm-3">
286
+ <div class="video-header">Ours</div>
287
+ <div class="video-container">
288
+ <iframe src="https://youtube.com/embed/ovaJhWSquYE"></iframe>
289
+ </div>
290
+ </div>
291
+ <div class="col-sm-3">
292
+ <div class="video-header">V2A-Mapper</div>
293
+ <div class="video-container">
294
+ <iframe src="https://youtube.com/embed/N723FS9lcy8"></iframe>
295
+ </div>
296
+ </div>
297
+ <div class="col-sm-3">
298
+ <div class="video-header">FoleyCrafter</div>
299
+ <div class="video-container">
300
+ <iframe src="https://youtube.com/embed/t0N4ZAAXo58"></iframe>
301
+ </div>
302
+ </div>
303
+ </div>
304
+ <div class="row g-1">
305
+ <div class="col-sm-3">
306
+ <div class="video-header">Frieren</div>
307
+ <div class="video-container">
308
+ <iframe src="https://youtube.com/embed/8YSRs03QNNA"></iframe>
309
+ </div>
310
+ </div>
311
+ <div class="col-sm-3">
312
+ <div class="video-header">VATT</div>
313
+ <div class="video-container">
314
+ <iframe src="https://youtube.com/embed/vOpMz55J1kY"></iframe>
315
+ </div>
316
+ </div>
317
+ <div class="col-sm-3">
318
+ <div class="video-header">V-AURA</div>
319
+ <div class="video-container">
320
+ <iframe src="https://youtube.com/embed/9JHC75vr9h0"></iframe>
321
+ </div>
322
+ </div>
323
+ <div class="col-sm-3">
324
+ <div class="video-header">Seeing and Hearing</div>
325
+ <div class="video-container">
326
+ <iframe src="https://youtube.com/embed/9w0JckNzXmY"></iframe>
327
+ </div>
328
+ </div>
329
+ </div>
330
+ </div>
331
+
332
+ <div id="vgg6">
333
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
334
+ <p style="overflow: hidden;">
335
+ Example 6: A group of people playing tambourines.
336
+ <span style="float:right;"><a href="#index">Back to index</a></span>
337
+ </p>
338
+
339
+ <div class="row g-1">
340
+ <div class="col-sm-3">
341
+ <div class="video-header">Ground-truth</div>
342
+ <div class="video-container">
343
+ <iframe src="https://youtube.com/embed/mx6JLxzUkRc"></iframe>
344
+ </div>
345
+ </div>
346
+ <div class="col-sm-3">
347
+ <div class="video-header">Ours</div>
348
+ <div class="video-container">
349
+ <iframe src="https://youtube.com/embed/oLirHhP9Su8"></iframe>
350
+ </div>
351
+ </div>
352
+ <div class="col-sm-3">
353
+ <div class="video-header">V2A-Mapper</div>
354
+ <div class="video-container">
355
+ <iframe src="https://youtube.com/embed/HkLkHMqptv0"></iframe>
356
+ </div>
357
+ </div>
358
+ <div class="col-sm-3">
359
+ <div class="video-header">FoleyCrafter</div>
360
+ <div class="video-container">
361
+ <iframe src="https://youtube.com/embed/rpHiiODjmNU"></iframe>
362
+ </div>
363
+ </div>
364
+ </div>
365
+ <div class="row g-1">
366
+ <div class="col-sm-3">
367
+ <div class="video-header">Frieren</div>
368
+ <div class="video-container">
369
+ <iframe src="https://youtube.com/embed/1mVD3fJ0LpM"></iframe>
370
+ </div>
371
+ </div>
372
+ <div class="col-sm-3">
373
+ <div class="video-header">VATT</div>
374
+ <div class="video-container">
375
+ <iframe src="https://youtube.com/embed/yjVFnJiEJlw"></iframe>
376
+ </div>
377
+ </div>
378
+ <div class="col-sm-3">
379
+ <div class="video-header">V-AURA</div>
380
+ <div class="video-container">
381
+ <iframe src="https://youtube.com/embed/neVeMSWtRkU"></iframe>
382
+ </div>
383
+ </div>
384
+ <div class="col-sm-3">
385
+ <div class="video-header">Seeing and Hearing</div>
386
+ <div class="video-container">
387
+ <iframe src="https://youtube.com/embed/EUE7YwyVWz8"></iframe>
388
+ </div>
389
+ </div>
390
+ </div>
391
+ </div>
392
+
393
+ <div id="vgg_extra">
394
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
395
+ <p style="overflow: hidden;">
396
+ <span style="float:right;"><a href="#index">Back to index</a></span>
397
+ </p>
398
+
399
+ <div class="row g-1">
400
+ <div class="col-sm-3">
401
+ <div class="video-header">Moving train</div>
402
+ <div class="video-container">
403
+ <iframe src="https://youtube.com/embed/Ta6H45rBzJc"></iframe>
404
+ </div>
405
+ </div>
406
+ <div class="col-sm-3">
407
+ <div class="video-header">Water splashing</div>
408
+ <div class="video-container">
409
+ <iframe src="https://youtube.com/embed/hl6AtgHXpb4"></iframe>
410
+ </div>
411
+ </div>
412
+ <div class="col-sm-3">
413
+ <div class="video-header">Skateboarding</div>
414
+ <div class="video-container">
415
+ <iframe src="https://youtube.com/embed/n4sCNi_9buI"></iframe>
416
+ </div>
417
+ </div>
418
+ <div class="col-sm-3">
419
+ <div class="video-header">Synchronized clapping</div>
420
+ <div class="video-container">
421
+ <iframe src="https://youtube.com/embed/oxexfpLn7FE"></iframe>
422
+ </div>
423
+ </div>
424
+ </div>
425
+
426
+ <br><br>
427
+
428
+ <div id="extra-failure">
429
+ <h2 style="text-align: center;">Failure cases</h2>
430
+ <p style="overflow: hidden;">
431
+ <span style="float:right;"><a href="#index">Back to index</a></span>
432
+ </p>
433
+
434
+ <div class="row g-1">
435
+ <div class="col-sm-6">
436
+ <div class="video-header">Human speech</div>
437
+ <div class="video-container">
438
+ <iframe src="https://youtube.com/embed/nx0CyrDu70Y"></iframe>
439
+ </div>
440
+ </div>
441
+ <div class="col-sm-6">
442
+ <div class="video-header">Unfamiliar vision input</div>
443
+ <div class="video-container">
444
+ <iframe src="https://youtube.com/embed/hfnAqmK3X7w"></iframe>
445
+ </div>
446
+ </div>
447
+ </div>
448
+ </div>
449
+ </div>
450
+
451
+ </body>
452
+ </html>
gradio_demo.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ from argparse import ArgumentParser
4
+ from datetime import datetime
5
+ from fractions import Fraction
6
+ from pathlib import Path
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import torchaudio
11
+
12
+ from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
13
+ load_video, make_video, setup_eval_logging)
14
+ from mmaudio.model.flow_matching import FlowMatching
15
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
16
+ from mmaudio.model.sequence_config import SequenceConfig
17
+ from mmaudio.model.utils.features_utils import FeaturesUtils
18
+
19
+ torch.backends.cuda.matmul.allow_tf32 = True
20
+ torch.backends.cudnn.allow_tf32 = True
21
+
22
+ log = logging.getLogger()
23
+
24
+ device = 'cpu'
25
+ if torch.cuda.is_available():
26
+ device = 'cuda'
27
+ elif torch.backends.mps.is_available():
28
+ device = 'mps'
29
+ else:
30
+ log.warning('CUDA/MPS are not available, running on CPU')
31
+ dtype = torch.bfloat16
32
+
33
+ model: ModelConfig = all_model_cfg['large_44k_v2']
34
+ model.download_if_needed()
35
+ output_dir = Path('./output/gradio')
36
+
37
+ setup_eval_logging()
38
+
39
+
40
+ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
41
+ seq_cfg = model.seq_cfg
42
+
43
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
44
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
45
+ log.info(f'Loaded weights from {model.model_path}')
46
+
47
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
48
+ synchformer_ckpt=model.synchformer_ckpt,
49
+ enable_conditions=True,
50
+ mode=model.mode,
51
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
52
+ need_vae_encoder=False)
53
+ feature_utils = feature_utils.to(device, dtype).eval()
54
+
55
+ return net, feature_utils, seq_cfg
56
+
57
+
58
+ net, feature_utils, seq_cfg = get_model()
59
+
60
+
61
+ @torch.inference_mode()
62
+ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
63
+ cfg_strength: float, duration: float):
64
+
65
+ rng = torch.Generator(device=device)
66
+ if seed >= 0:
67
+ rng.manual_seed(seed)
68
+ else:
69
+ rng.seed()
70
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
71
+
72
+ video_info = load_video(video, duration)
73
+ clip_frames = video_info.clip_frames
74
+ sync_frames = video_info.sync_frames
75
+ duration = video_info.duration_sec
76
+ clip_frames = clip_frames.unsqueeze(0)
77
+ sync_frames = sync_frames.unsqueeze(0)
78
+ seq_cfg.duration = duration
79
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
80
+
81
+ audios = generate(clip_frames,
82
+ sync_frames, [prompt],
83
+ negative_text=[negative_prompt],
84
+ feature_utils=feature_utils,
85
+ net=net,
86
+ fm=fm,
87
+ rng=rng,
88
+ cfg_strength=cfg_strength)
89
+ audio = audios.float().cpu()[0]
90
+
91
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
92
+ output_dir.mkdir(exist_ok=True, parents=True)
93
+ video_save_path = output_dir / f'{current_time_string}.mp4'
94
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
95
+ gc.collect()
96
+ return video_save_path
97
+
98
+
99
+ @torch.inference_mode()
100
+ def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int, num_steps: int,
101
+ cfg_strength: float, duration: float):
102
+
103
+ rng = torch.Generator(device=device)
104
+ if seed >= 0:
105
+ rng.manual_seed(seed)
106
+ else:
107
+ rng.seed()
108
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
109
+
110
+ image_info = load_image(image)
111
+ clip_frames = image_info.clip_frames
112
+ sync_frames = image_info.sync_frames
113
+ clip_frames = clip_frames.unsqueeze(0)
114
+ sync_frames = sync_frames.unsqueeze(0)
115
+ seq_cfg.duration = duration
116
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
117
+
118
+ audios = generate(clip_frames,
119
+ sync_frames, [prompt],
120
+ negative_text=[negative_prompt],
121
+ feature_utils=feature_utils,
122
+ net=net,
123
+ fm=fm,
124
+ rng=rng,
125
+ cfg_strength=cfg_strength,
126
+ image_input=True)
127
+ audio = audios.float().cpu()[0]
128
+
129
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
130
+ output_dir.mkdir(exist_ok=True, parents=True)
131
+ video_save_path = output_dir / f'{current_time_string}.mp4'
132
+ video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1))
133
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
134
+ gc.collect()
135
+ return video_save_path
136
+
137
+
138
+ @torch.inference_mode()
139
+ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
140
+ duration: float):
141
+
142
+ rng = torch.Generator(device=device)
143
+ if seed >= 0:
144
+ rng.manual_seed(seed)
145
+ else:
146
+ rng.seed()
147
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
148
+
149
+ clip_frames = sync_frames = None
150
+ seq_cfg.duration = duration
151
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
152
+
153
+ audios = generate(clip_frames,
154
+ sync_frames, [prompt],
155
+ negative_text=[negative_prompt],
156
+ feature_utils=feature_utils,
157
+ net=net,
158
+ fm=fm,
159
+ rng=rng,
160
+ cfg_strength=cfg_strength)
161
+ audio = audios.float().cpu()[0]
162
+
163
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
164
+ output_dir.mkdir(exist_ok=True, parents=True)
165
+ audio_save_path = output_dir / f'{current_time_string}.flac'
166
+ torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
167
+ gc.collect()
168
+ return audio_save_path
169
+
170
+
171
+ video_to_audio_tab = gr.Interface(
172
+ fn=video_to_audio,
173
+ description="""
174
+ Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
175
+ Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
176
+
177
+ NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
178
+ Doing so does not improve results.
179
+ """,
180
+ inputs=[
181
+ gr.Video(),
182
+ gr.Text(label='Prompt'),
183
+ gr.Text(label='Negative prompt', value='music'),
184
+ gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
185
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
186
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
187
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
188
+ ],
189
+ outputs='playable_video',
190
+ cache_examples=False,
191
+ title='MMAudio — Video-to-Audio Synthesis',
192
+ examples=[
193
+ [
194
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
195
+ 'waves, seagulls',
196
+ '',
197
+ 0,
198
+ 25,
199
+ 4.5,
200
+ 10,
201
+ ],
202
+ [
203
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4',
204
+ '',
205
+ 'music',
206
+ 0,
207
+ 25,
208
+ 4.5,
209
+ 10,
210
+ ],
211
+ [
212
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4',
213
+ 'bubbles',
214
+ '',
215
+ 0,
216
+ 25,
217
+ 4.5,
218
+ 10,
219
+ ],
220
+ [
221
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4',
222
+ 'Indian holy music',
223
+ '',
224
+ 0,
225
+ 25,
226
+ 4.5,
227
+ 10,
228
+ ],
229
+ [
230
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4',
231
+ 'galloping',
232
+ '',
233
+ 0,
234
+ 25,
235
+ 4.5,
236
+ 10,
237
+ ],
238
+ [
239
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
240
+ 'waves, storm',
241
+ '',
242
+ 0,
243
+ 25,
244
+ 4.5,
245
+ 10,
246
+ ],
247
+ [
248
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
249
+ 'storm',
250
+ '',
251
+ 0,
252
+ 25,
253
+ 4.5,
254
+ 10,
255
+ ],
256
+ [
257
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
258
+ '',
259
+ '',
260
+ 0,
261
+ 25,
262
+ 4.5,
263
+ 10,
264
+ ],
265
+ [
266
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
267
+ 'typing',
268
+ '',
269
+ 0,
270
+ 25,
271
+ 4.5,
272
+ 10,
273
+ ],
274
+ [
275
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
276
+ '',
277
+ '',
278
+ 0,
279
+ 25,
280
+ 4.5,
281
+ 10,
282
+ ],
283
+ [
284
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
285
+ '',
286
+ '',
287
+ 0,
288
+ 25,
289
+ 4.5,
290
+ 10,
291
+ ],
292
+ ])
293
+
294
+ text_to_audio_tab = gr.Interface(
295
+ fn=text_to_audio,
296
+ description="""
297
+ Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
298
+ Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
299
+ """,
300
+ inputs=[
301
+ gr.Text(label='Prompt'),
302
+ gr.Text(label='Negative prompt'),
303
+ gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
304
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
305
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
306
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
307
+ ],
308
+ outputs='audio',
309
+ cache_examples=False,
310
+ title='MMAudio — Text-to-Audio Synthesis',
311
+ )
312
+
313
+ image_to_audio_tab = gr.Interface(
314
+ fn=image_to_audio,
315
+ description="""
316
+ Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
317
+ Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
318
+
319
+ NOTE: It takes longer to process high-resolution images (>384 px on the shorter side).
320
+ Doing so does not improve results.
321
+ """,
322
+ inputs=[
323
+ gr.Image(type='filepath'),
324
+ gr.Text(label='Prompt'),
325
+ gr.Text(label='Negative prompt'),
326
+ gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
327
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
328
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
329
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
330
+ ],
331
+ outputs='playable_video',
332
+ cache_examples=False,
333
+ title='MMAudio — Image-to-Audio Synthesis (experimental)',
334
+ )
335
+
336
+ if __name__ == "__main__":
337
+ parser = ArgumentParser()
338
+ parser.add_argument('--port', type=int, default=7860)
339
+ args = parser.parse_args()
340
+
341
+ gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab],
342
+ ['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']).launch(
343
+ server_port=args.port, allowed_paths=[output_dir])
mmaudio/__init__.py ADDED
File without changes
mmaudio/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (187 Bytes). View file
 
mmaudio/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (185 Bytes). View file
 
mmaudio/__pycache__/eval_utils.cpython-310.pyc ADDED
Binary file (7.07 kB). View file
 
mmaudio/__pycache__/eval_utils.cpython-38.pyc ADDED
Binary file (7.03 kB). View file
 
mmaudio/data/__init__.py ADDED
File without changes
mmaudio/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (192 Bytes). View file
 
mmaudio/data/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (190 Bytes). View file
 
mmaudio/data/__pycache__/av_utils.cpython-310.pyc ADDED
Binary file (4.91 kB). View file
 
mmaudio/data/__pycache__/av_utils.cpython-38.pyc ADDED
Binary file (4.89 kB). View file
 
mmaudio/data/av_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from fractions import Fraction
3
+ from pathlib import Path
4
+ from typing import Optional, List, Tuple
5
+
6
+ import av
7
+ import numpy as np
8
+ import torch
9
+ from av import AudioFrame
10
+
11
+
12
+ @dataclass
13
+ class VideoInfo:
14
+ duration_sec: float
15
+ fps: Fraction
16
+ clip_frames: torch.Tensor
17
+ sync_frames: torch.Tensor
18
+ all_frames: Optional[List[np.ndarray]]
19
+
20
+ @property
21
+ def height(self):
22
+ return self.all_frames[0].shape[0]
23
+
24
+ @property
25
+ def width(self):
26
+ return self.all_frames[0].shape[1]
27
+
28
+ @classmethod
29
+ def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
30
+ fps: Fraction) -> 'VideoInfo':
31
+ num_frames = int(duration_sec * fps)
32
+ all_frames = [image_info.original_frame] * num_frames
33
+ return cls(duration_sec=duration_sec,
34
+ fps=fps,
35
+ clip_frames=image_info.clip_frames,
36
+ sync_frames=image_info.sync_frames,
37
+ all_frames=all_frames)
38
+
39
+
40
+ @dataclass
41
+ class ImageInfo:
42
+ clip_frames: torch.Tensor
43
+ sync_frames: torch.Tensor
44
+ original_frame: Optional[np.ndarray]
45
+
46
+ @property
47
+ def height(self):
48
+ return self.original_frame.shape[0]
49
+
50
+ @property
51
+ def width(self):
52
+ return self.original_frame.shape[1]
53
+
54
+
55
+ def read_frames(video_path: Path, list_of_fps: List[float], start_sec: float, end_sec: float,
56
+ need_all_frames: bool) -> Tuple[List[np.ndarray], List[np.ndarray], Fraction]:
57
+ output_frames = [[] for _ in list_of_fps]
58
+ next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
59
+ time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
60
+ all_frames = []
61
+
62
+ # container = av.open(video_path)
63
+ with av.open(video_path) as container:
64
+ stream = container.streams.video[0]
65
+ fps = stream.guessed_rate
66
+ stream.thread_type = 'AUTO'
67
+ for packet in container.demux(stream):
68
+ for frame in packet.decode():
69
+ frame_time = frame.time
70
+ if frame_time < start_sec:
71
+ continue
72
+ if frame_time > end_sec:
73
+ break
74
+
75
+ frame_np = None
76
+ if need_all_frames:
77
+ frame_np = frame.to_ndarray(format='rgb24')
78
+ all_frames.append(frame_np)
79
+
80
+ for i, _ in enumerate(list_of_fps):
81
+ this_time = frame_time
82
+ while this_time >= next_frame_time_for_each_fps[i]:
83
+ if frame_np is None:
84
+ frame_np = frame.to_ndarray(format='rgb24')
85
+
86
+ output_frames[i].append(frame_np)
87
+ next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
88
+
89
+ output_frames = [np.stack(frames) for frames in output_frames]
90
+ return output_frames, all_frames, fps
91
+
92
+
93
+ def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
94
+ sampling_rate: int):
95
+ container = av.open(output_path, 'w')
96
+ output_video_stream = container.add_stream('h264', video_info.fps)
97
+ output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
98
+ output_video_stream.width = video_info.width
99
+ output_video_stream.height = video_info.height
100
+ output_video_stream.pix_fmt = 'yuv420p'
101
+
102
+ output_audio_stream = container.add_stream('aac', sampling_rate)
103
+
104
+ # encode video
105
+ for image in video_info.all_frames:
106
+ image = av.VideoFrame.from_ndarray(image)
107
+ packet = output_video_stream.encode(image)
108
+ container.mux(packet)
109
+
110
+ for packet in output_video_stream.encode():
111
+ container.mux(packet)
112
+
113
+ # convert float tensor audio to numpy array
114
+ audio_np = audio.numpy().astype(np.float32)
115
+ audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
116
+ audio_frame.sample_rate = sampling_rate
117
+
118
+ for packet in output_audio_stream.encode(audio_frame):
119
+ container.mux(packet)
120
+
121
+ for packet in output_audio_stream.encode():
122
+ container.mux(packet)
123
+
124
+ container.close()
125
+
126
+
127
+ def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
128
+ """
129
+ NOTE: I don't think we can get the exact video duration right without re-encoding
130
+ so we are not using this but keeping it here for reference
131
+ """
132
+ video = av.open(video_path)
133
+ output = av.open(output_path, 'w')
134
+ input_video_stream = video.streams.video[0]
135
+ output_video_stream = output.add_stream(template=input_video_stream)
136
+ output_audio_stream = output.add_stream('aac', sampling_rate)
137
+
138
+ duration_sec = audio.shape[-1] / sampling_rate
139
+
140
+ for packet in video.demux(input_video_stream):
141
+ # We need to skip the "flushing" packets that `demux` generates.
142
+ if packet.dts is None:
143
+ continue
144
+ # We need to assign the packet to the new stream.
145
+ packet.stream = output_video_stream
146
+ output.mux(packet)
147
+
148
+ # convert float tensor audio to numpy array
149
+ audio_np = audio.numpy().astype(np.float32)
150
+ audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
151
+ audio_frame.sample_rate = sampling_rate
152
+
153
+ for packet in output_audio_stream.encode(audio_frame):
154
+ output.mux(packet)
155
+
156
+ for packet in output_audio_stream.encode():
157
+ output.mux(packet)
158
+
159
+ video.close()
160
+ output.close()
161
+
162
+ output.close()
mmaudio/data/data_setup.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from torch.utils.data.dataloader import default_collate
9
+ from torch.utils.data.distributed import DistributedSampler
10
+
11
+ from mmaudio.data.eval.audiocaps import AudioCapsData
12
+ from mmaudio.data.eval.video_dataset import MovieGen, VGGSound
13
+ from mmaudio.data.extracted_audio import ExtractedAudio
14
+ from mmaudio.data.extracted_vgg import ExtractedVGG
15
+ from mmaudio.data.mm_dataset import MultiModalDataset
16
+ from mmaudio.utils.dist_utils import local_rank
17
+
18
+ log = logging.getLogger()
19
+
20
+
21
+ # Re-seed randomness every time we start a worker
22
+ def worker_init_fn(worker_id: int):
23
+ worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
24
+ np.random.seed(worker_seed)
25
+ random.seed(worker_seed)
26
+ log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
27
+
28
+
29
+ def load_vgg_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
30
+ dataset = ExtractedVGG(tsv_path=data_cfg.tsv,
31
+ data_dim=cfg.data_dim,
32
+ premade_mmap_dir=data_cfg.memmap_dir)
33
+
34
+ return dataset
35
+
36
+
37
+ def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
38
+ dataset = ExtractedAudio(tsv_path=data_cfg.tsv,
39
+ data_dim=cfg.data_dim,
40
+ premade_mmap_dir=data_cfg.memmap_dir)
41
+
42
+ return dataset
43
+
44
+
45
+ def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]:
46
+ if cfg.mini_train:
47
+ vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
48
+ audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
49
+ dataset = MultiModalDataset([vgg], [audiocaps])
50
+ if cfg.example_train:
51
+ video = load_vgg_data(cfg, cfg.data.Example_video)
52
+ audio = load_audio_data(cfg, cfg.data.Example_audio)
53
+ dataset = MultiModalDataset([video], [audio])
54
+ else:
55
+ # load the largest one first
56
+ freesound = load_audio_data(cfg, cfg.data.FreeSound)
57
+ vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG)
58
+ audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
59
+ audioset_sl = load_audio_data(cfg, cfg.data.AudioSetSL)
60
+ bbcsound = load_audio_data(cfg, cfg.data.BBCSound)
61
+ clotho = load_audio_data(cfg, cfg.data.Clotho)
62
+ dataset = MultiModalDataset([vgg] * cfg.vgg_oversample_rate,
63
+ [audiocaps, audioset_sl, bbcsound, freesound, clotho])
64
+
65
+ batch_size = cfg.batch_size
66
+ num_workers = cfg.num_workers
67
+ pin_memory = cfg.pin_memory
68
+ sampler, loader = construct_loader(dataset,
69
+ batch_size,
70
+ num_workers,
71
+ shuffle=True,
72
+ drop_last=True,
73
+ pin_memory=pin_memory)
74
+
75
+ return dataset, sampler, loader
76
+
77
+
78
+ def setup_test_datasets(cfg):
79
+ dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_test)
80
+
81
+ batch_size = cfg.batch_size
82
+ num_workers = cfg.num_workers
83
+ pin_memory = cfg.pin_memory
84
+ sampler, loader = construct_loader(dataset,
85
+ batch_size,
86
+ num_workers,
87
+ shuffle=False,
88
+ drop_last=False,
89
+ pin_memory=pin_memory)
90
+
91
+ return dataset, sampler, loader
92
+
93
+
94
+ def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]:
95
+ if cfg.example_train:
96
+ dataset = load_vgg_data(cfg, cfg.data.Example_video)
97
+ else:
98
+ dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
99
+
100
+ val_batch_size = cfg.batch_size
101
+ val_eval_batch_size = cfg.eval_batch_size
102
+ num_workers = cfg.num_workers
103
+ pin_memory = cfg.pin_memory
104
+ _, val_loader = construct_loader(dataset,
105
+ val_batch_size,
106
+ num_workers,
107
+ shuffle=False,
108
+ drop_last=False,
109
+ pin_memory=pin_memory)
110
+ _, eval_loader = construct_loader(dataset,
111
+ val_eval_batch_size,
112
+ num_workers,
113
+ shuffle=False,
114
+ drop_last=False,
115
+ pin_memory=pin_memory)
116
+
117
+ return dataset, val_loader, eval_loader
118
+
119
+
120
+ def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
121
+ if dataset_name.startswith('audiocaps_full'):
122
+ dataset = AudioCapsData(cfg.eval_data.AudioCaps_full.audio_path,
123
+ cfg.eval_data.AudioCaps_full.csv_path)
124
+ elif dataset_name.startswith('audiocaps'):
125
+ dataset = AudioCapsData(cfg.eval_data.AudioCaps.audio_path,
126
+ cfg.eval_data.AudioCaps.csv_path)
127
+ elif dataset_name.startswith('moviegen'):
128
+ dataset = MovieGen(cfg.eval_data.MovieGen.video_path,
129
+ cfg.eval_data.MovieGen.jsonl_path,
130
+ duration_sec=cfg.duration_s)
131
+ elif dataset_name.startswith('vggsound'):
132
+ dataset = VGGSound(cfg.eval_data.VGGSound.video_path,
133
+ cfg.eval_data.VGGSound.csv_path,
134
+ duration_sec=cfg.duration_s)
135
+ else:
136
+ raise ValueError(f'Invalid dataset name: {dataset_name}')
137
+
138
+ batch_size = cfg.batch_size
139
+ num_workers = cfg.num_workers
140
+ pin_memory = cfg.pin_memory
141
+ _, loader = construct_loader(dataset,
142
+ batch_size,
143
+ num_workers,
144
+ shuffle=False,
145
+ drop_last=False,
146
+ pin_memory=pin_memory,
147
+ error_avoidance=True)
148
+ return dataset, loader
149
+
150
+
151
+ def error_avoidance_collate(batch):
152
+ batch = list(filter(lambda x: x is not None, batch))
153
+ return default_collate(batch)
154
+
155
+
156
+ def construct_loader(dataset: Dataset,
157
+ batch_size: int,
158
+ num_workers: int,
159
+ *,
160
+ shuffle: bool = True,
161
+ drop_last: bool = True,
162
+ pin_memory: bool = False,
163
+ error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]:
164
+ train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
165
+ train_loader = DataLoader(dataset,
166
+ batch_size,
167
+ sampler=train_sampler,
168
+ num_workers=num_workers,
169
+ worker_init_fn=worker_init_fn,
170
+ drop_last=drop_last,
171
+ persistent_workers=num_workers > 0,
172
+ pin_memory=pin_memory,
173
+ collate_fn=error_avoidance_collate if error_avoidance else None)
174
+ return train_sampler, train_loader
mmaudio/data/eval/__init__.py ADDED
File without changes
mmaudio/data/eval/audiocaps.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from collections import defaultdict
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import pandas as pd
8
+ import torch
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+ log = logging.getLogger()
12
+
13
+
14
+ class AudioCapsData(Dataset):
15
+
16
+ def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
17
+ df = pd.read_csv(csv_path).to_dict(orient='records')
18
+
19
+ audio_files = sorted(os.listdir(audio_path))
20
+ audio_files = set(
21
+ [Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
22
+
23
+ self.data = []
24
+ for row in df:
25
+ self.data.append({
26
+ 'name': row['name'],
27
+ 'caption': row['caption'],
28
+ })
29
+
30
+ self.audio_path = Path(audio_path)
31
+ self.csv_path = Path(csv_path)
32
+
33
+ log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
34
+
35
+ def __getitem__(self, idx: int) -> torch.Tensor:
36
+ return self.data[idx]
37
+
38
+ def __len__(self):
39
+ return len(self.data)
mmaudio/data/eval/moviegen.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import torch
8
+ from torch.utils.data.dataset import Dataset
9
+ from torchvision.transforms import v2
10
+ from torio.io import StreamingMediaDecoder
11
+
12
+ from mmaudio.utils.dist_utils import local_rank
13
+
14
+ log = logging.getLogger()
15
+
16
+ _CLIP_SIZE = 384
17
+ _CLIP_FPS = 8.0
18
+
19
+ _SYNC_SIZE = 224
20
+ _SYNC_FPS = 25.0
21
+
22
+
23
+ class MovieGenData(Dataset):
24
+
25
+ def __init__(
26
+ self,
27
+ video_root: Union[str, Path],
28
+ sync_root: Union[str, Path],
29
+ jsonl_root: Union[str, Path],
30
+ *,
31
+ duration_sec: float = 10.0,
32
+ read_clip: bool = True,
33
+ ):
34
+ self.video_root = Path(video_root)
35
+ self.sync_root = Path(sync_root)
36
+ self.jsonl_root = Path(jsonl_root)
37
+ self.read_clip = read_clip
38
+
39
+ videos = sorted(os.listdir(self.video_root))
40
+ videos = [v[:-4] for v in videos] # remove extensions
41
+ self.captions = {}
42
+
43
+ for v in videos:
44
+ with open(self.jsonl_root / (v + '.jsonl')) as f:
45
+ data = json.load(f)
46
+ self.captions[v] = data['audio_prompt']
47
+
48
+ if local_rank == 0:
49
+ log.info(f'{len(videos)} videos found in {video_root}')
50
+
51
+ self.duration_sec = duration_sec
52
+
53
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
54
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
55
+
56
+ self.clip_augment = v2.Compose([
57
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
58
+ v2.ToImage(),
59
+ v2.ToDtype(torch.float32, scale=True),
60
+ ])
61
+
62
+ self.sync_augment = v2.Compose([
63
+ v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
64
+ v2.CenterCrop(_SYNC_SIZE),
65
+ v2.ToImage(),
66
+ v2.ToDtype(torch.float32, scale=True),
67
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
68
+ ])
69
+
70
+ self.videos = videos
71
+
72
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
73
+ video_id = self.videos[idx]
74
+ caption = self.captions[video_id]
75
+
76
+ reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
77
+ reader.add_basic_video_stream(
78
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
79
+ frame_rate=_CLIP_FPS,
80
+ format='rgb24',
81
+ )
82
+ reader.add_basic_video_stream(
83
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
84
+ frame_rate=_SYNC_FPS,
85
+ format='rgb24',
86
+ )
87
+
88
+ reader.fill_buffer()
89
+ data_chunk = reader.pop_chunks()
90
+
91
+ clip_chunk = data_chunk[0]
92
+ sync_chunk = data_chunk[1]
93
+ if clip_chunk is None:
94
+ raise RuntimeError(f'CLIP video returned None {video_id}')
95
+ if clip_chunk.shape[0] < self.clip_expected_length:
96
+ raise RuntimeError(f'CLIP video too short {video_id}')
97
+
98
+ if sync_chunk is None:
99
+ raise RuntimeError(f'Sync video returned None {video_id}')
100
+ if sync_chunk.shape[0] < self.sync_expected_length:
101
+ raise RuntimeError(f'Sync video too short {video_id}')
102
+
103
+ # truncate the video
104
+ clip_chunk = clip_chunk[:self.clip_expected_length]
105
+ if clip_chunk.shape[0] != self.clip_expected_length:
106
+ raise RuntimeError(f'CLIP video wrong length {video_id}, '
107
+ f'expected {self.clip_expected_length}, '
108
+ f'got {clip_chunk.shape[0]}')
109
+ clip_chunk = self.clip_augment(clip_chunk)
110
+
111
+ sync_chunk = sync_chunk[:self.sync_expected_length]
112
+ if sync_chunk.shape[0] != self.sync_expected_length:
113
+ raise RuntimeError(f'Sync video wrong length {video_id}, '
114
+ f'expected {self.sync_expected_length}, '
115
+ f'got {sync_chunk.shape[0]}')
116
+ sync_chunk = self.sync_augment(sync_chunk)
117
+
118
+ data = {
119
+ 'name': video_id,
120
+ 'caption': caption,
121
+ 'clip_video': clip_chunk,
122
+ 'sync_video': sync_chunk,
123
+ }
124
+
125
+ return data
126
+
127
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
128
+ return self.sample(idx)
129
+
130
+ def __len__(self):
131
+ return len(self.captions)
mmaudio/data/eval/video_dataset.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import pandas as pd
8
+ import torch
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+
13
+ from mmaudio.utils.dist_utils import local_rank
14
+
15
+ log = logging.getLogger()
16
+
17
+ _CLIP_SIZE = 384
18
+ _CLIP_FPS = 8.0
19
+
20
+ _SYNC_SIZE = 224
21
+ _SYNC_FPS = 25.0
22
+
23
+
24
+ class VideoDataset(Dataset):
25
+
26
+ def __init__(
27
+ self,
28
+ video_root: Union[str, Path],
29
+ *,
30
+ duration_sec: float = 8.0,
31
+ ):
32
+ self.video_root = Path(video_root)
33
+
34
+ self.duration_sec = duration_sec
35
+
36
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
37
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
38
+
39
+ self.clip_transform = v2.Compose([
40
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
41
+ v2.ToImage(),
42
+ v2.ToDtype(torch.float32, scale=True),
43
+ ])
44
+
45
+ self.sync_transform = v2.Compose([
46
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
47
+ v2.CenterCrop(_SYNC_SIZE),
48
+ v2.ToImage(),
49
+ v2.ToDtype(torch.float32, scale=True),
50
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
51
+ ])
52
+
53
+ # to be implemented by subclasses
54
+ self.captions = {}
55
+ self.videos = sorted(list(self.captions.keys()))
56
+
57
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
58
+ video_id = self.videos[idx]
59
+ caption = self.captions[video_id]
60
+
61
+ reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
62
+ reader.add_basic_video_stream(
63
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
64
+ frame_rate=_CLIP_FPS,
65
+ format='rgb24',
66
+ )
67
+ reader.add_basic_video_stream(
68
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
69
+ frame_rate=_SYNC_FPS,
70
+ format='rgb24',
71
+ )
72
+
73
+ reader.fill_buffer()
74
+ data_chunk = reader.pop_chunks()
75
+
76
+ clip_chunk = data_chunk[0]
77
+ sync_chunk = data_chunk[1]
78
+ if clip_chunk is None:
79
+ raise RuntimeError(f'CLIP video returned None {video_id}')
80
+ if clip_chunk.shape[0] < self.clip_expected_length:
81
+ raise RuntimeError(
82
+ f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
83
+ )
84
+
85
+ if sync_chunk is None:
86
+ raise RuntimeError(f'Sync video returned None {video_id}')
87
+ if sync_chunk.shape[0] < self.sync_expected_length:
88
+ raise RuntimeError(
89
+ f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
90
+ )
91
+
92
+ # truncate the video
93
+ clip_chunk = clip_chunk[:self.clip_expected_length]
94
+ if clip_chunk.shape[0] != self.clip_expected_length:
95
+ raise RuntimeError(f'CLIP video wrong length {video_id}, '
96
+ f'expected {self.clip_expected_length}, '
97
+ f'got {clip_chunk.shape[0]}')
98
+ clip_chunk = self.clip_transform(clip_chunk)
99
+
100
+ sync_chunk = sync_chunk[:self.sync_expected_length]
101
+ if sync_chunk.shape[0] != self.sync_expected_length:
102
+ raise RuntimeError(f'Sync video wrong length {video_id}, '
103
+ f'expected {self.sync_expected_length}, '
104
+ f'got {sync_chunk.shape[0]}')
105
+ sync_chunk = self.sync_transform(sync_chunk)
106
+
107
+ data = {
108
+ 'name': video_id,
109
+ 'caption': caption,
110
+ 'clip_video': clip_chunk,
111
+ 'sync_video': sync_chunk,
112
+ }
113
+
114
+ return data
115
+
116
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
117
+ try:
118
+ return self.sample(idx)
119
+ except Exception as e:
120
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
121
+ return None
122
+
123
+ def __len__(self):
124
+ return len(self.captions)
125
+
126
+
127
+ class VGGSound(VideoDataset):
128
+
129
+ def __init__(
130
+ self,
131
+ video_root: Union[str, Path],
132
+ csv_path: Union[str, Path],
133
+ *,
134
+ duration_sec: float = 8.0,
135
+ ):
136
+ super().__init__(video_root, duration_sec=duration_sec)
137
+ self.video_root = Path(video_root)
138
+ self.csv_path = Path(csv_path)
139
+
140
+ videos = sorted(os.listdir(self.video_root))
141
+ if local_rank == 0:
142
+ log.info(f'{len(videos)} videos found in {video_root}')
143
+ self.captions = {}
144
+
145
+ df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
146
+ 'split']).to_dict(orient='records')
147
+
148
+ videos_no_found = []
149
+ for row in df:
150
+ if row['split'] == 'test':
151
+ start_sec = int(row['sec'])
152
+ video_id = str(row['id'])
153
+ # this is how our videos are named
154
+ video_name = f'{video_id}_{start_sec:06d}'
155
+ if video_name + '.mp4' not in videos:
156
+ videos_no_found.append(video_name)
157
+ continue
158
+
159
+ self.captions[video_name] = row['caption']
160
+
161
+ if local_rank == 0:
162
+ log.info(f'{len(videos)} videos found in {video_root}')
163
+ log.info(f'{len(self.captions)} useable videos found')
164
+ if videos_no_found:
165
+ log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
166
+ log.info(
167
+ 'A small amount is expected, as not all videos are still available on YouTube')
168
+
169
+ self.videos = sorted(list(self.captions.keys()))
170
+
171
+
172
+ class MovieGen(VideoDataset):
173
+
174
+ def __init__(
175
+ self,
176
+ video_root: Union[str, Path],
177
+ jsonl_root: Union[str, Path],
178
+ *,
179
+ duration_sec: float = 10.0,
180
+ ):
181
+ super().__init__(video_root, duration_sec=duration_sec)
182
+ self.video_root = Path(video_root)
183
+ self.jsonl_root = Path(jsonl_root)
184
+
185
+ videos = sorted(os.listdir(self.video_root))
186
+ videos = [v[:-4] for v in videos] # remove extensions
187
+ self.captions = {}
188
+
189
+ for v in videos:
190
+ with open(self.jsonl_root / (v + '.jsonl')) as f:
191
+ data = json.load(f)
192
+ self.captions[v] = data['audio_prompt']
193
+
194
+ if local_rank == 0:
195
+ log.info(f'{len(videos)} videos found in {video_root}')
196
+
197
+ self.videos = videos
mmaudio/data/extracted_audio.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ import pandas as pd
6
+ import torch
7
+ from tensordict import TensorDict
8
+ from torch.utils.data.dataset import Dataset
9
+
10
+ from mmaudio.utils.dist_utils import local_rank
11
+
12
+ log = logging.getLogger()
13
+
14
+
15
+ class ExtractedAudio(Dataset):
16
+
17
+ def __init__(
18
+ self,
19
+ tsv_path: Union[str, Path],
20
+ *,
21
+ premade_mmap_dir: Union[str, Path],
22
+ data_dim: dict[str, int],
23
+ ):
24
+ super().__init__()
25
+
26
+ self.data_dim = data_dim
27
+ self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
28
+ self.ids = [str(d['id']) for d in self.df_list]
29
+
30
+ log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
31
+ # load precomputed memory mapped tensors
32
+ premade_mmap_dir = Path(premade_mmap_dir)
33
+ td = TensorDict.load_memmap(premade_mmap_dir)
34
+ log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
35
+ self.mean = td['mean']
36
+ self.std = td['std']
37
+ self.text_features = td['text_features']
38
+
39
+ log.info(f'Loaded {len(self)} samples from {premade_mmap_dir}.')
40
+ log.info(f'Loaded mean: {self.mean.shape}.')
41
+ log.info(f'Loaded std: {self.std.shape}.')
42
+ log.info(f'Loaded text features: {self.text_features.shape}.')
43
+
44
+ assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
45
+ f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
46
+ assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
47
+ f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
48
+
49
+ assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
50
+ f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
51
+ assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
52
+ f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
53
+
54
+ self.fake_clip_features = torch.zeros(self.data_dim['clip_seq_len'],
55
+ self.data_dim['clip_dim'])
56
+ self.fake_sync_features = torch.zeros(self.data_dim['sync_seq_len'],
57
+ self.data_dim['sync_dim'])
58
+ self.video_exist = torch.tensor(0, dtype=torch.bool)
59
+ self.text_exist = torch.tensor(1, dtype=torch.bool)
60
+
61
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
62
+ latents = self.mean
63
+ return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
64
+
65
+ def get_memory_mapped_tensor(self) -> TensorDict:
66
+ td = TensorDict({
67
+ 'mean': self.mean,
68
+ 'std': self.std,
69
+ 'text_features': self.text_features,
70
+ })
71
+ return td
72
+
73
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
74
+ data = {
75
+ 'id': str(self.df_list[idx]['id']),
76
+ 'a_mean': self.mean[idx],
77
+ 'a_std': self.std[idx],
78
+ 'clip_features': self.fake_clip_features,
79
+ 'sync_features': self.fake_sync_features,
80
+ 'text_features': self.text_features[idx],
81
+ 'caption': self.df_list[idx]['caption'],
82
+ 'video_exist': self.video_exist,
83
+ 'text_exist': self.text_exist,
84
+ }
85
+ return data
86
+
87
+ def __len__(self):
88
+ return len(self.ids)
mmaudio/data/extracted_vgg.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ import pandas as pd
6
+ import torch
7
+ from tensordict import TensorDict
8
+ from torch.utils.data.dataset import Dataset
9
+
10
+ from mmaudio.utils.dist_utils import local_rank
11
+
12
+ log = logging.getLogger()
13
+
14
+
15
+ class ExtractedVGG(Dataset):
16
+
17
+ def __init__(
18
+ self,
19
+ tsv_path: Union[str, Path],
20
+ *,
21
+ premade_mmap_dir: Union[str, Path],
22
+ data_dim: dict[str, int],
23
+ ):
24
+ super().__init__()
25
+
26
+ self.data_dim = data_dim
27
+ self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
28
+ self.ids = [d['id'] for d in self.df_list]
29
+
30
+ log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
31
+ # load precomputed memory mapped tensors
32
+ premade_mmap_dir = Path(premade_mmap_dir)
33
+ td = TensorDict.load_memmap(premade_mmap_dir)
34
+ log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
35
+ self.mean = td['mean']
36
+ self.std = td['std']
37
+ self.clip_features = td['clip_features']
38
+ self.sync_features = td['sync_features']
39
+ self.text_features = td['text_features']
40
+
41
+ if local_rank == 0:
42
+ log.info(f'Loaded {len(self)} samples.')
43
+ log.info(f'Loaded mean: {self.mean.shape}.')
44
+ log.info(f'Loaded std: {self.std.shape}.')
45
+ log.info(f'Loaded clip_features: {self.clip_features.shape}.')
46
+ log.info(f'Loaded sync_features: {self.sync_features.shape}.')
47
+ log.info(f'Loaded text_features: {self.text_features.shape}.')
48
+
49
+ assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
50
+ f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
51
+ assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
52
+ f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
53
+
54
+ assert self.clip_features.shape[1] == self.data_dim['clip_seq_len'], \
55
+ f'{self.clip_features.shape[1]} != {self.data_dim["clip_seq_len"]}'
56
+ assert self.sync_features.shape[1] == self.data_dim['sync_seq_len'], \
57
+ f'{self.sync_features.shape[1]} != {self.data_dim["sync_seq_len"]}'
58
+ assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
59
+ f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
60
+
61
+ assert self.clip_features.shape[-1] == self.data_dim['clip_dim'], \
62
+ f'{self.clip_features.shape[-1]} != {self.data_dim["clip_dim"]}'
63
+ assert self.sync_features.shape[-1] == self.data_dim['sync_dim'], \
64
+ f'{self.sync_features.shape[-1]} != {self.data_dim["sync_dim"]}'
65
+ assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
66
+ f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
67
+
68
+ self.video_exist = torch.tensor(1, dtype=torch.bool)
69
+ self.text_exist = torch.tensor(1, dtype=torch.bool)
70
+
71
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
72
+ latents = self.mean
73
+ return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
74
+
75
+ def get_memory_mapped_tensor(self) -> TensorDict:
76
+ td = TensorDict({
77
+ 'mean': self.mean,
78
+ 'std': self.std,
79
+ 'clip_features': self.clip_features,
80
+ 'sync_features': self.sync_features,
81
+ 'text_features': self.text_features,
82
+ })
83
+ return td
84
+
85
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
86
+ data = {
87
+ 'id': self.df_list[idx]['id'],
88
+ 'a_mean': self.mean[idx],
89
+ 'a_std': self.std[idx],
90
+ 'clip_features': self.clip_features[idx],
91
+ 'sync_features': self.sync_features[idx],
92
+ 'text_features': self.text_features[idx],
93
+ 'caption': self.df_list[idx]['label'],
94
+ 'video_exist': self.video_exist,
95
+ 'text_exist': self.text_exist,
96
+ }
97
+
98
+ return data
99
+
100
+ def __len__(self):
101
+ return len(self.ids)
mmaudio/data/extraction/__init__.py ADDED
File without changes
mmaudio/data/extraction/vgg_sound.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+
13
+ from mmaudio.utils.dist_utils import local_rank
14
+
15
+ log = logging.getLogger()
16
+
17
+ _CLIP_SIZE = 384
18
+ _CLIP_FPS = 8.0
19
+
20
+ _SYNC_SIZE = 224
21
+ _SYNC_FPS = 25.0
22
+
23
+
24
+ class VGGSound(Dataset):
25
+
26
+ def __init__(
27
+ self,
28
+ root: Union[str, Path],
29
+ *,
30
+ tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
31
+ sample_rate: int = 16_000,
32
+ duration_sec: float = 8.0,
33
+ audio_samples: Optional[int] = None,
34
+ normalize_audio: bool = False,
35
+ ):
36
+ self.root = Path(root)
37
+ self.normalize_audio = normalize_audio
38
+ if audio_samples is None:
39
+ self.audio_samples = int(sample_rate * duration_sec)
40
+ else:
41
+ self.audio_samples = audio_samples
42
+ effective_duration = audio_samples / sample_rate
43
+ # make sure the duration is close enough, within 15ms
44
+ assert abs(effective_duration - duration_sec) < 0.015, \
45
+ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
46
+
47
+ videos = sorted(os.listdir(self.root))
48
+ videos = set([Path(v).stem for v in videos]) # remove extensions
49
+ self.labels = {}
50
+ self.videos = []
51
+ missing_videos = []
52
+
53
+ # read the tsv for subset information
54
+ df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
55
+ for record in df_list:
56
+ id = record['id']
57
+ label = record['label']
58
+ if id in videos:
59
+ self.labels[id] = label
60
+ self.videos.append(id)
61
+ else:
62
+ missing_videos.append(id)
63
+
64
+ if local_rank == 0:
65
+ log.info(f'{len(videos)} videos found in {root}')
66
+ log.info(f'{len(self.videos)} videos found in {tsv_path}')
67
+ log.info(f'{len(missing_videos)} videos missing in {root}')
68
+
69
+ self.sample_rate = sample_rate
70
+ self.duration_sec = duration_sec
71
+
72
+ self.expected_audio_length = audio_samples
73
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
74
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
75
+
76
+ self.clip_transform = v2.Compose([
77
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
78
+ v2.ToImage(),
79
+ v2.ToDtype(torch.float32, scale=True),
80
+ ])
81
+
82
+ self.sync_transform = v2.Compose([
83
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
84
+ v2.CenterCrop(_SYNC_SIZE),
85
+ v2.ToImage(),
86
+ v2.ToDtype(torch.float32, scale=True),
87
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
88
+ ])
89
+
90
+ self.resampler = {}
91
+
92
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
93
+ video_id = self.videos[idx]
94
+ label = self.labels[video_id]
95
+
96
+ reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
97
+ reader.add_basic_video_stream(
98
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
99
+ frame_rate=_CLIP_FPS,
100
+ format='rgb24',
101
+ )
102
+ reader.add_basic_video_stream(
103
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
104
+ frame_rate=_SYNC_FPS,
105
+ format='rgb24',
106
+ )
107
+ reader.add_basic_audio_stream(frames_per_chunk=2**30, )
108
+
109
+ reader.fill_buffer()
110
+ data_chunk = reader.pop_chunks()
111
+
112
+ clip_chunk = data_chunk[0]
113
+ sync_chunk = data_chunk[1]
114
+ audio_chunk = data_chunk[2]
115
+
116
+ if clip_chunk is None:
117
+ raise RuntimeError(f'CLIP video returned None {video_id}')
118
+ if clip_chunk.shape[0] < self.clip_expected_length:
119
+ raise RuntimeError(
120
+ f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
121
+ )
122
+
123
+ if sync_chunk is None:
124
+ raise RuntimeError(f'Sync video returned None {video_id}')
125
+ if sync_chunk.shape[0] < self.sync_expected_length:
126
+ raise RuntimeError(
127
+ f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
128
+ )
129
+
130
+ # process audio
131
+ sample_rate = int(reader.get_out_stream_info(2).sample_rate)
132
+ audio_chunk = audio_chunk.transpose(0, 1)
133
+ audio_chunk = audio_chunk.mean(dim=0) # mono
134
+ if self.normalize_audio:
135
+ abs_max = audio_chunk.abs().max()
136
+ audio_chunk = audio_chunk / abs_max * 0.95
137
+ if abs_max <= 1e-6:
138
+ raise RuntimeError(f'Audio is silent {video_id}')
139
+
140
+ # resample
141
+ if sample_rate == self.sample_rate:
142
+ audio_chunk = audio_chunk
143
+ else:
144
+ if sample_rate not in self.resampler:
145
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
146
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
147
+ sample_rate,
148
+ self.sample_rate,
149
+ lowpass_filter_width=64,
150
+ rolloff=0.9475937167399596,
151
+ resampling_method='sinc_interp_kaiser',
152
+ beta=14.769656459379492,
153
+ )
154
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
155
+
156
+ if audio_chunk.shape[0] < self.expected_audio_length:
157
+ raise RuntimeError(f'Audio too short {video_id}')
158
+ audio_chunk = audio_chunk[:self.expected_audio_length]
159
+
160
+ # truncate the video
161
+ clip_chunk = clip_chunk[:self.clip_expected_length]
162
+ if clip_chunk.shape[0] != self.clip_expected_length:
163
+ raise RuntimeError(f'CLIP video wrong length {video_id}, '
164
+ f'expected {self.clip_expected_length}, '
165
+ f'got {clip_chunk.shape[0]}')
166
+ clip_chunk = self.clip_transform(clip_chunk)
167
+
168
+ sync_chunk = sync_chunk[:self.sync_expected_length]
169
+ if sync_chunk.shape[0] != self.sync_expected_length:
170
+ raise RuntimeError(f'Sync video wrong length {video_id}, '
171
+ f'expected {self.sync_expected_length}, '
172
+ f'got {sync_chunk.shape[0]}')
173
+ sync_chunk = self.sync_transform(sync_chunk)
174
+
175
+ data = {
176
+ 'id': video_id,
177
+ 'caption': label,
178
+ 'audio': audio_chunk,
179
+ 'clip_video': clip_chunk,
180
+ 'sync_video': sync_chunk,
181
+ }
182
+
183
+ return data
184
+
185
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
186
+ try:
187
+ return self.sample(idx)
188
+ except Exception as e:
189
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
190
+ return None
191
+
192
+ def __len__(self):
193
+ return len(self.labels)
mmaudio/data/extraction/wav_dataset.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import open_clip
7
+ import pandas as pd
8
+ import torch
9
+ import torchaudio
10
+ from torch.utils.data.dataset import Dataset
11
+
12
+ log = logging.getLogger()
13
+
14
+
15
+ class WavTextClipsDataset(Dataset):
16
+
17
+ def __init__(
18
+ self,
19
+ root: Union[str, Path],
20
+ *,
21
+ captions_tsv: Union[str, Path],
22
+ clips_tsv: Union[str, Path],
23
+ sample_rate: int,
24
+ num_samples: int,
25
+ normalize_audio: bool = False,
26
+ reject_silent: bool = False,
27
+ tokenizer_id: str = 'ViT-H-14-378-quickgelu',
28
+ ):
29
+ self.root = Path(root)
30
+ self.sample_rate = sample_rate
31
+ self.num_samples = num_samples
32
+ self.normalize_audio = normalize_audio
33
+ self.reject_silent = reject_silent
34
+ self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
35
+
36
+ audios = sorted(os.listdir(self.root))
37
+ audios = set([
38
+ Path(audio).stem for audio in audios
39
+ if audio.endswith('.wav') or audio.endswith('.flac')
40
+ ])
41
+ self.captions = {}
42
+
43
+ # read the caption tsv
44
+ df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
45
+ for record in df_list:
46
+ id = record['id']
47
+ caption = record['caption']
48
+ self.captions[id] = caption
49
+
50
+ # read the clip tsv
51
+ df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
52
+ 'id': str,
53
+ 'name': str
54
+ }).to_dict('records')
55
+ self.clips = []
56
+ for record in df_list:
57
+ record['id'] = record['id']
58
+ record['name'] = record['name']
59
+ id = record['id']
60
+ name = record['name']
61
+ if name not in self.captions:
62
+ log.warning(f'Audio {name} not found in {captions_tsv}')
63
+ continue
64
+ record['caption'] = self.captions[name]
65
+ self.clips.append(record)
66
+
67
+ log.info(f'Found {len(self.clips)} audio files in {self.root}')
68
+
69
+ self.resampler = {}
70
+
71
+ def __getitem__(self, idx: int) -> torch.Tensor:
72
+ try:
73
+ clip = self.clips[idx]
74
+ audio_name = clip['name']
75
+ audio_id = clip['id']
76
+ caption = clip['caption']
77
+ start_sample = clip['start_sample']
78
+ end_sample = clip['end_sample']
79
+
80
+ audio_path = self.root / f'{audio_name}.flac'
81
+ if not audio_path.exists():
82
+ audio_path = self.root / f'{audio_name}.wav'
83
+ assert audio_path.exists()
84
+
85
+ audio_chunk, sample_rate = torchaudio.load(audio_path)
86
+ audio_chunk = audio_chunk.mean(dim=0) # mono
87
+ abs_max = audio_chunk.abs().max()
88
+ if self.normalize_audio:
89
+ audio_chunk = audio_chunk / abs_max * 0.95
90
+
91
+ if self.reject_silent and abs_max < 1e-6:
92
+ log.warning(f'Rejecting silent audio')
93
+ return None
94
+
95
+ audio_chunk = audio_chunk[start_sample:end_sample]
96
+
97
+ # resample
98
+ if sample_rate == self.sample_rate:
99
+ audio_chunk = audio_chunk
100
+ else:
101
+ if sample_rate not in self.resampler:
102
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
103
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
104
+ sample_rate,
105
+ self.sample_rate,
106
+ lowpass_filter_width=64,
107
+ rolloff=0.9475937167399596,
108
+ resampling_method='sinc_interp_kaiser',
109
+ beta=14.769656459379492,
110
+ )
111
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
112
+
113
+ if audio_chunk.shape[0] < self.num_samples:
114
+ raise ValueError('Audio is too short')
115
+ audio_chunk = audio_chunk[:self.num_samples]
116
+
117
+ tokens = self.tokenizer([caption])[0]
118
+
119
+ output = {
120
+ 'waveform': audio_chunk,
121
+ 'id': audio_id,
122
+ 'caption': caption,
123
+ 'tokens': tokens,
124
+ }
125
+
126
+ return output
127
+ except Exception as e:
128
+ log.error(f'Error reading {audio_path}: {e}')
129
+ return None
130
+
131
+ def __len__(self):
132
+ return len(self.clips)
mmaudio/data/mm_dataset.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+
3
+ import torch
4
+ from torch.utils.data.dataset import Dataset
5
+
6
+
7
+ # modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
8
+ class MultiModalDataset(Dataset):
9
+ datasets: list[Dataset]
10
+ cumulative_sizes: list[int]
11
+
12
+ @staticmethod
13
+ def cumsum(sequence):
14
+ r, s = [], 0
15
+ for e in sequence:
16
+ l = len(e)
17
+ r.append(l + s)
18
+ s += l
19
+ return r
20
+
21
+ def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
22
+ super().__init__()
23
+ self.video_datasets = list(video_datasets)
24
+ self.audio_datasets = list(audio_datasets)
25
+ self.datasets = self.video_datasets + self.audio_datasets
26
+
27
+ self.cumulative_sizes = self.cumsum(self.datasets)
28
+
29
+ def __len__(self):
30
+ return self.cumulative_sizes[-1]
31
+
32
+ def __getitem__(self, idx):
33
+ if idx < 0:
34
+ if -idx > len(self):
35
+ raise ValueError("absolute value of index should not exceed dataset length")
36
+ idx = len(self) + idx
37
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
38
+ if dataset_idx == 0:
39
+ sample_idx = idx
40
+ else:
41
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
42
+ return self.datasets[dataset_idx][sample_idx]
43
+
44
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
45
+ return self.video_datasets[0].compute_latent_stats()
mmaudio/data/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Any, Optional, Union
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from tensordict import MemoryMappedTensor
11
+ from torch.utils.data import DataLoader
12
+ from torch.utils.data.dataset import Dataset
13
+ from tqdm import tqdm
14
+
15
+ from mmaudio.utils.dist_utils import local_rank, world_size
16
+
17
+ scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
18
+ shm_path = Path('/dev/shm')
19
+
20
+ log = logging.getLogger()
21
+
22
+
23
+ def reseed(seed):
24
+ random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+
28
+ def local_scatter_torch(obj: Optional[Any]):
29
+ if world_size == 1:
30
+ # Just one worker. Do nothing.
31
+ return obj
32
+
33
+ array = [obj] * world_size
34
+ target_array = [None]
35
+ if local_rank == 0:
36
+ dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
37
+ else:
38
+ dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
39
+ return target_array[0]
40
+
41
+
42
+ class ShardDataset(Dataset):
43
+
44
+ def __init__(self, root):
45
+ self.root = root
46
+ self.shards = sorted(os.listdir(root))
47
+
48
+ def __len__(self):
49
+ return len(self.shards)
50
+
51
+ def __getitem__(self, idx):
52
+ return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
53
+
54
+
55
+ def get_tmp_dir(in_memory: bool) -> Path:
56
+ return shm_path if in_memory else scratch_path
57
+
58
+
59
+ def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
60
+ in_memory: bool) -> MemoryMappedTensor:
61
+ if local_rank == 0:
62
+ with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
63
+ log.info(f'Loading shards from {data_path} into {f.name}...')
64
+ data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
65
+ data = share_tensor_to_all(data)
66
+ torch.distributed.barrier()
67
+ f.close() # why does the context manager not close the file for me?
68
+ else:
69
+ log.info('Waiting for the data to be shared with me...')
70
+ data = share_tensor_to_all(None)
71
+ torch.distributed.barrier()
72
+
73
+ return data
74
+
75
+
76
+ def load_shards(
77
+ data_path: Union[str, Path],
78
+ ids: list[int],
79
+ *,
80
+ tmp_file_path: str,
81
+ ) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
82
+
83
+ id_set = set(ids)
84
+ shards = sorted(os.listdir(data_path))
85
+ log.info(f'Found {len(shards)} shards in {data_path}.')
86
+ first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
87
+
88
+ log.info(f'Rank {local_rank} created file {tmp_file_path}')
89
+ first_item = next(iter(first_shard.values()))
90
+ log.info(f'First item shape: {first_item.shape}')
91
+ mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
92
+ dtype=torch.float32,
93
+ filename=tmp_file_path,
94
+ existsok=True)
95
+ total_count = 0
96
+ used_index = set()
97
+ id_indexing = {i: idx for idx, i in enumerate(ids)}
98
+ # faster with no workers; otherwise we need to set_sharing_strategy('file_system')
99
+ loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
100
+ for data in tqdm(loader, desc='Loading shards'):
101
+ for i, v in data.items():
102
+ if i not in id_set:
103
+ continue
104
+
105
+ # tensor_index = ids.index(i)
106
+ tensor_index = id_indexing[i]
107
+ if tensor_index in used_index:
108
+ raise ValueError(f'Duplicate id {i} found in {data_path}.')
109
+ used_index.add(tensor_index)
110
+ mm_tensor[tensor_index] = v
111
+ total_count += 1
112
+
113
+ assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
114
+ log.info(f'Loaded {total_count} tensors from {data_path}.')
115
+
116
+ return mm_tensor
117
+
118
+
119
+ def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
120
+ """
121
+ x: the tensor to be shared; None if local_rank != 0
122
+ return: the shared tensor
123
+ """
124
+
125
+ # there is no need to share your stuff with anyone if you are alone; must be in memory
126
+ if world_size == 1:
127
+ return x
128
+
129
+ if local_rank == 0:
130
+ assert x is not None, 'x must not be None if local_rank == 0'
131
+ else:
132
+ assert x is None, 'x must be None if local_rank != 0'
133
+
134
+ if local_rank == 0:
135
+ filename = x.filename
136
+ meta_information = (filename, x.shape, x.dtype)
137
+ else:
138
+ meta_information = None
139
+
140
+ filename, data_shape, data_type = local_scatter_torch(meta_information)
141
+ if local_rank == 0:
142
+ data = x
143
+ else:
144
+ data = MemoryMappedTensor.from_filename(filename=filename,
145
+ dtype=data_type,
146
+ shape=data_shape)
147
+
148
+ return data
mmaudio/eval_utils.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Optional, Tuple, List, Dict
5
+
6
+ import numpy as np
7
+ import torch
8
+ from colorlog import ColoredFormatter
9
+ from PIL import Image
10
+ from torchvision.transforms import v2
11
+
12
+ from mmaudio.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio
13
+ from mmaudio.model.flow_matching import FlowMatching
14
+ from mmaudio.model.networks import MMAudio
15
+ from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
16
+ from mmaudio.model.utils.features_utils import FeaturesUtils
17
+ from mmaudio.utils.download_utils import download_model_if_needed
18
+
19
+ log = logging.getLogger()
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class ModelConfig:
24
+ model_name: str
25
+ model_path: Path
26
+ vae_path: Path
27
+ bigvgan_16k_path: Optional[Path]
28
+ mode: str
29
+ synchformer_ckpt: Path = Path('./pretrained/v2a/mmaudio/ext_weights/synchformer_state_dict.pth')
30
+
31
+ @property
32
+ def seq_cfg(self) -> SequenceConfig:
33
+ if self.mode == '16k':
34
+ return CONFIG_16K
35
+ elif self.mode == '44k':
36
+ return CONFIG_44K
37
+
38
+ def download_if_needed(self):
39
+ download_model_if_needed(self.model_path)
40
+ download_model_if_needed(self.vae_path)
41
+ if self.bigvgan_16k_path is not None:
42
+ download_model_if_needed(self.bigvgan_16k_path)
43
+ download_model_if_needed(self.synchformer_ckpt)
44
+
45
+
46
+ small_16k = ModelConfig(model_name='small_16k',
47
+ model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_small_16k.pth'),
48
+ vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-16.pth'),
49
+ bigvgan_16k_path=Path('./pretrained/v2a/mmaudio/ext_weights/best_netG.pt'),
50
+ mode='16k')
51
+ small_44k = ModelConfig(model_name='small_44k',
52
+ model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_small_44k.pth'),
53
+ vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'),
54
+ bigvgan_16k_path=None,
55
+ mode='44k')
56
+ medium_44k = ModelConfig(model_name='medium_44k',
57
+ model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_medium_44k.pth'),
58
+ vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'),
59
+ bigvgan_16k_path=None,
60
+ mode='44k')
61
+ large_44k = ModelConfig(model_name='large_44k',
62
+ model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_large_44k.pth'),
63
+ vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'),
64
+ bigvgan_16k_path=None,
65
+ mode='44k')
66
+ large_44k_v2 = ModelConfig(model_name='large_44k_v2',
67
+ model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_large_44k_v2.pth'),
68
+ vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'),
69
+ bigvgan_16k_path=None,
70
+ mode='44k')
71
+ all_model_cfg: Dict[str, ModelConfig] = {
72
+ 'small_16k': small_16k,
73
+ 'small_44k': small_44k,
74
+ 'medium_44k': medium_44k,
75
+ 'large_44k': large_44k,
76
+ 'large_44k_v2': large_44k_v2,
77
+ }
78
+
79
+
80
+ def generate(
81
+ clip_video: Optional[torch.Tensor],
82
+ sync_video: Optional[torch.Tensor],
83
+ text: Optional[List[str]],
84
+ *,
85
+ negative_text: Optional[List[str]] = None,
86
+ feature_utils: FeaturesUtils,
87
+ net: MMAudio,
88
+ fm: FlowMatching,
89
+ rng: torch.Generator,
90
+ cfg_strength: float,
91
+ clip_batch_size_multiplier: int = 40,
92
+ sync_batch_size_multiplier: int = 40,
93
+ image_input: bool = False,
94
+ ) -> torch.Tensor:
95
+ device = feature_utils.device
96
+ dtype = feature_utils.dtype
97
+
98
+ bs = len(text)
99
+ if clip_video is not None:
100
+ clip_video = clip_video.to(device, dtype, non_blocking=True)
101
+ clip_features = feature_utils.encode_video_with_clip(clip_video,
102
+ batch_size=bs *
103
+ clip_batch_size_multiplier)
104
+ if image_input:
105
+ clip_features = clip_features.expand(-1, net.clip_seq_len, -1)
106
+ else:
107
+ clip_features = net.get_empty_clip_sequence(bs)
108
+
109
+ if sync_video is not None and not image_input:
110
+ sync_video = sync_video.to(device, dtype, non_blocking=True)
111
+ sync_features = feature_utils.encode_video_with_sync(sync_video,
112
+ batch_size=bs *
113
+ sync_batch_size_multiplier)
114
+ else:
115
+ sync_features = net.get_empty_sync_sequence(bs)
116
+
117
+ if text is not None:
118
+ text_features = feature_utils.encode_text(text)
119
+ else:
120
+ text_features = net.get_empty_string_sequence(bs)
121
+
122
+ if negative_text is not None:
123
+ assert len(negative_text) == bs
124
+ negative_text_features = feature_utils.encode_text(negative_text)
125
+ else:
126
+ negative_text_features = net.get_empty_string_sequence(bs)
127
+
128
+ x0 = torch.randn(bs,
129
+ net.latent_seq_len,
130
+ net.latent_dim,
131
+ device=device,
132
+ dtype=dtype,
133
+ generator=rng)
134
+ preprocessed_conditions = net.preprocess_conditions(clip_features, sync_features, text_features)
135
+ empty_conditions = net.get_empty_conditions(
136
+ bs, negative_text_features=negative_text_features if negative_text is not None else None)
137
+
138
+ cfg_ode_wrapper = lambda t, x: net.ode_wrapper(t, x, preprocessed_conditions, empty_conditions,
139
+ cfg_strength)
140
+ x1 = fm.to_data(cfg_ode_wrapper, x0)
141
+ x1 = net.unnormalize(x1)
142
+ spec = feature_utils.decode(x1)
143
+ audio = feature_utils.vocode(spec)
144
+ return audio
145
+
146
+
147
+ LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s"
148
+
149
+
150
+ def setup_eval_logging(log_level: int = logging.INFO):
151
+ logging.root.setLevel(log_level)
152
+ formatter = ColoredFormatter(LOGFORMAT)
153
+ stream = logging.StreamHandler()
154
+ stream.setLevel(log_level)
155
+ stream.setFormatter(formatter)
156
+ log = logging.getLogger()
157
+ log.setLevel(log_level)
158
+ log.addHandler(stream)
159
+
160
+
161
+ _CLIP_SIZE = 384
162
+ _CLIP_FPS = 8.0
163
+
164
+ _SYNC_SIZE = 224
165
+ _SYNC_FPS = 25.0
166
+
167
+
168
+ def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
169
+
170
+ clip_transform = v2.Compose([
171
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
172
+ v2.ToImage(),
173
+ v2.ToDtype(torch.float32, scale=True),
174
+ ])
175
+
176
+ sync_transform = v2.Compose([
177
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
178
+ v2.CenterCrop(_SYNC_SIZE),
179
+ v2.ToImage(),
180
+ v2.ToDtype(torch.float32, scale=True),
181
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
182
+ ])
183
+
184
+ output_frames, all_frames, orig_fps = read_frames(video_path,
185
+ list_of_fps=[_CLIP_FPS, _SYNC_FPS],
186
+ start_sec=0,
187
+ end_sec=duration_sec,
188
+ need_all_frames=load_all_frames)
189
+
190
+ clip_chunk, sync_chunk = output_frames
191
+ clip_chunk = torch.from_numpy(clip_chunk).permute(0, 3, 1, 2)
192
+ sync_chunk = torch.from_numpy(sync_chunk).permute(0, 3, 1, 2)
193
+
194
+ clip_frames = clip_transform(clip_chunk)
195
+ sync_frames = sync_transform(sync_chunk)
196
+
197
+ clip_length_sec = clip_frames.shape[0] / _CLIP_FPS
198
+ sync_length_sec = sync_frames.shape[0] / _SYNC_FPS
199
+
200
+ if clip_length_sec < duration_sec:
201
+ log.warning(f'Clip video is too short: {clip_length_sec:.2f} < {duration_sec:.2f}')
202
+ log.warning(f'Truncating to {clip_length_sec:.2f} sec')
203
+ duration_sec = clip_length_sec
204
+
205
+ if sync_length_sec < duration_sec:
206
+ log.warning(f'Sync video is too short: {sync_length_sec:.2f} < {duration_sec:.2f}')
207
+ log.warning(f'Truncating to {sync_length_sec:.2f} sec')
208
+ duration_sec = sync_length_sec
209
+
210
+ clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
211
+ sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
212
+
213
+ video_info = VideoInfo(
214
+ duration_sec=duration_sec,
215
+ fps=orig_fps,
216
+ clip_frames=clip_frames,
217
+ sync_frames=sync_frames,
218
+ all_frames=all_frames if load_all_frames else None,
219
+ )
220
+ return video_info
221
+
222
+
223
+ def load_image(image_path: Path) -> VideoInfo:
224
+ clip_transform = v2.Compose([
225
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
226
+ v2.ToImage(),
227
+ v2.ToDtype(torch.float32, scale=True),
228
+ ])
229
+
230
+ sync_transform = v2.Compose([
231
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
232
+ v2.CenterCrop(_SYNC_SIZE),
233
+ v2.ToImage(),
234
+ v2.ToDtype(torch.float32, scale=True),
235
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
236
+ ])
237
+
238
+ frame = np.array(Image.open(image_path))
239
+
240
+ clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
241
+ sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
242
+
243
+ clip_frames = clip_transform(clip_chunk)
244
+ sync_frames = sync_transform(sync_chunk)
245
+
246
+ video_info = ImageInfo(
247
+ clip_frames=clip_frames,
248
+ sync_frames=sync_frames,
249
+ original_frame=frame,
250
+ )
251
+ return video_info
252
+
253
+
254
+ def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
255
+ reencode_with_audio(video_info, output_path, audio, sampling_rate)