Spaces:
Running
Running
lym0302
commited on
Commit
·
0163d98
1
Parent(s):
4cfc7ca
new exp
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +69 -101
- app.py +34 -102
- batch_eval.py +0 -110
- config/__init__.py +0 -0
- config/base_config.yaml +0 -62
- config/data/base.yaml +0 -70
- config/eval_config.yaml +0 -17
- config/eval_data/base.yaml +0 -22
- config/hydra/job_logging/custom-eval.yaml +0 -32
- config/hydra/job_logging/custom-no-rank.yaml +0 -32
- config/hydra/job_logging/custom-simplest.yaml +0 -26
- config/hydra/job_logging/custom.yaml +0 -33
- config/train_config.yaml +0 -41
- demo.py +1 -7
- docs/EVAL.md +0 -22
- docs/MODELS.md +0 -50
- docs/TRAINING.md +0 -184
- docs/index.html +10 -12
- gradio_demo.py +0 -343
- mmaudio/__pycache__/__init__.cpython-310.pyc +0 -0
- mmaudio/__pycache__/__init__.cpython-38.pyc +0 -0
- mmaudio/__pycache__/eval_utils.cpython-310.pyc +0 -0
- mmaudio/__pycache__/eval_utils.cpython-38.pyc +0 -0
- mmaudio/data/__pycache__/__init__.cpython-310.pyc +0 -0
- mmaudio/data/__pycache__/__init__.cpython-38.pyc +0 -0
- mmaudio/data/__pycache__/av_utils.cpython-310.pyc +0 -0
- mmaudio/data/__pycache__/av_utils.cpython-38.pyc +0 -0
- mmaudio/data/av_utils.py +4 -30
- mmaudio/data/data_setup.py +0 -174
- mmaudio/data/eval/__init__.py +0 -0
- mmaudio/data/eval/audiocaps.py +0 -39
- mmaudio/data/eval/moviegen.py +0 -131
- mmaudio/data/eval/video_dataset.py +0 -197
- mmaudio/data/extracted_audio.py +0 -88
- mmaudio/data/extracted_vgg.py +0 -101
- mmaudio/data/extraction/__init__.py +0 -0
- mmaudio/data/extraction/vgg_sound.py +0 -193
- mmaudio/data/extraction/wav_dataset.py +0 -132
- mmaudio/data/mm_dataset.py +0 -45
- mmaudio/data/utils.py +0 -148
- mmaudio/eval_utils.py +25 -63
- mmaudio/ext/__pycache__/__init__.cpython-310.pyc +0 -0
- mmaudio/ext/__pycache__/__init__.cpython-38.pyc +0 -0
- mmaudio/ext/__pycache__/mel_converter.cpython-310.pyc +0 -0
- mmaudio/ext/__pycache__/mel_converter.cpython-38.pyc +0 -0
- mmaudio/ext/__pycache__/rotary_embeddings.cpython-310.pyc +0 -0
- mmaudio/ext/__pycache__/rotary_embeddings.cpython-38.pyc +0 -0
- mmaudio/ext/autoencoder/__pycache__/__init__.cpython-310.pyc +0 -0
- mmaudio/ext/autoencoder/__pycache__/__init__.cpython-38.pyc +0 -0
- mmaudio/ext/autoencoder/__pycache__/autoencoder.cpython-310.pyc +0 -0
README.md
CHANGED
@@ -1,27 +1,25 @@
|
|
1 |
---
|
2 |
title: DeepSound-V1
|
3 |
-
|
4 |
-
|
|
|
5 |
sdk: gradio
|
6 |
-
sdk_version: 5.22.0
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
---
|
10 |
|
11 |
-
<div align="center">
|
12 |
-
<p align="center">
|
13 |
-
<h2>MMAudio</h2>
|
14 |
-
<a href="https://arxiv.org/abs/2412.15322">Paper</a> | <a href="https://hkchengrex.github.io/MMAudio">Webpage</a> | <a href="https://huggingface.co/hkchengrex/MMAudio/tree/main">Models</a> | <a href="https://huggingface.co/spaces/hkchengrex/MMAudio"> Huggingface Demo</a> | <a href="https://colab.research.google.com/drive/1TAaXCY2-kPk4xE4PwKB3EqFbSnkUuzZ8?usp=sharing">Colab Demo</a> | <a href="https://replicate.com/zsxkib/mmaudio">Replicate Demo</a>
|
15 |
-
</p>
|
16 |
-
</div>
|
17 |
|
18 |
-
|
19 |
|
20 |
[Ho Kei Cheng](https://hkchengrex.github.io/), [Masato Ishii](https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ), [Akio Hayakawa](https://scholar.google.com/citations?user=sXAjHFIAAAAJ), [Takashi Shibuya](https://scholar.google.com/citations?user=XCRO260AAAAJ), [Alexander Schwing](https://www.alexander-schwing.de/), [Yuki Mitsufuji](https://www.yukimitsufuji.com/)
|
21 |
|
22 |
University of Illinois Urbana-Champaign, Sony AI, and Sony Group Corporation
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
|
26 |
## Highlight
|
27 |
|
@@ -29,25 +27,22 @@ MMAudio generates synchronized audio given video and/or text inputs.
|
|
29 |
Our key innovation is multimodal joint training which allows training on a wide range of audio-visual and audio-text datasets.
|
30 |
Moreover, a synchronization module aligns the generated audio with the video frames.
|
31 |
|
|
|
32 |
## Results
|
33 |
|
34 |
(All audio from our algorithm MMAudio)
|
35 |
|
36 |
-
Videos from Sora:
|
37 |
|
38 |
https://github.com/user-attachments/assets/82afd192-0cee-48a1-86ca-bd39b8c8f330
|
39 |
|
40 |
-
Videos from Veo 2:
|
41 |
-
|
42 |
-
https://github.com/user-attachments/assets/8a11419e-fee2-46e0-9e67-dfb03c48d00e
|
43 |
|
44 |
-
Videos from MovieGen/Hunyuan Video/VGGSound:
|
45 |
|
46 |
https://github.com/user-attachments/assets/29230d4e-21c1-4cf8-a221-c28f2af6d0ca
|
47 |
|
48 |
For more results, visit https://hkchengrex.com/MMAudio/video_main.html.
|
49 |
|
50 |
-
|
51 |
## Installation
|
52 |
|
53 |
We have only tested this on Ubuntu.
|
@@ -56,30 +51,17 @@ We have only tested this on Ubuntu.
|
|
56 |
|
57 |
We recommend using a [miniforge](https://github.com/conda-forge/miniforge) environment.
|
58 |
|
59 |
-
- Python 3.
|
60 |
-
- PyTorch **2.5.1+** and corresponding torchvision/torchaudio (pick your CUDA version https://pytorch.org
|
61 |
-
|
62 |
|
63 |
-
**
|
64 |
-
|
65 |
-
```bash
|
66 |
-
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 --upgrade
|
67 |
-
```
|
68 |
-
|
69 |
-
(Or any other CUDA versions that your GPUs/driver support)
|
70 |
-
|
71 |
-
<!-- ```
|
72 |
-
conda install -c conda-forge 'ffmpeg<7
|
73 |
-
```
|
74 |
-
(Optional, if you use miniforge and don't already have the appropriate ffmpeg) -->
|
75 |
-
|
76 |
-
**2. Clone our repository:**
|
77 |
|
78 |
```bash
|
79 |
git clone https://github.com/hkchengrex/MMAudio.git
|
80 |
```
|
81 |
|
82 |
-
**
|
83 |
|
84 |
```bash
|
85 |
cd MMAudio
|
@@ -88,108 +70,94 @@ pip install -e .
|
|
88 |
|
89 |
(If you encounter the File "setup.py" not found error, upgrade your pip with pip install --upgrade pip)
|
90 |
|
91 |
-
|
92 |
**Pretrained models:**
|
93 |
|
94 |
-
The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
## Demo
|
99 |
|
100 |
-
By default, these scripts use the `
|
101 |
In our experiments, inference only takes around 6GB of GPU memory (in 16-bit mode) which should fit in most modern GPUs.
|
102 |
|
103 |
### Command-line interface
|
104 |
|
105 |
With `demo.py`
|
106 |
-
|
107 |
```bash
|
108 |
python demo.py --duration=8 --video=<path to video> --prompt "your prompt"
|
109 |
```
|
110 |
-
|
111 |
The output (audio in `.flac` format, and video in `.mp4` format) will be saved in `./output`.
|
112 |
See the file for more options.
|
113 |
Simply omit the `--video` option for text-to-audio synthesis.
|
114 |
The default output (and training) duration is 8 seconds. Longer/shorter durations could also work, but a large deviation from the training duration may result in a lower quality.
|
115 |
|
|
|
116 |
### Gradio interface
|
117 |
|
118 |
Supports video-to-audio and text-to-audio synthesis.
|
119 |
-
You can also try experimental image-to-audio synthesis which duplicates the input image to a video for processing. This might be interesting to some but it is not something MMAudio has been trained for.
|
120 |
-
Use [port forwarding](https://unix.stackexchange.com/questions/115897/whats-ssh-port-forwarding-and-whats-the-difference-between-ssh-local-and-remot) (e.g., `ssh -L 7860:localhost:7860 server`) if necessary. The default port is `7860` which you can specify with `--port`.
|
121 |
|
122 |
-
```
|
123 |
python gradio_demo.py
|
124 |
```
|
125 |
|
126 |
-
### FAQ
|
127 |
-
|
128 |
-
1. Video processing
|
129 |
-
- Processing higher-resolution videos takes longer due to encoding and decoding (which can take >95% of the processing time!), but it does not improve the quality of results.
|
130 |
-
- The CLIP encoder resizes input frames to 384×384 pixels.
|
131 |
-
- Synchformer resizes the shorter edge to 224 pixels and applies a center crop, focusing only on the central square of each frame.
|
132 |
-
2. Frame rates
|
133 |
-
- The CLIP model operates at 8 FPS, while Synchformer works at 25 FPS.
|
134 |
-
- Frame rate conversion happens on-the-fly via the video reader.
|
135 |
-
- For input videos with a frame rate below 25 FPS, frames will be duplicated to match the required rate.
|
136 |
-
3. Failure cases
|
137 |
-
As with most models of this type, failures can occur, and the reasons are not always clear. Below are some known failure modes. If you notice a failure mode or believe there’s a bug, feel free to open an issue in the repository.
|
138 |
-
4. Performance variations
|
139 |
-
We notice that there can be subtle performance variations in different hardware and software environments. Some of the reasons include using/not using `torch.compile`, video reader library/backend, inference precision, batch sizes, random seeds, etc. We (will) provide pre-computed results on standard benchmark for reference. Results obtained from this codebase should be similar but might not be exactly the same.
|
140 |
-
|
141 |
### Known limitations
|
142 |
|
143 |
-
1. The model sometimes generates unintelligible human speech-like sounds
|
144 |
-
2. The model sometimes generates background music
|
145 |
3. The model struggles with unfamiliar concepts, e.g., it can generate "gunfires" but not "RPG firing".
|
146 |
|
147 |
We believe all of these three limitations can be addressed with more high-quality training data.
|
148 |
|
149 |
## Training
|
150 |
-
|
151 |
-
See [TRAINING.md](docs/TRAINING.md).
|
152 |
|
153 |
## Evaluation
|
154 |
-
|
155 |
-
See [EVAL.md](docs/EVAL.md).
|
156 |
-
|
157 |
-
## Training Datasets
|
158 |
-
|
159 |
-
MMAudio was trained on several datasets, including [AudioSet](https://research.google.com/audioset/), [Freesound](https://github.com/LAION-AI/audio-dataset/blob/main/laion-audio-630k/README.md), [VGGSound](https://www.robots.ox.ac.uk/~vgg/data/vggsound/), [AudioCaps](https://audiocaps.github.io/), and [WavCaps](https://github.com/XinhaoMei/WavCaps). These datasets are subject to specific licenses, which can be accessed on their respective websites. We do not guarantee that the pre-trained models are suitable for commercial use. Please use them at your own risk.
|
160 |
-
|
161 |
-
## Update Logs
|
162 |
-
|
163 |
-
- 2025-03-09: Uploaded the corrected tsv files. See [TRAINING.md](docs/TRAINING.md).
|
164 |
-
- 2025-02-27: Disabled the GradScaler by default to improve training stability. See #49.
|
165 |
-
- 2024-12-23: Added training and batch evaluation scripts.
|
166 |
-
- 2024-12-14: Removed the `ffmpeg<7` requirement for the demos by replacing `torio.io.StreamingMediaDecoder` with `pyav` for reading frames. The read frames are also cached, so we are not reading the same frames again during reconstruction. This should speed things up and make installation less of a hassle.
|
167 |
-
- 2024-12-13: Improved for-loop processing in CLIP/Sync feature extraction by introducing a batch size multiplier. We can approximately use 40x batch size for CLIP/Sync without using more memory, thereby speeding up processing. Removed VAE encoder during inference -- we don't need it.
|
168 |
-
- 2024-12-11: Replaced `torio.io.StreamingMediaDecoder` with `pyav` for reading framerate when reconstructing the input video. `torio.io.StreamingMediaDecoder` does not work reliably in huggingface ZeroGPU's environment, and I suspect that it might not work in some other environments as well.
|
169 |
-
|
170 |
-
## Citation
|
171 |
-
|
172 |
-
```bibtex
|
173 |
-
@inproceedings{cheng2025taming,
|
174 |
-
title={Taming Multimodal Joint Training for High-Quality Video-to-Audio Synthesis},
|
175 |
-
author={Cheng, Ho Kei and Ishii, Masato and Hayakawa, Akio and Shibuya, Takashi and Schwing, Alexander and Mitsufuji, Yuki},
|
176 |
-
booktitle={CVPR},
|
177 |
-
year={2025}
|
178 |
-
}
|
179 |
-
```
|
180 |
-
|
181 |
-
## Relevant Repositories
|
182 |
-
|
183 |
-
- [av-benchmark](https://github.com/hkchengrex/av-benchmark) for benchmarking results.
|
184 |
-
|
185 |
-
## Disclaimer
|
186 |
-
|
187 |
-
We have no affiliation with and have no knowledge of the party behind the domain "mmaudio.net".
|
188 |
|
189 |
## Acknowledgement
|
190 |
-
|
191 |
Many thanks to:
|
192 |
-
- [Make-An-Audio 2](https://github.com/bytedance/Make-An-Audio-2) for the 16kHz BigVGAN pretrained model
|
193 |
- [BigVGAN](https://github.com/NVIDIA/BigVGAN)
|
194 |
- [Synchformer](https://github.com/v-iashin/Synchformer)
|
195 |
-
|
|
|
1 |
---
|
2 |
title: DeepSound-V1
|
3 |
+
emoji: 🔊
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
|
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
---
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
# [Taming Multimodal Joint Training for High-Quality Video-to-Audio Synthesis](https://hkchengrex.github.io/MMAudio)
|
13 |
|
14 |
[Ho Kei Cheng](https://hkchengrex.github.io/), [Masato Ishii](https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ), [Akio Hayakawa](https://scholar.google.com/citations?user=sXAjHFIAAAAJ), [Takashi Shibuya](https://scholar.google.com/citations?user=XCRO260AAAAJ), [Alexander Schwing](https://www.alexander-schwing.de/), [Yuki Mitsufuji](https://www.yukimitsufuji.com/)
|
15 |
|
16 |
University of Illinois Urbana-Champaign, Sony AI, and Sony Group Corporation
|
17 |
|
18 |
+
|
19 |
+
[[Paper (being prepared)]](https://hkchengrex.github.io/MMAudio) [[Project Page]](https://hkchengrex.github.io/MMAudio)
|
20 |
+
|
21 |
+
|
22 |
+
**Note: This repository is still under construction. Single-example inference should work as expected. The training code will be added. Code is subject to non-backward-compatible changes.**
|
23 |
|
24 |
## Highlight
|
25 |
|
|
|
27 |
Our key innovation is multimodal joint training which allows training on a wide range of audio-visual and audio-text datasets.
|
28 |
Moreover, a synchronization module aligns the generated audio with the video frames.
|
29 |
|
30 |
+
|
31 |
## Results
|
32 |
|
33 |
(All audio from our algorithm MMAudio)
|
34 |
|
35 |
+
Videos from Sora:
|
36 |
|
37 |
https://github.com/user-attachments/assets/82afd192-0cee-48a1-86ca-bd39b8c8f330
|
38 |
|
|
|
|
|
|
|
39 |
|
40 |
+
Videos from MovieGen/Hunyuan Video/VGGSound:
|
41 |
|
42 |
https://github.com/user-attachments/assets/29230d4e-21c1-4cf8-a221-c28f2af6d0ca
|
43 |
|
44 |
For more results, visit https://hkchengrex.com/MMAudio/video_main.html.
|
45 |
|
|
|
46 |
## Installation
|
47 |
|
48 |
We have only tested this on Ubuntu.
|
|
|
51 |
|
52 |
We recommend using a [miniforge](https://github.com/conda-forge/miniforge) environment.
|
53 |
|
54 |
+
- Python 3.8+
|
55 |
+
- PyTorch **2.5.1+** and corresponding torchvision/torchaudio (pick your CUDA version https://pytorch.org/)
|
56 |
+
- ffmpeg<7 ([this is required by torchaudio](https://pytorch.org/audio/master/installation.html#optional-dependencies), you can install it in a miniforge environment with `conda install -c conda-forge 'ffmpeg<7'`)
|
57 |
|
58 |
+
**Clone our repository:**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
```bash
|
61 |
git clone https://github.com/hkchengrex/MMAudio.git
|
62 |
```
|
63 |
|
64 |
+
**Install with pip:**
|
65 |
|
66 |
```bash
|
67 |
cd MMAudio
|
|
|
70 |
|
71 |
(If you encounter the File "setup.py" not found error, upgrade your pip with pip install --upgrade pip)
|
72 |
|
|
|
73 |
**Pretrained models:**
|
74 |
|
75 |
+
The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`
|
76 |
+
|
77 |
+
| Model | Download link | File size |
|
78 |
+
| -------- | ------- | ------- |
|
79 |
+
| Flow prediction network, small 16kHz | <a href="https://databank.illinois.edu/datafiles/k6jve/download" download="mmaudio_small_16k.pth">mmaudio_small_16k.pth</a> | 601M |
|
80 |
+
| Flow prediction network, small 44.1kHz | <a href="https://databank.illinois.edu/datafiles/864ya/download" download="mmaudio_small_44k.pth">mmaudio_small_44k.pth</a> | 601M |
|
81 |
+
| Flow prediction network, medium 44.1kHz | <a href="https://databank.illinois.edu/datafiles/pa94t/download" download="mmaudio_medium_44k.pth">mmaudio_medium_44k.pth</a> | 2.4G |
|
82 |
+
| Flow prediction network, large 44.1kHz **(recommended)** | <a href="https://databank.illinois.edu/datafiles/4jx76/download" download="mmaudio_large_44k.pth">mmaudio_large_44k.pth</a> | 3.9G |
|
83 |
+
| 16kHz VAE | <a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-16.pth">v1-16.pth</a> | 655M |
|
84 |
+
| 16kHz BigVGAN vocoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/best_netG.pt">best_netG.pt</a> | 429M |
|
85 |
+
| 44.1kHz VAE |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth">v1-44.pth</a> | 1.2G |
|
86 |
+
| Synchformer visual encoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth">synchformer_state_dict.pth</a> | 907M |
|
87 |
+
|
88 |
+
The 44.1kHz vocoder will be downloaded automatically.
|
89 |
+
|
90 |
+
The expected directory structure (full):
|
91 |
+
|
92 |
+
```bash
|
93 |
+
MMAudio
|
94 |
+
├── ext_weights
|
95 |
+
│ ├── best_netG.pt
|
96 |
+
│ ├── synchformer_state_dict.pth
|
97 |
+
│ ├── v1-16.pth
|
98 |
+
│ └── v1-44.pth
|
99 |
+
├── weights
|
100 |
+
│ ├── mmaudio_small_16k.pth
|
101 |
+
│ ├── mmaudio_small_44k.pth
|
102 |
+
│ ├── mmaudio_medium_44k.pth
|
103 |
+
│ └── mmaudio_large_44k.pth
|
104 |
+
└── ...
|
105 |
+
```
|
106 |
+
|
107 |
+
The expected directory structure (minimal, for the recommended model only):
|
108 |
+
|
109 |
+
```bash
|
110 |
+
MMAudio
|
111 |
+
├── ext_weights
|
112 |
+
│ ├── synchformer_state_dict.pth
|
113 |
+
│ └── v1-44.pth
|
114 |
+
├── weights
|
115 |
+
│ └── mmaudio_large_44k.pth
|
116 |
+
└── ...
|
117 |
+
```
|
118 |
|
119 |
## Demo
|
120 |
|
121 |
+
By default, these scripts use the `large_44k` model.
|
122 |
In our experiments, inference only takes around 6GB of GPU memory (in 16-bit mode) which should fit in most modern GPUs.
|
123 |
|
124 |
### Command-line interface
|
125 |
|
126 |
With `demo.py`
|
|
|
127 |
```bash
|
128 |
python demo.py --duration=8 --video=<path to video> --prompt "your prompt"
|
129 |
```
|
|
|
130 |
The output (audio in `.flac` format, and video in `.mp4` format) will be saved in `./output`.
|
131 |
See the file for more options.
|
132 |
Simply omit the `--video` option for text-to-audio synthesis.
|
133 |
The default output (and training) duration is 8 seconds. Longer/shorter durations could also work, but a large deviation from the training duration may result in a lower quality.
|
134 |
|
135 |
+
|
136 |
### Gradio interface
|
137 |
|
138 |
Supports video-to-audio and text-to-audio synthesis.
|
|
|
|
|
139 |
|
140 |
+
```
|
141 |
python gradio_demo.py
|
142 |
```
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
### Known limitations
|
145 |
|
146 |
+
1. The model sometimes generates undesired unintelligible human speech-like sounds
|
147 |
+
2. The model sometimes generates undesired background music
|
148 |
3. The model struggles with unfamiliar concepts, e.g., it can generate "gunfires" but not "RPG firing".
|
149 |
|
150 |
We believe all of these three limitations can be addressed with more high-quality training data.
|
151 |
|
152 |
## Training
|
153 |
+
Work in progress.
|
|
|
154 |
|
155 |
## Evaluation
|
156 |
+
Work in progress.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
## Acknowledgement
|
|
|
159 |
Many thanks to:
|
160 |
+
- [Make-An-Audio 2](https://github.com/bytedance/Make-An-Audio-2) for the 16kHz BigVGAN pretrained model
|
161 |
- [BigVGAN](https://github.com/NVIDIA/BigVGAN)
|
162 |
- [Synchformer](https://github.com/v-iashin/Synchformer)
|
163 |
+
|
app.py
CHANGED
@@ -1,33 +1,33 @@
|
|
1 |
-
import
|
2 |
import logging
|
3 |
-
from argparse import ArgumentParser
|
4 |
from datetime import datetime
|
5 |
-
from fractions import Fraction
|
6 |
from pathlib import Path
|
7 |
|
8 |
import gradio as gr
|
9 |
import torch
|
10 |
import torchaudio
|
|
|
11 |
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
from mmaudio.model.flow_matching import FlowMatching
|
15 |
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
16 |
from mmaudio.model.sequence_config import SequenceConfig
|
17 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
|
|
18 |
|
19 |
torch.backends.cuda.matmul.allow_tf32 = True
|
20 |
torch.backends.cudnn.allow_tf32 = True
|
21 |
|
22 |
log = logging.getLogger()
|
23 |
|
24 |
-
device = '
|
25 |
-
if torch.cuda.is_available():
|
26 |
-
device = 'cuda'
|
27 |
-
elif torch.backends.mps.is_available():
|
28 |
-
device = 'mps'
|
29 |
-
else:
|
30 |
-
log.warning('CUDA/MPS are not available, running on CPU')
|
31 |
dtype = torch.bfloat16
|
32 |
|
33 |
model: ModelConfig = all_model_cfg['large_44k_v2']
|
@@ -58,6 +58,7 @@ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
|
58 |
net, feature_utils, seq_cfg = get_model()
|
59 |
|
60 |
|
|
|
61 |
@torch.inference_mode()
|
62 |
def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
63 |
cfg_strength: float, duration: float):
|
@@ -88,53 +89,16 @@ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int
|
|
88 |
cfg_strength=cfg_strength)
|
89 |
audio = audios.float().cpu()[0]
|
90 |
|
91 |
-
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
gc.collect()
|
96 |
-
return video_save_path
|
97 |
-
|
98 |
-
|
99 |
-
@torch.inference_mode()
|
100 |
-
def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
101 |
-
cfg_strength: float, duration: float):
|
102 |
-
|
103 |
-
rng = torch.Generator(device=device)
|
104 |
-
if seed >= 0:
|
105 |
-
rng.manual_seed(seed)
|
106 |
-
else:
|
107 |
-
rng.seed()
|
108 |
-
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
109 |
-
|
110 |
-
image_info = load_image(image)
|
111 |
-
clip_frames = image_info.clip_frames
|
112 |
-
sync_frames = image_info.sync_frames
|
113 |
-
clip_frames = clip_frames.unsqueeze(0)
|
114 |
-
sync_frames = sync_frames.unsqueeze(0)
|
115 |
-
seq_cfg.duration = duration
|
116 |
-
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
117 |
-
|
118 |
-
audios = generate(clip_frames,
|
119 |
-
sync_frames, [prompt],
|
120 |
-
negative_text=[negative_prompt],
|
121 |
-
feature_utils=feature_utils,
|
122 |
-
net=net,
|
123 |
-
fm=fm,
|
124 |
-
rng=rng,
|
125 |
-
cfg_strength=cfg_strength,
|
126 |
-
image_input=True)
|
127 |
-
audio = audios.float().cpu()[0]
|
128 |
-
|
129 |
-
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
130 |
-
output_dir.mkdir(exist_ok=True, parents=True)
|
131 |
-
video_save_path = output_dir / f'{current_time_string}.mp4'
|
132 |
-
video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1))
|
133 |
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
134 |
-
|
135 |
return video_save_path
|
136 |
|
137 |
|
|
|
138 |
@torch.inference_mode()
|
139 |
def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
|
140 |
duration: float):
|
@@ -160,11 +124,9 @@ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
|
160 |
cfg_strength=cfg_strength)
|
161 |
audio = audios.float().cpu()[0]
|
162 |
|
163 |
-
|
164 |
-
output_dir.mkdir(exist_ok=True, parents=True)
|
165 |
-
audio_save_path = output_dir / f'{current_time_string}.flac'
|
166 |
torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
|
167 |
-
|
168 |
return audio_save_path
|
169 |
|
170 |
|
@@ -176,6 +138,8 @@ video_to_audio_tab = gr.Interface(
|
|
176 |
|
177 |
NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
|
178 |
Doing so does not improve results.
|
|
|
|
|
179 |
""",
|
180 |
inputs=[
|
181 |
gr.Video(),
|
@@ -245,8 +209,8 @@ video_to_audio_tab = gr.Interface(
|
|
245 |
10,
|
246 |
],
|
247 |
[
|
248 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/
|
249 |
-
'
|
250 |
'',
|
251 |
0,
|
252 |
25,
|
@@ -254,8 +218,8 @@ video_to_audio_tab = gr.Interface(
|
|
254 |
10,
|
255 |
],
|
256 |
[
|
257 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/
|
258 |
-
'',
|
259 |
'',
|
260 |
0,
|
261 |
25,
|
@@ -263,8 +227,8 @@ video_to_audio_tab = gr.Interface(
|
|
263 |
10,
|
264 |
],
|
265 |
[
|
266 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/
|
267 |
-
'
|
268 |
'',
|
269 |
0,
|
270 |
25,
|
@@ -272,8 +236,8 @@ video_to_audio_tab = gr.Interface(
|
|
272 |
10,
|
273 |
],
|
274 |
[
|
275 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/
|
276 |
-
'',
|
277 |
'',
|
278 |
0,
|
279 |
25,
|
@@ -281,7 +245,7 @@ video_to_audio_tab = gr.Interface(
|
|
281 |
10,
|
282 |
],
|
283 |
[
|
284 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/
|
285 |
'',
|
286 |
'',
|
287 |
0,
|
@@ -293,10 +257,6 @@ video_to_audio_tab = gr.Interface(
|
|
293 |
|
294 |
text_to_audio_tab = gr.Interface(
|
295 |
fn=text_to_audio,
|
296 |
-
description="""
|
297 |
-
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
298 |
-
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
299 |
-
""",
|
300 |
inputs=[
|
301 |
gr.Text(label='Prompt'),
|
302 |
gr.Text(label='Negative prompt'),
|
@@ -310,34 +270,6 @@ text_to_audio_tab = gr.Interface(
|
|
310 |
title='MMAudio — Text-to-Audio Synthesis',
|
311 |
)
|
312 |
|
313 |
-
image_to_audio_tab = gr.Interface(
|
314 |
-
fn=image_to_audio,
|
315 |
-
description="""
|
316 |
-
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
317 |
-
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
318 |
-
|
319 |
-
NOTE: It takes longer to process high-resolution images (>384 px on the shorter side).
|
320 |
-
Doing so does not improve results.
|
321 |
-
""",
|
322 |
-
inputs=[
|
323 |
-
gr.Image(type='filepath'),
|
324 |
-
gr.Text(label='Prompt'),
|
325 |
-
gr.Text(label='Negative prompt'),
|
326 |
-
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
327 |
-
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
328 |
-
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
329 |
-
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
330 |
-
],
|
331 |
-
outputs='playable_video',
|
332 |
-
cache_examples=False,
|
333 |
-
title='MMAudio — Image-to-Audio Synthesis (experimental)',
|
334 |
-
)
|
335 |
-
|
336 |
if __name__ == "__main__":
|
337 |
-
|
338 |
-
|
339 |
-
args = parser.parse_args()
|
340 |
-
|
341 |
-
gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab],
|
342 |
-
['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']).launch(
|
343 |
-
server_port=args.port, allowed_paths=[output_dir])
|
|
|
1 |
+
import spaces
|
2 |
import logging
|
|
|
3 |
from datetime import datetime
|
|
|
4 |
from pathlib import Path
|
5 |
|
6 |
import gradio as gr
|
7 |
import torch
|
8 |
import torchaudio
|
9 |
+
import os
|
10 |
|
11 |
+
try:
|
12 |
+
import mmaudio
|
13 |
+
except ImportError:
|
14 |
+
os.system("pip install -e .")
|
15 |
+
import mmaudio
|
16 |
+
|
17 |
+
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
|
18 |
+
setup_eval_logging)
|
19 |
from mmaudio.model.flow_matching import FlowMatching
|
20 |
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
21 |
from mmaudio.model.sequence_config import SequenceConfig
|
22 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
23 |
+
import tempfile
|
24 |
|
25 |
torch.backends.cuda.matmul.allow_tf32 = True
|
26 |
torch.backends.cudnn.allow_tf32 = True
|
27 |
|
28 |
log = logging.getLogger()
|
29 |
|
30 |
+
device = 'cuda'
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
dtype = torch.bfloat16
|
32 |
|
33 |
model: ModelConfig = all_model_cfg['large_44k_v2']
|
|
|
58 |
net, feature_utils, seq_cfg = get_model()
|
59 |
|
60 |
|
61 |
+
@spaces.GPU(duration=120)
|
62 |
@torch.inference_mode()
|
63 |
def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
64 |
cfg_strength: float, duration: float):
|
|
|
89 |
cfg_strength=cfg_strength)
|
90 |
audio = audios.float().cpu()[0]
|
91 |
|
92 |
+
# current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
93 |
+
video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
94 |
+
# output_dir.mkdir(exist_ok=True, parents=True)
|
95 |
+
# video_save_path = output_dir / f'{current_time_string}.mp4'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
97 |
+
log.info(f'Saved video to {video_save_path}')
|
98 |
return video_save_path
|
99 |
|
100 |
|
101 |
+
@spaces.GPU(duration=120)
|
102 |
@torch.inference_mode()
|
103 |
def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
|
104 |
duration: float):
|
|
|
124 |
cfg_strength=cfg_strength)
|
125 |
audio = audios.float().cpu()[0]
|
126 |
|
127 |
+
audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
|
|
|
|
|
128 |
torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
|
129 |
+
log.info(f'Saved audio to {audio_save_path}')
|
130 |
return audio_save_path
|
131 |
|
132 |
|
|
|
138 |
|
139 |
NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
|
140 |
Doing so does not improve results.
|
141 |
+
|
142 |
+
The model has been trained on 8-second videos. Using much longer or shorter videos will degrade performance. Around 5s~12s should be fine.
|
143 |
""",
|
144 |
inputs=[
|
145 |
gr.Video(),
|
|
|
209 |
10,
|
210 |
],
|
211 |
[
|
212 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
|
213 |
+
'',
|
214 |
'',
|
215 |
0,
|
216 |
25,
|
|
|
218 |
10,
|
219 |
],
|
220 |
[
|
221 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
|
222 |
+
'storm',
|
223 |
'',
|
224 |
0,
|
225 |
25,
|
|
|
227 |
10,
|
228 |
],
|
229 |
[
|
230 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
|
231 |
+
'',
|
232 |
'',
|
233 |
0,
|
234 |
25,
|
|
|
236 |
10,
|
237 |
],
|
238 |
[
|
239 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
|
240 |
+
'typing',
|
241 |
'',
|
242 |
0,
|
243 |
25,
|
|
|
245 |
10,
|
246 |
],
|
247 |
[
|
248 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
|
249 |
'',
|
250 |
'',
|
251 |
0,
|
|
|
257 |
|
258 |
text_to_audio_tab = gr.Interface(
|
259 |
fn=text_to_audio,
|
|
|
|
|
|
|
|
|
260 |
inputs=[
|
261 |
gr.Text(label='Prompt'),
|
262 |
gr.Text(label='Negative prompt'),
|
|
|
270 |
title='MMAudio — Text-to-Audio Synthesis',
|
271 |
)
|
272 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
if __name__ == "__main__":
|
274 |
+
gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab],
|
275 |
+
['Video-to-Audio', 'Text-to-Audio']).launch(allowed_paths=[output_dir])
|
|
|
|
|
|
|
|
|
|
batch_eval.py
DELETED
@@ -1,110 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
import hydra
|
6 |
-
import torch
|
7 |
-
import torch.distributed as distributed
|
8 |
-
import torchaudio
|
9 |
-
from hydra.core.hydra_config import HydraConfig
|
10 |
-
from omegaconf import DictConfig
|
11 |
-
from tqdm import tqdm
|
12 |
-
|
13 |
-
from mmaudio.data.data_setup import setup_eval_dataset
|
14 |
-
from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate
|
15 |
-
from mmaudio.model.flow_matching import FlowMatching
|
16 |
-
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
17 |
-
from mmaudio.model.utils.features_utils import FeaturesUtils
|
18 |
-
|
19 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
20 |
-
torch.backends.cudnn.allow_tf32 = True
|
21 |
-
|
22 |
-
local_rank = int(os.environ['LOCAL_RANK'])
|
23 |
-
world_size = int(os.environ['WORLD_SIZE'])
|
24 |
-
log = logging.getLogger()
|
25 |
-
|
26 |
-
|
27 |
-
@torch.inference_mode()
|
28 |
-
@hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')
|
29 |
-
def main(cfg: DictConfig):
|
30 |
-
device = 'cuda'
|
31 |
-
torch.cuda.set_device(local_rank)
|
32 |
-
|
33 |
-
if cfg.model not in all_model_cfg:
|
34 |
-
raise ValueError(f'Unknown model variant: {cfg.model}')
|
35 |
-
model: ModelConfig = all_model_cfg[cfg.model]
|
36 |
-
model.download_if_needed()
|
37 |
-
seq_cfg = model.seq_cfg
|
38 |
-
|
39 |
-
run_dir = Path(HydraConfig.get().run.dir)
|
40 |
-
if cfg.output_name is None:
|
41 |
-
output_dir = run_dir / cfg.dataset
|
42 |
-
else:
|
43 |
-
output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}'
|
44 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
45 |
-
|
46 |
-
# load a pretrained model
|
47 |
-
seq_cfg.duration = cfg.duration_s
|
48 |
-
net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval()
|
49 |
-
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
50 |
-
log.info(f'Loaded weights from {model.model_path}')
|
51 |
-
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
52 |
-
log.info(f'Latent seq len: {seq_cfg.latent_seq_len}')
|
53 |
-
log.info(f'Clip seq len: {seq_cfg.clip_seq_len}')
|
54 |
-
log.info(f'Sync seq len: {seq_cfg.sync_seq_len}')
|
55 |
-
|
56 |
-
# misc setup
|
57 |
-
rng = torch.Generator(device=device)
|
58 |
-
rng.manual_seed(cfg.seed)
|
59 |
-
fm = FlowMatching(cfg.sampling.min_sigma,
|
60 |
-
inference_mode=cfg.sampling.method,
|
61 |
-
num_steps=cfg.sampling.num_steps)
|
62 |
-
|
63 |
-
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
64 |
-
synchformer_ckpt=model.synchformer_ckpt,
|
65 |
-
enable_conditions=True,
|
66 |
-
mode=model.mode,
|
67 |
-
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
68 |
-
need_vae_encoder=False)
|
69 |
-
feature_utils = feature_utils.to(device).eval()
|
70 |
-
|
71 |
-
if cfg.compile:
|
72 |
-
net.preprocess_conditions = torch.compile(net.preprocess_conditions)
|
73 |
-
net.predict_flow = torch.compile(net.predict_flow)
|
74 |
-
feature_utils.compile()
|
75 |
-
|
76 |
-
dataset, loader = setup_eval_dataset(cfg.dataset, cfg)
|
77 |
-
|
78 |
-
with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):
|
79 |
-
for batch in tqdm(loader):
|
80 |
-
audios = generate(batch.get('clip_video', None),
|
81 |
-
batch.get('sync_video', None),
|
82 |
-
batch.get('caption', None),
|
83 |
-
feature_utils=feature_utils,
|
84 |
-
net=net,
|
85 |
-
fm=fm,
|
86 |
-
rng=rng,
|
87 |
-
cfg_strength=cfg.cfg_strength,
|
88 |
-
clip_batch_size_multiplier=64,
|
89 |
-
sync_batch_size_multiplier=64)
|
90 |
-
audios = audios.float().cpu()
|
91 |
-
names = batch['name']
|
92 |
-
for audio, name in zip(audios, names):
|
93 |
-
torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)
|
94 |
-
|
95 |
-
|
96 |
-
def distributed_setup():
|
97 |
-
distributed.init_process_group(backend="nccl")
|
98 |
-
local_rank = distributed.get_rank()
|
99 |
-
world_size = distributed.get_world_size()
|
100 |
-
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
|
101 |
-
return local_rank, world_size
|
102 |
-
|
103 |
-
|
104 |
-
if __name__ == '__main__':
|
105 |
-
distributed_setup()
|
106 |
-
|
107 |
-
main()
|
108 |
-
|
109 |
-
# clean-up
|
110 |
-
distributed.destroy_process_group()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/__init__.py
DELETED
File without changes
|
config/base_config.yaml
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- data: base
|
3 |
-
- eval_data: base
|
4 |
-
- override hydra/job_logging: custom-simplest
|
5 |
-
- _self_
|
6 |
-
|
7 |
-
hydra:
|
8 |
-
run:
|
9 |
-
dir: ./output/${exp_id}
|
10 |
-
output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
|
11 |
-
|
12 |
-
enable_email: False
|
13 |
-
|
14 |
-
model: small_16k
|
15 |
-
|
16 |
-
exp_id: default
|
17 |
-
debug: False
|
18 |
-
cudnn_benchmark: True
|
19 |
-
compile: True
|
20 |
-
amp: True
|
21 |
-
weights: null
|
22 |
-
checkpoint: null
|
23 |
-
seed: 14159265
|
24 |
-
num_workers: 10 # per-GPU
|
25 |
-
pin_memory: False # set to True if your system can handle it, i.e., have enough memory
|
26 |
-
|
27 |
-
# NOTE: This DOSE NOT affect the model during inference in any way
|
28 |
-
# they are just for the dataloader to fill in the missing data in multi-modal loading
|
29 |
-
# to change the sequence length for the model, see networks.py
|
30 |
-
data_dim:
|
31 |
-
text_seq_len: 77
|
32 |
-
clip_dim: 1024
|
33 |
-
sync_dim: 768
|
34 |
-
text_dim: 1024
|
35 |
-
|
36 |
-
# ema configuration
|
37 |
-
ema:
|
38 |
-
enable: True
|
39 |
-
sigma_rels: [0.05, 0.1]
|
40 |
-
update_every: 1
|
41 |
-
checkpoint_every: 5_000
|
42 |
-
checkpoint_folder: ${hydra:run.dir}/ema_ckpts
|
43 |
-
default_output_sigma: 0.05
|
44 |
-
|
45 |
-
|
46 |
-
# sampling
|
47 |
-
sampling:
|
48 |
-
mean: 0.0
|
49 |
-
scale: 1.0
|
50 |
-
min_sigma: 0.0
|
51 |
-
method: euler
|
52 |
-
num_steps: 25
|
53 |
-
|
54 |
-
# classifier-free guidance
|
55 |
-
null_condition_probability: 0.1
|
56 |
-
cfg_strength: 4.5
|
57 |
-
|
58 |
-
# checkpoint paths to external modules
|
59 |
-
vae_16k_ckpt: ./ext_weights/v1-16.pth
|
60 |
-
vae_44k_ckpt: ./ext_weights/v1-44.pth
|
61 |
-
bigvgan_vocoder_ckpt: ./ext_weights/best_netG.pt
|
62 |
-
synchformer_ckpt: ./ext_weights/synchformer_state_dict.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/data/base.yaml
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
VGGSound:
|
2 |
-
root: ../data/video
|
3 |
-
subset_name: sets/vgg3-train.tsv
|
4 |
-
fps: 8
|
5 |
-
height: 384
|
6 |
-
width: 384
|
7 |
-
sample_duration_sec: 8.0
|
8 |
-
|
9 |
-
VGGSound_test:
|
10 |
-
root: ../data/video
|
11 |
-
subset_name: sets/vgg3-test.tsv
|
12 |
-
fps: 8
|
13 |
-
height: 384
|
14 |
-
width: 384
|
15 |
-
sample_duration_sec: 8.0
|
16 |
-
|
17 |
-
VGGSound_val:
|
18 |
-
root: ../data/video
|
19 |
-
subset_name: sets/vgg3-val.tsv
|
20 |
-
fps: 8
|
21 |
-
height: 384
|
22 |
-
width: 384
|
23 |
-
sample_duration_sec: 8.0
|
24 |
-
|
25 |
-
ExtractedVGG:
|
26 |
-
tsv: ../data/v1-16-memmap/vgg-train.tsv
|
27 |
-
memmap_dir: ../data/v1-16-memmap/vgg-train
|
28 |
-
|
29 |
-
ExtractedVGG_test:
|
30 |
-
tag: test
|
31 |
-
gt_cache: ../data/eval-cache/vggsound-test
|
32 |
-
output_subdir: null
|
33 |
-
tsv: ../data/v1-16-memmap/vgg-test.tsv
|
34 |
-
memmap_dir: ../data/v1-16-memmap/vgg-test
|
35 |
-
|
36 |
-
ExtractedVGG_val:
|
37 |
-
tag: val
|
38 |
-
gt_cache: ../data/eval-cache/vggsound-val
|
39 |
-
output_subdir: val
|
40 |
-
tsv: ../data/v1-16-memmap/vgg-val.tsv
|
41 |
-
memmap_dir: ../data/v1-16-memmap/vgg-val
|
42 |
-
|
43 |
-
AudioCaps:
|
44 |
-
tsv: ../data/v1-16-memmap/audiocaps.tsv
|
45 |
-
memmap_dir: ../data/v1-16-memmap/audiocaps
|
46 |
-
|
47 |
-
AudioSetSL:
|
48 |
-
tsv: ../data/v1-16-memmap/audioset_sl.tsv
|
49 |
-
memmap_dir: ../data/v1-16-memmap/audioset_sl
|
50 |
-
|
51 |
-
BBCSound:
|
52 |
-
tsv: ../data/v1-16-memmap/bbcsound.tsv
|
53 |
-
memmap_dir: ../data/v1-16-memmap/bbcsound
|
54 |
-
|
55 |
-
FreeSound:
|
56 |
-
tsv: ../data/v1-16-memmap/freesound.tsv
|
57 |
-
memmap_dir: ../data/v1-16-memmap/freesound
|
58 |
-
|
59 |
-
Clotho:
|
60 |
-
tsv: ../data/v1-16-memmap/clotho.tsv
|
61 |
-
memmap_dir: ../data/v1-16-memmap/clotho
|
62 |
-
|
63 |
-
Example_video:
|
64 |
-
tsv: ./training/example_output/memmap/vgg-example.tsv
|
65 |
-
memmap_dir: ./training/example_output/memmap/vgg-example
|
66 |
-
|
67 |
-
Example_audio:
|
68 |
-
tsv: ./training/example_output/memmap/audio-example.tsv
|
69 |
-
memmap_dir: ./training/example_output/memmap/audio-example
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/eval_config.yaml
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base_config
|
3 |
-
- override hydra/job_logging: custom-simplest
|
4 |
-
- _self_
|
5 |
-
|
6 |
-
hydra:
|
7 |
-
run:
|
8 |
-
dir: ./output/${exp_id}
|
9 |
-
output_subdir: eval-${now:%Y-%m-%d_%H-%M-%S}-hydra
|
10 |
-
|
11 |
-
exp_id: ${model}
|
12 |
-
dataset: audiocaps
|
13 |
-
duration_s: 8.0
|
14 |
-
|
15 |
-
# for inference, this is the per-GPU batch size
|
16 |
-
batch_size: 16
|
17 |
-
output_name: null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/eval_data/base.yaml
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
AudioCaps:
|
2 |
-
audio_path: ../data/AudioCaps-test-audioldm-ver
|
3 |
-
# a csv file, with a header row of 'name' and 'caption'
|
4 |
-
# name should match the audio file name without extension
|
5 |
-
# Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_audioldm_data.csv
|
6 |
-
csv_path: ../data/AudioCaps-test-audioldm-ver/data.csv
|
7 |
-
|
8 |
-
AudioCaps_full:
|
9 |
-
audio_path: ../data/AudioCaps-test-full-ver
|
10 |
-
# a csv file, with a header row of 'name' and 'caption'
|
11 |
-
# name should match the audio file name without extension
|
12 |
-
# Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_full_data.csv
|
13 |
-
csv_path: ../data/AudioCaps-test-full-ver/data.csv
|
14 |
-
|
15 |
-
MovieGen:
|
16 |
-
video_path: ../data/MovieGen/MovieGenAudioBenchSfx/video_with_audio
|
17 |
-
jsonl_path: ../data/MovieGen/MovieGenAudioBenchSfx/metadata
|
18 |
-
|
19 |
-
VGGSound:
|
20 |
-
video_path: ../data/test-videos
|
21 |
-
# from the officially released csv file
|
22 |
-
csv_path: ../data/vggsound.csv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/hydra/job_logging/custom-eval.yaml
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
# python logging configuration for tasks
|
2 |
-
version: 1
|
3 |
-
formatters:
|
4 |
-
simple:
|
5 |
-
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
|
6 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
-
colorlog:
|
8 |
-
'()': 'colorlog.ColoredFormatter'
|
9 |
-
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
10 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
11 |
-
log_colors:
|
12 |
-
DEBUG: purple
|
13 |
-
INFO: green
|
14 |
-
WARNING: yellow
|
15 |
-
ERROR: red
|
16 |
-
CRITICAL: red
|
17 |
-
handlers:
|
18 |
-
console:
|
19 |
-
class: logging.StreamHandler
|
20 |
-
formatter: colorlog
|
21 |
-
stream: ext://sys.stdout
|
22 |
-
file:
|
23 |
-
class: logging.FileHandler
|
24 |
-
formatter: simple
|
25 |
-
# absolute file path
|
26 |
-
filename: ${hydra.runtime.output_dir}/eval-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
|
27 |
-
mode: w
|
28 |
-
root:
|
29 |
-
level: INFO
|
30 |
-
handlers: [console, file]
|
31 |
-
|
32 |
-
disable_existing_loggers: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/hydra/job_logging/custom-no-rank.yaml
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
# python logging configuration for tasks
|
2 |
-
version: 1
|
3 |
-
formatters:
|
4 |
-
simple:
|
5 |
-
format: '[%(asctime)s][%(levelname)s] - %(message)s'
|
6 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
-
colorlog:
|
8 |
-
'()': 'colorlog.ColoredFormatter'
|
9 |
-
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
10 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
11 |
-
log_colors:
|
12 |
-
DEBUG: purple
|
13 |
-
INFO: green
|
14 |
-
WARNING: yellow
|
15 |
-
ERROR: red
|
16 |
-
CRITICAL: red
|
17 |
-
handlers:
|
18 |
-
console:
|
19 |
-
class: logging.StreamHandler
|
20 |
-
formatter: colorlog
|
21 |
-
stream: ext://sys.stdout
|
22 |
-
file:
|
23 |
-
class: logging.FileHandler
|
24 |
-
formatter: simple
|
25 |
-
# absolute file path
|
26 |
-
filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
|
27 |
-
mode: w
|
28 |
-
root:
|
29 |
-
level: INFO
|
30 |
-
handlers: [console, file]
|
31 |
-
|
32 |
-
disable_existing_loggers: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/hydra/job_logging/custom-simplest.yaml
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
# python logging configuration for tasks
|
2 |
-
version: 1
|
3 |
-
formatters:
|
4 |
-
simple:
|
5 |
-
format: '[%(asctime)s][%(levelname)s] - %(message)s'
|
6 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
-
colorlog:
|
8 |
-
'()': 'colorlog.ColoredFormatter'
|
9 |
-
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
10 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
11 |
-
log_colors:
|
12 |
-
DEBUG: purple
|
13 |
-
INFO: green
|
14 |
-
WARNING: yellow
|
15 |
-
ERROR: red
|
16 |
-
CRITICAL: red
|
17 |
-
handlers:
|
18 |
-
console:
|
19 |
-
class: logging.StreamHandler
|
20 |
-
formatter: colorlog
|
21 |
-
stream: ext://sys.stdout
|
22 |
-
root:
|
23 |
-
level: INFO
|
24 |
-
handlers: [console]
|
25 |
-
|
26 |
-
disable_existing_loggers: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/hydra/job_logging/custom.yaml
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
# @package hydra.job_logging
|
2 |
-
# python logging configuration for tasks
|
3 |
-
version: 1
|
4 |
-
formatters:
|
5 |
-
simple:
|
6 |
-
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
|
7 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
8 |
-
colorlog:
|
9 |
-
'()': 'colorlog.ColoredFormatter'
|
10 |
-
format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)sr${oc.env:LOCAL_RANK}%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
11 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
12 |
-
log_colors:
|
13 |
-
DEBUG: purple
|
14 |
-
INFO: green
|
15 |
-
WARNING: yellow
|
16 |
-
ERROR: red
|
17 |
-
CRITICAL: red
|
18 |
-
handlers:
|
19 |
-
console:
|
20 |
-
class: logging.StreamHandler
|
21 |
-
formatter: colorlog
|
22 |
-
stream: ext://sys.stdout
|
23 |
-
file:
|
24 |
-
class: logging.FileHandler
|
25 |
-
formatter: simple
|
26 |
-
# absolute file path
|
27 |
-
filename: ${hydra.runtime.output_dir}/train-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
|
28 |
-
mode: w
|
29 |
-
root:
|
30 |
-
level: INFO
|
31 |
-
handlers: [console, file]
|
32 |
-
|
33 |
-
disable_existing_loggers: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/train_config.yaml
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base_config
|
3 |
-
- override data: base
|
4 |
-
- override hydra/job_logging: custom
|
5 |
-
- _self_
|
6 |
-
|
7 |
-
hydra:
|
8 |
-
run:
|
9 |
-
dir: ./output/${exp_id}
|
10 |
-
output_subdir: train-${now:%Y-%m-%d_%H-%M-%S}-hydra
|
11 |
-
|
12 |
-
ema:
|
13 |
-
start: 0
|
14 |
-
|
15 |
-
mini_train: False
|
16 |
-
example_train: False
|
17 |
-
enable_grad_scaler: False
|
18 |
-
vgg_oversample_rate: 5
|
19 |
-
|
20 |
-
log_text_interval: 200
|
21 |
-
log_extra_interval: 20_000
|
22 |
-
val_interval: 5_000
|
23 |
-
eval_interval: 20_000
|
24 |
-
save_eval_interval: 40_000
|
25 |
-
save_weights_interval: 10_000
|
26 |
-
save_checkpoint_interval: 10_000
|
27 |
-
save_copy_iterations: []
|
28 |
-
|
29 |
-
batch_size: 512
|
30 |
-
eval_batch_size: 256 # per-GPU
|
31 |
-
|
32 |
-
num_iterations: 300_000
|
33 |
-
learning_rate: 1.0e-4
|
34 |
-
linear_warmup_steps: 1_000
|
35 |
-
|
36 |
-
lr_schedule: step
|
37 |
-
lr_schedule_steps: [240_000, 270_000]
|
38 |
-
lr_schedule_gamma: 0.1
|
39 |
-
|
40 |
-
clip_grad_norm: 1.0
|
41 |
-
weight_decay: 1.0e-6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.py
CHANGED
@@ -62,13 +62,7 @@ def main():
|
|
62 |
skip_video_composite: bool = args.skip_video_composite
|
63 |
mask_away_clip: bool = args.mask_away_clip
|
64 |
|
65 |
-
device = '
|
66 |
-
if torch.cuda.is_available():
|
67 |
-
device = 'cuda'
|
68 |
-
elif torch.backends.mps.is_available():
|
69 |
-
device = 'mps'
|
70 |
-
else:
|
71 |
-
log.warning('CUDA/MPS are not available, running on CPU')
|
72 |
dtype = torch.float32 if args.full_precision else torch.bfloat16
|
73 |
|
74 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
62 |
skip_video_composite: bool = args.skip_video_composite
|
63 |
mask_away_clip: bool = args.mask_away_clip
|
64 |
|
65 |
+
device = 'cuda'
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
dtype = torch.float32 if args.full_precision else torch.bfloat16
|
67 |
|
68 |
output_dir.mkdir(parents=True, exist_ok=True)
|
docs/EVAL.md
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
# Evaluation
|
2 |
-
|
3 |
-
## Batch Evaluation
|
4 |
-
|
5 |
-
To evaluate the model on a dataset, use the `batch_eval.py` script. It is significantly more efficient in large-scale evaluation compared to `demo.py`, supporting batched inference, multi-GPU inference, torch compilation, and skipping video compositions.
|
6 |
-
|
7 |
-
An example of running this script with four GPUs is as follows:
|
8 |
-
|
9 |
-
```bash
|
10 |
-
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=4 batch_eval.py duration_s=8 dataset=vggsound model=small_16k num_workers=8
|
11 |
-
```
|
12 |
-
|
13 |
-
You may need to update the data paths in `config/eval_data/base.yaml`.
|
14 |
-
More configuration options can be found in `config/base_config.yaml` and `config/eval_config.yaml`.
|
15 |
-
|
16 |
-
## Precomputed Results
|
17 |
-
|
18 |
-
Precomputed results for VGGSound, AudioCaps, and MovieGen are available here: https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results
|
19 |
-
|
20 |
-
## Obtaining Quantitative Metrics
|
21 |
-
|
22 |
-
Our evaluation code is available here: https://github.com/hkchengrex/av-benchmark
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/MODELS.md
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
# Pretrained models
|
2 |
-
|
3 |
-
The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`.
|
4 |
-
The models are also available at https://huggingface.co/hkchengrex/MMAudio/tree/main
|
5 |
-
|
6 |
-
| Model | Download link | File size |
|
7 |
-
| -------- | ------- | ------- |
|
8 |
-
| Flow prediction network, small 16kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_16k.pth" download="mmaudio_small_16k.pth">mmaudio_small_16k.pth</a> | 601M |
|
9 |
-
| Flow prediction network, small 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_44k.pth" download="mmaudio_small_44k.pth">mmaudio_small_44k.pth</a> | 601M |
|
10 |
-
| Flow prediction network, medium 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_medium_44k.pth" download="mmaudio_medium_44k.pth">mmaudio_medium_44k.pth</a> | 2.4G |
|
11 |
-
| Flow prediction network, large 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k.pth" download="mmaudio_large_44k.pth">mmaudio_large_44k.pth</a> | 3.9G |
|
12 |
-
| Flow prediction network, large 44.1kHz, v2 **(recommended)** | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth" download="mmaudio_large_44k_v2.pth">mmaudio_large_44k_v2.pth</a> | 3.9G |
|
13 |
-
| 16kHz VAE | <a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-16.pth">v1-16.pth</a> | 655M |
|
14 |
-
| 16kHz BigVGAN vocoder (from Make-An-Audio 2) |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/best_netG.pt">best_netG.pt</a> | 429M |
|
15 |
-
| 44.1kHz VAE |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth">v1-44.pth</a> | 1.2G |
|
16 |
-
| Synchformer visual encoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth">synchformer_state_dict.pth</a> | 907M |
|
17 |
-
|
18 |
-
To run the model, you need four components: a flow prediction network, visual feature extractors (Synchformer and CLIP, CLIP will be downloaded automatically), a VAE, and a vocoder. VAEs and vocoders are specific to the sampling rate (16kHz or 44.1kHz) and not model sizes.
|
19 |
-
The 44.1kHz vocoder will be downloaded automatically.
|
20 |
-
The `_v2` model performs worse in benchmarking (e.g., in Fréchet distance), but, in my experience, generalizes better to new data.
|
21 |
-
|
22 |
-
The expected directory structure (full):
|
23 |
-
|
24 |
-
```bash
|
25 |
-
MMAudio
|
26 |
-
├── ext_weights
|
27 |
-
│ ├── best_netG.pt
|
28 |
-
│ ├── synchformer_state_dict.pth
|
29 |
-
│ ├── v1-16.pth
|
30 |
-
│ └── v1-44.pth
|
31 |
-
├── weights
|
32 |
-
│ ├── mmaudio_small_16k.pth
|
33 |
-
│ ├── mmaudio_small_44k.pth
|
34 |
-
│ ├── mmaudio_medium_44k.pth
|
35 |
-
│ ├── mmaudio_large_44k.pth
|
36 |
-
│ └── mmaudio_large_44k_v2.pth
|
37 |
-
└── ...
|
38 |
-
```
|
39 |
-
|
40 |
-
The expected directory structure (minimal, for the recommended model only):
|
41 |
-
|
42 |
-
```bash
|
43 |
-
MMAudio
|
44 |
-
├── ext_weights
|
45 |
-
│ ├── synchformer_state_dict.pth
|
46 |
-
│ └── v1-44.pth
|
47 |
-
├── weights
|
48 |
-
│ └── mmaudio_large_44k_v2.pth
|
49 |
-
└── ...
|
50 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/TRAINING.md
DELETED
@@ -1,184 +0,0 @@
|
|
1 |
-
# Training
|
2 |
-
|
3 |
-
## Overview
|
4 |
-
|
5 |
-
We have put a large emphasis on making training as fast as possible.
|
6 |
-
Consequently, some pre-processing steps are required.
|
7 |
-
|
8 |
-
Namely, before starting any training, we
|
9 |
-
|
10 |
-
1. Obtain training data as videos, audios, and captions.
|
11 |
-
2. Encode training audios into spectrograms and then with VAE into mean/std
|
12 |
-
3. Extract CLIP and synchronization features from videos
|
13 |
-
4. Extract CLIP features from text (captions)
|
14 |
-
5. Encode all extracted features into [MemoryMappedTensors](https://pytorch.org/tensordict/main/reference/generated/tensordict.MemoryMappedTensor.html) with [TensorDict](https://pytorch.org/tensordict/main/reference/tensordict.html)
|
15 |
-
|
16 |
-
**NOTE:** for maximum training speed (e.g., when training the base model with 2*H100s), you would need around 3~5 GB/s of random read speed. Spinning disks would not be able to catch up and most consumer-grade SSDs would struggle. In my experience, the best bet is to have a large enough system memory such that the OS can cache the data. This way, the data is read from RAM instead of disk.
|
17 |
-
|
18 |
-
The current training script does not support `_v2` training.
|
19 |
-
|
20 |
-
## Recommended Hardware Configuration
|
21 |
-
|
22 |
-
These are what I recommend for a smooth and efficient training experience. These are not minimum requirements.
|
23 |
-
|
24 |
-
- Single-node machine. We did not implement multi-node training
|
25 |
-
- GPUs: for the small model, two 80G-H100s or above; for the large model, eight 80G-H100s or above
|
26 |
-
- System memory: for 16kHz training, 600GB+; for 44kHz training, 700GB+
|
27 |
-
- Storage: >2TB of fast NVMe storage. If you have enough system memory, OS caching will help and the storage does not need to be as fast.
|
28 |
-
|
29 |
-
## Prerequisites
|
30 |
-
|
31 |
-
1. Install [av-benchmark](https://github.com/hkchengrex/av-benchmark). We use this library to automatically evaluate on the validation set during training, and on the test set after training.
|
32 |
-
2. Extract features for evaluation using [av-benchmark](https://github.com/hkchengrex/av-benchmark) for the validation and test set as a [validation cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L38) and a [test cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L31). You can also download the precomputed evaluation cache [here](https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results/tree/main).
|
33 |
-
|
34 |
-
3. You will need ffmpeg to extract frames from videos. Note that `torchaudio` imposes a maximum version limit (`ffmpeg<7`). You can install it as follows:
|
35 |
-
|
36 |
-
```bash
|
37 |
-
conda install -c conda-forge 'ffmpeg<7'
|
38 |
-
```
|
39 |
-
|
40 |
-
4. Download the training datasets. We used [VGGSound](https://arxiv.org/abs/2004.14368), [AudioCaps](https://audiocaps.github.io/), [WavCaps](https://arxiv.org/abs/2303.17395), and [Clotho](https://arxiv.org/abs/1910.09387) (paper to be updated). Note that the audio files in the huggingface release of WavCaps have been downsampled to 32kHz. To the best of our ability, we located the original (high-sampling rate) audio files and used them instead to prevent artifacts during 44.1kHz training. We did not use the "SoundBible" portion of WavCaps, since it is a small set with many short audio unsuitable for our training.
|
41 |
-
|
42 |
-
5. Download the corresponding VAE (`v1-16.pth` for 16kHz training, and `v1-44.pth` for 44.1kHz training), vocoder models (`best_netG.pt` for 16kHz training; the vocoder for 44.1kHz training will be downloaded automatically), the [empty string encoding](https://github.com/hkchengrex/MMAudio/releases/download/v0.1/empty_string.pth), and Synchformer weights from [MODELS.md](https://github.com/hkchengrex/MMAudio/blob/main/docs/MODELS.md) place them in `ext_weights/`.
|
43 |
-
|
44 |
-
### Helpful links for downloading the datasets
|
45 |
-
|
46 |
-
We cannot redistribute the datasets for copyright reasons, but we do find some links helpful and they might be helpful to you as well.
|
47 |
-
|
48 |
-
- https://huggingface.co/datasets/Meranti/CLAP_freesound
|
49 |
-
- https://huggingface.co/datasets/agkphysics/AudioSet
|
50 |
-
- https://sound-effects.bbcrewind.co.uk/
|
51 |
-
|
52 |
-
For certain sources of VGGSound, you might notice desychronization between the audio and the video. This happens the video keyframes do not always align with the start of the audio and what happens during playbacks is player-dependent. We used PyTorch's decoder which can correctly handle these cases.
|
53 |
-
|
54 |
-
## Preparing Audio-Video-Text Features
|
55 |
-
|
56 |
-
We have prepared some example data in `training/example_videos`.
|
57 |
-
`training/extract_video_training_latents.py` extracts audio, video, and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
|
58 |
-
|
59 |
-
To run this script, use the `torchrun` utility:
|
60 |
-
|
61 |
-
```bash
|
62 |
-
torchrun --standalone training/extract_video_training_latents.py
|
63 |
-
```
|
64 |
-
|
65 |
-
You can run this script with multiple GPUs (with `--nproc_per_node=<n>` after `--standalone` and before the script name) to speed up extraction.
|
66 |
-
Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
|
67 |
-
Change the data path definitions in `data_cfg` if necessary.
|
68 |
-
|
69 |
-
Arguments:
|
70 |
-
|
71 |
-
- `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
|
72 |
-
- `output_dir` -- where TensorDict and the metadata file are saved.
|
73 |
-
|
74 |
-
Outputs produced in `output_dir`:
|
75 |
-
|
76 |
-
1. A directory named `vgg-{split}` (i.e., in the TensorDict format), containing
|
77 |
-
a. `mean.memmap` mean values predicted by the VAE encoder (number of videos X sequence length X channel size)
|
78 |
-
b. `std.memmap` standard deviation values predicted by the VAE encoder (number of videos X sequence length X channel size)
|
79 |
-
c. `text_features.memmap` text features extracted from CLIP (number of videos X 77 (sequence length) X 1024)
|
80 |
-
d. `clip_features.memmap` clip features extracted from CLIP (number of videos X 64 (8 fps) X 1024)
|
81 |
-
e. `sync_features.memmap` synchronization features extracted from Synchformer (number of videos X 192 (24 fps) X 768)
|
82 |
-
f. `meta.json` that contains the metadata for the above memory mappings
|
83 |
-
2. A tab-separated values file named `vgg-{split}.tsv` that contains two columns: `id` containing video file names without extension, and `label` containing corresponding text labels (i.e., captions)
|
84 |
-
|
85 |
-
## Preparing Audio-Text Features
|
86 |
-
|
87 |
-
We have prepared some example data in `training/example_audios`.
|
88 |
-
|
89 |
-
1. Run `training/partition_clips` to partition each audio file into clips (by finding start and end points; we do not save the partitioned audio onto the disk to save disk space)
|
90 |
-
2. Run `training/extract_audio_training_latents.py` to extract each clip's audio and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
|
91 |
-
|
92 |
-
### Partitioning the audio files
|
93 |
-
|
94 |
-
Run
|
95 |
-
|
96 |
-
```bash
|
97 |
-
python training/partition_clips.py
|
98 |
-
```
|
99 |
-
|
100 |
-
Arguments:
|
101 |
-
|
102 |
-
- `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`)
|
103 |
-
- `output_dir` -- path to the output `.csv` file
|
104 |
-
- `start` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the beginning of the chunk to be processed
|
105 |
-
- `end` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the end of the chunk to be processed
|
106 |
-
|
107 |
-
### Extracting audio and text features
|
108 |
-
|
109 |
-
Run
|
110 |
-
|
111 |
-
```bash
|
112 |
-
torchrun --standalone training/extract_audio_training_latents.py
|
113 |
-
```
|
114 |
-
|
115 |
-
You can run this with multiple GPUs (with `--nproc_per_node=<n>`) to speed up extraction.
|
116 |
-
Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
|
117 |
-
|
118 |
-
Arguments:
|
119 |
-
|
120 |
-
- `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`), same as the previous step
|
121 |
-
- `captions_tsv` -- path to the captions file, a tab-separated values (tsv) file at least with columns `id` and `caption`
|
122 |
-
- `clips_tsv` -- path to the clips file, generated in the last step
|
123 |
-
- `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
|
124 |
-
- `output_dir` -- where TensorDict and the metadata file are saved.
|
125 |
-
|
126 |
-
Outputs produced in `output_dir`:
|
127 |
-
|
128 |
-
1. A directory named `{basename(output_dir)}` (i.e., in the TensorDict format), containing
|
129 |
-
a. `mean.memmap` mean values predicted by the VAE encoder (number of audios X sequence length X channel size)
|
130 |
-
b. `std.memmap` standard deviation values predicted by the VAE encoder (number of audios X sequence length X channel size)
|
131 |
-
c. `text_features.memmap` text features extracted from CLIP (number of audios X 77 (sequence length) X 1024)
|
132 |
-
f. `meta.json` that contains the metadata for the above memory mappings
|
133 |
-
2. A tab-separated values file named `{basename(output_dir)}.tsv` that contains two columns: `id` containing audio file names without extension, and `label` containing corresponding text labels (i.e., captions)
|
134 |
-
|
135 |
-
### Reference tsv files (with overlaps removed as mentioned in the paper)
|
136 |
-
|
137 |
-
The reference tsv files can be found [here](https://github.com/hkchengrex/MMAudio/releases/tag/v0.1).
|
138 |
-
|
139 |
-
Note that these reference tsv files are the **outputs** of `extract_audio_training_latents.py`, which means the `id` column might contain duplicate entries (one per clip). You can still use it as the `captions_tsv` input though -- the script will handle duplicates gracefully.
|
140 |
-
Among these reference tsv files, `audioset_sl.tsv`, `bbcsound.tsv`, and `freesound.tsv` are subsets that are parts of WavCaps. These subsets might be smaller than the original datasets.
|
141 |
-
The Clotho data contains both the development set and the validation set.
|
142 |
-
|
143 |
-
**Update (Mar 9, 2025)**:
|
144 |
-
We have updated a corrected set of reference tsv files. The previous tsv files contained some (<1%) corrupted captions (ie, mismatch between audio and caption, see https://github.com/hkchengrex/MMAudio/issues/56). The tsv files for VGGSound are unaffected. This reason for this error is unknown, but I cannot reproduce this error in the latest version of the code. Our pre-trained models are trained with **uncorrected** tsv files. For future training, I recommend using the corrected tsv files.
|
145 |
-
|
146 |
-
The error statistics are as follows:
|
147 |
-
|
148 |
-
- AudioCaps (170/43824), 0.39%
|
149 |
-
- Freesound: (1670/180636), 0.92%
|
150 |
-
- AudioSet: (290/100776), 0.29%
|
151 |
-
- BBCSound: (3/29975), 0.01%
|
152 |
-
- Clotho: (8/24332), 0.03%
|
153 |
-
|
154 |
-
## Training on Extracted Features
|
155 |
-
|
156 |
-
We use Distributed Data Parallel (DDP) for training.
|
157 |
-
First, specify the data path in `config/data/base.yaml`. If you used the default parameters in the scripts above to extract features for the example data, the `Example_video` and `Example_audio` items should already be correct.
|
158 |
-
|
159 |
-
To run training on the example data, use the following command:
|
160 |
-
|
161 |
-
```bash
|
162 |
-
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=1 train.py exp_id=debug compile=False debug=True example_train=True batch_size=1
|
163 |
-
```
|
164 |
-
|
165 |
-
This will not train a useful model, but it will check if everything is set up correctly.
|
166 |
-
|
167 |
-
For full training on the base model with two GPUs, use the following command:
|
168 |
-
|
169 |
-
```bash
|
170 |
-
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=2 train.py exp_id=exp_1 model=small_16k
|
171 |
-
```
|
172 |
-
|
173 |
-
Any outputs from training will be stored in `output/<exp_id>`.
|
174 |
-
|
175 |
-
More configuration options can be found in `config/base_config.yaml` and `config/train_config.yaml`.
|
176 |
-
For the medium and large models, specify `vgg_oversample_rate` to be `3` to reduce overfitting.
|
177 |
-
|
178 |
-
## Checkpoints
|
179 |
-
|
180 |
-
Model checkpoints, including optimizer states and the latest EMA weights, are available here: https://huggingface.co/hkchengrex/MMAudio
|
181 |
-
|
182 |
-
---
|
183 |
-
|
184 |
-
Godspeed!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/index.html
CHANGED
@@ -40,7 +40,7 @@
|
|
40 |
<br>
|
41 |
<div class="row text-center" style="font-size:28px">
|
42 |
<div class="col">
|
43 |
-
|
44 |
</div>
|
45 |
</div>
|
46 |
<br>
|
@@ -83,21 +83,19 @@
|
|
83 |
<br>
|
84 |
|
85 |
<div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
|
86 |
-
<div class="col-sm-2">
|
87 |
-
<a href="https://arxiv.org/abs/
|
88 |
-
</div>
|
89 |
-
<div class="col-sm-2">
|
90 |
-
<a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
|
91 |
-
</div>
|
92 |
<div class="col-sm-3">
|
93 |
-
<a href="
|
94 |
-
</div>
|
95 |
-
<div class="col-sm-2">
|
96 |
-
<a href="https://colab.research.google.com/drive/1TAaXCY2-kPk4xE4PwKB3EqFbSnkUuzZ8?usp=sharing">[Colab Demo]</a>
|
97 |
</div>
|
98 |
<div class="col-sm-3">
|
99 |
-
<a href="https://
|
100 |
</div>
|
|
|
|
|
|
|
|
|
101 |
</div>
|
102 |
|
103 |
<br>
|
|
|
40 |
<br>
|
41 |
<div class="row text-center" style="font-size:28px">
|
42 |
<div class="col">
|
43 |
+
arXiv 2024
|
44 |
</div>
|
45 |
</div>
|
46 |
<br>
|
|
|
83 |
<br>
|
84 |
|
85 |
<div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
|
86 |
+
<!-- <div class="col-sm-2">
|
87 |
+
<a href="https://arxiv.org/abs/2310.12982">[arXiv]</a>
|
88 |
+
</div> -->
|
|
|
|
|
|
|
89 |
<div class="col-sm-3">
|
90 |
+
<a href="">[Paper (being prepared)]</a>
|
|
|
|
|
|
|
91 |
</div>
|
92 |
<div class="col-sm-3">
|
93 |
+
<a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
|
94 |
</div>
|
95 |
+
<!-- <div class="col-sm-2">
|
96 |
+
<a
|
97 |
+
href="https://colab.research.google.com/drive/1yo43XTbjxuWA7XgCUO9qxAi7wBI6HzvP?usp=sharing">[Colab]</a>
|
98 |
+
</div> -->
|
99 |
</div>
|
100 |
|
101 |
<br>
|
gradio_demo.py
DELETED
@@ -1,343 +0,0 @@
|
|
1 |
-
import gc
|
2 |
-
import logging
|
3 |
-
from argparse import ArgumentParser
|
4 |
-
from datetime import datetime
|
5 |
-
from fractions import Fraction
|
6 |
-
from pathlib import Path
|
7 |
-
|
8 |
-
import gradio as gr
|
9 |
-
import torch
|
10 |
-
import torchaudio
|
11 |
-
|
12 |
-
from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
|
13 |
-
load_video, make_video, setup_eval_logging)
|
14 |
-
from mmaudio.model.flow_matching import FlowMatching
|
15 |
-
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
16 |
-
from mmaudio.model.sequence_config import SequenceConfig
|
17 |
-
from mmaudio.model.utils.features_utils import FeaturesUtils
|
18 |
-
|
19 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
20 |
-
torch.backends.cudnn.allow_tf32 = True
|
21 |
-
|
22 |
-
log = logging.getLogger()
|
23 |
-
|
24 |
-
device = 'cpu'
|
25 |
-
if torch.cuda.is_available():
|
26 |
-
device = 'cuda'
|
27 |
-
elif torch.backends.mps.is_available():
|
28 |
-
device = 'mps'
|
29 |
-
else:
|
30 |
-
log.warning('CUDA/MPS are not available, running on CPU')
|
31 |
-
dtype = torch.bfloat16
|
32 |
-
|
33 |
-
model: ModelConfig = all_model_cfg['large_44k_v2']
|
34 |
-
model.download_if_needed()
|
35 |
-
output_dir = Path('./output/gradio')
|
36 |
-
|
37 |
-
setup_eval_logging()
|
38 |
-
|
39 |
-
|
40 |
-
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
41 |
-
seq_cfg = model.seq_cfg
|
42 |
-
|
43 |
-
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
|
44 |
-
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
45 |
-
log.info(f'Loaded weights from {model.model_path}')
|
46 |
-
|
47 |
-
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
48 |
-
synchformer_ckpt=model.synchformer_ckpt,
|
49 |
-
enable_conditions=True,
|
50 |
-
mode=model.mode,
|
51 |
-
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
52 |
-
need_vae_encoder=False)
|
53 |
-
feature_utils = feature_utils.to(device, dtype).eval()
|
54 |
-
|
55 |
-
return net, feature_utils, seq_cfg
|
56 |
-
|
57 |
-
|
58 |
-
net, feature_utils, seq_cfg = get_model()
|
59 |
-
|
60 |
-
|
61 |
-
@torch.inference_mode()
|
62 |
-
def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
63 |
-
cfg_strength: float, duration: float):
|
64 |
-
|
65 |
-
rng = torch.Generator(device=device)
|
66 |
-
if seed >= 0:
|
67 |
-
rng.manual_seed(seed)
|
68 |
-
else:
|
69 |
-
rng.seed()
|
70 |
-
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
71 |
-
|
72 |
-
video_info = load_video(video, duration)
|
73 |
-
clip_frames = video_info.clip_frames
|
74 |
-
sync_frames = video_info.sync_frames
|
75 |
-
duration = video_info.duration_sec
|
76 |
-
clip_frames = clip_frames.unsqueeze(0)
|
77 |
-
sync_frames = sync_frames.unsqueeze(0)
|
78 |
-
seq_cfg.duration = duration
|
79 |
-
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
80 |
-
|
81 |
-
audios = generate(clip_frames,
|
82 |
-
sync_frames, [prompt],
|
83 |
-
negative_text=[negative_prompt],
|
84 |
-
feature_utils=feature_utils,
|
85 |
-
net=net,
|
86 |
-
fm=fm,
|
87 |
-
rng=rng,
|
88 |
-
cfg_strength=cfg_strength)
|
89 |
-
audio = audios.float().cpu()[0]
|
90 |
-
|
91 |
-
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
92 |
-
output_dir.mkdir(exist_ok=True, parents=True)
|
93 |
-
video_save_path = output_dir / f'{current_time_string}.mp4'
|
94 |
-
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
95 |
-
gc.collect()
|
96 |
-
return video_save_path
|
97 |
-
|
98 |
-
|
99 |
-
@torch.inference_mode()
|
100 |
-
def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
101 |
-
cfg_strength: float, duration: float):
|
102 |
-
|
103 |
-
rng = torch.Generator(device=device)
|
104 |
-
if seed >= 0:
|
105 |
-
rng.manual_seed(seed)
|
106 |
-
else:
|
107 |
-
rng.seed()
|
108 |
-
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
109 |
-
|
110 |
-
image_info = load_image(image)
|
111 |
-
clip_frames = image_info.clip_frames
|
112 |
-
sync_frames = image_info.sync_frames
|
113 |
-
clip_frames = clip_frames.unsqueeze(0)
|
114 |
-
sync_frames = sync_frames.unsqueeze(0)
|
115 |
-
seq_cfg.duration = duration
|
116 |
-
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
117 |
-
|
118 |
-
audios = generate(clip_frames,
|
119 |
-
sync_frames, [prompt],
|
120 |
-
negative_text=[negative_prompt],
|
121 |
-
feature_utils=feature_utils,
|
122 |
-
net=net,
|
123 |
-
fm=fm,
|
124 |
-
rng=rng,
|
125 |
-
cfg_strength=cfg_strength,
|
126 |
-
image_input=True)
|
127 |
-
audio = audios.float().cpu()[0]
|
128 |
-
|
129 |
-
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
130 |
-
output_dir.mkdir(exist_ok=True, parents=True)
|
131 |
-
video_save_path = output_dir / f'{current_time_string}.mp4'
|
132 |
-
video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1))
|
133 |
-
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
134 |
-
gc.collect()
|
135 |
-
return video_save_path
|
136 |
-
|
137 |
-
|
138 |
-
@torch.inference_mode()
|
139 |
-
def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
|
140 |
-
duration: float):
|
141 |
-
|
142 |
-
rng = torch.Generator(device=device)
|
143 |
-
if seed >= 0:
|
144 |
-
rng.manual_seed(seed)
|
145 |
-
else:
|
146 |
-
rng.seed()
|
147 |
-
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
148 |
-
|
149 |
-
clip_frames = sync_frames = None
|
150 |
-
seq_cfg.duration = duration
|
151 |
-
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
152 |
-
|
153 |
-
audios = generate(clip_frames,
|
154 |
-
sync_frames, [prompt],
|
155 |
-
negative_text=[negative_prompt],
|
156 |
-
feature_utils=feature_utils,
|
157 |
-
net=net,
|
158 |
-
fm=fm,
|
159 |
-
rng=rng,
|
160 |
-
cfg_strength=cfg_strength)
|
161 |
-
audio = audios.float().cpu()[0]
|
162 |
-
|
163 |
-
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
164 |
-
output_dir.mkdir(exist_ok=True, parents=True)
|
165 |
-
audio_save_path = output_dir / f'{current_time_string}.flac'
|
166 |
-
torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
|
167 |
-
gc.collect()
|
168 |
-
return audio_save_path
|
169 |
-
|
170 |
-
|
171 |
-
video_to_audio_tab = gr.Interface(
|
172 |
-
fn=video_to_audio,
|
173 |
-
description="""
|
174 |
-
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
175 |
-
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
176 |
-
|
177 |
-
NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
|
178 |
-
Doing so does not improve results.
|
179 |
-
""",
|
180 |
-
inputs=[
|
181 |
-
gr.Video(),
|
182 |
-
gr.Text(label='Prompt'),
|
183 |
-
gr.Text(label='Negative prompt', value='music'),
|
184 |
-
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
185 |
-
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
186 |
-
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
187 |
-
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
188 |
-
],
|
189 |
-
outputs='playable_video',
|
190 |
-
cache_examples=False,
|
191 |
-
title='MMAudio — Video-to-Audio Synthesis',
|
192 |
-
examples=[
|
193 |
-
[
|
194 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
|
195 |
-
'waves, seagulls',
|
196 |
-
'',
|
197 |
-
0,
|
198 |
-
25,
|
199 |
-
4.5,
|
200 |
-
10,
|
201 |
-
],
|
202 |
-
[
|
203 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4',
|
204 |
-
'',
|
205 |
-
'music',
|
206 |
-
0,
|
207 |
-
25,
|
208 |
-
4.5,
|
209 |
-
10,
|
210 |
-
],
|
211 |
-
[
|
212 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4',
|
213 |
-
'bubbles',
|
214 |
-
'',
|
215 |
-
0,
|
216 |
-
25,
|
217 |
-
4.5,
|
218 |
-
10,
|
219 |
-
],
|
220 |
-
[
|
221 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4',
|
222 |
-
'Indian holy music',
|
223 |
-
'',
|
224 |
-
0,
|
225 |
-
25,
|
226 |
-
4.5,
|
227 |
-
10,
|
228 |
-
],
|
229 |
-
[
|
230 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4',
|
231 |
-
'galloping',
|
232 |
-
'',
|
233 |
-
0,
|
234 |
-
25,
|
235 |
-
4.5,
|
236 |
-
10,
|
237 |
-
],
|
238 |
-
[
|
239 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
|
240 |
-
'waves, storm',
|
241 |
-
'',
|
242 |
-
0,
|
243 |
-
25,
|
244 |
-
4.5,
|
245 |
-
10,
|
246 |
-
],
|
247 |
-
[
|
248 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
|
249 |
-
'storm',
|
250 |
-
'',
|
251 |
-
0,
|
252 |
-
25,
|
253 |
-
4.5,
|
254 |
-
10,
|
255 |
-
],
|
256 |
-
[
|
257 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
|
258 |
-
'',
|
259 |
-
'',
|
260 |
-
0,
|
261 |
-
25,
|
262 |
-
4.5,
|
263 |
-
10,
|
264 |
-
],
|
265 |
-
[
|
266 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
|
267 |
-
'typing',
|
268 |
-
'',
|
269 |
-
0,
|
270 |
-
25,
|
271 |
-
4.5,
|
272 |
-
10,
|
273 |
-
],
|
274 |
-
[
|
275 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
|
276 |
-
'',
|
277 |
-
'',
|
278 |
-
0,
|
279 |
-
25,
|
280 |
-
4.5,
|
281 |
-
10,
|
282 |
-
],
|
283 |
-
[
|
284 |
-
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
|
285 |
-
'',
|
286 |
-
'',
|
287 |
-
0,
|
288 |
-
25,
|
289 |
-
4.5,
|
290 |
-
10,
|
291 |
-
],
|
292 |
-
])
|
293 |
-
|
294 |
-
text_to_audio_tab = gr.Interface(
|
295 |
-
fn=text_to_audio,
|
296 |
-
description="""
|
297 |
-
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
298 |
-
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
299 |
-
""",
|
300 |
-
inputs=[
|
301 |
-
gr.Text(label='Prompt'),
|
302 |
-
gr.Text(label='Negative prompt'),
|
303 |
-
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
304 |
-
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
305 |
-
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
306 |
-
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
307 |
-
],
|
308 |
-
outputs='audio',
|
309 |
-
cache_examples=False,
|
310 |
-
title='MMAudio — Text-to-Audio Synthesis',
|
311 |
-
)
|
312 |
-
|
313 |
-
image_to_audio_tab = gr.Interface(
|
314 |
-
fn=image_to_audio,
|
315 |
-
description="""
|
316 |
-
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
317 |
-
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
318 |
-
|
319 |
-
NOTE: It takes longer to process high-resolution images (>384 px on the shorter side).
|
320 |
-
Doing so does not improve results.
|
321 |
-
""",
|
322 |
-
inputs=[
|
323 |
-
gr.Image(type='filepath'),
|
324 |
-
gr.Text(label='Prompt'),
|
325 |
-
gr.Text(label='Negative prompt'),
|
326 |
-
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
327 |
-
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
328 |
-
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
329 |
-
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
330 |
-
],
|
331 |
-
outputs='playable_video',
|
332 |
-
cache_examples=False,
|
333 |
-
title='MMAudio — Image-to-Audio Synthesis (experimental)',
|
334 |
-
)
|
335 |
-
|
336 |
-
if __name__ == "__main__":
|
337 |
-
parser = ArgumentParser()
|
338 |
-
parser.add_argument('--port', type=int, default=7860)
|
339 |
-
args = parser.parse_args()
|
340 |
-
|
341 |
-
gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab],
|
342 |
-
['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']).launch(
|
343 |
-
server_port=args.port, allowed_paths=[output_dir])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (187 Bytes)
|
|
mmaudio/__pycache__/__init__.cpython-38.pyc
DELETED
Binary file (185 Bytes)
|
|
mmaudio/__pycache__/eval_utils.cpython-310.pyc
DELETED
Binary file (7.07 kB)
|
|
mmaudio/__pycache__/eval_utils.cpython-38.pyc
DELETED
Binary file (7.03 kB)
|
|
mmaudio/data/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (192 Bytes)
|
|
mmaudio/data/__pycache__/__init__.cpython-38.pyc
DELETED
Binary file (190 Bytes)
|
|
mmaudio/data/__pycache__/av_utils.cpython-310.pyc
DELETED
Binary file (4.91 kB)
|
|
mmaudio/data/__pycache__/av_utils.cpython-38.pyc
DELETED
Binary file (4.89 kB)
|
|
mmaudio/data/av_utils.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
from dataclasses import dataclass
|
2 |
from fractions import Fraction
|
3 |
from pathlib import Path
|
4 |
-
from typing import Optional
|
5 |
|
6 |
import av
|
7 |
import numpy as np
|
@@ -15,7 +15,7 @@ class VideoInfo:
|
|
15 |
fps: Fraction
|
16 |
clip_frames: torch.Tensor
|
17 |
sync_frames: torch.Tensor
|
18 |
-
all_frames: Optional[
|
19 |
|
20 |
@property
|
21 |
def height(self):
|
@@ -25,35 +25,9 @@ class VideoInfo:
|
|
25 |
def width(self):
|
26 |
return self.all_frames[0].shape[1]
|
27 |
|
28 |
-
@classmethod
|
29 |
-
def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
|
30 |
-
fps: Fraction) -> 'VideoInfo':
|
31 |
-
num_frames = int(duration_sec * fps)
|
32 |
-
all_frames = [image_info.original_frame] * num_frames
|
33 |
-
return cls(duration_sec=duration_sec,
|
34 |
-
fps=fps,
|
35 |
-
clip_frames=image_info.clip_frames,
|
36 |
-
sync_frames=image_info.sync_frames,
|
37 |
-
all_frames=all_frames)
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
class ImageInfo:
|
42 |
-
clip_frames: torch.Tensor
|
43 |
-
sync_frames: torch.Tensor
|
44 |
-
original_frame: Optional[np.ndarray]
|
45 |
-
|
46 |
-
@property
|
47 |
-
def height(self):
|
48 |
-
return self.original_frame.shape[0]
|
49 |
-
|
50 |
-
@property
|
51 |
-
def width(self):
|
52 |
-
return self.original_frame.shape[1]
|
53 |
-
|
54 |
-
|
55 |
-
def read_frames(video_path: Path, list_of_fps: List[float], start_sec: float, end_sec: float,
|
56 |
-
need_all_frames: bool) -> Tuple[List[np.ndarray], List[np.ndarray], Fraction]:
|
57 |
output_frames = [[] for _ in list_of_fps]
|
58 |
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
|
59 |
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
|
|
|
1 |
from dataclasses import dataclass
|
2 |
from fractions import Fraction
|
3 |
from pathlib import Path
|
4 |
+
from typing import Optional
|
5 |
|
6 |
import av
|
7 |
import numpy as np
|
|
|
15 |
fps: Fraction
|
16 |
clip_frames: torch.Tensor
|
17 |
sync_frames: torch.Tensor
|
18 |
+
all_frames: Optional[list[np.ndarray]]
|
19 |
|
20 |
@property
|
21 |
def height(self):
|
|
|
25 |
def width(self):
|
26 |
return self.all_frames[0].shape[1]
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
|
30 |
+
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
output_frames = [[] for _ in list_of_fps]
|
32 |
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
|
33 |
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
|
mmaudio/data/data_setup.py
DELETED
@@ -1,174 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import random
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
from omegaconf import DictConfig
|
7 |
-
from torch.utils.data import DataLoader, Dataset
|
8 |
-
from torch.utils.data.dataloader import default_collate
|
9 |
-
from torch.utils.data.distributed import DistributedSampler
|
10 |
-
|
11 |
-
from mmaudio.data.eval.audiocaps import AudioCapsData
|
12 |
-
from mmaudio.data.eval.video_dataset import MovieGen, VGGSound
|
13 |
-
from mmaudio.data.extracted_audio import ExtractedAudio
|
14 |
-
from mmaudio.data.extracted_vgg import ExtractedVGG
|
15 |
-
from mmaudio.data.mm_dataset import MultiModalDataset
|
16 |
-
from mmaudio.utils.dist_utils import local_rank
|
17 |
-
|
18 |
-
log = logging.getLogger()
|
19 |
-
|
20 |
-
|
21 |
-
# Re-seed randomness every time we start a worker
|
22 |
-
def worker_init_fn(worker_id: int):
|
23 |
-
worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
|
24 |
-
np.random.seed(worker_seed)
|
25 |
-
random.seed(worker_seed)
|
26 |
-
log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
|
27 |
-
|
28 |
-
|
29 |
-
def load_vgg_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
30 |
-
dataset = ExtractedVGG(tsv_path=data_cfg.tsv,
|
31 |
-
data_dim=cfg.data_dim,
|
32 |
-
premade_mmap_dir=data_cfg.memmap_dir)
|
33 |
-
|
34 |
-
return dataset
|
35 |
-
|
36 |
-
|
37 |
-
def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
38 |
-
dataset = ExtractedAudio(tsv_path=data_cfg.tsv,
|
39 |
-
data_dim=cfg.data_dim,
|
40 |
-
premade_mmap_dir=data_cfg.memmap_dir)
|
41 |
-
|
42 |
-
return dataset
|
43 |
-
|
44 |
-
|
45 |
-
def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
46 |
-
if cfg.mini_train:
|
47 |
-
vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
|
48 |
-
audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
|
49 |
-
dataset = MultiModalDataset([vgg], [audiocaps])
|
50 |
-
if cfg.example_train:
|
51 |
-
video = load_vgg_data(cfg, cfg.data.Example_video)
|
52 |
-
audio = load_audio_data(cfg, cfg.data.Example_audio)
|
53 |
-
dataset = MultiModalDataset([video], [audio])
|
54 |
-
else:
|
55 |
-
# load the largest one first
|
56 |
-
freesound = load_audio_data(cfg, cfg.data.FreeSound)
|
57 |
-
vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG)
|
58 |
-
audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
|
59 |
-
audioset_sl = load_audio_data(cfg, cfg.data.AudioSetSL)
|
60 |
-
bbcsound = load_audio_data(cfg, cfg.data.BBCSound)
|
61 |
-
clotho = load_audio_data(cfg, cfg.data.Clotho)
|
62 |
-
dataset = MultiModalDataset([vgg] * cfg.vgg_oversample_rate,
|
63 |
-
[audiocaps, audioset_sl, bbcsound, freesound, clotho])
|
64 |
-
|
65 |
-
batch_size = cfg.batch_size
|
66 |
-
num_workers = cfg.num_workers
|
67 |
-
pin_memory = cfg.pin_memory
|
68 |
-
sampler, loader = construct_loader(dataset,
|
69 |
-
batch_size,
|
70 |
-
num_workers,
|
71 |
-
shuffle=True,
|
72 |
-
drop_last=True,
|
73 |
-
pin_memory=pin_memory)
|
74 |
-
|
75 |
-
return dataset, sampler, loader
|
76 |
-
|
77 |
-
|
78 |
-
def setup_test_datasets(cfg):
|
79 |
-
dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_test)
|
80 |
-
|
81 |
-
batch_size = cfg.batch_size
|
82 |
-
num_workers = cfg.num_workers
|
83 |
-
pin_memory = cfg.pin_memory
|
84 |
-
sampler, loader = construct_loader(dataset,
|
85 |
-
batch_size,
|
86 |
-
num_workers,
|
87 |
-
shuffle=False,
|
88 |
-
drop_last=False,
|
89 |
-
pin_memory=pin_memory)
|
90 |
-
|
91 |
-
return dataset, sampler, loader
|
92 |
-
|
93 |
-
|
94 |
-
def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]:
|
95 |
-
if cfg.example_train:
|
96 |
-
dataset = load_vgg_data(cfg, cfg.data.Example_video)
|
97 |
-
else:
|
98 |
-
dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
|
99 |
-
|
100 |
-
val_batch_size = cfg.batch_size
|
101 |
-
val_eval_batch_size = cfg.eval_batch_size
|
102 |
-
num_workers = cfg.num_workers
|
103 |
-
pin_memory = cfg.pin_memory
|
104 |
-
_, val_loader = construct_loader(dataset,
|
105 |
-
val_batch_size,
|
106 |
-
num_workers,
|
107 |
-
shuffle=False,
|
108 |
-
drop_last=False,
|
109 |
-
pin_memory=pin_memory)
|
110 |
-
_, eval_loader = construct_loader(dataset,
|
111 |
-
val_eval_batch_size,
|
112 |
-
num_workers,
|
113 |
-
shuffle=False,
|
114 |
-
drop_last=False,
|
115 |
-
pin_memory=pin_memory)
|
116 |
-
|
117 |
-
return dataset, val_loader, eval_loader
|
118 |
-
|
119 |
-
|
120 |
-
def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
|
121 |
-
if dataset_name.startswith('audiocaps_full'):
|
122 |
-
dataset = AudioCapsData(cfg.eval_data.AudioCaps_full.audio_path,
|
123 |
-
cfg.eval_data.AudioCaps_full.csv_path)
|
124 |
-
elif dataset_name.startswith('audiocaps'):
|
125 |
-
dataset = AudioCapsData(cfg.eval_data.AudioCaps.audio_path,
|
126 |
-
cfg.eval_data.AudioCaps.csv_path)
|
127 |
-
elif dataset_name.startswith('moviegen'):
|
128 |
-
dataset = MovieGen(cfg.eval_data.MovieGen.video_path,
|
129 |
-
cfg.eval_data.MovieGen.jsonl_path,
|
130 |
-
duration_sec=cfg.duration_s)
|
131 |
-
elif dataset_name.startswith('vggsound'):
|
132 |
-
dataset = VGGSound(cfg.eval_data.VGGSound.video_path,
|
133 |
-
cfg.eval_data.VGGSound.csv_path,
|
134 |
-
duration_sec=cfg.duration_s)
|
135 |
-
else:
|
136 |
-
raise ValueError(f'Invalid dataset name: {dataset_name}')
|
137 |
-
|
138 |
-
batch_size = cfg.batch_size
|
139 |
-
num_workers = cfg.num_workers
|
140 |
-
pin_memory = cfg.pin_memory
|
141 |
-
_, loader = construct_loader(dataset,
|
142 |
-
batch_size,
|
143 |
-
num_workers,
|
144 |
-
shuffle=False,
|
145 |
-
drop_last=False,
|
146 |
-
pin_memory=pin_memory,
|
147 |
-
error_avoidance=True)
|
148 |
-
return dataset, loader
|
149 |
-
|
150 |
-
|
151 |
-
def error_avoidance_collate(batch):
|
152 |
-
batch = list(filter(lambda x: x is not None, batch))
|
153 |
-
return default_collate(batch)
|
154 |
-
|
155 |
-
|
156 |
-
def construct_loader(dataset: Dataset,
|
157 |
-
batch_size: int,
|
158 |
-
num_workers: int,
|
159 |
-
*,
|
160 |
-
shuffle: bool = True,
|
161 |
-
drop_last: bool = True,
|
162 |
-
pin_memory: bool = False,
|
163 |
-
error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]:
|
164 |
-
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
|
165 |
-
train_loader = DataLoader(dataset,
|
166 |
-
batch_size,
|
167 |
-
sampler=train_sampler,
|
168 |
-
num_workers=num_workers,
|
169 |
-
worker_init_fn=worker_init_fn,
|
170 |
-
drop_last=drop_last,
|
171 |
-
persistent_workers=num_workers > 0,
|
172 |
-
pin_memory=pin_memory,
|
173 |
-
collate_fn=error_avoidance_collate if error_avoidance else None)
|
174 |
-
return train_sampler, train_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/eval/__init__.py
DELETED
File without changes
|
mmaudio/data/eval/audiocaps.py
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
from collections import defaultdict
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Union
|
6 |
-
|
7 |
-
import pandas as pd
|
8 |
-
import torch
|
9 |
-
from torch.utils.data.dataset import Dataset
|
10 |
-
|
11 |
-
log = logging.getLogger()
|
12 |
-
|
13 |
-
|
14 |
-
class AudioCapsData(Dataset):
|
15 |
-
|
16 |
-
def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
|
17 |
-
df = pd.read_csv(csv_path).to_dict(orient='records')
|
18 |
-
|
19 |
-
audio_files = sorted(os.listdir(audio_path))
|
20 |
-
audio_files = set(
|
21 |
-
[Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
|
22 |
-
|
23 |
-
self.data = []
|
24 |
-
for row in df:
|
25 |
-
self.data.append({
|
26 |
-
'name': row['name'],
|
27 |
-
'caption': row['caption'],
|
28 |
-
})
|
29 |
-
|
30 |
-
self.audio_path = Path(audio_path)
|
31 |
-
self.csv_path = Path(csv_path)
|
32 |
-
|
33 |
-
log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
|
34 |
-
|
35 |
-
def __getitem__(self, idx: int) -> torch.Tensor:
|
36 |
-
return self.data[idx]
|
37 |
-
|
38 |
-
def __len__(self):
|
39 |
-
return len(self.data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/eval/moviegen.py
DELETED
@@ -1,131 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Union
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from torch.utils.data.dataset import Dataset
|
9 |
-
from torchvision.transforms import v2
|
10 |
-
from torio.io import StreamingMediaDecoder
|
11 |
-
|
12 |
-
from mmaudio.utils.dist_utils import local_rank
|
13 |
-
|
14 |
-
log = logging.getLogger()
|
15 |
-
|
16 |
-
_CLIP_SIZE = 384
|
17 |
-
_CLIP_FPS = 8.0
|
18 |
-
|
19 |
-
_SYNC_SIZE = 224
|
20 |
-
_SYNC_FPS = 25.0
|
21 |
-
|
22 |
-
|
23 |
-
class MovieGenData(Dataset):
|
24 |
-
|
25 |
-
def __init__(
|
26 |
-
self,
|
27 |
-
video_root: Union[str, Path],
|
28 |
-
sync_root: Union[str, Path],
|
29 |
-
jsonl_root: Union[str, Path],
|
30 |
-
*,
|
31 |
-
duration_sec: float = 10.0,
|
32 |
-
read_clip: bool = True,
|
33 |
-
):
|
34 |
-
self.video_root = Path(video_root)
|
35 |
-
self.sync_root = Path(sync_root)
|
36 |
-
self.jsonl_root = Path(jsonl_root)
|
37 |
-
self.read_clip = read_clip
|
38 |
-
|
39 |
-
videos = sorted(os.listdir(self.video_root))
|
40 |
-
videos = [v[:-4] for v in videos] # remove extensions
|
41 |
-
self.captions = {}
|
42 |
-
|
43 |
-
for v in videos:
|
44 |
-
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
45 |
-
data = json.load(f)
|
46 |
-
self.captions[v] = data['audio_prompt']
|
47 |
-
|
48 |
-
if local_rank == 0:
|
49 |
-
log.info(f'{len(videos)} videos found in {video_root}')
|
50 |
-
|
51 |
-
self.duration_sec = duration_sec
|
52 |
-
|
53 |
-
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
54 |
-
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
55 |
-
|
56 |
-
self.clip_augment = v2.Compose([
|
57 |
-
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
58 |
-
v2.ToImage(),
|
59 |
-
v2.ToDtype(torch.float32, scale=True),
|
60 |
-
])
|
61 |
-
|
62 |
-
self.sync_augment = v2.Compose([
|
63 |
-
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
64 |
-
v2.CenterCrop(_SYNC_SIZE),
|
65 |
-
v2.ToImage(),
|
66 |
-
v2.ToDtype(torch.float32, scale=True),
|
67 |
-
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
68 |
-
])
|
69 |
-
|
70 |
-
self.videos = videos
|
71 |
-
|
72 |
-
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
73 |
-
video_id = self.videos[idx]
|
74 |
-
caption = self.captions[video_id]
|
75 |
-
|
76 |
-
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
77 |
-
reader.add_basic_video_stream(
|
78 |
-
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
79 |
-
frame_rate=_CLIP_FPS,
|
80 |
-
format='rgb24',
|
81 |
-
)
|
82 |
-
reader.add_basic_video_stream(
|
83 |
-
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
84 |
-
frame_rate=_SYNC_FPS,
|
85 |
-
format='rgb24',
|
86 |
-
)
|
87 |
-
|
88 |
-
reader.fill_buffer()
|
89 |
-
data_chunk = reader.pop_chunks()
|
90 |
-
|
91 |
-
clip_chunk = data_chunk[0]
|
92 |
-
sync_chunk = data_chunk[1]
|
93 |
-
if clip_chunk is None:
|
94 |
-
raise RuntimeError(f'CLIP video returned None {video_id}')
|
95 |
-
if clip_chunk.shape[0] < self.clip_expected_length:
|
96 |
-
raise RuntimeError(f'CLIP video too short {video_id}')
|
97 |
-
|
98 |
-
if sync_chunk is None:
|
99 |
-
raise RuntimeError(f'Sync video returned None {video_id}')
|
100 |
-
if sync_chunk.shape[0] < self.sync_expected_length:
|
101 |
-
raise RuntimeError(f'Sync video too short {video_id}')
|
102 |
-
|
103 |
-
# truncate the video
|
104 |
-
clip_chunk = clip_chunk[:self.clip_expected_length]
|
105 |
-
if clip_chunk.shape[0] != self.clip_expected_length:
|
106 |
-
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
107 |
-
f'expected {self.clip_expected_length}, '
|
108 |
-
f'got {clip_chunk.shape[0]}')
|
109 |
-
clip_chunk = self.clip_augment(clip_chunk)
|
110 |
-
|
111 |
-
sync_chunk = sync_chunk[:self.sync_expected_length]
|
112 |
-
if sync_chunk.shape[0] != self.sync_expected_length:
|
113 |
-
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
114 |
-
f'expected {self.sync_expected_length}, '
|
115 |
-
f'got {sync_chunk.shape[0]}')
|
116 |
-
sync_chunk = self.sync_augment(sync_chunk)
|
117 |
-
|
118 |
-
data = {
|
119 |
-
'name': video_id,
|
120 |
-
'caption': caption,
|
121 |
-
'clip_video': clip_chunk,
|
122 |
-
'sync_video': sync_chunk,
|
123 |
-
}
|
124 |
-
|
125 |
-
return data
|
126 |
-
|
127 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
128 |
-
return self.sample(idx)
|
129 |
-
|
130 |
-
def __len__(self):
|
131 |
-
return len(self.captions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/eval/video_dataset.py
DELETED
@@ -1,197 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Union
|
6 |
-
|
7 |
-
import pandas as pd
|
8 |
-
import torch
|
9 |
-
from torch.utils.data.dataset import Dataset
|
10 |
-
from torchvision.transforms import v2
|
11 |
-
from torio.io import StreamingMediaDecoder
|
12 |
-
|
13 |
-
from mmaudio.utils.dist_utils import local_rank
|
14 |
-
|
15 |
-
log = logging.getLogger()
|
16 |
-
|
17 |
-
_CLIP_SIZE = 384
|
18 |
-
_CLIP_FPS = 8.0
|
19 |
-
|
20 |
-
_SYNC_SIZE = 224
|
21 |
-
_SYNC_FPS = 25.0
|
22 |
-
|
23 |
-
|
24 |
-
class VideoDataset(Dataset):
|
25 |
-
|
26 |
-
def __init__(
|
27 |
-
self,
|
28 |
-
video_root: Union[str, Path],
|
29 |
-
*,
|
30 |
-
duration_sec: float = 8.0,
|
31 |
-
):
|
32 |
-
self.video_root = Path(video_root)
|
33 |
-
|
34 |
-
self.duration_sec = duration_sec
|
35 |
-
|
36 |
-
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
37 |
-
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
38 |
-
|
39 |
-
self.clip_transform = v2.Compose([
|
40 |
-
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
41 |
-
v2.ToImage(),
|
42 |
-
v2.ToDtype(torch.float32, scale=True),
|
43 |
-
])
|
44 |
-
|
45 |
-
self.sync_transform = v2.Compose([
|
46 |
-
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
47 |
-
v2.CenterCrop(_SYNC_SIZE),
|
48 |
-
v2.ToImage(),
|
49 |
-
v2.ToDtype(torch.float32, scale=True),
|
50 |
-
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
51 |
-
])
|
52 |
-
|
53 |
-
# to be implemented by subclasses
|
54 |
-
self.captions = {}
|
55 |
-
self.videos = sorted(list(self.captions.keys()))
|
56 |
-
|
57 |
-
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
58 |
-
video_id = self.videos[idx]
|
59 |
-
caption = self.captions[video_id]
|
60 |
-
|
61 |
-
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
62 |
-
reader.add_basic_video_stream(
|
63 |
-
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
64 |
-
frame_rate=_CLIP_FPS,
|
65 |
-
format='rgb24',
|
66 |
-
)
|
67 |
-
reader.add_basic_video_stream(
|
68 |
-
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
69 |
-
frame_rate=_SYNC_FPS,
|
70 |
-
format='rgb24',
|
71 |
-
)
|
72 |
-
|
73 |
-
reader.fill_buffer()
|
74 |
-
data_chunk = reader.pop_chunks()
|
75 |
-
|
76 |
-
clip_chunk = data_chunk[0]
|
77 |
-
sync_chunk = data_chunk[1]
|
78 |
-
if clip_chunk is None:
|
79 |
-
raise RuntimeError(f'CLIP video returned None {video_id}')
|
80 |
-
if clip_chunk.shape[0] < self.clip_expected_length:
|
81 |
-
raise RuntimeError(
|
82 |
-
f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
83 |
-
)
|
84 |
-
|
85 |
-
if sync_chunk is None:
|
86 |
-
raise RuntimeError(f'Sync video returned None {video_id}')
|
87 |
-
if sync_chunk.shape[0] < self.sync_expected_length:
|
88 |
-
raise RuntimeError(
|
89 |
-
f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
|
90 |
-
)
|
91 |
-
|
92 |
-
# truncate the video
|
93 |
-
clip_chunk = clip_chunk[:self.clip_expected_length]
|
94 |
-
if clip_chunk.shape[0] != self.clip_expected_length:
|
95 |
-
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
96 |
-
f'expected {self.clip_expected_length}, '
|
97 |
-
f'got {clip_chunk.shape[0]}')
|
98 |
-
clip_chunk = self.clip_transform(clip_chunk)
|
99 |
-
|
100 |
-
sync_chunk = sync_chunk[:self.sync_expected_length]
|
101 |
-
if sync_chunk.shape[0] != self.sync_expected_length:
|
102 |
-
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
103 |
-
f'expected {self.sync_expected_length}, '
|
104 |
-
f'got {sync_chunk.shape[0]}')
|
105 |
-
sync_chunk = self.sync_transform(sync_chunk)
|
106 |
-
|
107 |
-
data = {
|
108 |
-
'name': video_id,
|
109 |
-
'caption': caption,
|
110 |
-
'clip_video': clip_chunk,
|
111 |
-
'sync_video': sync_chunk,
|
112 |
-
}
|
113 |
-
|
114 |
-
return data
|
115 |
-
|
116 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
117 |
-
try:
|
118 |
-
return self.sample(idx)
|
119 |
-
except Exception as e:
|
120 |
-
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
121 |
-
return None
|
122 |
-
|
123 |
-
def __len__(self):
|
124 |
-
return len(self.captions)
|
125 |
-
|
126 |
-
|
127 |
-
class VGGSound(VideoDataset):
|
128 |
-
|
129 |
-
def __init__(
|
130 |
-
self,
|
131 |
-
video_root: Union[str, Path],
|
132 |
-
csv_path: Union[str, Path],
|
133 |
-
*,
|
134 |
-
duration_sec: float = 8.0,
|
135 |
-
):
|
136 |
-
super().__init__(video_root, duration_sec=duration_sec)
|
137 |
-
self.video_root = Path(video_root)
|
138 |
-
self.csv_path = Path(csv_path)
|
139 |
-
|
140 |
-
videos = sorted(os.listdir(self.video_root))
|
141 |
-
if local_rank == 0:
|
142 |
-
log.info(f'{len(videos)} videos found in {video_root}')
|
143 |
-
self.captions = {}
|
144 |
-
|
145 |
-
df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
|
146 |
-
'split']).to_dict(orient='records')
|
147 |
-
|
148 |
-
videos_no_found = []
|
149 |
-
for row in df:
|
150 |
-
if row['split'] == 'test':
|
151 |
-
start_sec = int(row['sec'])
|
152 |
-
video_id = str(row['id'])
|
153 |
-
# this is how our videos are named
|
154 |
-
video_name = f'{video_id}_{start_sec:06d}'
|
155 |
-
if video_name + '.mp4' not in videos:
|
156 |
-
videos_no_found.append(video_name)
|
157 |
-
continue
|
158 |
-
|
159 |
-
self.captions[video_name] = row['caption']
|
160 |
-
|
161 |
-
if local_rank == 0:
|
162 |
-
log.info(f'{len(videos)} videos found in {video_root}')
|
163 |
-
log.info(f'{len(self.captions)} useable videos found')
|
164 |
-
if videos_no_found:
|
165 |
-
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
|
166 |
-
log.info(
|
167 |
-
'A small amount is expected, as not all videos are still available on YouTube')
|
168 |
-
|
169 |
-
self.videos = sorted(list(self.captions.keys()))
|
170 |
-
|
171 |
-
|
172 |
-
class MovieGen(VideoDataset):
|
173 |
-
|
174 |
-
def __init__(
|
175 |
-
self,
|
176 |
-
video_root: Union[str, Path],
|
177 |
-
jsonl_root: Union[str, Path],
|
178 |
-
*,
|
179 |
-
duration_sec: float = 10.0,
|
180 |
-
):
|
181 |
-
super().__init__(video_root, duration_sec=duration_sec)
|
182 |
-
self.video_root = Path(video_root)
|
183 |
-
self.jsonl_root = Path(jsonl_root)
|
184 |
-
|
185 |
-
videos = sorted(os.listdir(self.video_root))
|
186 |
-
videos = [v[:-4] for v in videos] # remove extensions
|
187 |
-
self.captions = {}
|
188 |
-
|
189 |
-
for v in videos:
|
190 |
-
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
191 |
-
data = json.load(f)
|
192 |
-
self.captions[v] = data['audio_prompt']
|
193 |
-
|
194 |
-
if local_rank == 0:
|
195 |
-
log.info(f'{len(videos)} videos found in {video_root}')
|
196 |
-
|
197 |
-
self.videos = videos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/extracted_audio.py
DELETED
@@ -1,88 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Union
|
4 |
-
|
5 |
-
import pandas as pd
|
6 |
-
import torch
|
7 |
-
from tensordict import TensorDict
|
8 |
-
from torch.utils.data.dataset import Dataset
|
9 |
-
|
10 |
-
from mmaudio.utils.dist_utils import local_rank
|
11 |
-
|
12 |
-
log = logging.getLogger()
|
13 |
-
|
14 |
-
|
15 |
-
class ExtractedAudio(Dataset):
|
16 |
-
|
17 |
-
def __init__(
|
18 |
-
self,
|
19 |
-
tsv_path: Union[str, Path],
|
20 |
-
*,
|
21 |
-
premade_mmap_dir: Union[str, Path],
|
22 |
-
data_dim: dict[str, int],
|
23 |
-
):
|
24 |
-
super().__init__()
|
25 |
-
|
26 |
-
self.data_dim = data_dim
|
27 |
-
self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
|
28 |
-
self.ids = [str(d['id']) for d in self.df_list]
|
29 |
-
|
30 |
-
log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
|
31 |
-
# load precomputed memory mapped tensors
|
32 |
-
premade_mmap_dir = Path(premade_mmap_dir)
|
33 |
-
td = TensorDict.load_memmap(premade_mmap_dir)
|
34 |
-
log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
|
35 |
-
self.mean = td['mean']
|
36 |
-
self.std = td['std']
|
37 |
-
self.text_features = td['text_features']
|
38 |
-
|
39 |
-
log.info(f'Loaded {len(self)} samples from {premade_mmap_dir}.')
|
40 |
-
log.info(f'Loaded mean: {self.mean.shape}.')
|
41 |
-
log.info(f'Loaded std: {self.std.shape}.')
|
42 |
-
log.info(f'Loaded text features: {self.text_features.shape}.')
|
43 |
-
|
44 |
-
assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
|
45 |
-
f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
46 |
-
assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
|
47 |
-
f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
48 |
-
|
49 |
-
assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
|
50 |
-
f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
|
51 |
-
assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
|
52 |
-
f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
|
53 |
-
|
54 |
-
self.fake_clip_features = torch.zeros(self.data_dim['clip_seq_len'],
|
55 |
-
self.data_dim['clip_dim'])
|
56 |
-
self.fake_sync_features = torch.zeros(self.data_dim['sync_seq_len'],
|
57 |
-
self.data_dim['sync_dim'])
|
58 |
-
self.video_exist = torch.tensor(0, dtype=torch.bool)
|
59 |
-
self.text_exist = torch.tensor(1, dtype=torch.bool)
|
60 |
-
|
61 |
-
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
62 |
-
latents = self.mean
|
63 |
-
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
|
64 |
-
|
65 |
-
def get_memory_mapped_tensor(self) -> TensorDict:
|
66 |
-
td = TensorDict({
|
67 |
-
'mean': self.mean,
|
68 |
-
'std': self.std,
|
69 |
-
'text_features': self.text_features,
|
70 |
-
})
|
71 |
-
return td
|
72 |
-
|
73 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
74 |
-
data = {
|
75 |
-
'id': str(self.df_list[idx]['id']),
|
76 |
-
'a_mean': self.mean[idx],
|
77 |
-
'a_std': self.std[idx],
|
78 |
-
'clip_features': self.fake_clip_features,
|
79 |
-
'sync_features': self.fake_sync_features,
|
80 |
-
'text_features': self.text_features[idx],
|
81 |
-
'caption': self.df_list[idx]['caption'],
|
82 |
-
'video_exist': self.video_exist,
|
83 |
-
'text_exist': self.text_exist,
|
84 |
-
}
|
85 |
-
return data
|
86 |
-
|
87 |
-
def __len__(self):
|
88 |
-
return len(self.ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/extracted_vgg.py
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Union
|
4 |
-
|
5 |
-
import pandas as pd
|
6 |
-
import torch
|
7 |
-
from tensordict import TensorDict
|
8 |
-
from torch.utils.data.dataset import Dataset
|
9 |
-
|
10 |
-
from mmaudio.utils.dist_utils import local_rank
|
11 |
-
|
12 |
-
log = logging.getLogger()
|
13 |
-
|
14 |
-
|
15 |
-
class ExtractedVGG(Dataset):
|
16 |
-
|
17 |
-
def __init__(
|
18 |
-
self,
|
19 |
-
tsv_path: Union[str, Path],
|
20 |
-
*,
|
21 |
-
premade_mmap_dir: Union[str, Path],
|
22 |
-
data_dim: dict[str, int],
|
23 |
-
):
|
24 |
-
super().__init__()
|
25 |
-
|
26 |
-
self.data_dim = data_dim
|
27 |
-
self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
|
28 |
-
self.ids = [d['id'] for d in self.df_list]
|
29 |
-
|
30 |
-
log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
|
31 |
-
# load precomputed memory mapped tensors
|
32 |
-
premade_mmap_dir = Path(premade_mmap_dir)
|
33 |
-
td = TensorDict.load_memmap(premade_mmap_dir)
|
34 |
-
log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
|
35 |
-
self.mean = td['mean']
|
36 |
-
self.std = td['std']
|
37 |
-
self.clip_features = td['clip_features']
|
38 |
-
self.sync_features = td['sync_features']
|
39 |
-
self.text_features = td['text_features']
|
40 |
-
|
41 |
-
if local_rank == 0:
|
42 |
-
log.info(f'Loaded {len(self)} samples.')
|
43 |
-
log.info(f'Loaded mean: {self.mean.shape}.')
|
44 |
-
log.info(f'Loaded std: {self.std.shape}.')
|
45 |
-
log.info(f'Loaded clip_features: {self.clip_features.shape}.')
|
46 |
-
log.info(f'Loaded sync_features: {self.sync_features.shape}.')
|
47 |
-
log.info(f'Loaded text_features: {self.text_features.shape}.')
|
48 |
-
|
49 |
-
assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
|
50 |
-
f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
51 |
-
assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
|
52 |
-
f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
53 |
-
|
54 |
-
assert self.clip_features.shape[1] == self.data_dim['clip_seq_len'], \
|
55 |
-
f'{self.clip_features.shape[1]} != {self.data_dim["clip_seq_len"]}'
|
56 |
-
assert self.sync_features.shape[1] == self.data_dim['sync_seq_len'], \
|
57 |
-
f'{self.sync_features.shape[1]} != {self.data_dim["sync_seq_len"]}'
|
58 |
-
assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
|
59 |
-
f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
|
60 |
-
|
61 |
-
assert self.clip_features.shape[-1] == self.data_dim['clip_dim'], \
|
62 |
-
f'{self.clip_features.shape[-1]} != {self.data_dim["clip_dim"]}'
|
63 |
-
assert self.sync_features.shape[-1] == self.data_dim['sync_dim'], \
|
64 |
-
f'{self.sync_features.shape[-1]} != {self.data_dim["sync_dim"]}'
|
65 |
-
assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
|
66 |
-
f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
|
67 |
-
|
68 |
-
self.video_exist = torch.tensor(1, dtype=torch.bool)
|
69 |
-
self.text_exist = torch.tensor(1, dtype=torch.bool)
|
70 |
-
|
71 |
-
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
72 |
-
latents = self.mean
|
73 |
-
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
|
74 |
-
|
75 |
-
def get_memory_mapped_tensor(self) -> TensorDict:
|
76 |
-
td = TensorDict({
|
77 |
-
'mean': self.mean,
|
78 |
-
'std': self.std,
|
79 |
-
'clip_features': self.clip_features,
|
80 |
-
'sync_features': self.sync_features,
|
81 |
-
'text_features': self.text_features,
|
82 |
-
})
|
83 |
-
return td
|
84 |
-
|
85 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
86 |
-
data = {
|
87 |
-
'id': self.df_list[idx]['id'],
|
88 |
-
'a_mean': self.mean[idx],
|
89 |
-
'a_std': self.std[idx],
|
90 |
-
'clip_features': self.clip_features[idx],
|
91 |
-
'sync_features': self.sync_features[idx],
|
92 |
-
'text_features': self.text_features[idx],
|
93 |
-
'caption': self.df_list[idx]['label'],
|
94 |
-
'video_exist': self.video_exist,
|
95 |
-
'text_exist': self.text_exist,
|
96 |
-
}
|
97 |
-
|
98 |
-
return data
|
99 |
-
|
100 |
-
def __len__(self):
|
101 |
-
return len(self.ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/extraction/__init__.py
DELETED
File without changes
|
mmaudio/data/extraction/vgg_sound.py
DELETED
@@ -1,193 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Optional, Union
|
5 |
-
|
6 |
-
import pandas as pd
|
7 |
-
import torch
|
8 |
-
import torchaudio
|
9 |
-
from torch.utils.data.dataset import Dataset
|
10 |
-
from torchvision.transforms import v2
|
11 |
-
from torio.io import StreamingMediaDecoder
|
12 |
-
|
13 |
-
from mmaudio.utils.dist_utils import local_rank
|
14 |
-
|
15 |
-
log = logging.getLogger()
|
16 |
-
|
17 |
-
_CLIP_SIZE = 384
|
18 |
-
_CLIP_FPS = 8.0
|
19 |
-
|
20 |
-
_SYNC_SIZE = 224
|
21 |
-
_SYNC_FPS = 25.0
|
22 |
-
|
23 |
-
|
24 |
-
class VGGSound(Dataset):
|
25 |
-
|
26 |
-
def __init__(
|
27 |
-
self,
|
28 |
-
root: Union[str, Path],
|
29 |
-
*,
|
30 |
-
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
|
31 |
-
sample_rate: int = 16_000,
|
32 |
-
duration_sec: float = 8.0,
|
33 |
-
audio_samples: Optional[int] = None,
|
34 |
-
normalize_audio: bool = False,
|
35 |
-
):
|
36 |
-
self.root = Path(root)
|
37 |
-
self.normalize_audio = normalize_audio
|
38 |
-
if audio_samples is None:
|
39 |
-
self.audio_samples = int(sample_rate * duration_sec)
|
40 |
-
else:
|
41 |
-
self.audio_samples = audio_samples
|
42 |
-
effective_duration = audio_samples / sample_rate
|
43 |
-
# make sure the duration is close enough, within 15ms
|
44 |
-
assert abs(effective_duration - duration_sec) < 0.015, \
|
45 |
-
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
46 |
-
|
47 |
-
videos = sorted(os.listdir(self.root))
|
48 |
-
videos = set([Path(v).stem for v in videos]) # remove extensions
|
49 |
-
self.labels = {}
|
50 |
-
self.videos = []
|
51 |
-
missing_videos = []
|
52 |
-
|
53 |
-
# read the tsv for subset information
|
54 |
-
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
|
55 |
-
for record in df_list:
|
56 |
-
id = record['id']
|
57 |
-
label = record['label']
|
58 |
-
if id in videos:
|
59 |
-
self.labels[id] = label
|
60 |
-
self.videos.append(id)
|
61 |
-
else:
|
62 |
-
missing_videos.append(id)
|
63 |
-
|
64 |
-
if local_rank == 0:
|
65 |
-
log.info(f'{len(videos)} videos found in {root}')
|
66 |
-
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
67 |
-
log.info(f'{len(missing_videos)} videos missing in {root}')
|
68 |
-
|
69 |
-
self.sample_rate = sample_rate
|
70 |
-
self.duration_sec = duration_sec
|
71 |
-
|
72 |
-
self.expected_audio_length = audio_samples
|
73 |
-
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
74 |
-
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
75 |
-
|
76 |
-
self.clip_transform = v2.Compose([
|
77 |
-
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
78 |
-
v2.ToImage(),
|
79 |
-
v2.ToDtype(torch.float32, scale=True),
|
80 |
-
])
|
81 |
-
|
82 |
-
self.sync_transform = v2.Compose([
|
83 |
-
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
84 |
-
v2.CenterCrop(_SYNC_SIZE),
|
85 |
-
v2.ToImage(),
|
86 |
-
v2.ToDtype(torch.float32, scale=True),
|
87 |
-
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
88 |
-
])
|
89 |
-
|
90 |
-
self.resampler = {}
|
91 |
-
|
92 |
-
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
93 |
-
video_id = self.videos[idx]
|
94 |
-
label = self.labels[video_id]
|
95 |
-
|
96 |
-
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
97 |
-
reader.add_basic_video_stream(
|
98 |
-
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
99 |
-
frame_rate=_CLIP_FPS,
|
100 |
-
format='rgb24',
|
101 |
-
)
|
102 |
-
reader.add_basic_video_stream(
|
103 |
-
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
104 |
-
frame_rate=_SYNC_FPS,
|
105 |
-
format='rgb24',
|
106 |
-
)
|
107 |
-
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
|
108 |
-
|
109 |
-
reader.fill_buffer()
|
110 |
-
data_chunk = reader.pop_chunks()
|
111 |
-
|
112 |
-
clip_chunk = data_chunk[0]
|
113 |
-
sync_chunk = data_chunk[1]
|
114 |
-
audio_chunk = data_chunk[2]
|
115 |
-
|
116 |
-
if clip_chunk is None:
|
117 |
-
raise RuntimeError(f'CLIP video returned None {video_id}')
|
118 |
-
if clip_chunk.shape[0] < self.clip_expected_length:
|
119 |
-
raise RuntimeError(
|
120 |
-
f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
121 |
-
)
|
122 |
-
|
123 |
-
if sync_chunk is None:
|
124 |
-
raise RuntimeError(f'Sync video returned None {video_id}')
|
125 |
-
if sync_chunk.shape[0] < self.sync_expected_length:
|
126 |
-
raise RuntimeError(
|
127 |
-
f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
|
128 |
-
)
|
129 |
-
|
130 |
-
# process audio
|
131 |
-
sample_rate = int(reader.get_out_stream_info(2).sample_rate)
|
132 |
-
audio_chunk = audio_chunk.transpose(0, 1)
|
133 |
-
audio_chunk = audio_chunk.mean(dim=0) # mono
|
134 |
-
if self.normalize_audio:
|
135 |
-
abs_max = audio_chunk.abs().max()
|
136 |
-
audio_chunk = audio_chunk / abs_max * 0.95
|
137 |
-
if abs_max <= 1e-6:
|
138 |
-
raise RuntimeError(f'Audio is silent {video_id}')
|
139 |
-
|
140 |
-
# resample
|
141 |
-
if sample_rate == self.sample_rate:
|
142 |
-
audio_chunk = audio_chunk
|
143 |
-
else:
|
144 |
-
if sample_rate not in self.resampler:
|
145 |
-
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
146 |
-
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
147 |
-
sample_rate,
|
148 |
-
self.sample_rate,
|
149 |
-
lowpass_filter_width=64,
|
150 |
-
rolloff=0.9475937167399596,
|
151 |
-
resampling_method='sinc_interp_kaiser',
|
152 |
-
beta=14.769656459379492,
|
153 |
-
)
|
154 |
-
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
155 |
-
|
156 |
-
if audio_chunk.shape[0] < self.expected_audio_length:
|
157 |
-
raise RuntimeError(f'Audio too short {video_id}')
|
158 |
-
audio_chunk = audio_chunk[:self.expected_audio_length]
|
159 |
-
|
160 |
-
# truncate the video
|
161 |
-
clip_chunk = clip_chunk[:self.clip_expected_length]
|
162 |
-
if clip_chunk.shape[0] != self.clip_expected_length:
|
163 |
-
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
164 |
-
f'expected {self.clip_expected_length}, '
|
165 |
-
f'got {clip_chunk.shape[0]}')
|
166 |
-
clip_chunk = self.clip_transform(clip_chunk)
|
167 |
-
|
168 |
-
sync_chunk = sync_chunk[:self.sync_expected_length]
|
169 |
-
if sync_chunk.shape[0] != self.sync_expected_length:
|
170 |
-
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
171 |
-
f'expected {self.sync_expected_length}, '
|
172 |
-
f'got {sync_chunk.shape[0]}')
|
173 |
-
sync_chunk = self.sync_transform(sync_chunk)
|
174 |
-
|
175 |
-
data = {
|
176 |
-
'id': video_id,
|
177 |
-
'caption': label,
|
178 |
-
'audio': audio_chunk,
|
179 |
-
'clip_video': clip_chunk,
|
180 |
-
'sync_video': sync_chunk,
|
181 |
-
}
|
182 |
-
|
183 |
-
return data
|
184 |
-
|
185 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
186 |
-
try:
|
187 |
-
return self.sample(idx)
|
188 |
-
except Exception as e:
|
189 |
-
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
190 |
-
return None
|
191 |
-
|
192 |
-
def __len__(self):
|
193 |
-
return len(self.labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/extraction/wav_dataset.py
DELETED
@@ -1,132 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Union
|
5 |
-
|
6 |
-
import open_clip
|
7 |
-
import pandas as pd
|
8 |
-
import torch
|
9 |
-
import torchaudio
|
10 |
-
from torch.utils.data.dataset import Dataset
|
11 |
-
|
12 |
-
log = logging.getLogger()
|
13 |
-
|
14 |
-
|
15 |
-
class WavTextClipsDataset(Dataset):
|
16 |
-
|
17 |
-
def __init__(
|
18 |
-
self,
|
19 |
-
root: Union[str, Path],
|
20 |
-
*,
|
21 |
-
captions_tsv: Union[str, Path],
|
22 |
-
clips_tsv: Union[str, Path],
|
23 |
-
sample_rate: int,
|
24 |
-
num_samples: int,
|
25 |
-
normalize_audio: bool = False,
|
26 |
-
reject_silent: bool = False,
|
27 |
-
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
|
28 |
-
):
|
29 |
-
self.root = Path(root)
|
30 |
-
self.sample_rate = sample_rate
|
31 |
-
self.num_samples = num_samples
|
32 |
-
self.normalize_audio = normalize_audio
|
33 |
-
self.reject_silent = reject_silent
|
34 |
-
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
|
35 |
-
|
36 |
-
audios = sorted(os.listdir(self.root))
|
37 |
-
audios = set([
|
38 |
-
Path(audio).stem for audio in audios
|
39 |
-
if audio.endswith('.wav') or audio.endswith('.flac')
|
40 |
-
])
|
41 |
-
self.captions = {}
|
42 |
-
|
43 |
-
# read the caption tsv
|
44 |
-
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
|
45 |
-
for record in df_list:
|
46 |
-
id = record['id']
|
47 |
-
caption = record['caption']
|
48 |
-
self.captions[id] = caption
|
49 |
-
|
50 |
-
# read the clip tsv
|
51 |
-
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
|
52 |
-
'id': str,
|
53 |
-
'name': str
|
54 |
-
}).to_dict('records')
|
55 |
-
self.clips = []
|
56 |
-
for record in df_list:
|
57 |
-
record['id'] = record['id']
|
58 |
-
record['name'] = record['name']
|
59 |
-
id = record['id']
|
60 |
-
name = record['name']
|
61 |
-
if name not in self.captions:
|
62 |
-
log.warning(f'Audio {name} not found in {captions_tsv}')
|
63 |
-
continue
|
64 |
-
record['caption'] = self.captions[name]
|
65 |
-
self.clips.append(record)
|
66 |
-
|
67 |
-
log.info(f'Found {len(self.clips)} audio files in {self.root}')
|
68 |
-
|
69 |
-
self.resampler = {}
|
70 |
-
|
71 |
-
def __getitem__(self, idx: int) -> torch.Tensor:
|
72 |
-
try:
|
73 |
-
clip = self.clips[idx]
|
74 |
-
audio_name = clip['name']
|
75 |
-
audio_id = clip['id']
|
76 |
-
caption = clip['caption']
|
77 |
-
start_sample = clip['start_sample']
|
78 |
-
end_sample = clip['end_sample']
|
79 |
-
|
80 |
-
audio_path = self.root / f'{audio_name}.flac'
|
81 |
-
if not audio_path.exists():
|
82 |
-
audio_path = self.root / f'{audio_name}.wav'
|
83 |
-
assert audio_path.exists()
|
84 |
-
|
85 |
-
audio_chunk, sample_rate = torchaudio.load(audio_path)
|
86 |
-
audio_chunk = audio_chunk.mean(dim=0) # mono
|
87 |
-
abs_max = audio_chunk.abs().max()
|
88 |
-
if self.normalize_audio:
|
89 |
-
audio_chunk = audio_chunk / abs_max * 0.95
|
90 |
-
|
91 |
-
if self.reject_silent and abs_max < 1e-6:
|
92 |
-
log.warning(f'Rejecting silent audio')
|
93 |
-
return None
|
94 |
-
|
95 |
-
audio_chunk = audio_chunk[start_sample:end_sample]
|
96 |
-
|
97 |
-
# resample
|
98 |
-
if sample_rate == self.sample_rate:
|
99 |
-
audio_chunk = audio_chunk
|
100 |
-
else:
|
101 |
-
if sample_rate not in self.resampler:
|
102 |
-
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
103 |
-
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
104 |
-
sample_rate,
|
105 |
-
self.sample_rate,
|
106 |
-
lowpass_filter_width=64,
|
107 |
-
rolloff=0.9475937167399596,
|
108 |
-
resampling_method='sinc_interp_kaiser',
|
109 |
-
beta=14.769656459379492,
|
110 |
-
)
|
111 |
-
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
112 |
-
|
113 |
-
if audio_chunk.shape[0] < self.num_samples:
|
114 |
-
raise ValueError('Audio is too short')
|
115 |
-
audio_chunk = audio_chunk[:self.num_samples]
|
116 |
-
|
117 |
-
tokens = self.tokenizer([caption])[0]
|
118 |
-
|
119 |
-
output = {
|
120 |
-
'waveform': audio_chunk,
|
121 |
-
'id': audio_id,
|
122 |
-
'caption': caption,
|
123 |
-
'tokens': tokens,
|
124 |
-
}
|
125 |
-
|
126 |
-
return output
|
127 |
-
except Exception as e:
|
128 |
-
log.error(f'Error reading {audio_path}: {e}')
|
129 |
-
return None
|
130 |
-
|
131 |
-
def __len__(self):
|
132 |
-
return len(self.clips)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/mm_dataset.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
import bisect
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch.utils.data.dataset import Dataset
|
5 |
-
|
6 |
-
|
7 |
-
# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
|
8 |
-
class MultiModalDataset(Dataset):
|
9 |
-
datasets: list[Dataset]
|
10 |
-
cumulative_sizes: list[int]
|
11 |
-
|
12 |
-
@staticmethod
|
13 |
-
def cumsum(sequence):
|
14 |
-
r, s = [], 0
|
15 |
-
for e in sequence:
|
16 |
-
l = len(e)
|
17 |
-
r.append(l + s)
|
18 |
-
s += l
|
19 |
-
return r
|
20 |
-
|
21 |
-
def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
|
22 |
-
super().__init__()
|
23 |
-
self.video_datasets = list(video_datasets)
|
24 |
-
self.audio_datasets = list(audio_datasets)
|
25 |
-
self.datasets = self.video_datasets + self.audio_datasets
|
26 |
-
|
27 |
-
self.cumulative_sizes = self.cumsum(self.datasets)
|
28 |
-
|
29 |
-
def __len__(self):
|
30 |
-
return self.cumulative_sizes[-1]
|
31 |
-
|
32 |
-
def __getitem__(self, idx):
|
33 |
-
if idx < 0:
|
34 |
-
if -idx > len(self):
|
35 |
-
raise ValueError("absolute value of index should not exceed dataset length")
|
36 |
-
idx = len(self) + idx
|
37 |
-
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
38 |
-
if dataset_idx == 0:
|
39 |
-
sample_idx = idx
|
40 |
-
else:
|
41 |
-
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
42 |
-
return self.datasets[dataset_idx][sample_idx]
|
43 |
-
|
44 |
-
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
45 |
-
return self.video_datasets[0].compute_latent_stats()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/utils.py
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
import random
|
4 |
-
import tempfile
|
5 |
-
from pathlib import Path
|
6 |
-
from typing import Any, Optional, Union
|
7 |
-
|
8 |
-
import torch
|
9 |
-
import torch.distributed as dist
|
10 |
-
from tensordict import MemoryMappedTensor
|
11 |
-
from torch.utils.data import DataLoader
|
12 |
-
from torch.utils.data.dataset import Dataset
|
13 |
-
from tqdm import tqdm
|
14 |
-
|
15 |
-
from mmaudio.utils.dist_utils import local_rank, world_size
|
16 |
-
|
17 |
-
scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
|
18 |
-
shm_path = Path('/dev/shm')
|
19 |
-
|
20 |
-
log = logging.getLogger()
|
21 |
-
|
22 |
-
|
23 |
-
def reseed(seed):
|
24 |
-
random.seed(seed)
|
25 |
-
torch.manual_seed(seed)
|
26 |
-
|
27 |
-
|
28 |
-
def local_scatter_torch(obj: Optional[Any]):
|
29 |
-
if world_size == 1:
|
30 |
-
# Just one worker. Do nothing.
|
31 |
-
return obj
|
32 |
-
|
33 |
-
array = [obj] * world_size
|
34 |
-
target_array = [None]
|
35 |
-
if local_rank == 0:
|
36 |
-
dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
|
37 |
-
else:
|
38 |
-
dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
|
39 |
-
return target_array[0]
|
40 |
-
|
41 |
-
|
42 |
-
class ShardDataset(Dataset):
|
43 |
-
|
44 |
-
def __init__(self, root):
|
45 |
-
self.root = root
|
46 |
-
self.shards = sorted(os.listdir(root))
|
47 |
-
|
48 |
-
def __len__(self):
|
49 |
-
return len(self.shards)
|
50 |
-
|
51 |
-
def __getitem__(self, idx):
|
52 |
-
return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
|
53 |
-
|
54 |
-
|
55 |
-
def get_tmp_dir(in_memory: bool) -> Path:
|
56 |
-
return shm_path if in_memory else scratch_path
|
57 |
-
|
58 |
-
|
59 |
-
def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
|
60 |
-
in_memory: bool) -> MemoryMappedTensor:
|
61 |
-
if local_rank == 0:
|
62 |
-
with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
|
63 |
-
log.info(f'Loading shards from {data_path} into {f.name}...')
|
64 |
-
data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
|
65 |
-
data = share_tensor_to_all(data)
|
66 |
-
torch.distributed.barrier()
|
67 |
-
f.close() # why does the context manager not close the file for me?
|
68 |
-
else:
|
69 |
-
log.info('Waiting for the data to be shared with me...')
|
70 |
-
data = share_tensor_to_all(None)
|
71 |
-
torch.distributed.barrier()
|
72 |
-
|
73 |
-
return data
|
74 |
-
|
75 |
-
|
76 |
-
def load_shards(
|
77 |
-
data_path: Union[str, Path],
|
78 |
-
ids: list[int],
|
79 |
-
*,
|
80 |
-
tmp_file_path: str,
|
81 |
-
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
|
82 |
-
|
83 |
-
id_set = set(ids)
|
84 |
-
shards = sorted(os.listdir(data_path))
|
85 |
-
log.info(f'Found {len(shards)} shards in {data_path}.')
|
86 |
-
first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
|
87 |
-
|
88 |
-
log.info(f'Rank {local_rank} created file {tmp_file_path}')
|
89 |
-
first_item = next(iter(first_shard.values()))
|
90 |
-
log.info(f'First item shape: {first_item.shape}')
|
91 |
-
mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
|
92 |
-
dtype=torch.float32,
|
93 |
-
filename=tmp_file_path,
|
94 |
-
existsok=True)
|
95 |
-
total_count = 0
|
96 |
-
used_index = set()
|
97 |
-
id_indexing = {i: idx for idx, i in enumerate(ids)}
|
98 |
-
# faster with no workers; otherwise we need to set_sharing_strategy('file_system')
|
99 |
-
loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
|
100 |
-
for data in tqdm(loader, desc='Loading shards'):
|
101 |
-
for i, v in data.items():
|
102 |
-
if i not in id_set:
|
103 |
-
continue
|
104 |
-
|
105 |
-
# tensor_index = ids.index(i)
|
106 |
-
tensor_index = id_indexing[i]
|
107 |
-
if tensor_index in used_index:
|
108 |
-
raise ValueError(f'Duplicate id {i} found in {data_path}.')
|
109 |
-
used_index.add(tensor_index)
|
110 |
-
mm_tensor[tensor_index] = v
|
111 |
-
total_count += 1
|
112 |
-
|
113 |
-
assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
|
114 |
-
log.info(f'Loaded {total_count} tensors from {data_path}.')
|
115 |
-
|
116 |
-
return mm_tensor
|
117 |
-
|
118 |
-
|
119 |
-
def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
|
120 |
-
"""
|
121 |
-
x: the tensor to be shared; None if local_rank != 0
|
122 |
-
return: the shared tensor
|
123 |
-
"""
|
124 |
-
|
125 |
-
# there is no need to share your stuff with anyone if you are alone; must be in memory
|
126 |
-
if world_size == 1:
|
127 |
-
return x
|
128 |
-
|
129 |
-
if local_rank == 0:
|
130 |
-
assert x is not None, 'x must not be None if local_rank == 0'
|
131 |
-
else:
|
132 |
-
assert x is None, 'x must be None if local_rank != 0'
|
133 |
-
|
134 |
-
if local_rank == 0:
|
135 |
-
filename = x.filename
|
136 |
-
meta_information = (filename, x.shape, x.dtype)
|
137 |
-
else:
|
138 |
-
meta_information = None
|
139 |
-
|
140 |
-
filename, data_shape, data_type = local_scatter_torch(meta_information)
|
141 |
-
if local_rank == 0:
|
142 |
-
data = x
|
143 |
-
else:
|
144 |
-
data = MemoryMappedTensor.from_filename(filename=filename,
|
145 |
-
dtype=data_type,
|
146 |
-
shape=data_shape)
|
147 |
-
|
148 |
-
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/eval_utils.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1 |
import dataclasses
|
2 |
import logging
|
3 |
from pathlib import Path
|
4 |
-
from typing import Optional
|
5 |
|
6 |
-
import numpy as np
|
7 |
import torch
|
8 |
from colorlog import ColoredFormatter
|
9 |
-
from PIL import Image
|
10 |
from torchvision.transforms import v2
|
11 |
|
12 |
-
from mmaudio.data.av_utils import
|
13 |
from mmaudio.model.flow_matching import FlowMatching
|
14 |
from mmaudio.model.networks import MMAudio
|
15 |
-
from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
|
16 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
17 |
from mmaudio.utils.download_utils import download_model_if_needed
|
18 |
|
@@ -26,7 +24,7 @@ class ModelConfig:
|
|
26 |
vae_path: Path
|
27 |
bigvgan_16k_path: Optional[Path]
|
28 |
mode: str
|
29 |
-
synchformer_ckpt: Path = Path('./
|
30 |
|
31 |
@property
|
32 |
def seq_cfg(self) -> SequenceConfig:
|
@@ -44,31 +42,31 @@ class ModelConfig:
|
|
44 |
|
45 |
|
46 |
small_16k = ModelConfig(model_name='small_16k',
|
47 |
-
model_path=Path('./
|
48 |
-
vae_path=Path('./
|
49 |
-
bigvgan_16k_path=Path('./
|
50 |
mode='16k')
|
51 |
small_44k = ModelConfig(model_name='small_44k',
|
52 |
-
model_path=Path('./
|
53 |
-
vae_path=Path('./
|
54 |
bigvgan_16k_path=None,
|
55 |
mode='44k')
|
56 |
medium_44k = ModelConfig(model_name='medium_44k',
|
57 |
-
model_path=Path('./
|
58 |
-
vae_path=Path('./
|
59 |
bigvgan_16k_path=None,
|
60 |
mode='44k')
|
61 |
large_44k = ModelConfig(model_name='large_44k',
|
62 |
-
model_path=Path('./
|
63 |
-
vae_path=Path('./
|
64 |
bigvgan_16k_path=None,
|
65 |
mode='44k')
|
66 |
large_44k_v2 = ModelConfig(model_name='large_44k_v2',
|
67 |
-
model_path=Path('./
|
68 |
-
vae_path=Path('./
|
69 |
bigvgan_16k_path=None,
|
70 |
mode='44k')
|
71 |
-
all_model_cfg:
|
72 |
'small_16k': small_16k,
|
73 |
'small_44k': small_44k,
|
74 |
'medium_44k': medium_44k,
|
@@ -80,9 +78,9 @@ all_model_cfg: Dict[str, ModelConfig] = {
|
|
80 |
def generate(
|
81 |
clip_video: Optional[torch.Tensor],
|
82 |
sync_video: Optional[torch.Tensor],
|
83 |
-
text: Optional[
|
84 |
*,
|
85 |
-
negative_text: Optional[
|
86 |
feature_utils: FeaturesUtils,
|
87 |
net: MMAudio,
|
88 |
fm: FlowMatching,
|
@@ -90,7 +88,6 @@ def generate(
|
|
90 |
cfg_strength: float,
|
91 |
clip_batch_size_multiplier: int = 40,
|
92 |
sync_batch_size_multiplier: int = 40,
|
93 |
-
image_input: bool = False,
|
94 |
) -> torch.Tensor:
|
95 |
device = feature_utils.device
|
96 |
dtype = feature_utils.dtype
|
@@ -101,12 +98,10 @@ def generate(
|
|
101 |
clip_features = feature_utils.encode_video_with_clip(clip_video,
|
102 |
batch_size=bs *
|
103 |
clip_batch_size_multiplier)
|
104 |
-
if image_input:
|
105 |
-
clip_features = clip_features.expand(-1, net.clip_seq_len, -1)
|
106 |
else:
|
107 |
clip_features = net.get_empty_clip_sequence(bs)
|
108 |
|
109 |
-
if sync_video is not None
|
110 |
sync_video = sync_video.to(device, dtype, non_blocking=True)
|
111 |
sync_features = feature_utils.encode_video_with_sync(sync_video,
|
112 |
batch_size=bs *
|
@@ -144,7 +139,7 @@ def generate(
|
|
144 |
return audio
|
145 |
|
146 |
|
147 |
-
LOGFORMAT = "
|
148 |
|
149 |
|
150 |
def setup_eval_logging(log_level: int = logging.INFO):
|
@@ -158,14 +153,12 @@ def setup_eval_logging(log_level: int = logging.INFO):
|
|
158 |
log.addHandler(stream)
|
159 |
|
160 |
|
161 |
-
_CLIP_SIZE = 384
|
162 |
-
_CLIP_FPS = 8.0
|
163 |
-
|
164 |
-
_SYNC_SIZE = 224
|
165 |
-
_SYNC_FPS = 25.0
|
166 |
-
|
167 |
-
|
168 |
def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
clip_transform = v2.Compose([
|
171 |
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
@@ -220,36 +213,5 @@ def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = Tr
|
|
220 |
return video_info
|
221 |
|
222 |
|
223 |
-
def load_image(image_path: Path) -> VideoInfo:
|
224 |
-
clip_transform = v2.Compose([
|
225 |
-
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
226 |
-
v2.ToImage(),
|
227 |
-
v2.ToDtype(torch.float32, scale=True),
|
228 |
-
])
|
229 |
-
|
230 |
-
sync_transform = v2.Compose([
|
231 |
-
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
232 |
-
v2.CenterCrop(_SYNC_SIZE),
|
233 |
-
v2.ToImage(),
|
234 |
-
v2.ToDtype(torch.float32, scale=True),
|
235 |
-
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
236 |
-
])
|
237 |
-
|
238 |
-
frame = np.array(Image.open(image_path))
|
239 |
-
|
240 |
-
clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
|
241 |
-
sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
|
242 |
-
|
243 |
-
clip_frames = clip_transform(clip_chunk)
|
244 |
-
sync_frames = sync_transform(sync_chunk)
|
245 |
-
|
246 |
-
video_info = ImageInfo(
|
247 |
-
clip_frames=clip_frames,
|
248 |
-
sync_frames=sync_frames,
|
249 |
-
original_frame=frame,
|
250 |
-
)
|
251 |
-
return video_info
|
252 |
-
|
253 |
-
|
254 |
def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
255 |
reencode_with_audio(video_info, output_path, audio, sampling_rate)
|
|
|
1 |
import dataclasses
|
2 |
import logging
|
3 |
from pathlib import Path
|
4 |
+
from typing import Optional
|
5 |
|
|
|
6 |
import torch
|
7 |
from colorlog import ColoredFormatter
|
|
|
8 |
from torchvision.transforms import v2
|
9 |
|
10 |
+
from mmaudio.data.av_utils import VideoInfo, read_frames, reencode_with_audio
|
11 |
from mmaudio.model.flow_matching import FlowMatching
|
12 |
from mmaudio.model.networks import MMAudio
|
13 |
+
from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig)
|
14 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
15 |
from mmaudio.utils.download_utils import download_model_if_needed
|
16 |
|
|
|
24 |
vae_path: Path
|
25 |
bigvgan_16k_path: Optional[Path]
|
26 |
mode: str
|
27 |
+
synchformer_ckpt: Path = Path('./ext_weights/synchformer_state_dict.pth')
|
28 |
|
29 |
@property
|
30 |
def seq_cfg(self) -> SequenceConfig:
|
|
|
42 |
|
43 |
|
44 |
small_16k = ModelConfig(model_name='small_16k',
|
45 |
+
model_path=Path('./weights/mmaudio_small_16k.pth'),
|
46 |
+
vae_path=Path('./ext_weights/v1-16.pth'),
|
47 |
+
bigvgan_16k_path=Path('./ext_weights/best_netG.pt'),
|
48 |
mode='16k')
|
49 |
small_44k = ModelConfig(model_name='small_44k',
|
50 |
+
model_path=Path('./weights/mmaudio_small_44k.pth'),
|
51 |
+
vae_path=Path('./ext_weights/v1-44.pth'),
|
52 |
bigvgan_16k_path=None,
|
53 |
mode='44k')
|
54 |
medium_44k = ModelConfig(model_name='medium_44k',
|
55 |
+
model_path=Path('./weights/mmaudio_medium_44k.pth'),
|
56 |
+
vae_path=Path('./ext_weights/v1-44.pth'),
|
57 |
bigvgan_16k_path=None,
|
58 |
mode='44k')
|
59 |
large_44k = ModelConfig(model_name='large_44k',
|
60 |
+
model_path=Path('./weights/mmaudio_large_44k.pth'),
|
61 |
+
vae_path=Path('./ext_weights/v1-44.pth'),
|
62 |
bigvgan_16k_path=None,
|
63 |
mode='44k')
|
64 |
large_44k_v2 = ModelConfig(model_name='large_44k_v2',
|
65 |
+
model_path=Path('./weights/mmaudio_large_44k_v2.pth'),
|
66 |
+
vae_path=Path('./ext_weights/v1-44.pth'),
|
67 |
bigvgan_16k_path=None,
|
68 |
mode='44k')
|
69 |
+
all_model_cfg: dict[str, ModelConfig] = {
|
70 |
'small_16k': small_16k,
|
71 |
'small_44k': small_44k,
|
72 |
'medium_44k': medium_44k,
|
|
|
78 |
def generate(
|
79 |
clip_video: Optional[torch.Tensor],
|
80 |
sync_video: Optional[torch.Tensor],
|
81 |
+
text: Optional[list[str]],
|
82 |
*,
|
83 |
+
negative_text: Optional[list[str]] = None,
|
84 |
feature_utils: FeaturesUtils,
|
85 |
net: MMAudio,
|
86 |
fm: FlowMatching,
|
|
|
88 |
cfg_strength: float,
|
89 |
clip_batch_size_multiplier: int = 40,
|
90 |
sync_batch_size_multiplier: int = 40,
|
|
|
91 |
) -> torch.Tensor:
|
92 |
device = feature_utils.device
|
93 |
dtype = feature_utils.dtype
|
|
|
98 |
clip_features = feature_utils.encode_video_with_clip(clip_video,
|
99 |
batch_size=bs *
|
100 |
clip_batch_size_multiplier)
|
|
|
|
|
101 |
else:
|
102 |
clip_features = net.get_empty_clip_sequence(bs)
|
103 |
|
104 |
+
if sync_video is not None:
|
105 |
sync_video = sync_video.to(device, dtype, non_blocking=True)
|
106 |
sync_features = feature_utils.encode_video_with_sync(sync_video,
|
107 |
batch_size=bs *
|
|
|
139 |
return audio
|
140 |
|
141 |
|
142 |
+
LOGFORMAT = " %(log_color)s%(levelname)-8s%(reset)s | %(log_color)s%(message)s%(reset)s"
|
143 |
|
144 |
|
145 |
def setup_eval_logging(log_level: int = logging.INFO):
|
|
|
153 |
log.addHandler(stream)
|
154 |
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
|
157 |
+
_CLIP_SIZE = 384
|
158 |
+
_CLIP_FPS = 8.0
|
159 |
+
|
160 |
+
_SYNC_SIZE = 224
|
161 |
+
_SYNC_FPS = 25.0
|
162 |
|
163 |
clip_transform = v2.Compose([
|
164 |
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
|
|
213 |
return video_info
|
214 |
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
217 |
reencode_with_audio(video_info, output_path, audio, sampling_rate)
|
mmaudio/ext/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (191 Bytes)
|
|
mmaudio/ext/__pycache__/__init__.cpython-38.pyc
DELETED
Binary file (189 Bytes)
|
|
mmaudio/ext/__pycache__/mel_converter.cpython-310.pyc
DELETED
Binary file (2.87 kB)
|
|
mmaudio/ext/__pycache__/mel_converter.cpython-38.pyc
DELETED
Binary file (2.84 kB)
|
|
mmaudio/ext/__pycache__/rotary_embeddings.cpython-310.pyc
DELETED
Binary file (1.48 kB)
|
|
mmaudio/ext/__pycache__/rotary_embeddings.cpython-38.pyc
DELETED
Binary file (1.45 kB)
|
|
mmaudio/ext/autoencoder/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (256 Bytes)
|
|
mmaudio/ext/autoencoder/__pycache__/__init__.cpython-38.pyc
DELETED
Binary file (254 Bytes)
|
|
mmaudio/ext/autoencoder/__pycache__/autoencoder.cpython-310.pyc
DELETED
Binary file (2.14 kB)
|
|