Spaces:
Running
Running
lym0302
commited on
Commit
·
eedfa8e
1
Parent(s):
bafca5a
try exp
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +21 -0
- README.md +185 -14
- app.py +343 -0
- batch_eval.py +110 -0
- config/__init__.py +0 -0
- config/base_config.yaml +62 -0
- config/data/base.yaml +70 -0
- config/eval_config.yaml +17 -0
- config/eval_data/base.yaml +22 -0
- config/hydra/job_logging/custom-eval.yaml +32 -0
- config/hydra/job_logging/custom-no-rank.yaml +32 -0
- config/hydra/job_logging/custom-simplest.yaml +26 -0
- config/hydra/job_logging/custom.yaml +33 -0
- config/train_config.yaml +41 -0
- demo.py +141 -0
- docs/EVAL.md +22 -0
- docs/MODELS.md +50 -0
- docs/TRAINING.md +184 -0
- docs/images/icon.png +0 -0
- docs/index.html +149 -0
- docs/style.css +78 -0
- docs/style_videos.css +52 -0
- docs/video_gen.html +254 -0
- docs/video_main.html +98 -0
- docs/video_vgg.html +452 -0
- gradio_demo.py +343 -0
- mmaudio/__init__.py +0 -0
- mmaudio/__pycache__/__init__.cpython-310.pyc +0 -0
- mmaudio/__pycache__/__init__.cpython-38.pyc +0 -0
- mmaudio/__pycache__/eval_utils.cpython-310.pyc +0 -0
- mmaudio/__pycache__/eval_utils.cpython-38.pyc +0 -0
- mmaudio/data/__init__.py +0 -0
- mmaudio/data/__pycache__/__init__.cpython-310.pyc +0 -0
- mmaudio/data/__pycache__/__init__.cpython-38.pyc +0 -0
- mmaudio/data/__pycache__/av_utils.cpython-310.pyc +0 -0
- mmaudio/data/__pycache__/av_utils.cpython-38.pyc +0 -0
- mmaudio/data/av_utils.py +162 -0
- mmaudio/data/data_setup.py +174 -0
- mmaudio/data/eval/__init__.py +0 -0
- mmaudio/data/eval/audiocaps.py +39 -0
- mmaudio/data/eval/moviegen.py +131 -0
- mmaudio/data/eval/video_dataset.py +197 -0
- mmaudio/data/extracted_audio.py +88 -0
- mmaudio/data/extracted_vgg.py +101 -0
- mmaudio/data/extraction/__init__.py +0 -0
- mmaudio/data/extraction/vgg_sound.py +193 -0
- mmaudio/data/extraction/wav_dataset.py +132 -0
- mmaudio/data/mm_dataset.py +45 -0
- mmaudio/data/utils.py +148 -0
- mmaudio/eval_utils.py +255 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Sony Research Inc.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,14 +1,185 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<p align="center">
|
3 |
+
<h2>MMAudio</h2>
|
4 |
+
<a href="https://arxiv.org/abs/2412.15322">Paper</a> | <a href="https://hkchengrex.github.io/MMAudio">Webpage</a> | <a href="https://huggingface.co/hkchengrex/MMAudio/tree/main">Models</a> | <a href="https://huggingface.co/spaces/hkchengrex/MMAudio"> Huggingface Demo</a> | <a href="https://colab.research.google.com/drive/1TAaXCY2-kPk4xE4PwKB3EqFbSnkUuzZ8?usp=sharing">Colab Demo</a> | <a href="https://replicate.com/zsxkib/mmaudio">Replicate Demo</a>
|
5 |
+
</p>
|
6 |
+
</div>
|
7 |
+
|
8 |
+
## [Taming Multimodal Joint Training for High-Quality Video-to-Audio Synthesis](https://hkchengrex.github.io/MMAudio)
|
9 |
+
|
10 |
+
[Ho Kei Cheng](https://hkchengrex.github.io/), [Masato Ishii](https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ), [Akio Hayakawa](https://scholar.google.com/citations?user=sXAjHFIAAAAJ), [Takashi Shibuya](https://scholar.google.com/citations?user=XCRO260AAAAJ), [Alexander Schwing](https://www.alexander-schwing.de/), [Yuki Mitsufuji](https://www.yukimitsufuji.com/)
|
11 |
+
|
12 |
+
University of Illinois Urbana-Champaign, Sony AI, and Sony Group Corporation
|
13 |
+
|
14 |
+
CVPR 2025
|
15 |
+
|
16 |
+
## Highlight
|
17 |
+
|
18 |
+
MMAudio generates synchronized audio given video and/or text inputs.
|
19 |
+
Our key innovation is multimodal joint training which allows training on a wide range of audio-visual and audio-text datasets.
|
20 |
+
Moreover, a synchronization module aligns the generated audio with the video frames.
|
21 |
+
|
22 |
+
## Results
|
23 |
+
|
24 |
+
(All audio from our algorithm MMAudio)
|
25 |
+
|
26 |
+
Videos from Sora:
|
27 |
+
|
28 |
+
https://github.com/user-attachments/assets/82afd192-0cee-48a1-86ca-bd39b8c8f330
|
29 |
+
|
30 |
+
Videos from Veo 2:
|
31 |
+
|
32 |
+
https://github.com/user-attachments/assets/8a11419e-fee2-46e0-9e67-dfb03c48d00e
|
33 |
+
|
34 |
+
Videos from MovieGen/Hunyuan Video/VGGSound:
|
35 |
+
|
36 |
+
https://github.com/user-attachments/assets/29230d4e-21c1-4cf8-a221-c28f2af6d0ca
|
37 |
+
|
38 |
+
For more results, visit https://hkchengrex.com/MMAudio/video_main.html.
|
39 |
+
|
40 |
+
|
41 |
+
## Installation
|
42 |
+
|
43 |
+
We have only tested this on Ubuntu.
|
44 |
+
|
45 |
+
### Prerequisites
|
46 |
+
|
47 |
+
We recommend using a [miniforge](https://github.com/conda-forge/miniforge) environment.
|
48 |
+
|
49 |
+
- Python 3.9+
|
50 |
+
- PyTorch **2.5.1+** and corresponding torchvision/torchaudio (pick your CUDA version https://pytorch.org/, pip install recommended)
|
51 |
+
<!-- - ffmpeg<7 ([this is required by torchaudio](https://pytorch.org/audio/master/installation.html#optional-dependencies), you can install it in a miniforge environment with `conda install -c conda-forge 'ffmpeg<7'`) -->
|
52 |
+
|
53 |
+
**1. Install prerequisite if not yet met:**
|
54 |
+
|
55 |
+
```bash
|
56 |
+
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 --upgrade
|
57 |
+
```
|
58 |
+
|
59 |
+
(Or any other CUDA versions that your GPUs/driver support)
|
60 |
+
|
61 |
+
<!-- ```
|
62 |
+
conda install -c conda-forge 'ffmpeg<7
|
63 |
+
```
|
64 |
+
(Optional, if you use miniforge and don't already have the appropriate ffmpeg) -->
|
65 |
+
|
66 |
+
**2. Clone our repository:**
|
67 |
+
|
68 |
+
```bash
|
69 |
+
git clone https://github.com/hkchengrex/MMAudio.git
|
70 |
+
```
|
71 |
+
|
72 |
+
**3. Install with pip (install pytorch first before attempting this!):**
|
73 |
+
|
74 |
+
```bash
|
75 |
+
cd MMAudio
|
76 |
+
pip install -e .
|
77 |
+
```
|
78 |
+
|
79 |
+
(If you encounter the File "setup.py" not found error, upgrade your pip with pip install --upgrade pip)
|
80 |
+
|
81 |
+
|
82 |
+
**Pretrained models:**
|
83 |
+
|
84 |
+
The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`.
|
85 |
+
The models are also available at https://huggingface.co/hkchengrex/MMAudio/tree/main
|
86 |
+
See [MODELS.md](docs/MODELS.md) for more details.
|
87 |
+
|
88 |
+
## Demo
|
89 |
+
|
90 |
+
By default, these scripts use the `large_44k_v2` model.
|
91 |
+
In our experiments, inference only takes around 6GB of GPU memory (in 16-bit mode) which should fit in most modern GPUs.
|
92 |
+
|
93 |
+
### Command-line interface
|
94 |
+
|
95 |
+
With `demo.py`
|
96 |
+
|
97 |
+
```bash
|
98 |
+
python demo.py --duration=8 --video=<path to video> --prompt "your prompt"
|
99 |
+
```
|
100 |
+
|
101 |
+
The output (audio in `.flac` format, and video in `.mp4` format) will be saved in `./output`.
|
102 |
+
See the file for more options.
|
103 |
+
Simply omit the `--video` option for text-to-audio synthesis.
|
104 |
+
The default output (and training) duration is 8 seconds. Longer/shorter durations could also work, but a large deviation from the training duration may result in a lower quality.
|
105 |
+
|
106 |
+
### Gradio interface
|
107 |
+
|
108 |
+
Supports video-to-audio and text-to-audio synthesis.
|
109 |
+
You can also try experimental image-to-audio synthesis which duplicates the input image to a video for processing. This might be interesting to some but it is not something MMAudio has been trained for.
|
110 |
+
Use [port forwarding](https://unix.stackexchange.com/questions/115897/whats-ssh-port-forwarding-and-whats-the-difference-between-ssh-local-and-remot) (e.g., `ssh -L 7860:localhost:7860 server`) if necessary. The default port is `7860` which you can specify with `--port`.
|
111 |
+
|
112 |
+
```bash
|
113 |
+
python gradio_demo.py
|
114 |
+
```
|
115 |
+
|
116 |
+
### FAQ
|
117 |
+
|
118 |
+
1. Video processing
|
119 |
+
- Processing higher-resolution videos takes longer due to encoding and decoding (which can take >95% of the processing time!), but it does not improve the quality of results.
|
120 |
+
- The CLIP encoder resizes input frames to 384×384 pixels.
|
121 |
+
- Synchformer resizes the shorter edge to 224 pixels and applies a center crop, focusing only on the central square of each frame.
|
122 |
+
2. Frame rates
|
123 |
+
- The CLIP model operates at 8 FPS, while Synchformer works at 25 FPS.
|
124 |
+
- Frame rate conversion happens on-the-fly via the video reader.
|
125 |
+
- For input videos with a frame rate below 25 FPS, frames will be duplicated to match the required rate.
|
126 |
+
3. Failure cases
|
127 |
+
As with most models of this type, failures can occur, and the reasons are not always clear. Below are some known failure modes. If you notice a failure mode or believe there’s a bug, feel free to open an issue in the repository.
|
128 |
+
4. Performance variations
|
129 |
+
We notice that there can be subtle performance variations in different hardware and software environments. Some of the reasons include using/not using `torch.compile`, video reader library/backend, inference precision, batch sizes, random seeds, etc. We (will) provide pre-computed results on standard benchmark for reference. Results obtained from this codebase should be similar but might not be exactly the same.
|
130 |
+
|
131 |
+
### Known limitations
|
132 |
+
|
133 |
+
1. The model sometimes generates unintelligible human speech-like sounds
|
134 |
+
2. The model sometimes generates background music (without explicit training, it would not be high quality)
|
135 |
+
3. The model struggles with unfamiliar concepts, e.g., it can generate "gunfires" but not "RPG firing".
|
136 |
+
|
137 |
+
We believe all of these three limitations can be addressed with more high-quality training data.
|
138 |
+
|
139 |
+
## Training
|
140 |
+
|
141 |
+
See [TRAINING.md](docs/TRAINING.md).
|
142 |
+
|
143 |
+
## Evaluation
|
144 |
+
|
145 |
+
See [EVAL.md](docs/EVAL.md).
|
146 |
+
|
147 |
+
## Training Datasets
|
148 |
+
|
149 |
+
MMAudio was trained on several datasets, including [AudioSet](https://research.google.com/audioset/), [Freesound](https://github.com/LAION-AI/audio-dataset/blob/main/laion-audio-630k/README.md), [VGGSound](https://www.robots.ox.ac.uk/~vgg/data/vggsound/), [AudioCaps](https://audiocaps.github.io/), and [WavCaps](https://github.com/XinhaoMei/WavCaps). These datasets are subject to specific licenses, which can be accessed on their respective websites. We do not guarantee that the pre-trained models are suitable for commercial use. Please use them at your own risk.
|
150 |
+
|
151 |
+
## Update Logs
|
152 |
+
|
153 |
+
- 2025-03-09: Uploaded the corrected tsv files. See [TRAINING.md](docs/TRAINING.md).
|
154 |
+
- 2025-02-27: Disabled the GradScaler by default to improve training stability. See #49.
|
155 |
+
- 2024-12-23: Added training and batch evaluation scripts.
|
156 |
+
- 2024-12-14: Removed the `ffmpeg<7` requirement for the demos by replacing `torio.io.StreamingMediaDecoder` with `pyav` for reading frames. The read frames are also cached, so we are not reading the same frames again during reconstruction. This should speed things up and make installation less of a hassle.
|
157 |
+
- 2024-12-13: Improved for-loop processing in CLIP/Sync feature extraction by introducing a batch size multiplier. We can approximately use 40x batch size for CLIP/Sync without using more memory, thereby speeding up processing. Removed VAE encoder during inference -- we don't need it.
|
158 |
+
- 2024-12-11: Replaced `torio.io.StreamingMediaDecoder` with `pyav` for reading framerate when reconstructing the input video. `torio.io.StreamingMediaDecoder` does not work reliably in huggingface ZeroGPU's environment, and I suspect that it might not work in some other environments as well.
|
159 |
+
|
160 |
+
## Citation
|
161 |
+
|
162 |
+
```bibtex
|
163 |
+
@inproceedings{cheng2025taming,
|
164 |
+
title={Taming Multimodal Joint Training for High-Quality Video-to-Audio Synthesis},
|
165 |
+
author={Cheng, Ho Kei and Ishii, Masato and Hayakawa, Akio and Shibuya, Takashi and Schwing, Alexander and Mitsufuji, Yuki},
|
166 |
+
booktitle={CVPR},
|
167 |
+
year={2025}
|
168 |
+
}
|
169 |
+
```
|
170 |
+
|
171 |
+
## Relevant Repositories
|
172 |
+
|
173 |
+
- [av-benchmark](https://github.com/hkchengrex/av-benchmark) for benchmarking results.
|
174 |
+
|
175 |
+
## Disclaimer
|
176 |
+
|
177 |
+
We have no affiliation with and have no knowledge of the party behind the domain "mmaudio.net".
|
178 |
+
|
179 |
+
## Acknowledgement
|
180 |
+
|
181 |
+
Many thanks to:
|
182 |
+
- [Make-An-Audio 2](https://github.com/bytedance/Make-An-Audio-2) for the 16kHz BigVGAN pretrained model and the VAE architecture
|
183 |
+
- [BigVGAN](https://github.com/NVIDIA/BigVGAN)
|
184 |
+
- [Synchformer](https://github.com/v-iashin/Synchformer)
|
185 |
+
- [EDM2](https://github.com/NVlabs/edm2) for the magnitude-preserving VAE network architecture
|
app.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
from datetime import datetime
|
5 |
+
from fractions import Fraction
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import torch
|
10 |
+
import torchaudio
|
11 |
+
|
12 |
+
from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
|
13 |
+
load_video, make_video, setup_eval_logging)
|
14 |
+
from mmaudio.model.flow_matching import FlowMatching
|
15 |
+
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
16 |
+
from mmaudio.model.sequence_config import SequenceConfig
|
17 |
+
from mmaudio.model.utils.features_utils import FeaturesUtils
|
18 |
+
|
19 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
20 |
+
torch.backends.cudnn.allow_tf32 = True
|
21 |
+
|
22 |
+
log = logging.getLogger()
|
23 |
+
|
24 |
+
device = 'cpu'
|
25 |
+
if torch.cuda.is_available():
|
26 |
+
device = 'cuda'
|
27 |
+
elif torch.backends.mps.is_available():
|
28 |
+
device = 'mps'
|
29 |
+
else:
|
30 |
+
log.warning('CUDA/MPS are not available, running on CPU')
|
31 |
+
dtype = torch.bfloat16
|
32 |
+
|
33 |
+
model: ModelConfig = all_model_cfg['large_44k_v2']
|
34 |
+
model.download_if_needed()
|
35 |
+
output_dir = Path('./output/gradio')
|
36 |
+
|
37 |
+
setup_eval_logging()
|
38 |
+
|
39 |
+
|
40 |
+
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
41 |
+
seq_cfg = model.seq_cfg
|
42 |
+
|
43 |
+
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
|
44 |
+
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
45 |
+
log.info(f'Loaded weights from {model.model_path}')
|
46 |
+
|
47 |
+
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
48 |
+
synchformer_ckpt=model.synchformer_ckpt,
|
49 |
+
enable_conditions=True,
|
50 |
+
mode=model.mode,
|
51 |
+
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
52 |
+
need_vae_encoder=False)
|
53 |
+
feature_utils = feature_utils.to(device, dtype).eval()
|
54 |
+
|
55 |
+
return net, feature_utils, seq_cfg
|
56 |
+
|
57 |
+
|
58 |
+
net, feature_utils, seq_cfg = get_model()
|
59 |
+
|
60 |
+
|
61 |
+
@torch.inference_mode()
|
62 |
+
def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
63 |
+
cfg_strength: float, duration: float):
|
64 |
+
|
65 |
+
rng = torch.Generator(device=device)
|
66 |
+
if seed >= 0:
|
67 |
+
rng.manual_seed(seed)
|
68 |
+
else:
|
69 |
+
rng.seed()
|
70 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
71 |
+
|
72 |
+
video_info = load_video(video, duration)
|
73 |
+
clip_frames = video_info.clip_frames
|
74 |
+
sync_frames = video_info.sync_frames
|
75 |
+
duration = video_info.duration_sec
|
76 |
+
clip_frames = clip_frames.unsqueeze(0)
|
77 |
+
sync_frames = sync_frames.unsqueeze(0)
|
78 |
+
seq_cfg.duration = duration
|
79 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
80 |
+
|
81 |
+
audios = generate(clip_frames,
|
82 |
+
sync_frames, [prompt],
|
83 |
+
negative_text=[negative_prompt],
|
84 |
+
feature_utils=feature_utils,
|
85 |
+
net=net,
|
86 |
+
fm=fm,
|
87 |
+
rng=rng,
|
88 |
+
cfg_strength=cfg_strength)
|
89 |
+
audio = audios.float().cpu()[0]
|
90 |
+
|
91 |
+
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
92 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
93 |
+
video_save_path = output_dir / f'{current_time_string}.mp4'
|
94 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
95 |
+
gc.collect()
|
96 |
+
return video_save_path
|
97 |
+
|
98 |
+
|
99 |
+
@torch.inference_mode()
|
100 |
+
def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
101 |
+
cfg_strength: float, duration: float):
|
102 |
+
|
103 |
+
rng = torch.Generator(device=device)
|
104 |
+
if seed >= 0:
|
105 |
+
rng.manual_seed(seed)
|
106 |
+
else:
|
107 |
+
rng.seed()
|
108 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
109 |
+
|
110 |
+
image_info = load_image(image)
|
111 |
+
clip_frames = image_info.clip_frames
|
112 |
+
sync_frames = image_info.sync_frames
|
113 |
+
clip_frames = clip_frames.unsqueeze(0)
|
114 |
+
sync_frames = sync_frames.unsqueeze(0)
|
115 |
+
seq_cfg.duration = duration
|
116 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
117 |
+
|
118 |
+
audios = generate(clip_frames,
|
119 |
+
sync_frames, [prompt],
|
120 |
+
negative_text=[negative_prompt],
|
121 |
+
feature_utils=feature_utils,
|
122 |
+
net=net,
|
123 |
+
fm=fm,
|
124 |
+
rng=rng,
|
125 |
+
cfg_strength=cfg_strength,
|
126 |
+
image_input=True)
|
127 |
+
audio = audios.float().cpu()[0]
|
128 |
+
|
129 |
+
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
130 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
131 |
+
video_save_path = output_dir / f'{current_time_string}.mp4'
|
132 |
+
video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1))
|
133 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
134 |
+
gc.collect()
|
135 |
+
return video_save_path
|
136 |
+
|
137 |
+
|
138 |
+
@torch.inference_mode()
|
139 |
+
def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
|
140 |
+
duration: float):
|
141 |
+
|
142 |
+
rng = torch.Generator(device=device)
|
143 |
+
if seed >= 0:
|
144 |
+
rng.manual_seed(seed)
|
145 |
+
else:
|
146 |
+
rng.seed()
|
147 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
148 |
+
|
149 |
+
clip_frames = sync_frames = None
|
150 |
+
seq_cfg.duration = duration
|
151 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
152 |
+
|
153 |
+
audios = generate(clip_frames,
|
154 |
+
sync_frames, [prompt],
|
155 |
+
negative_text=[negative_prompt],
|
156 |
+
feature_utils=feature_utils,
|
157 |
+
net=net,
|
158 |
+
fm=fm,
|
159 |
+
rng=rng,
|
160 |
+
cfg_strength=cfg_strength)
|
161 |
+
audio = audios.float().cpu()[0]
|
162 |
+
|
163 |
+
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
164 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
165 |
+
audio_save_path = output_dir / f'{current_time_string}.flac'
|
166 |
+
torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
|
167 |
+
gc.collect()
|
168 |
+
return audio_save_path
|
169 |
+
|
170 |
+
|
171 |
+
video_to_audio_tab = gr.Interface(
|
172 |
+
fn=video_to_audio,
|
173 |
+
description="""
|
174 |
+
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
175 |
+
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
176 |
+
|
177 |
+
NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
|
178 |
+
Doing so does not improve results.
|
179 |
+
""",
|
180 |
+
inputs=[
|
181 |
+
gr.Video(),
|
182 |
+
gr.Text(label='Prompt'),
|
183 |
+
gr.Text(label='Negative prompt', value='music'),
|
184 |
+
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
185 |
+
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
186 |
+
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
187 |
+
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
188 |
+
],
|
189 |
+
outputs='playable_video',
|
190 |
+
cache_examples=False,
|
191 |
+
title='MMAudio — Video-to-Audio Synthesis',
|
192 |
+
examples=[
|
193 |
+
[
|
194 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
|
195 |
+
'waves, seagulls',
|
196 |
+
'',
|
197 |
+
0,
|
198 |
+
25,
|
199 |
+
4.5,
|
200 |
+
10,
|
201 |
+
],
|
202 |
+
[
|
203 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4',
|
204 |
+
'',
|
205 |
+
'music',
|
206 |
+
0,
|
207 |
+
25,
|
208 |
+
4.5,
|
209 |
+
10,
|
210 |
+
],
|
211 |
+
[
|
212 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4',
|
213 |
+
'bubbles',
|
214 |
+
'',
|
215 |
+
0,
|
216 |
+
25,
|
217 |
+
4.5,
|
218 |
+
10,
|
219 |
+
],
|
220 |
+
[
|
221 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4',
|
222 |
+
'Indian holy music',
|
223 |
+
'',
|
224 |
+
0,
|
225 |
+
25,
|
226 |
+
4.5,
|
227 |
+
10,
|
228 |
+
],
|
229 |
+
[
|
230 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4',
|
231 |
+
'galloping',
|
232 |
+
'',
|
233 |
+
0,
|
234 |
+
25,
|
235 |
+
4.5,
|
236 |
+
10,
|
237 |
+
],
|
238 |
+
[
|
239 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
|
240 |
+
'waves, storm',
|
241 |
+
'',
|
242 |
+
0,
|
243 |
+
25,
|
244 |
+
4.5,
|
245 |
+
10,
|
246 |
+
],
|
247 |
+
[
|
248 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
|
249 |
+
'storm',
|
250 |
+
'',
|
251 |
+
0,
|
252 |
+
25,
|
253 |
+
4.5,
|
254 |
+
10,
|
255 |
+
],
|
256 |
+
[
|
257 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
|
258 |
+
'',
|
259 |
+
'',
|
260 |
+
0,
|
261 |
+
25,
|
262 |
+
4.5,
|
263 |
+
10,
|
264 |
+
],
|
265 |
+
[
|
266 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
|
267 |
+
'typing',
|
268 |
+
'',
|
269 |
+
0,
|
270 |
+
25,
|
271 |
+
4.5,
|
272 |
+
10,
|
273 |
+
],
|
274 |
+
[
|
275 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
|
276 |
+
'',
|
277 |
+
'',
|
278 |
+
0,
|
279 |
+
25,
|
280 |
+
4.5,
|
281 |
+
10,
|
282 |
+
],
|
283 |
+
[
|
284 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
|
285 |
+
'',
|
286 |
+
'',
|
287 |
+
0,
|
288 |
+
25,
|
289 |
+
4.5,
|
290 |
+
10,
|
291 |
+
],
|
292 |
+
])
|
293 |
+
|
294 |
+
text_to_audio_tab = gr.Interface(
|
295 |
+
fn=text_to_audio,
|
296 |
+
description="""
|
297 |
+
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
298 |
+
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
299 |
+
""",
|
300 |
+
inputs=[
|
301 |
+
gr.Text(label='Prompt'),
|
302 |
+
gr.Text(label='Negative prompt'),
|
303 |
+
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
304 |
+
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
305 |
+
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
306 |
+
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
307 |
+
],
|
308 |
+
outputs='audio',
|
309 |
+
cache_examples=False,
|
310 |
+
title='MMAudio — Text-to-Audio Synthesis',
|
311 |
+
)
|
312 |
+
|
313 |
+
image_to_audio_tab = gr.Interface(
|
314 |
+
fn=image_to_audio,
|
315 |
+
description="""
|
316 |
+
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
317 |
+
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
318 |
+
|
319 |
+
NOTE: It takes longer to process high-resolution images (>384 px on the shorter side).
|
320 |
+
Doing so does not improve results.
|
321 |
+
""",
|
322 |
+
inputs=[
|
323 |
+
gr.Image(type='filepath'),
|
324 |
+
gr.Text(label='Prompt'),
|
325 |
+
gr.Text(label='Negative prompt'),
|
326 |
+
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
327 |
+
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
328 |
+
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
329 |
+
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
330 |
+
],
|
331 |
+
outputs='playable_video',
|
332 |
+
cache_examples=False,
|
333 |
+
title='MMAudio — Image-to-Audio Synthesis (experimental)',
|
334 |
+
)
|
335 |
+
|
336 |
+
if __name__ == "__main__":
|
337 |
+
parser = ArgumentParser()
|
338 |
+
parser.add_argument('--port', type=int, default=7860)
|
339 |
+
args = parser.parse_args()
|
340 |
+
|
341 |
+
gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab],
|
342 |
+
['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']).launch(
|
343 |
+
server_port=args.port, allowed_paths=[output_dir])
|
batch_eval.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import hydra
|
6 |
+
import torch
|
7 |
+
import torch.distributed as distributed
|
8 |
+
import torchaudio
|
9 |
+
from hydra.core.hydra_config import HydraConfig
|
10 |
+
from omegaconf import DictConfig
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from mmaudio.data.data_setup import setup_eval_dataset
|
14 |
+
from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate
|
15 |
+
from mmaudio.model.flow_matching import FlowMatching
|
16 |
+
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
17 |
+
from mmaudio.model.utils.features_utils import FeaturesUtils
|
18 |
+
|
19 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
20 |
+
torch.backends.cudnn.allow_tf32 = True
|
21 |
+
|
22 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
23 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
24 |
+
log = logging.getLogger()
|
25 |
+
|
26 |
+
|
27 |
+
@torch.inference_mode()
|
28 |
+
@hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')
|
29 |
+
def main(cfg: DictConfig):
|
30 |
+
device = 'cuda'
|
31 |
+
torch.cuda.set_device(local_rank)
|
32 |
+
|
33 |
+
if cfg.model not in all_model_cfg:
|
34 |
+
raise ValueError(f'Unknown model variant: {cfg.model}')
|
35 |
+
model: ModelConfig = all_model_cfg[cfg.model]
|
36 |
+
model.download_if_needed()
|
37 |
+
seq_cfg = model.seq_cfg
|
38 |
+
|
39 |
+
run_dir = Path(HydraConfig.get().run.dir)
|
40 |
+
if cfg.output_name is None:
|
41 |
+
output_dir = run_dir / cfg.dataset
|
42 |
+
else:
|
43 |
+
output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}'
|
44 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
45 |
+
|
46 |
+
# load a pretrained model
|
47 |
+
seq_cfg.duration = cfg.duration_s
|
48 |
+
net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval()
|
49 |
+
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
50 |
+
log.info(f'Loaded weights from {model.model_path}')
|
51 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
52 |
+
log.info(f'Latent seq len: {seq_cfg.latent_seq_len}')
|
53 |
+
log.info(f'Clip seq len: {seq_cfg.clip_seq_len}')
|
54 |
+
log.info(f'Sync seq len: {seq_cfg.sync_seq_len}')
|
55 |
+
|
56 |
+
# misc setup
|
57 |
+
rng = torch.Generator(device=device)
|
58 |
+
rng.manual_seed(cfg.seed)
|
59 |
+
fm = FlowMatching(cfg.sampling.min_sigma,
|
60 |
+
inference_mode=cfg.sampling.method,
|
61 |
+
num_steps=cfg.sampling.num_steps)
|
62 |
+
|
63 |
+
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
64 |
+
synchformer_ckpt=model.synchformer_ckpt,
|
65 |
+
enable_conditions=True,
|
66 |
+
mode=model.mode,
|
67 |
+
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
68 |
+
need_vae_encoder=False)
|
69 |
+
feature_utils = feature_utils.to(device).eval()
|
70 |
+
|
71 |
+
if cfg.compile:
|
72 |
+
net.preprocess_conditions = torch.compile(net.preprocess_conditions)
|
73 |
+
net.predict_flow = torch.compile(net.predict_flow)
|
74 |
+
feature_utils.compile()
|
75 |
+
|
76 |
+
dataset, loader = setup_eval_dataset(cfg.dataset, cfg)
|
77 |
+
|
78 |
+
with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):
|
79 |
+
for batch in tqdm(loader):
|
80 |
+
audios = generate(batch.get('clip_video', None),
|
81 |
+
batch.get('sync_video', None),
|
82 |
+
batch.get('caption', None),
|
83 |
+
feature_utils=feature_utils,
|
84 |
+
net=net,
|
85 |
+
fm=fm,
|
86 |
+
rng=rng,
|
87 |
+
cfg_strength=cfg.cfg_strength,
|
88 |
+
clip_batch_size_multiplier=64,
|
89 |
+
sync_batch_size_multiplier=64)
|
90 |
+
audios = audios.float().cpu()
|
91 |
+
names = batch['name']
|
92 |
+
for audio, name in zip(audios, names):
|
93 |
+
torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)
|
94 |
+
|
95 |
+
|
96 |
+
def distributed_setup():
|
97 |
+
distributed.init_process_group(backend="nccl")
|
98 |
+
local_rank = distributed.get_rank()
|
99 |
+
world_size = distributed.get_world_size()
|
100 |
+
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
|
101 |
+
return local_rank, world_size
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == '__main__':
|
105 |
+
distributed_setup()
|
106 |
+
|
107 |
+
main()
|
108 |
+
|
109 |
+
# clean-up
|
110 |
+
distributed.destroy_process_group()
|
config/__init__.py
ADDED
File without changes
|
config/base_config.yaml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- data: base
|
3 |
+
- eval_data: base
|
4 |
+
- override hydra/job_logging: custom-simplest
|
5 |
+
- _self_
|
6 |
+
|
7 |
+
hydra:
|
8 |
+
run:
|
9 |
+
dir: ./output/${exp_id}
|
10 |
+
output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
|
11 |
+
|
12 |
+
enable_email: False
|
13 |
+
|
14 |
+
model: small_16k
|
15 |
+
|
16 |
+
exp_id: default
|
17 |
+
debug: False
|
18 |
+
cudnn_benchmark: True
|
19 |
+
compile: True
|
20 |
+
amp: True
|
21 |
+
weights: null
|
22 |
+
checkpoint: null
|
23 |
+
seed: 14159265
|
24 |
+
num_workers: 10 # per-GPU
|
25 |
+
pin_memory: False # set to True if your system can handle it, i.e., have enough memory
|
26 |
+
|
27 |
+
# NOTE: This DOSE NOT affect the model during inference in any way
|
28 |
+
# they are just for the dataloader to fill in the missing data in multi-modal loading
|
29 |
+
# to change the sequence length for the model, see networks.py
|
30 |
+
data_dim:
|
31 |
+
text_seq_len: 77
|
32 |
+
clip_dim: 1024
|
33 |
+
sync_dim: 768
|
34 |
+
text_dim: 1024
|
35 |
+
|
36 |
+
# ema configuration
|
37 |
+
ema:
|
38 |
+
enable: True
|
39 |
+
sigma_rels: [0.05, 0.1]
|
40 |
+
update_every: 1
|
41 |
+
checkpoint_every: 5_000
|
42 |
+
checkpoint_folder: ${hydra:run.dir}/ema_ckpts
|
43 |
+
default_output_sigma: 0.05
|
44 |
+
|
45 |
+
|
46 |
+
# sampling
|
47 |
+
sampling:
|
48 |
+
mean: 0.0
|
49 |
+
scale: 1.0
|
50 |
+
min_sigma: 0.0
|
51 |
+
method: euler
|
52 |
+
num_steps: 25
|
53 |
+
|
54 |
+
# classifier-free guidance
|
55 |
+
null_condition_probability: 0.1
|
56 |
+
cfg_strength: 4.5
|
57 |
+
|
58 |
+
# checkpoint paths to external modules
|
59 |
+
vae_16k_ckpt: ./ext_weights/v1-16.pth
|
60 |
+
vae_44k_ckpt: ./ext_weights/v1-44.pth
|
61 |
+
bigvgan_vocoder_ckpt: ./ext_weights/best_netG.pt
|
62 |
+
synchformer_ckpt: ./ext_weights/synchformer_state_dict.pth
|
config/data/base.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
VGGSound:
|
2 |
+
root: ../data/video
|
3 |
+
subset_name: sets/vgg3-train.tsv
|
4 |
+
fps: 8
|
5 |
+
height: 384
|
6 |
+
width: 384
|
7 |
+
sample_duration_sec: 8.0
|
8 |
+
|
9 |
+
VGGSound_test:
|
10 |
+
root: ../data/video
|
11 |
+
subset_name: sets/vgg3-test.tsv
|
12 |
+
fps: 8
|
13 |
+
height: 384
|
14 |
+
width: 384
|
15 |
+
sample_duration_sec: 8.0
|
16 |
+
|
17 |
+
VGGSound_val:
|
18 |
+
root: ../data/video
|
19 |
+
subset_name: sets/vgg3-val.tsv
|
20 |
+
fps: 8
|
21 |
+
height: 384
|
22 |
+
width: 384
|
23 |
+
sample_duration_sec: 8.0
|
24 |
+
|
25 |
+
ExtractedVGG:
|
26 |
+
tsv: ../data/v1-16-memmap/vgg-train.tsv
|
27 |
+
memmap_dir: ../data/v1-16-memmap/vgg-train
|
28 |
+
|
29 |
+
ExtractedVGG_test:
|
30 |
+
tag: test
|
31 |
+
gt_cache: ../data/eval-cache/vggsound-test
|
32 |
+
output_subdir: null
|
33 |
+
tsv: ../data/v1-16-memmap/vgg-test.tsv
|
34 |
+
memmap_dir: ../data/v1-16-memmap/vgg-test
|
35 |
+
|
36 |
+
ExtractedVGG_val:
|
37 |
+
tag: val
|
38 |
+
gt_cache: ../data/eval-cache/vggsound-val
|
39 |
+
output_subdir: val
|
40 |
+
tsv: ../data/v1-16-memmap/vgg-val.tsv
|
41 |
+
memmap_dir: ../data/v1-16-memmap/vgg-val
|
42 |
+
|
43 |
+
AudioCaps:
|
44 |
+
tsv: ../data/v1-16-memmap/audiocaps.tsv
|
45 |
+
memmap_dir: ../data/v1-16-memmap/audiocaps
|
46 |
+
|
47 |
+
AudioSetSL:
|
48 |
+
tsv: ../data/v1-16-memmap/audioset_sl.tsv
|
49 |
+
memmap_dir: ../data/v1-16-memmap/audioset_sl
|
50 |
+
|
51 |
+
BBCSound:
|
52 |
+
tsv: ../data/v1-16-memmap/bbcsound.tsv
|
53 |
+
memmap_dir: ../data/v1-16-memmap/bbcsound
|
54 |
+
|
55 |
+
FreeSound:
|
56 |
+
tsv: ../data/v1-16-memmap/freesound.tsv
|
57 |
+
memmap_dir: ../data/v1-16-memmap/freesound
|
58 |
+
|
59 |
+
Clotho:
|
60 |
+
tsv: ../data/v1-16-memmap/clotho.tsv
|
61 |
+
memmap_dir: ../data/v1-16-memmap/clotho
|
62 |
+
|
63 |
+
Example_video:
|
64 |
+
tsv: ./training/example_output/memmap/vgg-example.tsv
|
65 |
+
memmap_dir: ./training/example_output/memmap/vgg-example
|
66 |
+
|
67 |
+
Example_audio:
|
68 |
+
tsv: ./training/example_output/memmap/audio-example.tsv
|
69 |
+
memmap_dir: ./training/example_output/memmap/audio-example
|
70 |
+
|
config/eval_config.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- base_config
|
3 |
+
- override hydra/job_logging: custom-simplest
|
4 |
+
- _self_
|
5 |
+
|
6 |
+
hydra:
|
7 |
+
run:
|
8 |
+
dir: ./output/${exp_id}
|
9 |
+
output_subdir: eval-${now:%Y-%m-%d_%H-%M-%S}-hydra
|
10 |
+
|
11 |
+
exp_id: ${model}
|
12 |
+
dataset: audiocaps
|
13 |
+
duration_s: 8.0
|
14 |
+
|
15 |
+
# for inference, this is the per-GPU batch size
|
16 |
+
batch_size: 16
|
17 |
+
output_name: null
|
config/eval_data/base.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioCaps:
|
2 |
+
audio_path: ../data/AudioCaps-test-audioldm-ver
|
3 |
+
# a csv file, with a header row of 'name' and 'caption'
|
4 |
+
# name should match the audio file name without extension
|
5 |
+
# Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_audioldm_data.csv
|
6 |
+
csv_path: ../data/AudioCaps-test-audioldm-ver/data.csv
|
7 |
+
|
8 |
+
AudioCaps_full:
|
9 |
+
audio_path: ../data/AudioCaps-test-full-ver
|
10 |
+
# a csv file, with a header row of 'name' and 'caption'
|
11 |
+
# name should match the audio file name without extension
|
12 |
+
# Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_full_data.csv
|
13 |
+
csv_path: ../data/AudioCaps-test-full-ver/data.csv
|
14 |
+
|
15 |
+
MovieGen:
|
16 |
+
video_path: ../data/MovieGen/MovieGenAudioBenchSfx/video_with_audio
|
17 |
+
jsonl_path: ../data/MovieGen/MovieGenAudioBenchSfx/metadata
|
18 |
+
|
19 |
+
VGGSound:
|
20 |
+
video_path: ../data/test-videos
|
21 |
+
# from the officially released csv file
|
22 |
+
csv_path: ../data/vggsound.csv
|
config/hydra/job_logging/custom-eval.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python logging configuration for tasks
|
2 |
+
version: 1
|
3 |
+
formatters:
|
4 |
+
simple:
|
5 |
+
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
|
6 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
+
colorlog:
|
8 |
+
'()': 'colorlog.ColoredFormatter'
|
9 |
+
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
10 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
11 |
+
log_colors:
|
12 |
+
DEBUG: purple
|
13 |
+
INFO: green
|
14 |
+
WARNING: yellow
|
15 |
+
ERROR: red
|
16 |
+
CRITICAL: red
|
17 |
+
handlers:
|
18 |
+
console:
|
19 |
+
class: logging.StreamHandler
|
20 |
+
formatter: colorlog
|
21 |
+
stream: ext://sys.stdout
|
22 |
+
file:
|
23 |
+
class: logging.FileHandler
|
24 |
+
formatter: simple
|
25 |
+
# absolute file path
|
26 |
+
filename: ${hydra.runtime.output_dir}/eval-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
|
27 |
+
mode: w
|
28 |
+
root:
|
29 |
+
level: INFO
|
30 |
+
handlers: [console, file]
|
31 |
+
|
32 |
+
disable_existing_loggers: false
|
config/hydra/job_logging/custom-no-rank.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python logging configuration for tasks
|
2 |
+
version: 1
|
3 |
+
formatters:
|
4 |
+
simple:
|
5 |
+
format: '[%(asctime)s][%(levelname)s] - %(message)s'
|
6 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
+
colorlog:
|
8 |
+
'()': 'colorlog.ColoredFormatter'
|
9 |
+
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
10 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
11 |
+
log_colors:
|
12 |
+
DEBUG: purple
|
13 |
+
INFO: green
|
14 |
+
WARNING: yellow
|
15 |
+
ERROR: red
|
16 |
+
CRITICAL: red
|
17 |
+
handlers:
|
18 |
+
console:
|
19 |
+
class: logging.StreamHandler
|
20 |
+
formatter: colorlog
|
21 |
+
stream: ext://sys.stdout
|
22 |
+
file:
|
23 |
+
class: logging.FileHandler
|
24 |
+
formatter: simple
|
25 |
+
# absolute file path
|
26 |
+
filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
|
27 |
+
mode: w
|
28 |
+
root:
|
29 |
+
level: INFO
|
30 |
+
handlers: [console, file]
|
31 |
+
|
32 |
+
disable_existing_loggers: false
|
config/hydra/job_logging/custom-simplest.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python logging configuration for tasks
|
2 |
+
version: 1
|
3 |
+
formatters:
|
4 |
+
simple:
|
5 |
+
format: '[%(asctime)s][%(levelname)s] - %(message)s'
|
6 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
+
colorlog:
|
8 |
+
'()': 'colorlog.ColoredFormatter'
|
9 |
+
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
10 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
11 |
+
log_colors:
|
12 |
+
DEBUG: purple
|
13 |
+
INFO: green
|
14 |
+
WARNING: yellow
|
15 |
+
ERROR: red
|
16 |
+
CRITICAL: red
|
17 |
+
handlers:
|
18 |
+
console:
|
19 |
+
class: logging.StreamHandler
|
20 |
+
formatter: colorlog
|
21 |
+
stream: ext://sys.stdout
|
22 |
+
root:
|
23 |
+
level: INFO
|
24 |
+
handlers: [console]
|
25 |
+
|
26 |
+
disable_existing_loggers: false
|
config/hydra/job_logging/custom.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package hydra.job_logging
|
2 |
+
# python logging configuration for tasks
|
3 |
+
version: 1
|
4 |
+
formatters:
|
5 |
+
simple:
|
6 |
+
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
|
7 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
8 |
+
colorlog:
|
9 |
+
'()': 'colorlog.ColoredFormatter'
|
10 |
+
format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)sr${oc.env:LOCAL_RANK}%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
11 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
12 |
+
log_colors:
|
13 |
+
DEBUG: purple
|
14 |
+
INFO: green
|
15 |
+
WARNING: yellow
|
16 |
+
ERROR: red
|
17 |
+
CRITICAL: red
|
18 |
+
handlers:
|
19 |
+
console:
|
20 |
+
class: logging.StreamHandler
|
21 |
+
formatter: colorlog
|
22 |
+
stream: ext://sys.stdout
|
23 |
+
file:
|
24 |
+
class: logging.FileHandler
|
25 |
+
formatter: simple
|
26 |
+
# absolute file path
|
27 |
+
filename: ${hydra.runtime.output_dir}/train-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
|
28 |
+
mode: w
|
29 |
+
root:
|
30 |
+
level: INFO
|
31 |
+
handlers: [console, file]
|
32 |
+
|
33 |
+
disable_existing_loggers: false
|
config/train_config.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- base_config
|
3 |
+
- override data: base
|
4 |
+
- override hydra/job_logging: custom
|
5 |
+
- _self_
|
6 |
+
|
7 |
+
hydra:
|
8 |
+
run:
|
9 |
+
dir: ./output/${exp_id}
|
10 |
+
output_subdir: train-${now:%Y-%m-%d_%H-%M-%S}-hydra
|
11 |
+
|
12 |
+
ema:
|
13 |
+
start: 0
|
14 |
+
|
15 |
+
mini_train: False
|
16 |
+
example_train: False
|
17 |
+
enable_grad_scaler: False
|
18 |
+
vgg_oversample_rate: 5
|
19 |
+
|
20 |
+
log_text_interval: 200
|
21 |
+
log_extra_interval: 20_000
|
22 |
+
val_interval: 5_000
|
23 |
+
eval_interval: 20_000
|
24 |
+
save_eval_interval: 40_000
|
25 |
+
save_weights_interval: 10_000
|
26 |
+
save_checkpoint_interval: 10_000
|
27 |
+
save_copy_iterations: []
|
28 |
+
|
29 |
+
batch_size: 512
|
30 |
+
eval_batch_size: 256 # per-GPU
|
31 |
+
|
32 |
+
num_iterations: 300_000
|
33 |
+
learning_rate: 1.0e-4
|
34 |
+
linear_warmup_steps: 1_000
|
35 |
+
|
36 |
+
lr_schedule: step
|
37 |
+
lr_schedule_steps: [240_000, 270_000]
|
38 |
+
lr_schedule_gamma: 0.1
|
39 |
+
|
40 |
+
clip_grad_norm: 1.0
|
41 |
+
weight_decay: 1.0e-6
|
demo.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
|
9 |
+
setup_eval_logging)
|
10 |
+
from mmaudio.model.flow_matching import FlowMatching
|
11 |
+
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
12 |
+
from mmaudio.model.utils.features_utils import FeaturesUtils
|
13 |
+
|
14 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
15 |
+
torch.backends.cudnn.allow_tf32 = True
|
16 |
+
|
17 |
+
log = logging.getLogger()
|
18 |
+
|
19 |
+
|
20 |
+
@torch.inference_mode()
|
21 |
+
def main():
|
22 |
+
setup_eval_logging()
|
23 |
+
|
24 |
+
parser = ArgumentParser()
|
25 |
+
parser.add_argument('--variant',
|
26 |
+
type=str,
|
27 |
+
default='large_44k_v2',
|
28 |
+
help='small_16k, small_44k, medium_44k, large_44k, large_44k_v2')
|
29 |
+
parser.add_argument('--video', type=Path, help='Path to the video file')
|
30 |
+
parser.add_argument('--prompt', type=str, help='Input prompt', default='')
|
31 |
+
parser.add_argument('--negative_prompt', type=str, help='Negative prompt', default='')
|
32 |
+
parser.add_argument('--duration', type=float, default=8.0)
|
33 |
+
parser.add_argument('--cfg_strength', type=float, default=4.5)
|
34 |
+
parser.add_argument('--num_steps', type=int, default=25)
|
35 |
+
|
36 |
+
parser.add_argument('--mask_away_clip', action='store_true')
|
37 |
+
|
38 |
+
parser.add_argument('--output', type=Path, help='Output directory', default='./output')
|
39 |
+
parser.add_argument('--seed', type=int, help='Random seed', default=42)
|
40 |
+
parser.add_argument('--skip_video_composite', action='store_true')
|
41 |
+
parser.add_argument('--full_precision', action='store_true')
|
42 |
+
|
43 |
+
args = parser.parse_args()
|
44 |
+
|
45 |
+
if args.variant not in all_model_cfg:
|
46 |
+
raise ValueError(f'Unknown model variant: {args.variant}')
|
47 |
+
model: ModelConfig = all_model_cfg[args.variant]
|
48 |
+
model.download_if_needed()
|
49 |
+
seq_cfg = model.seq_cfg
|
50 |
+
|
51 |
+
if args.video:
|
52 |
+
video_path: Path = Path(args.video).expanduser()
|
53 |
+
else:
|
54 |
+
video_path = None
|
55 |
+
prompt: str = args.prompt
|
56 |
+
negative_prompt: str = args.negative_prompt
|
57 |
+
output_dir: str = args.output.expanduser()
|
58 |
+
seed: int = args.seed
|
59 |
+
num_steps: int = args.num_steps
|
60 |
+
duration: float = args.duration
|
61 |
+
cfg_strength: float = args.cfg_strength
|
62 |
+
skip_video_composite: bool = args.skip_video_composite
|
63 |
+
mask_away_clip: bool = args.mask_away_clip
|
64 |
+
|
65 |
+
device = 'cpu'
|
66 |
+
if torch.cuda.is_available():
|
67 |
+
device = 'cuda'
|
68 |
+
elif torch.backends.mps.is_available():
|
69 |
+
device = 'mps'
|
70 |
+
else:
|
71 |
+
log.warning('CUDA/MPS are not available, running on CPU')
|
72 |
+
dtype = torch.float32 if args.full_precision else torch.bfloat16
|
73 |
+
|
74 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
75 |
+
|
76 |
+
# load a pretrained model
|
77 |
+
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
|
78 |
+
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
79 |
+
log.info(f'Loaded weights from {model.model_path}')
|
80 |
+
|
81 |
+
# misc setup
|
82 |
+
rng = torch.Generator(device=device)
|
83 |
+
rng.manual_seed(seed)
|
84 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
85 |
+
|
86 |
+
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
87 |
+
synchformer_ckpt=model.synchformer_ckpt,
|
88 |
+
enable_conditions=True,
|
89 |
+
mode=model.mode,
|
90 |
+
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
91 |
+
need_vae_encoder=False)
|
92 |
+
feature_utils = feature_utils.to(device, dtype).eval()
|
93 |
+
|
94 |
+
if video_path is not None:
|
95 |
+
log.info(f'Using video {video_path}')
|
96 |
+
video_info = load_video(video_path, duration)
|
97 |
+
clip_frames = video_info.clip_frames
|
98 |
+
sync_frames = video_info.sync_frames
|
99 |
+
duration = video_info.duration_sec
|
100 |
+
if mask_away_clip:
|
101 |
+
clip_frames = None
|
102 |
+
else:
|
103 |
+
clip_frames = clip_frames.unsqueeze(0)
|
104 |
+
sync_frames = sync_frames.unsqueeze(0)
|
105 |
+
else:
|
106 |
+
log.info('No video provided -- text-to-audio mode')
|
107 |
+
clip_frames = sync_frames = None
|
108 |
+
|
109 |
+
seq_cfg.duration = duration
|
110 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
111 |
+
|
112 |
+
log.info(f'Prompt: {prompt}')
|
113 |
+
log.info(f'Negative prompt: {negative_prompt}')
|
114 |
+
|
115 |
+
audios = generate(clip_frames,
|
116 |
+
sync_frames, [prompt],
|
117 |
+
negative_text=[negative_prompt],
|
118 |
+
feature_utils=feature_utils,
|
119 |
+
net=net,
|
120 |
+
fm=fm,
|
121 |
+
rng=rng,
|
122 |
+
cfg_strength=cfg_strength)
|
123 |
+
audio = audios.float().cpu()[0]
|
124 |
+
if video_path is not None:
|
125 |
+
save_path = output_dir / f'{video_path.stem}.flac'
|
126 |
+
else:
|
127 |
+
safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')
|
128 |
+
save_path = output_dir / f'{safe_filename}.flac'
|
129 |
+
torchaudio.save(save_path, audio, seq_cfg.sampling_rate)
|
130 |
+
|
131 |
+
log.info(f'Audio saved to {save_path}')
|
132 |
+
if video_path is not None and not skip_video_composite:
|
133 |
+
video_save_path = output_dir / f'{video_path.stem}.mp4'
|
134 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
135 |
+
log.info(f'Video saved to {output_dir / video_save_path}')
|
136 |
+
|
137 |
+
log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
|
138 |
+
|
139 |
+
|
140 |
+
if __name__ == '__main__':
|
141 |
+
main()
|
docs/EVAL.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluation
|
2 |
+
|
3 |
+
## Batch Evaluation
|
4 |
+
|
5 |
+
To evaluate the model on a dataset, use the `batch_eval.py` script. It is significantly more efficient in large-scale evaluation compared to `demo.py`, supporting batched inference, multi-GPU inference, torch compilation, and skipping video compositions.
|
6 |
+
|
7 |
+
An example of running this script with four GPUs is as follows:
|
8 |
+
|
9 |
+
```bash
|
10 |
+
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=4 batch_eval.py duration_s=8 dataset=vggsound model=small_16k num_workers=8
|
11 |
+
```
|
12 |
+
|
13 |
+
You may need to update the data paths in `config/eval_data/base.yaml`.
|
14 |
+
More configuration options can be found in `config/base_config.yaml` and `config/eval_config.yaml`.
|
15 |
+
|
16 |
+
## Precomputed Results
|
17 |
+
|
18 |
+
Precomputed results for VGGSound, AudioCaps, and MovieGen are available here: https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results
|
19 |
+
|
20 |
+
## Obtaining Quantitative Metrics
|
21 |
+
|
22 |
+
Our evaluation code is available here: https://github.com/hkchengrex/av-benchmark
|
docs/MODELS.md
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pretrained models
|
2 |
+
|
3 |
+
The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`.
|
4 |
+
The models are also available at https://huggingface.co/hkchengrex/MMAudio/tree/main
|
5 |
+
|
6 |
+
| Model | Download link | File size |
|
7 |
+
| -------- | ------- | ------- |
|
8 |
+
| Flow prediction network, small 16kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_16k.pth" download="mmaudio_small_16k.pth">mmaudio_small_16k.pth</a> | 601M |
|
9 |
+
| Flow prediction network, small 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_44k.pth" download="mmaudio_small_44k.pth">mmaudio_small_44k.pth</a> | 601M |
|
10 |
+
| Flow prediction network, medium 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_medium_44k.pth" download="mmaudio_medium_44k.pth">mmaudio_medium_44k.pth</a> | 2.4G |
|
11 |
+
| Flow prediction network, large 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k.pth" download="mmaudio_large_44k.pth">mmaudio_large_44k.pth</a> | 3.9G |
|
12 |
+
| Flow prediction network, large 44.1kHz, v2 **(recommended)** | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth" download="mmaudio_large_44k_v2.pth">mmaudio_large_44k_v2.pth</a> | 3.9G |
|
13 |
+
| 16kHz VAE | <a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-16.pth">v1-16.pth</a> | 655M |
|
14 |
+
| 16kHz BigVGAN vocoder (from Make-An-Audio 2) |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/best_netG.pt">best_netG.pt</a> | 429M |
|
15 |
+
| 44.1kHz VAE |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth">v1-44.pth</a> | 1.2G |
|
16 |
+
| Synchformer visual encoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth">synchformer_state_dict.pth</a> | 907M |
|
17 |
+
|
18 |
+
To run the model, you need four components: a flow prediction network, visual feature extractors (Synchformer and CLIP, CLIP will be downloaded automatically), a VAE, and a vocoder. VAEs and vocoders are specific to the sampling rate (16kHz or 44.1kHz) and not model sizes.
|
19 |
+
The 44.1kHz vocoder will be downloaded automatically.
|
20 |
+
The `_v2` model performs worse in benchmarking (e.g., in Fréchet distance), but, in my experience, generalizes better to new data.
|
21 |
+
|
22 |
+
The expected directory structure (full):
|
23 |
+
|
24 |
+
```bash
|
25 |
+
MMAudio
|
26 |
+
├── ext_weights
|
27 |
+
│ ├── best_netG.pt
|
28 |
+
│ ├── synchformer_state_dict.pth
|
29 |
+
│ ├── v1-16.pth
|
30 |
+
│ └── v1-44.pth
|
31 |
+
├── weights
|
32 |
+
│ ├── mmaudio_small_16k.pth
|
33 |
+
│ ├── mmaudio_small_44k.pth
|
34 |
+
│ ├── mmaudio_medium_44k.pth
|
35 |
+
│ ├── mmaudio_large_44k.pth
|
36 |
+
│ └── mmaudio_large_44k_v2.pth
|
37 |
+
└── ...
|
38 |
+
```
|
39 |
+
|
40 |
+
The expected directory structure (minimal, for the recommended model only):
|
41 |
+
|
42 |
+
```bash
|
43 |
+
MMAudio
|
44 |
+
├── ext_weights
|
45 |
+
│ ├── synchformer_state_dict.pth
|
46 |
+
│ └── v1-44.pth
|
47 |
+
├── weights
|
48 |
+
│ └── mmaudio_large_44k_v2.pth
|
49 |
+
└── ...
|
50 |
+
```
|
docs/TRAINING.md
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Training
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
We have put a large emphasis on making training as fast as possible.
|
6 |
+
Consequently, some pre-processing steps are required.
|
7 |
+
|
8 |
+
Namely, before starting any training, we
|
9 |
+
|
10 |
+
1. Obtain training data as videos, audios, and captions.
|
11 |
+
2. Encode training audios into spectrograms and then with VAE into mean/std
|
12 |
+
3. Extract CLIP and synchronization features from videos
|
13 |
+
4. Extract CLIP features from text (captions)
|
14 |
+
5. Encode all extracted features into [MemoryMappedTensors](https://pytorch.org/tensordict/main/reference/generated/tensordict.MemoryMappedTensor.html) with [TensorDict](https://pytorch.org/tensordict/main/reference/tensordict.html)
|
15 |
+
|
16 |
+
**NOTE:** for maximum training speed (e.g., when training the base model with 2*H100s), you would need around 3~5 GB/s of random read speed. Spinning disks would not be able to catch up and most consumer-grade SSDs would struggle. In my experience, the best bet is to have a large enough system memory such that the OS can cache the data. This way, the data is read from RAM instead of disk.
|
17 |
+
|
18 |
+
The current training script does not support `_v2` training.
|
19 |
+
|
20 |
+
## Recommended Hardware Configuration
|
21 |
+
|
22 |
+
These are what I recommend for a smooth and efficient training experience. These are not minimum requirements.
|
23 |
+
|
24 |
+
- Single-node machine. We did not implement multi-node training
|
25 |
+
- GPUs: for the small model, two 80G-H100s or above; for the large model, eight 80G-H100s or above
|
26 |
+
- System memory: for 16kHz training, 600GB+; for 44kHz training, 700GB+
|
27 |
+
- Storage: >2TB of fast NVMe storage. If you have enough system memory, OS caching will help and the storage does not need to be as fast.
|
28 |
+
|
29 |
+
## Prerequisites
|
30 |
+
|
31 |
+
1. Install [av-benchmark](https://github.com/hkchengrex/av-benchmark). We use this library to automatically evaluate on the validation set during training, and on the test set after training.
|
32 |
+
2. Extract features for evaluation using [av-benchmark](https://github.com/hkchengrex/av-benchmark) for the validation and test set as a [validation cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L38) and a [test cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L31). You can also download the precomputed evaluation cache [here](https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results/tree/main).
|
33 |
+
|
34 |
+
3. You will need ffmpeg to extract frames from videos. Note that `torchaudio` imposes a maximum version limit (`ffmpeg<7`). You can install it as follows:
|
35 |
+
|
36 |
+
```bash
|
37 |
+
conda install -c conda-forge 'ffmpeg<7'
|
38 |
+
```
|
39 |
+
|
40 |
+
4. Download the training datasets. We used [VGGSound](https://arxiv.org/abs/2004.14368), [AudioCaps](https://audiocaps.github.io/), [WavCaps](https://arxiv.org/abs/2303.17395), and [Clotho](https://arxiv.org/abs/1910.09387) (paper to be updated). Note that the audio files in the huggingface release of WavCaps have been downsampled to 32kHz. To the best of our ability, we located the original (high-sampling rate) audio files and used them instead to prevent artifacts during 44.1kHz training. We did not use the "SoundBible" portion of WavCaps, since it is a small set with many short audio unsuitable for our training.
|
41 |
+
|
42 |
+
5. Download the corresponding VAE (`v1-16.pth` for 16kHz training, and `v1-44.pth` for 44.1kHz training), vocoder models (`best_netG.pt` for 16kHz training; the vocoder for 44.1kHz training will be downloaded automatically), the [empty string encoding](https://github.com/hkchengrex/MMAudio/releases/download/v0.1/empty_string.pth), and Synchformer weights from [MODELS.md](https://github.com/hkchengrex/MMAudio/blob/main/docs/MODELS.md) place them in `ext_weights/`.
|
43 |
+
|
44 |
+
### Helpful links for downloading the datasets
|
45 |
+
|
46 |
+
We cannot redistribute the datasets for copyright reasons, but we do find some links helpful and they might be helpful to you as well.
|
47 |
+
|
48 |
+
- https://huggingface.co/datasets/Meranti/CLAP_freesound
|
49 |
+
- https://huggingface.co/datasets/agkphysics/AudioSet
|
50 |
+
- https://sound-effects.bbcrewind.co.uk/
|
51 |
+
|
52 |
+
For certain sources of VGGSound, you might notice desychronization between the audio and the video. This happens the video keyframes do not always align with the start of the audio and what happens during playbacks is player-dependent. We used PyTorch's decoder which can correctly handle these cases.
|
53 |
+
|
54 |
+
## Preparing Audio-Video-Text Features
|
55 |
+
|
56 |
+
We have prepared some example data in `training/example_videos`.
|
57 |
+
`training/extract_video_training_latents.py` extracts audio, video, and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
|
58 |
+
|
59 |
+
To run this script, use the `torchrun` utility:
|
60 |
+
|
61 |
+
```bash
|
62 |
+
torchrun --standalone training/extract_video_training_latents.py
|
63 |
+
```
|
64 |
+
|
65 |
+
You can run this script with multiple GPUs (with `--nproc_per_node=<n>` after `--standalone` and before the script name) to speed up extraction.
|
66 |
+
Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
|
67 |
+
Change the data path definitions in `data_cfg` if necessary.
|
68 |
+
|
69 |
+
Arguments:
|
70 |
+
|
71 |
+
- `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
|
72 |
+
- `output_dir` -- where TensorDict and the metadata file are saved.
|
73 |
+
|
74 |
+
Outputs produced in `output_dir`:
|
75 |
+
|
76 |
+
1. A directory named `vgg-{split}` (i.e., in the TensorDict format), containing
|
77 |
+
a. `mean.memmap` mean values predicted by the VAE encoder (number of videos X sequence length X channel size)
|
78 |
+
b. `std.memmap` standard deviation values predicted by the VAE encoder (number of videos X sequence length X channel size)
|
79 |
+
c. `text_features.memmap` text features extracted from CLIP (number of videos X 77 (sequence length) X 1024)
|
80 |
+
d. `clip_features.memmap` clip features extracted from CLIP (number of videos X 64 (8 fps) X 1024)
|
81 |
+
e. `sync_features.memmap` synchronization features extracted from Synchformer (number of videos X 192 (24 fps) X 768)
|
82 |
+
f. `meta.json` that contains the metadata for the above memory mappings
|
83 |
+
2. A tab-separated values file named `vgg-{split}.tsv` that contains two columns: `id` containing video file names without extension, and `label` containing corresponding text labels (i.e., captions)
|
84 |
+
|
85 |
+
## Preparing Audio-Text Features
|
86 |
+
|
87 |
+
We have prepared some example data in `training/example_audios`.
|
88 |
+
|
89 |
+
1. Run `training/partition_clips` to partition each audio file into clips (by finding start and end points; we do not save the partitioned audio onto the disk to save disk space)
|
90 |
+
2. Run `training/extract_audio_training_latents.py` to extract each clip's audio and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
|
91 |
+
|
92 |
+
### Partitioning the audio files
|
93 |
+
|
94 |
+
Run
|
95 |
+
|
96 |
+
```bash
|
97 |
+
python training/partition_clips.py
|
98 |
+
```
|
99 |
+
|
100 |
+
Arguments:
|
101 |
+
|
102 |
+
- `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`)
|
103 |
+
- `output_dir` -- path to the output `.csv` file
|
104 |
+
- `start` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the beginning of the chunk to be processed
|
105 |
+
- `end` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the end of the chunk to be processed
|
106 |
+
|
107 |
+
### Extracting audio and text features
|
108 |
+
|
109 |
+
Run
|
110 |
+
|
111 |
+
```bash
|
112 |
+
torchrun --standalone training/extract_audio_training_latents.py
|
113 |
+
```
|
114 |
+
|
115 |
+
You can run this with multiple GPUs (with `--nproc_per_node=<n>`) to speed up extraction.
|
116 |
+
Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
|
117 |
+
|
118 |
+
Arguments:
|
119 |
+
|
120 |
+
- `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`), same as the previous step
|
121 |
+
- `captions_tsv` -- path to the captions file, a tab-separated values (tsv) file at least with columns `id` and `caption`
|
122 |
+
- `clips_tsv` -- path to the clips file, generated in the last step
|
123 |
+
- `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
|
124 |
+
- `output_dir` -- where TensorDict and the metadata file are saved.
|
125 |
+
|
126 |
+
Outputs produced in `output_dir`:
|
127 |
+
|
128 |
+
1. A directory named `{basename(output_dir)}` (i.e., in the TensorDict format), containing
|
129 |
+
a. `mean.memmap` mean values predicted by the VAE encoder (number of audios X sequence length X channel size)
|
130 |
+
b. `std.memmap` standard deviation values predicted by the VAE encoder (number of audios X sequence length X channel size)
|
131 |
+
c. `text_features.memmap` text features extracted from CLIP (number of audios X 77 (sequence length) X 1024)
|
132 |
+
f. `meta.json` that contains the metadata for the above memory mappings
|
133 |
+
2. A tab-separated values file named `{basename(output_dir)}.tsv` that contains two columns: `id` containing audio file names without extension, and `label` containing corresponding text labels (i.e., captions)
|
134 |
+
|
135 |
+
### Reference tsv files (with overlaps removed as mentioned in the paper)
|
136 |
+
|
137 |
+
The reference tsv files can be found [here](https://github.com/hkchengrex/MMAudio/releases/tag/v0.1).
|
138 |
+
|
139 |
+
Note that these reference tsv files are the **outputs** of `extract_audio_training_latents.py`, which means the `id` column might contain duplicate entries (one per clip). You can still use it as the `captions_tsv` input though -- the script will handle duplicates gracefully.
|
140 |
+
Among these reference tsv files, `audioset_sl.tsv`, `bbcsound.tsv`, and `freesound.tsv` are subsets that are parts of WavCaps. These subsets might be smaller than the original datasets.
|
141 |
+
The Clotho data contains both the development set and the validation set.
|
142 |
+
|
143 |
+
**Update (Mar 9, 2025)**:
|
144 |
+
We have updated a corrected set of reference tsv files. The previous tsv files contained some (<1%) corrupted captions (ie, mismatch between audio and caption, see https://github.com/hkchengrex/MMAudio/issues/56). The tsv files for VGGSound are unaffected. This reason for this error is unknown, but I cannot reproduce this error in the latest version of the code. Our pre-trained models are trained with **uncorrected** tsv files. For future training, I recommend using the corrected tsv files.
|
145 |
+
|
146 |
+
The error statistics are as follows:
|
147 |
+
|
148 |
+
- AudioCaps (170/43824), 0.39%
|
149 |
+
- Freesound: (1670/180636), 0.92%
|
150 |
+
- AudioSet: (290/100776), 0.29%
|
151 |
+
- BBCSound: (3/29975), 0.01%
|
152 |
+
- Clotho: (8/24332), 0.03%
|
153 |
+
|
154 |
+
## Training on Extracted Features
|
155 |
+
|
156 |
+
We use Distributed Data Parallel (DDP) for training.
|
157 |
+
First, specify the data path in `config/data/base.yaml`. If you used the default parameters in the scripts above to extract features for the example data, the `Example_video` and `Example_audio` items should already be correct.
|
158 |
+
|
159 |
+
To run training on the example data, use the following command:
|
160 |
+
|
161 |
+
```bash
|
162 |
+
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=1 train.py exp_id=debug compile=False debug=True example_train=True batch_size=1
|
163 |
+
```
|
164 |
+
|
165 |
+
This will not train a useful model, but it will check if everything is set up correctly.
|
166 |
+
|
167 |
+
For full training on the base model with two GPUs, use the following command:
|
168 |
+
|
169 |
+
```bash
|
170 |
+
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=2 train.py exp_id=exp_1 model=small_16k
|
171 |
+
```
|
172 |
+
|
173 |
+
Any outputs from training will be stored in `output/<exp_id>`.
|
174 |
+
|
175 |
+
More configuration options can be found in `config/base_config.yaml` and `config/train_config.yaml`.
|
176 |
+
For the medium and large models, specify `vgg_oversample_rate` to be `3` to reduce overfitting.
|
177 |
+
|
178 |
+
## Checkpoints
|
179 |
+
|
180 |
+
Model checkpoints, including optimizer states and the latest EMA weights, are available here: https://huggingface.co/hkchengrex/MMAudio
|
181 |
+
|
182 |
+
---
|
183 |
+
|
184 |
+
Godspeed!
|
docs/images/icon.png
ADDED
![]() |
docs/index.html
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<!-- Google tag (gtag.js) -->
|
5 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
|
6 |
+
<script>
|
7 |
+
window.dataLayer = window.dataLayer || [];
|
8 |
+
function gtag(){dataLayer.push(arguments);}
|
9 |
+
gtag('js', new Date());
|
10 |
+
gtag('config', 'G-0JKBJ3WRJZ');
|
11 |
+
</script>
|
12 |
+
|
13 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
14 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
15 |
+
<link href="https://fonts.googleapis.com/css2?family=Source+Sans+3&display=swap" rel="stylesheet">
|
16 |
+
<meta charset="UTF-8">
|
17 |
+
<title>MMAudio</title>
|
18 |
+
|
19 |
+
<link rel="icon" type="image/png" href="images/icon.png">
|
20 |
+
|
21 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
22 |
+
<!-- CSS only -->
|
23 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
|
24 |
+
integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
|
25 |
+
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
|
26 |
+
|
27 |
+
<link rel="stylesheet" href="style.css">
|
28 |
+
</head>
|
29 |
+
<body>
|
30 |
+
|
31 |
+
<body>
|
32 |
+
<br><br><br><br>
|
33 |
+
<div class="container">
|
34 |
+
<div class="row text-center" style="font-size:38px">
|
35 |
+
<div class="col strong">
|
36 |
+
Taming Multimodal Joint Training for High-Quality <br>Video-to-Audio Synthesis
|
37 |
+
</div>
|
38 |
+
</div>
|
39 |
+
|
40 |
+
<br>
|
41 |
+
<div class="row text-center" style="font-size:28px">
|
42 |
+
<div class="col">
|
43 |
+
CVPR 2025
|
44 |
+
</div>
|
45 |
+
</div>
|
46 |
+
<br>
|
47 |
+
|
48 |
+
<div class="h-100 row text-center heavy justify-content-md-center" style="font-size:22px;">
|
49 |
+
<div class="col-sm-auto px-lg-2">
|
50 |
+
<a href="https://hkchengrex.github.io/">Ho Kei Cheng<sup>1</sup></a>
|
51 |
+
</div>
|
52 |
+
<div class="col-sm-auto px-lg-2">
|
53 |
+
<nobr><a href="https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ">Masato Ishii<sup>2</sup></a></nobr>
|
54 |
+
</div>
|
55 |
+
<div class="col-sm-auto px-lg-2">
|
56 |
+
<nobr><a href="https://scholar.google.com/citations?user=sXAjHFIAAAAJ">Akio Hayakawa<sup>2</sup></a></nobr>
|
57 |
+
</div>
|
58 |
+
<div class="col-sm-auto px-lg-2">
|
59 |
+
<nobr><a href="https://scholar.google.com/citations?user=XCRO260AAAAJ">Takashi Shibuya<sup>2</sup></a></nobr>
|
60 |
+
</div>
|
61 |
+
<div class="col-sm-auto px-lg-2">
|
62 |
+
<nobr><a href="https://www.alexander-schwing.de/">Alexander Schwing<sup>1</sup></a></nobr>
|
63 |
+
</div>
|
64 |
+
<div class="col-sm-auto px-lg-2" >
|
65 |
+
<nobr><a href="https://www.yukimitsufuji.com/">Yuki Mitsufuji<sup>2,3</sup></a></nobr>
|
66 |
+
</div>
|
67 |
+
</div>
|
68 |
+
|
69 |
+
<div class="h-100 row text-center heavy justify-content-md-center" style="font-size:22px;">
|
70 |
+
<div class="col-sm-auto px-lg-2">
|
71 |
+
<sup>1</sup>University of Illinois Urbana-Champaign
|
72 |
+
</div>
|
73 |
+
<div class="col-sm-auto px-lg-2">
|
74 |
+
<sup>2</sup>Sony AI
|
75 |
+
</div>
|
76 |
+
<div class="col-sm-auto px-lg-2">
|
77 |
+
<sup>3</sup>Sony Group Corporation
|
78 |
+
</div>
|
79 |
+
</div>
|
80 |
+
|
81 |
+
<br>
|
82 |
+
|
83 |
+
<br>
|
84 |
+
|
85 |
+
<div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
|
86 |
+
<div class="col-sm-2">
|
87 |
+
<a href="https://arxiv.org/abs/2412.15322">[Paper]</a>
|
88 |
+
</div>
|
89 |
+
<div class="col-sm-2">
|
90 |
+
<a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
|
91 |
+
</div>
|
92 |
+
<div class="col-sm-3">
|
93 |
+
<a href="https://huggingface.co/spaces/hkchengrex/MMAudio">[Huggingface Demo]</a>
|
94 |
+
</div>
|
95 |
+
<div class="col-sm-2">
|
96 |
+
<a href="https://colab.research.google.com/drive/1TAaXCY2-kPk4xE4PwKB3EqFbSnkUuzZ8?usp=sharing">[Colab Demo]</a>
|
97 |
+
</div>
|
98 |
+
<div class="col-sm-3">
|
99 |
+
<a href="https://replicate.com/zsxkib/mmaudio">[Replicate Demo]</a>
|
100 |
+
</div>
|
101 |
+
</div>
|
102 |
+
|
103 |
+
<br>
|
104 |
+
|
105 |
+
<hr>
|
106 |
+
|
107 |
+
<div class="row" style="font-size:32px">
|
108 |
+
<div class="col strong">
|
109 |
+
TL;DR
|
110 |
+
</div>
|
111 |
+
</div>
|
112 |
+
<br>
|
113 |
+
<div class="row">
|
114 |
+
<div class="col">
|
115 |
+
<p class="light" style="text-align: left;">
|
116 |
+
MMAudio generates synchronized audio given video and/or text inputs.
|
117 |
+
</p>
|
118 |
+
</div>
|
119 |
+
</div>
|
120 |
+
|
121 |
+
<br>
|
122 |
+
<hr>
|
123 |
+
<br>
|
124 |
+
|
125 |
+
<div class="row" style="font-size:32px">
|
126 |
+
<div class="col strong">
|
127 |
+
Demo
|
128 |
+
</div>
|
129 |
+
</div>
|
130 |
+
<br>
|
131 |
+
<div class="row" style="font-size:48px">
|
132 |
+
<div class="col strong text-center">
|
133 |
+
<a href="video_main.html" style="text-decoration: underline;"><More results></a>
|
134 |
+
</div>
|
135 |
+
</div>
|
136 |
+
<br>
|
137 |
+
<div class="video-container" style="text-align: center;">
|
138 |
+
<iframe src="https://youtube.com/embed/YElewUT2M4M"></iframe>
|
139 |
+
</div>
|
140 |
+
|
141 |
+
<br>
|
142 |
+
|
143 |
+
<br><br>
|
144 |
+
<br><br>
|
145 |
+
|
146 |
+
</div>
|
147 |
+
|
148 |
+
</body>
|
149 |
+
</html>
|
docs/style.css
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
body {
|
2 |
+
font-family: 'Source Sans 3', sans-serif;
|
3 |
+
font-size: 18px;
|
4 |
+
margin-left: auto;
|
5 |
+
margin-right: auto;
|
6 |
+
font-weight: 400;
|
7 |
+
height: 100%;
|
8 |
+
max-width: 1000px;
|
9 |
+
}
|
10 |
+
|
11 |
+
table {
|
12 |
+
width: 100%;
|
13 |
+
border-collapse: collapse;
|
14 |
+
}
|
15 |
+
th, td {
|
16 |
+
border: 1px solid #ddd;
|
17 |
+
padding: 8px;
|
18 |
+
text-align: center;
|
19 |
+
}
|
20 |
+
th {
|
21 |
+
background-color: #f2f2f2;
|
22 |
+
}
|
23 |
+
video {
|
24 |
+
width: 100%;
|
25 |
+
height: auto;
|
26 |
+
}
|
27 |
+
p {
|
28 |
+
font-size: 28px;
|
29 |
+
}
|
30 |
+
h2 {
|
31 |
+
font-size: 36px;
|
32 |
+
}
|
33 |
+
|
34 |
+
.strong {
|
35 |
+
font-weight: 700;
|
36 |
+
}
|
37 |
+
|
38 |
+
.light {
|
39 |
+
font-weight: 100;
|
40 |
+
}
|
41 |
+
|
42 |
+
.heavy {
|
43 |
+
font-weight: 900;
|
44 |
+
}
|
45 |
+
|
46 |
+
.column {
|
47 |
+
float: left;
|
48 |
+
}
|
49 |
+
|
50 |
+
a:link,
|
51 |
+
a:visited {
|
52 |
+
color: #05538f;
|
53 |
+
text-decoration: none;
|
54 |
+
}
|
55 |
+
|
56 |
+
a:hover {
|
57 |
+
color: #63cbdd;
|
58 |
+
}
|
59 |
+
|
60 |
+
hr {
|
61 |
+
border: 0;
|
62 |
+
height: 1px;
|
63 |
+
background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
|
64 |
+
}
|
65 |
+
|
66 |
+
.video-container {
|
67 |
+
position: relative;
|
68 |
+
padding-bottom: 56.25%; /* 16:9 */
|
69 |
+
height: 0;
|
70 |
+
}
|
71 |
+
|
72 |
+
.video-container iframe {
|
73 |
+
position: absolute;
|
74 |
+
top: 0;
|
75 |
+
left: 0;
|
76 |
+
width: 100%;
|
77 |
+
height: 100%;
|
78 |
+
}
|
docs/style_videos.css
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
body {
|
2 |
+
font-family: 'Source Sans 3', sans-serif;
|
3 |
+
font-size: 1.5vh;
|
4 |
+
font-weight: 400;
|
5 |
+
}
|
6 |
+
|
7 |
+
table {
|
8 |
+
width: 100%;
|
9 |
+
border-collapse: collapse;
|
10 |
+
}
|
11 |
+
th, td {
|
12 |
+
border: 1px solid #ddd;
|
13 |
+
padding: 8px;
|
14 |
+
text-align: center;
|
15 |
+
}
|
16 |
+
th {
|
17 |
+
background-color: #f2f2f2;
|
18 |
+
}
|
19 |
+
video {
|
20 |
+
width: 100%;
|
21 |
+
height: auto;
|
22 |
+
}
|
23 |
+
p {
|
24 |
+
font-size: 1.5vh;
|
25 |
+
font-weight: bold;
|
26 |
+
}
|
27 |
+
h2 {
|
28 |
+
font-size: 2vh;
|
29 |
+
font-weight: bold;
|
30 |
+
}
|
31 |
+
|
32 |
+
.video-container {
|
33 |
+
position: relative;
|
34 |
+
padding-bottom: 56.25%; /* 16:9 */
|
35 |
+
height: 0;
|
36 |
+
}
|
37 |
+
|
38 |
+
.video-container iframe {
|
39 |
+
position: absolute;
|
40 |
+
top: 0;
|
41 |
+
left: 0;
|
42 |
+
width: 100%;
|
43 |
+
height: 100%;
|
44 |
+
}
|
45 |
+
|
46 |
+
.video-header {
|
47 |
+
background-color: #f2f2f2;
|
48 |
+
text-align: center;
|
49 |
+
font-size: 1.5vh;
|
50 |
+
font-weight: bold;
|
51 |
+
padding: 8px;
|
52 |
+
}
|
docs/video_gen.html
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<!-- Google tag (gtag.js) -->
|
5 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
|
6 |
+
<script>
|
7 |
+
window.dataLayer = window.dataLayer || [];
|
8 |
+
function gtag(){dataLayer.push(arguments);}
|
9 |
+
gtag('js', new Date());
|
10 |
+
gtag('config', 'G-0JKBJ3WRJZ');
|
11 |
+
</script>
|
12 |
+
|
13 |
+
<link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
|
14 |
+
<meta charset="UTF-8">
|
15 |
+
<title>MMAudio</title>
|
16 |
+
|
17 |
+
<link rel="icon" type="image/png" href="images/icon.png">
|
18 |
+
|
19 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
20 |
+
<!-- CSS only -->
|
21 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
|
22 |
+
integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
|
23 |
+
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.7.1/jquery.min.js"></script>
|
24 |
+
|
25 |
+
<link rel="stylesheet" href="style_videos.css">
|
26 |
+
</head>
|
27 |
+
<body>
|
28 |
+
|
29 |
+
<div id="moviegen_all">
|
30 |
+
<h2 id="moviegen" style="text-align: center;">Comparisons with Movie Gen Audio on Videos Generated by MovieGen</h2>
|
31 |
+
<p id="moviegen1" style="overflow: hidden;">
|
32 |
+
Example 1: Ice cracking with sharp snapping sound, and metal tool scraping against the ice surface.
|
33 |
+
<span style="float: right;"><a href="#index">Back to index</a></span>
|
34 |
+
</p>
|
35 |
+
|
36 |
+
<div class="row g-1">
|
37 |
+
<div class="col-sm-6">
|
38 |
+
<div class="video-header">Movie Gen Audio</div>
|
39 |
+
<div class="video-container">
|
40 |
+
<iframe src="https://youtube.com/embed/d7Lb0ihtGcE"></iframe>
|
41 |
+
</div>
|
42 |
+
</div>
|
43 |
+
<div class="col-sm-6">
|
44 |
+
<div class="video-header">Ours</div>
|
45 |
+
<div class="video-container">
|
46 |
+
<iframe src="https://youtube.com/embed/F4JoJ2r2m8U"></iframe>
|
47 |
+
</div>
|
48 |
+
</div>
|
49 |
+
</div>
|
50 |
+
<br>
|
51 |
+
|
52 |
+
<!-- <p id="moviegen2">Example 2: Rhythmic splashing and lapping of water. <span style="float:right;"><a href="#index">Back to index</a></span> </p>
|
53 |
+
|
54 |
+
<table>
|
55 |
+
<thead>
|
56 |
+
<tr>
|
57 |
+
<th>Movie Gen Audio</th>
|
58 |
+
<th>Ours</th>
|
59 |
+
</tr>
|
60 |
+
</thead>
|
61 |
+
<tbody>
|
62 |
+
<tr>
|
63 |
+
<td width="50%">
|
64 |
+
<div class="video-container">
|
65 |
+
<iframe src="https://youtube.com/embed/5gQNPK99CIk"></iframe>
|
66 |
+
</div>
|
67 |
+
</td>
|
68 |
+
<td width="50%">
|
69 |
+
<div class="video-container">
|
70 |
+
<iframe src="https://youtube.com/embed/AbwnTzG-BpA"></iframe>
|
71 |
+
</div>
|
72 |
+
</td>
|
73 |
+
</tr>
|
74 |
+
</tbody>
|
75 |
+
</table> -->
|
76 |
+
|
77 |
+
<p id="moviegen2" style="overflow: hidden;">
|
78 |
+
Example 2: Rhythmic splashing and lapping of water.
|
79 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
80 |
+
</p>
|
81 |
+
<div class="row g-1">
|
82 |
+
<div class="col-sm-6">
|
83 |
+
<div class="video-header">Movie Gen Audio</div>
|
84 |
+
<div class="video-container">
|
85 |
+
<iframe src="https://youtube.com/embed/5gQNPK99CIk"></iframe>
|
86 |
+
</div>
|
87 |
+
</div>
|
88 |
+
<div class="col-sm-6">
|
89 |
+
<div class="video-header">Ours</div>
|
90 |
+
<div class="video-container">
|
91 |
+
<iframe src="https://youtube.com/embed/AbwnTzG-BpA"></iframe>
|
92 |
+
</div>
|
93 |
+
</div>
|
94 |
+
</div>
|
95 |
+
<br>
|
96 |
+
|
97 |
+
<p id="moviegen3" style="overflow: hidden;">
|
98 |
+
Example 3: Shovel scrapes against dry earth.
|
99 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
100 |
+
</p>
|
101 |
+
<div class="row g-1">
|
102 |
+
<div class="col-sm-6">
|
103 |
+
<div class="video-header">Movie Gen Audio</div>
|
104 |
+
<div class="video-container">
|
105 |
+
<iframe src="https://youtube.com/embed/PUKGyEve7XQ"></iframe>
|
106 |
+
</div>
|
107 |
+
</div>
|
108 |
+
<div class="col-sm-6">
|
109 |
+
<div class="video-header">Ours</div>
|
110 |
+
<div class="video-container">
|
111 |
+
<iframe src="https://youtube.com/embed/CNn7i8VNkdc"></iframe>
|
112 |
+
</div>
|
113 |
+
</div>
|
114 |
+
</div>
|
115 |
+
<br>
|
116 |
+
|
117 |
+
|
118 |
+
<p id="moviegen4" style="overflow: hidden;">
|
119 |
+
(Failure case) Example 4: Creamy sound of mashed potatoes being scooped.
|
120 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
121 |
+
</p>
|
122 |
+
<div class="row g-1">
|
123 |
+
<div class="col-sm-6">
|
124 |
+
<div class="video-header">Movie Gen Audio</div>
|
125 |
+
<div class="video-container">
|
126 |
+
<iframe src="https://youtube.com/embed/PJv1zxR9JjQ"></iframe>
|
127 |
+
</div>
|
128 |
+
</div>
|
129 |
+
<div class="col-sm-6">
|
130 |
+
<div class="video-header">Ours</div>
|
131 |
+
<div class="video-container">
|
132 |
+
<iframe src="https://youtube.com/embed/c3-LJ1lNsPQ"></iframe>
|
133 |
+
</div>
|
134 |
+
</div>
|
135 |
+
</div>
|
136 |
+
<br>
|
137 |
+
|
138 |
+
</div>
|
139 |
+
|
140 |
+
<div id="hunyuan_sora_all">
|
141 |
+
|
142 |
+
<h2 id="hunyuan" style="text-align: center;">Results on Videos Generated by Hunyuan</h2>
|
143 |
+
<p style="overflow: hidden;">
|
144 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
145 |
+
</p>
|
146 |
+
<div class="row g-1">
|
147 |
+
<div class="col-sm-6">
|
148 |
+
<div class="video-header">Typing</div>
|
149 |
+
<div class="video-container">
|
150 |
+
<iframe src="https://youtube.com/embed/8ln_9hhH_nk"></iframe>
|
151 |
+
</div>
|
152 |
+
</div>
|
153 |
+
<div class="col-sm-6">
|
154 |
+
<div class="video-header">Water is rushing down a stream and pouring</div>
|
155 |
+
<div class="video-container">
|
156 |
+
<iframe src="https://youtube.com/embed/5df1FZFQj30"></iframe>
|
157 |
+
</div>
|
158 |
+
</div>
|
159 |
+
</div>
|
160 |
+
<div class="row g-1">
|
161 |
+
<div class="col-sm-6">
|
162 |
+
<div class="video-header">Waves on beach</div>
|
163 |
+
<div class="video-container">
|
164 |
+
<iframe src="https://youtube.com/embed/7wQ9D5WgpFc"></iframe>
|
165 |
+
</div>
|
166 |
+
</div>
|
167 |
+
<div class="col-sm-6">
|
168 |
+
<div class="video-header">Water droplet</div>
|
169 |
+
<div class="video-container">
|
170 |
+
<iframe src="https://youtube.com/embed/q7M2nsalGjM"></iframe>
|
171 |
+
</div>
|
172 |
+
</div>
|
173 |
+
</div>
|
174 |
+
<br>
|
175 |
+
|
176 |
+
<h2 id="sora" style="text-align: center;">Results on Videos Generated by Sora</h2>
|
177 |
+
<p style="overflow: hidden;">
|
178 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
179 |
+
</p>
|
180 |
+
<div class="row g-1">
|
181 |
+
<div class="col-sm-6">
|
182 |
+
<div class="video-header">Ships riding waves</div>
|
183 |
+
<div class="video-container">
|
184 |
+
<iframe src="https://youtube.com/embed/JbgQzHHytk8"></iframe>
|
185 |
+
</div>
|
186 |
+
</div>
|
187 |
+
<div class="col-sm-6">
|
188 |
+
<div class="video-header">Train (no text prompt given)</div>
|
189 |
+
<div class="video-container">
|
190 |
+
<iframe src="https://youtube.com/embed/xOW7zrjpWC8"></iframe>
|
191 |
+
</div>
|
192 |
+
</div>
|
193 |
+
</div>
|
194 |
+
<div class="row g-1">
|
195 |
+
<div class="col-sm-6">
|
196 |
+
<div class="video-header">Seashore (no text prompt given)</div>
|
197 |
+
<div class="video-container">
|
198 |
+
<iframe src="https://youtube.com/embed/fIuw5Y8ZZ9E"></iframe>
|
199 |
+
</div>
|
200 |
+
</div>
|
201 |
+
<div class="col-sm-6">
|
202 |
+
<div class="video-header">Surfing (failure: unprompted music)</div>
|
203 |
+
<div class="video-container">
|
204 |
+
<iframe src="https://youtube.com/embed/UcSTk-v0M_s"></iframe>
|
205 |
+
</div>
|
206 |
+
</div>
|
207 |
+
</div>
|
208 |
+
<br>
|
209 |
+
|
210 |
+
<div id="mochi_ltx_all">
|
211 |
+
<h2 id="mochi" style="text-align: center;">Results on Videos Generated by Mochi 1</h2>
|
212 |
+
<p style="overflow: hidden;">
|
213 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
214 |
+
</p>
|
215 |
+
<div class="row g-1">
|
216 |
+
<div class="col-sm-6">
|
217 |
+
<div class="video-header">Magical fire and lightning (no text prompt given)</div>
|
218 |
+
<div class="video-container">
|
219 |
+
<iframe src="https://youtube.com/embed/tTlRZaSMNwY"></iframe>
|
220 |
+
</div>
|
221 |
+
</div>
|
222 |
+
<div class="col-sm-6">
|
223 |
+
<div class="video-header">Storm (no text prompt given)</div>
|
224 |
+
<div class="video-container">
|
225 |
+
<iframe src="https://youtube.com/embed/4hrZTMJUy3w"></iframe>
|
226 |
+
</div>
|
227 |
+
</div>
|
228 |
+
</div>
|
229 |
+
<br>
|
230 |
+
|
231 |
+
<h2 id="ltx" style="text-align: center;">Results on Videos Generated by LTX-Video</h2>
|
232 |
+
<p style="overflow: hidden;">
|
233 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
234 |
+
</p>
|
235 |
+
<div class="row g-1">
|
236 |
+
<div class="col-sm-6">
|
237 |
+
<div class="video-header">Firewood burning and cracking</div>
|
238 |
+
<div class="video-container">
|
239 |
+
<iframe src="https://youtube.com/embed/P7_DDpgev0g"></iframe>
|
240 |
+
</div>
|
241 |
+
</div>
|
242 |
+
<div class="col-sm-6">
|
243 |
+
<div class="video-header">Waterfall, water splashing</div>
|
244 |
+
<div class="video-container">
|
245 |
+
<iframe src="https://youtube.com/embed/4MvjceYnIO0"></iframe>
|
246 |
+
</div>
|
247 |
+
</div>
|
248 |
+
</div>
|
249 |
+
<br>
|
250 |
+
|
251 |
+
</div>
|
252 |
+
|
253 |
+
</body>
|
254 |
+
</html>
|
docs/video_main.html
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<!-- Google tag (gtag.js) -->
|
5 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
|
6 |
+
<script>
|
7 |
+
window.dataLayer = window.dataLayer || [];
|
8 |
+
function gtag(){dataLayer.push(arguments);}
|
9 |
+
gtag('js', new Date());
|
10 |
+
gtag('config', 'G-0JKBJ3WRJZ');
|
11 |
+
</script>
|
12 |
+
|
13 |
+
<link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
|
14 |
+
<meta charset="UTF-8">
|
15 |
+
<title>MMAudio</title>
|
16 |
+
|
17 |
+
<link rel="icon" type="image/png" href="images/icon.png">
|
18 |
+
|
19 |
+
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
|
20 |
+
<!-- CSS only -->
|
21 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
|
22 |
+
integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
|
23 |
+
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.7.1/jquery.min.js"></script>
|
24 |
+
|
25 |
+
<link rel="stylesheet" href="style_videos.css">
|
26 |
+
|
27 |
+
<script type="text/javascript">
|
28 |
+
$(document).ready(function(){
|
29 |
+
$("#content").load("video_gen.html #moviegen_all");
|
30 |
+
$("#load_moveigen").click(function(){
|
31 |
+
$("#content").load("video_gen.html #moviegen_all");
|
32 |
+
});
|
33 |
+
$("#load_hunyuan_sora").click(function(){
|
34 |
+
$("#content").load("video_gen.html #hunyuan_sora_all");
|
35 |
+
});
|
36 |
+
$("#load_mochi_ltx").click(function(){
|
37 |
+
$("#content").load("video_gen.html #mochi_ltx_all");
|
38 |
+
});
|
39 |
+
$("#load_vgg1").click(function(){
|
40 |
+
$("#content").load("video_vgg.html #vgg1");
|
41 |
+
});
|
42 |
+
$("#load_vgg2").click(function(){
|
43 |
+
$("#content").load("video_vgg.html #vgg2");
|
44 |
+
});
|
45 |
+
$("#load_vgg3").click(function(){
|
46 |
+
$("#content").load("video_vgg.html #vgg3");
|
47 |
+
});
|
48 |
+
$("#load_vgg4").click(function(){
|
49 |
+
$("#content").load("video_vgg.html #vgg4");
|
50 |
+
});
|
51 |
+
$("#load_vgg5").click(function(){
|
52 |
+
$("#content").load("video_vgg.html #vgg5");
|
53 |
+
});
|
54 |
+
$("#load_vgg6").click(function(){
|
55 |
+
$("#content").load("video_vgg.html #vgg6");
|
56 |
+
});
|
57 |
+
$("#load_vgg_extra").click(function(){
|
58 |
+
$("#content").load("video_vgg.html #vgg_extra");
|
59 |
+
});
|
60 |
+
});
|
61 |
+
</script>
|
62 |
+
</head>
|
63 |
+
<body>
|
64 |
+
<h1 id="index" style="text-align: center;">Index</h1>
|
65 |
+
<p><b>(Click on the links to load the corresponding videos)</b> <span style="float:right;"><a href="index.html">Back to project page</a></span></p>
|
66 |
+
|
67 |
+
<ol>
|
68 |
+
<li>
|
69 |
+
<a href="#" id="load_moveigen">Comparisons with Movie Gen Audio on Videos Generated by MovieGen</a>
|
70 |
+
</li>
|
71 |
+
<li>
|
72 |
+
<a href="#" id="load_hunyuan_sora">Results on Videos Generated by Hunyuan and Sora</a>
|
73 |
+
</li>
|
74 |
+
<li>
|
75 |
+
<a href="#" id="load_mochi_ltx">Results on Videos Generated by Mochi 1 and LTX-Video</a>
|
76 |
+
</li>
|
77 |
+
<li>
|
78 |
+
On VGGSound
|
79 |
+
<ol>
|
80 |
+
<li><a id='load_vgg1' href="#">Example 1: Wolf howling</a></li>
|
81 |
+
<li><a id='load_vgg2' href="#">Example 2: Striking a golf ball</a></li>
|
82 |
+
<li><a id='load_vgg3' href="#">Example 3: Hitting a drum</a></li>
|
83 |
+
<li><a id='load_vgg4' href="#">Example 4: Dog barking</a></li>
|
84 |
+
<li><a id='load_vgg5' href="#">Example 5: Playing a string instrument</a></li>
|
85 |
+
<li><a id='load_vgg6' href="#">Example 6: A group of people playing tambourines</a></li>
|
86 |
+
<li><a id='load_vgg_extra' href="#">Extra results & failure cases</a></li>
|
87 |
+
</ol>
|
88 |
+
</li>
|
89 |
+
</ol>
|
90 |
+
|
91 |
+
<div id="content" class="container-fluid">
|
92 |
+
|
93 |
+
</div>
|
94 |
+
<br>
|
95 |
+
<br>
|
96 |
+
|
97 |
+
</body>
|
98 |
+
</html>
|
docs/video_vgg.html
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<!-- Google tag (gtag.js) -->
|
5 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
|
6 |
+
<script>
|
7 |
+
window.dataLayer = window.dataLayer || [];
|
8 |
+
function gtag(){dataLayer.push(arguments);}
|
9 |
+
gtag('js', new Date());
|
10 |
+
gtag('config', 'G-0JKBJ3WRJZ');
|
11 |
+
</script>
|
12 |
+
|
13 |
+
<link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
|
14 |
+
<meta charset="UTF-8">
|
15 |
+
<title>MMAudio</title>
|
16 |
+
|
17 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
18 |
+
<!-- CSS only -->
|
19 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
|
20 |
+
integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
|
21 |
+
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
|
22 |
+
|
23 |
+
<link rel="stylesheet" href="style_videos.css">
|
24 |
+
</head>
|
25 |
+
<body>
|
26 |
+
|
27 |
+
<div id="vgg1">
|
28 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
29 |
+
<p style="overflow: hidden;">
|
30 |
+
Example 1: Wolf howling.
|
31 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
32 |
+
</p>
|
33 |
+
<div class="row g-1">
|
34 |
+
<div class="col-sm-3">
|
35 |
+
<div class="video-header">Ground-truth</div>
|
36 |
+
<div class="video-container">
|
37 |
+
<iframe src="https://youtube.com/embed/9J_V74gqMUA"></iframe>
|
38 |
+
</div>
|
39 |
+
</div>
|
40 |
+
<div class="col-sm-3">
|
41 |
+
<div class="video-header">Ours</div>
|
42 |
+
<div class="video-container">
|
43 |
+
<iframe src="https://youtube.com/embed/P6O8IpjErPc"></iframe>
|
44 |
+
</div>
|
45 |
+
</div>
|
46 |
+
<div class="col-sm-3">
|
47 |
+
<div class="video-header">V2A-Mapper</div>
|
48 |
+
<div class="video-container">
|
49 |
+
<iframe src="https://youtube.com/embed/w-5eyqepvTk"></iframe>
|
50 |
+
</div>
|
51 |
+
</div>
|
52 |
+
<div class="col-sm-3">
|
53 |
+
<div class="video-header">FoleyCrafter</div>
|
54 |
+
<div class="video-container">
|
55 |
+
<iframe src="https://youtube.com/embed/VOLfoZlRkzo"></iframe>
|
56 |
+
</div>
|
57 |
+
</div>
|
58 |
+
</div>
|
59 |
+
<div class="row g-1">
|
60 |
+
<div class="col-sm-3">
|
61 |
+
<div class="video-header">Frieren</div>
|
62 |
+
<div class="video-container">
|
63 |
+
<iframe src="https://youtube.com/embed/49owKyA5Pa8"></iframe>
|
64 |
+
</div>
|
65 |
+
</div>
|
66 |
+
<div class="col-sm-3">
|
67 |
+
<div class="video-header">VATT</div>
|
68 |
+
<div class="video-container">
|
69 |
+
<iframe src="https://youtube.com/embed/QVtrFgbeGDM"></iframe>
|
70 |
+
</div>
|
71 |
+
</div>
|
72 |
+
<div class="col-sm-3">
|
73 |
+
<div class="video-header">V-AURA</div>
|
74 |
+
<div class="video-container">
|
75 |
+
<iframe src="https://youtube.com/embed/8r0uEfSNjvI"></iframe>
|
76 |
+
</div>
|
77 |
+
</div>
|
78 |
+
<div class="col-sm-3">
|
79 |
+
<div class="video-header">Seeing and Hearing</div>
|
80 |
+
<div class="video-container">
|
81 |
+
<iframe src="https://youtube.com/embed/bn-sLg2qulk"></iframe>
|
82 |
+
</div>
|
83 |
+
</div>
|
84 |
+
</div>
|
85 |
+
</div>
|
86 |
+
|
87 |
+
<div id="vgg2">
|
88 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
89 |
+
<p style="overflow: hidden;">
|
90 |
+
Example 2: Striking a golf ball.
|
91 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
92 |
+
</p>
|
93 |
+
|
94 |
+
<div class="row g-1">
|
95 |
+
<div class="col-sm-3">
|
96 |
+
<div class="video-header">Ground-truth</div>
|
97 |
+
<div class="video-container">
|
98 |
+
<iframe src="https://youtube.com/embed/1hwSu42kkho"></iframe>
|
99 |
+
</div>
|
100 |
+
</div>
|
101 |
+
<div class="col-sm-3">
|
102 |
+
<div class="video-header">Ours</div>
|
103 |
+
<div class="video-container">
|
104 |
+
<iframe src="https://youtube.com/embed/kZibDoDCNxI"></iframe>
|
105 |
+
</div>
|
106 |
+
</div>
|
107 |
+
<div class="col-sm-3">
|
108 |
+
<div class="video-header">V2A-Mapper</div>
|
109 |
+
<div class="video-container">
|
110 |
+
<iframe src="https://youtube.com/embed/jgKfLBLhh7Y"></iframe>
|
111 |
+
</div>
|
112 |
+
</div>
|
113 |
+
<div class="col-sm-3">
|
114 |
+
<div class="video-header">FoleyCrafter</div>
|
115 |
+
<div class="video-container">
|
116 |
+
<iframe src="https://youtube.com/embed/Lfsx8mOPcJo"></iframe>
|
117 |
+
</div>
|
118 |
+
</div>
|
119 |
+
</div>
|
120 |
+
<div class="row g-1">
|
121 |
+
<div class="col-sm-3">
|
122 |
+
<div class="video-header">Frieren</div>
|
123 |
+
<div class="video-container">
|
124 |
+
<iframe src="https://youtube.com/embed/tz-LpbB0MBc"></iframe>
|
125 |
+
</div>
|
126 |
+
</div>
|
127 |
+
<div class="col-sm-3">
|
128 |
+
<div class="video-header">VATT</div>
|
129 |
+
<div class="video-container">
|
130 |
+
<iframe src="https://youtube.com/embed/RTDUHMi08n4"></iframe>
|
131 |
+
</div>
|
132 |
+
</div>
|
133 |
+
<div class="col-sm-3">
|
134 |
+
<div class="video-header">V-AURA</div>
|
135 |
+
<div class="video-container">
|
136 |
+
<iframe src="https://youtube.com/embed/N-3TDOsPnZQ"></iframe>
|
137 |
+
</div>
|
138 |
+
</div>
|
139 |
+
<div class="col-sm-3">
|
140 |
+
<div class="video-header">Seeing and Hearing</div>
|
141 |
+
<div class="video-container">
|
142 |
+
<iframe src="https://youtube.com/embed/QnsHnLn4gB0"></iframe>
|
143 |
+
</div>
|
144 |
+
</div>
|
145 |
+
</div>
|
146 |
+
</div>
|
147 |
+
|
148 |
+
<div id="vgg3">
|
149 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
150 |
+
<p style="overflow: hidden;">
|
151 |
+
Example 3: Hitting a drum.
|
152 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
153 |
+
</p>
|
154 |
+
|
155 |
+
<div class="row g-1">
|
156 |
+
<div class="col-sm-3">
|
157 |
+
<div class="video-header">Ground-truth</div>
|
158 |
+
<div class="video-container">
|
159 |
+
<iframe src="https://youtube.com/embed/0oeIwq77w0Q"></iframe>
|
160 |
+
</div>
|
161 |
+
</div>
|
162 |
+
<div class="col-sm-3">
|
163 |
+
<div class="video-header">Ours</div>
|
164 |
+
<div class="video-container">
|
165 |
+
<iframe src="https://youtube.com/embed/-UtPV9ohuIM"></iframe>
|
166 |
+
</div>
|
167 |
+
</div>
|
168 |
+
<div class="col-sm-3">
|
169 |
+
<div class="video-header">V2A-Mapper</div>
|
170 |
+
<div class="video-container">
|
171 |
+
<iframe src="https://youtube.com/embed/9yivkgN-zwc"></iframe>
|
172 |
+
</div>
|
173 |
+
</div>
|
174 |
+
<div class="col-sm-3">
|
175 |
+
<div class="video-header">FoleyCrafter</div>
|
176 |
+
<div class="video-container">
|
177 |
+
<iframe src="https://youtube.com/embed/kkCsXPOlBvY"></iframe>
|
178 |
+
</div>
|
179 |
+
</div>
|
180 |
+
</div>
|
181 |
+
<div class="row g-1">
|
182 |
+
<div class="col-sm-3">
|
183 |
+
<div class="video-header">Frieren</div>
|
184 |
+
<div class="video-container">
|
185 |
+
<iframe src="https://youtube.com/embed/MbNKsVsuvig"></iframe>
|
186 |
+
</div>
|
187 |
+
</div>
|
188 |
+
<div class="col-sm-3">
|
189 |
+
<div class="video-header">VATT</div>
|
190 |
+
<div class="video-container">
|
191 |
+
<iframe src="https://youtube.com/embed/2yYviBjrpBw"></iframe>
|
192 |
+
</div>
|
193 |
+
</div>
|
194 |
+
<div class="col-sm-3">
|
195 |
+
<div class="video-header">V-AURA</div>
|
196 |
+
<div class="video-container">
|
197 |
+
<iframe src="https://youtube.com/embed/9yivkgN-zwc"></iframe>
|
198 |
+
</div>
|
199 |
+
</div>
|
200 |
+
<div class="col-sm-3">
|
201 |
+
<div class="video-header">Seeing and Hearing</div>
|
202 |
+
<div class="video-container">
|
203 |
+
<iframe src="https://youtube.com/embed/6dnyQt4Fuhs"></iframe>
|
204 |
+
</div>
|
205 |
+
</div>
|
206 |
+
</div>
|
207 |
+
</div>
|
208 |
+
</div>
|
209 |
+
|
210 |
+
<div id="vgg4">
|
211 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
212 |
+
<p style="overflow: hidden;">
|
213 |
+
Example 4: Dog barking.
|
214 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
215 |
+
</p>
|
216 |
+
|
217 |
+
<div class="row g-1">
|
218 |
+
<div class="col-sm-3">
|
219 |
+
<div class="video-header">Ground-truth</div>
|
220 |
+
<div class="video-container">
|
221 |
+
<iframe src="https://youtube.com/embed/ckaqvTyMYAw"></iframe>
|
222 |
+
</div>
|
223 |
+
</div>
|
224 |
+
<div class="col-sm-3">
|
225 |
+
<div class="video-header">Ours</div>
|
226 |
+
<div class="video-container">
|
227 |
+
<iframe src="https://youtube.com/embed/_aRndFZzZ-I"></iframe>
|
228 |
+
</div>
|
229 |
+
</div>
|
230 |
+
<div class="col-sm-3">
|
231 |
+
<div class="video-header">V2A-Mapper</div>
|
232 |
+
<div class="video-container">
|
233 |
+
<iframe src="https://youtube.com/embed/mNCISP3LBl0"></iframe>
|
234 |
+
</div>
|
235 |
+
</div>
|
236 |
+
<div class="col-sm-3">
|
237 |
+
<div class="video-header">FoleyCrafter</div>
|
238 |
+
<div class="video-container">
|
239 |
+
<iframe src="https://youtube.com/embed/phZBQ3L7foE"></iframe>
|
240 |
+
</div>
|
241 |
+
</div>
|
242 |
+
</div>
|
243 |
+
<div class="row g-1">
|
244 |
+
<div class="col-sm-3">
|
245 |
+
<div class="video-header">Frieren</div>
|
246 |
+
<div class="video-container">
|
247 |
+
<iframe src="https://youtube.com/embed/Sb5Mg1-ORao"></iframe>
|
248 |
+
</div>
|
249 |
+
</div>
|
250 |
+
<div class="col-sm-3">
|
251 |
+
<div class="video-header">VATT</div>
|
252 |
+
<div class="video-container">
|
253 |
+
<iframe src="https://youtube.com/embed/eHmAGOmtDDg"></iframe>
|
254 |
+
</div>
|
255 |
+
</div>
|
256 |
+
<div class="col-sm-3">
|
257 |
+
<div class="video-header">V-AURA</div>
|
258 |
+
<div class="video-container">
|
259 |
+
<iframe src="https://youtube.com/embed/NEGa3krBrm0"></iframe>
|
260 |
+
</div>
|
261 |
+
</div>
|
262 |
+
<div class="col-sm-3">
|
263 |
+
<div class="video-header">Seeing and Hearing</div>
|
264 |
+
<div class="video-container">
|
265 |
+
<iframe src="https://youtube.com/embed/aO0EAXlwE7A"></iframe>
|
266 |
+
</div>
|
267 |
+
</div>
|
268 |
+
</div>
|
269 |
+
</div>
|
270 |
+
|
271 |
+
<div id="vgg5">
|
272 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
273 |
+
<p style="overflow: hidden;">
|
274 |
+
Example 5: Playing a string instrument.
|
275 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
276 |
+
</p>
|
277 |
+
|
278 |
+
<div class="row g-1">
|
279 |
+
<div class="col-sm-3">
|
280 |
+
<div class="video-header">Ground-truth</div>
|
281 |
+
<div class="video-container">
|
282 |
+
<iframe src="https://youtube.com/embed/KP1QhWauIOc"></iframe>
|
283 |
+
</div>
|
284 |
+
</div>
|
285 |
+
<div class="col-sm-3">
|
286 |
+
<div class="video-header">Ours</div>
|
287 |
+
<div class="video-container">
|
288 |
+
<iframe src="https://youtube.com/embed/ovaJhWSquYE"></iframe>
|
289 |
+
</div>
|
290 |
+
</div>
|
291 |
+
<div class="col-sm-3">
|
292 |
+
<div class="video-header">V2A-Mapper</div>
|
293 |
+
<div class="video-container">
|
294 |
+
<iframe src="https://youtube.com/embed/N723FS9lcy8"></iframe>
|
295 |
+
</div>
|
296 |
+
</div>
|
297 |
+
<div class="col-sm-3">
|
298 |
+
<div class="video-header">FoleyCrafter</div>
|
299 |
+
<div class="video-container">
|
300 |
+
<iframe src="https://youtube.com/embed/t0N4ZAAXo58"></iframe>
|
301 |
+
</div>
|
302 |
+
</div>
|
303 |
+
</div>
|
304 |
+
<div class="row g-1">
|
305 |
+
<div class="col-sm-3">
|
306 |
+
<div class="video-header">Frieren</div>
|
307 |
+
<div class="video-container">
|
308 |
+
<iframe src="https://youtube.com/embed/8YSRs03QNNA"></iframe>
|
309 |
+
</div>
|
310 |
+
</div>
|
311 |
+
<div class="col-sm-3">
|
312 |
+
<div class="video-header">VATT</div>
|
313 |
+
<div class="video-container">
|
314 |
+
<iframe src="https://youtube.com/embed/vOpMz55J1kY"></iframe>
|
315 |
+
</div>
|
316 |
+
</div>
|
317 |
+
<div class="col-sm-3">
|
318 |
+
<div class="video-header">V-AURA</div>
|
319 |
+
<div class="video-container">
|
320 |
+
<iframe src="https://youtube.com/embed/9JHC75vr9h0"></iframe>
|
321 |
+
</div>
|
322 |
+
</div>
|
323 |
+
<div class="col-sm-3">
|
324 |
+
<div class="video-header">Seeing and Hearing</div>
|
325 |
+
<div class="video-container">
|
326 |
+
<iframe src="https://youtube.com/embed/9w0JckNzXmY"></iframe>
|
327 |
+
</div>
|
328 |
+
</div>
|
329 |
+
</div>
|
330 |
+
</div>
|
331 |
+
|
332 |
+
<div id="vgg6">
|
333 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
334 |
+
<p style="overflow: hidden;">
|
335 |
+
Example 6: A group of people playing tambourines.
|
336 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
337 |
+
</p>
|
338 |
+
|
339 |
+
<div class="row g-1">
|
340 |
+
<div class="col-sm-3">
|
341 |
+
<div class="video-header">Ground-truth</div>
|
342 |
+
<div class="video-container">
|
343 |
+
<iframe src="https://youtube.com/embed/mx6JLxzUkRc"></iframe>
|
344 |
+
</div>
|
345 |
+
</div>
|
346 |
+
<div class="col-sm-3">
|
347 |
+
<div class="video-header">Ours</div>
|
348 |
+
<div class="video-container">
|
349 |
+
<iframe src="https://youtube.com/embed/oLirHhP9Su8"></iframe>
|
350 |
+
</div>
|
351 |
+
</div>
|
352 |
+
<div class="col-sm-3">
|
353 |
+
<div class="video-header">V2A-Mapper</div>
|
354 |
+
<div class="video-container">
|
355 |
+
<iframe src="https://youtube.com/embed/HkLkHMqptv0"></iframe>
|
356 |
+
</div>
|
357 |
+
</div>
|
358 |
+
<div class="col-sm-3">
|
359 |
+
<div class="video-header">FoleyCrafter</div>
|
360 |
+
<div class="video-container">
|
361 |
+
<iframe src="https://youtube.com/embed/rpHiiODjmNU"></iframe>
|
362 |
+
</div>
|
363 |
+
</div>
|
364 |
+
</div>
|
365 |
+
<div class="row g-1">
|
366 |
+
<div class="col-sm-3">
|
367 |
+
<div class="video-header">Frieren</div>
|
368 |
+
<div class="video-container">
|
369 |
+
<iframe src="https://youtube.com/embed/1mVD3fJ0LpM"></iframe>
|
370 |
+
</div>
|
371 |
+
</div>
|
372 |
+
<div class="col-sm-3">
|
373 |
+
<div class="video-header">VATT</div>
|
374 |
+
<div class="video-container">
|
375 |
+
<iframe src="https://youtube.com/embed/yjVFnJiEJlw"></iframe>
|
376 |
+
</div>
|
377 |
+
</div>
|
378 |
+
<div class="col-sm-3">
|
379 |
+
<div class="video-header">V-AURA</div>
|
380 |
+
<div class="video-container">
|
381 |
+
<iframe src="https://youtube.com/embed/neVeMSWtRkU"></iframe>
|
382 |
+
</div>
|
383 |
+
</div>
|
384 |
+
<div class="col-sm-3">
|
385 |
+
<div class="video-header">Seeing and Hearing</div>
|
386 |
+
<div class="video-container">
|
387 |
+
<iframe src="https://youtube.com/embed/EUE7YwyVWz8"></iframe>
|
388 |
+
</div>
|
389 |
+
</div>
|
390 |
+
</div>
|
391 |
+
</div>
|
392 |
+
|
393 |
+
<div id="vgg_extra">
|
394 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
395 |
+
<p style="overflow: hidden;">
|
396 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
397 |
+
</p>
|
398 |
+
|
399 |
+
<div class="row g-1">
|
400 |
+
<div class="col-sm-3">
|
401 |
+
<div class="video-header">Moving train</div>
|
402 |
+
<div class="video-container">
|
403 |
+
<iframe src="https://youtube.com/embed/Ta6H45rBzJc"></iframe>
|
404 |
+
</div>
|
405 |
+
</div>
|
406 |
+
<div class="col-sm-3">
|
407 |
+
<div class="video-header">Water splashing</div>
|
408 |
+
<div class="video-container">
|
409 |
+
<iframe src="https://youtube.com/embed/hl6AtgHXpb4"></iframe>
|
410 |
+
</div>
|
411 |
+
</div>
|
412 |
+
<div class="col-sm-3">
|
413 |
+
<div class="video-header">Skateboarding</div>
|
414 |
+
<div class="video-container">
|
415 |
+
<iframe src="https://youtube.com/embed/n4sCNi_9buI"></iframe>
|
416 |
+
</div>
|
417 |
+
</div>
|
418 |
+
<div class="col-sm-3">
|
419 |
+
<div class="video-header">Synchronized clapping</div>
|
420 |
+
<div class="video-container">
|
421 |
+
<iframe src="https://youtube.com/embed/oxexfpLn7FE"></iframe>
|
422 |
+
</div>
|
423 |
+
</div>
|
424 |
+
</div>
|
425 |
+
|
426 |
+
<br><br>
|
427 |
+
|
428 |
+
<div id="extra-failure">
|
429 |
+
<h2 style="text-align: center;">Failure cases</h2>
|
430 |
+
<p style="overflow: hidden;">
|
431 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
432 |
+
</p>
|
433 |
+
|
434 |
+
<div class="row g-1">
|
435 |
+
<div class="col-sm-6">
|
436 |
+
<div class="video-header">Human speech</div>
|
437 |
+
<div class="video-container">
|
438 |
+
<iframe src="https://youtube.com/embed/nx0CyrDu70Y"></iframe>
|
439 |
+
</div>
|
440 |
+
</div>
|
441 |
+
<div class="col-sm-6">
|
442 |
+
<div class="video-header">Unfamiliar vision input</div>
|
443 |
+
<div class="video-container">
|
444 |
+
<iframe src="https://youtube.com/embed/hfnAqmK3X7w"></iframe>
|
445 |
+
</div>
|
446 |
+
</div>
|
447 |
+
</div>
|
448 |
+
</div>
|
449 |
+
</div>
|
450 |
+
|
451 |
+
</body>
|
452 |
+
</html>
|
gradio_demo.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
from datetime import datetime
|
5 |
+
from fractions import Fraction
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import torch
|
10 |
+
import torchaudio
|
11 |
+
|
12 |
+
from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
|
13 |
+
load_video, make_video, setup_eval_logging)
|
14 |
+
from mmaudio.model.flow_matching import FlowMatching
|
15 |
+
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
16 |
+
from mmaudio.model.sequence_config import SequenceConfig
|
17 |
+
from mmaudio.model.utils.features_utils import FeaturesUtils
|
18 |
+
|
19 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
20 |
+
torch.backends.cudnn.allow_tf32 = True
|
21 |
+
|
22 |
+
log = logging.getLogger()
|
23 |
+
|
24 |
+
device = 'cpu'
|
25 |
+
if torch.cuda.is_available():
|
26 |
+
device = 'cuda'
|
27 |
+
elif torch.backends.mps.is_available():
|
28 |
+
device = 'mps'
|
29 |
+
else:
|
30 |
+
log.warning('CUDA/MPS are not available, running on CPU')
|
31 |
+
dtype = torch.bfloat16
|
32 |
+
|
33 |
+
model: ModelConfig = all_model_cfg['large_44k_v2']
|
34 |
+
model.download_if_needed()
|
35 |
+
output_dir = Path('./output/gradio')
|
36 |
+
|
37 |
+
setup_eval_logging()
|
38 |
+
|
39 |
+
|
40 |
+
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
41 |
+
seq_cfg = model.seq_cfg
|
42 |
+
|
43 |
+
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
|
44 |
+
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
45 |
+
log.info(f'Loaded weights from {model.model_path}')
|
46 |
+
|
47 |
+
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
48 |
+
synchformer_ckpt=model.synchformer_ckpt,
|
49 |
+
enable_conditions=True,
|
50 |
+
mode=model.mode,
|
51 |
+
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
52 |
+
need_vae_encoder=False)
|
53 |
+
feature_utils = feature_utils.to(device, dtype).eval()
|
54 |
+
|
55 |
+
return net, feature_utils, seq_cfg
|
56 |
+
|
57 |
+
|
58 |
+
net, feature_utils, seq_cfg = get_model()
|
59 |
+
|
60 |
+
|
61 |
+
@torch.inference_mode()
|
62 |
+
def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
63 |
+
cfg_strength: float, duration: float):
|
64 |
+
|
65 |
+
rng = torch.Generator(device=device)
|
66 |
+
if seed >= 0:
|
67 |
+
rng.manual_seed(seed)
|
68 |
+
else:
|
69 |
+
rng.seed()
|
70 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
71 |
+
|
72 |
+
video_info = load_video(video, duration)
|
73 |
+
clip_frames = video_info.clip_frames
|
74 |
+
sync_frames = video_info.sync_frames
|
75 |
+
duration = video_info.duration_sec
|
76 |
+
clip_frames = clip_frames.unsqueeze(0)
|
77 |
+
sync_frames = sync_frames.unsqueeze(0)
|
78 |
+
seq_cfg.duration = duration
|
79 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
80 |
+
|
81 |
+
audios = generate(clip_frames,
|
82 |
+
sync_frames, [prompt],
|
83 |
+
negative_text=[negative_prompt],
|
84 |
+
feature_utils=feature_utils,
|
85 |
+
net=net,
|
86 |
+
fm=fm,
|
87 |
+
rng=rng,
|
88 |
+
cfg_strength=cfg_strength)
|
89 |
+
audio = audios.float().cpu()[0]
|
90 |
+
|
91 |
+
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
92 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
93 |
+
video_save_path = output_dir / f'{current_time_string}.mp4'
|
94 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
95 |
+
gc.collect()
|
96 |
+
return video_save_path
|
97 |
+
|
98 |
+
|
99 |
+
@torch.inference_mode()
|
100 |
+
def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
101 |
+
cfg_strength: float, duration: float):
|
102 |
+
|
103 |
+
rng = torch.Generator(device=device)
|
104 |
+
if seed >= 0:
|
105 |
+
rng.manual_seed(seed)
|
106 |
+
else:
|
107 |
+
rng.seed()
|
108 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
109 |
+
|
110 |
+
image_info = load_image(image)
|
111 |
+
clip_frames = image_info.clip_frames
|
112 |
+
sync_frames = image_info.sync_frames
|
113 |
+
clip_frames = clip_frames.unsqueeze(0)
|
114 |
+
sync_frames = sync_frames.unsqueeze(0)
|
115 |
+
seq_cfg.duration = duration
|
116 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
117 |
+
|
118 |
+
audios = generate(clip_frames,
|
119 |
+
sync_frames, [prompt],
|
120 |
+
negative_text=[negative_prompt],
|
121 |
+
feature_utils=feature_utils,
|
122 |
+
net=net,
|
123 |
+
fm=fm,
|
124 |
+
rng=rng,
|
125 |
+
cfg_strength=cfg_strength,
|
126 |
+
image_input=True)
|
127 |
+
audio = audios.float().cpu()[0]
|
128 |
+
|
129 |
+
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
130 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
131 |
+
video_save_path = output_dir / f'{current_time_string}.mp4'
|
132 |
+
video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1))
|
133 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
134 |
+
gc.collect()
|
135 |
+
return video_save_path
|
136 |
+
|
137 |
+
|
138 |
+
@torch.inference_mode()
|
139 |
+
def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
|
140 |
+
duration: float):
|
141 |
+
|
142 |
+
rng = torch.Generator(device=device)
|
143 |
+
if seed >= 0:
|
144 |
+
rng.manual_seed(seed)
|
145 |
+
else:
|
146 |
+
rng.seed()
|
147 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
148 |
+
|
149 |
+
clip_frames = sync_frames = None
|
150 |
+
seq_cfg.duration = duration
|
151 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
152 |
+
|
153 |
+
audios = generate(clip_frames,
|
154 |
+
sync_frames, [prompt],
|
155 |
+
negative_text=[negative_prompt],
|
156 |
+
feature_utils=feature_utils,
|
157 |
+
net=net,
|
158 |
+
fm=fm,
|
159 |
+
rng=rng,
|
160 |
+
cfg_strength=cfg_strength)
|
161 |
+
audio = audios.float().cpu()[0]
|
162 |
+
|
163 |
+
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
164 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
165 |
+
audio_save_path = output_dir / f'{current_time_string}.flac'
|
166 |
+
torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
|
167 |
+
gc.collect()
|
168 |
+
return audio_save_path
|
169 |
+
|
170 |
+
|
171 |
+
video_to_audio_tab = gr.Interface(
|
172 |
+
fn=video_to_audio,
|
173 |
+
description="""
|
174 |
+
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
175 |
+
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
176 |
+
|
177 |
+
NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
|
178 |
+
Doing so does not improve results.
|
179 |
+
""",
|
180 |
+
inputs=[
|
181 |
+
gr.Video(),
|
182 |
+
gr.Text(label='Prompt'),
|
183 |
+
gr.Text(label='Negative prompt', value='music'),
|
184 |
+
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
185 |
+
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
186 |
+
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
187 |
+
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
188 |
+
],
|
189 |
+
outputs='playable_video',
|
190 |
+
cache_examples=False,
|
191 |
+
title='MMAudio — Video-to-Audio Synthesis',
|
192 |
+
examples=[
|
193 |
+
[
|
194 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
|
195 |
+
'waves, seagulls',
|
196 |
+
'',
|
197 |
+
0,
|
198 |
+
25,
|
199 |
+
4.5,
|
200 |
+
10,
|
201 |
+
],
|
202 |
+
[
|
203 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4',
|
204 |
+
'',
|
205 |
+
'music',
|
206 |
+
0,
|
207 |
+
25,
|
208 |
+
4.5,
|
209 |
+
10,
|
210 |
+
],
|
211 |
+
[
|
212 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4',
|
213 |
+
'bubbles',
|
214 |
+
'',
|
215 |
+
0,
|
216 |
+
25,
|
217 |
+
4.5,
|
218 |
+
10,
|
219 |
+
],
|
220 |
+
[
|
221 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4',
|
222 |
+
'Indian holy music',
|
223 |
+
'',
|
224 |
+
0,
|
225 |
+
25,
|
226 |
+
4.5,
|
227 |
+
10,
|
228 |
+
],
|
229 |
+
[
|
230 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4',
|
231 |
+
'galloping',
|
232 |
+
'',
|
233 |
+
0,
|
234 |
+
25,
|
235 |
+
4.5,
|
236 |
+
10,
|
237 |
+
],
|
238 |
+
[
|
239 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
|
240 |
+
'waves, storm',
|
241 |
+
'',
|
242 |
+
0,
|
243 |
+
25,
|
244 |
+
4.5,
|
245 |
+
10,
|
246 |
+
],
|
247 |
+
[
|
248 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
|
249 |
+
'storm',
|
250 |
+
'',
|
251 |
+
0,
|
252 |
+
25,
|
253 |
+
4.5,
|
254 |
+
10,
|
255 |
+
],
|
256 |
+
[
|
257 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
|
258 |
+
'',
|
259 |
+
'',
|
260 |
+
0,
|
261 |
+
25,
|
262 |
+
4.5,
|
263 |
+
10,
|
264 |
+
],
|
265 |
+
[
|
266 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
|
267 |
+
'typing',
|
268 |
+
'',
|
269 |
+
0,
|
270 |
+
25,
|
271 |
+
4.5,
|
272 |
+
10,
|
273 |
+
],
|
274 |
+
[
|
275 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
|
276 |
+
'',
|
277 |
+
'',
|
278 |
+
0,
|
279 |
+
25,
|
280 |
+
4.5,
|
281 |
+
10,
|
282 |
+
],
|
283 |
+
[
|
284 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
|
285 |
+
'',
|
286 |
+
'',
|
287 |
+
0,
|
288 |
+
25,
|
289 |
+
4.5,
|
290 |
+
10,
|
291 |
+
],
|
292 |
+
])
|
293 |
+
|
294 |
+
text_to_audio_tab = gr.Interface(
|
295 |
+
fn=text_to_audio,
|
296 |
+
description="""
|
297 |
+
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
298 |
+
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
299 |
+
""",
|
300 |
+
inputs=[
|
301 |
+
gr.Text(label='Prompt'),
|
302 |
+
gr.Text(label='Negative prompt'),
|
303 |
+
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
304 |
+
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
305 |
+
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
306 |
+
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
307 |
+
],
|
308 |
+
outputs='audio',
|
309 |
+
cache_examples=False,
|
310 |
+
title='MMAudio — Text-to-Audio Synthesis',
|
311 |
+
)
|
312 |
+
|
313 |
+
image_to_audio_tab = gr.Interface(
|
314 |
+
fn=image_to_audio,
|
315 |
+
description="""
|
316 |
+
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
317 |
+
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
318 |
+
|
319 |
+
NOTE: It takes longer to process high-resolution images (>384 px on the shorter side).
|
320 |
+
Doing so does not improve results.
|
321 |
+
""",
|
322 |
+
inputs=[
|
323 |
+
gr.Image(type='filepath'),
|
324 |
+
gr.Text(label='Prompt'),
|
325 |
+
gr.Text(label='Negative prompt'),
|
326 |
+
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
327 |
+
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
328 |
+
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
329 |
+
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
330 |
+
],
|
331 |
+
outputs='playable_video',
|
332 |
+
cache_examples=False,
|
333 |
+
title='MMAudio — Image-to-Audio Synthesis (experimental)',
|
334 |
+
)
|
335 |
+
|
336 |
+
if __name__ == "__main__":
|
337 |
+
parser = ArgumentParser()
|
338 |
+
parser.add_argument('--port', type=int, default=7860)
|
339 |
+
args = parser.parse_args()
|
340 |
+
|
341 |
+
gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab],
|
342 |
+
['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']).launch(
|
343 |
+
server_port=args.port, allowed_paths=[output_dir])
|
mmaudio/__init__.py
ADDED
File without changes
|
mmaudio/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (187 Bytes). View file
|
|
mmaudio/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (185 Bytes). View file
|
|
mmaudio/__pycache__/eval_utils.cpython-310.pyc
ADDED
Binary file (7.07 kB). View file
|
|
mmaudio/__pycache__/eval_utils.cpython-38.pyc
ADDED
Binary file (7.03 kB). View file
|
|
mmaudio/data/__init__.py
ADDED
File without changes
|
mmaudio/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (192 Bytes). View file
|
|
mmaudio/data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (190 Bytes). View file
|
|
mmaudio/data/__pycache__/av_utils.cpython-310.pyc
ADDED
Binary file (4.91 kB). View file
|
|
mmaudio/data/__pycache__/av_utils.cpython-38.pyc
ADDED
Binary file (4.89 kB). View file
|
|
mmaudio/data/av_utils.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from fractions import Fraction
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, List, Tuple
|
5 |
+
|
6 |
+
import av
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from av import AudioFrame
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class VideoInfo:
|
14 |
+
duration_sec: float
|
15 |
+
fps: Fraction
|
16 |
+
clip_frames: torch.Tensor
|
17 |
+
sync_frames: torch.Tensor
|
18 |
+
all_frames: Optional[List[np.ndarray]]
|
19 |
+
|
20 |
+
@property
|
21 |
+
def height(self):
|
22 |
+
return self.all_frames[0].shape[0]
|
23 |
+
|
24 |
+
@property
|
25 |
+
def width(self):
|
26 |
+
return self.all_frames[0].shape[1]
|
27 |
+
|
28 |
+
@classmethod
|
29 |
+
def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
|
30 |
+
fps: Fraction) -> 'VideoInfo':
|
31 |
+
num_frames = int(duration_sec * fps)
|
32 |
+
all_frames = [image_info.original_frame] * num_frames
|
33 |
+
return cls(duration_sec=duration_sec,
|
34 |
+
fps=fps,
|
35 |
+
clip_frames=image_info.clip_frames,
|
36 |
+
sync_frames=image_info.sync_frames,
|
37 |
+
all_frames=all_frames)
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class ImageInfo:
|
42 |
+
clip_frames: torch.Tensor
|
43 |
+
sync_frames: torch.Tensor
|
44 |
+
original_frame: Optional[np.ndarray]
|
45 |
+
|
46 |
+
@property
|
47 |
+
def height(self):
|
48 |
+
return self.original_frame.shape[0]
|
49 |
+
|
50 |
+
@property
|
51 |
+
def width(self):
|
52 |
+
return self.original_frame.shape[1]
|
53 |
+
|
54 |
+
|
55 |
+
def read_frames(video_path: Path, list_of_fps: List[float], start_sec: float, end_sec: float,
|
56 |
+
need_all_frames: bool) -> Tuple[List[np.ndarray], List[np.ndarray], Fraction]:
|
57 |
+
output_frames = [[] for _ in list_of_fps]
|
58 |
+
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
|
59 |
+
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
|
60 |
+
all_frames = []
|
61 |
+
|
62 |
+
# container = av.open(video_path)
|
63 |
+
with av.open(video_path) as container:
|
64 |
+
stream = container.streams.video[0]
|
65 |
+
fps = stream.guessed_rate
|
66 |
+
stream.thread_type = 'AUTO'
|
67 |
+
for packet in container.demux(stream):
|
68 |
+
for frame in packet.decode():
|
69 |
+
frame_time = frame.time
|
70 |
+
if frame_time < start_sec:
|
71 |
+
continue
|
72 |
+
if frame_time > end_sec:
|
73 |
+
break
|
74 |
+
|
75 |
+
frame_np = None
|
76 |
+
if need_all_frames:
|
77 |
+
frame_np = frame.to_ndarray(format='rgb24')
|
78 |
+
all_frames.append(frame_np)
|
79 |
+
|
80 |
+
for i, _ in enumerate(list_of_fps):
|
81 |
+
this_time = frame_time
|
82 |
+
while this_time >= next_frame_time_for_each_fps[i]:
|
83 |
+
if frame_np is None:
|
84 |
+
frame_np = frame.to_ndarray(format='rgb24')
|
85 |
+
|
86 |
+
output_frames[i].append(frame_np)
|
87 |
+
next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
|
88 |
+
|
89 |
+
output_frames = [np.stack(frames) for frames in output_frames]
|
90 |
+
return output_frames, all_frames, fps
|
91 |
+
|
92 |
+
|
93 |
+
def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
|
94 |
+
sampling_rate: int):
|
95 |
+
container = av.open(output_path, 'w')
|
96 |
+
output_video_stream = container.add_stream('h264', video_info.fps)
|
97 |
+
output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
|
98 |
+
output_video_stream.width = video_info.width
|
99 |
+
output_video_stream.height = video_info.height
|
100 |
+
output_video_stream.pix_fmt = 'yuv420p'
|
101 |
+
|
102 |
+
output_audio_stream = container.add_stream('aac', sampling_rate)
|
103 |
+
|
104 |
+
# encode video
|
105 |
+
for image in video_info.all_frames:
|
106 |
+
image = av.VideoFrame.from_ndarray(image)
|
107 |
+
packet = output_video_stream.encode(image)
|
108 |
+
container.mux(packet)
|
109 |
+
|
110 |
+
for packet in output_video_stream.encode():
|
111 |
+
container.mux(packet)
|
112 |
+
|
113 |
+
# convert float tensor audio to numpy array
|
114 |
+
audio_np = audio.numpy().astype(np.float32)
|
115 |
+
audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
116 |
+
audio_frame.sample_rate = sampling_rate
|
117 |
+
|
118 |
+
for packet in output_audio_stream.encode(audio_frame):
|
119 |
+
container.mux(packet)
|
120 |
+
|
121 |
+
for packet in output_audio_stream.encode():
|
122 |
+
container.mux(packet)
|
123 |
+
|
124 |
+
container.close()
|
125 |
+
|
126 |
+
|
127 |
+
def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
|
128 |
+
"""
|
129 |
+
NOTE: I don't think we can get the exact video duration right without re-encoding
|
130 |
+
so we are not using this but keeping it here for reference
|
131 |
+
"""
|
132 |
+
video = av.open(video_path)
|
133 |
+
output = av.open(output_path, 'w')
|
134 |
+
input_video_stream = video.streams.video[0]
|
135 |
+
output_video_stream = output.add_stream(template=input_video_stream)
|
136 |
+
output_audio_stream = output.add_stream('aac', sampling_rate)
|
137 |
+
|
138 |
+
duration_sec = audio.shape[-1] / sampling_rate
|
139 |
+
|
140 |
+
for packet in video.demux(input_video_stream):
|
141 |
+
# We need to skip the "flushing" packets that `demux` generates.
|
142 |
+
if packet.dts is None:
|
143 |
+
continue
|
144 |
+
# We need to assign the packet to the new stream.
|
145 |
+
packet.stream = output_video_stream
|
146 |
+
output.mux(packet)
|
147 |
+
|
148 |
+
# convert float tensor audio to numpy array
|
149 |
+
audio_np = audio.numpy().astype(np.float32)
|
150 |
+
audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
151 |
+
audio_frame.sample_rate = sampling_rate
|
152 |
+
|
153 |
+
for packet in output_audio_stream.encode(audio_frame):
|
154 |
+
output.mux(packet)
|
155 |
+
|
156 |
+
for packet in output_audio_stream.encode():
|
157 |
+
output.mux(packet)
|
158 |
+
|
159 |
+
video.close()
|
160 |
+
output.close()
|
161 |
+
|
162 |
+
output.close()
|
mmaudio/data/data_setup.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from omegaconf import DictConfig
|
7 |
+
from torch.utils.data import DataLoader, Dataset
|
8 |
+
from torch.utils.data.dataloader import default_collate
|
9 |
+
from torch.utils.data.distributed import DistributedSampler
|
10 |
+
|
11 |
+
from mmaudio.data.eval.audiocaps import AudioCapsData
|
12 |
+
from mmaudio.data.eval.video_dataset import MovieGen, VGGSound
|
13 |
+
from mmaudio.data.extracted_audio import ExtractedAudio
|
14 |
+
from mmaudio.data.extracted_vgg import ExtractedVGG
|
15 |
+
from mmaudio.data.mm_dataset import MultiModalDataset
|
16 |
+
from mmaudio.utils.dist_utils import local_rank
|
17 |
+
|
18 |
+
log = logging.getLogger()
|
19 |
+
|
20 |
+
|
21 |
+
# Re-seed randomness every time we start a worker
|
22 |
+
def worker_init_fn(worker_id: int):
|
23 |
+
worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
|
24 |
+
np.random.seed(worker_seed)
|
25 |
+
random.seed(worker_seed)
|
26 |
+
log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
|
27 |
+
|
28 |
+
|
29 |
+
def load_vgg_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
30 |
+
dataset = ExtractedVGG(tsv_path=data_cfg.tsv,
|
31 |
+
data_dim=cfg.data_dim,
|
32 |
+
premade_mmap_dir=data_cfg.memmap_dir)
|
33 |
+
|
34 |
+
return dataset
|
35 |
+
|
36 |
+
|
37 |
+
def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
38 |
+
dataset = ExtractedAudio(tsv_path=data_cfg.tsv,
|
39 |
+
data_dim=cfg.data_dim,
|
40 |
+
premade_mmap_dir=data_cfg.memmap_dir)
|
41 |
+
|
42 |
+
return dataset
|
43 |
+
|
44 |
+
|
45 |
+
def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
46 |
+
if cfg.mini_train:
|
47 |
+
vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
|
48 |
+
audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
|
49 |
+
dataset = MultiModalDataset([vgg], [audiocaps])
|
50 |
+
if cfg.example_train:
|
51 |
+
video = load_vgg_data(cfg, cfg.data.Example_video)
|
52 |
+
audio = load_audio_data(cfg, cfg.data.Example_audio)
|
53 |
+
dataset = MultiModalDataset([video], [audio])
|
54 |
+
else:
|
55 |
+
# load the largest one first
|
56 |
+
freesound = load_audio_data(cfg, cfg.data.FreeSound)
|
57 |
+
vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG)
|
58 |
+
audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
|
59 |
+
audioset_sl = load_audio_data(cfg, cfg.data.AudioSetSL)
|
60 |
+
bbcsound = load_audio_data(cfg, cfg.data.BBCSound)
|
61 |
+
clotho = load_audio_data(cfg, cfg.data.Clotho)
|
62 |
+
dataset = MultiModalDataset([vgg] * cfg.vgg_oversample_rate,
|
63 |
+
[audiocaps, audioset_sl, bbcsound, freesound, clotho])
|
64 |
+
|
65 |
+
batch_size = cfg.batch_size
|
66 |
+
num_workers = cfg.num_workers
|
67 |
+
pin_memory = cfg.pin_memory
|
68 |
+
sampler, loader = construct_loader(dataset,
|
69 |
+
batch_size,
|
70 |
+
num_workers,
|
71 |
+
shuffle=True,
|
72 |
+
drop_last=True,
|
73 |
+
pin_memory=pin_memory)
|
74 |
+
|
75 |
+
return dataset, sampler, loader
|
76 |
+
|
77 |
+
|
78 |
+
def setup_test_datasets(cfg):
|
79 |
+
dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_test)
|
80 |
+
|
81 |
+
batch_size = cfg.batch_size
|
82 |
+
num_workers = cfg.num_workers
|
83 |
+
pin_memory = cfg.pin_memory
|
84 |
+
sampler, loader = construct_loader(dataset,
|
85 |
+
batch_size,
|
86 |
+
num_workers,
|
87 |
+
shuffle=False,
|
88 |
+
drop_last=False,
|
89 |
+
pin_memory=pin_memory)
|
90 |
+
|
91 |
+
return dataset, sampler, loader
|
92 |
+
|
93 |
+
|
94 |
+
def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]:
|
95 |
+
if cfg.example_train:
|
96 |
+
dataset = load_vgg_data(cfg, cfg.data.Example_video)
|
97 |
+
else:
|
98 |
+
dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
|
99 |
+
|
100 |
+
val_batch_size = cfg.batch_size
|
101 |
+
val_eval_batch_size = cfg.eval_batch_size
|
102 |
+
num_workers = cfg.num_workers
|
103 |
+
pin_memory = cfg.pin_memory
|
104 |
+
_, val_loader = construct_loader(dataset,
|
105 |
+
val_batch_size,
|
106 |
+
num_workers,
|
107 |
+
shuffle=False,
|
108 |
+
drop_last=False,
|
109 |
+
pin_memory=pin_memory)
|
110 |
+
_, eval_loader = construct_loader(dataset,
|
111 |
+
val_eval_batch_size,
|
112 |
+
num_workers,
|
113 |
+
shuffle=False,
|
114 |
+
drop_last=False,
|
115 |
+
pin_memory=pin_memory)
|
116 |
+
|
117 |
+
return dataset, val_loader, eval_loader
|
118 |
+
|
119 |
+
|
120 |
+
def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
|
121 |
+
if dataset_name.startswith('audiocaps_full'):
|
122 |
+
dataset = AudioCapsData(cfg.eval_data.AudioCaps_full.audio_path,
|
123 |
+
cfg.eval_data.AudioCaps_full.csv_path)
|
124 |
+
elif dataset_name.startswith('audiocaps'):
|
125 |
+
dataset = AudioCapsData(cfg.eval_data.AudioCaps.audio_path,
|
126 |
+
cfg.eval_data.AudioCaps.csv_path)
|
127 |
+
elif dataset_name.startswith('moviegen'):
|
128 |
+
dataset = MovieGen(cfg.eval_data.MovieGen.video_path,
|
129 |
+
cfg.eval_data.MovieGen.jsonl_path,
|
130 |
+
duration_sec=cfg.duration_s)
|
131 |
+
elif dataset_name.startswith('vggsound'):
|
132 |
+
dataset = VGGSound(cfg.eval_data.VGGSound.video_path,
|
133 |
+
cfg.eval_data.VGGSound.csv_path,
|
134 |
+
duration_sec=cfg.duration_s)
|
135 |
+
else:
|
136 |
+
raise ValueError(f'Invalid dataset name: {dataset_name}')
|
137 |
+
|
138 |
+
batch_size = cfg.batch_size
|
139 |
+
num_workers = cfg.num_workers
|
140 |
+
pin_memory = cfg.pin_memory
|
141 |
+
_, loader = construct_loader(dataset,
|
142 |
+
batch_size,
|
143 |
+
num_workers,
|
144 |
+
shuffle=False,
|
145 |
+
drop_last=False,
|
146 |
+
pin_memory=pin_memory,
|
147 |
+
error_avoidance=True)
|
148 |
+
return dataset, loader
|
149 |
+
|
150 |
+
|
151 |
+
def error_avoidance_collate(batch):
|
152 |
+
batch = list(filter(lambda x: x is not None, batch))
|
153 |
+
return default_collate(batch)
|
154 |
+
|
155 |
+
|
156 |
+
def construct_loader(dataset: Dataset,
|
157 |
+
batch_size: int,
|
158 |
+
num_workers: int,
|
159 |
+
*,
|
160 |
+
shuffle: bool = True,
|
161 |
+
drop_last: bool = True,
|
162 |
+
pin_memory: bool = False,
|
163 |
+
error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]:
|
164 |
+
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
|
165 |
+
train_loader = DataLoader(dataset,
|
166 |
+
batch_size,
|
167 |
+
sampler=train_sampler,
|
168 |
+
num_workers=num_workers,
|
169 |
+
worker_init_fn=worker_init_fn,
|
170 |
+
drop_last=drop_last,
|
171 |
+
persistent_workers=num_workers > 0,
|
172 |
+
pin_memory=pin_memory,
|
173 |
+
collate_fn=error_avoidance_collate if error_avoidance else None)
|
174 |
+
return train_sampler, train_loader
|
mmaudio/data/eval/__init__.py
ADDED
File without changes
|
mmaudio/data/eval/audiocaps.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from collections import defaultdict
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
|
11 |
+
log = logging.getLogger()
|
12 |
+
|
13 |
+
|
14 |
+
class AudioCapsData(Dataset):
|
15 |
+
|
16 |
+
def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
|
17 |
+
df = pd.read_csv(csv_path).to_dict(orient='records')
|
18 |
+
|
19 |
+
audio_files = sorted(os.listdir(audio_path))
|
20 |
+
audio_files = set(
|
21 |
+
[Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
|
22 |
+
|
23 |
+
self.data = []
|
24 |
+
for row in df:
|
25 |
+
self.data.append({
|
26 |
+
'name': row['name'],
|
27 |
+
'caption': row['caption'],
|
28 |
+
})
|
29 |
+
|
30 |
+
self.audio_path = Path(audio_path)
|
31 |
+
self.csv_path = Path(csv_path)
|
32 |
+
|
33 |
+
log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
|
34 |
+
|
35 |
+
def __getitem__(self, idx: int) -> torch.Tensor:
|
36 |
+
return self.data[idx]
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.data)
|
mmaudio/data/eval/moviegen.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils.data.dataset import Dataset
|
9 |
+
from torchvision.transforms import v2
|
10 |
+
from torio.io import StreamingMediaDecoder
|
11 |
+
|
12 |
+
from mmaudio.utils.dist_utils import local_rank
|
13 |
+
|
14 |
+
log = logging.getLogger()
|
15 |
+
|
16 |
+
_CLIP_SIZE = 384
|
17 |
+
_CLIP_FPS = 8.0
|
18 |
+
|
19 |
+
_SYNC_SIZE = 224
|
20 |
+
_SYNC_FPS = 25.0
|
21 |
+
|
22 |
+
|
23 |
+
class MovieGenData(Dataset):
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
video_root: Union[str, Path],
|
28 |
+
sync_root: Union[str, Path],
|
29 |
+
jsonl_root: Union[str, Path],
|
30 |
+
*,
|
31 |
+
duration_sec: float = 10.0,
|
32 |
+
read_clip: bool = True,
|
33 |
+
):
|
34 |
+
self.video_root = Path(video_root)
|
35 |
+
self.sync_root = Path(sync_root)
|
36 |
+
self.jsonl_root = Path(jsonl_root)
|
37 |
+
self.read_clip = read_clip
|
38 |
+
|
39 |
+
videos = sorted(os.listdir(self.video_root))
|
40 |
+
videos = [v[:-4] for v in videos] # remove extensions
|
41 |
+
self.captions = {}
|
42 |
+
|
43 |
+
for v in videos:
|
44 |
+
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
45 |
+
data = json.load(f)
|
46 |
+
self.captions[v] = data['audio_prompt']
|
47 |
+
|
48 |
+
if local_rank == 0:
|
49 |
+
log.info(f'{len(videos)} videos found in {video_root}')
|
50 |
+
|
51 |
+
self.duration_sec = duration_sec
|
52 |
+
|
53 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
54 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
55 |
+
|
56 |
+
self.clip_augment = v2.Compose([
|
57 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
58 |
+
v2.ToImage(),
|
59 |
+
v2.ToDtype(torch.float32, scale=True),
|
60 |
+
])
|
61 |
+
|
62 |
+
self.sync_augment = v2.Compose([
|
63 |
+
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
64 |
+
v2.CenterCrop(_SYNC_SIZE),
|
65 |
+
v2.ToImage(),
|
66 |
+
v2.ToDtype(torch.float32, scale=True),
|
67 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
68 |
+
])
|
69 |
+
|
70 |
+
self.videos = videos
|
71 |
+
|
72 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
73 |
+
video_id = self.videos[idx]
|
74 |
+
caption = self.captions[video_id]
|
75 |
+
|
76 |
+
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
77 |
+
reader.add_basic_video_stream(
|
78 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
79 |
+
frame_rate=_CLIP_FPS,
|
80 |
+
format='rgb24',
|
81 |
+
)
|
82 |
+
reader.add_basic_video_stream(
|
83 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
84 |
+
frame_rate=_SYNC_FPS,
|
85 |
+
format='rgb24',
|
86 |
+
)
|
87 |
+
|
88 |
+
reader.fill_buffer()
|
89 |
+
data_chunk = reader.pop_chunks()
|
90 |
+
|
91 |
+
clip_chunk = data_chunk[0]
|
92 |
+
sync_chunk = data_chunk[1]
|
93 |
+
if clip_chunk is None:
|
94 |
+
raise RuntimeError(f'CLIP video returned None {video_id}')
|
95 |
+
if clip_chunk.shape[0] < self.clip_expected_length:
|
96 |
+
raise RuntimeError(f'CLIP video too short {video_id}')
|
97 |
+
|
98 |
+
if sync_chunk is None:
|
99 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
100 |
+
if sync_chunk.shape[0] < self.sync_expected_length:
|
101 |
+
raise RuntimeError(f'Sync video too short {video_id}')
|
102 |
+
|
103 |
+
# truncate the video
|
104 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
105 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
106 |
+
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
107 |
+
f'expected {self.clip_expected_length}, '
|
108 |
+
f'got {clip_chunk.shape[0]}')
|
109 |
+
clip_chunk = self.clip_augment(clip_chunk)
|
110 |
+
|
111 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
112 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
113 |
+
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
114 |
+
f'expected {self.sync_expected_length}, '
|
115 |
+
f'got {sync_chunk.shape[0]}')
|
116 |
+
sync_chunk = self.sync_augment(sync_chunk)
|
117 |
+
|
118 |
+
data = {
|
119 |
+
'name': video_id,
|
120 |
+
'caption': caption,
|
121 |
+
'clip_video': clip_chunk,
|
122 |
+
'sync_video': sync_chunk,
|
123 |
+
}
|
124 |
+
|
125 |
+
return data
|
126 |
+
|
127 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
128 |
+
return self.sample(idx)
|
129 |
+
|
130 |
+
def __len__(self):
|
131 |
+
return len(self.captions)
|
mmaudio/data/eval/video_dataset.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
from torio.io import StreamingMediaDecoder
|
12 |
+
|
13 |
+
from mmaudio.utils.dist_utils import local_rank
|
14 |
+
|
15 |
+
log = logging.getLogger()
|
16 |
+
|
17 |
+
_CLIP_SIZE = 384
|
18 |
+
_CLIP_FPS = 8.0
|
19 |
+
|
20 |
+
_SYNC_SIZE = 224
|
21 |
+
_SYNC_FPS = 25.0
|
22 |
+
|
23 |
+
|
24 |
+
class VideoDataset(Dataset):
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
video_root: Union[str, Path],
|
29 |
+
*,
|
30 |
+
duration_sec: float = 8.0,
|
31 |
+
):
|
32 |
+
self.video_root = Path(video_root)
|
33 |
+
|
34 |
+
self.duration_sec = duration_sec
|
35 |
+
|
36 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
37 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
38 |
+
|
39 |
+
self.clip_transform = v2.Compose([
|
40 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
41 |
+
v2.ToImage(),
|
42 |
+
v2.ToDtype(torch.float32, scale=True),
|
43 |
+
])
|
44 |
+
|
45 |
+
self.sync_transform = v2.Compose([
|
46 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
47 |
+
v2.CenterCrop(_SYNC_SIZE),
|
48 |
+
v2.ToImage(),
|
49 |
+
v2.ToDtype(torch.float32, scale=True),
|
50 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
51 |
+
])
|
52 |
+
|
53 |
+
# to be implemented by subclasses
|
54 |
+
self.captions = {}
|
55 |
+
self.videos = sorted(list(self.captions.keys()))
|
56 |
+
|
57 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
58 |
+
video_id = self.videos[idx]
|
59 |
+
caption = self.captions[video_id]
|
60 |
+
|
61 |
+
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
62 |
+
reader.add_basic_video_stream(
|
63 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
64 |
+
frame_rate=_CLIP_FPS,
|
65 |
+
format='rgb24',
|
66 |
+
)
|
67 |
+
reader.add_basic_video_stream(
|
68 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
69 |
+
frame_rate=_SYNC_FPS,
|
70 |
+
format='rgb24',
|
71 |
+
)
|
72 |
+
|
73 |
+
reader.fill_buffer()
|
74 |
+
data_chunk = reader.pop_chunks()
|
75 |
+
|
76 |
+
clip_chunk = data_chunk[0]
|
77 |
+
sync_chunk = data_chunk[1]
|
78 |
+
if clip_chunk is None:
|
79 |
+
raise RuntimeError(f'CLIP video returned None {video_id}')
|
80 |
+
if clip_chunk.shape[0] < self.clip_expected_length:
|
81 |
+
raise RuntimeError(
|
82 |
+
f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
83 |
+
)
|
84 |
+
|
85 |
+
if sync_chunk is None:
|
86 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
87 |
+
if sync_chunk.shape[0] < self.sync_expected_length:
|
88 |
+
raise RuntimeError(
|
89 |
+
f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
|
90 |
+
)
|
91 |
+
|
92 |
+
# truncate the video
|
93 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
94 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
95 |
+
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
96 |
+
f'expected {self.clip_expected_length}, '
|
97 |
+
f'got {clip_chunk.shape[0]}')
|
98 |
+
clip_chunk = self.clip_transform(clip_chunk)
|
99 |
+
|
100 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
101 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
102 |
+
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
103 |
+
f'expected {self.sync_expected_length}, '
|
104 |
+
f'got {sync_chunk.shape[0]}')
|
105 |
+
sync_chunk = self.sync_transform(sync_chunk)
|
106 |
+
|
107 |
+
data = {
|
108 |
+
'name': video_id,
|
109 |
+
'caption': caption,
|
110 |
+
'clip_video': clip_chunk,
|
111 |
+
'sync_video': sync_chunk,
|
112 |
+
}
|
113 |
+
|
114 |
+
return data
|
115 |
+
|
116 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
117 |
+
try:
|
118 |
+
return self.sample(idx)
|
119 |
+
except Exception as e:
|
120 |
+
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
121 |
+
return None
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return len(self.captions)
|
125 |
+
|
126 |
+
|
127 |
+
class VGGSound(VideoDataset):
|
128 |
+
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
video_root: Union[str, Path],
|
132 |
+
csv_path: Union[str, Path],
|
133 |
+
*,
|
134 |
+
duration_sec: float = 8.0,
|
135 |
+
):
|
136 |
+
super().__init__(video_root, duration_sec=duration_sec)
|
137 |
+
self.video_root = Path(video_root)
|
138 |
+
self.csv_path = Path(csv_path)
|
139 |
+
|
140 |
+
videos = sorted(os.listdir(self.video_root))
|
141 |
+
if local_rank == 0:
|
142 |
+
log.info(f'{len(videos)} videos found in {video_root}')
|
143 |
+
self.captions = {}
|
144 |
+
|
145 |
+
df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
|
146 |
+
'split']).to_dict(orient='records')
|
147 |
+
|
148 |
+
videos_no_found = []
|
149 |
+
for row in df:
|
150 |
+
if row['split'] == 'test':
|
151 |
+
start_sec = int(row['sec'])
|
152 |
+
video_id = str(row['id'])
|
153 |
+
# this is how our videos are named
|
154 |
+
video_name = f'{video_id}_{start_sec:06d}'
|
155 |
+
if video_name + '.mp4' not in videos:
|
156 |
+
videos_no_found.append(video_name)
|
157 |
+
continue
|
158 |
+
|
159 |
+
self.captions[video_name] = row['caption']
|
160 |
+
|
161 |
+
if local_rank == 0:
|
162 |
+
log.info(f'{len(videos)} videos found in {video_root}')
|
163 |
+
log.info(f'{len(self.captions)} useable videos found')
|
164 |
+
if videos_no_found:
|
165 |
+
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
|
166 |
+
log.info(
|
167 |
+
'A small amount is expected, as not all videos are still available on YouTube')
|
168 |
+
|
169 |
+
self.videos = sorted(list(self.captions.keys()))
|
170 |
+
|
171 |
+
|
172 |
+
class MovieGen(VideoDataset):
|
173 |
+
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
video_root: Union[str, Path],
|
177 |
+
jsonl_root: Union[str, Path],
|
178 |
+
*,
|
179 |
+
duration_sec: float = 10.0,
|
180 |
+
):
|
181 |
+
super().__init__(video_root, duration_sec=duration_sec)
|
182 |
+
self.video_root = Path(video_root)
|
183 |
+
self.jsonl_root = Path(jsonl_root)
|
184 |
+
|
185 |
+
videos = sorted(os.listdir(self.video_root))
|
186 |
+
videos = [v[:-4] for v in videos] # remove extensions
|
187 |
+
self.captions = {}
|
188 |
+
|
189 |
+
for v in videos:
|
190 |
+
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
191 |
+
data = json.load(f)
|
192 |
+
self.captions[v] = data['audio_prompt']
|
193 |
+
|
194 |
+
if local_rank == 0:
|
195 |
+
log.info(f'{len(videos)} videos found in {video_root}')
|
196 |
+
|
197 |
+
self.videos = videos
|
mmaudio/data/extracted_audio.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from tensordict import TensorDict
|
8 |
+
from torch.utils.data.dataset import Dataset
|
9 |
+
|
10 |
+
from mmaudio.utils.dist_utils import local_rank
|
11 |
+
|
12 |
+
log = logging.getLogger()
|
13 |
+
|
14 |
+
|
15 |
+
class ExtractedAudio(Dataset):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
tsv_path: Union[str, Path],
|
20 |
+
*,
|
21 |
+
premade_mmap_dir: Union[str, Path],
|
22 |
+
data_dim: dict[str, int],
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.data_dim = data_dim
|
27 |
+
self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
|
28 |
+
self.ids = [str(d['id']) for d in self.df_list]
|
29 |
+
|
30 |
+
log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
|
31 |
+
# load precomputed memory mapped tensors
|
32 |
+
premade_mmap_dir = Path(premade_mmap_dir)
|
33 |
+
td = TensorDict.load_memmap(premade_mmap_dir)
|
34 |
+
log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
|
35 |
+
self.mean = td['mean']
|
36 |
+
self.std = td['std']
|
37 |
+
self.text_features = td['text_features']
|
38 |
+
|
39 |
+
log.info(f'Loaded {len(self)} samples from {premade_mmap_dir}.')
|
40 |
+
log.info(f'Loaded mean: {self.mean.shape}.')
|
41 |
+
log.info(f'Loaded std: {self.std.shape}.')
|
42 |
+
log.info(f'Loaded text features: {self.text_features.shape}.')
|
43 |
+
|
44 |
+
assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
|
45 |
+
f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
46 |
+
assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
|
47 |
+
f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
48 |
+
|
49 |
+
assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
|
50 |
+
f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
|
51 |
+
assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
|
52 |
+
f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
|
53 |
+
|
54 |
+
self.fake_clip_features = torch.zeros(self.data_dim['clip_seq_len'],
|
55 |
+
self.data_dim['clip_dim'])
|
56 |
+
self.fake_sync_features = torch.zeros(self.data_dim['sync_seq_len'],
|
57 |
+
self.data_dim['sync_dim'])
|
58 |
+
self.video_exist = torch.tensor(0, dtype=torch.bool)
|
59 |
+
self.text_exist = torch.tensor(1, dtype=torch.bool)
|
60 |
+
|
61 |
+
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
62 |
+
latents = self.mean
|
63 |
+
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
|
64 |
+
|
65 |
+
def get_memory_mapped_tensor(self) -> TensorDict:
|
66 |
+
td = TensorDict({
|
67 |
+
'mean': self.mean,
|
68 |
+
'std': self.std,
|
69 |
+
'text_features': self.text_features,
|
70 |
+
})
|
71 |
+
return td
|
72 |
+
|
73 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
74 |
+
data = {
|
75 |
+
'id': str(self.df_list[idx]['id']),
|
76 |
+
'a_mean': self.mean[idx],
|
77 |
+
'a_std': self.std[idx],
|
78 |
+
'clip_features': self.fake_clip_features,
|
79 |
+
'sync_features': self.fake_sync_features,
|
80 |
+
'text_features': self.text_features[idx],
|
81 |
+
'caption': self.df_list[idx]['caption'],
|
82 |
+
'video_exist': self.video_exist,
|
83 |
+
'text_exist': self.text_exist,
|
84 |
+
}
|
85 |
+
return data
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return len(self.ids)
|
mmaudio/data/extracted_vgg.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from tensordict import TensorDict
|
8 |
+
from torch.utils.data.dataset import Dataset
|
9 |
+
|
10 |
+
from mmaudio.utils.dist_utils import local_rank
|
11 |
+
|
12 |
+
log = logging.getLogger()
|
13 |
+
|
14 |
+
|
15 |
+
class ExtractedVGG(Dataset):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
tsv_path: Union[str, Path],
|
20 |
+
*,
|
21 |
+
premade_mmap_dir: Union[str, Path],
|
22 |
+
data_dim: dict[str, int],
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.data_dim = data_dim
|
27 |
+
self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
|
28 |
+
self.ids = [d['id'] for d in self.df_list]
|
29 |
+
|
30 |
+
log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
|
31 |
+
# load precomputed memory mapped tensors
|
32 |
+
premade_mmap_dir = Path(premade_mmap_dir)
|
33 |
+
td = TensorDict.load_memmap(premade_mmap_dir)
|
34 |
+
log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
|
35 |
+
self.mean = td['mean']
|
36 |
+
self.std = td['std']
|
37 |
+
self.clip_features = td['clip_features']
|
38 |
+
self.sync_features = td['sync_features']
|
39 |
+
self.text_features = td['text_features']
|
40 |
+
|
41 |
+
if local_rank == 0:
|
42 |
+
log.info(f'Loaded {len(self)} samples.')
|
43 |
+
log.info(f'Loaded mean: {self.mean.shape}.')
|
44 |
+
log.info(f'Loaded std: {self.std.shape}.')
|
45 |
+
log.info(f'Loaded clip_features: {self.clip_features.shape}.')
|
46 |
+
log.info(f'Loaded sync_features: {self.sync_features.shape}.')
|
47 |
+
log.info(f'Loaded text_features: {self.text_features.shape}.')
|
48 |
+
|
49 |
+
assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
|
50 |
+
f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
51 |
+
assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
|
52 |
+
f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
53 |
+
|
54 |
+
assert self.clip_features.shape[1] == self.data_dim['clip_seq_len'], \
|
55 |
+
f'{self.clip_features.shape[1]} != {self.data_dim["clip_seq_len"]}'
|
56 |
+
assert self.sync_features.shape[1] == self.data_dim['sync_seq_len'], \
|
57 |
+
f'{self.sync_features.shape[1]} != {self.data_dim["sync_seq_len"]}'
|
58 |
+
assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
|
59 |
+
f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
|
60 |
+
|
61 |
+
assert self.clip_features.shape[-1] == self.data_dim['clip_dim'], \
|
62 |
+
f'{self.clip_features.shape[-1]} != {self.data_dim["clip_dim"]}'
|
63 |
+
assert self.sync_features.shape[-1] == self.data_dim['sync_dim'], \
|
64 |
+
f'{self.sync_features.shape[-1]} != {self.data_dim["sync_dim"]}'
|
65 |
+
assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
|
66 |
+
f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
|
67 |
+
|
68 |
+
self.video_exist = torch.tensor(1, dtype=torch.bool)
|
69 |
+
self.text_exist = torch.tensor(1, dtype=torch.bool)
|
70 |
+
|
71 |
+
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
72 |
+
latents = self.mean
|
73 |
+
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
|
74 |
+
|
75 |
+
def get_memory_mapped_tensor(self) -> TensorDict:
|
76 |
+
td = TensorDict({
|
77 |
+
'mean': self.mean,
|
78 |
+
'std': self.std,
|
79 |
+
'clip_features': self.clip_features,
|
80 |
+
'sync_features': self.sync_features,
|
81 |
+
'text_features': self.text_features,
|
82 |
+
})
|
83 |
+
return td
|
84 |
+
|
85 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
86 |
+
data = {
|
87 |
+
'id': self.df_list[idx]['id'],
|
88 |
+
'a_mean': self.mean[idx],
|
89 |
+
'a_std': self.std[idx],
|
90 |
+
'clip_features': self.clip_features[idx],
|
91 |
+
'sync_features': self.sync_features[idx],
|
92 |
+
'text_features': self.text_features[idx],
|
93 |
+
'caption': self.df_list[idx]['label'],
|
94 |
+
'video_exist': self.video_exist,
|
95 |
+
'text_exist': self.text_exist,
|
96 |
+
}
|
97 |
+
|
98 |
+
return data
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.ids)
|
mmaudio/data/extraction/__init__.py
ADDED
File without changes
|
mmaudio/data/extraction/vgg_sound.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Union
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
from torio.io import StreamingMediaDecoder
|
12 |
+
|
13 |
+
from mmaudio.utils.dist_utils import local_rank
|
14 |
+
|
15 |
+
log = logging.getLogger()
|
16 |
+
|
17 |
+
_CLIP_SIZE = 384
|
18 |
+
_CLIP_FPS = 8.0
|
19 |
+
|
20 |
+
_SYNC_SIZE = 224
|
21 |
+
_SYNC_FPS = 25.0
|
22 |
+
|
23 |
+
|
24 |
+
class VGGSound(Dataset):
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
root: Union[str, Path],
|
29 |
+
*,
|
30 |
+
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
|
31 |
+
sample_rate: int = 16_000,
|
32 |
+
duration_sec: float = 8.0,
|
33 |
+
audio_samples: Optional[int] = None,
|
34 |
+
normalize_audio: bool = False,
|
35 |
+
):
|
36 |
+
self.root = Path(root)
|
37 |
+
self.normalize_audio = normalize_audio
|
38 |
+
if audio_samples is None:
|
39 |
+
self.audio_samples = int(sample_rate * duration_sec)
|
40 |
+
else:
|
41 |
+
self.audio_samples = audio_samples
|
42 |
+
effective_duration = audio_samples / sample_rate
|
43 |
+
# make sure the duration is close enough, within 15ms
|
44 |
+
assert abs(effective_duration - duration_sec) < 0.015, \
|
45 |
+
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
46 |
+
|
47 |
+
videos = sorted(os.listdir(self.root))
|
48 |
+
videos = set([Path(v).stem for v in videos]) # remove extensions
|
49 |
+
self.labels = {}
|
50 |
+
self.videos = []
|
51 |
+
missing_videos = []
|
52 |
+
|
53 |
+
# read the tsv for subset information
|
54 |
+
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
|
55 |
+
for record in df_list:
|
56 |
+
id = record['id']
|
57 |
+
label = record['label']
|
58 |
+
if id in videos:
|
59 |
+
self.labels[id] = label
|
60 |
+
self.videos.append(id)
|
61 |
+
else:
|
62 |
+
missing_videos.append(id)
|
63 |
+
|
64 |
+
if local_rank == 0:
|
65 |
+
log.info(f'{len(videos)} videos found in {root}')
|
66 |
+
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
67 |
+
log.info(f'{len(missing_videos)} videos missing in {root}')
|
68 |
+
|
69 |
+
self.sample_rate = sample_rate
|
70 |
+
self.duration_sec = duration_sec
|
71 |
+
|
72 |
+
self.expected_audio_length = audio_samples
|
73 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
74 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
75 |
+
|
76 |
+
self.clip_transform = v2.Compose([
|
77 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
78 |
+
v2.ToImage(),
|
79 |
+
v2.ToDtype(torch.float32, scale=True),
|
80 |
+
])
|
81 |
+
|
82 |
+
self.sync_transform = v2.Compose([
|
83 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
84 |
+
v2.CenterCrop(_SYNC_SIZE),
|
85 |
+
v2.ToImage(),
|
86 |
+
v2.ToDtype(torch.float32, scale=True),
|
87 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
88 |
+
])
|
89 |
+
|
90 |
+
self.resampler = {}
|
91 |
+
|
92 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
93 |
+
video_id = self.videos[idx]
|
94 |
+
label = self.labels[video_id]
|
95 |
+
|
96 |
+
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
97 |
+
reader.add_basic_video_stream(
|
98 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
99 |
+
frame_rate=_CLIP_FPS,
|
100 |
+
format='rgb24',
|
101 |
+
)
|
102 |
+
reader.add_basic_video_stream(
|
103 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
104 |
+
frame_rate=_SYNC_FPS,
|
105 |
+
format='rgb24',
|
106 |
+
)
|
107 |
+
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
|
108 |
+
|
109 |
+
reader.fill_buffer()
|
110 |
+
data_chunk = reader.pop_chunks()
|
111 |
+
|
112 |
+
clip_chunk = data_chunk[0]
|
113 |
+
sync_chunk = data_chunk[1]
|
114 |
+
audio_chunk = data_chunk[2]
|
115 |
+
|
116 |
+
if clip_chunk is None:
|
117 |
+
raise RuntimeError(f'CLIP video returned None {video_id}')
|
118 |
+
if clip_chunk.shape[0] < self.clip_expected_length:
|
119 |
+
raise RuntimeError(
|
120 |
+
f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
121 |
+
)
|
122 |
+
|
123 |
+
if sync_chunk is None:
|
124 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
125 |
+
if sync_chunk.shape[0] < self.sync_expected_length:
|
126 |
+
raise RuntimeError(
|
127 |
+
f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
|
128 |
+
)
|
129 |
+
|
130 |
+
# process audio
|
131 |
+
sample_rate = int(reader.get_out_stream_info(2).sample_rate)
|
132 |
+
audio_chunk = audio_chunk.transpose(0, 1)
|
133 |
+
audio_chunk = audio_chunk.mean(dim=0) # mono
|
134 |
+
if self.normalize_audio:
|
135 |
+
abs_max = audio_chunk.abs().max()
|
136 |
+
audio_chunk = audio_chunk / abs_max * 0.95
|
137 |
+
if abs_max <= 1e-6:
|
138 |
+
raise RuntimeError(f'Audio is silent {video_id}')
|
139 |
+
|
140 |
+
# resample
|
141 |
+
if sample_rate == self.sample_rate:
|
142 |
+
audio_chunk = audio_chunk
|
143 |
+
else:
|
144 |
+
if sample_rate not in self.resampler:
|
145 |
+
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
146 |
+
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
147 |
+
sample_rate,
|
148 |
+
self.sample_rate,
|
149 |
+
lowpass_filter_width=64,
|
150 |
+
rolloff=0.9475937167399596,
|
151 |
+
resampling_method='sinc_interp_kaiser',
|
152 |
+
beta=14.769656459379492,
|
153 |
+
)
|
154 |
+
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
155 |
+
|
156 |
+
if audio_chunk.shape[0] < self.expected_audio_length:
|
157 |
+
raise RuntimeError(f'Audio too short {video_id}')
|
158 |
+
audio_chunk = audio_chunk[:self.expected_audio_length]
|
159 |
+
|
160 |
+
# truncate the video
|
161 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
162 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
163 |
+
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
164 |
+
f'expected {self.clip_expected_length}, '
|
165 |
+
f'got {clip_chunk.shape[0]}')
|
166 |
+
clip_chunk = self.clip_transform(clip_chunk)
|
167 |
+
|
168 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
169 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
170 |
+
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
171 |
+
f'expected {self.sync_expected_length}, '
|
172 |
+
f'got {sync_chunk.shape[0]}')
|
173 |
+
sync_chunk = self.sync_transform(sync_chunk)
|
174 |
+
|
175 |
+
data = {
|
176 |
+
'id': video_id,
|
177 |
+
'caption': label,
|
178 |
+
'audio': audio_chunk,
|
179 |
+
'clip_video': clip_chunk,
|
180 |
+
'sync_video': sync_chunk,
|
181 |
+
}
|
182 |
+
|
183 |
+
return data
|
184 |
+
|
185 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
186 |
+
try:
|
187 |
+
return self.sample(idx)
|
188 |
+
except Exception as e:
|
189 |
+
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
190 |
+
return None
|
191 |
+
|
192 |
+
def __len__(self):
|
193 |
+
return len(self.labels)
|
mmaudio/data/extraction/wav_dataset.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import open_clip
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
from torch.utils.data.dataset import Dataset
|
11 |
+
|
12 |
+
log = logging.getLogger()
|
13 |
+
|
14 |
+
|
15 |
+
class WavTextClipsDataset(Dataset):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
root: Union[str, Path],
|
20 |
+
*,
|
21 |
+
captions_tsv: Union[str, Path],
|
22 |
+
clips_tsv: Union[str, Path],
|
23 |
+
sample_rate: int,
|
24 |
+
num_samples: int,
|
25 |
+
normalize_audio: bool = False,
|
26 |
+
reject_silent: bool = False,
|
27 |
+
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
|
28 |
+
):
|
29 |
+
self.root = Path(root)
|
30 |
+
self.sample_rate = sample_rate
|
31 |
+
self.num_samples = num_samples
|
32 |
+
self.normalize_audio = normalize_audio
|
33 |
+
self.reject_silent = reject_silent
|
34 |
+
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
|
35 |
+
|
36 |
+
audios = sorted(os.listdir(self.root))
|
37 |
+
audios = set([
|
38 |
+
Path(audio).stem for audio in audios
|
39 |
+
if audio.endswith('.wav') or audio.endswith('.flac')
|
40 |
+
])
|
41 |
+
self.captions = {}
|
42 |
+
|
43 |
+
# read the caption tsv
|
44 |
+
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
|
45 |
+
for record in df_list:
|
46 |
+
id = record['id']
|
47 |
+
caption = record['caption']
|
48 |
+
self.captions[id] = caption
|
49 |
+
|
50 |
+
# read the clip tsv
|
51 |
+
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
|
52 |
+
'id': str,
|
53 |
+
'name': str
|
54 |
+
}).to_dict('records')
|
55 |
+
self.clips = []
|
56 |
+
for record in df_list:
|
57 |
+
record['id'] = record['id']
|
58 |
+
record['name'] = record['name']
|
59 |
+
id = record['id']
|
60 |
+
name = record['name']
|
61 |
+
if name not in self.captions:
|
62 |
+
log.warning(f'Audio {name} not found in {captions_tsv}')
|
63 |
+
continue
|
64 |
+
record['caption'] = self.captions[name]
|
65 |
+
self.clips.append(record)
|
66 |
+
|
67 |
+
log.info(f'Found {len(self.clips)} audio files in {self.root}')
|
68 |
+
|
69 |
+
self.resampler = {}
|
70 |
+
|
71 |
+
def __getitem__(self, idx: int) -> torch.Tensor:
|
72 |
+
try:
|
73 |
+
clip = self.clips[idx]
|
74 |
+
audio_name = clip['name']
|
75 |
+
audio_id = clip['id']
|
76 |
+
caption = clip['caption']
|
77 |
+
start_sample = clip['start_sample']
|
78 |
+
end_sample = clip['end_sample']
|
79 |
+
|
80 |
+
audio_path = self.root / f'{audio_name}.flac'
|
81 |
+
if not audio_path.exists():
|
82 |
+
audio_path = self.root / f'{audio_name}.wav'
|
83 |
+
assert audio_path.exists()
|
84 |
+
|
85 |
+
audio_chunk, sample_rate = torchaudio.load(audio_path)
|
86 |
+
audio_chunk = audio_chunk.mean(dim=0) # mono
|
87 |
+
abs_max = audio_chunk.abs().max()
|
88 |
+
if self.normalize_audio:
|
89 |
+
audio_chunk = audio_chunk / abs_max * 0.95
|
90 |
+
|
91 |
+
if self.reject_silent and abs_max < 1e-6:
|
92 |
+
log.warning(f'Rejecting silent audio')
|
93 |
+
return None
|
94 |
+
|
95 |
+
audio_chunk = audio_chunk[start_sample:end_sample]
|
96 |
+
|
97 |
+
# resample
|
98 |
+
if sample_rate == self.sample_rate:
|
99 |
+
audio_chunk = audio_chunk
|
100 |
+
else:
|
101 |
+
if sample_rate not in self.resampler:
|
102 |
+
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
103 |
+
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
104 |
+
sample_rate,
|
105 |
+
self.sample_rate,
|
106 |
+
lowpass_filter_width=64,
|
107 |
+
rolloff=0.9475937167399596,
|
108 |
+
resampling_method='sinc_interp_kaiser',
|
109 |
+
beta=14.769656459379492,
|
110 |
+
)
|
111 |
+
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
112 |
+
|
113 |
+
if audio_chunk.shape[0] < self.num_samples:
|
114 |
+
raise ValueError('Audio is too short')
|
115 |
+
audio_chunk = audio_chunk[:self.num_samples]
|
116 |
+
|
117 |
+
tokens = self.tokenizer([caption])[0]
|
118 |
+
|
119 |
+
output = {
|
120 |
+
'waveform': audio_chunk,
|
121 |
+
'id': audio_id,
|
122 |
+
'caption': caption,
|
123 |
+
'tokens': tokens,
|
124 |
+
}
|
125 |
+
|
126 |
+
return output
|
127 |
+
except Exception as e:
|
128 |
+
log.error(f'Error reading {audio_path}: {e}')
|
129 |
+
return None
|
130 |
+
|
131 |
+
def __len__(self):
|
132 |
+
return len(self.clips)
|
mmaudio/data/mm_dataset.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bisect
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.utils.data.dataset import Dataset
|
5 |
+
|
6 |
+
|
7 |
+
# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
|
8 |
+
class MultiModalDataset(Dataset):
|
9 |
+
datasets: list[Dataset]
|
10 |
+
cumulative_sizes: list[int]
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def cumsum(sequence):
|
14 |
+
r, s = [], 0
|
15 |
+
for e in sequence:
|
16 |
+
l = len(e)
|
17 |
+
r.append(l + s)
|
18 |
+
s += l
|
19 |
+
return r
|
20 |
+
|
21 |
+
def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
|
22 |
+
super().__init__()
|
23 |
+
self.video_datasets = list(video_datasets)
|
24 |
+
self.audio_datasets = list(audio_datasets)
|
25 |
+
self.datasets = self.video_datasets + self.audio_datasets
|
26 |
+
|
27 |
+
self.cumulative_sizes = self.cumsum(self.datasets)
|
28 |
+
|
29 |
+
def __len__(self):
|
30 |
+
return self.cumulative_sizes[-1]
|
31 |
+
|
32 |
+
def __getitem__(self, idx):
|
33 |
+
if idx < 0:
|
34 |
+
if -idx > len(self):
|
35 |
+
raise ValueError("absolute value of index should not exceed dataset length")
|
36 |
+
idx = len(self) + idx
|
37 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
38 |
+
if dataset_idx == 0:
|
39 |
+
sample_idx = idx
|
40 |
+
else:
|
41 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
42 |
+
return self.datasets[dataset_idx][sample_idx]
|
43 |
+
|
44 |
+
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
45 |
+
return self.video_datasets[0].compute_latent_stats()
|
mmaudio/data/utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import tempfile
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any, Optional, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from tensordict import MemoryMappedTensor
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from torch.utils.data.dataset import Dataset
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from mmaudio.utils.dist_utils import local_rank, world_size
|
16 |
+
|
17 |
+
scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
|
18 |
+
shm_path = Path('/dev/shm')
|
19 |
+
|
20 |
+
log = logging.getLogger()
|
21 |
+
|
22 |
+
|
23 |
+
def reseed(seed):
|
24 |
+
random.seed(seed)
|
25 |
+
torch.manual_seed(seed)
|
26 |
+
|
27 |
+
|
28 |
+
def local_scatter_torch(obj: Optional[Any]):
|
29 |
+
if world_size == 1:
|
30 |
+
# Just one worker. Do nothing.
|
31 |
+
return obj
|
32 |
+
|
33 |
+
array = [obj] * world_size
|
34 |
+
target_array = [None]
|
35 |
+
if local_rank == 0:
|
36 |
+
dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
|
37 |
+
else:
|
38 |
+
dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
|
39 |
+
return target_array[0]
|
40 |
+
|
41 |
+
|
42 |
+
class ShardDataset(Dataset):
|
43 |
+
|
44 |
+
def __init__(self, root):
|
45 |
+
self.root = root
|
46 |
+
self.shards = sorted(os.listdir(root))
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.shards)
|
50 |
+
|
51 |
+
def __getitem__(self, idx):
|
52 |
+
return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
|
53 |
+
|
54 |
+
|
55 |
+
def get_tmp_dir(in_memory: bool) -> Path:
|
56 |
+
return shm_path if in_memory else scratch_path
|
57 |
+
|
58 |
+
|
59 |
+
def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
|
60 |
+
in_memory: bool) -> MemoryMappedTensor:
|
61 |
+
if local_rank == 0:
|
62 |
+
with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
|
63 |
+
log.info(f'Loading shards from {data_path} into {f.name}...')
|
64 |
+
data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
|
65 |
+
data = share_tensor_to_all(data)
|
66 |
+
torch.distributed.barrier()
|
67 |
+
f.close() # why does the context manager not close the file for me?
|
68 |
+
else:
|
69 |
+
log.info('Waiting for the data to be shared with me...')
|
70 |
+
data = share_tensor_to_all(None)
|
71 |
+
torch.distributed.barrier()
|
72 |
+
|
73 |
+
return data
|
74 |
+
|
75 |
+
|
76 |
+
def load_shards(
|
77 |
+
data_path: Union[str, Path],
|
78 |
+
ids: list[int],
|
79 |
+
*,
|
80 |
+
tmp_file_path: str,
|
81 |
+
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
|
82 |
+
|
83 |
+
id_set = set(ids)
|
84 |
+
shards = sorted(os.listdir(data_path))
|
85 |
+
log.info(f'Found {len(shards)} shards in {data_path}.')
|
86 |
+
first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
|
87 |
+
|
88 |
+
log.info(f'Rank {local_rank} created file {tmp_file_path}')
|
89 |
+
first_item = next(iter(first_shard.values()))
|
90 |
+
log.info(f'First item shape: {first_item.shape}')
|
91 |
+
mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
|
92 |
+
dtype=torch.float32,
|
93 |
+
filename=tmp_file_path,
|
94 |
+
existsok=True)
|
95 |
+
total_count = 0
|
96 |
+
used_index = set()
|
97 |
+
id_indexing = {i: idx for idx, i in enumerate(ids)}
|
98 |
+
# faster with no workers; otherwise we need to set_sharing_strategy('file_system')
|
99 |
+
loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
|
100 |
+
for data in tqdm(loader, desc='Loading shards'):
|
101 |
+
for i, v in data.items():
|
102 |
+
if i not in id_set:
|
103 |
+
continue
|
104 |
+
|
105 |
+
# tensor_index = ids.index(i)
|
106 |
+
tensor_index = id_indexing[i]
|
107 |
+
if tensor_index in used_index:
|
108 |
+
raise ValueError(f'Duplicate id {i} found in {data_path}.')
|
109 |
+
used_index.add(tensor_index)
|
110 |
+
mm_tensor[tensor_index] = v
|
111 |
+
total_count += 1
|
112 |
+
|
113 |
+
assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
|
114 |
+
log.info(f'Loaded {total_count} tensors from {data_path}.')
|
115 |
+
|
116 |
+
return mm_tensor
|
117 |
+
|
118 |
+
|
119 |
+
def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
|
120 |
+
"""
|
121 |
+
x: the tensor to be shared; None if local_rank != 0
|
122 |
+
return: the shared tensor
|
123 |
+
"""
|
124 |
+
|
125 |
+
# there is no need to share your stuff with anyone if you are alone; must be in memory
|
126 |
+
if world_size == 1:
|
127 |
+
return x
|
128 |
+
|
129 |
+
if local_rank == 0:
|
130 |
+
assert x is not None, 'x must not be None if local_rank == 0'
|
131 |
+
else:
|
132 |
+
assert x is None, 'x must be None if local_rank != 0'
|
133 |
+
|
134 |
+
if local_rank == 0:
|
135 |
+
filename = x.filename
|
136 |
+
meta_information = (filename, x.shape, x.dtype)
|
137 |
+
else:
|
138 |
+
meta_information = None
|
139 |
+
|
140 |
+
filename, data_shape, data_type = local_scatter_torch(meta_information)
|
141 |
+
if local_rank == 0:
|
142 |
+
data = x
|
143 |
+
else:
|
144 |
+
data = MemoryMappedTensor.from_filename(filename=filename,
|
145 |
+
dtype=data_type,
|
146 |
+
shape=data_shape)
|
147 |
+
|
148 |
+
return data
|
mmaudio/eval_utils.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Tuple, List, Dict
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from colorlog import ColoredFormatter
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
|
12 |
+
from mmaudio.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio
|
13 |
+
from mmaudio.model.flow_matching import FlowMatching
|
14 |
+
from mmaudio.model.networks import MMAudio
|
15 |
+
from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
|
16 |
+
from mmaudio.model.utils.features_utils import FeaturesUtils
|
17 |
+
from mmaudio.utils.download_utils import download_model_if_needed
|
18 |
+
|
19 |
+
log = logging.getLogger()
|
20 |
+
|
21 |
+
|
22 |
+
@dataclasses.dataclass
|
23 |
+
class ModelConfig:
|
24 |
+
model_name: str
|
25 |
+
model_path: Path
|
26 |
+
vae_path: Path
|
27 |
+
bigvgan_16k_path: Optional[Path]
|
28 |
+
mode: str
|
29 |
+
synchformer_ckpt: Path = Path('./pretrained/v2a/mmaudio/ext_weights/synchformer_state_dict.pth')
|
30 |
+
|
31 |
+
@property
|
32 |
+
def seq_cfg(self) -> SequenceConfig:
|
33 |
+
if self.mode == '16k':
|
34 |
+
return CONFIG_16K
|
35 |
+
elif self.mode == '44k':
|
36 |
+
return CONFIG_44K
|
37 |
+
|
38 |
+
def download_if_needed(self):
|
39 |
+
download_model_if_needed(self.model_path)
|
40 |
+
download_model_if_needed(self.vae_path)
|
41 |
+
if self.bigvgan_16k_path is not None:
|
42 |
+
download_model_if_needed(self.bigvgan_16k_path)
|
43 |
+
download_model_if_needed(self.synchformer_ckpt)
|
44 |
+
|
45 |
+
|
46 |
+
small_16k = ModelConfig(model_name='small_16k',
|
47 |
+
model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_small_16k.pth'),
|
48 |
+
vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-16.pth'),
|
49 |
+
bigvgan_16k_path=Path('./pretrained/v2a/mmaudio/ext_weights/best_netG.pt'),
|
50 |
+
mode='16k')
|
51 |
+
small_44k = ModelConfig(model_name='small_44k',
|
52 |
+
model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_small_44k.pth'),
|
53 |
+
vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'),
|
54 |
+
bigvgan_16k_path=None,
|
55 |
+
mode='44k')
|
56 |
+
medium_44k = ModelConfig(model_name='medium_44k',
|
57 |
+
model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_medium_44k.pth'),
|
58 |
+
vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'),
|
59 |
+
bigvgan_16k_path=None,
|
60 |
+
mode='44k')
|
61 |
+
large_44k = ModelConfig(model_name='large_44k',
|
62 |
+
model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_large_44k.pth'),
|
63 |
+
vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'),
|
64 |
+
bigvgan_16k_path=None,
|
65 |
+
mode='44k')
|
66 |
+
large_44k_v2 = ModelConfig(model_name='large_44k_v2',
|
67 |
+
model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_large_44k_v2.pth'),
|
68 |
+
vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'),
|
69 |
+
bigvgan_16k_path=None,
|
70 |
+
mode='44k')
|
71 |
+
all_model_cfg: Dict[str, ModelConfig] = {
|
72 |
+
'small_16k': small_16k,
|
73 |
+
'small_44k': small_44k,
|
74 |
+
'medium_44k': medium_44k,
|
75 |
+
'large_44k': large_44k,
|
76 |
+
'large_44k_v2': large_44k_v2,
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def generate(
|
81 |
+
clip_video: Optional[torch.Tensor],
|
82 |
+
sync_video: Optional[torch.Tensor],
|
83 |
+
text: Optional[List[str]],
|
84 |
+
*,
|
85 |
+
negative_text: Optional[List[str]] = None,
|
86 |
+
feature_utils: FeaturesUtils,
|
87 |
+
net: MMAudio,
|
88 |
+
fm: FlowMatching,
|
89 |
+
rng: torch.Generator,
|
90 |
+
cfg_strength: float,
|
91 |
+
clip_batch_size_multiplier: int = 40,
|
92 |
+
sync_batch_size_multiplier: int = 40,
|
93 |
+
image_input: bool = False,
|
94 |
+
) -> torch.Tensor:
|
95 |
+
device = feature_utils.device
|
96 |
+
dtype = feature_utils.dtype
|
97 |
+
|
98 |
+
bs = len(text)
|
99 |
+
if clip_video is not None:
|
100 |
+
clip_video = clip_video.to(device, dtype, non_blocking=True)
|
101 |
+
clip_features = feature_utils.encode_video_with_clip(clip_video,
|
102 |
+
batch_size=bs *
|
103 |
+
clip_batch_size_multiplier)
|
104 |
+
if image_input:
|
105 |
+
clip_features = clip_features.expand(-1, net.clip_seq_len, -1)
|
106 |
+
else:
|
107 |
+
clip_features = net.get_empty_clip_sequence(bs)
|
108 |
+
|
109 |
+
if sync_video is not None and not image_input:
|
110 |
+
sync_video = sync_video.to(device, dtype, non_blocking=True)
|
111 |
+
sync_features = feature_utils.encode_video_with_sync(sync_video,
|
112 |
+
batch_size=bs *
|
113 |
+
sync_batch_size_multiplier)
|
114 |
+
else:
|
115 |
+
sync_features = net.get_empty_sync_sequence(bs)
|
116 |
+
|
117 |
+
if text is not None:
|
118 |
+
text_features = feature_utils.encode_text(text)
|
119 |
+
else:
|
120 |
+
text_features = net.get_empty_string_sequence(bs)
|
121 |
+
|
122 |
+
if negative_text is not None:
|
123 |
+
assert len(negative_text) == bs
|
124 |
+
negative_text_features = feature_utils.encode_text(negative_text)
|
125 |
+
else:
|
126 |
+
negative_text_features = net.get_empty_string_sequence(bs)
|
127 |
+
|
128 |
+
x0 = torch.randn(bs,
|
129 |
+
net.latent_seq_len,
|
130 |
+
net.latent_dim,
|
131 |
+
device=device,
|
132 |
+
dtype=dtype,
|
133 |
+
generator=rng)
|
134 |
+
preprocessed_conditions = net.preprocess_conditions(clip_features, sync_features, text_features)
|
135 |
+
empty_conditions = net.get_empty_conditions(
|
136 |
+
bs, negative_text_features=negative_text_features if negative_text is not None else None)
|
137 |
+
|
138 |
+
cfg_ode_wrapper = lambda t, x: net.ode_wrapper(t, x, preprocessed_conditions, empty_conditions,
|
139 |
+
cfg_strength)
|
140 |
+
x1 = fm.to_data(cfg_ode_wrapper, x0)
|
141 |
+
x1 = net.unnormalize(x1)
|
142 |
+
spec = feature_utils.decode(x1)
|
143 |
+
audio = feature_utils.vocode(spec)
|
144 |
+
return audio
|
145 |
+
|
146 |
+
|
147 |
+
LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s"
|
148 |
+
|
149 |
+
|
150 |
+
def setup_eval_logging(log_level: int = logging.INFO):
|
151 |
+
logging.root.setLevel(log_level)
|
152 |
+
formatter = ColoredFormatter(LOGFORMAT)
|
153 |
+
stream = logging.StreamHandler()
|
154 |
+
stream.setLevel(log_level)
|
155 |
+
stream.setFormatter(formatter)
|
156 |
+
log = logging.getLogger()
|
157 |
+
log.setLevel(log_level)
|
158 |
+
log.addHandler(stream)
|
159 |
+
|
160 |
+
|
161 |
+
_CLIP_SIZE = 384
|
162 |
+
_CLIP_FPS = 8.0
|
163 |
+
|
164 |
+
_SYNC_SIZE = 224
|
165 |
+
_SYNC_FPS = 25.0
|
166 |
+
|
167 |
+
|
168 |
+
def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
|
169 |
+
|
170 |
+
clip_transform = v2.Compose([
|
171 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
172 |
+
v2.ToImage(),
|
173 |
+
v2.ToDtype(torch.float32, scale=True),
|
174 |
+
])
|
175 |
+
|
176 |
+
sync_transform = v2.Compose([
|
177 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
178 |
+
v2.CenterCrop(_SYNC_SIZE),
|
179 |
+
v2.ToImage(),
|
180 |
+
v2.ToDtype(torch.float32, scale=True),
|
181 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
182 |
+
])
|
183 |
+
|
184 |
+
output_frames, all_frames, orig_fps = read_frames(video_path,
|
185 |
+
list_of_fps=[_CLIP_FPS, _SYNC_FPS],
|
186 |
+
start_sec=0,
|
187 |
+
end_sec=duration_sec,
|
188 |
+
need_all_frames=load_all_frames)
|
189 |
+
|
190 |
+
clip_chunk, sync_chunk = output_frames
|
191 |
+
clip_chunk = torch.from_numpy(clip_chunk).permute(0, 3, 1, 2)
|
192 |
+
sync_chunk = torch.from_numpy(sync_chunk).permute(0, 3, 1, 2)
|
193 |
+
|
194 |
+
clip_frames = clip_transform(clip_chunk)
|
195 |
+
sync_frames = sync_transform(sync_chunk)
|
196 |
+
|
197 |
+
clip_length_sec = clip_frames.shape[0] / _CLIP_FPS
|
198 |
+
sync_length_sec = sync_frames.shape[0] / _SYNC_FPS
|
199 |
+
|
200 |
+
if clip_length_sec < duration_sec:
|
201 |
+
log.warning(f'Clip video is too short: {clip_length_sec:.2f} < {duration_sec:.2f}')
|
202 |
+
log.warning(f'Truncating to {clip_length_sec:.2f} sec')
|
203 |
+
duration_sec = clip_length_sec
|
204 |
+
|
205 |
+
if sync_length_sec < duration_sec:
|
206 |
+
log.warning(f'Sync video is too short: {sync_length_sec:.2f} < {duration_sec:.2f}')
|
207 |
+
log.warning(f'Truncating to {sync_length_sec:.2f} sec')
|
208 |
+
duration_sec = sync_length_sec
|
209 |
+
|
210 |
+
clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
|
211 |
+
sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
|
212 |
+
|
213 |
+
video_info = VideoInfo(
|
214 |
+
duration_sec=duration_sec,
|
215 |
+
fps=orig_fps,
|
216 |
+
clip_frames=clip_frames,
|
217 |
+
sync_frames=sync_frames,
|
218 |
+
all_frames=all_frames if load_all_frames else None,
|
219 |
+
)
|
220 |
+
return video_info
|
221 |
+
|
222 |
+
|
223 |
+
def load_image(image_path: Path) -> VideoInfo:
|
224 |
+
clip_transform = v2.Compose([
|
225 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
226 |
+
v2.ToImage(),
|
227 |
+
v2.ToDtype(torch.float32, scale=True),
|
228 |
+
])
|
229 |
+
|
230 |
+
sync_transform = v2.Compose([
|
231 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
232 |
+
v2.CenterCrop(_SYNC_SIZE),
|
233 |
+
v2.ToImage(),
|
234 |
+
v2.ToDtype(torch.float32, scale=True),
|
235 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
236 |
+
])
|
237 |
+
|
238 |
+
frame = np.array(Image.open(image_path))
|
239 |
+
|
240 |
+
clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
|
241 |
+
sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
|
242 |
+
|
243 |
+
clip_frames = clip_transform(clip_chunk)
|
244 |
+
sync_frames = sync_transform(sync_chunk)
|
245 |
+
|
246 |
+
video_info = ImageInfo(
|
247 |
+
clip_frames=clip_frames,
|
248 |
+
sync_frames=sync_frames,
|
249 |
+
original_frame=frame,
|
250 |
+
)
|
251 |
+
return video_info
|
252 |
+
|
253 |
+
|
254 |
+
def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
255 |
+
reencode_with_audio(video_info, output_path, audio, sampling_rate)
|