lym0302 commited on
Commit
0163d98
·
1 Parent(s): 4cfc7ca
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +69 -101
  2. app.py +34 -102
  3. batch_eval.py +0 -110
  4. config/__init__.py +0 -0
  5. config/base_config.yaml +0 -62
  6. config/data/base.yaml +0 -70
  7. config/eval_config.yaml +0 -17
  8. config/eval_data/base.yaml +0 -22
  9. config/hydra/job_logging/custom-eval.yaml +0 -32
  10. config/hydra/job_logging/custom-no-rank.yaml +0 -32
  11. config/hydra/job_logging/custom-simplest.yaml +0 -26
  12. config/hydra/job_logging/custom.yaml +0 -33
  13. config/train_config.yaml +0 -41
  14. demo.py +1 -7
  15. docs/EVAL.md +0 -22
  16. docs/MODELS.md +0 -50
  17. docs/TRAINING.md +0 -184
  18. docs/index.html +10 -12
  19. gradio_demo.py +0 -343
  20. mmaudio/__pycache__/__init__.cpython-310.pyc +0 -0
  21. mmaudio/__pycache__/__init__.cpython-38.pyc +0 -0
  22. mmaudio/__pycache__/eval_utils.cpython-310.pyc +0 -0
  23. mmaudio/__pycache__/eval_utils.cpython-38.pyc +0 -0
  24. mmaudio/data/__pycache__/__init__.cpython-310.pyc +0 -0
  25. mmaudio/data/__pycache__/__init__.cpython-38.pyc +0 -0
  26. mmaudio/data/__pycache__/av_utils.cpython-310.pyc +0 -0
  27. mmaudio/data/__pycache__/av_utils.cpython-38.pyc +0 -0
  28. mmaudio/data/av_utils.py +4 -30
  29. mmaudio/data/data_setup.py +0 -174
  30. mmaudio/data/eval/__init__.py +0 -0
  31. mmaudio/data/eval/audiocaps.py +0 -39
  32. mmaudio/data/eval/moviegen.py +0 -131
  33. mmaudio/data/eval/video_dataset.py +0 -197
  34. mmaudio/data/extracted_audio.py +0 -88
  35. mmaudio/data/extracted_vgg.py +0 -101
  36. mmaudio/data/extraction/__init__.py +0 -0
  37. mmaudio/data/extraction/vgg_sound.py +0 -193
  38. mmaudio/data/extraction/wav_dataset.py +0 -132
  39. mmaudio/data/mm_dataset.py +0 -45
  40. mmaudio/data/utils.py +0 -148
  41. mmaudio/eval_utils.py +25 -63
  42. mmaudio/ext/__pycache__/__init__.cpython-310.pyc +0 -0
  43. mmaudio/ext/__pycache__/__init__.cpython-38.pyc +0 -0
  44. mmaudio/ext/__pycache__/mel_converter.cpython-310.pyc +0 -0
  45. mmaudio/ext/__pycache__/mel_converter.cpython-38.pyc +0 -0
  46. mmaudio/ext/__pycache__/rotary_embeddings.cpython-310.pyc +0 -0
  47. mmaudio/ext/__pycache__/rotary_embeddings.cpython-38.pyc +0 -0
  48. mmaudio/ext/autoencoder/__pycache__/__init__.cpython-310.pyc +0 -0
  49. mmaudio/ext/autoencoder/__pycache__/__init__.cpython-38.pyc +0 -0
  50. mmaudio/ext/autoencoder/__pycache__/autoencoder.cpython-310.pyc +0 -0
README.md CHANGED
@@ -1,27 +1,25 @@
1
  ---
2
  title: DeepSound-V1
3
- colorFrom: indigo
4
- colorTo: purple
 
5
  sdk: gradio
6
- sdk_version: 5.22.0
7
  app_file: app.py
8
  pinned: false
9
  ---
10
 
11
- <div align="center">
12
- <p align="center">
13
- <h2>MMAudio</h2>
14
- <a href="https://arxiv.org/abs/2412.15322">Paper</a> | <a href="https://hkchengrex.github.io/MMAudio">Webpage</a> | <a href="https://huggingface.co/hkchengrex/MMAudio/tree/main">Models</a> | <a href="https://huggingface.co/spaces/hkchengrex/MMAudio"> Huggingface Demo</a> | <a href="https://colab.research.google.com/drive/1TAaXCY2-kPk4xE4PwKB3EqFbSnkUuzZ8?usp=sharing">Colab Demo</a> | <a href="https://replicate.com/zsxkib/mmaudio">Replicate Demo</a>
15
- </p>
16
- </div>
17
 
18
- ## [Taming Multimodal Joint Training for High-Quality Video-to-Audio Synthesis](https://hkchengrex.github.io/MMAudio)
19
 
20
  [Ho Kei Cheng](https://hkchengrex.github.io/), [Masato Ishii](https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ), [Akio Hayakawa](https://scholar.google.com/citations?user=sXAjHFIAAAAJ), [Takashi Shibuya](https://scholar.google.com/citations?user=XCRO260AAAAJ), [Alexander Schwing](https://www.alexander-schwing.de/), [Yuki Mitsufuji](https://www.yukimitsufuji.com/)
21
 
22
  University of Illinois Urbana-Champaign, Sony AI, and Sony Group Corporation
23
 
24
- CVPR 2025
 
 
 
 
25
 
26
  ## Highlight
27
 
@@ -29,25 +27,22 @@ MMAudio generates synchronized audio given video and/or text inputs.
29
  Our key innovation is multimodal joint training which allows training on a wide range of audio-visual and audio-text datasets.
30
  Moreover, a synchronization module aligns the generated audio with the video frames.
31
 
 
32
  ## Results
33
 
34
  (All audio from our algorithm MMAudio)
35
 
36
- Videos from Sora:
37
 
38
  https://github.com/user-attachments/assets/82afd192-0cee-48a1-86ca-bd39b8c8f330
39
 
40
- Videos from Veo 2:
41
-
42
- https://github.com/user-attachments/assets/8a11419e-fee2-46e0-9e67-dfb03c48d00e
43
 
44
- Videos from MovieGen/Hunyuan Video/VGGSound:
45
 
46
  https://github.com/user-attachments/assets/29230d4e-21c1-4cf8-a221-c28f2af6d0ca
47
 
48
  For more results, visit https://hkchengrex.com/MMAudio/video_main.html.
49
 
50
-
51
  ## Installation
52
 
53
  We have only tested this on Ubuntu.
@@ -56,30 +51,17 @@ We have only tested this on Ubuntu.
56
 
57
  We recommend using a [miniforge](https://github.com/conda-forge/miniforge) environment.
58
 
59
- - Python 3.9+
60
- - PyTorch **2.5.1+** and corresponding torchvision/torchaudio (pick your CUDA version https://pytorch.org/, pip install recommended)
61
- <!-- - ffmpeg<7 ([this is required by torchaudio](https://pytorch.org/audio/master/installation.html#optional-dependencies), you can install it in a miniforge environment with `conda install -c conda-forge 'ffmpeg<7'`) -->
62
 
63
- **1. Install prerequisite if not yet met:**
64
-
65
- ```bash
66
- pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 --upgrade
67
- ```
68
-
69
- (Or any other CUDA versions that your GPUs/driver support)
70
-
71
- <!-- ```
72
- conda install -c conda-forge 'ffmpeg<7
73
- ```
74
- (Optional, if you use miniforge and don't already have the appropriate ffmpeg) -->
75
-
76
- **2. Clone our repository:**
77
 
78
  ```bash
79
  git clone https://github.com/hkchengrex/MMAudio.git
80
  ```
81
 
82
- **3. Install with pip (install pytorch first before attempting this!):**
83
 
84
  ```bash
85
  cd MMAudio
@@ -88,108 +70,94 @@ pip install -e .
88
 
89
  (If you encounter the File "setup.py" not found error, upgrade your pip with pip install --upgrade pip)
90
 
91
-
92
  **Pretrained models:**
93
 
94
- The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`.
95
- The models are also available at https://huggingface.co/hkchengrex/MMAudio/tree/main
96
- See [MODELS.md](docs/MODELS.md) for more details.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  ## Demo
99
 
100
- By default, these scripts use the `large_44k_v2` model.
101
  In our experiments, inference only takes around 6GB of GPU memory (in 16-bit mode) which should fit in most modern GPUs.
102
 
103
  ### Command-line interface
104
 
105
  With `demo.py`
106
-
107
  ```bash
108
  python demo.py --duration=8 --video=<path to video> --prompt "your prompt"
109
  ```
110
-
111
  The output (audio in `.flac` format, and video in `.mp4` format) will be saved in `./output`.
112
  See the file for more options.
113
  Simply omit the `--video` option for text-to-audio synthesis.
114
  The default output (and training) duration is 8 seconds. Longer/shorter durations could also work, but a large deviation from the training duration may result in a lower quality.
115
 
 
116
  ### Gradio interface
117
 
118
  Supports video-to-audio and text-to-audio synthesis.
119
- You can also try experimental image-to-audio synthesis which duplicates the input image to a video for processing. This might be interesting to some but it is not something MMAudio has been trained for.
120
- Use [port forwarding](https://unix.stackexchange.com/questions/115897/whats-ssh-port-forwarding-and-whats-the-difference-between-ssh-local-and-remot) (e.g., `ssh -L 7860:localhost:7860 server`) if necessary. The default port is `7860` which you can specify with `--port`.
121
 
122
- ```bash
123
  python gradio_demo.py
124
  ```
125
 
126
- ### FAQ
127
-
128
- 1. Video processing
129
- - Processing higher-resolution videos takes longer due to encoding and decoding (which can take >95% of the processing time!), but it does not improve the quality of results.
130
- - The CLIP encoder resizes input frames to 384×384 pixels.
131
- - Synchformer resizes the shorter edge to 224 pixels and applies a center crop, focusing only on the central square of each frame.
132
- 2. Frame rates
133
- - The CLIP model operates at 8 FPS, while Synchformer works at 25 FPS.
134
- - Frame rate conversion happens on-the-fly via the video reader.
135
- - For input videos with a frame rate below 25 FPS, frames will be duplicated to match the required rate.
136
- 3. Failure cases
137
- As with most models of this type, failures can occur, and the reasons are not always clear. Below are some known failure modes. If you notice a failure mode or believe there’s a bug, feel free to open an issue in the repository.
138
- 4. Performance variations
139
- We notice that there can be subtle performance variations in different hardware and software environments. Some of the reasons include using/not using `torch.compile`, video reader library/backend, inference precision, batch sizes, random seeds, etc. We (will) provide pre-computed results on standard benchmark for reference. Results obtained from this codebase should be similar but might not be exactly the same.
140
-
141
  ### Known limitations
142
 
143
- 1. The model sometimes generates unintelligible human speech-like sounds
144
- 2. The model sometimes generates background music (without explicit training, it would not be high quality)
145
  3. The model struggles with unfamiliar concepts, e.g., it can generate "gunfires" but not "RPG firing".
146
 
147
  We believe all of these three limitations can be addressed with more high-quality training data.
148
 
149
  ## Training
150
-
151
- See [TRAINING.md](docs/TRAINING.md).
152
 
153
  ## Evaluation
154
-
155
- See [EVAL.md](docs/EVAL.md).
156
-
157
- ## Training Datasets
158
-
159
- MMAudio was trained on several datasets, including [AudioSet](https://research.google.com/audioset/), [Freesound](https://github.com/LAION-AI/audio-dataset/blob/main/laion-audio-630k/README.md), [VGGSound](https://www.robots.ox.ac.uk/~vgg/data/vggsound/), [AudioCaps](https://audiocaps.github.io/), and [WavCaps](https://github.com/XinhaoMei/WavCaps). These datasets are subject to specific licenses, which can be accessed on their respective websites. We do not guarantee that the pre-trained models are suitable for commercial use. Please use them at your own risk.
160
-
161
- ## Update Logs
162
-
163
- - 2025-03-09: Uploaded the corrected tsv files. See [TRAINING.md](docs/TRAINING.md).
164
- - 2025-02-27: Disabled the GradScaler by default to improve training stability. See #49.
165
- - 2024-12-23: Added training and batch evaluation scripts.
166
- - 2024-12-14: Removed the `ffmpeg<7` requirement for the demos by replacing `torio.io.StreamingMediaDecoder` with `pyav` for reading frames. The read frames are also cached, so we are not reading the same frames again during reconstruction. This should speed things up and make installation less of a hassle.
167
- - 2024-12-13: Improved for-loop processing in CLIP/Sync feature extraction by introducing a batch size multiplier. We can approximately use 40x batch size for CLIP/Sync without using more memory, thereby speeding up processing. Removed VAE encoder during inference -- we don't need it.
168
- - 2024-12-11: Replaced `torio.io.StreamingMediaDecoder` with `pyav` for reading framerate when reconstructing the input video. `torio.io.StreamingMediaDecoder` does not work reliably in huggingface ZeroGPU's environment, and I suspect that it might not work in some other environments as well.
169
-
170
- ## Citation
171
-
172
- ```bibtex
173
- @inproceedings{cheng2025taming,
174
- title={Taming Multimodal Joint Training for High-Quality Video-to-Audio Synthesis},
175
- author={Cheng, Ho Kei and Ishii, Masato and Hayakawa, Akio and Shibuya, Takashi and Schwing, Alexander and Mitsufuji, Yuki},
176
- booktitle={CVPR},
177
- year={2025}
178
- }
179
- ```
180
-
181
- ## Relevant Repositories
182
-
183
- - [av-benchmark](https://github.com/hkchengrex/av-benchmark) for benchmarking results.
184
-
185
- ## Disclaimer
186
-
187
- We have no affiliation with and have no knowledge of the party behind the domain "mmaudio.net".
188
 
189
  ## Acknowledgement
190
-
191
  Many thanks to:
192
- - [Make-An-Audio 2](https://github.com/bytedance/Make-An-Audio-2) for the 16kHz BigVGAN pretrained model and the VAE architecture
193
  - [BigVGAN](https://github.com/NVIDIA/BigVGAN)
194
  - [Synchformer](https://github.com/v-iashin/Synchformer)
195
- - [EDM2](https://github.com/NVlabs/edm2) for the magnitude-preserving VAE network architecture
 
1
  ---
2
  title: DeepSound-V1
3
+ emoji: 🔊
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
9
  ---
10
 
 
 
 
 
 
 
11
 
12
+ # [Taming Multimodal Joint Training for High-Quality Video-to-Audio Synthesis](https://hkchengrex.github.io/MMAudio)
13
 
14
  [Ho Kei Cheng](https://hkchengrex.github.io/), [Masato Ishii](https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ), [Akio Hayakawa](https://scholar.google.com/citations?user=sXAjHFIAAAAJ), [Takashi Shibuya](https://scholar.google.com/citations?user=XCRO260AAAAJ), [Alexander Schwing](https://www.alexander-schwing.de/), [Yuki Mitsufuji](https://www.yukimitsufuji.com/)
15
 
16
  University of Illinois Urbana-Champaign, Sony AI, and Sony Group Corporation
17
 
18
+
19
+ [[Paper (being prepared)]](https://hkchengrex.github.io/MMAudio) [[Project Page]](https://hkchengrex.github.io/MMAudio)
20
+
21
+
22
+ **Note: This repository is still under construction. Single-example inference should work as expected. The training code will be added. Code is subject to non-backward-compatible changes.**
23
 
24
  ## Highlight
25
 
 
27
  Our key innovation is multimodal joint training which allows training on a wide range of audio-visual and audio-text datasets.
28
  Moreover, a synchronization module aligns the generated audio with the video frames.
29
 
30
+
31
  ## Results
32
 
33
  (All audio from our algorithm MMAudio)
34
 
35
+ Videos from Sora:
36
 
37
  https://github.com/user-attachments/assets/82afd192-0cee-48a1-86ca-bd39b8c8f330
38
 
 
 
 
39
 
40
+ Videos from MovieGen/Hunyuan Video/VGGSound:
41
 
42
  https://github.com/user-attachments/assets/29230d4e-21c1-4cf8-a221-c28f2af6d0ca
43
 
44
  For more results, visit https://hkchengrex.com/MMAudio/video_main.html.
45
 
 
46
  ## Installation
47
 
48
  We have only tested this on Ubuntu.
 
51
 
52
  We recommend using a [miniforge](https://github.com/conda-forge/miniforge) environment.
53
 
54
+ - Python 3.8+
55
+ - PyTorch **2.5.1+** and corresponding torchvision/torchaudio (pick your CUDA version https://pytorch.org/)
56
+ - ffmpeg<7 ([this is required by torchaudio](https://pytorch.org/audio/master/installation.html#optional-dependencies), you can install it in a miniforge environment with `conda install -c conda-forge 'ffmpeg<7'`)
57
 
58
+ **Clone our repository:**
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  ```bash
61
  git clone https://github.com/hkchengrex/MMAudio.git
62
  ```
63
 
64
+ **Install with pip:**
65
 
66
  ```bash
67
  cd MMAudio
 
70
 
71
  (If you encounter the File "setup.py" not found error, upgrade your pip with pip install --upgrade pip)
72
 
 
73
  **Pretrained models:**
74
 
75
+ The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`
76
+
77
+ | Model | Download link | File size |
78
+ | -------- | ------- | ------- |
79
+ | Flow prediction network, small 16kHz | <a href="https://databank.illinois.edu/datafiles/k6jve/download" download="mmaudio_small_16k.pth">mmaudio_small_16k.pth</a> | 601M |
80
+ | Flow prediction network, small 44.1kHz | <a href="https://databank.illinois.edu/datafiles/864ya/download" download="mmaudio_small_44k.pth">mmaudio_small_44k.pth</a> | 601M |
81
+ | Flow prediction network, medium 44.1kHz | <a href="https://databank.illinois.edu/datafiles/pa94t/download" download="mmaudio_medium_44k.pth">mmaudio_medium_44k.pth</a> | 2.4G |
82
+ | Flow prediction network, large 44.1kHz **(recommended)** | <a href="https://databank.illinois.edu/datafiles/4jx76/download" download="mmaudio_large_44k.pth">mmaudio_large_44k.pth</a> | 3.9G |
83
+ | 16kHz VAE | <a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-16.pth">v1-16.pth</a> | 655M |
84
+ | 16kHz BigVGAN vocoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/best_netG.pt">best_netG.pt</a> | 429M |
85
+ | 44.1kHz VAE |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth">v1-44.pth</a> | 1.2G |
86
+ | Synchformer visual encoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth">synchformer_state_dict.pth</a> | 907M |
87
+
88
+ The 44.1kHz vocoder will be downloaded automatically.
89
+
90
+ The expected directory structure (full):
91
+
92
+ ```bash
93
+ MMAudio
94
+ ├── ext_weights
95
+ │ ├── best_netG.pt
96
+ │ ├── synchformer_state_dict.pth
97
+ │ ├── v1-16.pth
98
+ │ └── v1-44.pth
99
+ ├── weights
100
+ │ ├── mmaudio_small_16k.pth
101
+ │ ├── mmaudio_small_44k.pth
102
+ │ ├── mmaudio_medium_44k.pth
103
+ │ └── mmaudio_large_44k.pth
104
+ └── ...
105
+ ```
106
+
107
+ The expected directory structure (minimal, for the recommended model only):
108
+
109
+ ```bash
110
+ MMAudio
111
+ ├── ext_weights
112
+ │ ├── synchformer_state_dict.pth
113
+ │ └── v1-44.pth
114
+ ├── weights
115
+ │ └── mmaudio_large_44k.pth
116
+ └── ...
117
+ ```
118
 
119
  ## Demo
120
 
121
+ By default, these scripts use the `large_44k` model.
122
  In our experiments, inference only takes around 6GB of GPU memory (in 16-bit mode) which should fit in most modern GPUs.
123
 
124
  ### Command-line interface
125
 
126
  With `demo.py`
 
127
  ```bash
128
  python demo.py --duration=8 --video=<path to video> --prompt "your prompt"
129
  ```
 
130
  The output (audio in `.flac` format, and video in `.mp4` format) will be saved in `./output`.
131
  See the file for more options.
132
  Simply omit the `--video` option for text-to-audio synthesis.
133
  The default output (and training) duration is 8 seconds. Longer/shorter durations could also work, but a large deviation from the training duration may result in a lower quality.
134
 
135
+
136
  ### Gradio interface
137
 
138
  Supports video-to-audio and text-to-audio synthesis.
 
 
139
 
140
+ ```
141
  python gradio_demo.py
142
  ```
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  ### Known limitations
145
 
146
+ 1. The model sometimes generates undesired unintelligible human speech-like sounds
147
+ 2. The model sometimes generates undesired background music
148
  3. The model struggles with unfamiliar concepts, e.g., it can generate "gunfires" but not "RPG firing".
149
 
150
  We believe all of these three limitations can be addressed with more high-quality training data.
151
 
152
  ## Training
153
+ Work in progress.
 
154
 
155
  ## Evaluation
156
+ Work in progress.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  ## Acknowledgement
 
159
  Many thanks to:
160
+ - [Make-An-Audio 2](https://github.com/bytedance/Make-An-Audio-2) for the 16kHz BigVGAN pretrained model
161
  - [BigVGAN](https://github.com/NVIDIA/BigVGAN)
162
  - [Synchformer](https://github.com/v-iashin/Synchformer)
163
+
app.py CHANGED
@@ -1,33 +1,33 @@
1
- import gc
2
  import logging
3
- from argparse import ArgumentParser
4
  from datetime import datetime
5
- from fractions import Fraction
6
  from pathlib import Path
7
 
8
  import gradio as gr
9
  import torch
10
  import torchaudio
 
11
 
12
- from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
13
- load_video, make_video, setup_eval_logging)
 
 
 
 
 
 
14
  from mmaudio.model.flow_matching import FlowMatching
15
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
16
  from mmaudio.model.sequence_config import SequenceConfig
17
  from mmaudio.model.utils.features_utils import FeaturesUtils
 
18
 
19
  torch.backends.cuda.matmul.allow_tf32 = True
20
  torch.backends.cudnn.allow_tf32 = True
21
 
22
  log = logging.getLogger()
23
 
24
- device = 'cpu'
25
- if torch.cuda.is_available():
26
- device = 'cuda'
27
- elif torch.backends.mps.is_available():
28
- device = 'mps'
29
- else:
30
- log.warning('CUDA/MPS are not available, running on CPU')
31
  dtype = torch.bfloat16
32
 
33
  model: ModelConfig = all_model_cfg['large_44k_v2']
@@ -58,6 +58,7 @@ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
58
  net, feature_utils, seq_cfg = get_model()
59
 
60
 
 
61
  @torch.inference_mode()
62
  def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
63
  cfg_strength: float, duration: float):
@@ -88,53 +89,16 @@ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int
88
  cfg_strength=cfg_strength)
89
  audio = audios.float().cpu()[0]
90
 
91
- current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
92
- output_dir.mkdir(exist_ok=True, parents=True)
93
- video_save_path = output_dir / f'{current_time_string}.mp4'
94
- make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
95
- gc.collect()
96
- return video_save_path
97
-
98
-
99
- @torch.inference_mode()
100
- def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int, num_steps: int,
101
- cfg_strength: float, duration: float):
102
-
103
- rng = torch.Generator(device=device)
104
- if seed >= 0:
105
- rng.manual_seed(seed)
106
- else:
107
- rng.seed()
108
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
109
-
110
- image_info = load_image(image)
111
- clip_frames = image_info.clip_frames
112
- sync_frames = image_info.sync_frames
113
- clip_frames = clip_frames.unsqueeze(0)
114
- sync_frames = sync_frames.unsqueeze(0)
115
- seq_cfg.duration = duration
116
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
117
-
118
- audios = generate(clip_frames,
119
- sync_frames, [prompt],
120
- negative_text=[negative_prompt],
121
- feature_utils=feature_utils,
122
- net=net,
123
- fm=fm,
124
- rng=rng,
125
- cfg_strength=cfg_strength,
126
- image_input=True)
127
- audio = audios.float().cpu()[0]
128
-
129
- current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
130
- output_dir.mkdir(exist_ok=True, parents=True)
131
- video_save_path = output_dir / f'{current_time_string}.mp4'
132
- video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1))
133
  make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
134
- gc.collect()
135
  return video_save_path
136
 
137
 
 
138
  @torch.inference_mode()
139
  def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
140
  duration: float):
@@ -160,11 +124,9 @@ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int,
160
  cfg_strength=cfg_strength)
161
  audio = audios.float().cpu()[0]
162
 
163
- current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
164
- output_dir.mkdir(exist_ok=True, parents=True)
165
- audio_save_path = output_dir / f'{current_time_string}.flac'
166
  torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
167
- gc.collect()
168
  return audio_save_path
169
 
170
 
@@ -176,6 +138,8 @@ video_to_audio_tab = gr.Interface(
176
 
177
  NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
178
  Doing so does not improve results.
 
 
179
  """,
180
  inputs=[
181
  gr.Video(),
@@ -245,8 +209,8 @@ video_to_audio_tab = gr.Interface(
245
  10,
246
  ],
247
  [
248
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
249
- 'storm',
250
  '',
251
  0,
252
  25,
@@ -254,8 +218,8 @@ video_to_audio_tab = gr.Interface(
254
  10,
255
  ],
256
  [
257
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
258
- '',
259
  '',
260
  0,
261
  25,
@@ -263,8 +227,8 @@ video_to_audio_tab = gr.Interface(
263
  10,
264
  ],
265
  [
266
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
267
- 'typing',
268
  '',
269
  0,
270
  25,
@@ -272,8 +236,8 @@ video_to_audio_tab = gr.Interface(
272
  10,
273
  ],
274
  [
275
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
276
- '',
277
  '',
278
  0,
279
  25,
@@ -281,7 +245,7 @@ video_to_audio_tab = gr.Interface(
281
  10,
282
  ],
283
  [
284
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
285
  '',
286
  '',
287
  0,
@@ -293,10 +257,6 @@ video_to_audio_tab = gr.Interface(
293
 
294
  text_to_audio_tab = gr.Interface(
295
  fn=text_to_audio,
296
- description="""
297
- Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
298
- Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
299
- """,
300
  inputs=[
301
  gr.Text(label='Prompt'),
302
  gr.Text(label='Negative prompt'),
@@ -310,34 +270,6 @@ text_to_audio_tab = gr.Interface(
310
  title='MMAudio — Text-to-Audio Synthesis',
311
  )
312
 
313
- image_to_audio_tab = gr.Interface(
314
- fn=image_to_audio,
315
- description="""
316
- Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
317
- Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
318
-
319
- NOTE: It takes longer to process high-resolution images (>384 px on the shorter side).
320
- Doing so does not improve results.
321
- """,
322
- inputs=[
323
- gr.Image(type='filepath'),
324
- gr.Text(label='Prompt'),
325
- gr.Text(label='Negative prompt'),
326
- gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
327
- gr.Number(label='Num steps', value=25, precision=0, minimum=1),
328
- gr.Number(label='Guidance Strength', value=4.5, minimum=1),
329
- gr.Number(label='Duration (sec)', value=8, minimum=1),
330
- ],
331
- outputs='playable_video',
332
- cache_examples=False,
333
- title='MMAudio — Image-to-Audio Synthesis (experimental)',
334
- )
335
-
336
  if __name__ == "__main__":
337
- parser = ArgumentParser()
338
- parser.add_argument('--port', type=int, default=7860)
339
- args = parser.parse_args()
340
-
341
- gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab],
342
- ['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']).launch(
343
- server_port=args.port, allowed_paths=[output_dir])
 
1
+ import spaces
2
  import logging
 
3
  from datetime import datetime
 
4
  from pathlib import Path
5
 
6
  import gradio as gr
7
  import torch
8
  import torchaudio
9
+ import os
10
 
11
+ try:
12
+ import mmaudio
13
+ except ImportError:
14
+ os.system("pip install -e .")
15
+ import mmaudio
16
+
17
+ from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
18
+ setup_eval_logging)
19
  from mmaudio.model.flow_matching import FlowMatching
20
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
21
  from mmaudio.model.sequence_config import SequenceConfig
22
  from mmaudio.model.utils.features_utils import FeaturesUtils
23
+ import tempfile
24
 
25
  torch.backends.cuda.matmul.allow_tf32 = True
26
  torch.backends.cudnn.allow_tf32 = True
27
 
28
  log = logging.getLogger()
29
 
30
+ device = 'cuda'
 
 
 
 
 
 
31
  dtype = torch.bfloat16
32
 
33
  model: ModelConfig = all_model_cfg['large_44k_v2']
 
58
  net, feature_utils, seq_cfg = get_model()
59
 
60
 
61
+ @spaces.GPU(duration=120)
62
  @torch.inference_mode()
63
  def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
64
  cfg_strength: float, duration: float):
 
89
  cfg_strength=cfg_strength)
90
  audio = audios.float().cpu()[0]
91
 
92
+ # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
93
+ video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
94
+ # output_dir.mkdir(exist_ok=True, parents=True)
95
+ # video_save_path = output_dir / f'{current_time_string}.mp4'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
97
+ log.info(f'Saved video to {video_save_path}')
98
  return video_save_path
99
 
100
 
101
+ @spaces.GPU(duration=120)
102
  @torch.inference_mode()
103
  def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
104
  duration: float):
 
124
  cfg_strength=cfg_strength)
125
  audio = audios.float().cpu()[0]
126
 
127
+ audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
 
 
128
  torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
129
+ log.info(f'Saved audio to {audio_save_path}')
130
  return audio_save_path
131
 
132
 
 
138
 
139
  NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
140
  Doing so does not improve results.
141
+
142
+ The model has been trained on 8-second videos. Using much longer or shorter videos will degrade performance. Around 5s~12s should be fine.
143
  """,
144
  inputs=[
145
  gr.Video(),
 
209
  10,
210
  ],
211
  [
212
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
213
+ '',
214
  '',
215
  0,
216
  25,
 
218
  10,
219
  ],
220
  [
221
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
222
+ 'storm',
223
  '',
224
  0,
225
  25,
 
227
  10,
228
  ],
229
  [
230
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
231
+ '',
232
  '',
233
  0,
234
  25,
 
236
  10,
237
  ],
238
  [
239
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
240
+ 'typing',
241
  '',
242
  0,
243
  25,
 
245
  10,
246
  ],
247
  [
248
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
249
  '',
250
  '',
251
  0,
 
257
 
258
  text_to_audio_tab = gr.Interface(
259
  fn=text_to_audio,
 
 
 
 
260
  inputs=[
261
  gr.Text(label='Prompt'),
262
  gr.Text(label='Negative prompt'),
 
270
  title='MMAudio — Text-to-Audio Synthesis',
271
  )
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  if __name__ == "__main__":
274
+ gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab],
275
+ ['Video-to-Audio', 'Text-to-Audio']).launch(allowed_paths=[output_dir])
 
 
 
 
 
batch_eval.py DELETED
@@ -1,110 +0,0 @@
1
- import logging
2
- import os
3
- from pathlib import Path
4
-
5
- import hydra
6
- import torch
7
- import torch.distributed as distributed
8
- import torchaudio
9
- from hydra.core.hydra_config import HydraConfig
10
- from omegaconf import DictConfig
11
- from tqdm import tqdm
12
-
13
- from mmaudio.data.data_setup import setup_eval_dataset
14
- from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate
15
- from mmaudio.model.flow_matching import FlowMatching
16
- from mmaudio.model.networks import MMAudio, get_my_mmaudio
17
- from mmaudio.model.utils.features_utils import FeaturesUtils
18
-
19
- torch.backends.cuda.matmul.allow_tf32 = True
20
- torch.backends.cudnn.allow_tf32 = True
21
-
22
- local_rank = int(os.environ['LOCAL_RANK'])
23
- world_size = int(os.environ['WORLD_SIZE'])
24
- log = logging.getLogger()
25
-
26
-
27
- @torch.inference_mode()
28
- @hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')
29
- def main(cfg: DictConfig):
30
- device = 'cuda'
31
- torch.cuda.set_device(local_rank)
32
-
33
- if cfg.model not in all_model_cfg:
34
- raise ValueError(f'Unknown model variant: {cfg.model}')
35
- model: ModelConfig = all_model_cfg[cfg.model]
36
- model.download_if_needed()
37
- seq_cfg = model.seq_cfg
38
-
39
- run_dir = Path(HydraConfig.get().run.dir)
40
- if cfg.output_name is None:
41
- output_dir = run_dir / cfg.dataset
42
- else:
43
- output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}'
44
- output_dir.mkdir(parents=True, exist_ok=True)
45
-
46
- # load a pretrained model
47
- seq_cfg.duration = cfg.duration_s
48
- net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval()
49
- net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
50
- log.info(f'Loaded weights from {model.model_path}')
51
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
52
- log.info(f'Latent seq len: {seq_cfg.latent_seq_len}')
53
- log.info(f'Clip seq len: {seq_cfg.clip_seq_len}')
54
- log.info(f'Sync seq len: {seq_cfg.sync_seq_len}')
55
-
56
- # misc setup
57
- rng = torch.Generator(device=device)
58
- rng.manual_seed(cfg.seed)
59
- fm = FlowMatching(cfg.sampling.min_sigma,
60
- inference_mode=cfg.sampling.method,
61
- num_steps=cfg.sampling.num_steps)
62
-
63
- feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
64
- synchformer_ckpt=model.synchformer_ckpt,
65
- enable_conditions=True,
66
- mode=model.mode,
67
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
68
- need_vae_encoder=False)
69
- feature_utils = feature_utils.to(device).eval()
70
-
71
- if cfg.compile:
72
- net.preprocess_conditions = torch.compile(net.preprocess_conditions)
73
- net.predict_flow = torch.compile(net.predict_flow)
74
- feature_utils.compile()
75
-
76
- dataset, loader = setup_eval_dataset(cfg.dataset, cfg)
77
-
78
- with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):
79
- for batch in tqdm(loader):
80
- audios = generate(batch.get('clip_video', None),
81
- batch.get('sync_video', None),
82
- batch.get('caption', None),
83
- feature_utils=feature_utils,
84
- net=net,
85
- fm=fm,
86
- rng=rng,
87
- cfg_strength=cfg.cfg_strength,
88
- clip_batch_size_multiplier=64,
89
- sync_batch_size_multiplier=64)
90
- audios = audios.float().cpu()
91
- names = batch['name']
92
- for audio, name in zip(audios, names):
93
- torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)
94
-
95
-
96
- def distributed_setup():
97
- distributed.init_process_group(backend="nccl")
98
- local_rank = distributed.get_rank()
99
- world_size = distributed.get_world_size()
100
- log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
101
- return local_rank, world_size
102
-
103
-
104
- if __name__ == '__main__':
105
- distributed_setup()
106
-
107
- main()
108
-
109
- # clean-up
110
- distributed.destroy_process_group()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/__init__.py DELETED
File without changes
config/base_config.yaml DELETED
@@ -1,62 +0,0 @@
1
- defaults:
2
- - data: base
3
- - eval_data: base
4
- - override hydra/job_logging: custom-simplest
5
- - _self_
6
-
7
- hydra:
8
- run:
9
- dir: ./output/${exp_id}
10
- output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
11
-
12
- enable_email: False
13
-
14
- model: small_16k
15
-
16
- exp_id: default
17
- debug: False
18
- cudnn_benchmark: True
19
- compile: True
20
- amp: True
21
- weights: null
22
- checkpoint: null
23
- seed: 14159265
24
- num_workers: 10 # per-GPU
25
- pin_memory: False # set to True if your system can handle it, i.e., have enough memory
26
-
27
- # NOTE: This DOSE NOT affect the model during inference in any way
28
- # they are just for the dataloader to fill in the missing data in multi-modal loading
29
- # to change the sequence length for the model, see networks.py
30
- data_dim:
31
- text_seq_len: 77
32
- clip_dim: 1024
33
- sync_dim: 768
34
- text_dim: 1024
35
-
36
- # ema configuration
37
- ema:
38
- enable: True
39
- sigma_rels: [0.05, 0.1]
40
- update_every: 1
41
- checkpoint_every: 5_000
42
- checkpoint_folder: ${hydra:run.dir}/ema_ckpts
43
- default_output_sigma: 0.05
44
-
45
-
46
- # sampling
47
- sampling:
48
- mean: 0.0
49
- scale: 1.0
50
- min_sigma: 0.0
51
- method: euler
52
- num_steps: 25
53
-
54
- # classifier-free guidance
55
- null_condition_probability: 0.1
56
- cfg_strength: 4.5
57
-
58
- # checkpoint paths to external modules
59
- vae_16k_ckpt: ./ext_weights/v1-16.pth
60
- vae_44k_ckpt: ./ext_weights/v1-44.pth
61
- bigvgan_vocoder_ckpt: ./ext_weights/best_netG.pt
62
- synchformer_ckpt: ./ext_weights/synchformer_state_dict.pth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/data/base.yaml DELETED
@@ -1,70 +0,0 @@
1
- VGGSound:
2
- root: ../data/video
3
- subset_name: sets/vgg3-train.tsv
4
- fps: 8
5
- height: 384
6
- width: 384
7
- sample_duration_sec: 8.0
8
-
9
- VGGSound_test:
10
- root: ../data/video
11
- subset_name: sets/vgg3-test.tsv
12
- fps: 8
13
- height: 384
14
- width: 384
15
- sample_duration_sec: 8.0
16
-
17
- VGGSound_val:
18
- root: ../data/video
19
- subset_name: sets/vgg3-val.tsv
20
- fps: 8
21
- height: 384
22
- width: 384
23
- sample_duration_sec: 8.0
24
-
25
- ExtractedVGG:
26
- tsv: ../data/v1-16-memmap/vgg-train.tsv
27
- memmap_dir: ../data/v1-16-memmap/vgg-train
28
-
29
- ExtractedVGG_test:
30
- tag: test
31
- gt_cache: ../data/eval-cache/vggsound-test
32
- output_subdir: null
33
- tsv: ../data/v1-16-memmap/vgg-test.tsv
34
- memmap_dir: ../data/v1-16-memmap/vgg-test
35
-
36
- ExtractedVGG_val:
37
- tag: val
38
- gt_cache: ../data/eval-cache/vggsound-val
39
- output_subdir: val
40
- tsv: ../data/v1-16-memmap/vgg-val.tsv
41
- memmap_dir: ../data/v1-16-memmap/vgg-val
42
-
43
- AudioCaps:
44
- tsv: ../data/v1-16-memmap/audiocaps.tsv
45
- memmap_dir: ../data/v1-16-memmap/audiocaps
46
-
47
- AudioSetSL:
48
- tsv: ../data/v1-16-memmap/audioset_sl.tsv
49
- memmap_dir: ../data/v1-16-memmap/audioset_sl
50
-
51
- BBCSound:
52
- tsv: ../data/v1-16-memmap/bbcsound.tsv
53
- memmap_dir: ../data/v1-16-memmap/bbcsound
54
-
55
- FreeSound:
56
- tsv: ../data/v1-16-memmap/freesound.tsv
57
- memmap_dir: ../data/v1-16-memmap/freesound
58
-
59
- Clotho:
60
- tsv: ../data/v1-16-memmap/clotho.tsv
61
- memmap_dir: ../data/v1-16-memmap/clotho
62
-
63
- Example_video:
64
- tsv: ./training/example_output/memmap/vgg-example.tsv
65
- memmap_dir: ./training/example_output/memmap/vgg-example
66
-
67
- Example_audio:
68
- tsv: ./training/example_output/memmap/audio-example.tsv
69
- memmap_dir: ./training/example_output/memmap/audio-example
70
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/eval_config.yaml DELETED
@@ -1,17 +0,0 @@
1
- defaults:
2
- - base_config
3
- - override hydra/job_logging: custom-simplest
4
- - _self_
5
-
6
- hydra:
7
- run:
8
- dir: ./output/${exp_id}
9
- output_subdir: eval-${now:%Y-%m-%d_%H-%M-%S}-hydra
10
-
11
- exp_id: ${model}
12
- dataset: audiocaps
13
- duration_s: 8.0
14
-
15
- # for inference, this is the per-GPU batch size
16
- batch_size: 16
17
- output_name: null
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/eval_data/base.yaml DELETED
@@ -1,22 +0,0 @@
1
- AudioCaps:
2
- audio_path: ../data/AudioCaps-test-audioldm-ver
3
- # a csv file, with a header row of 'name' and 'caption'
4
- # name should match the audio file name without extension
5
- # Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_audioldm_data.csv
6
- csv_path: ../data/AudioCaps-test-audioldm-ver/data.csv
7
-
8
- AudioCaps_full:
9
- audio_path: ../data/AudioCaps-test-full-ver
10
- # a csv file, with a header row of 'name' and 'caption'
11
- # name should match the audio file name without extension
12
- # Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_full_data.csv
13
- csv_path: ../data/AudioCaps-test-full-ver/data.csv
14
-
15
- MovieGen:
16
- video_path: ../data/MovieGen/MovieGenAudioBenchSfx/video_with_audio
17
- jsonl_path: ../data/MovieGen/MovieGenAudioBenchSfx/metadata
18
-
19
- VGGSound:
20
- video_path: ../data/test-videos
21
- # from the officially released csv file
22
- csv_path: ../data/vggsound.csv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/hydra/job_logging/custom-eval.yaml DELETED
@@ -1,32 +0,0 @@
1
- # python logging configuration for tasks
2
- version: 1
3
- formatters:
4
- simple:
5
- format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
6
- datefmt: '%Y-%m-%d %H:%M:%S'
7
- colorlog:
8
- '()': 'colorlog.ColoredFormatter'
9
- format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
- datefmt: '%Y-%m-%d %H:%M:%S'
11
- log_colors:
12
- DEBUG: purple
13
- INFO: green
14
- WARNING: yellow
15
- ERROR: red
16
- CRITICAL: red
17
- handlers:
18
- console:
19
- class: logging.StreamHandler
20
- formatter: colorlog
21
- stream: ext://sys.stdout
22
- file:
23
- class: logging.FileHandler
24
- formatter: simple
25
- # absolute file path
26
- filename: ${hydra.runtime.output_dir}/eval-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
27
- mode: w
28
- root:
29
- level: INFO
30
- handlers: [console, file]
31
-
32
- disable_existing_loggers: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/hydra/job_logging/custom-no-rank.yaml DELETED
@@ -1,32 +0,0 @@
1
- # python logging configuration for tasks
2
- version: 1
3
- formatters:
4
- simple:
5
- format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
- datefmt: '%Y-%m-%d %H:%M:%S'
7
- colorlog:
8
- '()': 'colorlog.ColoredFormatter'
9
- format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
- datefmt: '%Y-%m-%d %H:%M:%S'
11
- log_colors:
12
- DEBUG: purple
13
- INFO: green
14
- WARNING: yellow
15
- ERROR: red
16
- CRITICAL: red
17
- handlers:
18
- console:
19
- class: logging.StreamHandler
20
- formatter: colorlog
21
- stream: ext://sys.stdout
22
- file:
23
- class: logging.FileHandler
24
- formatter: simple
25
- # absolute file path
26
- filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
27
- mode: w
28
- root:
29
- level: INFO
30
- handlers: [console, file]
31
-
32
- disable_existing_loggers: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/hydra/job_logging/custom-simplest.yaml DELETED
@@ -1,26 +0,0 @@
1
- # python logging configuration for tasks
2
- version: 1
3
- formatters:
4
- simple:
5
- format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
- datefmt: '%Y-%m-%d %H:%M:%S'
7
- colorlog:
8
- '()': 'colorlog.ColoredFormatter'
9
- format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
- datefmt: '%Y-%m-%d %H:%M:%S'
11
- log_colors:
12
- DEBUG: purple
13
- INFO: green
14
- WARNING: yellow
15
- ERROR: red
16
- CRITICAL: red
17
- handlers:
18
- console:
19
- class: logging.StreamHandler
20
- formatter: colorlog
21
- stream: ext://sys.stdout
22
- root:
23
- level: INFO
24
- handlers: [console]
25
-
26
- disable_existing_loggers: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/hydra/job_logging/custom.yaml DELETED
@@ -1,33 +0,0 @@
1
- # @package hydra.job_logging
2
- # python logging configuration for tasks
3
- version: 1
4
- formatters:
5
- simple:
6
- format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
7
- datefmt: '%Y-%m-%d %H:%M:%S'
8
- colorlog:
9
- '()': 'colorlog.ColoredFormatter'
10
- format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)sr${oc.env:LOCAL_RANK}%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
11
- datefmt: '%Y-%m-%d %H:%M:%S'
12
- log_colors:
13
- DEBUG: purple
14
- INFO: green
15
- WARNING: yellow
16
- ERROR: red
17
- CRITICAL: red
18
- handlers:
19
- console:
20
- class: logging.StreamHandler
21
- formatter: colorlog
22
- stream: ext://sys.stdout
23
- file:
24
- class: logging.FileHandler
25
- formatter: simple
26
- # absolute file path
27
- filename: ${hydra.runtime.output_dir}/train-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
28
- mode: w
29
- root:
30
- level: INFO
31
- handlers: [console, file]
32
-
33
- disable_existing_loggers: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/train_config.yaml DELETED
@@ -1,41 +0,0 @@
1
- defaults:
2
- - base_config
3
- - override data: base
4
- - override hydra/job_logging: custom
5
- - _self_
6
-
7
- hydra:
8
- run:
9
- dir: ./output/${exp_id}
10
- output_subdir: train-${now:%Y-%m-%d_%H-%M-%S}-hydra
11
-
12
- ema:
13
- start: 0
14
-
15
- mini_train: False
16
- example_train: False
17
- enable_grad_scaler: False
18
- vgg_oversample_rate: 5
19
-
20
- log_text_interval: 200
21
- log_extra_interval: 20_000
22
- val_interval: 5_000
23
- eval_interval: 20_000
24
- save_eval_interval: 40_000
25
- save_weights_interval: 10_000
26
- save_checkpoint_interval: 10_000
27
- save_copy_iterations: []
28
-
29
- batch_size: 512
30
- eval_batch_size: 256 # per-GPU
31
-
32
- num_iterations: 300_000
33
- learning_rate: 1.0e-4
34
- linear_warmup_steps: 1_000
35
-
36
- lr_schedule: step
37
- lr_schedule_steps: [240_000, 270_000]
38
- lr_schedule_gamma: 0.1
39
-
40
- clip_grad_norm: 1.0
41
- weight_decay: 1.0e-6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo.py CHANGED
@@ -62,13 +62,7 @@ def main():
62
  skip_video_composite: bool = args.skip_video_composite
63
  mask_away_clip: bool = args.mask_away_clip
64
 
65
- device = 'cpu'
66
- if torch.cuda.is_available():
67
- device = 'cuda'
68
- elif torch.backends.mps.is_available():
69
- device = 'mps'
70
- else:
71
- log.warning('CUDA/MPS are not available, running on CPU')
72
  dtype = torch.float32 if args.full_precision else torch.bfloat16
73
 
74
  output_dir.mkdir(parents=True, exist_ok=True)
 
62
  skip_video_composite: bool = args.skip_video_composite
63
  mask_away_clip: bool = args.mask_away_clip
64
 
65
+ device = 'cuda'
 
 
 
 
 
 
66
  dtype = torch.float32 if args.full_precision else torch.bfloat16
67
 
68
  output_dir.mkdir(parents=True, exist_ok=True)
docs/EVAL.md DELETED
@@ -1,22 +0,0 @@
1
- # Evaluation
2
-
3
- ## Batch Evaluation
4
-
5
- To evaluate the model on a dataset, use the `batch_eval.py` script. It is significantly more efficient in large-scale evaluation compared to `demo.py`, supporting batched inference, multi-GPU inference, torch compilation, and skipping video compositions.
6
-
7
- An example of running this script with four GPUs is as follows:
8
-
9
- ```bash
10
- OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=4 batch_eval.py duration_s=8 dataset=vggsound model=small_16k num_workers=8
11
- ```
12
-
13
- You may need to update the data paths in `config/eval_data/base.yaml`.
14
- More configuration options can be found in `config/base_config.yaml` and `config/eval_config.yaml`.
15
-
16
- ## Precomputed Results
17
-
18
- Precomputed results for VGGSound, AudioCaps, and MovieGen are available here: https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results
19
-
20
- ## Obtaining Quantitative Metrics
21
-
22
- Our evaluation code is available here: https://github.com/hkchengrex/av-benchmark
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/MODELS.md DELETED
@@ -1,50 +0,0 @@
1
- # Pretrained models
2
-
3
- The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`.
4
- The models are also available at https://huggingface.co/hkchengrex/MMAudio/tree/main
5
-
6
- | Model | Download link | File size |
7
- | -------- | ------- | ------- |
8
- | Flow prediction network, small 16kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_16k.pth" download="mmaudio_small_16k.pth">mmaudio_small_16k.pth</a> | 601M |
9
- | Flow prediction network, small 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_44k.pth" download="mmaudio_small_44k.pth">mmaudio_small_44k.pth</a> | 601M |
10
- | Flow prediction network, medium 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_medium_44k.pth" download="mmaudio_medium_44k.pth">mmaudio_medium_44k.pth</a> | 2.4G |
11
- | Flow prediction network, large 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k.pth" download="mmaudio_large_44k.pth">mmaudio_large_44k.pth</a> | 3.9G |
12
- | Flow prediction network, large 44.1kHz, v2 **(recommended)** | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth" download="mmaudio_large_44k_v2.pth">mmaudio_large_44k_v2.pth</a> | 3.9G |
13
- | 16kHz VAE | <a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-16.pth">v1-16.pth</a> | 655M |
14
- | 16kHz BigVGAN vocoder (from Make-An-Audio 2) |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/best_netG.pt">best_netG.pt</a> | 429M |
15
- | 44.1kHz VAE |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth">v1-44.pth</a> | 1.2G |
16
- | Synchformer visual encoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth">synchformer_state_dict.pth</a> | 907M |
17
-
18
- To run the model, you need four components: a flow prediction network, visual feature extractors (Synchformer and CLIP, CLIP will be downloaded automatically), a VAE, and a vocoder. VAEs and vocoders are specific to the sampling rate (16kHz or 44.1kHz) and not model sizes.
19
- The 44.1kHz vocoder will be downloaded automatically.
20
- The `_v2` model performs worse in benchmarking (e.g., in Fréchet distance), but, in my experience, generalizes better to new data.
21
-
22
- The expected directory structure (full):
23
-
24
- ```bash
25
- MMAudio
26
- ├── ext_weights
27
- │ ├── best_netG.pt
28
- │ ├── synchformer_state_dict.pth
29
- │ ├── v1-16.pth
30
- │ └── v1-44.pth
31
- ├── weights
32
- │ ├── mmaudio_small_16k.pth
33
- │ ├── mmaudio_small_44k.pth
34
- │ ├── mmaudio_medium_44k.pth
35
- │ ├── mmaudio_large_44k.pth
36
- │ └── mmaudio_large_44k_v2.pth
37
- └── ...
38
- ```
39
-
40
- The expected directory structure (minimal, for the recommended model only):
41
-
42
- ```bash
43
- MMAudio
44
- ├── ext_weights
45
- │ ├── synchformer_state_dict.pth
46
- │ └── v1-44.pth
47
- ├── weights
48
- │ └── mmaudio_large_44k_v2.pth
49
- └── ...
50
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/TRAINING.md DELETED
@@ -1,184 +0,0 @@
1
- # Training
2
-
3
- ## Overview
4
-
5
- We have put a large emphasis on making training as fast as possible.
6
- Consequently, some pre-processing steps are required.
7
-
8
- Namely, before starting any training, we
9
-
10
- 1. Obtain training data as videos, audios, and captions.
11
- 2. Encode training audios into spectrograms and then with VAE into mean/std
12
- 3. Extract CLIP and synchronization features from videos
13
- 4. Extract CLIP features from text (captions)
14
- 5. Encode all extracted features into [MemoryMappedTensors](https://pytorch.org/tensordict/main/reference/generated/tensordict.MemoryMappedTensor.html) with [TensorDict](https://pytorch.org/tensordict/main/reference/tensordict.html)
15
-
16
- **NOTE:** for maximum training speed (e.g., when training the base model with 2*H100s), you would need around 3~5 GB/s of random read speed. Spinning disks would not be able to catch up and most consumer-grade SSDs would struggle. In my experience, the best bet is to have a large enough system memory such that the OS can cache the data. This way, the data is read from RAM instead of disk.
17
-
18
- The current training script does not support `_v2` training.
19
-
20
- ## Recommended Hardware Configuration
21
-
22
- These are what I recommend for a smooth and efficient training experience. These are not minimum requirements.
23
-
24
- - Single-node machine. We did not implement multi-node training
25
- - GPUs: for the small model, two 80G-H100s or above; for the large model, eight 80G-H100s or above
26
- - System memory: for 16kHz training, 600GB+; for 44kHz training, 700GB+
27
- - Storage: >2TB of fast NVMe storage. If you have enough system memory, OS caching will help and the storage does not need to be as fast.
28
-
29
- ## Prerequisites
30
-
31
- 1. Install [av-benchmark](https://github.com/hkchengrex/av-benchmark). We use this library to automatically evaluate on the validation set during training, and on the test set after training.
32
- 2. Extract features for evaluation using [av-benchmark](https://github.com/hkchengrex/av-benchmark) for the validation and test set as a [validation cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L38) and a [test cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L31). You can also download the precomputed evaluation cache [here](https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results/tree/main).
33
-
34
- 3. You will need ffmpeg to extract frames from videos. Note that `torchaudio` imposes a maximum version limit (`ffmpeg<7`). You can install it as follows:
35
-
36
- ```bash
37
- conda install -c conda-forge 'ffmpeg<7'
38
- ```
39
-
40
- 4. Download the training datasets. We used [VGGSound](https://arxiv.org/abs/2004.14368), [AudioCaps](https://audiocaps.github.io/), [WavCaps](https://arxiv.org/abs/2303.17395), and [Clotho](https://arxiv.org/abs/1910.09387) (paper to be updated). Note that the audio files in the huggingface release of WavCaps have been downsampled to 32kHz. To the best of our ability, we located the original (high-sampling rate) audio files and used them instead to prevent artifacts during 44.1kHz training. We did not use the "SoundBible" portion of WavCaps, since it is a small set with many short audio unsuitable for our training.
41
-
42
- 5. Download the corresponding VAE (`v1-16.pth` for 16kHz training, and `v1-44.pth` for 44.1kHz training), vocoder models (`best_netG.pt` for 16kHz training; the vocoder for 44.1kHz training will be downloaded automatically), the [empty string encoding](https://github.com/hkchengrex/MMAudio/releases/download/v0.1/empty_string.pth), and Synchformer weights from [MODELS.md](https://github.com/hkchengrex/MMAudio/blob/main/docs/MODELS.md) place them in `ext_weights/`.
43
-
44
- ### Helpful links for downloading the datasets
45
-
46
- We cannot redistribute the datasets for copyright reasons, but we do find some links helpful and they might be helpful to you as well.
47
-
48
- - https://huggingface.co/datasets/Meranti/CLAP_freesound
49
- - https://huggingface.co/datasets/agkphysics/AudioSet
50
- - https://sound-effects.bbcrewind.co.uk/
51
-
52
- For certain sources of VGGSound, you might notice desychronization between the audio and the video. This happens the video keyframes do not always align with the start of the audio and what happens during playbacks is player-dependent. We used PyTorch's decoder which can correctly handle these cases.
53
-
54
- ## Preparing Audio-Video-Text Features
55
-
56
- We have prepared some example data in `training/example_videos`.
57
- `training/extract_video_training_latents.py` extracts audio, video, and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
58
-
59
- To run this script, use the `torchrun` utility:
60
-
61
- ```bash
62
- torchrun --standalone training/extract_video_training_latents.py
63
- ```
64
-
65
- You can run this script with multiple GPUs (with `--nproc_per_node=<n>` after `--standalone` and before the script name) to speed up extraction.
66
- Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
67
- Change the data path definitions in `data_cfg` if necessary.
68
-
69
- Arguments:
70
-
71
- - `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
72
- - `output_dir` -- where TensorDict and the metadata file are saved.
73
-
74
- Outputs produced in `output_dir`:
75
-
76
- 1. A directory named `vgg-{split}` (i.e., in the TensorDict format), containing
77
- a. `mean.memmap` mean values predicted by the VAE encoder (number of videos X sequence length X channel size)
78
- b. `std.memmap` standard deviation values predicted by the VAE encoder (number of videos X sequence length X channel size)
79
- c. `text_features.memmap` text features extracted from CLIP (number of videos X 77 (sequence length) X 1024)
80
- d. `clip_features.memmap` clip features extracted from CLIP (number of videos X 64 (8 fps) X 1024)
81
- e. `sync_features.memmap` synchronization features extracted from Synchformer (number of videos X 192 (24 fps) X 768)
82
- f. `meta.json` that contains the metadata for the above memory mappings
83
- 2. A tab-separated values file named `vgg-{split}.tsv` that contains two columns: `id` containing video file names without extension, and `label` containing corresponding text labels (i.e., captions)
84
-
85
- ## Preparing Audio-Text Features
86
-
87
- We have prepared some example data in `training/example_audios`.
88
-
89
- 1. Run `training/partition_clips` to partition each audio file into clips (by finding start and end points; we do not save the partitioned audio onto the disk to save disk space)
90
- 2. Run `training/extract_audio_training_latents.py` to extract each clip's audio and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
91
-
92
- ### Partitioning the audio files
93
-
94
- Run
95
-
96
- ```bash
97
- python training/partition_clips.py
98
- ```
99
-
100
- Arguments:
101
-
102
- - `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`)
103
- - `output_dir` -- path to the output `.csv` file
104
- - `start` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the beginning of the chunk to be processed
105
- - `end` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the end of the chunk to be processed
106
-
107
- ### Extracting audio and text features
108
-
109
- Run
110
-
111
- ```bash
112
- torchrun --standalone training/extract_audio_training_latents.py
113
- ```
114
-
115
- You can run this with multiple GPUs (with `--nproc_per_node=<n>`) to speed up extraction.
116
- Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
117
-
118
- Arguments:
119
-
120
- - `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`), same as the previous step
121
- - `captions_tsv` -- path to the captions file, a tab-separated values (tsv) file at least with columns `id` and `caption`
122
- - `clips_tsv` -- path to the clips file, generated in the last step
123
- - `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
124
- - `output_dir` -- where TensorDict and the metadata file are saved.
125
-
126
- Outputs produced in `output_dir`:
127
-
128
- 1. A directory named `{basename(output_dir)}` (i.e., in the TensorDict format), containing
129
- a. `mean.memmap` mean values predicted by the VAE encoder (number of audios X sequence length X channel size)
130
- b. `std.memmap` standard deviation values predicted by the VAE encoder (number of audios X sequence length X channel size)
131
- c. `text_features.memmap` text features extracted from CLIP (number of audios X 77 (sequence length) X 1024)
132
- f. `meta.json` that contains the metadata for the above memory mappings
133
- 2. A tab-separated values file named `{basename(output_dir)}.tsv` that contains two columns: `id` containing audio file names without extension, and `label` containing corresponding text labels (i.e., captions)
134
-
135
- ### Reference tsv files (with overlaps removed as mentioned in the paper)
136
-
137
- The reference tsv files can be found [here](https://github.com/hkchengrex/MMAudio/releases/tag/v0.1).
138
-
139
- Note that these reference tsv files are the **outputs** of `extract_audio_training_latents.py`, which means the `id` column might contain duplicate entries (one per clip). You can still use it as the `captions_tsv` input though -- the script will handle duplicates gracefully.
140
- Among these reference tsv files, `audioset_sl.tsv`, `bbcsound.tsv`, and `freesound.tsv` are subsets that are parts of WavCaps. These subsets might be smaller than the original datasets.
141
- The Clotho data contains both the development set and the validation set.
142
-
143
- **Update (Mar 9, 2025)**:
144
- We have updated a corrected set of reference tsv files. The previous tsv files contained some (<1%) corrupted captions (ie, mismatch between audio and caption, see https://github.com/hkchengrex/MMAudio/issues/56). The tsv files for VGGSound are unaffected. This reason for this error is unknown, but I cannot reproduce this error in the latest version of the code. Our pre-trained models are trained with **uncorrected** tsv files. For future training, I recommend using the corrected tsv files.
145
-
146
- The error statistics are as follows:
147
-
148
- - AudioCaps (170/43824), 0.39%
149
- - Freesound: (1670/180636), 0.92%
150
- - AudioSet: (290/100776), 0.29%
151
- - BBCSound: (3/29975), 0.01%
152
- - Clotho: (8/24332), 0.03%
153
-
154
- ## Training on Extracted Features
155
-
156
- We use Distributed Data Parallel (DDP) for training.
157
- First, specify the data path in `config/data/base.yaml`. If you used the default parameters in the scripts above to extract features for the example data, the `Example_video` and `Example_audio` items should already be correct.
158
-
159
- To run training on the example data, use the following command:
160
-
161
- ```bash
162
- OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=1 train.py exp_id=debug compile=False debug=True example_train=True batch_size=1
163
- ```
164
-
165
- This will not train a useful model, but it will check if everything is set up correctly.
166
-
167
- For full training on the base model with two GPUs, use the following command:
168
-
169
- ```bash
170
- OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=2 train.py exp_id=exp_1 model=small_16k
171
- ```
172
-
173
- Any outputs from training will be stored in `output/<exp_id>`.
174
-
175
- More configuration options can be found in `config/base_config.yaml` and `config/train_config.yaml`.
176
- For the medium and large models, specify `vgg_oversample_rate` to be `3` to reduce overfitting.
177
-
178
- ## Checkpoints
179
-
180
- Model checkpoints, including optimizer states and the latest EMA weights, are available here: https://huggingface.co/hkchengrex/MMAudio
181
-
182
- ---
183
-
184
- Godspeed!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/index.html CHANGED
@@ -40,7 +40,7 @@
40
  <br>
41
  <div class="row text-center" style="font-size:28px">
42
  <div class="col">
43
- CVPR 2025
44
  </div>
45
  </div>
46
  <br>
@@ -83,21 +83,19 @@
83
  <br>
84
 
85
  <div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
86
- <div class="col-sm-2">
87
- <a href="https://arxiv.org/abs/2412.15322">[Paper]</a>
88
- </div>
89
- <div class="col-sm-2">
90
- <a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
91
- </div>
92
  <div class="col-sm-3">
93
- <a href="https://huggingface.co/spaces/hkchengrex/MMAudio">[Huggingface Demo]</a>
94
- </div>
95
- <div class="col-sm-2">
96
- <a href="https://colab.research.google.com/drive/1TAaXCY2-kPk4xE4PwKB3EqFbSnkUuzZ8?usp=sharing">[Colab Demo]</a>
97
  </div>
98
  <div class="col-sm-3">
99
- <a href="https://replicate.com/zsxkib/mmaudio">[Replicate Demo]</a>
100
  </div>
 
 
 
 
101
  </div>
102
 
103
  <br>
 
40
  <br>
41
  <div class="row text-center" style="font-size:28px">
42
  <div class="col">
43
+ arXiv 2024
44
  </div>
45
  </div>
46
  <br>
 
83
  <br>
84
 
85
  <div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
86
+ <!-- <div class="col-sm-2">
87
+ <a href="https://arxiv.org/abs/2310.12982">[arXiv]</a>
88
+ </div> -->
 
 
 
89
  <div class="col-sm-3">
90
+ <a href="">[Paper (being prepared)]</a>
 
 
 
91
  </div>
92
  <div class="col-sm-3">
93
+ <a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
94
  </div>
95
+ <!-- <div class="col-sm-2">
96
+ <a
97
+ href="https://colab.research.google.com/drive/1yo43XTbjxuWA7XgCUO9qxAi7wBI6HzvP?usp=sharing">[Colab]</a>
98
+ </div> -->
99
  </div>
100
 
101
  <br>
gradio_demo.py DELETED
@@ -1,343 +0,0 @@
1
- import gc
2
- import logging
3
- from argparse import ArgumentParser
4
- from datetime import datetime
5
- from fractions import Fraction
6
- from pathlib import Path
7
-
8
- import gradio as gr
9
- import torch
10
- import torchaudio
11
-
12
- from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
13
- load_video, make_video, setup_eval_logging)
14
- from mmaudio.model.flow_matching import FlowMatching
15
- from mmaudio.model.networks import MMAudio, get_my_mmaudio
16
- from mmaudio.model.sequence_config import SequenceConfig
17
- from mmaudio.model.utils.features_utils import FeaturesUtils
18
-
19
- torch.backends.cuda.matmul.allow_tf32 = True
20
- torch.backends.cudnn.allow_tf32 = True
21
-
22
- log = logging.getLogger()
23
-
24
- device = 'cpu'
25
- if torch.cuda.is_available():
26
- device = 'cuda'
27
- elif torch.backends.mps.is_available():
28
- device = 'mps'
29
- else:
30
- log.warning('CUDA/MPS are not available, running on CPU')
31
- dtype = torch.bfloat16
32
-
33
- model: ModelConfig = all_model_cfg['large_44k_v2']
34
- model.download_if_needed()
35
- output_dir = Path('./output/gradio')
36
-
37
- setup_eval_logging()
38
-
39
-
40
- def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
41
- seq_cfg = model.seq_cfg
42
-
43
- net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
44
- net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
45
- log.info(f'Loaded weights from {model.model_path}')
46
-
47
- feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
48
- synchformer_ckpt=model.synchformer_ckpt,
49
- enable_conditions=True,
50
- mode=model.mode,
51
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
52
- need_vae_encoder=False)
53
- feature_utils = feature_utils.to(device, dtype).eval()
54
-
55
- return net, feature_utils, seq_cfg
56
-
57
-
58
- net, feature_utils, seq_cfg = get_model()
59
-
60
-
61
- @torch.inference_mode()
62
- def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
63
- cfg_strength: float, duration: float):
64
-
65
- rng = torch.Generator(device=device)
66
- if seed >= 0:
67
- rng.manual_seed(seed)
68
- else:
69
- rng.seed()
70
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
71
-
72
- video_info = load_video(video, duration)
73
- clip_frames = video_info.clip_frames
74
- sync_frames = video_info.sync_frames
75
- duration = video_info.duration_sec
76
- clip_frames = clip_frames.unsqueeze(0)
77
- sync_frames = sync_frames.unsqueeze(0)
78
- seq_cfg.duration = duration
79
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
80
-
81
- audios = generate(clip_frames,
82
- sync_frames, [prompt],
83
- negative_text=[negative_prompt],
84
- feature_utils=feature_utils,
85
- net=net,
86
- fm=fm,
87
- rng=rng,
88
- cfg_strength=cfg_strength)
89
- audio = audios.float().cpu()[0]
90
-
91
- current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
92
- output_dir.mkdir(exist_ok=True, parents=True)
93
- video_save_path = output_dir / f'{current_time_string}.mp4'
94
- make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
95
- gc.collect()
96
- return video_save_path
97
-
98
-
99
- @torch.inference_mode()
100
- def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int, num_steps: int,
101
- cfg_strength: float, duration: float):
102
-
103
- rng = torch.Generator(device=device)
104
- if seed >= 0:
105
- rng.manual_seed(seed)
106
- else:
107
- rng.seed()
108
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
109
-
110
- image_info = load_image(image)
111
- clip_frames = image_info.clip_frames
112
- sync_frames = image_info.sync_frames
113
- clip_frames = clip_frames.unsqueeze(0)
114
- sync_frames = sync_frames.unsqueeze(0)
115
- seq_cfg.duration = duration
116
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
117
-
118
- audios = generate(clip_frames,
119
- sync_frames, [prompt],
120
- negative_text=[negative_prompt],
121
- feature_utils=feature_utils,
122
- net=net,
123
- fm=fm,
124
- rng=rng,
125
- cfg_strength=cfg_strength,
126
- image_input=True)
127
- audio = audios.float().cpu()[0]
128
-
129
- current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
130
- output_dir.mkdir(exist_ok=True, parents=True)
131
- video_save_path = output_dir / f'{current_time_string}.mp4'
132
- video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1))
133
- make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
134
- gc.collect()
135
- return video_save_path
136
-
137
-
138
- @torch.inference_mode()
139
- def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
140
- duration: float):
141
-
142
- rng = torch.Generator(device=device)
143
- if seed >= 0:
144
- rng.manual_seed(seed)
145
- else:
146
- rng.seed()
147
- fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
148
-
149
- clip_frames = sync_frames = None
150
- seq_cfg.duration = duration
151
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
152
-
153
- audios = generate(clip_frames,
154
- sync_frames, [prompt],
155
- negative_text=[negative_prompt],
156
- feature_utils=feature_utils,
157
- net=net,
158
- fm=fm,
159
- rng=rng,
160
- cfg_strength=cfg_strength)
161
- audio = audios.float().cpu()[0]
162
-
163
- current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
164
- output_dir.mkdir(exist_ok=True, parents=True)
165
- audio_save_path = output_dir / f'{current_time_string}.flac'
166
- torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
167
- gc.collect()
168
- return audio_save_path
169
-
170
-
171
- video_to_audio_tab = gr.Interface(
172
- fn=video_to_audio,
173
- description="""
174
- Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
175
- Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
176
-
177
- NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
178
- Doing so does not improve results.
179
- """,
180
- inputs=[
181
- gr.Video(),
182
- gr.Text(label='Prompt'),
183
- gr.Text(label='Negative prompt', value='music'),
184
- gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
185
- gr.Number(label='Num steps', value=25, precision=0, minimum=1),
186
- gr.Number(label='Guidance Strength', value=4.5, minimum=1),
187
- gr.Number(label='Duration (sec)', value=8, minimum=1),
188
- ],
189
- outputs='playable_video',
190
- cache_examples=False,
191
- title='MMAudio — Video-to-Audio Synthesis',
192
- examples=[
193
- [
194
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
195
- 'waves, seagulls',
196
- '',
197
- 0,
198
- 25,
199
- 4.5,
200
- 10,
201
- ],
202
- [
203
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4',
204
- '',
205
- 'music',
206
- 0,
207
- 25,
208
- 4.5,
209
- 10,
210
- ],
211
- [
212
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4',
213
- 'bubbles',
214
- '',
215
- 0,
216
- 25,
217
- 4.5,
218
- 10,
219
- ],
220
- [
221
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4',
222
- 'Indian holy music',
223
- '',
224
- 0,
225
- 25,
226
- 4.5,
227
- 10,
228
- ],
229
- [
230
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4',
231
- 'galloping',
232
- '',
233
- 0,
234
- 25,
235
- 4.5,
236
- 10,
237
- ],
238
- [
239
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
240
- 'waves, storm',
241
- '',
242
- 0,
243
- 25,
244
- 4.5,
245
- 10,
246
- ],
247
- [
248
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
249
- 'storm',
250
- '',
251
- 0,
252
- 25,
253
- 4.5,
254
- 10,
255
- ],
256
- [
257
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
258
- '',
259
- '',
260
- 0,
261
- 25,
262
- 4.5,
263
- 10,
264
- ],
265
- [
266
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
267
- 'typing',
268
- '',
269
- 0,
270
- 25,
271
- 4.5,
272
- 10,
273
- ],
274
- [
275
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
276
- '',
277
- '',
278
- 0,
279
- 25,
280
- 4.5,
281
- 10,
282
- ],
283
- [
284
- 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
285
- '',
286
- '',
287
- 0,
288
- 25,
289
- 4.5,
290
- 10,
291
- ],
292
- ])
293
-
294
- text_to_audio_tab = gr.Interface(
295
- fn=text_to_audio,
296
- description="""
297
- Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
298
- Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
299
- """,
300
- inputs=[
301
- gr.Text(label='Prompt'),
302
- gr.Text(label='Negative prompt'),
303
- gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
304
- gr.Number(label='Num steps', value=25, precision=0, minimum=1),
305
- gr.Number(label='Guidance Strength', value=4.5, minimum=1),
306
- gr.Number(label='Duration (sec)', value=8, minimum=1),
307
- ],
308
- outputs='audio',
309
- cache_examples=False,
310
- title='MMAudio — Text-to-Audio Synthesis',
311
- )
312
-
313
- image_to_audio_tab = gr.Interface(
314
- fn=image_to_audio,
315
- description="""
316
- Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
317
- Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
318
-
319
- NOTE: It takes longer to process high-resolution images (>384 px on the shorter side).
320
- Doing so does not improve results.
321
- """,
322
- inputs=[
323
- gr.Image(type='filepath'),
324
- gr.Text(label='Prompt'),
325
- gr.Text(label='Negative prompt'),
326
- gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
327
- gr.Number(label='Num steps', value=25, precision=0, minimum=1),
328
- gr.Number(label='Guidance Strength', value=4.5, minimum=1),
329
- gr.Number(label='Duration (sec)', value=8, minimum=1),
330
- ],
331
- outputs='playable_video',
332
- cache_examples=False,
333
- title='MMAudio — Image-to-Audio Synthesis (experimental)',
334
- )
335
-
336
- if __name__ == "__main__":
337
- parser = ArgumentParser()
338
- parser.add_argument('--port', type=int, default=7860)
339
- args = parser.parse_args()
340
-
341
- gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab],
342
- ['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']).launch(
343
- server_port=args.port, allowed_paths=[output_dir])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (187 Bytes)
 
mmaudio/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (185 Bytes)
 
mmaudio/__pycache__/eval_utils.cpython-310.pyc DELETED
Binary file (7.07 kB)
 
mmaudio/__pycache__/eval_utils.cpython-38.pyc DELETED
Binary file (7.03 kB)
 
mmaudio/data/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (192 Bytes)
 
mmaudio/data/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (190 Bytes)
 
mmaudio/data/__pycache__/av_utils.cpython-310.pyc DELETED
Binary file (4.91 kB)
 
mmaudio/data/__pycache__/av_utils.cpython-38.pyc DELETED
Binary file (4.89 kB)
 
mmaudio/data/av_utils.py CHANGED
@@ -1,7 +1,7 @@
1
  from dataclasses import dataclass
2
  from fractions import Fraction
3
  from pathlib import Path
4
- from typing import Optional, List, Tuple
5
 
6
  import av
7
  import numpy as np
@@ -15,7 +15,7 @@ class VideoInfo:
15
  fps: Fraction
16
  clip_frames: torch.Tensor
17
  sync_frames: torch.Tensor
18
- all_frames: Optional[List[np.ndarray]]
19
 
20
  @property
21
  def height(self):
@@ -25,35 +25,9 @@ class VideoInfo:
25
  def width(self):
26
  return self.all_frames[0].shape[1]
27
 
28
- @classmethod
29
- def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
30
- fps: Fraction) -> 'VideoInfo':
31
- num_frames = int(duration_sec * fps)
32
- all_frames = [image_info.original_frame] * num_frames
33
- return cls(duration_sec=duration_sec,
34
- fps=fps,
35
- clip_frames=image_info.clip_frames,
36
- sync_frames=image_info.sync_frames,
37
- all_frames=all_frames)
38
 
39
-
40
- @dataclass
41
- class ImageInfo:
42
- clip_frames: torch.Tensor
43
- sync_frames: torch.Tensor
44
- original_frame: Optional[np.ndarray]
45
-
46
- @property
47
- def height(self):
48
- return self.original_frame.shape[0]
49
-
50
- @property
51
- def width(self):
52
- return self.original_frame.shape[1]
53
-
54
-
55
- def read_frames(video_path: Path, list_of_fps: List[float], start_sec: float, end_sec: float,
56
- need_all_frames: bool) -> Tuple[List[np.ndarray], List[np.ndarray], Fraction]:
57
  output_frames = [[] for _ in list_of_fps]
58
  next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
59
  time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
 
1
  from dataclasses import dataclass
2
  from fractions import Fraction
3
  from pathlib import Path
4
+ from typing import Optional
5
 
6
  import av
7
  import numpy as np
 
15
  fps: Fraction
16
  clip_frames: torch.Tensor
17
  sync_frames: torch.Tensor
18
+ all_frames: Optional[list[np.ndarray]]
19
 
20
  @property
21
  def height(self):
 
25
  def width(self):
26
  return self.all_frames[0].shape[1]
27
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
30
+ need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  output_frames = [[] for _ in list_of_fps]
32
  next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
33
  time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
mmaudio/data/data_setup.py DELETED
@@ -1,174 +0,0 @@
1
- import logging
2
- import random
3
-
4
- import numpy as np
5
- import torch
6
- from omegaconf import DictConfig
7
- from torch.utils.data import DataLoader, Dataset
8
- from torch.utils.data.dataloader import default_collate
9
- from torch.utils.data.distributed import DistributedSampler
10
-
11
- from mmaudio.data.eval.audiocaps import AudioCapsData
12
- from mmaudio.data.eval.video_dataset import MovieGen, VGGSound
13
- from mmaudio.data.extracted_audio import ExtractedAudio
14
- from mmaudio.data.extracted_vgg import ExtractedVGG
15
- from mmaudio.data.mm_dataset import MultiModalDataset
16
- from mmaudio.utils.dist_utils import local_rank
17
-
18
- log = logging.getLogger()
19
-
20
-
21
- # Re-seed randomness every time we start a worker
22
- def worker_init_fn(worker_id: int):
23
- worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
24
- np.random.seed(worker_seed)
25
- random.seed(worker_seed)
26
- log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
27
-
28
-
29
- def load_vgg_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
30
- dataset = ExtractedVGG(tsv_path=data_cfg.tsv,
31
- data_dim=cfg.data_dim,
32
- premade_mmap_dir=data_cfg.memmap_dir)
33
-
34
- return dataset
35
-
36
-
37
- def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
38
- dataset = ExtractedAudio(tsv_path=data_cfg.tsv,
39
- data_dim=cfg.data_dim,
40
- premade_mmap_dir=data_cfg.memmap_dir)
41
-
42
- return dataset
43
-
44
-
45
- def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]:
46
- if cfg.mini_train:
47
- vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
48
- audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
49
- dataset = MultiModalDataset([vgg], [audiocaps])
50
- if cfg.example_train:
51
- video = load_vgg_data(cfg, cfg.data.Example_video)
52
- audio = load_audio_data(cfg, cfg.data.Example_audio)
53
- dataset = MultiModalDataset([video], [audio])
54
- else:
55
- # load the largest one first
56
- freesound = load_audio_data(cfg, cfg.data.FreeSound)
57
- vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG)
58
- audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
59
- audioset_sl = load_audio_data(cfg, cfg.data.AudioSetSL)
60
- bbcsound = load_audio_data(cfg, cfg.data.BBCSound)
61
- clotho = load_audio_data(cfg, cfg.data.Clotho)
62
- dataset = MultiModalDataset([vgg] * cfg.vgg_oversample_rate,
63
- [audiocaps, audioset_sl, bbcsound, freesound, clotho])
64
-
65
- batch_size = cfg.batch_size
66
- num_workers = cfg.num_workers
67
- pin_memory = cfg.pin_memory
68
- sampler, loader = construct_loader(dataset,
69
- batch_size,
70
- num_workers,
71
- shuffle=True,
72
- drop_last=True,
73
- pin_memory=pin_memory)
74
-
75
- return dataset, sampler, loader
76
-
77
-
78
- def setup_test_datasets(cfg):
79
- dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_test)
80
-
81
- batch_size = cfg.batch_size
82
- num_workers = cfg.num_workers
83
- pin_memory = cfg.pin_memory
84
- sampler, loader = construct_loader(dataset,
85
- batch_size,
86
- num_workers,
87
- shuffle=False,
88
- drop_last=False,
89
- pin_memory=pin_memory)
90
-
91
- return dataset, sampler, loader
92
-
93
-
94
- def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]:
95
- if cfg.example_train:
96
- dataset = load_vgg_data(cfg, cfg.data.Example_video)
97
- else:
98
- dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
99
-
100
- val_batch_size = cfg.batch_size
101
- val_eval_batch_size = cfg.eval_batch_size
102
- num_workers = cfg.num_workers
103
- pin_memory = cfg.pin_memory
104
- _, val_loader = construct_loader(dataset,
105
- val_batch_size,
106
- num_workers,
107
- shuffle=False,
108
- drop_last=False,
109
- pin_memory=pin_memory)
110
- _, eval_loader = construct_loader(dataset,
111
- val_eval_batch_size,
112
- num_workers,
113
- shuffle=False,
114
- drop_last=False,
115
- pin_memory=pin_memory)
116
-
117
- return dataset, val_loader, eval_loader
118
-
119
-
120
- def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
121
- if dataset_name.startswith('audiocaps_full'):
122
- dataset = AudioCapsData(cfg.eval_data.AudioCaps_full.audio_path,
123
- cfg.eval_data.AudioCaps_full.csv_path)
124
- elif dataset_name.startswith('audiocaps'):
125
- dataset = AudioCapsData(cfg.eval_data.AudioCaps.audio_path,
126
- cfg.eval_data.AudioCaps.csv_path)
127
- elif dataset_name.startswith('moviegen'):
128
- dataset = MovieGen(cfg.eval_data.MovieGen.video_path,
129
- cfg.eval_data.MovieGen.jsonl_path,
130
- duration_sec=cfg.duration_s)
131
- elif dataset_name.startswith('vggsound'):
132
- dataset = VGGSound(cfg.eval_data.VGGSound.video_path,
133
- cfg.eval_data.VGGSound.csv_path,
134
- duration_sec=cfg.duration_s)
135
- else:
136
- raise ValueError(f'Invalid dataset name: {dataset_name}')
137
-
138
- batch_size = cfg.batch_size
139
- num_workers = cfg.num_workers
140
- pin_memory = cfg.pin_memory
141
- _, loader = construct_loader(dataset,
142
- batch_size,
143
- num_workers,
144
- shuffle=False,
145
- drop_last=False,
146
- pin_memory=pin_memory,
147
- error_avoidance=True)
148
- return dataset, loader
149
-
150
-
151
- def error_avoidance_collate(batch):
152
- batch = list(filter(lambda x: x is not None, batch))
153
- return default_collate(batch)
154
-
155
-
156
- def construct_loader(dataset: Dataset,
157
- batch_size: int,
158
- num_workers: int,
159
- *,
160
- shuffle: bool = True,
161
- drop_last: bool = True,
162
- pin_memory: bool = False,
163
- error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]:
164
- train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
165
- train_loader = DataLoader(dataset,
166
- batch_size,
167
- sampler=train_sampler,
168
- num_workers=num_workers,
169
- worker_init_fn=worker_init_fn,
170
- drop_last=drop_last,
171
- persistent_workers=num_workers > 0,
172
- pin_memory=pin_memory,
173
- collate_fn=error_avoidance_collate if error_avoidance else None)
174
- return train_sampler, train_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/eval/__init__.py DELETED
File without changes
mmaudio/data/eval/audiocaps.py DELETED
@@ -1,39 +0,0 @@
1
- import logging
2
- import os
3
- from collections import defaultdict
4
- from pathlib import Path
5
- from typing import Union
6
-
7
- import pandas as pd
8
- import torch
9
- from torch.utils.data.dataset import Dataset
10
-
11
- log = logging.getLogger()
12
-
13
-
14
- class AudioCapsData(Dataset):
15
-
16
- def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
17
- df = pd.read_csv(csv_path).to_dict(orient='records')
18
-
19
- audio_files = sorted(os.listdir(audio_path))
20
- audio_files = set(
21
- [Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
22
-
23
- self.data = []
24
- for row in df:
25
- self.data.append({
26
- 'name': row['name'],
27
- 'caption': row['caption'],
28
- })
29
-
30
- self.audio_path = Path(audio_path)
31
- self.csv_path = Path(csv_path)
32
-
33
- log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
34
-
35
- def __getitem__(self, idx: int) -> torch.Tensor:
36
- return self.data[idx]
37
-
38
- def __len__(self):
39
- return len(self.data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/eval/moviegen.py DELETED
@@ -1,131 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- from pathlib import Path
5
- from typing import Union
6
-
7
- import torch
8
- from torch.utils.data.dataset import Dataset
9
- from torchvision.transforms import v2
10
- from torio.io import StreamingMediaDecoder
11
-
12
- from mmaudio.utils.dist_utils import local_rank
13
-
14
- log = logging.getLogger()
15
-
16
- _CLIP_SIZE = 384
17
- _CLIP_FPS = 8.0
18
-
19
- _SYNC_SIZE = 224
20
- _SYNC_FPS = 25.0
21
-
22
-
23
- class MovieGenData(Dataset):
24
-
25
- def __init__(
26
- self,
27
- video_root: Union[str, Path],
28
- sync_root: Union[str, Path],
29
- jsonl_root: Union[str, Path],
30
- *,
31
- duration_sec: float = 10.0,
32
- read_clip: bool = True,
33
- ):
34
- self.video_root = Path(video_root)
35
- self.sync_root = Path(sync_root)
36
- self.jsonl_root = Path(jsonl_root)
37
- self.read_clip = read_clip
38
-
39
- videos = sorted(os.listdir(self.video_root))
40
- videos = [v[:-4] for v in videos] # remove extensions
41
- self.captions = {}
42
-
43
- for v in videos:
44
- with open(self.jsonl_root / (v + '.jsonl')) as f:
45
- data = json.load(f)
46
- self.captions[v] = data['audio_prompt']
47
-
48
- if local_rank == 0:
49
- log.info(f'{len(videos)} videos found in {video_root}')
50
-
51
- self.duration_sec = duration_sec
52
-
53
- self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
54
- self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
55
-
56
- self.clip_augment = v2.Compose([
57
- v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
58
- v2.ToImage(),
59
- v2.ToDtype(torch.float32, scale=True),
60
- ])
61
-
62
- self.sync_augment = v2.Compose([
63
- v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
64
- v2.CenterCrop(_SYNC_SIZE),
65
- v2.ToImage(),
66
- v2.ToDtype(torch.float32, scale=True),
67
- v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
68
- ])
69
-
70
- self.videos = videos
71
-
72
- def sample(self, idx: int) -> dict[str, torch.Tensor]:
73
- video_id = self.videos[idx]
74
- caption = self.captions[video_id]
75
-
76
- reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
77
- reader.add_basic_video_stream(
78
- frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
79
- frame_rate=_CLIP_FPS,
80
- format='rgb24',
81
- )
82
- reader.add_basic_video_stream(
83
- frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
84
- frame_rate=_SYNC_FPS,
85
- format='rgb24',
86
- )
87
-
88
- reader.fill_buffer()
89
- data_chunk = reader.pop_chunks()
90
-
91
- clip_chunk = data_chunk[0]
92
- sync_chunk = data_chunk[1]
93
- if clip_chunk is None:
94
- raise RuntimeError(f'CLIP video returned None {video_id}')
95
- if clip_chunk.shape[0] < self.clip_expected_length:
96
- raise RuntimeError(f'CLIP video too short {video_id}')
97
-
98
- if sync_chunk is None:
99
- raise RuntimeError(f'Sync video returned None {video_id}')
100
- if sync_chunk.shape[0] < self.sync_expected_length:
101
- raise RuntimeError(f'Sync video too short {video_id}')
102
-
103
- # truncate the video
104
- clip_chunk = clip_chunk[:self.clip_expected_length]
105
- if clip_chunk.shape[0] != self.clip_expected_length:
106
- raise RuntimeError(f'CLIP video wrong length {video_id}, '
107
- f'expected {self.clip_expected_length}, '
108
- f'got {clip_chunk.shape[0]}')
109
- clip_chunk = self.clip_augment(clip_chunk)
110
-
111
- sync_chunk = sync_chunk[:self.sync_expected_length]
112
- if sync_chunk.shape[0] != self.sync_expected_length:
113
- raise RuntimeError(f'Sync video wrong length {video_id}, '
114
- f'expected {self.sync_expected_length}, '
115
- f'got {sync_chunk.shape[0]}')
116
- sync_chunk = self.sync_augment(sync_chunk)
117
-
118
- data = {
119
- 'name': video_id,
120
- 'caption': caption,
121
- 'clip_video': clip_chunk,
122
- 'sync_video': sync_chunk,
123
- }
124
-
125
- return data
126
-
127
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
128
- return self.sample(idx)
129
-
130
- def __len__(self):
131
- return len(self.captions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/eval/video_dataset.py DELETED
@@ -1,197 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- from pathlib import Path
5
- from typing import Union
6
-
7
- import pandas as pd
8
- import torch
9
- from torch.utils.data.dataset import Dataset
10
- from torchvision.transforms import v2
11
- from torio.io import StreamingMediaDecoder
12
-
13
- from mmaudio.utils.dist_utils import local_rank
14
-
15
- log = logging.getLogger()
16
-
17
- _CLIP_SIZE = 384
18
- _CLIP_FPS = 8.0
19
-
20
- _SYNC_SIZE = 224
21
- _SYNC_FPS = 25.0
22
-
23
-
24
- class VideoDataset(Dataset):
25
-
26
- def __init__(
27
- self,
28
- video_root: Union[str, Path],
29
- *,
30
- duration_sec: float = 8.0,
31
- ):
32
- self.video_root = Path(video_root)
33
-
34
- self.duration_sec = duration_sec
35
-
36
- self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
37
- self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
38
-
39
- self.clip_transform = v2.Compose([
40
- v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
41
- v2.ToImage(),
42
- v2.ToDtype(torch.float32, scale=True),
43
- ])
44
-
45
- self.sync_transform = v2.Compose([
46
- v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
47
- v2.CenterCrop(_SYNC_SIZE),
48
- v2.ToImage(),
49
- v2.ToDtype(torch.float32, scale=True),
50
- v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
51
- ])
52
-
53
- # to be implemented by subclasses
54
- self.captions = {}
55
- self.videos = sorted(list(self.captions.keys()))
56
-
57
- def sample(self, idx: int) -> dict[str, torch.Tensor]:
58
- video_id = self.videos[idx]
59
- caption = self.captions[video_id]
60
-
61
- reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
62
- reader.add_basic_video_stream(
63
- frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
64
- frame_rate=_CLIP_FPS,
65
- format='rgb24',
66
- )
67
- reader.add_basic_video_stream(
68
- frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
69
- frame_rate=_SYNC_FPS,
70
- format='rgb24',
71
- )
72
-
73
- reader.fill_buffer()
74
- data_chunk = reader.pop_chunks()
75
-
76
- clip_chunk = data_chunk[0]
77
- sync_chunk = data_chunk[1]
78
- if clip_chunk is None:
79
- raise RuntimeError(f'CLIP video returned None {video_id}')
80
- if clip_chunk.shape[0] < self.clip_expected_length:
81
- raise RuntimeError(
82
- f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
83
- )
84
-
85
- if sync_chunk is None:
86
- raise RuntimeError(f'Sync video returned None {video_id}')
87
- if sync_chunk.shape[0] < self.sync_expected_length:
88
- raise RuntimeError(
89
- f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
90
- )
91
-
92
- # truncate the video
93
- clip_chunk = clip_chunk[:self.clip_expected_length]
94
- if clip_chunk.shape[0] != self.clip_expected_length:
95
- raise RuntimeError(f'CLIP video wrong length {video_id}, '
96
- f'expected {self.clip_expected_length}, '
97
- f'got {clip_chunk.shape[0]}')
98
- clip_chunk = self.clip_transform(clip_chunk)
99
-
100
- sync_chunk = sync_chunk[:self.sync_expected_length]
101
- if sync_chunk.shape[0] != self.sync_expected_length:
102
- raise RuntimeError(f'Sync video wrong length {video_id}, '
103
- f'expected {self.sync_expected_length}, '
104
- f'got {sync_chunk.shape[0]}')
105
- sync_chunk = self.sync_transform(sync_chunk)
106
-
107
- data = {
108
- 'name': video_id,
109
- 'caption': caption,
110
- 'clip_video': clip_chunk,
111
- 'sync_video': sync_chunk,
112
- }
113
-
114
- return data
115
-
116
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
117
- try:
118
- return self.sample(idx)
119
- except Exception as e:
120
- log.error(f'Error loading video {self.videos[idx]}: {e}')
121
- return None
122
-
123
- def __len__(self):
124
- return len(self.captions)
125
-
126
-
127
- class VGGSound(VideoDataset):
128
-
129
- def __init__(
130
- self,
131
- video_root: Union[str, Path],
132
- csv_path: Union[str, Path],
133
- *,
134
- duration_sec: float = 8.0,
135
- ):
136
- super().__init__(video_root, duration_sec=duration_sec)
137
- self.video_root = Path(video_root)
138
- self.csv_path = Path(csv_path)
139
-
140
- videos = sorted(os.listdir(self.video_root))
141
- if local_rank == 0:
142
- log.info(f'{len(videos)} videos found in {video_root}')
143
- self.captions = {}
144
-
145
- df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
146
- 'split']).to_dict(orient='records')
147
-
148
- videos_no_found = []
149
- for row in df:
150
- if row['split'] == 'test':
151
- start_sec = int(row['sec'])
152
- video_id = str(row['id'])
153
- # this is how our videos are named
154
- video_name = f'{video_id}_{start_sec:06d}'
155
- if video_name + '.mp4' not in videos:
156
- videos_no_found.append(video_name)
157
- continue
158
-
159
- self.captions[video_name] = row['caption']
160
-
161
- if local_rank == 0:
162
- log.info(f'{len(videos)} videos found in {video_root}')
163
- log.info(f'{len(self.captions)} useable videos found')
164
- if videos_no_found:
165
- log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
166
- log.info(
167
- 'A small amount is expected, as not all videos are still available on YouTube')
168
-
169
- self.videos = sorted(list(self.captions.keys()))
170
-
171
-
172
- class MovieGen(VideoDataset):
173
-
174
- def __init__(
175
- self,
176
- video_root: Union[str, Path],
177
- jsonl_root: Union[str, Path],
178
- *,
179
- duration_sec: float = 10.0,
180
- ):
181
- super().__init__(video_root, duration_sec=duration_sec)
182
- self.video_root = Path(video_root)
183
- self.jsonl_root = Path(jsonl_root)
184
-
185
- videos = sorted(os.listdir(self.video_root))
186
- videos = [v[:-4] for v in videos] # remove extensions
187
- self.captions = {}
188
-
189
- for v in videos:
190
- with open(self.jsonl_root / (v + '.jsonl')) as f:
191
- data = json.load(f)
192
- self.captions[v] = data['audio_prompt']
193
-
194
- if local_rank == 0:
195
- log.info(f'{len(videos)} videos found in {video_root}')
196
-
197
- self.videos = videos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/extracted_audio.py DELETED
@@ -1,88 +0,0 @@
1
- import logging
2
- from pathlib import Path
3
- from typing import Union
4
-
5
- import pandas as pd
6
- import torch
7
- from tensordict import TensorDict
8
- from torch.utils.data.dataset import Dataset
9
-
10
- from mmaudio.utils.dist_utils import local_rank
11
-
12
- log = logging.getLogger()
13
-
14
-
15
- class ExtractedAudio(Dataset):
16
-
17
- def __init__(
18
- self,
19
- tsv_path: Union[str, Path],
20
- *,
21
- premade_mmap_dir: Union[str, Path],
22
- data_dim: dict[str, int],
23
- ):
24
- super().__init__()
25
-
26
- self.data_dim = data_dim
27
- self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
28
- self.ids = [str(d['id']) for d in self.df_list]
29
-
30
- log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
31
- # load precomputed memory mapped tensors
32
- premade_mmap_dir = Path(premade_mmap_dir)
33
- td = TensorDict.load_memmap(premade_mmap_dir)
34
- log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
35
- self.mean = td['mean']
36
- self.std = td['std']
37
- self.text_features = td['text_features']
38
-
39
- log.info(f'Loaded {len(self)} samples from {premade_mmap_dir}.')
40
- log.info(f'Loaded mean: {self.mean.shape}.')
41
- log.info(f'Loaded std: {self.std.shape}.')
42
- log.info(f'Loaded text features: {self.text_features.shape}.')
43
-
44
- assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
45
- f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
46
- assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
47
- f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
48
-
49
- assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
50
- f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
51
- assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
52
- f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
53
-
54
- self.fake_clip_features = torch.zeros(self.data_dim['clip_seq_len'],
55
- self.data_dim['clip_dim'])
56
- self.fake_sync_features = torch.zeros(self.data_dim['sync_seq_len'],
57
- self.data_dim['sync_dim'])
58
- self.video_exist = torch.tensor(0, dtype=torch.bool)
59
- self.text_exist = torch.tensor(1, dtype=torch.bool)
60
-
61
- def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
62
- latents = self.mean
63
- return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
64
-
65
- def get_memory_mapped_tensor(self) -> TensorDict:
66
- td = TensorDict({
67
- 'mean': self.mean,
68
- 'std': self.std,
69
- 'text_features': self.text_features,
70
- })
71
- return td
72
-
73
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
74
- data = {
75
- 'id': str(self.df_list[idx]['id']),
76
- 'a_mean': self.mean[idx],
77
- 'a_std': self.std[idx],
78
- 'clip_features': self.fake_clip_features,
79
- 'sync_features': self.fake_sync_features,
80
- 'text_features': self.text_features[idx],
81
- 'caption': self.df_list[idx]['caption'],
82
- 'video_exist': self.video_exist,
83
- 'text_exist': self.text_exist,
84
- }
85
- return data
86
-
87
- def __len__(self):
88
- return len(self.ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/extracted_vgg.py DELETED
@@ -1,101 +0,0 @@
1
- import logging
2
- from pathlib import Path
3
- from typing import Union
4
-
5
- import pandas as pd
6
- import torch
7
- from tensordict import TensorDict
8
- from torch.utils.data.dataset import Dataset
9
-
10
- from mmaudio.utils.dist_utils import local_rank
11
-
12
- log = logging.getLogger()
13
-
14
-
15
- class ExtractedVGG(Dataset):
16
-
17
- def __init__(
18
- self,
19
- tsv_path: Union[str, Path],
20
- *,
21
- premade_mmap_dir: Union[str, Path],
22
- data_dim: dict[str, int],
23
- ):
24
- super().__init__()
25
-
26
- self.data_dim = data_dim
27
- self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
28
- self.ids = [d['id'] for d in self.df_list]
29
-
30
- log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
31
- # load precomputed memory mapped tensors
32
- premade_mmap_dir = Path(premade_mmap_dir)
33
- td = TensorDict.load_memmap(premade_mmap_dir)
34
- log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
35
- self.mean = td['mean']
36
- self.std = td['std']
37
- self.clip_features = td['clip_features']
38
- self.sync_features = td['sync_features']
39
- self.text_features = td['text_features']
40
-
41
- if local_rank == 0:
42
- log.info(f'Loaded {len(self)} samples.')
43
- log.info(f'Loaded mean: {self.mean.shape}.')
44
- log.info(f'Loaded std: {self.std.shape}.')
45
- log.info(f'Loaded clip_features: {self.clip_features.shape}.')
46
- log.info(f'Loaded sync_features: {self.sync_features.shape}.')
47
- log.info(f'Loaded text_features: {self.text_features.shape}.')
48
-
49
- assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
50
- f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
51
- assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
52
- f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
53
-
54
- assert self.clip_features.shape[1] == self.data_dim['clip_seq_len'], \
55
- f'{self.clip_features.shape[1]} != {self.data_dim["clip_seq_len"]}'
56
- assert self.sync_features.shape[1] == self.data_dim['sync_seq_len'], \
57
- f'{self.sync_features.shape[1]} != {self.data_dim["sync_seq_len"]}'
58
- assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
59
- f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
60
-
61
- assert self.clip_features.shape[-1] == self.data_dim['clip_dim'], \
62
- f'{self.clip_features.shape[-1]} != {self.data_dim["clip_dim"]}'
63
- assert self.sync_features.shape[-1] == self.data_dim['sync_dim'], \
64
- f'{self.sync_features.shape[-1]} != {self.data_dim["sync_dim"]}'
65
- assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
66
- f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
67
-
68
- self.video_exist = torch.tensor(1, dtype=torch.bool)
69
- self.text_exist = torch.tensor(1, dtype=torch.bool)
70
-
71
- def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
72
- latents = self.mean
73
- return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
74
-
75
- def get_memory_mapped_tensor(self) -> TensorDict:
76
- td = TensorDict({
77
- 'mean': self.mean,
78
- 'std': self.std,
79
- 'clip_features': self.clip_features,
80
- 'sync_features': self.sync_features,
81
- 'text_features': self.text_features,
82
- })
83
- return td
84
-
85
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
86
- data = {
87
- 'id': self.df_list[idx]['id'],
88
- 'a_mean': self.mean[idx],
89
- 'a_std': self.std[idx],
90
- 'clip_features': self.clip_features[idx],
91
- 'sync_features': self.sync_features[idx],
92
- 'text_features': self.text_features[idx],
93
- 'caption': self.df_list[idx]['label'],
94
- 'video_exist': self.video_exist,
95
- 'text_exist': self.text_exist,
96
- }
97
-
98
- return data
99
-
100
- def __len__(self):
101
- return len(self.ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/extraction/__init__.py DELETED
File without changes
mmaudio/data/extraction/vgg_sound.py DELETED
@@ -1,193 +0,0 @@
1
- import logging
2
- import os
3
- from pathlib import Path
4
- from typing import Optional, Union
5
-
6
- import pandas as pd
7
- import torch
8
- import torchaudio
9
- from torch.utils.data.dataset import Dataset
10
- from torchvision.transforms import v2
11
- from torio.io import StreamingMediaDecoder
12
-
13
- from mmaudio.utils.dist_utils import local_rank
14
-
15
- log = logging.getLogger()
16
-
17
- _CLIP_SIZE = 384
18
- _CLIP_FPS = 8.0
19
-
20
- _SYNC_SIZE = 224
21
- _SYNC_FPS = 25.0
22
-
23
-
24
- class VGGSound(Dataset):
25
-
26
- def __init__(
27
- self,
28
- root: Union[str, Path],
29
- *,
30
- tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
31
- sample_rate: int = 16_000,
32
- duration_sec: float = 8.0,
33
- audio_samples: Optional[int] = None,
34
- normalize_audio: bool = False,
35
- ):
36
- self.root = Path(root)
37
- self.normalize_audio = normalize_audio
38
- if audio_samples is None:
39
- self.audio_samples = int(sample_rate * duration_sec)
40
- else:
41
- self.audio_samples = audio_samples
42
- effective_duration = audio_samples / sample_rate
43
- # make sure the duration is close enough, within 15ms
44
- assert abs(effective_duration - duration_sec) < 0.015, \
45
- f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
46
-
47
- videos = sorted(os.listdir(self.root))
48
- videos = set([Path(v).stem for v in videos]) # remove extensions
49
- self.labels = {}
50
- self.videos = []
51
- missing_videos = []
52
-
53
- # read the tsv for subset information
54
- df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
55
- for record in df_list:
56
- id = record['id']
57
- label = record['label']
58
- if id in videos:
59
- self.labels[id] = label
60
- self.videos.append(id)
61
- else:
62
- missing_videos.append(id)
63
-
64
- if local_rank == 0:
65
- log.info(f'{len(videos)} videos found in {root}')
66
- log.info(f'{len(self.videos)} videos found in {tsv_path}')
67
- log.info(f'{len(missing_videos)} videos missing in {root}')
68
-
69
- self.sample_rate = sample_rate
70
- self.duration_sec = duration_sec
71
-
72
- self.expected_audio_length = audio_samples
73
- self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
74
- self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
75
-
76
- self.clip_transform = v2.Compose([
77
- v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
78
- v2.ToImage(),
79
- v2.ToDtype(torch.float32, scale=True),
80
- ])
81
-
82
- self.sync_transform = v2.Compose([
83
- v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
84
- v2.CenterCrop(_SYNC_SIZE),
85
- v2.ToImage(),
86
- v2.ToDtype(torch.float32, scale=True),
87
- v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
88
- ])
89
-
90
- self.resampler = {}
91
-
92
- def sample(self, idx: int) -> dict[str, torch.Tensor]:
93
- video_id = self.videos[idx]
94
- label = self.labels[video_id]
95
-
96
- reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
97
- reader.add_basic_video_stream(
98
- frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
99
- frame_rate=_CLIP_FPS,
100
- format='rgb24',
101
- )
102
- reader.add_basic_video_stream(
103
- frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
104
- frame_rate=_SYNC_FPS,
105
- format='rgb24',
106
- )
107
- reader.add_basic_audio_stream(frames_per_chunk=2**30, )
108
-
109
- reader.fill_buffer()
110
- data_chunk = reader.pop_chunks()
111
-
112
- clip_chunk = data_chunk[0]
113
- sync_chunk = data_chunk[1]
114
- audio_chunk = data_chunk[2]
115
-
116
- if clip_chunk is None:
117
- raise RuntimeError(f'CLIP video returned None {video_id}')
118
- if clip_chunk.shape[0] < self.clip_expected_length:
119
- raise RuntimeError(
120
- f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
121
- )
122
-
123
- if sync_chunk is None:
124
- raise RuntimeError(f'Sync video returned None {video_id}')
125
- if sync_chunk.shape[0] < self.sync_expected_length:
126
- raise RuntimeError(
127
- f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
128
- )
129
-
130
- # process audio
131
- sample_rate = int(reader.get_out_stream_info(2).sample_rate)
132
- audio_chunk = audio_chunk.transpose(0, 1)
133
- audio_chunk = audio_chunk.mean(dim=0) # mono
134
- if self.normalize_audio:
135
- abs_max = audio_chunk.abs().max()
136
- audio_chunk = audio_chunk / abs_max * 0.95
137
- if abs_max <= 1e-6:
138
- raise RuntimeError(f'Audio is silent {video_id}')
139
-
140
- # resample
141
- if sample_rate == self.sample_rate:
142
- audio_chunk = audio_chunk
143
- else:
144
- if sample_rate not in self.resampler:
145
- # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
146
- self.resampler[sample_rate] = torchaudio.transforms.Resample(
147
- sample_rate,
148
- self.sample_rate,
149
- lowpass_filter_width=64,
150
- rolloff=0.9475937167399596,
151
- resampling_method='sinc_interp_kaiser',
152
- beta=14.769656459379492,
153
- )
154
- audio_chunk = self.resampler[sample_rate](audio_chunk)
155
-
156
- if audio_chunk.shape[0] < self.expected_audio_length:
157
- raise RuntimeError(f'Audio too short {video_id}')
158
- audio_chunk = audio_chunk[:self.expected_audio_length]
159
-
160
- # truncate the video
161
- clip_chunk = clip_chunk[:self.clip_expected_length]
162
- if clip_chunk.shape[0] != self.clip_expected_length:
163
- raise RuntimeError(f'CLIP video wrong length {video_id}, '
164
- f'expected {self.clip_expected_length}, '
165
- f'got {clip_chunk.shape[0]}')
166
- clip_chunk = self.clip_transform(clip_chunk)
167
-
168
- sync_chunk = sync_chunk[:self.sync_expected_length]
169
- if sync_chunk.shape[0] != self.sync_expected_length:
170
- raise RuntimeError(f'Sync video wrong length {video_id}, '
171
- f'expected {self.sync_expected_length}, '
172
- f'got {sync_chunk.shape[0]}')
173
- sync_chunk = self.sync_transform(sync_chunk)
174
-
175
- data = {
176
- 'id': video_id,
177
- 'caption': label,
178
- 'audio': audio_chunk,
179
- 'clip_video': clip_chunk,
180
- 'sync_video': sync_chunk,
181
- }
182
-
183
- return data
184
-
185
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
186
- try:
187
- return self.sample(idx)
188
- except Exception as e:
189
- log.error(f'Error loading video {self.videos[idx]}: {e}')
190
- return None
191
-
192
- def __len__(self):
193
- return len(self.labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/extraction/wav_dataset.py DELETED
@@ -1,132 +0,0 @@
1
- import logging
2
- import os
3
- from pathlib import Path
4
- from typing import Union
5
-
6
- import open_clip
7
- import pandas as pd
8
- import torch
9
- import torchaudio
10
- from torch.utils.data.dataset import Dataset
11
-
12
- log = logging.getLogger()
13
-
14
-
15
- class WavTextClipsDataset(Dataset):
16
-
17
- def __init__(
18
- self,
19
- root: Union[str, Path],
20
- *,
21
- captions_tsv: Union[str, Path],
22
- clips_tsv: Union[str, Path],
23
- sample_rate: int,
24
- num_samples: int,
25
- normalize_audio: bool = False,
26
- reject_silent: bool = False,
27
- tokenizer_id: str = 'ViT-H-14-378-quickgelu',
28
- ):
29
- self.root = Path(root)
30
- self.sample_rate = sample_rate
31
- self.num_samples = num_samples
32
- self.normalize_audio = normalize_audio
33
- self.reject_silent = reject_silent
34
- self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
35
-
36
- audios = sorted(os.listdir(self.root))
37
- audios = set([
38
- Path(audio).stem for audio in audios
39
- if audio.endswith('.wav') or audio.endswith('.flac')
40
- ])
41
- self.captions = {}
42
-
43
- # read the caption tsv
44
- df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
45
- for record in df_list:
46
- id = record['id']
47
- caption = record['caption']
48
- self.captions[id] = caption
49
-
50
- # read the clip tsv
51
- df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
52
- 'id': str,
53
- 'name': str
54
- }).to_dict('records')
55
- self.clips = []
56
- for record in df_list:
57
- record['id'] = record['id']
58
- record['name'] = record['name']
59
- id = record['id']
60
- name = record['name']
61
- if name not in self.captions:
62
- log.warning(f'Audio {name} not found in {captions_tsv}')
63
- continue
64
- record['caption'] = self.captions[name]
65
- self.clips.append(record)
66
-
67
- log.info(f'Found {len(self.clips)} audio files in {self.root}')
68
-
69
- self.resampler = {}
70
-
71
- def __getitem__(self, idx: int) -> torch.Tensor:
72
- try:
73
- clip = self.clips[idx]
74
- audio_name = clip['name']
75
- audio_id = clip['id']
76
- caption = clip['caption']
77
- start_sample = clip['start_sample']
78
- end_sample = clip['end_sample']
79
-
80
- audio_path = self.root / f'{audio_name}.flac'
81
- if not audio_path.exists():
82
- audio_path = self.root / f'{audio_name}.wav'
83
- assert audio_path.exists()
84
-
85
- audio_chunk, sample_rate = torchaudio.load(audio_path)
86
- audio_chunk = audio_chunk.mean(dim=0) # mono
87
- abs_max = audio_chunk.abs().max()
88
- if self.normalize_audio:
89
- audio_chunk = audio_chunk / abs_max * 0.95
90
-
91
- if self.reject_silent and abs_max < 1e-6:
92
- log.warning(f'Rejecting silent audio')
93
- return None
94
-
95
- audio_chunk = audio_chunk[start_sample:end_sample]
96
-
97
- # resample
98
- if sample_rate == self.sample_rate:
99
- audio_chunk = audio_chunk
100
- else:
101
- if sample_rate not in self.resampler:
102
- # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
103
- self.resampler[sample_rate] = torchaudio.transforms.Resample(
104
- sample_rate,
105
- self.sample_rate,
106
- lowpass_filter_width=64,
107
- rolloff=0.9475937167399596,
108
- resampling_method='sinc_interp_kaiser',
109
- beta=14.769656459379492,
110
- )
111
- audio_chunk = self.resampler[sample_rate](audio_chunk)
112
-
113
- if audio_chunk.shape[0] < self.num_samples:
114
- raise ValueError('Audio is too short')
115
- audio_chunk = audio_chunk[:self.num_samples]
116
-
117
- tokens = self.tokenizer([caption])[0]
118
-
119
- output = {
120
- 'waveform': audio_chunk,
121
- 'id': audio_id,
122
- 'caption': caption,
123
- 'tokens': tokens,
124
- }
125
-
126
- return output
127
- except Exception as e:
128
- log.error(f'Error reading {audio_path}: {e}')
129
- return None
130
-
131
- def __len__(self):
132
- return len(self.clips)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/mm_dataset.py DELETED
@@ -1,45 +0,0 @@
1
- import bisect
2
-
3
- import torch
4
- from torch.utils.data.dataset import Dataset
5
-
6
-
7
- # modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
8
- class MultiModalDataset(Dataset):
9
- datasets: list[Dataset]
10
- cumulative_sizes: list[int]
11
-
12
- @staticmethod
13
- def cumsum(sequence):
14
- r, s = [], 0
15
- for e in sequence:
16
- l = len(e)
17
- r.append(l + s)
18
- s += l
19
- return r
20
-
21
- def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
22
- super().__init__()
23
- self.video_datasets = list(video_datasets)
24
- self.audio_datasets = list(audio_datasets)
25
- self.datasets = self.video_datasets + self.audio_datasets
26
-
27
- self.cumulative_sizes = self.cumsum(self.datasets)
28
-
29
- def __len__(self):
30
- return self.cumulative_sizes[-1]
31
-
32
- def __getitem__(self, idx):
33
- if idx < 0:
34
- if -idx > len(self):
35
- raise ValueError("absolute value of index should not exceed dataset length")
36
- idx = len(self) + idx
37
- dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
38
- if dataset_idx == 0:
39
- sample_idx = idx
40
- else:
41
- sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
42
- return self.datasets[dataset_idx][sample_idx]
43
-
44
- def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
45
- return self.video_datasets[0].compute_latent_stats()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/utils.py DELETED
@@ -1,148 +0,0 @@
1
- import logging
2
- import os
3
- import random
4
- import tempfile
5
- from pathlib import Path
6
- from typing import Any, Optional, Union
7
-
8
- import torch
9
- import torch.distributed as dist
10
- from tensordict import MemoryMappedTensor
11
- from torch.utils.data import DataLoader
12
- from torch.utils.data.dataset import Dataset
13
- from tqdm import tqdm
14
-
15
- from mmaudio.utils.dist_utils import local_rank, world_size
16
-
17
- scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
18
- shm_path = Path('/dev/shm')
19
-
20
- log = logging.getLogger()
21
-
22
-
23
- def reseed(seed):
24
- random.seed(seed)
25
- torch.manual_seed(seed)
26
-
27
-
28
- def local_scatter_torch(obj: Optional[Any]):
29
- if world_size == 1:
30
- # Just one worker. Do nothing.
31
- return obj
32
-
33
- array = [obj] * world_size
34
- target_array = [None]
35
- if local_rank == 0:
36
- dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
37
- else:
38
- dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
39
- return target_array[0]
40
-
41
-
42
- class ShardDataset(Dataset):
43
-
44
- def __init__(self, root):
45
- self.root = root
46
- self.shards = sorted(os.listdir(root))
47
-
48
- def __len__(self):
49
- return len(self.shards)
50
-
51
- def __getitem__(self, idx):
52
- return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
53
-
54
-
55
- def get_tmp_dir(in_memory: bool) -> Path:
56
- return shm_path if in_memory else scratch_path
57
-
58
-
59
- def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
60
- in_memory: bool) -> MemoryMappedTensor:
61
- if local_rank == 0:
62
- with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
63
- log.info(f'Loading shards from {data_path} into {f.name}...')
64
- data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
65
- data = share_tensor_to_all(data)
66
- torch.distributed.barrier()
67
- f.close() # why does the context manager not close the file for me?
68
- else:
69
- log.info('Waiting for the data to be shared with me...')
70
- data = share_tensor_to_all(None)
71
- torch.distributed.barrier()
72
-
73
- return data
74
-
75
-
76
- def load_shards(
77
- data_path: Union[str, Path],
78
- ids: list[int],
79
- *,
80
- tmp_file_path: str,
81
- ) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
82
-
83
- id_set = set(ids)
84
- shards = sorted(os.listdir(data_path))
85
- log.info(f'Found {len(shards)} shards in {data_path}.')
86
- first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
87
-
88
- log.info(f'Rank {local_rank} created file {tmp_file_path}')
89
- first_item = next(iter(first_shard.values()))
90
- log.info(f'First item shape: {first_item.shape}')
91
- mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
92
- dtype=torch.float32,
93
- filename=tmp_file_path,
94
- existsok=True)
95
- total_count = 0
96
- used_index = set()
97
- id_indexing = {i: idx for idx, i in enumerate(ids)}
98
- # faster with no workers; otherwise we need to set_sharing_strategy('file_system')
99
- loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
100
- for data in tqdm(loader, desc='Loading shards'):
101
- for i, v in data.items():
102
- if i not in id_set:
103
- continue
104
-
105
- # tensor_index = ids.index(i)
106
- tensor_index = id_indexing[i]
107
- if tensor_index in used_index:
108
- raise ValueError(f'Duplicate id {i} found in {data_path}.')
109
- used_index.add(tensor_index)
110
- mm_tensor[tensor_index] = v
111
- total_count += 1
112
-
113
- assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
114
- log.info(f'Loaded {total_count} tensors from {data_path}.')
115
-
116
- return mm_tensor
117
-
118
-
119
- def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
120
- """
121
- x: the tensor to be shared; None if local_rank != 0
122
- return: the shared tensor
123
- """
124
-
125
- # there is no need to share your stuff with anyone if you are alone; must be in memory
126
- if world_size == 1:
127
- return x
128
-
129
- if local_rank == 0:
130
- assert x is not None, 'x must not be None if local_rank == 0'
131
- else:
132
- assert x is None, 'x must be None if local_rank != 0'
133
-
134
- if local_rank == 0:
135
- filename = x.filename
136
- meta_information = (filename, x.shape, x.dtype)
137
- else:
138
- meta_information = None
139
-
140
- filename, data_shape, data_type = local_scatter_torch(meta_information)
141
- if local_rank == 0:
142
- data = x
143
- else:
144
- data = MemoryMappedTensor.from_filename(filename=filename,
145
- dtype=data_type,
146
- shape=data_shape)
147
-
148
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/eval_utils.py CHANGED
@@ -1,18 +1,16 @@
1
  import dataclasses
2
  import logging
3
  from pathlib import Path
4
- from typing import Optional, Tuple, List, Dict
5
 
6
- import numpy as np
7
  import torch
8
  from colorlog import ColoredFormatter
9
- from PIL import Image
10
  from torchvision.transforms import v2
11
 
12
- from mmaudio.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio
13
  from mmaudio.model.flow_matching import FlowMatching
14
  from mmaudio.model.networks import MMAudio
15
- from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
16
  from mmaudio.model.utils.features_utils import FeaturesUtils
17
  from mmaudio.utils.download_utils import download_model_if_needed
18
 
@@ -26,7 +24,7 @@ class ModelConfig:
26
  vae_path: Path
27
  bigvgan_16k_path: Optional[Path]
28
  mode: str
29
- synchformer_ckpt: Path = Path('./pretrained/v2a/mmaudio/ext_weights/synchformer_state_dict.pth')
30
 
31
  @property
32
  def seq_cfg(self) -> SequenceConfig:
@@ -44,31 +42,31 @@ class ModelConfig:
44
 
45
 
46
  small_16k = ModelConfig(model_name='small_16k',
47
- model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_small_16k.pth'),
48
- vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-16.pth'),
49
- bigvgan_16k_path=Path('./pretrained/v2a/mmaudio/ext_weights/best_netG.pt'),
50
  mode='16k')
51
  small_44k = ModelConfig(model_name='small_44k',
52
- model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_small_44k.pth'),
53
- vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'),
54
  bigvgan_16k_path=None,
55
  mode='44k')
56
  medium_44k = ModelConfig(model_name='medium_44k',
57
- model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_medium_44k.pth'),
58
- vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'),
59
  bigvgan_16k_path=None,
60
  mode='44k')
61
  large_44k = ModelConfig(model_name='large_44k',
62
- model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_large_44k.pth'),
63
- vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'),
64
  bigvgan_16k_path=None,
65
  mode='44k')
66
  large_44k_v2 = ModelConfig(model_name='large_44k_v2',
67
- model_path=Path('./pretrained/v2a/mmaudio/weights/mmaudio_large_44k_v2.pth'),
68
- vae_path=Path('./pretrained/v2a/mmaudio/ext_weights/v1-44.pth'),
69
  bigvgan_16k_path=None,
70
  mode='44k')
71
- all_model_cfg: Dict[str, ModelConfig] = {
72
  'small_16k': small_16k,
73
  'small_44k': small_44k,
74
  'medium_44k': medium_44k,
@@ -80,9 +78,9 @@ all_model_cfg: Dict[str, ModelConfig] = {
80
  def generate(
81
  clip_video: Optional[torch.Tensor],
82
  sync_video: Optional[torch.Tensor],
83
- text: Optional[List[str]],
84
  *,
85
- negative_text: Optional[List[str]] = None,
86
  feature_utils: FeaturesUtils,
87
  net: MMAudio,
88
  fm: FlowMatching,
@@ -90,7 +88,6 @@ def generate(
90
  cfg_strength: float,
91
  clip_batch_size_multiplier: int = 40,
92
  sync_batch_size_multiplier: int = 40,
93
- image_input: bool = False,
94
  ) -> torch.Tensor:
95
  device = feature_utils.device
96
  dtype = feature_utils.dtype
@@ -101,12 +98,10 @@ def generate(
101
  clip_features = feature_utils.encode_video_with_clip(clip_video,
102
  batch_size=bs *
103
  clip_batch_size_multiplier)
104
- if image_input:
105
- clip_features = clip_features.expand(-1, net.clip_seq_len, -1)
106
  else:
107
  clip_features = net.get_empty_clip_sequence(bs)
108
 
109
- if sync_video is not None and not image_input:
110
  sync_video = sync_video.to(device, dtype, non_blocking=True)
111
  sync_features = feature_utils.encode_video_with_sync(sync_video,
112
  batch_size=bs *
@@ -144,7 +139,7 @@ def generate(
144
  return audio
145
 
146
 
147
- LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s"
148
 
149
 
150
  def setup_eval_logging(log_level: int = logging.INFO):
@@ -158,14 +153,12 @@ def setup_eval_logging(log_level: int = logging.INFO):
158
  log.addHandler(stream)
159
 
160
 
161
- _CLIP_SIZE = 384
162
- _CLIP_FPS = 8.0
163
-
164
- _SYNC_SIZE = 224
165
- _SYNC_FPS = 25.0
166
-
167
-
168
  def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
 
 
 
 
 
169
 
170
  clip_transform = v2.Compose([
171
  v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
@@ -220,36 +213,5 @@ def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = Tr
220
  return video_info
221
 
222
 
223
- def load_image(image_path: Path) -> VideoInfo:
224
- clip_transform = v2.Compose([
225
- v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
226
- v2.ToImage(),
227
- v2.ToDtype(torch.float32, scale=True),
228
- ])
229
-
230
- sync_transform = v2.Compose([
231
- v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
232
- v2.CenterCrop(_SYNC_SIZE),
233
- v2.ToImage(),
234
- v2.ToDtype(torch.float32, scale=True),
235
- v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
236
- ])
237
-
238
- frame = np.array(Image.open(image_path))
239
-
240
- clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
241
- sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
242
-
243
- clip_frames = clip_transform(clip_chunk)
244
- sync_frames = sync_transform(sync_chunk)
245
-
246
- video_info = ImageInfo(
247
- clip_frames=clip_frames,
248
- sync_frames=sync_frames,
249
- original_frame=frame,
250
- )
251
- return video_info
252
-
253
-
254
  def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
255
  reencode_with_audio(video_info, output_path, audio, sampling_rate)
 
1
  import dataclasses
2
  import logging
3
  from pathlib import Path
4
+ from typing import Optional
5
 
 
6
  import torch
7
  from colorlog import ColoredFormatter
 
8
  from torchvision.transforms import v2
9
 
10
+ from mmaudio.data.av_utils import VideoInfo, read_frames, reencode_with_audio
11
  from mmaudio.model.flow_matching import FlowMatching
12
  from mmaudio.model.networks import MMAudio
13
+ from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig)
14
  from mmaudio.model.utils.features_utils import FeaturesUtils
15
  from mmaudio.utils.download_utils import download_model_if_needed
16
 
 
24
  vae_path: Path
25
  bigvgan_16k_path: Optional[Path]
26
  mode: str
27
+ synchformer_ckpt: Path = Path('./ext_weights/synchformer_state_dict.pth')
28
 
29
  @property
30
  def seq_cfg(self) -> SequenceConfig:
 
42
 
43
 
44
  small_16k = ModelConfig(model_name='small_16k',
45
+ model_path=Path('./weights/mmaudio_small_16k.pth'),
46
+ vae_path=Path('./ext_weights/v1-16.pth'),
47
+ bigvgan_16k_path=Path('./ext_weights/best_netG.pt'),
48
  mode='16k')
49
  small_44k = ModelConfig(model_name='small_44k',
50
+ model_path=Path('./weights/mmaudio_small_44k.pth'),
51
+ vae_path=Path('./ext_weights/v1-44.pth'),
52
  bigvgan_16k_path=None,
53
  mode='44k')
54
  medium_44k = ModelConfig(model_name='medium_44k',
55
+ model_path=Path('./weights/mmaudio_medium_44k.pth'),
56
+ vae_path=Path('./ext_weights/v1-44.pth'),
57
  bigvgan_16k_path=None,
58
  mode='44k')
59
  large_44k = ModelConfig(model_name='large_44k',
60
+ model_path=Path('./weights/mmaudio_large_44k.pth'),
61
+ vae_path=Path('./ext_weights/v1-44.pth'),
62
  bigvgan_16k_path=None,
63
  mode='44k')
64
  large_44k_v2 = ModelConfig(model_name='large_44k_v2',
65
+ model_path=Path('./weights/mmaudio_large_44k_v2.pth'),
66
+ vae_path=Path('./ext_weights/v1-44.pth'),
67
  bigvgan_16k_path=None,
68
  mode='44k')
69
+ all_model_cfg: dict[str, ModelConfig] = {
70
  'small_16k': small_16k,
71
  'small_44k': small_44k,
72
  'medium_44k': medium_44k,
 
78
  def generate(
79
  clip_video: Optional[torch.Tensor],
80
  sync_video: Optional[torch.Tensor],
81
+ text: Optional[list[str]],
82
  *,
83
+ negative_text: Optional[list[str]] = None,
84
  feature_utils: FeaturesUtils,
85
  net: MMAudio,
86
  fm: FlowMatching,
 
88
  cfg_strength: float,
89
  clip_batch_size_multiplier: int = 40,
90
  sync_batch_size_multiplier: int = 40,
 
91
  ) -> torch.Tensor:
92
  device = feature_utils.device
93
  dtype = feature_utils.dtype
 
98
  clip_features = feature_utils.encode_video_with_clip(clip_video,
99
  batch_size=bs *
100
  clip_batch_size_multiplier)
 
 
101
  else:
102
  clip_features = net.get_empty_clip_sequence(bs)
103
 
104
+ if sync_video is not None:
105
  sync_video = sync_video.to(device, dtype, non_blocking=True)
106
  sync_features = feature_utils.encode_video_with_sync(sync_video,
107
  batch_size=bs *
 
139
  return audio
140
 
141
 
142
+ LOGFORMAT = " %(log_color)s%(levelname)-8s%(reset)s | %(log_color)s%(message)s%(reset)s"
143
 
144
 
145
  def setup_eval_logging(log_level: int = logging.INFO):
 
153
  log.addHandler(stream)
154
 
155
 
 
 
 
 
 
 
 
156
  def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
157
+ _CLIP_SIZE = 384
158
+ _CLIP_FPS = 8.0
159
+
160
+ _SYNC_SIZE = 224
161
+ _SYNC_FPS = 25.0
162
 
163
  clip_transform = v2.Compose([
164
  v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
 
213
  return video_info
214
 
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
217
  reencode_with_audio(video_info, output_path, audio, sampling_rate)
mmaudio/ext/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (191 Bytes)
 
mmaudio/ext/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (189 Bytes)
 
mmaudio/ext/__pycache__/mel_converter.cpython-310.pyc DELETED
Binary file (2.87 kB)
 
mmaudio/ext/__pycache__/mel_converter.cpython-38.pyc DELETED
Binary file (2.84 kB)
 
mmaudio/ext/__pycache__/rotary_embeddings.cpython-310.pyc DELETED
Binary file (1.48 kB)
 
mmaudio/ext/__pycache__/rotary_embeddings.cpython-38.pyc DELETED
Binary file (1.45 kB)
 
mmaudio/ext/autoencoder/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (256 Bytes)
 
mmaudio/ext/autoencoder/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (254 Bytes)
 
mmaudio/ext/autoencoder/__pycache__/autoencoder.cpython-310.pyc DELETED
Binary file (2.14 kB)