Spaces:
Running
on
T4
Running
on
T4
Major Update
Browse files-metadata
-new models
-background and style
-readme
-requirements
-gradio extension
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- CHANGELOG.md +61 -12
- README.md +48 -7
- app.py +270 -202
- assets/KuritaSurnLogox64.png +0 -0
- assets/Vermilion-Musical-Notes-Typography-No-Background.svg +0 -0
- assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 +0 -0
- assets/icon_delete.png +0 -0
- assets/icon_download.png +0 -0
- assets/icon_refresh.png +0 -0
- assets/logo_animation_256.gif +3 -0
- assets/screenshot.png +3 -0
- assets/sirens_and_a_humming_engine_approach_and_pass.mp3 +0 -0
- audiocraft/__init__.py +1 -1
- audiocraft/data/__init__.py +1 -1
- audiocraft/data/audio.py +86 -1
- audiocraft/data/audio_dataset.py +93 -31
- audiocraft/data/audio_utils.py +9 -6
- audiocraft/data/info_audio_dataset.py +110 -0
- audiocraft/data/zip.py +8 -6
- audiocraft/environment.py +176 -0
- audiocraft/models/__init__.py +7 -0
- audiocraft/models/builders.py +89 -49
- audiocraft/models/encodec.py +267 -63
- audiocraft/models/lm.py +35 -29
- audiocraft/models/loaders.py +146 -5
- audiocraft/models/musicgen.py +108 -34
- audiocraft/models/unet.py +214 -0
- audiocraft/modules/__init__.py +1 -0
- audiocraft/modules/chroma.py +66 -0
- audiocraft/modules/codebooks_patterns.py +17 -12
- audiocraft/modules/conditioners.py +724 -298
- audiocraft/modules/conv.py +1 -1
- audiocraft/modules/diffusion_schedule.py +272 -0
- audiocraft/modules/rope.py +20 -19
- audiocraft/modules/transformer.py +36 -38
- audiocraft/quantization/core_vq.py +8 -3
- audiocraft/utils/cache.py +324 -0
- audiocraft/utils/cluster.py +75 -0
- audiocraft/utils/export.py +50 -27
- audiocraft/utils/export_legacy.py +56 -0
- audiocraft/utils/extend.py +5 -4
- audiocraft/utils/utils.py +68 -2
- modules/file_utils.py +91 -0
- modules/gradio.py +272 -0
- modules/user_history.py +598 -0
- modules/version_info.py +123 -0
- pre-requirements.txt +1 -1
- requirements.txt +28 -12
- style_20250331.css +215 -0
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
assets/logo_animation_256.gif filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/screenshot.png filter=lfs diff=lfs merge=lfs -text
|
CHANGELOG.md
CHANGED
@@ -1,13 +1,69 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
|
5 |
-
|
6 |
|
7 |
-
|
|
|
8 |
|
9 |
|
10 |
-
## [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
Improved demo, fixed top p (thanks @jnordberg).
|
13 |
|
@@ -24,10 +80,3 @@ Note that other implementations exist: https://github.com/camenduru/MusicGen-col
|
|
24 |
## [0.0.1] - 2023-06-09
|
25 |
|
26 |
Initial release, with model evaluation only.
|
27 |
-
|
28 |
-
|
29 |
-
# Changelog
|
30 |
-
|
31 |
-
All notable changes to this project will be documented in this file.
|
32 |
-
|
33 |
-
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
|
|
|
1 |
+
# Changelog
|
2 |
+
|
3 |
+
All notable changes to this project will be documented in this file.
|
4 |
+
|
5 |
+
## [1.2.Surn] - 2025-04-02
|
6 |
+
|
7 |
+
Implemented Unlimited Music Generation (UMG) with the [hf checkpoints](https://huggingface.co/facebook/unlimited-music-generation).
|
8 |
+
|
9 |
+
## [1.4.0a2] - 2025-01-14
|
10 |
+
|
11 |
+
Add training and inference code for JASCO (https://arxiv.org/abs/2406.10970) along with the [hf checkpoints](https://huggingface.co/facebook/jasco-chords-drums-melody-1B).
|
12 |
+
|
13 |
+
## [1.4.0a1] - 2024-06-03
|
14 |
+
|
15 |
+
Adding new metric PesqMetric ([Perceptual Evaluation of Speech Quality](https://doi.org/10.5281/zenodo.6549559))
|
16 |
+
|
17 |
+
Adding multiple audio augmentation functions: generating pink noises, up-/downsampling, low-/highpass filtering, banpass filtering, smoothing, duck masking, boosting. All are wrapped in the `audiocraft.utils.audio_effects.AudioEffects` and can be called with the API `audiocraft.utils.audio_effects.select_audio_effects`.
|
18 |
+
|
19 |
+
Add training code for AudioSeal (https://arxiv.org/abs/2401.17264) along with the [hf checkpoints]( https://huggingface.co/facebook/audioseal).
|
20 |
+
|
21 |
+
## [1.3.0] - 2024-05-02
|
22 |
+
|
23 |
+
Adding the MAGNeT model (https://arxiv.org/abs/2401.04577) along with hf checkpoints and a gradio demo app.
|
24 |
+
|
25 |
+
Typo fixes.
|
26 |
+
|
27 |
+
Fixing setup.py to install only audiocraft, not the unit tests and scripts.
|
28 |
+
|
29 |
+
Fix FSDP support with PyTorch 2.1.0.
|
30 |
+
|
31 |
+
## [1.2.0] - 2024-01-11
|
32 |
|
33 |
+
Adding stereo models.
|
34 |
|
35 |
+
Fixed the commitment loss, which was until now only applied to the first RVQ layer.
|
36 |
|
37 |
+
Removed compression model state from the LM checkpoints, for consistency, it
|
38 |
+
should always be loaded from the original `compression_model_checkpoint`.
|
39 |
|
40 |
|
41 |
+
## [1.1.0] - 2023-11-06
|
42 |
+
|
43 |
+
Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons.
|
44 |
+
|
45 |
+
Fixed DAC support with non default number of codebooks.
|
46 |
+
|
47 |
+
Fixed bug when `two_step_cfg` was overriden when calling `generate()`.
|
48 |
+
|
49 |
+
Fixed samples being always prompted with audio, rather than having both prompted and unprompted.
|
50 |
+
|
51 |
+
**Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release.
|
52 |
+
The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners.
|
53 |
+
We removed it, so you might need to retrain models.
|
54 |
+
|
55 |
+
**Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before).
|
56 |
+
|
57 |
+
**Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one
|
58 |
+
retrained a model with this pattern, so hopefully this won't impact you!
|
59 |
+
|
60 |
+
|
61 |
+
## [1.0.0] - 2023-09-07
|
62 |
+
|
63 |
+
Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
|
64 |
+
Added pretrained model for AudioGen and MultiBandDiffusion.
|
65 |
+
|
66 |
+
## [0.0.2] - 2023-08-01
|
67 |
|
68 |
Improved demo, fixed top p (thanks @jnordberg).
|
69 |
|
|
|
80 |
## [0.0.1] - 2023-06-09
|
81 |
|
82 |
Initial release, with model evaluation only.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -4,13 +4,21 @@ emoji: 🎼
|
|
4 |
colorFrom: gray
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
|
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: creativeml-openrail-m
|
11 |
tags:
|
12 |
- musicgen
|
13 |
- unlimited
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
---
|
15 |
|
16 |
[arxiv]: https://arxiv.org/abs/2306.05284
|
@@ -18,7 +26,18 @@ tags:
|
|
18 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
19 |
|
20 |
# UnlimitedMusicGen
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
# Audiocraft
|
24 |

|
@@ -46,12 +65,12 @@ Check out our [sample page][musicgen_samples] or test the available demo!
|
|
46 |
We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data.
|
47 |
|
48 |
## Installation
|
49 |
-
Audiocraft requires Python 3.9, PyTorch 2.
|
50 |
|
51 |
```shell
|
52 |
# Best to make sure you have torch installed first, in particular before installing xformers.
|
53 |
# Don't run this if you already have PyTorch installed.
|
54 |
-
pip install 'torch>=2.
|
55 |
# Then proceed to one of the following
|
56 |
pip install -U audiocraft # stable release
|
57 |
pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
|
@@ -60,7 +79,7 @@ pip install -e . # or if you cloned the repo locally
|
|
60 |
|
61 |
## Usage
|
62 |
We offer a number of way to interact with MusicGen:
|
63 |
-
1. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/
|
64 |
2. You can run the Gradio demo in Colab: [colab notebook](https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing).
|
65 |
3. You can use the gradio demo locally by running `python app.py`.
|
66 |
4. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally (if you have a GPU).
|
@@ -178,6 +197,25 @@ For more details on using the MusicGen model for inference using the 🤗 Transf
|
|
178 |
[MusicGen docs](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen) or the hands-on
|
179 |
[Google Colab](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/MusicGen.ipynb).
|
180 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
## Model Card
|
183 |
|
@@ -212,4 +250,7 @@ Check [@camenduru tutorial on Youtube](https://www.youtube.com/watch?v=EGfxuTy9E
|
|
212 |
* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
|
213 |
* The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
|
214 |
[arxiv]: https://arxiv.org/abs/2306.05284
|
215 |
-
|
|
|
|
|
|
|
|
4 |
colorFrom: gray
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.23.3
|
8 |
+
python_version: 3.12.8
|
9 |
app_file: app.py
|
10 |
+
pinned: true
|
11 |
license: creativeml-openrail-m
|
12 |
tags:
|
13 |
- musicgen
|
14 |
- unlimited
|
15 |
+
- user history
|
16 |
+
- metadata
|
17 |
+
hf_oauth: true
|
18 |
+
disable_embedding: true
|
19 |
+
short_description: 'unlimited Audio generation with a few added features '
|
20 |
+
thumbnail: >-
|
21 |
+
https://cdn-uploads.huggingface.co/production/uploads/6346595c9e5f0fe83fc60444/Z8E8OaKV84zuVAvvGpMDJ.png
|
22 |
---
|
23 |
|
24 |
[arxiv]: https://arxiv.org/abs/2306.05284
|
|
|
26 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
27 |
|
28 |
# UnlimitedMusicGen
|
29 |
+
Charles Fettinger's modification of the Audiocraft project to enable unlimited Audio generation. I have added a few features to the original project to enable this. I have also added a few features to the gradio interface to make it easier to use.
|
30 |
+
|
31 |
+
Please review my other AI relalated spaces at https://huggingface.co/Surn
|
32 |
+
|
33 |
+
Check your video's generative metadata with https://mediaarea.net/en/MediaInfo
|
34 |
+
|
35 |
+
Also note that I wrote an extension to Gradio for the waveform in the video after v4.48.0 removed it.
|
36 |
+
|
37 |
+
The key update here is in the extend utility. We segment melody input and then condition the next segment with current tensors and tensors from the current time in the conditioning melody file.
|
38 |
+
This allows us to follow the same arraingement of the original melody.
|
39 |
+
|
40 |
+
**Thank you Huggingface for the community grant to run this project**!!
|
41 |
|
42 |
# Audiocraft
|
43 |

|
|
|
65 |
We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data.
|
66 |
|
67 |
## Installation
|
68 |
+
Audiocraft requires Python 3.9, PyTorch 2.1.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following:
|
69 |
|
70 |
```shell
|
71 |
# Best to make sure you have torch installed first, in particular before installing xformers.
|
72 |
# Don't run this if you already have PyTorch installed.
|
73 |
+
pip install 'torch>=2.1'
|
74 |
# Then proceed to one of the following
|
75 |
pip install -U audiocraft # stable release
|
76 |
pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
|
|
|
79 |
|
80 |
## Usage
|
81 |
We offer a number of way to interact with MusicGen:
|
82 |
+
1. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/Surn/UnlimitedMusicGen) (huge thanks to all the HF team for their support).
|
83 |
2. You can run the Gradio demo in Colab: [colab notebook](https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing).
|
84 |
3. You can use the gradio demo locally by running `python app.py`.
|
85 |
4. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally (if you have a GPU).
|
|
|
197 |
[MusicGen docs](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen) or the hands-on
|
198 |
[Google Colab](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/MusicGen.ipynb).
|
199 |
|
200 |
+
## User History
|
201 |
+
|
202 |
+
User History is a plugin that you can add to your Spaces to cache generated images for your users.
|
203 |
+
|
204 |
+
Key features:
|
205 |
+
- 🤗 Sign in with Hugging Face
|
206 |
+
- Save generated image, video, audio and document files with their metadata: prompts, timestamp, hyper-parameters, etc.
|
207 |
+
- Export your history as zip.
|
208 |
+
- Delete your history to respect privacy.
|
209 |
+
- Compatible with Persistent Storage for long-term storage.
|
210 |
+
- Admin panel to check configuration and disk usage .
|
211 |
+
|
212 |
+
Useful links:
|
213 |
+
- Demo: https://huggingface.co/spaces/Wauplin/gradio-user-history
|
214 |
+
- README: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/README.md
|
215 |
+
- Source file: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/user_history.py
|
216 |
+
- Discussions: https://huggingface.co/spaces/Wauplin/gradio-user-history/discussions
|
217 |
+
|
218 |
+

|
219 |
|
220 |
## Model Card
|
221 |
|
|
|
250 |
* The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
|
251 |
* The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
|
252 |
[arxiv]: https://arxiv.org/abs/2306.05284
|
253 |
+
|
254 |
+
[arxiv]: https://arxiv.org/abs/2306.05284
|
255 |
+
[musicgen_samples]: https://ai.honu.io/papers/musicgen/
|
256 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -21,11 +21,17 @@ from audiocraft.models import MusicGen
|
|
21 |
from audiocraft.data.audio import audio_write
|
22 |
from audiocraft.data.audio_utils import apply_fade, apply_tafade, apply_splice_effect
|
23 |
from audiocraft.utils.extend import generate_music_segments, add_settings_to_image, INTERRUPTING
|
|
|
24 |
import numpy as np
|
25 |
import random
|
26 |
-
|
|
|
27 |
#from typing import List, Union
|
28 |
import librosa
|
|
|
|
|
|
|
|
|
29 |
|
30 |
MODEL = None
|
31 |
MODELS = None
|
@@ -35,7 +41,12 @@ UNLOAD_MODEL = False
|
|
35 |
MOVE_TO_CPU = False
|
36 |
MAX_PROMPT_INDEX = 0
|
37 |
git = os.environ.get('GIT', "git")
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def interrupt_callback():
|
41 |
return INTERRUPTED
|
@@ -72,7 +83,7 @@ def toggle_audio_src(choice):
|
|
72 |
else:
|
73 |
return gr.update(source="upload", value=None, label="File")
|
74 |
|
75 |
-
def
|
76 |
# Further remove some warnings.
|
77 |
be = time.time()
|
78 |
with warnings.catch_warnings():
|
@@ -80,6 +91,7 @@ def make_waveform(*args, **kwargs):
|
|
80 |
out = gr.make_waveform(*args, **kwargs)
|
81 |
print("Make a video took", time.time() - be)
|
82 |
return out
|
|
|
83 |
|
84 |
def load_model(version):
|
85 |
global MODEL, MODELS, UNLOAD_MODEL
|
@@ -102,32 +114,12 @@ def load_model(version):
|
|
102 |
print("Cached model loaded in %.2fs" % (time.monotonic() - t1))
|
103 |
return result
|
104 |
|
105 |
-
def get_filename(file):
|
106 |
-
# extract filename from file object
|
107 |
-
filename = None
|
108 |
-
if file is not None:
|
109 |
-
filename = file.name
|
110 |
-
return filename
|
111 |
-
|
112 |
-
def get_filename_from_filepath(filepath):
|
113 |
-
file_name = os.path.basename(filepath)
|
114 |
-
file_base, file_extension = os.path.splitext(file_name)
|
115 |
-
return file_base, file_extension
|
116 |
-
|
117 |
def get_melody(melody_filepath):
|
118 |
audio_data= list(librosa.load(melody_filepath, sr=None))
|
119 |
audio_data[0], audio_data[1] = audio_data[1], audio_data[0]
|
120 |
melody = tuple(audio_data)
|
121 |
return melody
|
122 |
|
123 |
-
|
124 |
-
def commit_hash():
|
125 |
-
try:
|
126 |
-
return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
127 |
-
except Exception:
|
128 |
-
return "<none>"
|
129 |
-
|
130 |
-
|
131 |
def git_tag():
|
132 |
try:
|
133 |
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
@@ -140,28 +132,6 @@ def git_tag():
|
|
140 |
except Exception:
|
141 |
return "<none>"
|
142 |
|
143 |
-
def versions_html():
|
144 |
-
import torch
|
145 |
-
|
146 |
-
python_version = ".".join([str(x) for x in sys.version_info[0:3]])
|
147 |
-
commit = commit_hash()
|
148 |
-
#tag = git_tag()
|
149 |
-
|
150 |
-
import xformers
|
151 |
-
xformers_version = xformers.__version__
|
152 |
-
|
153 |
-
return f"""
|
154 |
-
version: <a href="https://github.com/Oncorporation/audiocraft/commit/{"huggingface" if commit == "<none>" else commit}" target="_blank">{"huggingface" if commit == "<none>" else commit}</a>
|
155 |
-
 • 
|
156 |
-
python: <span title="{sys.version}">{python_version}</span>
|
157 |
-
 • 
|
158 |
-
torch: {getattr(torch, '__long_version__',torch.__version__)}
|
159 |
-
 • 
|
160 |
-
xformers: {xformers_version}
|
161 |
-
 • 
|
162 |
-
gradio: {gr.__version__}
|
163 |
-
"""
|
164 |
-
|
165 |
def load_melody_filepath(melody_filepath, title):
|
166 |
# get melody filename
|
167 |
#$Union[str, os.PathLike]
|
@@ -187,12 +157,13 @@ def load_melody_filepath(melody_filepath, title):
|
|
187 |
print(f"Melody length: {len(melody_data)}, Melody segments: {total_melodys}\n")
|
188 |
MAX_PROMPT_INDEX = total_melodys
|
189 |
|
190 |
-
return gr.
|
191 |
|
192 |
def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False):
|
193 |
global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU
|
194 |
output_segments = None
|
195 |
melody_name = "Not Used"
|
|
|
196 |
melody = None
|
197 |
if melody_filepath:
|
198 |
melody_name, melody_extension = get_filename_from_filepath(melody_filepath)
|
@@ -201,17 +172,23 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
|
|
201 |
INTERRUPTED = False
|
202 |
INTERRUPTING = False
|
203 |
if temperature < 0:
|
|
|
204 |
raise gr.Error("Temperature must be >= 0.")
|
205 |
if topk < 0:
|
|
|
206 |
raise gr.Error("Topk must be non-negative.")
|
207 |
if topp < 0:
|
|
|
208 |
raise gr.Error("Topp must be non-negative.")
|
209 |
|
210 |
-
|
211 |
-
MODEL
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
215 |
|
216 |
# prevent hacking
|
217 |
duration = min(duration, 720)
|
@@ -251,35 +228,41 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
|
|
251 |
rep_penalty=0.5
|
252 |
)
|
253 |
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
melody
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
#output = MODEL.generate(descriptions=[text], progress=False)
|
275 |
-
if not output_segments:
|
276 |
-
next_segment = MODEL.generate(descriptions=[text], progress=False)
|
277 |
-
duration -= segment_duration
|
278 |
else:
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
|
284 |
if INTERRUPTING:
|
285 |
INTERRUPTED = True
|
@@ -287,6 +270,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
|
|
287 |
print("Function execution interrupted!")
|
288 |
raise gr.Error("Interrupted.")
|
289 |
|
|
|
290 |
if output_segments:
|
291 |
try:
|
292 |
# Combine the output segments into one long audio file or stack tracks
|
@@ -312,7 +296,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
|
|
312 |
##overlapping_output = torch.cat([output[:, :, -overlap_samples:], output_segments[i][:, :, :overlap_samples]], dim=1) #stack tracks
|
313 |
##print(f" overlap size stack:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}")
|
314 |
#overlapping_output = torch.cat([output[:, :, -overlap_samples:], output_segments[i][:, :, :overlap_samples]], dim=2) #stack tracks
|
315 |
-
#print(f" overlap size cat:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}")
|
316 |
output = torch.cat([output[:, :, :-overlap_samples], overlapping_output, output_segments[i][:, :, overlap_samples:]], dim=dimension)
|
317 |
else:
|
318 |
output = torch.cat([output, output_segments[i]], dim=dimension)
|
@@ -321,143 +305,227 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
|
|
321 |
print(f"Error combining segments: {e}. Using the first segment only.")
|
322 |
output = output_segments[0].detach().cpu().float()[0]
|
323 |
else:
|
324 |
-
output
|
325 |
-
|
326 |
-
|
|
|
|
|
|
|
|
|
327 |
video_description = f"{text}\n Duration: {str(initial_duration)} Dimension: {dimension}\n Top-k:{topk} Top-p:{topp}\n Randomness:{temperature}\n cfg:{cfg_coef} overlap: {overlap}\n Seed: {seed}\n Model: {model}\n Melody Condition:{melody_name}\n Sample Segment: {prompt_index}"
|
328 |
if include_settings or include_title:
|
329 |
background = add_settings_to_image(title if include_title else "", video_description if include_settings else "", background_path=background, font=settings_font, font_color=settings_font_color)
|
330 |
audio_write(
|
331 |
file.name, output, MODEL.sample_rate, strategy="loudness",
|
332 |
-
loudness_headroom_db=18, loudness_compressor=True, add_suffix=False, channels=2)
|
333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
if MOVE_TO_CPU:
|
335 |
MODEL.to('cpu')
|
336 |
if UNLOAD_MODEL:
|
337 |
MODEL = None
|
338 |
torch.cuda.empty_cache()
|
339 |
torch.cuda.ipc_collect()
|
340 |
-
return
|
341 |
|
|
|
342 |
def ui(**kwargs):
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
#btn-generate:hover {background-image:linear-gradient(to right bottom, rgb(229, 255, 229), rgb(255, 255, 255));}
|
348 |
-
#btn-generate:active {background-image:linear-gradient(to right bottom, rgb(229, 255, 235), rgb(157, 255, 157));}
|
349 |
-
#versions {margin-top: 1em; width:100%; text-align:center;}
|
350 |
-
.small-btn {max-width:75px;}
|
351 |
-
"""
|
352 |
-
with gr.Blocks(title="UnlimitedMusicGen", css=css) as demo:
|
353 |
-
gr.Markdown(
|
354 |
-
"""
|
355 |
# UnlimitedMusicGen
|
356 |
This is your private demo for [UnlimitedMusicGen](https://github.com/Oncorporation/audiocraft), a simple and controllable model for music generation
|
357 |
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
|
358 |
|
359 |
Disclaimer: This won't run on CPU only. Clone this App and run on GPU instance!
|
360 |
|
361 |
-
Todo: Working on improved Interrupt
|
362 |
-
""
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
text = gr.Text(label="Describe your music", interactive=True, value="4/4 100bpm 320kbps 48khz, Industrial/Electronic Soundtrack, Dark, Intense, Sci-Fi")
|
376 |
-
with gr.Column():
|
377 |
-
duration = gr.Slider(minimum=1, maximum=720, value=10, label="Duration (s)", interactive=True)
|
378 |
-
model = gr.Radio(["melody", "medium", "small", "large", "melody-large", "stereo-melody", "stereo-medium", "stereo-small", "stereo-large", "stereo-melody-large"], label="AI Model", value="melody-large", interactive=True)
|
379 |
-
with gr.Row():
|
380 |
-
submit = gr.Button("Generate", elem_id="btn-generate")
|
381 |
-
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
|
382 |
-
_ = gr.Button("Interrupt", elem_id="btn-interrupt").click(fn=interrupt, queue=False)
|
383 |
-
with gr.Row():
|
384 |
-
with gr.Column():
|
385 |
-
radio = gr.Radio(["file", "mic"], value="file", label="Condition on a melody (optional) File or Mic")
|
386 |
-
melody_filepath = gr.Audio(source="upload", type="filepath", label="Melody Condition (optional)", interactive=True, elem_id="melody-input")
|
387 |
with gr.Column():
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
with gr.
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
],
|
431 |
-
[
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
"4/4 120bpm 320kbps 48khz, 90s rock song with electric guitar and heavy drums",
|
439 |
-
None,
|
440 |
-
"medium",
|
441 |
-
"90s Rock Guitar"
|
442 |
-
],
|
443 |
-
[
|
444 |
-
"4/4 120bpm 320kbps 48khz, a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
|
445 |
-
"./assets/bach.mp3",
|
446 |
-
"melody",
|
447 |
-
"EDM my Bach"
|
448 |
-
],
|
449 |
-
[
|
450 |
-
"4/4 320kbps 48khz, lofi slow bpm electro chill with organic samples",
|
451 |
-
None,
|
452 |
-
"medium",
|
453 |
-
"LoFi Chill"
|
454 |
-
],
|
455 |
-
],
|
456 |
-
inputs=[text, melody_filepath, model, title],
|
457 |
-
outputs=[output]
|
458 |
-
)
|
459 |
-
gr.HTML(value=versions_html(), visible=True, elem_id="versions")
|
460 |
-
|
461 |
# Show the interface
|
462 |
launch_kwargs = {}
|
463 |
share = kwargs.get('share', False)
|
@@ -471,10 +539,10 @@ def ui(**kwargs):
|
|
471 |
if share:
|
472 |
launch_kwargs['share'] = share
|
473 |
launch_kwargs['favicon_path']= "./assets/favicon.ico"
|
|
|
474 |
|
475 |
|
476 |
-
|
477 |
-
demo.queue(max_size=10, concurrency_count=1, api_open=False).launch(**launch_kwargs)
|
478 |
|
479 |
if __name__ == "__main__":
|
480 |
parser = argparse.ArgumentParser()
|
@@ -518,7 +586,7 @@ if __name__ == "__main__":
|
|
518 |
args = parser.parse_args()
|
519 |
|
520 |
launch_kwargs = {}
|
521 |
-
launch_kwargs['
|
522 |
|
523 |
if args.username and args.password:
|
524 |
launch_kwargs['auth'] = (args.username, args.password)
|
@@ -528,7 +596,7 @@ if __name__ == "__main__":
|
|
528 |
launch_kwargs['inbrowser'] = args.inbrowser
|
529 |
if args.share:
|
530 |
launch_kwargs['share'] = args.share
|
531 |
-
launch_kwargs['favicon_path']= "./assets/favicon.ico"
|
532 |
|
533 |
|
534 |
UNLOAD_MODEL = args.unload_model
|
@@ -538,6 +606,6 @@ if __name__ == "__main__":
|
|
538 |
|
539 |
ui(
|
540 |
unload_to_cpu = MOVE_TO_CPU,
|
541 |
-
share=args.share
|
542 |
-
|
543 |
)
|
|
|
21 |
from audiocraft.data.audio import audio_write
|
22 |
from audiocraft.data.audio_utils import apply_fade, apply_tafade, apply_splice_effect
|
23 |
from audiocraft.utils.extend import generate_music_segments, add_settings_to_image, INTERRUPTING
|
24 |
+
from audiocraft.utils import utils
|
25 |
import numpy as np
|
26 |
import random
|
27 |
+
import shutil
|
28 |
+
from mutagen.mp4 import MP4
|
29 |
#from typing import List, Union
|
30 |
import librosa
|
31 |
+
import modules.user_history
|
32 |
+
from modules.version_info import versions_html, commit_hash, get_xformers_version
|
33 |
+
from modules.gradio import *
|
34 |
+
from modules.file_utils import get_file_parts, get_filename_from_filepath, convert_title_to_filename, get_filename, delete_file
|
35 |
|
36 |
MODEL = None
|
37 |
MODELS = None
|
|
|
41 |
MOVE_TO_CPU = False
|
42 |
MAX_PROMPT_INDEX = 0
|
43 |
git = os.environ.get('GIT', "git")
|
44 |
+
#s.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
45 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
46 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
47 |
+
os.environ['CUDA_MODULE_LOADING']='LAZY'
|
48 |
+
os.environ['USE_FLASH_ATTENTION'] = '1'
|
49 |
+
os.environ['XFORMERS_FORCE_DISABLE_TRITON']= '1'
|
50 |
|
51 |
def interrupt_callback():
|
52 |
return INTERRUPTED
|
|
|
83 |
else:
|
84 |
return gr.update(source="upload", value=None, label="File")
|
85 |
|
86 |
+
def get_waveform(*args, **kwargs):
|
87 |
# Further remove some warnings.
|
88 |
be = time.time()
|
89 |
with warnings.catch_warnings():
|
|
|
91 |
out = gr.make_waveform(*args, **kwargs)
|
92 |
print("Make a video took", time.time() - be)
|
93 |
return out
|
94 |
+
|
95 |
|
96 |
def load_model(version):
|
97 |
global MODEL, MODELS, UNLOAD_MODEL
|
|
|
114 |
print("Cached model loaded in %.2fs" % (time.monotonic() - t1))
|
115 |
return result
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
def get_melody(melody_filepath):
|
118 |
audio_data= list(librosa.load(melody_filepath, sr=None))
|
119 |
audio_data[0], audio_data[1] = audio_data[1], audio_data[0]
|
120 |
melody = tuple(audio_data)
|
121 |
return melody
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
def git_tag():
|
124 |
try:
|
125 |
return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
|
|
|
132 |
except Exception:
|
133 |
return "<none>"
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
def load_melody_filepath(melody_filepath, title):
|
136 |
# get melody filename
|
137 |
#$Union[str, os.PathLike]
|
|
|
157 |
print(f"Melody length: {len(melody_data)}, Melody segments: {total_melodys}\n")
|
158 |
MAX_PROMPT_INDEX = total_melodys
|
159 |
|
160 |
+
return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=0), gr.update(value="melody", interactive=True)
|
161 |
|
162 |
def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False):
|
163 |
global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU
|
164 |
output_segments = None
|
165 |
melody_name = "Not Used"
|
166 |
+
melody_extension = "Not Used"
|
167 |
melody = None
|
168 |
if melody_filepath:
|
169 |
melody_name, melody_extension = get_filename_from_filepath(melody_filepath)
|
|
|
172 |
INTERRUPTED = False
|
173 |
INTERRUPTING = False
|
174 |
if temperature < 0:
|
175 |
+
temperature -0
|
176 |
raise gr.Error("Temperature must be >= 0.")
|
177 |
if topk < 0:
|
178 |
+
topk = 1
|
179 |
raise gr.Error("Topk must be non-negative.")
|
180 |
if topp < 0:
|
181 |
+
topp =1
|
182 |
raise gr.Error("Topp must be non-negative.")
|
183 |
|
184 |
+
try:
|
185 |
+
if MODEL is None or MODEL.name != model:
|
186 |
+
MODEL = load_model(model)
|
187 |
+
else:
|
188 |
+
if MOVE_TO_CPU:
|
189 |
+
MODEL.to('cuda')
|
190 |
+
except Exception as e:
|
191 |
+
raise gr.Error(f"Error loading model '{model}': {str(e)}. Try a different model.")
|
192 |
|
193 |
# prevent hacking
|
194 |
duration = min(duration, 720)
|
|
|
228 |
rep_penalty=0.5
|
229 |
)
|
230 |
|
231 |
+
try:
|
232 |
+
if melody:
|
233 |
+
# return excess duration, load next model and continue in loop structure building up output_segments
|
234 |
+
if duration > MODEL.lm.cfg.dataset.segment_duration:
|
235 |
+
output_segments, duration = generate_music_segments(text, melody, seed, MODEL, duration, overlap, MODEL.lm.cfg.dataset.segment_duration, prompt_index, harmony_only=False)
|
236 |
+
else:
|
237 |
+
# pure original code
|
238 |
+
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
|
239 |
+
print(melody.shape)
|
240 |
+
if melody.dim() == 2:
|
241 |
+
melody = melody[None]
|
242 |
+
melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
|
243 |
+
output = MODEL.generate_with_chroma(
|
244 |
+
descriptions=[text],
|
245 |
+
melody_wavs=melody,
|
246 |
+
melody_sample_rate=sr,
|
247 |
+
progress=True
|
248 |
+
)
|
249 |
+
# All output_segments are populated, so we can break the loop or set duration to 0
|
250 |
+
break
|
|
|
|
|
|
|
|
|
251 |
else:
|
252 |
+
#output = MODEL.generate(descriptions=[text], progress=False)
|
253 |
+
if not output_segments:
|
254 |
+
next_segment = MODEL.generate(descriptions=[text], progress=True)
|
255 |
+
duration -= segment_duration
|
256 |
+
else:
|
257 |
+
last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:]
|
258 |
+
next_segment = MODEL.generate_continuation(last_chunk, MODEL.sample_rate, descriptions=[text], progress=True)
|
259 |
+
duration -= segment_duration - overlap
|
260 |
+
if next_segment != None:
|
261 |
+
output_segments.append(next_segment)
|
262 |
+
except Exception as e:
|
263 |
+
print(f"Error generating audio: {e}")
|
264 |
+
gr.Error(f"Error generating audio: {e}")
|
265 |
+
return None, None, seed
|
266 |
|
267 |
if INTERRUPTING:
|
268 |
INTERRUPTED = True
|
|
|
270 |
print("Function execution interrupted!")
|
271 |
raise gr.Error("Interrupted.")
|
272 |
|
273 |
+
print(f"\nOutput segments: {len(output_segments)}\n")
|
274 |
if output_segments:
|
275 |
try:
|
276 |
# Combine the output segments into one long audio file or stack tracks
|
|
|
296 |
##overlapping_output = torch.cat([output[:, :, -overlap_samples:], output_segments[i][:, :, :overlap_samples]], dim=1) #stack tracks
|
297 |
##print(f" overlap size stack:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}")
|
298 |
#overlapping_output = torch.cat([output[:, :, -overlap_samples:], output_segments[i][:, :, :overlap_samples]], dim=2) #stack tracks
|
299 |
+
#print(f" overlap size cat:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}")
|
300 |
output = torch.cat([output[:, :, :-overlap_samples], overlapping_output, output_segments[i][:, :, overlap_samples:]], dim=dimension)
|
301 |
else:
|
302 |
output = torch.cat([output, output_segments[i]], dim=dimension)
|
|
|
305 |
print(f"Error combining segments: {e}. Using the first segment only.")
|
306 |
output = output_segments[0].detach().cpu().float()[0]
|
307 |
else:
|
308 |
+
if (output is None) or (output.dim() == 0):
|
309 |
+
return None, None, seed
|
310 |
+
else:
|
311 |
+
output = output.detach().cpu().float()[0]
|
312 |
+
profile: gr.OAuthProfile | None = None
|
313 |
+
title_file_name = convert_title_to_filename(title)
|
314 |
+
with NamedTemporaryFile("wb", suffix=".wav", delete=False, prefix = title_file_name) as file:
|
315 |
video_description = f"{text}\n Duration: {str(initial_duration)} Dimension: {dimension}\n Top-k:{topk} Top-p:{topp}\n Randomness:{temperature}\n cfg:{cfg_coef} overlap: {overlap}\n Seed: {seed}\n Model: {model}\n Melody Condition:{melody_name}\n Sample Segment: {prompt_index}"
|
316 |
if include_settings or include_title:
|
317 |
background = add_settings_to_image(title if include_title else "", video_description if include_settings else "", background_path=background, font=settings_font, font_color=settings_font_color)
|
318 |
audio_write(
|
319 |
file.name, output, MODEL.sample_rate, strategy="loudness",
|
320 |
+
loudness_headroom_db=18, loudness_compressor=True, add_suffix=False, channels=2)
|
321 |
+
waveform_video_path = get_waveform(file.name,bg_image=background, bar_count=45, name = title_file_name)
|
322 |
+
# Remove the extension from file.name
|
323 |
+
file_name_without_extension = os.path.splitext(file.name)[0]
|
324 |
+
# Get the directory, filename, name, extension, and new extension of the waveform video path
|
325 |
+
video_dir, video_name, video_name, video_ext, video_new_ext = get_file_parts(waveform_video_path)
|
326 |
+
|
327 |
+
new_video_path = os.path.join(video_dir, title_file_name + video_new_ext)
|
328 |
+
|
329 |
+
mp4 = MP4(waveform_video_path)
|
330 |
+
mp4["©nam"] = title_file_name # Title tag
|
331 |
+
mp4["desc"] = f"{text}\n Duration: {str(initial_duration)}" # Description tag
|
332 |
+
|
333 |
+
commit = commit_hash()
|
334 |
+
metadata={
|
335 |
+
"prompt": text,
|
336 |
+
"negative_prompt": "",
|
337 |
+
"Seed": seed,
|
338 |
+
"steps": 1,
|
339 |
+
"width": "768px",
|
340 |
+
"height":"512px",
|
341 |
+
"Dimension": dimension,
|
342 |
+
"Top-k": topk,
|
343 |
+
"Top-p":topp,
|
344 |
+
"Randomness": temperature,
|
345 |
+
"cfg":cfg_coef,
|
346 |
+
"overlap": overlap,
|
347 |
+
"Melody Condition": melody_name,
|
348 |
+
"Sample Segment": prompt_index,
|
349 |
+
"Duration": initial_duration,
|
350 |
+
"Audio": file.name,
|
351 |
+
"font": settings_font,
|
352 |
+
"font_color": settings_font_color,
|
353 |
+
"harmony_only": harmony_only,
|
354 |
+
"background": background,
|
355 |
+
"include_title": include_title,
|
356 |
+
"include_settings": include_settings,
|
357 |
+
"profile": profile,
|
358 |
+
"commit": commit_hash(),
|
359 |
+
"tag": git_tag(),
|
360 |
+
"version": gr.__version__,
|
361 |
+
"model_version": MODEL.version,
|
362 |
+
"model_name": MODEL.name,
|
363 |
+
"model_description": f"{MODEL.audio_channels} channels, {MODEL.sample_rate} Hz",
|
364 |
+
"melody_name" : melody_name if melody_name else "",
|
365 |
+
"melody_extension" : melody_extension if melody_extension else "",
|
366 |
+
"hostname": "https://huggingface.co/spaces/Surn/UnlimitedMusicGen",
|
367 |
+
"version" : f"""https://huggingface.co/spaces/Surn/UnlimitedMusicGen/commit/{"huggingface" if commit == "<none>" else commit}""",
|
368 |
+
"python" : sys.version,
|
369 |
+
"torch" : getattr(torch, '__long_version__',torch.__version__),
|
370 |
+
"xformers": get_xformers_version(),
|
371 |
+
"gradio": gr.__version__,
|
372 |
+
"huggingface_space": os.environ.get('SPACE_ID', ''),
|
373 |
+
"CUDA": f"""{"CUDA is available. device: " + torch.cuda.get_device_name(0) + " version: " + torch.version.cuda if torch.cuda.is_available() else "CUDA is not available."}""",
|
374 |
+
}
|
375 |
+
# Add additional metadata from the metadata dictionary (if it exists)
|
376 |
+
for key, value in metadata.items():
|
377 |
+
mp4[key] = str(value) # Convert values to strings as required by mutagen
|
378 |
+
|
379 |
+
# Save the metadata changes to the file
|
380 |
+
mp4.save()
|
381 |
+
|
382 |
+
try:
|
383 |
+
if os.path.exists(new_video_path):
|
384 |
+
delete_file(new_video_path)
|
385 |
+
# Open the original MP4 file in binary read mode and the new file in binary write mode
|
386 |
+
with open(waveform_video_path, "rb") as src, open(new_video_path, "wb") as dst:
|
387 |
+
if os.path.exists(waveform_video_path):
|
388 |
+
# Copy the contents from the source file to the destination file
|
389 |
+
shutil.copyfileobj(src, dst)
|
390 |
+
waveform_video_path = new_video_path
|
391 |
+
except Exception as e:
|
392 |
+
print(f"Error copying file: {e}")
|
393 |
+
|
394 |
+
if waveform_video_path:
|
395 |
+
modules.user_history.save_file(
|
396 |
+
profile=profile,
|
397 |
+
image=background,
|
398 |
+
audio=file,
|
399 |
+
video=waveform_video_path,
|
400 |
+
label=text,
|
401 |
+
metadata=metadata,
|
402 |
+
)
|
403 |
+
|
404 |
+
|
405 |
if MOVE_TO_CPU:
|
406 |
MODEL.to('cpu')
|
407 |
if UNLOAD_MODEL:
|
408 |
MODEL = None
|
409 |
torch.cuda.empty_cache()
|
410 |
torch.cuda.ipc_collect()
|
411 |
+
return waveform_video_path, file.name, seed
|
412 |
|
413 |
+
gr.set_static_paths(paths=["fonts/","assets/"])
|
414 |
def ui(**kwargs):
|
415 |
+
with gr.Blocks(title="UnlimitedMusicGen",css_paths="style_20250331.css", theme='Surn/beeuty') as interface:
|
416 |
+
with gr.Tab("UnlimitedMusicGen"):
|
417 |
+
gr.Markdown(
|
418 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
# UnlimitedMusicGen
|
420 |
This is your private demo for [UnlimitedMusicGen](https://github.com/Oncorporation/audiocraft), a simple and controllable model for music generation
|
421 |
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
|
422 |
|
423 |
Disclaimer: This won't run on CPU only. Clone this App and run on GPU instance!
|
424 |
|
425 |
+
Todo: Working on improved Interrupt.
|
426 |
+
Theme Available at ["Surn/Beeuty"](https://huggingface.co/spaces/Surn/Beeuty)
|
427 |
+
|
428 |
+
"""
|
429 |
+
)
|
430 |
+
if IS_SHARED_SPACE and not torch.cuda.is_available():
|
431 |
+
gr.Markdown("""
|
432 |
+
⚠ This Space doesn't work in this shared UI ⚠
|
433 |
+
|
434 |
+
<a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
|
435 |
+
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
436 |
+
to use it privately, or use the <a href="https://huggingface.co/spaces/facebook/MusicGen">public demo</a>
|
437 |
+
""")
|
438 |
+
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
439 |
with gr.Column():
|
440 |
+
with gr.Row():
|
441 |
+
text = gr.Text(label="Describe your music", interactive=True, value="4/4 100bpm 320kbps 48khz, Industrial/Electronic Soundtrack, Dark, Intense, Sci-Fi")
|
442 |
+
with gr.Column():
|
443 |
+
duration = gr.Slider(minimum=1, maximum=720, value=10, label="Duration (s)", interactive=True)
|
444 |
+
model = gr.Radio(["melody", "medium", "small", "large", "melody-large", "stereo-small", "stereo-medium", "stereo-large", "stereo-melody", "stereo-melody-large"], label="AI Model", value="melody", interactive=True)
|
445 |
+
with gr.Row():
|
446 |
+
submit = gr.Button("Generate", elem_id="btn-generate")
|
447 |
+
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
|
448 |
+
_ = gr.Button("Interrupt", elem_id="btn-interrupt").click(fn=interrupt, queue=False)
|
449 |
+
with gr.Row():
|
450 |
+
with gr.Column():
|
451 |
+
radio = gr.Radio(["file", "mic"], value="file", label="Condition on a melody (optional) File or Mic")
|
452 |
+
melody_filepath = gr.Audio(sources=["upload"], type="filepath", label="Melody Condition (optional)", interactive=True, elem_id="melody-input")
|
453 |
+
with gr.Column():
|
454 |
+
harmony_only = gr.Radio(label="Use Harmony Only",choices=["No", "Yes"], value="No", interactive=True, info="Remove Drums?")
|
455 |
+
prompt_index = gr.Slider(label="Melody Condition Sample Segment", minimum=-1, maximum=MAX_PROMPT_INDEX, step=1, value=0, interactive=True, info="Which 30 second segment to condition with, - 1 condition each segment independantly")
|
456 |
+
with gr.Accordion("Video", open=False):
|
457 |
+
with gr.Row():
|
458 |
+
background= gr.Image(value="./assets/background.png", sources=["upload"], label="Background", width=768, height=512, type="filepath", interactive=True)
|
459 |
+
with gr.Column():
|
460 |
+
include_title = gr.Checkbox(label="Add Title", value=True, interactive=True)
|
461 |
+
include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True)
|
462 |
+
with gr.Row():
|
463 |
+
title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True)
|
464 |
+
settings_font = gr.Text(label="Settings Font", value="./assets/arial.ttf", interactive=True)
|
465 |
+
settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#c87f05", interactive=True)
|
466 |
+
with gr.Accordion("Expert", open=False):
|
467 |
+
with gr.Row():
|
468 |
+
overlap = gr.Slider(minimum=0, maximum=15, value=2, step=1, label="Verse Overlap", interactive=True)
|
469 |
+
dimension = gr.Slider(minimum=-2, maximum=2, value=2, step=1, label="Dimension", info="determines which direction to add new segements of audio. (1 = stack tracks, 2 = lengthen, -2..0 = ?)", interactive=True)
|
470 |
+
with gr.Row():
|
471 |
+
topk = gr.Number(label="Top-k", value=280, precision=0, interactive=True)
|
472 |
+
topp = gr.Number(label="Top-p", value=1150, precision=0, interactive=True)
|
473 |
+
temperature = gr.Number(label="Randomness Temperature", value=0.7, precision=None, interactive=True)
|
474 |
+
cfg_coef = gr.Number(label="Classifier Free Guidance", value=8.5, precision=None, interactive=True)
|
475 |
+
with gr.Row():
|
476 |
+
seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
|
477 |
+
gr.Button('\U0001f3b2\ufe0f', elem_classes="small-btn").click(fn=lambda: -1, outputs=[seed], queue=False)
|
478 |
+
reuse_seed = gr.Button('\u267b\ufe0f', elem_classes="small-btn")
|
479 |
+
with gr.Column() as c:
|
480 |
+
output = gr.Video(label="Generated Music")
|
481 |
+
wave_file = gr.File(label=".wav file", elem_id="output_wavefile", interactive=True)
|
482 |
+
seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
|
483 |
+
|
484 |
+
radio.change(toggle_audio_src, radio, [melody_filepath], queue=False, show_progress=False)
|
485 |
+
melody_filepath.change(load_melody_filepath, inputs=[melody_filepath, title], outputs=[title, prompt_index , model], api_name="melody_filepath_change", queue=False)
|
486 |
+
reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False, api_name="reuse_seed")
|
487 |
+
submit.click(predict, inputs=[model, text,melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap, prompt_index, include_title, include_settings, harmony_only], outputs=[output, wave_file, seed_used], api_name="submit")
|
488 |
+
gr.Examples(
|
489 |
+
fn=predict,
|
490 |
+
examples=[
|
491 |
+
[
|
492 |
+
"4/4 120bpm 320kbps 48khz, An 80s driving pop song with heavy drums and synth pads in the background",
|
493 |
+
"./assets/bach.mp3",
|
494 |
+
"stereo-melody-large",
|
495 |
+
"80s Pop Synth"
|
496 |
+
],
|
497 |
+
[
|
498 |
+
"4/4 120bpm 320kbps 48khz, A cheerful country song with acoustic guitars",
|
499 |
+
"./assets/bolero_ravel.mp3",
|
500 |
+
"melody",
|
501 |
+
"Country Guitar"
|
502 |
+
],
|
503 |
+
[
|
504 |
+
"4/4 120bpm 320kbps 48khz, 90s rock song with electric guitar and heavy drums",
|
505 |
+
None,
|
506 |
+
"stereo-medium",
|
507 |
+
"90s Rock Guitar"
|
508 |
+
],
|
509 |
+
[
|
510 |
+
"4/4 120bpm 320kbps 48khz, a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
|
511 |
+
"./assets/bach.mp3",
|
512 |
+
"melody-large",
|
513 |
+
"EDM my Bach"
|
514 |
+
],
|
515 |
+
[
|
516 |
+
"4/4 320kbps 48khz, lofi slow bpm electro chill with organic samples",
|
517 |
+
None,
|
518 |
+
"medium",
|
519 |
+
"LoFi Chill"
|
520 |
+
],
|
521 |
],
|
522 |
+
inputs=[text, melody_filepath, model, title],
|
523 |
+
outputs=[output]
|
524 |
+
)
|
525 |
+
gr.HTML(value=versions_html(), visible=True, elem_id="versions")
|
526 |
+
with gr.Tab("User History") as history_tab:
|
527 |
+
modules.user_history.render()
|
528 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
# Show the interface
|
530 |
launch_kwargs = {}
|
531 |
share = kwargs.get('share', False)
|
|
|
539 |
if share:
|
540 |
launch_kwargs['share'] = share
|
541 |
launch_kwargs['favicon_path']= "./assets/favicon.ico"
|
542 |
+
|
543 |
|
544 |
|
545 |
+
interface.queue(max_size=10, api_open=False).launch(**launch_kwargs)
|
|
|
546 |
|
547 |
if __name__ == "__main__":
|
548 |
parser = argparse.ArgumentParser()
|
|
|
586 |
args = parser.parse_args()
|
587 |
|
588 |
launch_kwargs = {}
|
589 |
+
launch_kwargs['listen'] = args.listen
|
590 |
|
591 |
if args.username and args.password:
|
592 |
launch_kwargs['auth'] = (args.username, args.password)
|
|
|
596 |
launch_kwargs['inbrowser'] = args.inbrowser
|
597 |
if args.share:
|
598 |
launch_kwargs['share'] = args.share
|
599 |
+
launch_kwargs['favicon_path']= "./assets/favicon.ico"
|
600 |
|
601 |
|
602 |
UNLOAD_MODEL = args.unload_model
|
|
|
606 |
|
607 |
ui(
|
608 |
unload_to_cpu = MOVE_TO_CPU,
|
609 |
+
share=args.share,
|
610 |
+
**launch_kwargs,
|
611 |
)
|
assets/KuritaSurnLogox64.png
ADDED
![]() |
assets/Vermilion-Musical-Notes-Typography-No-Background.svg
ADDED
|
assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3
ADDED
Binary file (15.2 kB). View file
|
|
assets/icon_delete.png
ADDED
![]() |
assets/icon_download.png
ADDED
![]() |
assets/icon_refresh.png
ADDED
![]() |
assets/logo_animation_256.gif
ADDED
![]() |
Git LFS Details
|
assets/screenshot.png
ADDED
![]() |
Git LFS Details
|
assets/sirens_and_a_humming_engine_approach_and_pass.mp3
ADDED
Binary file (15.2 kB). View file
|
|
audiocraft/__init__.py
CHANGED
@@ -7,4 +7,4 @@
|
|
7 |
# flake8: noqa
|
8 |
from . import data, modules, models
|
9 |
|
10 |
-
__version__ = '
|
|
|
7 |
# flake8: noqa
|
8 |
from . import data, modules, models
|
9 |
|
10 |
+
__version__ = '1.4.Surn'
|
audiocraft/data/__init__.py
CHANGED
@@ -5,4 +5,4 @@
|
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
# flake8: noqa
|
8 |
-
from . import audio, audio_dataset
|
|
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
# flake8: noqa
|
8 |
+
from . import audio, audio_dataset, info_audio_dataset
|
audiocraft/data/audio.py
CHANGED
@@ -21,6 +21,7 @@ from torch.nn import functional as F
|
|
21 |
import torchaudio as ta
|
22 |
|
23 |
import av
|
|
|
24 |
|
25 |
from .audio_utils import f32_pcm, i16_pcm, normalize_audio, convert_audio
|
26 |
|
@@ -149,7 +150,17 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
|
149 |
wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
|
150 |
return wav, sr
|
151 |
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
def audio_write(stem_name: tp.Union[str, Path],
|
154 |
wav: torch.Tensor, sample_rate: int,
|
155 |
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
|
@@ -215,3 +226,77 @@ def audio_write(stem_name: tp.Union[str, Path],
|
|
215 |
path.unlink()
|
216 |
raise
|
217 |
return path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
import torchaudio as ta
|
22 |
|
23 |
import av
|
24 |
+
import subprocess as sp
|
25 |
|
26 |
from .audio_utils import f32_pcm, i16_pcm, normalize_audio, convert_audio
|
27 |
|
|
|
150 |
wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
|
151 |
return wav, sr
|
152 |
|
153 |
+
def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
|
154 |
+
# ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
|
155 |
+
assert wav.dim() == 2, wav.shape
|
156 |
+
command = [
|
157 |
+
'ffmpeg',
|
158 |
+
'-loglevel', 'error',
|
159 |
+
'-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
|
160 |
+
'-i', '-'] + flags + [str(out_path)]
|
161 |
+
input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
|
162 |
+
sp.run(command, input=input_, check=True)
|
163 |
+
|
164 |
def audio_write(stem_name: tp.Union[str, Path],
|
165 |
wav: torch.Tensor, sample_rate: int,
|
166 |
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
|
|
|
226 |
path.unlink()
|
227 |
raise
|
228 |
return path
|
229 |
+
|
230 |
+
def audio_write2(stem_name: tp.Union[str, Path],
|
231 |
+
wav: torch.Tensor, sample_rate: int,
|
232 |
+
format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
|
233 |
+
normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
234 |
+
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
235 |
+
loudness_compressor: bool = False,
|
236 |
+
log_clipping: bool = True, make_parent_dir: bool = True,
|
237 |
+
add_suffix: bool = True) -> Path:
|
238 |
+
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
stem_name (str or Path): Filename without extension which will be added automatically.
|
242 |
+
wav (torch.Tensor): Audio data to save.
|
243 |
+
sample_rate (int): Sample rate of audio data.
|
244 |
+
format (str): Either "wav", "mp3", "ogg", or "flac".
|
245 |
+
mp3_rate (int): kbps when using mp3s.
|
246 |
+
ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
|
247 |
+
normalize (bool): if `True` (default), normalizes according to the prescribed
|
248 |
+
strategy (see after). If `False`, the strategy is only used in case clipping
|
249 |
+
would happen.
|
250 |
+
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
251 |
+
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
252 |
+
with extra headroom to avoid clipping. 'clip' just clips.
|
253 |
+
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
254 |
+
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
255 |
+
than the `peak_clip` one to avoid further clipping.
|
256 |
+
loudness_headroom_db (float): Target loudness for loudness normalization.
|
257 |
+
loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
|
258 |
+
when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
|
259 |
+
occurs despite strategy (only for 'rms').
|
260 |
+
make_parent_dir (bool): Make parent directory if it doesn't exist.
|
261 |
+
Returns:
|
262 |
+
Path: Path of the saved audio.
|
263 |
+
"""
|
264 |
+
assert wav.dtype.is_floating_point, "wav is not floating point"
|
265 |
+
if wav.dim() == 1:
|
266 |
+
wav = wav[None]
|
267 |
+
elif wav.dim() > 2:
|
268 |
+
raise ValueError("Input wav should be at most 2 dimension.")
|
269 |
+
assert wav.isfinite().all()
|
270 |
+
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
|
271 |
+
rms_headroom_db, loudness_headroom_db, loudness_compressor,
|
272 |
+
log_clipping=log_clipping, sample_rate=sample_rate,
|
273 |
+
stem_name=str(stem_name))
|
274 |
+
if format == 'mp3':
|
275 |
+
suffix = '.mp3'
|
276 |
+
flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
|
277 |
+
elif format == 'wav':
|
278 |
+
suffix = '.wav'
|
279 |
+
flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
|
280 |
+
elif format == 'ogg':
|
281 |
+
suffix = '.ogg'
|
282 |
+
flags = ['-f', 'ogg', '-c:a', 'libvorbis']
|
283 |
+
if ogg_rate is not None:
|
284 |
+
flags += ['-b:a', f'{ogg_rate}k']
|
285 |
+
elif format == 'flac':
|
286 |
+
suffix = '.flac'
|
287 |
+
flags = ['-f', 'flac']
|
288 |
+
else:
|
289 |
+
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
|
290 |
+
if not add_suffix:
|
291 |
+
suffix = ''
|
292 |
+
path = Path(str(stem_name) + suffix)
|
293 |
+
if make_parent_dir:
|
294 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
295 |
+
try:
|
296 |
+
_piping_to_ffmpeg(path, wav, sample_rate, flags)
|
297 |
+
except Exception:
|
298 |
+
if path.exists():
|
299 |
+
# we do not want to leave half written files around.
|
300 |
+
path.unlink()
|
301 |
+
raise
|
302 |
+
return path
|
audiocraft/data/audio_dataset.py
CHANGED
@@ -3,12 +3,16 @@
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
|
|
|
|
|
|
7 |
import argparse
|
8 |
import copy
|
9 |
from concurrent.futures import ThreadPoolExecutor, Future
|
10 |
from dataclasses import dataclass, fields
|
11 |
from contextlib import ExitStack
|
|
|
12 |
import gzip
|
13 |
import json
|
14 |
import logging
|
@@ -81,9 +85,12 @@ class AudioMeta(BaseInfo):
|
|
81 |
class SegmentInfo(BaseInfo):
|
82 |
meta: AudioMeta
|
83 |
seek_time: float
|
84 |
-
|
|
|
|
|
85 |
total_frames: int # total number of frames, padding included
|
86 |
-
sample_rate: int
|
|
|
87 |
|
88 |
|
89 |
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
|
@@ -114,8 +121,8 @@ def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
|
|
114 |
|
115 |
Args:
|
116 |
m (AudioMeta): Audio meta to resolve.
|
117 |
-
fast (bool): If True, uses a really fast check for determining if a file
|
118 |
-
Only valid on Linux/Mac.
|
119 |
Returns:
|
120 |
AudioMeta: Audio meta with resolved path.
|
121 |
"""
|
@@ -151,7 +158,7 @@ def find_audio_files(path: tp.Union[Path, str],
|
|
151 |
progress (bool): Whether to log progress on audio files collection.
|
152 |
workers (int): number of parallel workers, if 0, use only the current thread.
|
153 |
Returns:
|
154 |
-
|
155 |
"""
|
156 |
audio_files = []
|
157 |
futures: tp.List[Future] = []
|
@@ -203,7 +210,7 @@ def load_audio_meta(path: tp.Union[str, Path],
|
|
203 |
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
|
204 |
fast (bool): activates some tricks to make things faster.
|
205 |
Returns:
|
206 |
-
|
207 |
"""
|
208 |
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
209 |
with open_fn(path, 'rb') as fp: # type: ignore
|
@@ -250,9 +257,14 @@ class AudioDataset:
|
|
250 |
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
|
251 |
original audio meta.
|
252 |
|
|
|
|
|
|
|
|
|
|
|
253 |
Args:
|
254 |
-
meta (
|
255 |
-
segment_duration (float): Optional segment duration of audio to load.
|
256 |
If not specified, the dataset will load the full audio segment from the file.
|
257 |
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
|
258 |
sample_rate (int): Target sample rate of the loaded audio samples.
|
@@ -266,10 +278,19 @@ class AudioDataset:
|
|
266 |
is shorter than the desired segment.
|
267 |
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
|
268 |
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
|
269 |
-
min_audio_duration (
|
270 |
audio shorter than this will be filtered out.
|
271 |
-
max_audio_duration (
|
272 |
audio longer than this will be filtered out.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
"""
|
274 |
def __init__(self,
|
275 |
meta: tp.List[AudioMeta],
|
@@ -285,16 +306,14 @@ class AudioDataset:
|
|
285 |
max_read_retry: int = 10,
|
286 |
return_info: bool = False,
|
287 |
min_audio_duration: tp.Optional[float] = None,
|
288 |
-
max_audio_duration: tp.Optional[float] = None
|
|
|
|
|
|
|
289 |
):
|
290 |
-
assert len(meta) > 0,
|
291 |
assert segment_duration is None or segment_duration > 0
|
292 |
assert segment_duration is None or min_segment_ratio >= 0
|
293 |
-
logging.debug(f'sample_on_duration: {sample_on_duration}')
|
294 |
-
logging.debug(f'sample_on_weight: {sample_on_weight}')
|
295 |
-
logging.debug(f'pad: {pad}')
|
296 |
-
logging.debug(f'min_segment_ratio: {min_segment_ratio}')
|
297 |
-
|
298 |
self.segment_duration = segment_duration
|
299 |
self.min_segment_ratio = min_segment_ratio
|
300 |
self.max_audio_duration = max_audio_duration
|
@@ -317,13 +336,25 @@ class AudioDataset:
|
|
317 |
self.sampling_probabilities = self._get_sampling_probabilities()
|
318 |
self.max_read_retry = max_read_retry
|
319 |
self.return_info = return_info
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
def __len__(self):
|
322 |
return self.num_samples
|
323 |
|
324 |
def _get_sampling_probabilities(self, normalized: bool = True):
|
325 |
-
"""Return the sampling probabilities for each file inside `self.meta`.
|
326 |
-
"""
|
327 |
scores: tp.List[float] = []
|
328 |
for file_meta in self.meta:
|
329 |
score = 1.
|
@@ -337,12 +368,32 @@ class AudioDataset:
|
|
337 |
probabilities /= probabilities.sum()
|
338 |
return probabilities
|
339 |
|
340 |
-
|
341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
This is only called if `segment_duration` is not None.
|
343 |
|
344 |
You must use the provided random number generator `rng` for reproducibility.
|
|
|
345 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
if not self.sample_on_weight and not self.sample_on_duration:
|
347 |
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
|
348 |
else:
|
@@ -350,6 +401,15 @@ class AudioDataset:
|
|
350 |
|
351 |
return self.meta[file_index]
|
352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
|
354 |
if self.segment_duration is None:
|
355 |
file_meta = self.meta[index]
|
@@ -357,18 +417,22 @@ class AudioDataset:
|
|
357 |
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
358 |
n_frames = out.shape[-1]
|
359 |
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
|
360 |
-
sample_rate=self.sample_rate)
|
361 |
else:
|
362 |
rng = torch.Generator()
|
363 |
if self.shuffle:
|
364 |
-
# We use index, plus extra randomness
|
365 |
-
|
|
|
|
|
|
|
|
|
366 |
else:
|
367 |
# We only use index
|
368 |
rng.manual_seed(index)
|
369 |
|
370 |
for retry in range(self.max_read_retry):
|
371 |
-
file_meta = self.sample_file(rng)
|
372 |
# We add some variance in the file position even if audio file is smaller than segment
|
373 |
# without ending up with empty segments
|
374 |
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
|
@@ -381,7 +445,7 @@ class AudioDataset:
|
|
381 |
if self.pad:
|
382 |
out = F.pad(out, (0, target_frames - n_frames))
|
383 |
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
|
384 |
-
sample_rate=self.sample_rate)
|
385 |
except Exception as exc:
|
386 |
logger.warning("Error opening file %s: %r", file_meta.path, exc)
|
387 |
if retry == self.max_read_retry - 1:
|
@@ -423,7 +487,7 @@ class AudioDataset:
|
|
423 |
if to_pad:
|
424 |
# Each wav could be of a different duration as they are not segmented.
|
425 |
for i in range(len(samples)):
|
426 |
-
# Determines the total
|
427 |
segment_infos[i].total_frames = max_len
|
428 |
wavs[i] = _pad_wav(wavs[i])
|
429 |
|
@@ -436,9 +500,7 @@ class AudioDataset:
|
|
436 |
return torch.stack(samples)
|
437 |
|
438 |
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
439 |
-
"""Filters out audio files with
|
440 |
-
Removes from meta files that have durations that will not allow to samples examples from them.
|
441 |
-
"""
|
442 |
orig_len = len(meta)
|
443 |
|
444 |
# Filter data that is too short.
|
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""AudioDataset support. In order to handle a larger number of files
|
7 |
+
without having to scan again the folders, we precompute some metadata
|
8 |
+
(filename, sample rate, duration), and use that to efficiently sample audio segments.
|
9 |
+
"""
|
10 |
import argparse
|
11 |
import copy
|
12 |
from concurrent.futures import ThreadPoolExecutor, Future
|
13 |
from dataclasses import dataclass, fields
|
14 |
from contextlib import ExitStack
|
15 |
+
from functools import lru_cache
|
16 |
import gzip
|
17 |
import json
|
18 |
import logging
|
|
|
85 |
class SegmentInfo(BaseInfo):
|
86 |
meta: AudioMeta
|
87 |
seek_time: float
|
88 |
+
# The following values are given once the audio is processed, e.g.
|
89 |
+
# at the target sample rate and target number of channels.
|
90 |
+
n_frames: int # actual number of frames without padding
|
91 |
total_frames: int # total number of frames, padding included
|
92 |
+
sample_rate: int # actual sample rate
|
93 |
+
channels: int # number of audio channels.
|
94 |
|
95 |
|
96 |
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
|
|
|
121 |
|
122 |
Args:
|
123 |
m (AudioMeta): Audio meta to resolve.
|
124 |
+
fast (bool): If True, uses a really fast check for determining if a file
|
125 |
+
is already absolute or not. Only valid on Linux/Mac.
|
126 |
Returns:
|
127 |
AudioMeta: Audio meta with resolved path.
|
128 |
"""
|
|
|
158 |
progress (bool): Whether to log progress on audio files collection.
|
159 |
workers (int): number of parallel workers, if 0, use only the current thread.
|
160 |
Returns:
|
161 |
+
list of AudioMeta: List of audio file path and its metadata.
|
162 |
"""
|
163 |
audio_files = []
|
164 |
futures: tp.List[Future] = []
|
|
|
210 |
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
|
211 |
fast (bool): activates some tricks to make things faster.
|
212 |
Returns:
|
213 |
+
list of AudioMeta: List of audio file path and its total duration.
|
214 |
"""
|
215 |
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
216 |
with open_fn(path, 'rb') as fp: # type: ignore
|
|
|
257 |
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
|
258 |
original audio meta.
|
259 |
|
260 |
+
Note that you can call `start_epoch(epoch)` in order to get
|
261 |
+
a deterministic "randomization" for `shuffle=True`.
|
262 |
+
For a given epoch and dataset index, this will always return the same extract.
|
263 |
+
You can get back some diversity by setting the `shuffle_seed` param.
|
264 |
+
|
265 |
Args:
|
266 |
+
meta (list of AudioMeta): List of audio files metadata.
|
267 |
+
segment_duration (float, optional): Optional segment duration of audio to load.
|
268 |
If not specified, the dataset will load the full audio segment from the file.
|
269 |
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
|
270 |
sample_rate (int): Target sample rate of the loaded audio samples.
|
|
|
278 |
is shorter than the desired segment.
|
279 |
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
|
280 |
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
|
281 |
+
min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
|
282 |
audio shorter than this will be filtered out.
|
283 |
+
max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
|
284 |
audio longer than this will be filtered out.
|
285 |
+
shuffle_seed (int): can be used to further randomize
|
286 |
+
load_wav (bool): if False, skip loading the wav but returns a tensor of 0
|
287 |
+
with the expected segment_duration (which must be provided if load_wav is False).
|
288 |
+
permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
|
289 |
+
are False. Will ensure a permutation on files when going through the dataset.
|
290 |
+
In that case the epoch number must be provided in order for the model
|
291 |
+
to continue the permutation across epochs. In that case, it is assumed
|
292 |
+
that `num_samples = total_batch_size * num_updates_per_epoch`, with
|
293 |
+
`total_batch_size` the overall batch size accounting for all gpus.
|
294 |
"""
|
295 |
def __init__(self,
|
296 |
meta: tp.List[AudioMeta],
|
|
|
306 |
max_read_retry: int = 10,
|
307 |
return_info: bool = False,
|
308 |
min_audio_duration: tp.Optional[float] = None,
|
309 |
+
max_audio_duration: tp.Optional[float] = None,
|
310 |
+
shuffle_seed: int = 0,
|
311 |
+
load_wav: bool = True,
|
312 |
+
permutation_on_files: bool = False,
|
313 |
):
|
314 |
+
assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
|
315 |
assert segment_duration is None or segment_duration > 0
|
316 |
assert segment_duration is None or min_segment_ratio >= 0
|
|
|
|
|
|
|
|
|
|
|
317 |
self.segment_duration = segment_duration
|
318 |
self.min_segment_ratio = min_segment_ratio
|
319 |
self.max_audio_duration = max_audio_duration
|
|
|
336 |
self.sampling_probabilities = self._get_sampling_probabilities()
|
337 |
self.max_read_retry = max_read_retry
|
338 |
self.return_info = return_info
|
339 |
+
self.shuffle_seed = shuffle_seed
|
340 |
+
self.current_epoch: tp.Optional[int] = None
|
341 |
+
self.load_wav = load_wav
|
342 |
+
if not load_wav:
|
343 |
+
assert segment_duration is not None
|
344 |
+
self.permutation_on_files = permutation_on_files
|
345 |
+
if permutation_on_files:
|
346 |
+
assert not self.sample_on_duration
|
347 |
+
assert not self.sample_on_weight
|
348 |
+
assert self.shuffle
|
349 |
+
|
350 |
+
def start_epoch(self, epoch: int):
|
351 |
+
self.current_epoch = epoch
|
352 |
|
353 |
def __len__(self):
|
354 |
return self.num_samples
|
355 |
|
356 |
def _get_sampling_probabilities(self, normalized: bool = True):
|
357 |
+
"""Return the sampling probabilities for each file inside `self.meta`."""
|
|
|
358 |
scores: tp.List[float] = []
|
359 |
for file_meta in self.meta:
|
360 |
score = 1.
|
|
|
368 |
probabilities /= probabilities.sum()
|
369 |
return probabilities
|
370 |
|
371 |
+
@staticmethod
|
372 |
+
@lru_cache(16)
|
373 |
+
def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
|
374 |
+
# Used to keep the most recent files permutation in memory implicitely.
|
375 |
+
# will work unless someone is using a lot of Datasets in parallel.
|
376 |
+
rng = torch.Generator()
|
377 |
+
rng.manual_seed(base_seed + permutation_index)
|
378 |
+
return torch.randperm(num_files, generator=rng)
|
379 |
+
|
380 |
+
def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
|
381 |
+
"""Sample a given file from `self.meta`. Can be overridden in subclasses.
|
382 |
This is only called if `segment_duration` is not None.
|
383 |
|
384 |
You must use the provided random number generator `rng` for reproducibility.
|
385 |
+
You can further make use of the index accessed.
|
386 |
"""
|
387 |
+
if self.permutation_on_files:
|
388 |
+
assert self.current_epoch is not None
|
389 |
+
total_index = self.current_epoch * len(self) + index
|
390 |
+
permutation_index = total_index // len(self.meta)
|
391 |
+
relative_index = total_index % len(self.meta)
|
392 |
+
permutation = AudioDataset._get_file_permutation(
|
393 |
+
len(self.meta), permutation_index, self.shuffle_seed)
|
394 |
+
file_index = permutation[relative_index]
|
395 |
+
return self.meta[file_index]
|
396 |
+
|
397 |
if not self.sample_on_weight and not self.sample_on_duration:
|
398 |
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
|
399 |
else:
|
|
|
401 |
|
402 |
return self.meta[file_index]
|
403 |
|
404 |
+
def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
|
405 |
+
# Override this method in subclass if needed.
|
406 |
+
if self.load_wav:
|
407 |
+
return audio_read(path, seek_time, duration, pad=False)
|
408 |
+
else:
|
409 |
+
assert self.segment_duration is not None
|
410 |
+
n_frames = int(self.sample_rate * self.segment_duration)
|
411 |
+
return torch.zeros(self.channels, n_frames), self.sample_rate
|
412 |
+
|
413 |
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
|
414 |
if self.segment_duration is None:
|
415 |
file_meta = self.meta[index]
|
|
|
417 |
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
418 |
n_frames = out.shape[-1]
|
419 |
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
|
420 |
+
sample_rate=self.sample_rate, channels=out.shape[0])
|
421 |
else:
|
422 |
rng = torch.Generator()
|
423 |
if self.shuffle:
|
424 |
+
# We use index, plus extra randomness, either totally random if we don't know the epoch.
|
425 |
+
# otherwise we make use of the epoch number and optional shuffle_seed.
|
426 |
+
if self.current_epoch is None:
|
427 |
+
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
|
428 |
+
else:
|
429 |
+
rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
|
430 |
else:
|
431 |
# We only use index
|
432 |
rng.manual_seed(index)
|
433 |
|
434 |
for retry in range(self.max_read_retry):
|
435 |
+
file_meta = self.sample_file(index, rng)
|
436 |
# We add some variance in the file position even if audio file is smaller than segment
|
437 |
# without ending up with empty segments
|
438 |
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
|
|
|
445 |
if self.pad:
|
446 |
out = F.pad(out, (0, target_frames - n_frames))
|
447 |
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
|
448 |
+
sample_rate=self.sample_rate, channels=out.shape[0])
|
449 |
except Exception as exc:
|
450 |
logger.warning("Error opening file %s: %r", file_meta.path, exc)
|
451 |
if retry == self.max_read_retry - 1:
|
|
|
487 |
if to_pad:
|
488 |
# Each wav could be of a different duration as they are not segmented.
|
489 |
for i in range(len(samples)):
|
490 |
+
# Determines the total length of the signal with padding, so we update here as we pad.
|
491 |
segment_infos[i].total_frames = max_len
|
492 |
wavs[i] = _pad_wav(wavs[i])
|
493 |
|
|
|
500 |
return torch.stack(samples)
|
501 |
|
502 |
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
503 |
+
"""Filters out audio files with audio durations that will not allow to sample examples from them."""
|
|
|
|
|
504 |
orig_len = len(meta)
|
505 |
|
506 |
# Filter data that is too short.
|
audiocraft/data/audio_utils.py
CHANGED
@@ -3,7 +3,8 @@
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
|
|
7 |
import sys
|
8 |
import typing as tp
|
9 |
|
@@ -150,17 +151,19 @@ def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
|
150 |
"""
|
151 |
if wav.dtype.is_floating_point:
|
152 |
return wav
|
153 |
-
|
154 |
-
assert wav.dtype == torch.int16
|
155 |
return wav.float() / 2**15
|
|
|
|
|
|
|
156 |
|
157 |
|
158 |
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
|
159 |
"""Convert audio to int 16 bits PCM format.
|
160 |
|
161 |
-
..Warning:: There exist many formula for doing this
|
162 |
-
due to the
|
163 |
-
or
|
164 |
it is possible that `i16_pcm(f32_pcm)) != Identity`.
|
165 |
"""
|
166 |
if wav.dtype.is_floating_point:
|
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Various utilities for audio convertion (pcm format, sample rate and channels),
|
7 |
+
and volume normalization."""
|
8 |
import sys
|
9 |
import typing as tp
|
10 |
|
|
|
151 |
"""
|
152 |
if wav.dtype.is_floating_point:
|
153 |
return wav
|
154 |
+
elif wav.dtype == torch.int16:
|
|
|
155 |
return wav.float() / 2**15
|
156 |
+
elif wav.dtype == torch.int32:
|
157 |
+
return wav.float() / 2**31
|
158 |
+
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
159 |
|
160 |
|
161 |
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
|
162 |
"""Convert audio to int 16 bits PCM format.
|
163 |
|
164 |
+
..Warning:: There exist many formula for doing this conversion. None are perfect
|
165 |
+
due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
|
166 |
+
or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
|
167 |
it is possible that `i16_pcm(f32_pcm)) != Identity`.
|
168 |
"""
|
169 |
if wav.dtype.is_floating_point:
|
audiocraft/data/info_audio_dataset.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Base classes for the datasets that also provide non-audio metadata,
|
7 |
+
e.g. description, text transcription etc.
|
8 |
+
"""
|
9 |
+
from dataclasses import dataclass
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
import re
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from .audio_dataset import AudioDataset, AudioMeta
|
18 |
+
from ..environment import AudioCraftEnvironment
|
19 |
+
from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
|
26 |
+
"""Monkey-patch meta to match cluster specificities."""
|
27 |
+
meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
|
28 |
+
if meta.info_path is not None:
|
29 |
+
meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
|
30 |
+
return meta
|
31 |
+
|
32 |
+
|
33 |
+
def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
34 |
+
"""Monkey-patch all meta to match cluster specificities."""
|
35 |
+
return [_clusterify_meta(m) for m in meta]
|
36 |
+
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class AudioInfo(SegmentWithAttributes):
|
40 |
+
"""Dummy SegmentInfo with empty attributes.
|
41 |
+
|
42 |
+
The InfoAudioDataset is expected to return metadata that inherits
|
43 |
+
from SegmentWithAttributes class and can return conditioning attributes.
|
44 |
+
|
45 |
+
This basically guarantees all datasets will be compatible with current
|
46 |
+
solver that contain conditioners requiring this.
|
47 |
+
"""
|
48 |
+
audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
|
49 |
+
|
50 |
+
def to_condition_attributes(self) -> ConditioningAttributes:
|
51 |
+
return ConditioningAttributes()
|
52 |
+
|
53 |
+
|
54 |
+
class InfoAudioDataset(AudioDataset):
|
55 |
+
"""AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
|
56 |
+
|
57 |
+
See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
|
58 |
+
"""
|
59 |
+
def __init__(self, meta: tp.List[AudioMeta], **kwargs):
|
60 |
+
super().__init__(clusterify_all_meta(meta), **kwargs)
|
61 |
+
|
62 |
+
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
|
63 |
+
if not self.return_info:
|
64 |
+
wav = super().__getitem__(index)
|
65 |
+
assert isinstance(wav, torch.Tensor)
|
66 |
+
return wav
|
67 |
+
wav, meta = super().__getitem__(index)
|
68 |
+
return wav, AudioInfo(**meta.to_dict())
|
69 |
+
|
70 |
+
|
71 |
+
def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
|
72 |
+
"""Preprocess a single keyword or possible a list of keywords."""
|
73 |
+
if isinstance(value, list):
|
74 |
+
return get_keyword_list(value)
|
75 |
+
else:
|
76 |
+
return get_keyword(value)
|
77 |
+
|
78 |
+
|
79 |
+
def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
|
80 |
+
"""Preprocess a single keyword."""
|
81 |
+
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
82 |
+
return None
|
83 |
+
else:
|
84 |
+
return value.strip()
|
85 |
+
|
86 |
+
|
87 |
+
def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
|
88 |
+
"""Preprocess a single keyword."""
|
89 |
+
if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
|
90 |
+
return None
|
91 |
+
else:
|
92 |
+
return value.strip().lower()
|
93 |
+
|
94 |
+
|
95 |
+
def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
|
96 |
+
"""Preprocess a list of keywords."""
|
97 |
+
if isinstance(values, str):
|
98 |
+
values = [v.strip() for v in re.split(r'[,\s]', values)]
|
99 |
+
elif isinstance(values, float) and math.isnan(values):
|
100 |
+
values = []
|
101 |
+
if not isinstance(values, list):
|
102 |
+
logger.debug(f"Unexpected keyword list {values}")
|
103 |
+
values = [str(values)]
|
104 |
+
|
105 |
+
kws = [get_keyword(v) for v in values]
|
106 |
+
kw_list = [k for k in kws if k is not None]
|
107 |
+
if len(kw_list) == 0:
|
108 |
+
return None
|
109 |
+
else:
|
110 |
+
return kw_list
|
audiocraft/data/zip.py
CHANGED
@@ -3,6 +3,8 @@
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
6 |
|
7 |
import typing
|
8 |
import zipfile
|
@@ -18,13 +20,13 @@ MODE = Literal['r', 'w', 'x', 'a']
|
|
18 |
|
19 |
@dataclass(order=True)
|
20 |
class PathInZip:
|
21 |
-
"""
|
22 |
|
23 |
Args:
|
24 |
-
path: The convention is <path_to_zip>:<relative_path_inside_zip
|
25 |
Let's assume there is a zip file /some/location/foo.zip
|
26 |
and inside of it is a json file located at /data/file1.json,
|
27 |
-
Then we expect path = "/some/location/foo.zip:/data/file1.json"
|
28 |
"""
|
29 |
|
30 |
INFO_PATH_SEP = ':'
|
@@ -55,7 +57,7 @@ def set_zip_cache_size(max_size: int):
|
|
55 |
"""Sets the maximal LRU caching for zip file opening.
|
56 |
|
57 |
Args:
|
58 |
-
max_size: the maximal LRU cache.
|
59 |
"""
|
60 |
global _cached_open_zip
|
61 |
_cached_open_zip = lru_cache(max_size)(_open_zip)
|
@@ -65,8 +67,8 @@ def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
|
|
65 |
"""Opens a file stored inside a zip and returns a file-like object.
|
66 |
|
67 |
Args:
|
68 |
-
path_in_zip: A PathInZip object representing the file to return a file-like object of.
|
69 |
-
mode: The mode in which to open the file with.
|
70 |
Returns:
|
71 |
A file-like object for PathInZip.
|
72 |
"""
|
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Utility for reading some info from inside a zip file.
|
7 |
+
"""
|
8 |
|
9 |
import typing
|
10 |
import zipfile
|
|
|
20 |
|
21 |
@dataclass(order=True)
|
22 |
class PathInZip:
|
23 |
+
"""Hold a path of file within a zip file.
|
24 |
|
25 |
Args:
|
26 |
+
path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
|
27 |
Let's assume there is a zip file /some/location/foo.zip
|
28 |
and inside of it is a json file located at /data/file1.json,
|
29 |
+
Then we expect path = "/some/location/foo.zip:/data/file1.json".
|
30 |
"""
|
31 |
|
32 |
INFO_PATH_SEP = ':'
|
|
|
57 |
"""Sets the maximal LRU caching for zip file opening.
|
58 |
|
59 |
Args:
|
60 |
+
max_size (int): the maximal LRU cache.
|
61 |
"""
|
62 |
global _cached_open_zip
|
63 |
_cached_open_zip = lru_cache(max_size)(_open_zip)
|
|
|
67 |
"""Opens a file stored inside a zip and returns a file-like object.
|
68 |
|
69 |
Args:
|
70 |
+
path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
|
71 |
+
mode (str): The mode in which to open the file with.
|
72 |
Returns:
|
73 |
A file-like object for PathInZip.
|
74 |
"""
|
audiocraft/environment.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Provides cluster and tools configuration across clusters (slurm, dora, utilities).
|
9 |
+
"""
|
10 |
+
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
from pathlib import Path
|
14 |
+
import re
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import omegaconf
|
18 |
+
|
19 |
+
from .utils.cluster import _guess_cluster_type
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class AudioCraftEnvironment:
|
26 |
+
"""Environment configuration for teams and clusters.
|
27 |
+
|
28 |
+
AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
|
29 |
+
or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
|
30 |
+
provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
|
31 |
+
allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
|
32 |
+
map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
|
33 |
+
|
34 |
+
The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
|
35 |
+
Use the following environment variables to specify the cluster, team or configuration:
|
36 |
+
|
37 |
+
AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
|
38 |
+
cannot be inferred automatically.
|
39 |
+
AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
|
40 |
+
If not set, configuration is read from config/teams.yaml.
|
41 |
+
AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
|
42 |
+
Cluster configuration are shared across teams to match compute allocation,
|
43 |
+
specify your cluster configuration in the configuration file under a key mapping
|
44 |
+
your team name.
|
45 |
+
"""
|
46 |
+
_instance = None
|
47 |
+
DEFAULT_TEAM = "default"
|
48 |
+
|
49 |
+
def __init__(self) -> None:
|
50 |
+
"""Loads configuration."""
|
51 |
+
self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
|
52 |
+
cluster_type = _guess_cluster_type()
|
53 |
+
cluster = os.getenv(
|
54 |
+
"AUDIOCRAFT_CLUSTER", cluster_type.value
|
55 |
+
)
|
56 |
+
logger.info("Detecting cluster type %s", cluster_type)
|
57 |
+
|
58 |
+
self.cluster: str = cluster
|
59 |
+
|
60 |
+
config_path = os.getenv(
|
61 |
+
"AUDIOCRAFT_CONFIG",
|
62 |
+
Path(__file__)
|
63 |
+
.parent.parent.joinpath("config/teams", self.team)
|
64 |
+
.with_suffix(".yaml"),
|
65 |
+
)
|
66 |
+
self.config = omegaconf.OmegaConf.load(config_path)
|
67 |
+
self._dataset_mappers = []
|
68 |
+
cluster_config = self._get_cluster_config()
|
69 |
+
if "dataset_mappers" in cluster_config:
|
70 |
+
for pattern, repl in cluster_config["dataset_mappers"].items():
|
71 |
+
regex = re.compile(pattern)
|
72 |
+
self._dataset_mappers.append((regex, repl))
|
73 |
+
|
74 |
+
def _get_cluster_config(self) -> omegaconf.DictConfig:
|
75 |
+
assert isinstance(self.config, omegaconf.DictConfig)
|
76 |
+
return self.config[self.cluster]
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def instance(cls):
|
80 |
+
if cls._instance is None:
|
81 |
+
cls._instance = cls()
|
82 |
+
return cls._instance
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def reset(cls):
|
86 |
+
"""Clears the environment and forces a reload on next invocation."""
|
87 |
+
cls._instance = None
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def get_team(cls) -> str:
|
91 |
+
"""Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
|
92 |
+
If not defined, defaults to "labs".
|
93 |
+
"""
|
94 |
+
return cls.instance().team
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def get_cluster(cls) -> str:
|
98 |
+
"""Gets the detected cluster.
|
99 |
+
This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
|
100 |
+
"""
|
101 |
+
return cls.instance().cluster
|
102 |
+
|
103 |
+
@classmethod
|
104 |
+
def get_dora_dir(cls) -> Path:
|
105 |
+
"""Gets the path to the dora directory for the current team and cluster.
|
106 |
+
Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
|
107 |
+
"""
|
108 |
+
cluster_config = cls.instance()._get_cluster_config()
|
109 |
+
dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
|
110 |
+
logger.warning(f"Dora directory: {dora_dir}")
|
111 |
+
return Path(dora_dir)
|
112 |
+
|
113 |
+
@classmethod
|
114 |
+
def get_reference_dir(cls) -> Path:
|
115 |
+
"""Gets the path to the reference directory for the current team and cluster.
|
116 |
+
Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
|
117 |
+
"""
|
118 |
+
cluster_config = cls.instance()._get_cluster_config()
|
119 |
+
return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
|
120 |
+
|
121 |
+
@classmethod
|
122 |
+
def get_slurm_exclude(cls) -> tp.Optional[str]:
|
123 |
+
"""Get the list of nodes to exclude for that cluster."""
|
124 |
+
cluster_config = cls.instance()._get_cluster_config()
|
125 |
+
return cluster_config.get("slurm_exclude")
|
126 |
+
|
127 |
+
@classmethod
|
128 |
+
def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
|
129 |
+
"""Gets the requested partitions for the current team and cluster as a comma-separated string.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
partition_types (list[str], optional): partition types to retrieve. Values must be
|
133 |
+
from ['global', 'team']. If not provided, the global partition is returned.
|
134 |
+
"""
|
135 |
+
if not partition_types:
|
136 |
+
partition_types = ["global"]
|
137 |
+
|
138 |
+
cluster_config = cls.instance()._get_cluster_config()
|
139 |
+
partitions = [
|
140 |
+
cluster_config["partitions"][partition_type]
|
141 |
+
for partition_type in partition_types
|
142 |
+
]
|
143 |
+
return ",".join(partitions)
|
144 |
+
|
145 |
+
@classmethod
|
146 |
+
def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
|
147 |
+
"""Converts reference placeholder in path with configured reference dir to resolve paths.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
path (str or Path): Path to resolve.
|
151 |
+
Returns:
|
152 |
+
Path: Resolved path.
|
153 |
+
"""
|
154 |
+
path = str(path)
|
155 |
+
|
156 |
+
if path.startswith("//reference"):
|
157 |
+
reference_dir = cls.get_reference_dir()
|
158 |
+
logger.warn(f"Reference directory: {reference_dir}")
|
159 |
+
assert (
|
160 |
+
reference_dir.exists() and reference_dir.is_dir()
|
161 |
+
), f"Reference directory does not exist: {reference_dir}."
|
162 |
+
path = re.sub("^//reference", str(reference_dir), path)
|
163 |
+
|
164 |
+
return Path(path)
|
165 |
+
|
166 |
+
@classmethod
|
167 |
+
def apply_dataset_mappers(cls, path: str) -> str:
|
168 |
+
"""Applies dataset mapping regex rules as defined in the configuration.
|
169 |
+
If no rules are defined, the path is returned as-is.
|
170 |
+
"""
|
171 |
+
instance = cls.instance()
|
172 |
+
|
173 |
+
for pattern, repl in instance._dataset_mappers:
|
174 |
+
path = pattern.sub(repl, path)
|
175 |
+
|
176 |
+
return path
|
audiocraft/models/__init__.py
CHANGED
@@ -4,7 +4,14 @@
|
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
|
|
|
|
|
|
7 |
# flake8: noqa
|
|
|
|
|
|
|
|
|
8 |
from .musicgen import MusicGen
|
9 |
from .lm import LMModel
|
10 |
from .encodec import CompressionModel, EncodecModel
|
|
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
+
"""
|
8 |
+
Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
|
9 |
+
"""
|
10 |
# flake8: noqa
|
11 |
+
from . import builders, loaders
|
12 |
+
from .encodec import (
|
13 |
+
CompressionModel, EncodecModel, DAC,
|
14 |
+
HFEncodecModel, HFEncodecCompressionModel)
|
15 |
from .musicgen import MusicGen
|
16 |
from .lm import LMModel
|
17 |
from .encodec import CompressionModel, EncodecModel
|
audiocraft/models/builders.py
CHANGED
@@ -10,32 +10,34 @@ from the Hydra config.
|
|
10 |
"""
|
11 |
|
12 |
import typing as tp
|
13 |
-
import warnings
|
14 |
|
15 |
import audiocraft
|
16 |
import omegaconf
|
17 |
import torch
|
18 |
|
19 |
-
from .encodec import CompressionModel, EncodecModel,
|
20 |
from .lm import LMModel
|
21 |
from ..modules.codebooks_patterns import (
|
22 |
CodebooksPatternProvider,
|
23 |
DelayedPatternProvider,
|
|
|
24 |
ParallelPatternProvider,
|
25 |
UnrolledPatternProvider,
|
26 |
-
|
27 |
-
MusicLMPattern,
|
28 |
)
|
29 |
from ..modules.conditioners import (
|
30 |
BaseConditioner,
|
|
|
|
|
|
|
31 |
ConditioningProvider,
|
32 |
LUTConditioner,
|
33 |
T5Conditioner,
|
34 |
-
ConditionFuser,
|
35 |
-
ChromaStemConditioner,
|
36 |
)
|
|
|
37 |
from .. import quantization as qt
|
38 |
from ..utils.utils import dict_from_config
|
|
|
39 |
|
40 |
|
41 |
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
|
@@ -60,12 +62,11 @@ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
|
60 |
decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
|
61 |
return encoder, decoder
|
62 |
else:
|
63 |
-
raise KeyError(f
|
64 |
|
65 |
|
66 |
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
|
67 |
-
"""Instantiate a compression model.
|
68 |
-
"""
|
69 |
if cfg.compression_model == 'encodec':
|
70 |
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
71 |
encoder_name = kwargs.pop('autoencoder')
|
@@ -73,20 +74,17 @@ def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
|
|
73 |
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
74 |
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
75 |
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
76 |
-
renormalize = kwargs.pop('renormalize',
|
77 |
-
|
78 |
-
|
79 |
-
renormalize = renorm is not None
|
80 |
-
warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.")
|
81 |
return EncodecModel(encoder, decoder, quantizer,
|
82 |
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
83 |
else:
|
84 |
-
raise KeyError(f
|
85 |
|
86 |
|
87 |
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
88 |
-
"""Instantiate a transformer LM.
|
89 |
-
"""
|
90 |
if cfg.lm_model == 'transformer_lm':
|
91 |
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
|
92 |
n_q = kwargs['n_q']
|
@@ -94,14 +92,14 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
|
94 |
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
95 |
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
96 |
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
97 |
-
cfg_prob, cfg_coef = cls_free_guidance[
|
98 |
fuser = get_condition_fuser(cfg)
|
99 |
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
|
100 |
-
if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att
|
101 |
kwargs['cross_attention'] = True
|
102 |
if codebooks_pattern_cfg.modeling is None:
|
103 |
assert q_modeling is not None, \
|
104 |
-
|
105 |
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
|
106 |
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
|
107 |
)
|
@@ -118,45 +116,50 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
|
118 |
**kwargs
|
119 |
).to(cfg.device)
|
120 |
else:
|
121 |
-
raise KeyError(f
|
122 |
|
123 |
|
124 |
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
|
125 |
-
"""Instantiate a conditioning model.
|
126 |
-
"""
|
127 |
device = cfg.device
|
128 |
duration = cfg.dataset.segment_duration
|
129 |
-
cfg = getattr(cfg,
|
130 |
-
|
131 |
conditioners: tp.Dict[str, BaseConditioner] = {}
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
136 |
model_args = cond_cfg[model_type]
|
137 |
-
if model_type ==
|
138 |
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
|
139 |
-
elif model_type ==
|
140 |
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
|
141 |
-
elif model_type ==
|
142 |
-
model_args.pop('cache_path', None)
|
143 |
conditioners[str(cond)] = ChromaStemConditioner(
|
144 |
output_dim=output_dim,
|
145 |
duration=duration,
|
146 |
device=device,
|
147 |
**model_args
|
148 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
else:
|
150 |
-
raise ValueError(f"
|
151 |
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
|
152 |
return conditioner
|
153 |
|
154 |
|
155 |
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
156 |
-
"""Instantiate a condition fuser object.
|
157 |
-
|
158 |
-
|
159 |
-
fuser_methods = ["sum", "cross", "prepend", "input_interpolate"]
|
160 |
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
|
161 |
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
162 |
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
@@ -164,13 +167,12 @@ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
|
164 |
|
165 |
|
166 |
def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
|
167 |
-
"""Instantiate a codebooks pattern provider object.
|
168 |
-
"""
|
169 |
pattern_providers = {
|
170 |
'parallel': ParallelPatternProvider,
|
171 |
'delay': DelayedPatternProvider,
|
172 |
'unroll': UnrolledPatternProvider,
|
173 |
-
'
|
174 |
'musiclm': MusicLMPattern,
|
175 |
}
|
176 |
name = cfg.modeling
|
@@ -179,14 +181,20 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb
|
|
179 |
return klass(n_q, **kwargs)
|
180 |
|
181 |
|
182 |
-
def get_debug_compression_model(device='cpu'):
|
183 |
-
"""Instantiate a debug compression model to be used for unit tests.
|
184 |
-
""
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
'n_filters': 4,
|
187 |
'n_residual_layers': 1,
|
188 |
'dimension': 32,
|
189 |
-
'ratios':
|
190 |
}
|
191 |
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
|
192 |
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
|
@@ -195,13 +203,31 @@ def get_debug_compression_model(device='cpu'):
|
|
195 |
quantizer(init_x, 1) # initialize kmeans etc.
|
196 |
compression_model = EncodecModel(
|
197 |
encoder, decoder, quantizer,
|
198 |
-
frame_rate=
|
199 |
return compression_model.eval()
|
200 |
|
201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
def get_debug_lm_model(device='cpu'):
|
203 |
-
"""Instantiate a debug LM to be used for unit tests.
|
204 |
-
"""
|
205 |
pattern = DelayedPatternProvider(n_q=4)
|
206 |
dim = 16
|
207 |
providers = {
|
@@ -216,3 +242,17 @@ def get_debug_lm_model(device='cpu'):
|
|
216 |
n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
|
217 |
cross_attention=True, causal=True)
|
218 |
return lm.to(device).eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
"""
|
11 |
|
12 |
import typing as tp
|
|
|
13 |
|
14 |
import audiocraft
|
15 |
import omegaconf
|
16 |
import torch
|
17 |
|
18 |
+
from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel
|
19 |
from .lm import LMModel
|
20 |
from ..modules.codebooks_patterns import (
|
21 |
CodebooksPatternProvider,
|
22 |
DelayedPatternProvider,
|
23 |
+
MusicLMPattern,
|
24 |
ParallelPatternProvider,
|
25 |
UnrolledPatternProvider,
|
26 |
+
CoarseFirstPattern,
|
|
|
27 |
)
|
28 |
from ..modules.conditioners import (
|
29 |
BaseConditioner,
|
30 |
+
ChromaStemConditioner,
|
31 |
+
CLAPEmbeddingConditioner,
|
32 |
+
ConditionFuser,
|
33 |
ConditioningProvider,
|
34 |
LUTConditioner,
|
35 |
T5Conditioner,
|
|
|
|
|
36 |
)
|
37 |
+
from .unet import DiffusionUnet
|
38 |
from .. import quantization as qt
|
39 |
from ..utils.utils import dict_from_config
|
40 |
+
from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
|
41 |
|
42 |
|
43 |
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
|
|
|
62 |
decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
|
63 |
return encoder, decoder
|
64 |
else:
|
65 |
+
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
66 |
|
67 |
|
68 |
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
|
69 |
+
"""Instantiate a compression model."""
|
|
|
70 |
if cfg.compression_model == 'encodec':
|
71 |
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
72 |
encoder_name = kwargs.pop('autoencoder')
|
|
|
74 |
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
75 |
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
76 |
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
77 |
+
renormalize = kwargs.pop('renormalize', False)
|
78 |
+
# deprecated params
|
79 |
+
kwargs.pop('renorm', None)
|
|
|
|
|
80 |
return EncodecModel(encoder, decoder, quantizer,
|
81 |
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
82 |
else:
|
83 |
+
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
84 |
|
85 |
|
86 |
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
87 |
+
"""Instantiate a transformer LM."""
|
|
|
88 |
if cfg.lm_model == 'transformer_lm':
|
89 |
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
|
90 |
n_q = kwargs['n_q']
|
|
|
92 |
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
93 |
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
94 |
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
95 |
+
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
|
96 |
fuser = get_condition_fuser(cfg)
|
97 |
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
|
98 |
+
if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
|
99 |
kwargs['cross_attention'] = True
|
100 |
if codebooks_pattern_cfg.modeling is None:
|
101 |
assert q_modeling is not None, \
|
102 |
+
"LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
|
103 |
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
|
104 |
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
|
105 |
)
|
|
|
116 |
**kwargs
|
117 |
).to(cfg.device)
|
118 |
else:
|
119 |
+
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
|
120 |
|
121 |
|
122 |
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
|
123 |
+
"""Instantiate a conditioning model."""
|
|
|
124 |
device = cfg.device
|
125 |
duration = cfg.dataset.segment_duration
|
126 |
+
cfg = getattr(cfg, 'conditioners')
|
127 |
+
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
128 |
conditioners: tp.Dict[str, BaseConditioner] = {}
|
129 |
+
condition_provider_args = dict_cfg.pop('args', {})
|
130 |
+
condition_provider_args.pop('merge_text_conditions_p', None)
|
131 |
+
condition_provider_args.pop('drop_desc_p', None)
|
132 |
+
|
133 |
+
for cond, cond_cfg in dict_cfg.items():
|
134 |
+
model_type = cond_cfg['model']
|
135 |
model_args = cond_cfg[model_type]
|
136 |
+
if model_type == 't5':
|
137 |
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
|
138 |
+
elif model_type == 'lut':
|
139 |
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
|
140 |
+
elif model_type == 'chroma_stem':
|
|
|
141 |
conditioners[str(cond)] = ChromaStemConditioner(
|
142 |
output_dim=output_dim,
|
143 |
duration=duration,
|
144 |
device=device,
|
145 |
**model_args
|
146 |
)
|
147 |
+
elif model_type == 'clap':
|
148 |
+
conditioners[str(cond)] = CLAPEmbeddingConditioner(
|
149 |
+
output_dim=output_dim,
|
150 |
+
device=device,
|
151 |
+
**model_args
|
152 |
+
)
|
153 |
else:
|
154 |
+
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
155 |
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
|
156 |
return conditioner
|
157 |
|
158 |
|
159 |
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
160 |
+
"""Instantiate a condition fuser object."""
|
161 |
+
fuser_cfg = getattr(cfg, 'fuser')
|
162 |
+
fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
|
|
|
163 |
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
|
164 |
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
165 |
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
|
|
167 |
|
168 |
|
169 |
def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
|
170 |
+
"""Instantiate a codebooks pattern provider object."""
|
|
|
171 |
pattern_providers = {
|
172 |
'parallel': ParallelPatternProvider,
|
173 |
'delay': DelayedPatternProvider,
|
174 |
'unroll': UnrolledPatternProvider,
|
175 |
+
'coarse_first': CoarseFirstPattern,
|
176 |
'musiclm': MusicLMPattern,
|
177 |
}
|
178 |
name = cfg.modeling
|
|
|
181 |
return klass(n_q, **kwargs)
|
182 |
|
183 |
|
184 |
+
def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
|
185 |
+
"""Instantiate a debug compression model to be used for unit tests."""
|
186 |
+
assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
|
187 |
+
model_ratios = {
|
188 |
+
16000: [10, 8, 8], # 25 Hz at 16kHz
|
189 |
+
32000: [10, 8, 16] # 25 Hz at 32kHz
|
190 |
+
}
|
191 |
+
ratios: tp.List[int] = model_ratios[sample_rate]
|
192 |
+
frame_rate = 25
|
193 |
+
seanet_kwargs: dict = {
|
194 |
'n_filters': 4,
|
195 |
'n_residual_layers': 1,
|
196 |
'dimension': 32,
|
197 |
+
'ratios': ratios,
|
198 |
}
|
199 |
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
|
200 |
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
|
|
|
203 |
quantizer(init_x, 1) # initialize kmeans etc.
|
204 |
compression_model = EncodecModel(
|
205 |
encoder, decoder, quantizer,
|
206 |
+
frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
|
207 |
return compression_model.eval()
|
208 |
|
209 |
|
210 |
+
def get_diffusion_model(cfg: omegaconf.DictConfig):
|
211 |
+
# TODO Find a way to infer the channels from dset
|
212 |
+
channels = cfg.channels
|
213 |
+
num_steps = cfg.schedule.num_steps
|
214 |
+
return DiffusionUnet(
|
215 |
+
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
216 |
+
|
217 |
+
|
218 |
+
def get_processor(cfg, sample_rate: int = 24000):
|
219 |
+
sample_processor = SampleProcessor()
|
220 |
+
if cfg.use:
|
221 |
+
kw = dict(cfg)
|
222 |
+
kw.pop('use')
|
223 |
+
kw.pop('name')
|
224 |
+
if cfg.name == "multi_band_processor":
|
225 |
+
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
|
226 |
+
return sample_processor
|
227 |
+
|
228 |
+
|
229 |
def get_debug_lm_model(device='cpu'):
|
230 |
+
"""Instantiate a debug LM to be used for unit tests."""
|
|
|
231 |
pattern = DelayedPatternProvider(n_q=4)
|
232 |
dim = 16
|
233 |
providers = {
|
|
|
242 |
n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
|
243 |
cross_attention=True, causal=True)
|
244 |
return lm.to(device).eval()
|
245 |
+
|
246 |
+
|
247 |
+
def get_wrapped_compression_model(
|
248 |
+
compression_model: CompressionModel,
|
249 |
+
cfg: omegaconf.DictConfig) -> CompressionModel:
|
250 |
+
if hasattr(cfg, 'interleave_stereo_codebooks'):
|
251 |
+
if cfg.interleave_stereo_codebooks.use:
|
252 |
+
kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
|
253 |
+
kwargs.pop('use')
|
254 |
+
compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
|
255 |
+
if hasattr(cfg, 'compression_model_n_q'):
|
256 |
+
if cfg.compression_model_n_q is not None:
|
257 |
+
compression_model.set_num_codebooks(cfg.compression_model_n_q)
|
258 |
+
return compression_model
|
audiocraft/models/encodec.py
CHANGED
@@ -3,18 +3,32 @@
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
6 |
|
7 |
from abc import ABC, abstractmethod
|
|
|
|
|
|
|
8 |
import typing as tp
|
9 |
|
10 |
from einops import rearrange
|
|
|
11 |
import torch
|
12 |
from torch import nn
|
|
|
13 |
|
14 |
from .. import quantization as qt
|
15 |
|
16 |
|
|
|
|
|
|
|
17 |
class CompressionModel(ABC, nn.Module):
|
|
|
|
|
|
|
18 |
|
19 |
@abstractmethod
|
20 |
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
@@ -22,12 +36,17 @@ class CompressionModel(ABC, nn.Module):
|
|
22 |
|
23 |
@abstractmethod
|
24 |
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
25 |
-
"""See `EncodecModel.encode
|
26 |
...
|
27 |
|
28 |
@abstractmethod
|
29 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
30 |
-
"""See `EncodecModel.decode
|
|
|
|
|
|
|
|
|
|
|
31 |
...
|
32 |
|
33 |
@property
|
@@ -37,7 +56,7 @@ class CompressionModel(ABC, nn.Module):
|
|
37 |
|
38 |
@property
|
39 |
@abstractmethod
|
40 |
-
def frame_rate(self) ->
|
41 |
...
|
42 |
|
43 |
@property
|
@@ -62,10 +81,46 @@ class CompressionModel(ABC, nn.Module):
|
|
62 |
|
63 |
@abstractmethod
|
64 |
def set_num_codebooks(self, n: int):
|
65 |
-
"""Set the active number of codebooks used by the quantizer.
|
66 |
-
"""
|
67 |
...
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
class EncodecModel(CompressionModel):
|
71 |
"""Encodec model operating on the raw waveform.
|
@@ -80,9 +135,9 @@ class EncodecModel(CompressionModel):
|
|
80 |
causal (bool): Whether to use a causal version of the model.
|
81 |
renormalize (bool): Whether to renormalize the audio before running the model.
|
82 |
"""
|
83 |
-
# we need
|
84 |
# I couldn't find a better way...
|
85 |
-
frame_rate:
|
86 |
sample_rate: int = 0
|
87 |
channels: int = 0
|
88 |
|
@@ -111,25 +166,21 @@ class EncodecModel(CompressionModel):
|
|
111 |
|
112 |
@property
|
113 |
def total_codebooks(self):
|
114 |
-
"""Total number of quantizer codebooks available.
|
115 |
-
"""
|
116 |
return self.quantizer.total_codebooks
|
117 |
|
118 |
@property
|
119 |
def num_codebooks(self):
|
120 |
-
"""Active number of codebooks used by the quantizer.
|
121 |
-
"""
|
122 |
return self.quantizer.num_codebooks
|
123 |
|
124 |
def set_num_codebooks(self, n: int):
|
125 |
-
"""Set the active number of codebooks used by the quantizer.
|
126 |
-
"""
|
127 |
self.quantizer.set_num_codebooks(n)
|
128 |
|
129 |
@property
|
130 |
def cardinality(self):
|
131 |
-
"""Cardinality of each codebook.
|
132 |
-
"""
|
133 |
return self.quantizer.bins
|
134 |
|
135 |
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
@@ -176,7 +227,7 @@ class EncodecModel(CompressionModel):
|
|
176 |
x (torch.Tensor): Float tensor of shape [B, C, T]
|
177 |
|
178 |
Returns:
|
179 |
-
codes, scale (
|
180 |
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
181 |
scale a float tensor containing the scale for audio renormalizealization.
|
182 |
"""
|
@@ -192,41 +243,174 @@ class EncodecModel(CompressionModel):
|
|
192 |
|
193 |
Args:
|
194 |
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
195 |
-
scale (
|
196 |
|
197 |
Returns:
|
198 |
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
199 |
"""
|
200 |
-
emb = self.
|
201 |
out = self.decoder(emb)
|
202 |
out = self.postprocess(out, scale)
|
203 |
# out contains extra padding added by the encoder and decoder
|
204 |
return out
|
205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
Args:
|
214 |
-
model (CompressionModel):
|
215 |
-
|
216 |
-
|
217 |
-
extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1,
|
218 |
-
if each codebook has a cardinality N, then the first codebook will
|
219 |
-
use the range [0, N - 1], and the second [N, 2 N - 1] etc.
|
220 |
-
On decoding, this can lead to potentially invalid sequences.
|
221 |
-
Any invalid entry will be silently remapped to the proper range
|
222 |
-
with a modulo.
|
223 |
"""
|
224 |
-
def __init__(self, model: CompressionModel,
|
225 |
-
extend_cardinality: bool = True):
|
226 |
super().__init__()
|
227 |
self.model = model
|
228 |
-
self.
|
229 |
-
self.
|
230 |
|
231 |
@property
|
232 |
def total_codebooks(self):
|
@@ -236,30 +420,27 @@ class FlattenedCompressionModel(CompressionModel):
|
|
236 |
def num_codebooks(self):
|
237 |
"""Active number of codebooks used by the quantizer.
|
238 |
|
239 |
-
..Warning:: this reports the number of codebooks after the
|
240 |
of the codebooks!
|
241 |
"""
|
242 |
-
|
243 |
-
return self.codebooks_per_step
|
244 |
|
245 |
def set_num_codebooks(self, n: int):
|
246 |
"""Set the active number of codebooks used by the quantizer.
|
247 |
|
248 |
-
..Warning:: this sets the number of codebooks
|
249 |
-
of the codebooks.
|
250 |
"""
|
251 |
-
assert n % self.codebooks_per_step == 0
|
252 |
self.model.set_num_codebooks(n)
|
253 |
|
254 |
@property
|
255 |
-
def num_virtual_steps(self) ->
|
256 |
"""Return the number of virtual steps, e.g. one real step
|
257 |
will be split into that many steps.
|
258 |
"""
|
259 |
-
return self.
|
260 |
|
261 |
@property
|
262 |
-
def frame_rate(self) ->
|
263 |
return self.model.frame_rate * self.num_virtual_steps
|
264 |
|
265 |
@property
|
@@ -268,35 +449,58 @@ class FlattenedCompressionModel(CompressionModel):
|
|
268 |
|
269 |
@property
|
270 |
def channels(self) -> int:
|
271 |
-
return
|
272 |
|
273 |
@property
|
274 |
def cardinality(self):
|
275 |
"""Cardinality of each codebook.
|
276 |
"""
|
277 |
-
|
278 |
-
return self.model.cardinality * self.num_virtual_steps
|
279 |
-
else:
|
280 |
-
return self.model.cardinality
|
281 |
|
282 |
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
283 |
raise NotImplementedError("Not supported, use encode and decode.")
|
284 |
|
285 |
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
return (indices, scales)
|
294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
296 |
B, K, T = codes.shape
|
297 |
-
assert T % self.num_virtual_steps == 0
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
#
|
4 |
# This source code is licensed under the license found in the
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Compression models or wrapper around existing models.
|
7 |
+
Also defines the main interface that a model must follow to be usable as an audio tokenizer.
|
8 |
+
"""
|
9 |
|
10 |
from abc import ABC, abstractmethod
|
11 |
+
import logging
|
12 |
+
import math
|
13 |
+
from pathlib import Path
|
14 |
import typing as tp
|
15 |
|
16 |
from einops import rearrange
|
17 |
+
import numpy as np
|
18 |
import torch
|
19 |
from torch import nn
|
20 |
+
from transformers import EncodecModel as HFEncodecModel
|
21 |
|
22 |
from .. import quantization as qt
|
23 |
|
24 |
|
25 |
+
logger = logging.getLogger()
|
26 |
+
|
27 |
+
|
28 |
class CompressionModel(ABC, nn.Module):
|
29 |
+
"""Base API for all compression model that aim at being used as audio tokenizers
|
30 |
+
with a language model.
|
31 |
+
"""
|
32 |
|
33 |
@abstractmethod
|
34 |
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
|
|
36 |
|
37 |
@abstractmethod
|
38 |
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
39 |
+
"""See `EncodecModel.encode`."""
|
40 |
...
|
41 |
|
42 |
@abstractmethod
|
43 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
44 |
+
"""See `EncodecModel.decode`."""
|
45 |
+
...
|
46 |
+
|
47 |
+
@abstractmethod
|
48 |
+
def decode_latent(self, codes: torch.Tensor):
|
49 |
+
"""Decode from the discrete codes to continuous latent space."""
|
50 |
...
|
51 |
|
52 |
@property
|
|
|
56 |
|
57 |
@property
|
58 |
@abstractmethod
|
59 |
+
def frame_rate(self) -> float:
|
60 |
...
|
61 |
|
62 |
@property
|
|
|
81 |
|
82 |
@abstractmethod
|
83 |
def set_num_codebooks(self, n: int):
|
84 |
+
"""Set the active number of codebooks used by the quantizer."""
|
|
|
85 |
...
|
86 |
|
87 |
+
@staticmethod
|
88 |
+
def get_pretrained(
|
89 |
+
name: str, device: tp.Union[torch.device, str] = 'cpu'
|
90 |
+
) -> 'CompressionModel':
|
91 |
+
"""Instantiate a CompressionModel from a given pretrained model.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
name (Path or str): name of the pretrained model. See after.
|
95 |
+
device (torch.device or str): Device on which the model is loaded.
|
96 |
+
|
97 |
+
Pretrained models:
|
98 |
+
- dac_44khz (https://github.com/descriptinc/descript-audio-codec)
|
99 |
+
- dac_24khz (same)
|
100 |
+
- facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
|
101 |
+
- facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
|
102 |
+
- your own model on HugginFace. Export instructions to come...
|
103 |
+
"""
|
104 |
+
|
105 |
+
from . import builders, loaders
|
106 |
+
model: CompressionModel
|
107 |
+
if name in ['dac_44khz', 'dac_24khz']:
|
108 |
+
model_type = name.split('_')[1]
|
109 |
+
logger.info("Getting pretrained compression model from DAC %s", model_type)
|
110 |
+
model = DAC(model_type)
|
111 |
+
elif name in ['debug_compression_model']:
|
112 |
+
logger.info("Getting pretrained compression model for debug")
|
113 |
+
model = builders.get_debug_compression_model()
|
114 |
+
elif Path(name).exists():
|
115 |
+
# We assume here if the paths exist that it is in fact an AC checkpoint
|
116 |
+
# that was exported using `audiocraft.utils.export` functions.
|
117 |
+
model = loaders.load_compression_model(name, device=device)
|
118 |
+
else:
|
119 |
+
logger.info("Getting pretrained compression model from HF %s", name)
|
120 |
+
hf_model = HFEncodecModel.from_pretrained(name)
|
121 |
+
model = HFEncodecCompressionModel(hf_model).to(device)
|
122 |
+
return model.to(device).eval()
|
123 |
+
|
124 |
|
125 |
class EncodecModel(CompressionModel):
|
126 |
"""Encodec model operating on the raw waveform.
|
|
|
135 |
causal (bool): Whether to use a causal version of the model.
|
136 |
renormalize (bool): Whether to renormalize the audio before running the model.
|
137 |
"""
|
138 |
+
# we need assignment to override the property in the abstract class,
|
139 |
# I couldn't find a better way...
|
140 |
+
frame_rate: float = 0
|
141 |
sample_rate: int = 0
|
142 |
channels: int = 0
|
143 |
|
|
|
166 |
|
167 |
@property
|
168 |
def total_codebooks(self):
|
169 |
+
"""Total number of quantizer codebooks available."""
|
|
|
170 |
return self.quantizer.total_codebooks
|
171 |
|
172 |
@property
|
173 |
def num_codebooks(self):
|
174 |
+
"""Active number of codebooks used by the quantizer."""
|
|
|
175 |
return self.quantizer.num_codebooks
|
176 |
|
177 |
def set_num_codebooks(self, n: int):
|
178 |
+
"""Set the active number of codebooks used by the quantizer."""
|
|
|
179 |
self.quantizer.set_num_codebooks(n)
|
180 |
|
181 |
@property
|
182 |
def cardinality(self):
|
183 |
+
"""Cardinality of each codebook."""
|
|
|
184 |
return self.quantizer.bins
|
185 |
|
186 |
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
|
|
227 |
x (torch.Tensor): Float tensor of shape [B, C, T]
|
228 |
|
229 |
Returns:
|
230 |
+
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
|
231 |
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
232 |
scale a float tensor containing the scale for audio renormalizealization.
|
233 |
"""
|
|
|
243 |
|
244 |
Args:
|
245 |
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
246 |
+
scale (torch.Tensor, optional): Float tensor containing the scale value.
|
247 |
|
248 |
Returns:
|
249 |
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
250 |
"""
|
251 |
+
emb = self.decode_latent(codes)
|
252 |
out = self.decoder(emb)
|
253 |
out = self.postprocess(out, scale)
|
254 |
# out contains extra padding added by the encoder and decoder
|
255 |
return out
|
256 |
|
257 |
+
def decode_latent(self, codes: torch.Tensor):
|
258 |
+
"""Decode from the discrete codes to continuous latent space."""
|
259 |
+
return self.quantizer.decode(codes)
|
260 |
+
|
261 |
+
|
262 |
+
class DAC(CompressionModel):
|
263 |
+
def __init__(self, model_type: str = "44khz"):
|
264 |
+
super().__init__()
|
265 |
+
try:
|
266 |
+
import dac.utils
|
267 |
+
except ImportError:
|
268 |
+
raise RuntimeError("Could not import dac, make sure it is installed, "
|
269 |
+
"please run `pip install descript-audio-codec`")
|
270 |
+
self.model = dac.utils.load_model(model_type=model_type)
|
271 |
+
self.n_quantizers = self.total_codebooks
|
272 |
+
self.model.eval()
|
273 |
+
|
274 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
275 |
+
# We don't support training with this.
|
276 |
+
raise NotImplementedError("Forward and training with DAC not supported.")
|
277 |
+
|
278 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
279 |
+
codes = self.model.encode(x, self.n_quantizers)[1]
|
280 |
+
return codes[:, :self.n_quantizers], None
|
281 |
+
|
282 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
283 |
+
assert scale is None
|
284 |
+
z_q = self.decode_latent(codes)
|
285 |
+
return self.model.decode(z_q)
|
286 |
+
|
287 |
+
def decode_latent(self, codes: torch.Tensor):
|
288 |
+
"""Decode from the discrete codes to continuous latent space."""
|
289 |
+
return self.model.quantizer.from_codes(codes)[0]
|
290 |
+
|
291 |
+
@property
|
292 |
+
def channels(self) -> int:
|
293 |
+
return 1
|
294 |
+
|
295 |
+
@property
|
296 |
+
def frame_rate(self) -> float:
|
297 |
+
return self.model.sample_rate / self.model.hop_length
|
298 |
+
|
299 |
+
@property
|
300 |
+
def sample_rate(self) -> int:
|
301 |
+
return self.model.sample_rate
|
302 |
+
|
303 |
+
@property
|
304 |
+
def cardinality(self) -> int:
|
305 |
+
return self.model.codebook_size
|
306 |
+
|
307 |
+
@property
|
308 |
+
def num_codebooks(self) -> int:
|
309 |
+
return self.n_quantizers
|
310 |
|
311 |
+
@property
|
312 |
+
def total_codebooks(self) -> int:
|
313 |
+
return self.model.n_codebooks
|
314 |
+
|
315 |
+
def set_num_codebooks(self, n: int):
|
316 |
+
"""Set the active number of codebooks used by the quantizer.
|
317 |
+
"""
|
318 |
+
assert n >= 1
|
319 |
+
assert n <= self.total_codebooks
|
320 |
+
self.n_quantizers = n
|
321 |
+
|
322 |
+
|
323 |
+
class HFEncodecCompressionModel(CompressionModel):
|
324 |
+
"""Wrapper around HuggingFace Encodec.
|
325 |
+
"""
|
326 |
+
def __init__(self, model: HFEncodecModel):
|
327 |
+
super().__init__()
|
328 |
+
self.model = model
|
329 |
+
bws = self.model.config.target_bandwidths
|
330 |
+
num_codebooks = [
|
331 |
+
bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
|
332 |
+
for bw in bws
|
333 |
+
]
|
334 |
+
deltas = [nc - int(nc) for nc in num_codebooks]
|
335 |
+
# Checking we didn't do some bad maths and we indeed have integers!
|
336 |
+
assert all(deltas) <= 1e-3, deltas
|
337 |
+
self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
|
338 |
+
self.set_num_codebooks(max(self.possible_num_codebooks))
|
339 |
+
|
340 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
341 |
+
# We don't support training with this.
|
342 |
+
raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
|
343 |
+
|
344 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
345 |
+
bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
|
346 |
+
bandwidth = self.model.config.target_bandwidths[bandwidth_index]
|
347 |
+
res = self.model.encode(x, None, bandwidth)
|
348 |
+
assert len(res[0]) == 1
|
349 |
+
assert len(res[1]) == 1
|
350 |
+
return res[0][0], res[1][0]
|
351 |
+
|
352 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
353 |
+
if scale is None:
|
354 |
+
scales = [None] # type: ignore
|
355 |
+
else:
|
356 |
+
scales = scale # type: ignore
|
357 |
+
res = self.model.decode(codes[None], scales)
|
358 |
+
return res[0]
|
359 |
+
|
360 |
+
def decode_latent(self, codes: torch.Tensor):
|
361 |
+
"""Decode from the discrete codes to continuous latent space."""
|
362 |
+
return self.model.quantizer.decode(codes.transpose(0, 1))
|
363 |
+
|
364 |
+
@property
|
365 |
+
def channels(self) -> int:
|
366 |
+
return self.model.config.audio_channels
|
367 |
+
|
368 |
+
@property
|
369 |
+
def frame_rate(self) -> float:
|
370 |
+
hop_length = int(np.prod(self.model.config.upsampling_ratios))
|
371 |
+
return self.sample_rate / hop_length
|
372 |
+
|
373 |
+
@property
|
374 |
+
def sample_rate(self) -> int:
|
375 |
+
return self.model.config.sampling_rate
|
376 |
+
|
377 |
+
@property
|
378 |
+
def cardinality(self) -> int:
|
379 |
+
return self.model.config.codebook_size
|
380 |
+
|
381 |
+
@property
|
382 |
+
def num_codebooks(self) -> int:
|
383 |
+
return self._num_codebooks
|
384 |
+
|
385 |
+
@property
|
386 |
+
def total_codebooks(self) -> int:
|
387 |
+
return max(self.possible_num_codebooks)
|
388 |
+
|
389 |
+
def set_num_codebooks(self, n: int):
|
390 |
+
"""Set the active number of codebooks used by the quantizer.
|
391 |
+
"""
|
392 |
+
if n not in self.possible_num_codebooks:
|
393 |
+
raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
|
394 |
+
self._num_codebooks = n
|
395 |
+
|
396 |
+
|
397 |
+
class InterleaveStereoCompressionModel(CompressionModel):
|
398 |
+
"""Wraps a CompressionModel to support stereo inputs. The wrapped model
|
399 |
+
will be applied independently to the left and right channels, and both codebooks
|
400 |
+
will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per
|
401 |
+
channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on
|
402 |
+
`per_timestep`.
|
403 |
|
404 |
Args:
|
405 |
+
model (CompressionModel): Compression model to wrap.
|
406 |
+
per_timestep (bool): Whether to interleave on the timestep dimension
|
407 |
+
or on the codebooks dimension.
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
"""
|
409 |
+
def __init__(self, model: CompressionModel, per_timestep: bool = False):
|
|
|
410 |
super().__init__()
|
411 |
self.model = model
|
412 |
+
self.per_timestep = per_timestep
|
413 |
+
assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio"
|
414 |
|
415 |
@property
|
416 |
def total_codebooks(self):
|
|
|
420 |
def num_codebooks(self):
|
421 |
"""Active number of codebooks used by the quantizer.
|
422 |
|
423 |
+
..Warning:: this reports the number of codebooks after the interleaving
|
424 |
of the codebooks!
|
425 |
"""
|
426 |
+
return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2
|
|
|
427 |
|
428 |
def set_num_codebooks(self, n: int):
|
429 |
"""Set the active number of codebooks used by the quantizer.
|
430 |
|
431 |
+
..Warning:: this sets the number of codebooks before the interleaving!
|
|
|
432 |
"""
|
|
|
433 |
self.model.set_num_codebooks(n)
|
434 |
|
435 |
@property
|
436 |
+
def num_virtual_steps(self) -> float:
|
437 |
"""Return the number of virtual steps, e.g. one real step
|
438 |
will be split into that many steps.
|
439 |
"""
|
440 |
+
return 2 if self.per_timestep else 1
|
441 |
|
442 |
@property
|
443 |
+
def frame_rate(self) -> float:
|
444 |
return self.model.frame_rate * self.num_virtual_steps
|
445 |
|
446 |
@property
|
|
|
449 |
|
450 |
@property
|
451 |
def channels(self) -> int:
|
452 |
+
return 2
|
453 |
|
454 |
@property
|
455 |
def cardinality(self):
|
456 |
"""Cardinality of each codebook.
|
457 |
"""
|
458 |
+
return self.model.cardinality
|
|
|
|
|
|
|
459 |
|
460 |
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
461 |
raise NotImplementedError("Not supported, use encode and decode.")
|
462 |
|
463 |
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
464 |
+
B, C, T = x.shape
|
465 |
+
assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}"
|
466 |
+
|
467 |
+
indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1))
|
468 |
+
indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1))
|
469 |
+
indices = torch.stack([indices_c0, indices_c1], dim=0)
|
470 |
+
scales: tp.Optional[torch.Tensor] = None
|
471 |
+
if scales_c0 is not None and scales_c1 is not None:
|
472 |
+
scales = torch.stack([scales_c0, scales_c1], dim=1)
|
473 |
+
|
474 |
+
if self.per_timestep:
|
475 |
+
indices = rearrange(indices, 'c b k t -> b k (t c)', c=2)
|
476 |
+
else:
|
477 |
+
indices = rearrange(indices, 'c b k t -> b (k c) t', c=2)
|
478 |
+
|
479 |
return (indices, scales)
|
480 |
|
481 |
+
def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
482 |
+
if self.per_timestep:
|
483 |
+
codes = rearrange(codes, 'b k (t c) -> c b k t', c=2)
|
484 |
+
else:
|
485 |
+
codes = rearrange(codes, 'b (k c) t -> c b k t', c=2)
|
486 |
+
return codes[0], codes[1]
|
487 |
+
|
488 |
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
489 |
B, K, T = codes.shape
|
490 |
+
assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match"
|
491 |
+
assert K == self.num_codebooks, "Provided codes' number of codebooks does not match"
|
492 |
+
|
493 |
+
scale_c0, scale_c1 = None, None
|
494 |
+
if scale is not None:
|
495 |
+
assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}"
|
496 |
+
scale_c0 = scale[0, ...]
|
497 |
+
scale_c1 = scale[1, ...]
|
498 |
+
|
499 |
+
codes_c0, codes_c1 = self.get_left_right_codes(codes)
|
500 |
+
audio_c0 = self.model.decode(codes_c0, scale_c0)
|
501 |
+
audio_c1 = self.model.decode(codes_c1, scale_c1)
|
502 |
+
return torch.cat([audio_c0, audio_c1], dim=1)
|
503 |
+
|
504 |
+
def decode_latent(self, codes: torch.Tensor):
|
505 |
+
"""Decode from the discrete codes to continuous latent space."""
|
506 |
+
raise NotImplementedError("Not supported by interleaved stereo wrapped models.")
|
audiocraft/models/lm.py
CHANGED
@@ -41,7 +41,7 @@ def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None
|
|
41 |
method (str): Method name for init function. Valid options are:
|
42 |
'gaussian', 'uniform'.
|
43 |
input_dim (int): Input dimension of the initialized module.
|
44 |
-
init_depth (
|
45 |
the standard deviation if defined.
|
46 |
"""
|
47 |
# Compute std
|
@@ -70,7 +70,7 @@ def init_layer(m: nn.Module,
|
|
70 |
Args:
|
71 |
m (nn.Module): Module to initialize.
|
72 |
method (str): Method name for the init function.
|
73 |
-
init_depth (
|
74 |
the standard deviation if defined.
|
75 |
zero_bias_init (bool): Whether to initialize the bias to 0 or not.
|
76 |
"""
|
@@ -130,10 +130,10 @@ class LMModel(StreamingModule):
|
|
130 |
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
|
131 |
norm (str): Normalization method.
|
132 |
norm_first (bool): Use pre-norm instead of post-norm.
|
133 |
-
emb_lr (
|
134 |
bias_proj (bool): Use bias for output projections.
|
135 |
-
weight_init (
|
136 |
-
depthwise_init (
|
137 |
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
|
138 |
cfg_dropout (float): Classifier-free guidance dropout.
|
139 |
cfg_coef (float): Classifier-free guidance coefficient.
|
@@ -179,11 +179,11 @@ class LMModel(StreamingModule):
|
|
179 |
"""Initialization of the transformer module weights.
|
180 |
|
181 |
Args:
|
182 |
-
weight_init (
|
183 |
-
depthwise_init (
|
184 |
'current' where the depth corresponds to the current layer index or 'global' where the total number
|
185 |
of layer is used as depth. If not set, no depthwise initialization strategy is used.
|
186 |
-
zero_bias_init (bool): Whether to
|
187 |
"""
|
188 |
assert depthwise_init is None or depthwise_init in ['current', 'global']
|
189 |
assert depthwise_init is None or weight_init is not None, \
|
@@ -225,17 +225,17 @@ class LMModel(StreamingModule):
|
|
225 |
S the sequence steps, return the logits with shape [B, card, K, S].
|
226 |
|
227 |
Args:
|
228 |
-
indices (torch.Tensor):
|
229 |
-
conditions (list
|
230 |
the given codes. Note that when evaluating multiple time with the same conditioning
|
231 |
you should pre-compute those and pass them as `condition_tensors`.
|
232 |
-
condition_tensors (dict[str, ConditionType]
|
233 |
tensors, see `conditions`.
|
234 |
Returns:
|
235 |
torch.Tensor: Logits.
|
236 |
"""
|
237 |
B, K, S = sequence.shape
|
238 |
-
assert K == self.num_codebooks,
|
239 |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
240 |
if condition_tensors is None:
|
241 |
assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
|
@@ -271,10 +271,10 @@ class LMModel(StreamingModule):
|
|
271 |
Args:
|
272 |
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
|
273 |
K the number of codebooks and T the number of timesteps.
|
274 |
-
conditions (list
|
275 |
the given codes. Note that when evaluating multiple time with the same conditioning
|
276 |
you should pre-compute those and pass them as `condition_tensors`.
|
277 |
-
condition_tensors (dict[str, ConditionType]
|
278 |
tensors, see `conditions`.
|
279 |
Returns:
|
280 |
LMOutput: Language model outputs
|
@@ -314,7 +314,8 @@ class LMModel(StreamingModule):
|
|
314 |
temp: float = 1.0,
|
315 |
top_k: int = 0,
|
316 |
top_p: float = 0.0,
|
317 |
-
cfg_coef: tp.Optional[float] = None
|
|
|
318 |
"""Sample next token from the model given a sequence and a set of conditions. The model supports
|
319 |
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
|
320 |
|
@@ -322,21 +323,22 @@ class LMModel(StreamingModule):
|
|
322 |
sequence (torch.Tensor): Current sequence of shape [B, K, S]
|
323 |
with K corresponding to the number of codebooks and S the number of sequence steps.
|
324 |
S = 1 in streaming mode, except for the first step that contains a bigger prompt.
|
325 |
-
condition_tensors (
|
326 |
should be twice the batch size, being the concatenation of the conditions + null conditions.
|
327 |
use_sampling (bool): Whether to use a sampling strategy or not.
|
328 |
temp (float): Sampling temperature.
|
329 |
top_k (int): K for "top-k" sampling.
|
330 |
top_p (float): P for "top-p" sampling.
|
331 |
-
cfg_coef (float): classifier free guidance coefficient
|
332 |
Returns:
|
333 |
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
|
334 |
"""
|
335 |
B = sequence.shape[0]
|
336 |
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
|
337 |
model = self if self._fsdp is None else self._fsdp
|
338 |
-
|
339 |
-
|
|
|
340 |
condition_tensors, null_condition_tensors = cfg_conditions
|
341 |
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
|
342 |
state = self.get_streaming_state()
|
@@ -388,7 +390,7 @@ class LMModel(StreamingModule):
|
|
388 |
top_k: int = 250,
|
389 |
top_p: float = 0.0,
|
390 |
cfg_coef: tp.Optional[float] = None,
|
391 |
-
two_step_cfg: bool =
|
392 |
remove_prompts: bool = False,
|
393 |
check: bool = False,
|
394 |
callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
|
@@ -396,15 +398,19 @@ class LMModel(StreamingModule):
|
|
396 |
be perform in a greedy fashion or using sampling with top K and top P strategies.
|
397 |
|
398 |
Args:
|
399 |
-
prompt (
|
400 |
-
conditions_tensors (
|
401 |
-
num_samples (int
|
402 |
max_gen_len (int): Maximum generation length.
|
403 |
use_sampling (bool): Whether to use a sampling strategy or not.
|
404 |
temp (float): Sampling temperature.
|
405 |
top_k (int): K for "top-k" sampling.
|
406 |
top_p (float): P for "top-p" sampling.
|
|
|
|
|
407 |
remove_prompts (bool): Whether to remove prompts from generation or not.
|
|
|
|
|
408 |
Returns:
|
409 |
torch.Tensor: Generated tokens.
|
410 |
"""
|
@@ -412,7 +418,7 @@ class LMModel(StreamingModule):
|
|
412 |
first_param = next(iter(self.parameters()))
|
413 |
device = first_param.device
|
414 |
|
415 |
-
# Checking all input shapes are
|
416 |
possible_num_samples = []
|
417 |
if num_samples is not None:
|
418 |
possible_num_samples.append(num_samples)
|
@@ -422,7 +428,7 @@ class LMModel(StreamingModule):
|
|
422 |
possible_num_samples.append(len(conditions))
|
423 |
else:
|
424 |
possible_num_samples.append(1)
|
425 |
-
assert [x == possible_num_samples[0] for x in possible_num_samples], "
|
426 |
num_samples = possible_num_samples[0]
|
427 |
|
428 |
# below we create set of conditions: one conditional and one unconditional
|
@@ -432,7 +438,7 @@ class LMModel(StreamingModule):
|
|
432 |
# 1. it is about x2 faster than doing 2 forward passes
|
433 |
# 2. avoid the streaming API treating the 2 passes as part of different time steps
|
434 |
# We also support doing two different passes, in particular to ensure that
|
435 |
-
# the padding structure is exactly the same between train
|
436 |
# With a batch size of 1, this can be slower though.
|
437 |
cfg_conditions: CFGConditions
|
438 |
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
@@ -457,8 +463,8 @@ class LMModel(StreamingModule):
|
|
457 |
B, K, T = prompt.shape
|
458 |
start_offset = T
|
459 |
print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
|
460 |
-
assert start_offset
|
461 |
-
|
462 |
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
463 |
# this token is used as default value for codes that are not generated yet
|
464 |
unknown_token = -1
|
@@ -490,7 +496,7 @@ class LMModel(StreamingModule):
|
|
490 |
# sample next token from the model, next token shape is [B, K, 1]
|
491 |
next_token = self._sample_next_token(
|
492 |
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
|
493 |
-
cfg_coef=cfg_coef)
|
494 |
# ensure the tokens that should be masked are properly set to special_token_id
|
495 |
# as the model never output special_token_id
|
496 |
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
|
|
41 |
method (str): Method name for init function. Valid options are:
|
42 |
'gaussian', 'uniform'.
|
43 |
input_dim (int): Input dimension of the initialized module.
|
44 |
+
init_depth (int, optional): Optional init depth value used to rescale
|
45 |
the standard deviation if defined.
|
46 |
"""
|
47 |
# Compute std
|
|
|
70 |
Args:
|
71 |
m (nn.Module): Module to initialize.
|
72 |
method (str): Method name for the init function.
|
73 |
+
init_depth (int, optional): Optional init depth value used to rescale
|
74 |
the standard deviation if defined.
|
75 |
zero_bias_init (bool): Whether to initialize the bias to 0 or not.
|
76 |
"""
|
|
|
130 |
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
|
131 |
norm (str): Normalization method.
|
132 |
norm_first (bool): Use pre-norm instead of post-norm.
|
133 |
+
emb_lr (float, optional): Embedding-specific learning rate.
|
134 |
bias_proj (bool): Use bias for output projections.
|
135 |
+
weight_init (str, optional): Method for weight initialization.
|
136 |
+
depthwise_init (str, optional): Method for depthwise weight initialization.
|
137 |
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
|
138 |
cfg_dropout (float): Classifier-free guidance dropout.
|
139 |
cfg_coef (float): Classifier-free guidance coefficient.
|
|
|
179 |
"""Initialization of the transformer module weights.
|
180 |
|
181 |
Args:
|
182 |
+
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
|
183 |
+
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
|
184 |
'current' where the depth corresponds to the current layer index or 'global' where the total number
|
185 |
of layer is used as depth. If not set, no depthwise initialization strategy is used.
|
186 |
+
zero_bias_init (bool): Whether to initialize bias to zero or not.
|
187 |
"""
|
188 |
assert depthwise_init is None or depthwise_init in ['current', 'global']
|
189 |
assert depthwise_init is None or weight_init is not None, \
|
|
|
225 |
S the sequence steps, return the logits with shape [B, card, K, S].
|
226 |
|
227 |
Args:
|
228 |
+
indices (torch.Tensor): Indices of the codes to model.
|
229 |
+
conditions (list of ConditioningAttributes): Conditions to use when modeling
|
230 |
the given codes. Note that when evaluating multiple time with the same conditioning
|
231 |
you should pre-compute those and pass them as `condition_tensors`.
|
232 |
+
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
|
233 |
tensors, see `conditions`.
|
234 |
Returns:
|
235 |
torch.Tensor: Logits.
|
236 |
"""
|
237 |
B, K, S = sequence.shape
|
238 |
+
assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
|
239 |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
240 |
if condition_tensors is None:
|
241 |
assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
|
|
|
271 |
Args:
|
272 |
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
|
273 |
K the number of codebooks and T the number of timesteps.
|
274 |
+
conditions (list of ConditioningAttributes): conditionings to use when modeling
|
275 |
the given codes. Note that when evaluating multiple time with the same conditioning
|
276 |
you should pre-compute those and pass them as `condition_tensors`.
|
277 |
+
condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
|
278 |
tensors, see `conditions`.
|
279 |
Returns:
|
280 |
LMOutput: Language model outputs
|
|
|
314 |
temp: float = 1.0,
|
315 |
top_k: int = 0,
|
316 |
top_p: float = 0.0,
|
317 |
+
cfg_coef: tp.Optional[float] = None,
|
318 |
+
two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
|
319 |
"""Sample next token from the model given a sequence and a set of conditions. The model supports
|
320 |
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
|
321 |
|
|
|
323 |
sequence (torch.Tensor): Current sequence of shape [B, K, S]
|
324 |
with K corresponding to the number of codebooks and S the number of sequence steps.
|
325 |
S = 1 in streaming mode, except for the first step that contains a bigger prompt.
|
326 |
+
condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
|
327 |
should be twice the batch size, being the concatenation of the conditions + null conditions.
|
328 |
use_sampling (bool): Whether to use a sampling strategy or not.
|
329 |
temp (float): Sampling temperature.
|
330 |
top_k (int): K for "top-k" sampling.
|
331 |
top_p (float): P for "top-p" sampling.
|
332 |
+
cfg_coef (float, optional): classifier free guidance coefficient
|
333 |
Returns:
|
334 |
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
|
335 |
"""
|
336 |
B = sequence.shape[0]
|
337 |
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
|
338 |
model = self if self._fsdp is None else self._fsdp
|
339 |
+
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
340 |
+
if two_step_cfg and cfg_conditions != {}:
|
341 |
+
assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
|
342 |
condition_tensors, null_condition_tensors = cfg_conditions
|
343 |
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
|
344 |
state = self.get_streaming_state()
|
|
|
390 |
top_k: int = 250,
|
391 |
top_p: float = 0.0,
|
392 |
cfg_coef: tp.Optional[float] = None,
|
393 |
+
two_step_cfg: tp.Optional[bool] = None,
|
394 |
remove_prompts: bool = False,
|
395 |
check: bool = False,
|
396 |
callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
|
|
|
398 |
be perform in a greedy fashion or using sampling with top K and top P strategies.
|
399 |
|
400 |
Args:
|
401 |
+
prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
|
402 |
+
conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
|
403 |
+
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
|
404 |
max_gen_len (int): Maximum generation length.
|
405 |
use_sampling (bool): Whether to use a sampling strategy or not.
|
406 |
temp (float): Sampling temperature.
|
407 |
top_k (int): K for "top-k" sampling.
|
408 |
top_p (float): P for "top-p" sampling.
|
409 |
+
cfg_coeff (float, optional): Classifier-free guidance coefficient.
|
410 |
+
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
|
411 |
remove_prompts (bool): Whether to remove prompts from generation or not.
|
412 |
+
check (bool): Whether to apply further checks on generated sequence.
|
413 |
+
callback (Callback, optional): Callback function to report generation progress.
|
414 |
Returns:
|
415 |
torch.Tensor: Generated tokens.
|
416 |
"""
|
|
|
418 |
first_param = next(iter(self.parameters()))
|
419 |
device = first_param.device
|
420 |
|
421 |
+
# Checking all input shapes are consistent.
|
422 |
possible_num_samples = []
|
423 |
if num_samples is not None:
|
424 |
possible_num_samples.append(num_samples)
|
|
|
428 |
possible_num_samples.append(len(conditions))
|
429 |
else:
|
430 |
possible_num_samples.append(1)
|
431 |
+
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
|
432 |
num_samples = possible_num_samples[0]
|
433 |
|
434 |
# below we create set of conditions: one conditional and one unconditional
|
|
|
438 |
# 1. it is about x2 faster than doing 2 forward passes
|
439 |
# 2. avoid the streaming API treating the 2 passes as part of different time steps
|
440 |
# We also support doing two different passes, in particular to ensure that
|
441 |
+
# the padding structure is exactly the same between train and test.
|
442 |
# With a batch size of 1, this can be slower though.
|
443 |
cfg_conditions: CFGConditions
|
444 |
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
|
|
463 |
B, K, T = prompt.shape
|
464 |
start_offset = T
|
465 |
print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
|
466 |
+
assert start_offset < max_gen_len
|
467 |
+
|
468 |
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
469 |
# this token is used as default value for codes that are not generated yet
|
470 |
unknown_token = -1
|
|
|
496 |
# sample next token from the model, next token shape is [B, K, 1]
|
497 |
next_token = self._sample_next_token(
|
498 |
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
|
499 |
+
cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
|
500 |
# ensure the tokens that should be masked are properly set to special_token_id
|
501 |
# as the model never output special_token_id
|
502 |
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
audiocraft/models/loaders.py
CHANGED
@@ -24,10 +24,16 @@ from huggingface_hub import hf_hub_download
|
|
24 |
import typing as tp
|
25 |
import os
|
26 |
|
27 |
-
from omegaconf import OmegaConf
|
28 |
import torch
|
29 |
|
|
|
30 |
from . import builders
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
|
33 |
HF_MODEL_CHECKPOINTS_MAP = {
|
@@ -50,6 +56,8 @@ def _get_state_dict(
|
|
50 |
device='cpu',
|
51 |
cache_dir: tp.Optional[str] = None,
|
52 |
):
|
|
|
|
|
53 |
# Return the state dict either from a file or url
|
54 |
file_or_url_or_id = str(file_or_url_or_id)
|
55 |
assert isinstance(file_or_url_or_id, str)
|
@@ -72,21 +80,120 @@ def _get_state_dict(
|
|
72 |
return torch.load(file, map_location=device)
|
73 |
|
74 |
else:
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
|
78 |
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
79 |
-
pkg =
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
cfg.device = str(device)
|
82 |
model = builders.get_compression_model(cfg)
|
83 |
model.load_state_dict(pkg['best_state'])
|
84 |
model.eval()
|
85 |
return model
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
89 |
-
pkg =
|
90 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
91 |
cfg.device = str(device)
|
92 |
if cfg.device == 'cpu':
|
@@ -95,8 +202,42 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_di
|
|
95 |
cfg.dtype = 'float32'
|
96 |
else:
|
97 |
cfg.dtype = 'float16'
|
|
|
|
|
|
|
98 |
model = builders.get_lm_model(cfg)
|
99 |
model.load_state_dict(pkg['best_state'])
|
100 |
model.eval()
|
101 |
model.cfg = cfg
|
102 |
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
import typing as tp
|
25 |
import os
|
26 |
|
27 |
+
from omegaconf import OmegaConf, DictConfig
|
28 |
import torch
|
29 |
|
30 |
+
import audiocraft
|
31 |
from . import builders
|
32 |
+
from .encodec import CompressionModel
|
33 |
+
|
34 |
+
|
35 |
+
def get_audiocraft_cache_dir() -> tp.Optional[str]:
|
36 |
+
return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
|
37 |
|
38 |
|
39 |
HF_MODEL_CHECKPOINTS_MAP = {
|
|
|
56 |
device='cpu',
|
57 |
cache_dir: tp.Optional[str] = None,
|
58 |
):
|
59 |
+
if cache_dir is None:
|
60 |
+
cache_dir = get_audiocraft_cache_dir()
|
61 |
# Return the state dict either from a file or url
|
62 |
file_or_url_or_id = str(file_or_url_or_id)
|
63 |
assert isinstance(file_or_url_or_id, str)
|
|
|
80 |
return torch.load(file, map_location=device)
|
81 |
|
82 |
else:
|
83 |
+
assert filename is not None, "filename needs to be defined if using HF checkpoints"
|
84 |
+
|
85 |
+
file = hf_hub_download(
|
86 |
+
repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir,
|
87 |
+
library_name="audiocraft", library_version=audiocraft.__version__)
|
88 |
+
return torch.load(file, map_location=device)
|
89 |
+
|
90 |
+
def create_melody_config(model_id: str, device: str) -> DictConfig:
|
91 |
+
"""Create a fallback configuration for melody models.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
model_id: The model identifier
|
95 |
+
device: The device to use
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
A compatible OmegaConf DictConfig
|
99 |
+
"""
|
100 |
+
base_cfg = {
|
101 |
+
"device": str(device),
|
102 |
+
"channels": 2 if "stereo" in model_id else 1,
|
103 |
+
"sample_rate": 32000,
|
104 |
+
"audio_channels": 2 if "stereo" in model_id else 1,
|
105 |
+
"frame_rate": 50,
|
106 |
+
"codec_name": "encodec",
|
107 |
+
"codec": {
|
108 |
+
"dim": 128,
|
109 |
+
"hidden_dim": 1024,
|
110 |
+
"stride": 320,
|
111 |
+
"n_q": 4,
|
112 |
+
"codebook_size": 2048,
|
113 |
+
"normalize": True,
|
114 |
+
}
|
115 |
+
}
|
116 |
+
return OmegaConf.create(base_cfg)
|
117 |
+
|
118 |
+
def create_default_config(model_id: str, device: str) -> DictConfig:
|
119 |
+
"""Create a fallback configuration for standard models.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
model_id: The model identifier
|
123 |
+
device: The device to use
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
A compatible OmegaConf DictConfig
|
127 |
+
"""
|
128 |
+
base_cfg = {
|
129 |
+
"device": str(device),
|
130 |
+
"channels": 2 if "stereo" in model_id else 1,
|
131 |
+
"sample_rate": 32000,
|
132 |
+
"audio_channels": 2 if "stereo" in model_id else 1,
|
133 |
+
"frame_rate": 50,
|
134 |
+
"codec_name": "encodec",
|
135 |
+
"codec": {
|
136 |
+
"dim": 128,
|
137 |
+
"hidden_dim": 1024,
|
138 |
+
"stride": 320,
|
139 |
+
"n_q": 4,
|
140 |
+
"codebook_size": 1024,
|
141 |
+
"normalize": True,
|
142 |
+
}
|
143 |
+
}
|
144 |
+
return OmegaConf.create(base_cfg)
|
145 |
+
|
146 |
+
|
147 |
+
def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
|
148 |
+
return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
|
149 |
|
150 |
|
151 |
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
152 |
+
pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
153 |
+
if 'pretrained' in pkg:
|
154 |
+
return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
|
155 |
+
|
156 |
+
# Handle newer model formats that might not have xp.cfg
|
157 |
+
if 'xp.cfg' not in pkg:
|
158 |
+
if file_or_url_or_id in ['melody-large', 'stereo-melody', 'stereo-medium',
|
159 |
+
'stereo-small', 'stereo-large', 'stereo-melody-large']:
|
160 |
+
print(f"Using fallback configuration for {file_or_url_or_id}")
|
161 |
+
# Create a default configuration based on the model type
|
162 |
+
# This is where you'd need to add model-specific configurations
|
163 |
+
if 'melody' in file_or_url_or_id:
|
164 |
+
cfg = create_melody_config(file_or_url_or_id, device)
|
165 |
+
else:
|
166 |
+
cfg = create_default_config(file_or_url_or_id, device)
|
167 |
+
else:
|
168 |
+
raise KeyError(f"Missing configuration for model {file_or_url_or_id}")
|
169 |
+
else:
|
170 |
+
cfg = OmegaConf.create(pkg['xp.cfg'])
|
171 |
+
|
172 |
cfg.device = str(device)
|
173 |
model = builders.get_compression_model(cfg)
|
174 |
model.load_state_dict(pkg['best_state'])
|
175 |
model.eval()
|
176 |
return model
|
177 |
|
178 |
+
def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
|
179 |
+
return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
|
180 |
+
|
181 |
+
|
182 |
+
def _delete_param(cfg: DictConfig, full_name: str):
|
183 |
+
parts = full_name.split('.')
|
184 |
+
for part in parts[:-1]:
|
185 |
+
if part in cfg:
|
186 |
+
cfg = cfg[part]
|
187 |
+
else:
|
188 |
+
return
|
189 |
+
OmegaConf.set_struct(cfg, False)
|
190 |
+
if parts[-1] in cfg:
|
191 |
+
del cfg[parts[-1]]
|
192 |
+
OmegaConf.set_struct(cfg, True)
|
193 |
+
|
194 |
|
195 |
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
196 |
+
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
197 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
198 |
cfg.device = str(device)
|
199 |
if cfg.device == 'cpu':
|
|
|
202 |
cfg.dtype = 'float32'
|
203 |
else:
|
204 |
cfg.dtype = 'float16'
|
205 |
+
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
206 |
+
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
207 |
+
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
208 |
model = builders.get_lm_model(cfg)
|
209 |
model.load_state_dict(pkg['best_state'])
|
210 |
model.eval()
|
211 |
model.cfg = cfg
|
212 |
return model
|
213 |
+
|
214 |
+
|
215 |
+
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
|
216 |
+
filename: tp.Optional[str] = None,
|
217 |
+
cache_dir: tp.Optional[str] = None):
|
218 |
+
return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
|
219 |
+
|
220 |
+
|
221 |
+
def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
|
222 |
+
device='cpu',
|
223 |
+
filename: tp.Optional[str] = None,
|
224 |
+
cache_dir: tp.Optional[str] = None):
|
225 |
+
pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
|
226 |
+
models = []
|
227 |
+
processors = []
|
228 |
+
cfgs = []
|
229 |
+
sample_rate = pkg['sample_rate']
|
230 |
+
for i in range(pkg['n_bands']):
|
231 |
+
cfg = pkg[i]['cfg']
|
232 |
+
model = builders.get_diffusion_model(cfg)
|
233 |
+
model_dict = pkg[i]['model_state']
|
234 |
+
model.load_state_dict(model_dict)
|
235 |
+
model.to(device)
|
236 |
+
processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
|
237 |
+
processor_dict = pkg[i]['processor_state']
|
238 |
+
processor.load_state_dict(processor_dict)
|
239 |
+
processor.to(device)
|
240 |
+
models.append(model)
|
241 |
+
processors.append(processor)
|
242 |
+
cfgs.append(cfg)
|
243 |
+
return models, processors, cfgs
|
audiocraft/models/musicgen.py
CHANGED
@@ -11,18 +11,19 @@ and provide easy access to the generation API.
|
|
11 |
|
12 |
import os
|
13 |
import typing as tp
|
|
|
14 |
|
|
|
15 |
import torch
|
16 |
|
17 |
from .encodec import CompressionModel
|
18 |
from .lm import LMModel
|
19 |
-
from .builders import get_debug_compression_model, get_debug_lm_model
|
20 |
from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
|
21 |
from ..data.audio_utils import convert_audio
|
22 |
from ..modules.conditioners import ConditioningAttributes, WavCondition
|
23 |
from ..utils.autocast import TorchAutocast
|
24 |
|
25 |
-
|
26 |
MelodyList = tp.List[tp.Optional[torch.Tensor]]
|
27 |
MelodyType = tp.Union[torch.Tensor, MelodyList]
|
28 |
|
@@ -35,11 +36,32 @@ class MusicGen:
|
|
35 |
compression_model (CompressionModel): Compression model
|
36 |
used to map audio to invertible discrete representations.
|
37 |
lm (LMModel): Language model over discrete representations.
|
|
|
|
|
38 |
"""
|
39 |
-
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: float = 30):
|
40 |
self.name = name
|
41 |
self.compression_model = compression_model
|
42 |
self.lm = lm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
self.max_duration = max_duration
|
44 |
self.duration = 15.0 # default duration
|
45 |
self.device = next(iter(lm.parameters())).device
|
@@ -53,7 +75,12 @@ class MusicGen:
|
|
53 |
enabled=True, device_type=self.device.type, dtype=torch.float16)
|
54 |
|
55 |
@property
|
56 |
-
def
|
|
|
|
|
|
|
|
|
|
|
57 |
"""Roughly the number of AR steps per seconds."""
|
58 |
return self.compression_model.frame_rate
|
59 |
|
@@ -100,12 +127,15 @@ class MusicGen:
|
|
100 |
f"{name} is not a valid checkpoint name. "
|
101 |
f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
|
102 |
)
|
|
|
|
|
103 |
|
104 |
cache_dir = os.environ.get('MUSICGEN_ROOT', None)
|
105 |
compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
|
106 |
lm = load_lm_model(name, device=device, cache_dir=cache_dir)
|
107 |
-
if name
|
108 |
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
|
|
|
109 |
|
110 |
return MusicGen(name, compression_model, lm)
|
111 |
|
@@ -125,6 +155,9 @@ class MusicGen:
|
|
125 |
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
126 |
instead of batching together the two. This has some impact on how things
|
127 |
are padded but seems to have little impact in practice.
|
|
|
|
|
|
|
128 |
rep_penalty (float, optional): If set, use repetition penalty during generation. Not Implemented.
|
129 |
"""
|
130 |
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
@@ -137,47 +170,61 @@ class MusicGen:
|
|
137 |
'top_k': top_k,
|
138 |
'top_p': top_p,
|
139 |
'cfg_coef': cfg_coef,
|
140 |
-
'two_step_cfg': two_step_cfg,
|
141 |
}
|
142 |
|
143 |
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
144 |
"""Override the default progress callback."""
|
145 |
self._progress_callback = progress_callback
|
146 |
|
147 |
-
def generate_unconditional(self, num_samples: int, progress: bool = False
|
|
|
|
|
148 |
"""Generate samples in an unconditional manner.
|
149 |
|
150 |
Args:
|
151 |
num_samples (int): Number of samples to be generated.
|
152 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
|
|
153 |
"""
|
154 |
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
|
155 |
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
156 |
-
|
|
|
|
|
|
|
157 |
|
158 |
-
def generate(self, descriptions: tp.List[str], progress: bool = False)
|
|
|
159 |
"""Generate samples conditioned on text.
|
160 |
|
161 |
Args:
|
162 |
-
descriptions (
|
163 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
|
|
164 |
"""
|
165 |
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
166 |
assert prompt_tokens is None
|
167 |
-
|
|
|
|
|
|
|
168 |
|
169 |
def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
170 |
-
melody_sample_rate: int, progress: bool = False
|
|
|
|
|
171 |
"""Generate samples conditioned on text and melody.
|
172 |
|
173 |
Args:
|
174 |
-
descriptions (
|
175 |
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
|
176 |
melody conditioning. Should have shape [B, C, T] with B matching the description length,
|
177 |
C=1 or 2. It can be [C, T] if there is a single description. It can also be
|
178 |
a list of [C, T] tensors.
|
179 |
melody_sample_rate: (int): Sample rate of the melody waveforms.
|
180 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
|
|
181 |
"""
|
182 |
if isinstance(melody_wavs, torch.Tensor):
|
183 |
if melody_wavs.dim() == 2:
|
@@ -197,10 +244,14 @@ class MusicGen:
|
|
197 |
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
|
198 |
melody_wavs=melody_wavs)
|
199 |
assert prompt_tokens is None
|
200 |
-
|
|
|
|
|
|
|
201 |
|
202 |
def generate_with_all(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
203 |
-
sample_rate: int, progress: bool = False, prompt: tp.Optional[torch.Tensor] = None)
|
|
|
204 |
"""Generate samples conditioned on text and melody and audio prompts.
|
205 |
Args:
|
206 |
descriptions (tp.List[str]): A list of strings used as text conditioning.
|
@@ -249,19 +300,24 @@ class MusicGen:
|
|
249 |
assert prompt_tokens is not None
|
250 |
else:
|
251 |
assert prompt_tokens is None
|
252 |
-
|
|
|
|
|
|
|
253 |
|
254 |
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
255 |
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
256 |
-
progress: bool = False)
|
|
|
257 |
"""Generate samples conditioned on audio prompts.
|
258 |
|
259 |
Args:
|
260 |
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
261 |
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
262 |
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
|
263 |
-
descriptions (
|
264 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
|
|
265 |
"""
|
266 |
if prompt.dim() == 2:
|
267 |
prompt = prompt[None]
|
@@ -272,7 +328,10 @@ class MusicGen:
|
|
272 |
descriptions = [None] * len(prompt)
|
273 |
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
274 |
assert prompt_tokens is not None
|
275 |
-
|
|
|
|
|
|
|
276 |
|
277 |
@torch.no_grad()
|
278 |
def _prepare_tokens_and_attributes(
|
@@ -284,9 +343,9 @@ class MusicGen:
|
|
284 |
"""Prepare model inputs.
|
285 |
|
286 |
Args:
|
287 |
-
descriptions (
|
288 |
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
289 |
-
melody_wavs (
|
290 |
used as melody conditioning. Defaults to None.
|
291 |
"""
|
292 |
attributes = [
|
@@ -296,11 +355,12 @@ class MusicGen:
|
|
296 |
if melody_wavs is None:
|
297 |
for attr in attributes:
|
298 |
attr.wav['self_wav'] = WavCondition(
|
299 |
-
torch.zeros((1, 1), device=self.device),
|
300 |
torch.tensor([0], device=self.device),
|
301 |
-
|
|
|
302 |
else:
|
303 |
-
if self.
|
304 |
raise RuntimeError("This model doesn't support melody conditioning. "
|
305 |
"Use the `melody` model.")
|
306 |
assert len(melody_wavs) == len(descriptions), \
|
@@ -309,13 +369,17 @@ class MusicGen:
|
|
309 |
for attr, melody in zip(attributes, melody_wavs):
|
310 |
if melody is None:
|
311 |
attr.wav['self_wav'] = WavCondition(
|
312 |
-
torch.zeros((1, 1), device=self.device),
|
313 |
torch.tensor([0], device=self.device),
|
314 |
-
|
|
|
315 |
else:
|
316 |
attr.wav['self_wav'] = WavCondition(
|
317 |
-
melody.to(device=self.device),
|
318 |
-
torch.tensor([melody.shape[-1]], device=self.device)
|
|
|
|
|
|
|
319 |
|
320 |
if prompt is not None:
|
321 |
if descriptions is not None:
|
@@ -396,8 +460,10 @@ class MusicGen:
|
|
396 |
positions = torch.arange(initial_position,
|
397 |
initial_position + wav_target_length, device=self.device)
|
398 |
attr.wav['self_wav'] = WavCondition(
|
399 |
-
ref_wav[0][
|
400 |
-
torch.full_like(ref_wav[1], wav_target_length)
|
|
|
|
|
401 |
with self.autocast:
|
402 |
gen_tokens = self.lm.generate(
|
403 |
prompt_tokens, attributes,
|
@@ -411,13 +477,21 @@ class MusicGen:
|
|
411 |
current_gen_offset += stride_tokens
|
412 |
|
413 |
gen_tokens = torch.cat(all_tokens, dim=-1)
|
|
|
414 |
|
415 |
# generate audio
|
416 |
-
assert gen_tokens.dim() == 3
|
417 |
-
with torch.no_grad():
|
418 |
-
gen_audio = self.compression_model.decode(gen_tokens, None)
|
419 |
-
return gen_audio
|
420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
#def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
422 |
# prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
423 |
# """Generate discrete audio tokens given audio prompt and/or conditions.
|
|
|
11 |
|
12 |
import os
|
13 |
import typing as tp
|
14 |
+
import warnings
|
15 |
|
16 |
+
import omegaconf
|
17 |
import torch
|
18 |
|
19 |
from .encodec import CompressionModel
|
20 |
from .lm import LMModel
|
21 |
+
from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model
|
22 |
from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
|
23 |
from ..data.audio_utils import convert_audio
|
24 |
from ..modules.conditioners import ConditioningAttributes, WavCondition
|
25 |
from ..utils.autocast import TorchAutocast
|
26 |
|
|
|
27 |
MelodyList = tp.List[tp.Optional[torch.Tensor]]
|
28 |
MelodyType = tp.Union[torch.Tensor, MelodyList]
|
29 |
|
|
|
36 |
compression_model (CompressionModel): Compression model
|
37 |
used to map audio to invertible discrete representations.
|
38 |
lm (LMModel): Language model over discrete representations.
|
39 |
+
max_duration (float, optional): maximum duration the model can produce,
|
40 |
+
otherwise, inferred from the training params.
|
41 |
"""
|
42 |
+
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: tp.Optional[float] = 30):
|
43 |
self.name = name
|
44 |
self.compression_model = compression_model
|
45 |
self.lm = lm
|
46 |
+
self.cfg: tp.Optional[omegaconf.DictConfig] = None
|
47 |
+
# Just to be safe, let's put everything in eval mode.
|
48 |
+
self.compression_model.eval()
|
49 |
+
self.lm.eval()
|
50 |
+
|
51 |
+
if hasattr(lm, 'cfg'):
|
52 |
+
cfg = lm.cfg
|
53 |
+
assert isinstance(cfg, omegaconf.DictConfig)
|
54 |
+
self.cfg = cfg
|
55 |
+
|
56 |
+
if self.cfg is not None:
|
57 |
+
self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
|
58 |
+
|
59 |
+
if max_duration is None:
|
60 |
+
if self.cfg is not None:
|
61 |
+
max_duration = lm.cfg.dataset.segment_duration # type: ignore
|
62 |
+
else:
|
63 |
+
raise ValueError("You must provide max_duration when building directly MusicGen")
|
64 |
+
assert max_duration is not None
|
65 |
self.max_duration = max_duration
|
66 |
self.duration = 15.0 # default duration
|
67 |
self.device = next(iter(lm.parameters())).device
|
|
|
75 |
enabled=True, device_type=self.device.type, dtype=torch.float16)
|
76 |
|
77 |
@property
|
78 |
+
def version(self) -> str:
|
79 |
+
from audiocraft import __version__ as audiocraft_version
|
80 |
+
return audiocraft_version
|
81 |
+
|
82 |
+
@property
|
83 |
+
def frame_rate(self) -> float:
|
84 |
"""Roughly the number of AR steps per seconds."""
|
85 |
return self.compression_model.frame_rate
|
86 |
|
|
|
127 |
f"{name} is not a valid checkpoint name. "
|
128 |
f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
|
129 |
)
|
130 |
+
else:
|
131 |
+
name = HF_MODEL_CHECKPOINTS_MAP[name]
|
132 |
|
133 |
cache_dir = os.environ.get('MUSICGEN_ROOT', None)
|
134 |
compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
|
135 |
lm = load_lm_model(name, device=device, cache_dir=cache_dir)
|
136 |
+
if name.__contains__('melody') or 'self_wav' in lm.condition_provider.conditioners:
|
137 |
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
|
138 |
+
lm.condition_provider.conditioners['self_wav']._use_masking = False
|
139 |
|
140 |
return MusicGen(name, compression_model, lm)
|
141 |
|
|
|
155 |
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
156 |
instead of batching together the two. This has some impact on how things
|
157 |
are padded but seems to have little impact in practice.
|
158 |
+
extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
|
159 |
+
should we extend the audio each time. Larger values will mean less context is
|
160 |
+
preserved, and shorter value will require extra computations.
|
161 |
rep_penalty (float, optional): If set, use repetition penalty during generation. Not Implemented.
|
162 |
"""
|
163 |
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
|
|
170 |
'top_k': top_k,
|
171 |
'top_p': top_p,
|
172 |
'cfg_coef': cfg_coef,
|
173 |
+
'two_step_cfg': two_step_cfg,
|
174 |
}
|
175 |
|
176 |
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
177 |
"""Override the default progress callback."""
|
178 |
self._progress_callback = progress_callback
|
179 |
|
180 |
+
def generate_unconditional(self, num_samples: int, progress: bool = False,
|
181 |
+
return_tokens: bool = False) -> tp.Union[torch.Tensor,
|
182 |
+
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
183 |
"""Generate samples in an unconditional manner.
|
184 |
|
185 |
Args:
|
186 |
num_samples (int): Number of samples to be generated.
|
187 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
188 |
+
return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.
|
189 |
"""
|
190 |
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
|
191 |
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
192 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
193 |
+
if return_tokens:
|
194 |
+
return self.generate_audio(tokens), tokens
|
195 |
+
return self.generate_audio(tokens)
|
196 |
|
197 |
+
def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
|
198 |
+
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
199 |
"""Generate samples conditioned on text.
|
200 |
|
201 |
Args:
|
202 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
203 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
204 |
+
return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.
|
205 |
"""
|
206 |
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
207 |
assert prompt_tokens is None
|
208 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
209 |
+
if return_tokens:
|
210 |
+
return self.generate_audio(tokens), tokens
|
211 |
+
return self.generate_audio(tokens)
|
212 |
|
213 |
def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
214 |
+
melody_sample_rate: int, progress: bool = False,
|
215 |
+
return_tokens: bool = False) -> tp.Union[torch.Tensor,
|
216 |
+
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
217 |
"""Generate samples conditioned on text and melody.
|
218 |
|
219 |
Args:
|
220 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
221 |
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
|
222 |
melody conditioning. Should have shape [B, C, T] with B matching the description length,
|
223 |
C=1 or 2. It can be [C, T] if there is a single description. It can also be
|
224 |
a list of [C, T] tensors.
|
225 |
melody_sample_rate: (int): Sample rate of the melody waveforms.
|
226 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
227 |
+
return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.
|
228 |
"""
|
229 |
if isinstance(melody_wavs, torch.Tensor):
|
230 |
if melody_wavs.dim() == 2:
|
|
|
244 |
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
|
245 |
melody_wavs=melody_wavs)
|
246 |
assert prompt_tokens is None
|
247 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
248 |
+
if return_tokens:
|
249 |
+
return self.generate_audio(tokens), tokens
|
250 |
+
return self.generate_audio(tokens)
|
251 |
|
252 |
def generate_with_all(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
253 |
+
sample_rate: int, progress: bool = False, prompt: tp.Optional[torch.Tensor] = None, return_tokens: bool = False) \
|
254 |
+
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
255 |
"""Generate samples conditioned on text and melody and audio prompts.
|
256 |
Args:
|
257 |
descriptions (tp.List[str]): A list of strings used as text conditioning.
|
|
|
300 |
assert prompt_tokens is not None
|
301 |
else:
|
302 |
assert prompt_tokens is None
|
303 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
304 |
+
if return_tokens:
|
305 |
+
return self.generate_audio(tokens), tokens
|
306 |
+
return self.generate_audio(tokens)
|
307 |
|
308 |
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
309 |
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
310 |
+
progress: bool = False, return_tokens: bool = False) \
|
311 |
+
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
312 |
"""Generate samples conditioned on audio prompts.
|
313 |
|
314 |
Args:
|
315 |
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
316 |
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
317 |
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
|
318 |
+
descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
|
319 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
320 |
+
return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.
|
321 |
"""
|
322 |
if prompt.dim() == 2:
|
323 |
prompt = prompt[None]
|
|
|
328 |
descriptions = [None] * len(prompt)
|
329 |
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
330 |
assert prompt_tokens is not None
|
331 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
332 |
+
if return_tokens:
|
333 |
+
return self.generate_audio(tokens), tokens
|
334 |
+
return self.generate_audio(tokens)
|
335 |
|
336 |
@torch.no_grad()
|
337 |
def _prepare_tokens_and_attributes(
|
|
|
343 |
"""Prepare model inputs.
|
344 |
|
345 |
Args:
|
346 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
347 |
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
348 |
+
melody_wavs (torch.Tensor, optional): A batch of waveforms
|
349 |
used as melody conditioning. Defaults to None.
|
350 |
"""
|
351 |
attributes = [
|
|
|
355 |
if melody_wavs is None:
|
356 |
for attr in attributes:
|
357 |
attr.wav['self_wav'] = WavCondition(
|
358 |
+
torch.zeros((1, 1, 1), device=self.device),
|
359 |
torch.tensor([0], device=self.device),
|
360 |
+
sample_rate=[self.sample_rate],
|
361 |
+
path=[None]) # type: ignore
|
362 |
else:
|
363 |
+
if 'self_wav' not in self.lm.condition_provider.conditioners:
|
364 |
raise RuntimeError("This model doesn't support melody conditioning. "
|
365 |
"Use the `melody` model.")
|
366 |
assert len(melody_wavs) == len(descriptions), \
|
|
|
369 |
for attr, melody in zip(attributes, melody_wavs):
|
370 |
if melody is None:
|
371 |
attr.wav['self_wav'] = WavCondition(
|
372 |
+
torch.zeros((1, 1, 1), device=self.device),
|
373 |
torch.tensor([0], device=self.device),
|
374 |
+
sample_rate=[self.sample_rate],
|
375 |
+
path=[None]) # type: ignore
|
376 |
else:
|
377 |
attr.wav['self_wav'] = WavCondition(
|
378 |
+
melody[None].to(device=self.device),
|
379 |
+
torch.tensor([melody.shape[-1]], device=self.device),
|
380 |
+
sample_rate=[self.sample_rate],
|
381 |
+
path=[None],
|
382 |
+
)
|
383 |
|
384 |
if prompt is not None:
|
385 |
if descriptions is not None:
|
|
|
460 |
positions = torch.arange(initial_position,
|
461 |
initial_position + wav_target_length, device=self.device)
|
462 |
attr.wav['self_wav'] = WavCondition(
|
463 |
+
ref_wav[0][..., positions % wav_length],
|
464 |
+
torch.full_like(ref_wav[1], wav_target_length),
|
465 |
+
[self.sample_rate] * ref_wav[0].size(0),
|
466 |
+
[None], [0.])
|
467 |
with self.autocast:
|
468 |
gen_tokens = self.lm.generate(
|
469 |
prompt_tokens, attributes,
|
|
|
477 |
current_gen_offset += stride_tokens
|
478 |
|
479 |
gen_tokens = torch.cat(all_tokens, dim=-1)
|
480 |
+
return gen_tokens
|
481 |
|
482 |
# generate audio
|
|
|
|
|
|
|
|
|
483 |
|
484 |
+
def generate_audio(self, gen_tokens: torch.Tensor):
|
485 |
+
try:
|
486 |
+
"""Generate Audio from tokens"""
|
487 |
+
assert gen_tokens.dim() == 3
|
488 |
+
with torch.no_grad():
|
489 |
+
gen_audio = self.compression_model.decode(gen_tokens, None)
|
490 |
+
return gen_audio
|
491 |
+
except Exception as e:
|
492 |
+
print(f"Error generating audio: {e}")
|
493 |
+
return None
|
494 |
+
|
495 |
#def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
496 |
# prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
497 |
# """Generate discrete audio tokens given audio prompt and/or conditions.
|
audiocraft/models/unet.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Pytorch Unet Module used for diffusion.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from dataclasses import dataclass
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class Output:
|
22 |
+
sample: torch.Tensor
|
23 |
+
|
24 |
+
|
25 |
+
def get_model(cfg, channels: int, side: int, num_steps: int):
|
26 |
+
if cfg.model == 'unet':
|
27 |
+
return DiffusionUnet(
|
28 |
+
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
29 |
+
else:
|
30 |
+
raise RuntimeError('Not Implemented')
|
31 |
+
|
32 |
+
|
33 |
+
class ResBlock(nn.Module):
|
34 |
+
def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
|
35 |
+
dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
36 |
+
dropout: float = 0.):
|
37 |
+
super().__init__()
|
38 |
+
stride = 1
|
39 |
+
padding = dilation * (kernel - stride) // 2
|
40 |
+
Conv = nn.Conv1d
|
41 |
+
Drop = nn.Dropout1d
|
42 |
+
self.norm1 = nn.GroupNorm(norm_groups, channels)
|
43 |
+
self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
|
44 |
+
self.activation1 = activation()
|
45 |
+
self.dropout1 = Drop(dropout)
|
46 |
+
|
47 |
+
self.norm2 = nn.GroupNorm(norm_groups, channels)
|
48 |
+
self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
|
49 |
+
self.activation2 = activation()
|
50 |
+
self.dropout2 = Drop(dropout)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
|
54 |
+
h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
|
55 |
+
return x + h
|
56 |
+
|
57 |
+
|
58 |
+
class DecoderLayer(nn.Module):
|
59 |
+
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
|
60 |
+
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
61 |
+
dropout: float = 0.):
|
62 |
+
super().__init__()
|
63 |
+
padding = (kernel - stride) // 2
|
64 |
+
self.res_blocks = nn.Sequential(
|
65 |
+
*[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
|
66 |
+
for idx in range(res_blocks)])
|
67 |
+
self.norm = nn.GroupNorm(norm_groups, chin)
|
68 |
+
ConvTr = nn.ConvTranspose1d
|
69 |
+
self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
|
70 |
+
self.activation = activation()
|
71 |
+
|
72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
73 |
+
x = self.res_blocks(x)
|
74 |
+
x = self.norm(x)
|
75 |
+
x = self.activation(x)
|
76 |
+
x = self.convtr(x)
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
class EncoderLayer(nn.Module):
|
81 |
+
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
|
82 |
+
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
83 |
+
dropout: float = 0.):
|
84 |
+
super().__init__()
|
85 |
+
padding = (kernel - stride) // 2
|
86 |
+
Conv = nn.Conv1d
|
87 |
+
self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
|
88 |
+
self.norm = nn.GroupNorm(norm_groups, chout)
|
89 |
+
self.activation = activation()
|
90 |
+
self.res_blocks = nn.Sequential(
|
91 |
+
*[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
|
92 |
+
for idx in range(res_blocks)])
|
93 |
+
|
94 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
95 |
+
B, C, T = x.shape
|
96 |
+
stride, = self.conv.stride
|
97 |
+
pad = (stride - (T % stride)) % stride
|
98 |
+
x = F.pad(x, (0, pad))
|
99 |
+
|
100 |
+
x = self.conv(x)
|
101 |
+
x = self.norm(x)
|
102 |
+
x = self.activation(x)
|
103 |
+
x = self.res_blocks(x)
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
class BLSTM(nn.Module):
|
108 |
+
"""BiLSTM with same hidden units as input dim.
|
109 |
+
"""
|
110 |
+
def __init__(self, dim, layers=2):
|
111 |
+
super().__init__()
|
112 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
113 |
+
self.linear = nn.Linear(2 * dim, dim)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
x = x.permute(2, 0, 1)
|
117 |
+
x = self.lstm(x)[0]
|
118 |
+
x = self.linear(x)
|
119 |
+
x = x.permute(1, 2, 0)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
class DiffusionUnet(nn.Module):
|
124 |
+
def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
|
125 |
+
max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
|
126 |
+
bilstm: bool = False, transformer: bool = False,
|
127 |
+
codec_dim: tp.Optional[int] = None, **kwargs):
|
128 |
+
super().__init__()
|
129 |
+
self.encoders = nn.ModuleList()
|
130 |
+
self.decoders = nn.ModuleList()
|
131 |
+
self.embeddings: tp.Optional[nn.ModuleList] = None
|
132 |
+
self.embedding = nn.Embedding(num_steps, hidden)
|
133 |
+
if emb_all_layers:
|
134 |
+
self.embeddings = nn.ModuleList()
|
135 |
+
self.condition_embedding: tp.Optional[nn.Module] = None
|
136 |
+
for d in range(depth):
|
137 |
+
encoder = EncoderLayer(chin, hidden, **kwargs)
|
138 |
+
decoder = DecoderLayer(hidden, chin, **kwargs)
|
139 |
+
self.encoders.append(encoder)
|
140 |
+
self.decoders.insert(0, decoder)
|
141 |
+
if emb_all_layers and d > 0:
|
142 |
+
assert self.embeddings is not None
|
143 |
+
self.embeddings.append(nn.Embedding(num_steps, hidden))
|
144 |
+
chin = hidden
|
145 |
+
hidden = min(int(chin * growth), max_channels)
|
146 |
+
self.bilstm: tp.Optional[nn.Module]
|
147 |
+
if bilstm:
|
148 |
+
self.bilstm = BLSTM(chin)
|
149 |
+
else:
|
150 |
+
self.bilstm = None
|
151 |
+
self.use_transformer = transformer
|
152 |
+
self.cross_attention = False
|
153 |
+
if transformer:
|
154 |
+
self.cross_attention = cross_attention
|
155 |
+
self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
|
156 |
+
cross_attention=cross_attention)
|
157 |
+
|
158 |
+
self.use_codec = False
|
159 |
+
if codec_dim is not None:
|
160 |
+
self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
|
161 |
+
self.use_codec = True
|
162 |
+
|
163 |
+
def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
|
164 |
+
skips = []
|
165 |
+
bs = x.size(0)
|
166 |
+
z = x
|
167 |
+
view_args = [1]
|
168 |
+
if type(step) is torch.Tensor:
|
169 |
+
step_tensor = step
|
170 |
+
else:
|
171 |
+
step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
|
172 |
+
|
173 |
+
for idx, encoder in enumerate(self.encoders):
|
174 |
+
z = encoder(z)
|
175 |
+
if idx == 0:
|
176 |
+
z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
|
177 |
+
elif self.embeddings is not None:
|
178 |
+
z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
|
179 |
+
|
180 |
+
skips.append(z)
|
181 |
+
|
182 |
+
if self.use_codec: # insert condition in the bottleneck
|
183 |
+
assert condition is not None, "Model defined for conditionnal generation"
|
184 |
+
condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
|
185 |
+
assert condition_emb.size(-1) <= 2 * z.size(-1), \
|
186 |
+
f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
|
187 |
+
if not self.cross_attention:
|
188 |
+
|
189 |
+
condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
|
190 |
+
assert z.size() == condition_emb.size()
|
191 |
+
z += condition_emb
|
192 |
+
cross_attention_src = None
|
193 |
+
else:
|
194 |
+
cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
|
195 |
+
B, T, C = cross_attention_src.shape
|
196 |
+
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
197 |
+
pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
|
198 |
+
cross_attention_src = cross_attention_src + pos_emb
|
199 |
+
if self.use_transformer:
|
200 |
+
z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
|
201 |
+
else:
|
202 |
+
if self.bilstm is None:
|
203 |
+
z = torch.zeros_like(z)
|
204 |
+
else:
|
205 |
+
z = self.bilstm(z)
|
206 |
+
|
207 |
+
for decoder in self.decoders:
|
208 |
+
s = skips.pop(-1)
|
209 |
+
z = z[:, :, :s.shape[2]]
|
210 |
+
z = z + s
|
211 |
+
z = decoder(z)
|
212 |
+
|
213 |
+
z = z[:, :, :x.shape[2]]
|
214 |
+
return Output(z)
|
audiocraft/modules/__init__.py
CHANGED
@@ -18,3 +18,4 @@ from .conv import (
|
|
18 |
)
|
19 |
from .lstm import StreamableLSTM
|
20 |
from .seanet import SEANetEncoder, SEANetDecoder
|
|
|
|
18 |
)
|
19 |
from .lstm import StreamableLSTM
|
20 |
from .seanet import SEANetEncoder, SEANetDecoder
|
21 |
+
from .transformer import StreamingTransformer
|
audiocraft/modules/chroma.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import typing as tp
|
7 |
+
|
8 |
+
from einops import rearrange
|
9 |
+
from librosa import filters
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torchaudio
|
14 |
+
|
15 |
+
|
16 |
+
class ChromaExtractor(nn.Module):
|
17 |
+
"""Chroma extraction and quantization.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
sample_rate (int): Sample rate for the chroma extraction.
|
21 |
+
n_chroma (int): Number of chroma bins for the chroma extraction.
|
22 |
+
radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
|
23 |
+
nfft (int, optional): Number of FFT.
|
24 |
+
winlen (int, optional): Window length.
|
25 |
+
winhop (int, optional): Window hop size.
|
26 |
+
argmax (bool, optional): Whether to use argmax. Defaults to False.
|
27 |
+
norm (float, optional): Norm for chroma normalization. Defaults to inf.
|
28 |
+
"""
|
29 |
+
def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
|
30 |
+
winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
|
31 |
+
norm: float = torch.inf):
|
32 |
+
super().__init__()
|
33 |
+
self.winlen = winlen or 2 ** radix2_exp
|
34 |
+
self.nfft = nfft or self.winlen
|
35 |
+
self.winhop = winhop or (self.winlen // 4)
|
36 |
+
self.sample_rate = sample_rate
|
37 |
+
self.n_chroma = n_chroma
|
38 |
+
self.norm = norm
|
39 |
+
self.argmax = argmax
|
40 |
+
self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
|
41 |
+
n_chroma=self.n_chroma)), persistent=False)
|
42 |
+
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
|
43 |
+
hop_length=self.winhop, power=2, center=True,
|
44 |
+
pad=0, normalized=True)
|
45 |
+
|
46 |
+
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
47 |
+
T = wav.shape[-1]
|
48 |
+
# in case we are getting a wav that was dropped out (nullified)
|
49 |
+
# from the conditioner, make sure wav length is no less that nfft
|
50 |
+
if T < self.nfft:
|
51 |
+
pad = self.nfft - T
|
52 |
+
r = 0 if pad % 2 == 0 else 1
|
53 |
+
wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
|
54 |
+
assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
|
55 |
+
|
56 |
+
spec = self.spec(wav).squeeze(1)
|
57 |
+
raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
|
58 |
+
norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
|
59 |
+
norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
|
60 |
+
|
61 |
+
if self.argmax:
|
62 |
+
idx = norm_chroma.argmax(-1, keepdim=True)
|
63 |
+
norm_chroma[:] = 0
|
64 |
+
norm_chroma.scatter_(dim=-1, index=idx, value=1)
|
65 |
+
|
66 |
+
return norm_chroma
|
audiocraft/modules/codebooks_patterns.py
CHANGED
@@ -122,7 +122,7 @@ class Pattern:
|
|
122 |
Args:
|
123 |
timesteps (int): Maximum number of timesteps steps to consider.
|
124 |
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
|
125 |
-
device (
|
126 |
Returns:
|
127 |
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
|
128 |
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
|
@@ -189,9 +189,9 @@ class Pattern:
|
|
189 |
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
190 |
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
191 |
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
|
192 |
-
device (
|
193 |
Returns:
|
194 |
-
torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
|
195 |
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
196 |
"""
|
197 |
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
@@ -295,7 +295,7 @@ class CodebooksPatternProvider(ABC):
|
|
295 |
"""Builds pattern with specific interleaving between codebooks.
|
296 |
|
297 |
Args:
|
298 |
-
timesteps (int): Total
|
299 |
"""
|
300 |
raise NotImplementedError()
|
301 |
|
@@ -318,7 +318,7 @@ class DelayedPatternProvider(CodebooksPatternProvider):
|
|
318 |
|
319 |
Args:
|
320 |
n_q (int): Number of codebooks.
|
321 |
-
delays (
|
322 |
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
323 |
flatten_first (int): Flatten the first N timesteps.
|
324 |
empty_initial (int): Prepend with N empty list of coordinates.
|
@@ -406,10 +406,10 @@ class UnrolledPatternProvider(CodebooksPatternProvider):
|
|
406 |
|
407 |
Args:
|
408 |
n_q (int): Number of codebooks.
|
409 |
-
flattening (
|
410 |
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
|
411 |
have n_q extra steps for each timestep.
|
412 |
-
delays (
|
413 |
no delay is added and therefore will default to [0] * ``n_q``.
|
414 |
Note that two codebooks that will be flattened to the same inner step
|
415 |
should have the same delay, otherwise the pattern is considered as invalid.
|
@@ -462,7 +462,7 @@ class UnrolledPatternProvider(CodebooksPatternProvider):
|
|
462 |
"""Builds pattern for delay across codebooks.
|
463 |
|
464 |
Args:
|
465 |
-
timesteps (int): Total
|
466 |
"""
|
467 |
# the PatternLayout is built as a tuple of sequence position and list of coordinates
|
468 |
# so that it can be reordered properly given the required delay between codebooks of given timesteps
|
@@ -486,13 +486,18 @@ class UnrolledPatternProvider(CodebooksPatternProvider):
|
|
486 |
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
487 |
|
488 |
|
489 |
-
class
|
490 |
-
"""
|
491 |
-
|
|
|
|
|
|
|
|
|
|
|
492 |
|
493 |
Args:
|
494 |
n_q (int): Number of codebooks.
|
495 |
-
delays (
|
496 |
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
497 |
"""
|
498 |
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
|
|
|
122 |
Args:
|
123 |
timesteps (int): Maximum number of timesteps steps to consider.
|
124 |
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
|
125 |
+
device (torch.device or str): Device for created tensors.
|
126 |
Returns:
|
127 |
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
|
128 |
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
|
|
|
189 |
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
190 |
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
191 |
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
|
192 |
+
device (torch.device or str): Device for created tensors.
|
193 |
Returns:
|
194 |
+
indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
|
195 |
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
196 |
"""
|
197 |
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
|
|
295 |
"""Builds pattern with specific interleaving between codebooks.
|
296 |
|
297 |
Args:
|
298 |
+
timesteps (int): Total number of timesteps.
|
299 |
"""
|
300 |
raise NotImplementedError()
|
301 |
|
|
|
318 |
|
319 |
Args:
|
320 |
n_q (int): Number of codebooks.
|
321 |
+
delays (list of int, optional): Delay for each of the codebooks.
|
322 |
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
323 |
flatten_first (int): Flatten the first N timesteps.
|
324 |
empty_initial (int): Prepend with N empty list of coordinates.
|
|
|
406 |
|
407 |
Args:
|
408 |
n_q (int): Number of codebooks.
|
409 |
+
flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
|
410 |
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
|
411 |
have n_q extra steps for each timestep.
|
412 |
+
delays (list of int, optional): Delay for each of the codebooks. If not defined,
|
413 |
no delay is added and therefore will default to [0] * ``n_q``.
|
414 |
Note that two codebooks that will be flattened to the same inner step
|
415 |
should have the same delay, otherwise the pattern is considered as invalid.
|
|
|
462 |
"""Builds pattern for delay across codebooks.
|
463 |
|
464 |
Args:
|
465 |
+
timesteps (int): Total number of timesteps.
|
466 |
"""
|
467 |
# the PatternLayout is built as a tuple of sequence position and list of coordinates
|
468 |
# so that it can be reordered properly given the required delay between codebooks of given timesteps
|
|
|
486 |
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
487 |
|
488 |
|
489 |
+
class CoarseFirstPattern(CodebooksPatternProvider):
|
490 |
+
"""First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
|
491 |
+
potentially with delays.
|
492 |
+
|
493 |
+
..Warning:: You must always generate the full training duration at test time, for instance,
|
494 |
+
30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
|
495 |
+
location. This is due to the non causality of the remaining codebooks with respect to
|
496 |
+
the first ones.
|
497 |
|
498 |
Args:
|
499 |
n_q (int): Number of codebooks.
|
500 |
+
delays (list of int, optional): Delay for each of the codebooks.
|
501 |
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
502 |
"""
|
503 |
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
|
audiocraft/modules/conditioners.py
CHANGED
@@ -10,87 +10,61 @@ from dataclasses import dataclass, field
|
|
10 |
from itertools import chain
|
11 |
import logging
|
12 |
import math
|
|
|
13 |
import random
|
14 |
import re
|
15 |
import typing as tp
|
16 |
import warnings
|
17 |
|
18 |
-
|
19 |
from num2words import num2words
|
20 |
import spacy
|
21 |
-
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
22 |
-
import torchaudio
|
23 |
import torch
|
24 |
from torch import nn
|
25 |
-
from torch import Tensor
|
26 |
import torch.nn.functional as F
|
27 |
from torch.nn.utils.rnn import pad_sequence
|
28 |
|
|
|
29 |
from .streaming import StreamingModule
|
30 |
from .transformer import create_sin_embedding
|
|
|
31 |
from ..data.audio_dataset import SegmentInfo
|
|
|
|
|
|
|
32 |
from ..utils.autocast import TorchAutocast
|
33 |
-
from ..utils.
|
|
|
34 |
|
35 |
|
36 |
logger = logging.getLogger(__name__)
|
37 |
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
38 |
-
ConditionType = tp.Tuple[Tensor, Tensor] # condition, mask
|
39 |
|
40 |
|
41 |
class WavCondition(tp.NamedTuple):
|
42 |
-
wav: Tensor
|
43 |
-
length: Tensor
|
|
|
44 |
path: tp.List[tp.Optional[str]] = []
|
|
|
45 |
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
dim (int): the dimension that will be truncated (should be the time dimension)
|
55 |
-
WARNING!: dim should not be the batch dimension!
|
56 |
-
Returns:
|
57 |
-
ConditionType: a tuple of null condition and mask
|
58 |
-
"""
|
59 |
-
assert dim != 0, "dim cannot be the batch dimension!"
|
60 |
-
assert type(condition) == tuple and \
|
61 |
-
type(condition[0]) == Tensor and \
|
62 |
-
type(condition[1]) == Tensor, "'nullify_condition' got an unexpected input type!"
|
63 |
-
cond, mask = condition
|
64 |
-
B = cond.shape[0]
|
65 |
-
last_dim = cond.dim() - 1
|
66 |
-
out = cond.transpose(dim, last_dim)
|
67 |
-
out = 0. * out[..., :1]
|
68 |
-
out = out.transpose(dim, last_dim)
|
69 |
-
mask = torch.zeros((B, 1), device=out.device).int()
|
70 |
-
assert cond.dim() == out.dim()
|
71 |
-
return out, mask
|
72 |
-
|
73 |
-
|
74 |
-
def nullify_wav(wav: Tensor) -> WavCondition:
|
75 |
-
"""Create a nullified WavCondition from a wav tensor with appropriate shape.
|
76 |
-
|
77 |
-
Args:
|
78 |
-
wav (Tensor): tensor of shape [B, T]
|
79 |
-
Returns:
|
80 |
-
WavCondition: wav condition with nullified wav.
|
81 |
-
"""
|
82 |
-
null_wav, _ = nullify_condition((wav, torch.zeros_like(wav)), dim=wav.dim() - 1)
|
83 |
-
return WavCondition(
|
84 |
-
wav=null_wav,
|
85 |
-
length=torch.tensor([0] * wav.shape[0], device=wav.device),
|
86 |
-
path=['null_wav'] * wav.shape[0]
|
87 |
-
)
|
88 |
|
89 |
|
90 |
@dataclass
|
91 |
class ConditioningAttributes:
|
92 |
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
93 |
wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
|
|
|
94 |
|
95 |
def __getitem__(self, item):
|
96 |
return getattr(self, item)
|
@@ -103,14 +77,23 @@ class ConditioningAttributes:
|
|
103 |
def wav_attributes(self):
|
104 |
return self.wav.keys()
|
105 |
|
|
|
|
|
|
|
|
|
106 |
@property
|
107 |
def attributes(self):
|
108 |
-
return {
|
|
|
|
|
|
|
|
|
109 |
|
110 |
def to_flat_dict(self):
|
111 |
return {
|
112 |
**{f"text.{k}": v for k, v in self.text.items()},
|
113 |
**{f"wav.{k}": v for k, v in self.wav.items()},
|
|
|
114 |
}
|
115 |
|
116 |
@classmethod
|
@@ -131,11 +114,74 @@ class SegmentWithAttributes(SegmentInfo):
|
|
131 |
raise NotImplementedError()
|
132 |
|
133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
class Tokenizer:
|
135 |
-
"""Base
|
136 |
(in case we want to introduce more advances tokenizers in the future).
|
137 |
"""
|
138 |
-
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
|
139 |
raise NotImplementedError()
|
140 |
|
141 |
|
@@ -146,7 +192,7 @@ class WhiteSpaceTokenizer(Tokenizer):
|
|
146 |
[[78, 62, 31, 4, 78, 25, 19, 34],
|
147 |
[59, 77, 0, 0, 0, 0, 0, 0]]
|
148 |
"""
|
149 |
-
|
150 |
|
151 |
def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
|
152 |
lemma: bool = True, stopwords: bool = True) -> None:
|
@@ -161,18 +207,15 @@ class WhiteSpaceTokenizer(Tokenizer):
|
|
161 |
self.nlp = spacy.load(language)
|
162 |
|
163 |
@tp.no_type_check
|
164 |
-
def __call__(
|
165 |
-
|
166 |
-
texts: tp.List[tp.Optional[str]],
|
167 |
-
return_text: bool = False
|
168 |
-
) -> tp.Tuple[Tensor, Tensor]:
|
169 |
"""Take a list of strings and convert them to a tensor of indices.
|
170 |
|
171 |
Args:
|
172 |
-
texts (
|
173 |
return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
|
174 |
Returns:
|
175 |
-
|
176 |
- Indices of words in the LUT.
|
177 |
- And a mask indicating where the padding tokens are
|
178 |
"""
|
@@ -181,7 +224,7 @@ class WhiteSpaceTokenizer(Tokenizer):
|
|
181 |
for i, text in enumerate(texts):
|
182 |
# if current sample doesn't have a certain attribute, replace with pad token
|
183 |
if text is None:
|
184 |
-
output.append(Tensor([self.pad_idx]))
|
185 |
lengths.append(0)
|
186 |
continue
|
187 |
|
@@ -192,15 +235,15 @@ class WhiteSpaceTokenizer(Tokenizer):
|
|
192 |
# remove stopwords
|
193 |
if self.stopwords:
|
194 |
text = [w for w in text if not w.is_stop] # type: ignore
|
195 |
-
# remove
|
196 |
-
text = [w for w in text if w.text not in self.
|
197 |
# lemmatize if needed
|
198 |
text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
|
199 |
|
200 |
texts[i] = " ".join(text)
|
201 |
lengths.append(len(text))
|
202 |
# convert to tensor
|
203 |
-
tokens = Tensor([hash_trick(w, self.n_bins) for w in text])
|
204 |
output.append(tokens)
|
205 |
|
206 |
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
@@ -224,7 +267,7 @@ class NoopTokenizer(Tokenizer):
|
|
224 |
self.n_bins = n_bins
|
225 |
self.pad_idx = pad_idx
|
226 |
|
227 |
-
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
|
228 |
output, lengths = [], []
|
229 |
for text in texts:
|
230 |
# if current sample doesn't have a certain attribute, replace with pad token
|
@@ -241,15 +284,16 @@ class NoopTokenizer(Tokenizer):
|
|
241 |
|
242 |
|
243 |
class BaseConditioner(nn.Module):
|
244 |
-
"""Base model for all conditioner modules.
|
245 |
-
|
|
|
246 |
2) make all condition dims consistent.
|
247 |
|
248 |
Args:
|
249 |
-
dim (int): Hidden dim of the model
|
250 |
output_dim (int): Output dim of the conditioner.
|
251 |
"""
|
252 |
-
def __init__(self, dim, output_dim):
|
253 |
super().__init__()
|
254 |
self.dim = dim
|
255 |
self.output_dim = output_dim
|
@@ -294,9 +338,9 @@ class LUTConditioner(TextConditioner):
|
|
294 |
super().__init__(dim, output_dim)
|
295 |
self.embed = nn.Embedding(n_bins, dim)
|
296 |
self.tokenizer: Tokenizer
|
297 |
-
if tokenizer ==
|
298 |
self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
|
299 |
-
elif tokenizer ==
|
300 |
self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
|
301 |
else:
|
302 |
raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
|
@@ -346,13 +390,12 @@ class T5Conditioner(TextConditioner):
|
|
346 |
def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
|
347 |
autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
|
348 |
normalize_text: bool = False):
|
349 |
-
assert name in self.MODELS, f"
|
350 |
super().__init__(self.MODELS_DIMS[name], output_dim)
|
351 |
self.device = device
|
352 |
self.name = name
|
353 |
self.finetune = finetune
|
354 |
self.word_dropout = word_dropout
|
355 |
-
|
356 |
if autocast_dtype is None or self.device == 'cpu':
|
357 |
self.autocast = TorchAutocast(enabled=False)
|
358 |
if self.device != 'cpu':
|
@@ -378,7 +421,7 @@ class T5Conditioner(TextConditioner):
|
|
378 |
else:
|
379 |
# this makes sure that the t5 models is not part
|
380 |
# of the saved checkpoint
|
381 |
-
self.__dict__[
|
382 |
|
383 |
self.normalize_text = normalize_text
|
384 |
if normalize_text:
|
@@ -398,13 +441,13 @@ class T5Conditioner(TextConditioner):
|
|
398 |
|
399 |
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
|
400 |
|
401 |
-
inputs = self.t5_tokenizer(entries, return_tensors=
|
402 |
-
mask = inputs[
|
403 |
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
404 |
return inputs
|
405 |
|
406 |
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
|
407 |
-
mask = inputs[
|
408 |
with torch.set_grad_enabled(self.finetune), self.autocast:
|
409 |
embeds = self.t5(**inputs).last_hidden_state
|
410 |
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
@@ -426,204 +469,558 @@ class WaveformConditioner(BaseConditioner):
|
|
426 |
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
|
427 |
super().__init__(dim, output_dim)
|
428 |
self.device = device
|
|
|
|
|
429 |
|
430 |
-
def tokenize(self,
|
431 |
-
wav, length, path =
|
432 |
assert length is not None
|
433 |
-
return WavCondition(wav.to(self.device), length.to(self.device), path)
|
434 |
|
435 |
-
def _get_wav_embedding(self,
|
436 |
-
"""Gets as input a
|
437 |
raise NotImplementedError()
|
438 |
|
439 |
def _downsampling_factor(self):
|
440 |
"""Returns the downsampling factor of the embedding model."""
|
441 |
raise NotImplementedError()
|
442 |
|
443 |
-
def forward(self,
|
444 |
-
"""
|
445 |
Args:
|
446 |
-
|
447 |
Returns:
|
448 |
-
ConditionType:
|
449 |
"""
|
450 |
-
wav, lengths,
|
451 |
with torch.no_grad():
|
452 |
-
embeds = self._get_wav_embedding(
|
453 |
embeds = embeds.to(self.output_proj.weight)
|
454 |
embeds = self.output_proj(embeds)
|
455 |
|
456 |
-
if lengths is not None:
|
457 |
lengths = lengths / self._downsampling_factor()
|
458 |
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
|
459 |
else:
|
460 |
-
mask = torch.ones_like(embeds)
|
461 |
-
embeds = (embeds * mask.unsqueeze(
|
462 |
-
|
463 |
return embeds, mask
|
464 |
|
465 |
|
466 |
class ChromaStemConditioner(WaveformConditioner):
|
467 |
-
"""Chroma conditioner
|
468 |
-
|
469 |
-
|
|
|
470 |
|
471 |
Args:
|
472 |
output_dim (int): Output dimension for the conditioner.
|
473 |
sample_rate (int): Sample rate for the chroma extractor.
|
474 |
-
n_chroma (int): Number of chroma for the chroma extractor.
|
475 |
-
radix2_exp (int):
|
476 |
-
duration (
|
477 |
in case we are using chroma as prefix.
|
478 |
-
match_len_on_eval (bool, optional):
|
479 |
duration. Defaults to False.
|
480 |
-
eval_wavs (str, optional):
|
481 |
conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
|
482 |
Defaults to None.
|
483 |
-
n_eval_wavs (int, optional):
|
484 |
device (tp.Union[torch.device, str], optional): Device for the conditioner.
|
485 |
**kwargs: Additional parameters for the chroma extractor.
|
486 |
"""
|
487 |
def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
|
488 |
duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
|
489 |
-
n_eval_wavs: int = 0,
|
|
|
490 |
from demucs import pretrained
|
491 |
super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
|
492 |
-
self.autocast = TorchAutocast(enabled=device !=
|
493 |
self.sample_rate = sample_rate
|
494 |
self.match_len_on_eval = match_len_on_eval
|
|
|
|
|
495 |
self.duration = duration
|
496 |
-
self.__dict__[
|
497 |
-
|
498 |
-
self.
|
499 |
-
self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
|
500 |
-
|
501 |
self.chroma_len = self._get_chroma_len()
|
502 |
-
|
503 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
return self.chroma.winhop
|
505 |
|
506 |
-
def
|
507 |
-
"""
|
508 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
509 |
dummy_chr = self.chroma(dummy_wav)
|
510 |
return dummy_chr.shape[1]
|
511 |
|
512 |
@torch.no_grad()
|
513 |
-
def
|
|
|
514 |
from demucs.apply import apply_model
|
515 |
from demucs.audio import convert_audio
|
516 |
with self.autocast:
|
517 |
-
wav = convert_audio(
|
|
|
518 |
stems = apply_model(self.demucs, wav, device=self.device)
|
519 |
-
stems = stems[:, self.
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
return stems
|
524 |
|
525 |
@torch.no_grad()
|
526 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
# avoid 0-size tensors when we are working with null conds
|
528 |
if wav.shape[-1] == 1:
|
529 |
-
return self.
|
530 |
-
stems = self.
|
531 |
-
chroma = self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
|
533 |
if self.match_len_on_eval:
|
534 |
-
|
535 |
-
if
|
536 |
chroma = chroma[:, :self.chroma_len]
|
537 |
-
logger.debug(f
|
538 |
-
elif
|
539 |
-
|
540 |
-
n_repeat = int(math.ceil(self.chroma_len / t))
|
541 |
chroma = chroma.repeat(1, n_repeat, 1)
|
542 |
chroma = chroma[:, :self.chroma_len]
|
543 |
-
logger.debug(f
|
|
|
544 |
return chroma
|
545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
546 |
|
547 |
-
|
548 |
-
|
|
|
549 |
|
550 |
Args:
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
"""
|
561 |
-
def __init__(self,
|
562 |
-
|
563 |
-
|
564 |
-
super().__init__()
|
565 |
-
from librosa import filters
|
566 |
self.device = device
|
567 |
-
self.
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
583 |
with self.autocast:
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
606 |
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
607 |
-
If the condition is of type "wav", then nullify it using
|
608 |
-
If the condition is of any other type, set its
|
609 |
Works in-place.
|
610 |
"""
|
611 |
-
if condition_type not in [
|
612 |
raise ValueError(
|
613 |
"dropout_condition got an unexpected condition type!"
|
614 |
-
f" expected 'wav' or '
|
615 |
)
|
616 |
|
617 |
if condition not in getattr(sample, condition_type):
|
618 |
raise ValueError(
|
619 |
"dropout_condition received an unexpected condition!"
|
620 |
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
|
621 |
-
f"but got '{condition}' of type '{condition_type}'!"
|
622 |
)
|
623 |
|
624 |
-
if condition_type ==
|
625 |
-
|
626 |
-
sample.wav[condition] = nullify_wav(
|
|
|
|
|
|
|
627 |
else:
|
628 |
sample.text[condition] = None
|
629 |
|
@@ -631,7 +1028,7 @@ def dropout_condition(sample: ConditioningAttributes, condition_type: str, condi
|
|
631 |
|
632 |
|
633 |
class DropoutModule(nn.Module):
|
634 |
-
"""Base
|
635 |
def __init__(self, seed: int = 1234):
|
636 |
super().__init__()
|
637 |
self.rng = torch.Generator()
|
@@ -639,10 +1036,11 @@ class DropoutModule(nn.Module):
|
|
639 |
|
640 |
|
641 |
class AttributeDropout(DropoutModule):
|
642 |
-
"""
|
643 |
-
|
644 |
-
"artist" can be dropped while "genre" remains.
|
645 |
-
where if "artist" is dropped "genre"
|
|
|
646 |
|
647 |
Args:
|
648 |
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
|
@@ -665,21 +1063,19 @@ class AttributeDropout(DropoutModule):
|
|
665 |
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
666 |
"""
|
667 |
Args:
|
668 |
-
samples (
|
669 |
Returns:
|
670 |
-
|
671 |
"""
|
672 |
if not self.training and not self.active_on_eval:
|
673 |
return samples
|
674 |
|
675 |
samples = deepcopy(samples)
|
676 |
-
|
677 |
for condition_type, ps in self.p.items(): # for condition types [text, wav]
|
678 |
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
|
679 |
if torch.rand(1, generator=self.rng).item() < p:
|
680 |
for sample in samples:
|
681 |
dropout_condition(sample, condition_type, condition)
|
682 |
-
|
683 |
return samples
|
684 |
|
685 |
def __repr__(self):
|
@@ -687,8 +1083,8 @@ class AttributeDropout(DropoutModule):
|
|
687 |
|
688 |
|
689 |
class ClassifierFreeGuidanceDropout(DropoutModule):
|
690 |
-
"""
|
691 |
-
are dropped with the same probability.
|
692 |
|
693 |
Args:
|
694 |
p (float): Probability to apply condition dropout during training.
|
@@ -701,9 +1097,9 @@ class ClassifierFreeGuidanceDropout(DropoutModule):
|
|
701 |
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
702 |
"""
|
703 |
Args:
|
704 |
-
samples (
|
705 |
Returns:
|
706 |
-
|
707 |
"""
|
708 |
if not self.training:
|
709 |
return samples
|
@@ -715,12 +1111,10 @@ class ClassifierFreeGuidanceDropout(DropoutModule):
|
|
715 |
|
716 |
# nullify conditions of all attributes
|
717 |
samples = deepcopy(samples)
|
718 |
-
|
719 |
for condition_type in ["wav", "text"]:
|
720 |
for sample in samples:
|
721 |
for condition in sample.attributes[condition_type]:
|
722 |
dropout_condition(sample, condition_type, condition)
|
723 |
-
|
724 |
return samples
|
725 |
|
726 |
def __repr__(self):
|
@@ -728,29 +1122,25 @@ class ClassifierFreeGuidanceDropout(DropoutModule):
|
|
728 |
|
729 |
|
730 |
class ConditioningProvider(nn.Module):
|
731 |
-
"""
|
732 |
|
733 |
Args:
|
734 |
conditioners (dict): Dictionary of conditioners.
|
735 |
-
|
736 |
-
into a single text condition. Defaults to 0.
|
737 |
-
drop_desc_p (float, optional): Probability to drop the original description
|
738 |
-
when merging all text sources into a single text condition. Defaults to 0.
|
739 |
-
device (tp.Union[torch.device, str], optional): Device for conditioners and output condition types.
|
740 |
"""
|
741 |
-
def __init__(
|
742 |
-
self,
|
743 |
-
conditioners: tp.Dict[str, BaseConditioner],
|
744 |
-
merge_text_conditions_p: float = 0,
|
745 |
-
drop_desc_p: float = 0,
|
746 |
-
device: tp.Union[torch.device, str] = "cpu",
|
747 |
-
):
|
748 |
super().__init__()
|
749 |
self.device = device
|
750 |
-
self.merge_text_conditions_p = merge_text_conditions_p
|
751 |
-
self.drop_desc_p = drop_desc_p
|
752 |
self.conditioners = nn.ModuleDict(conditioners)
|
753 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
754 |
@property
|
755 |
def text_conditions(self):
|
756 |
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
|
@@ -769,33 +1159,36 @@ class ConditioningProvider(nn.Module):
|
|
769 |
This will return a dict matching conditioner names to their arbitrary tokenized representations.
|
770 |
|
771 |
Args:
|
772 |
-
inputs (list[
|
773 |
text and wav conditions.
|
774 |
"""
|
775 |
-
assert all([
|
776 |
-
"
|
777 |
f" but types were {set([type(x) for x in inputs])}"
|
|
|
778 |
|
779 |
output = {}
|
780 |
text = self._collate_text(inputs)
|
781 |
wavs = self._collate_wavs(inputs)
|
|
|
782 |
|
783 |
-
assert set(text.keys() | wavs.keys()).issubset(set(self.conditioners.keys())),
|
784 |
-
f"
|
|
|
|
|
785 |
|
786 |
-
for attribute, batch in chain(text.items(), wavs.items()):
|
787 |
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
788 |
return output
|
789 |
|
790 |
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
791 |
-
"""Compute pairs of `(embedding, mask)` using the configured conditioners
|
792 |
-
|
793 |
-
|
794 |
-
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
-
}
|
799 |
|
800 |
Args:
|
801 |
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
|
@@ -820,51 +1213,22 @@ class ConditioningProvider(nn.Module):
|
|
820 |
"genre": ["Rock", "Hip-hop"],
|
821 |
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
|
822 |
}
|
823 |
-
"""
|
824 |
-
batch_per_attribute: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
|
825 |
-
|
826 |
-
def _merge_conds(cond, merge_text_conditions_p=0, drop_desc_p=0):
|
827 |
-
def is_valid(k, v):
|
828 |
-
k_valid = k in ['key', 'bpm', 'genre', 'moods', 'instrument']
|
829 |
-
v_valid = v is not None and isinstance(v, (int, float, str, list))
|
830 |
-
return k_valid and v_valid
|
831 |
-
|
832 |
-
def process_value(v):
|
833 |
-
if isinstance(v, (int, float, str)):
|
834 |
-
return v
|
835 |
-
if isinstance(v, list):
|
836 |
-
return ", ".join(v)
|
837 |
-
else:
|
838 |
-
RuntimeError(f"unknown type for text value! ({type(v), v})")
|
839 |
-
|
840 |
-
desc = cond.text['description']
|
841 |
-
meta_data = ""
|
842 |
-
if random.uniform(0, 1) < merge_text_conditions_p:
|
843 |
-
meta_pairs = [f'{k}: {process_value(v)}' for k, v in cond.text.items() if is_valid(k, v)]
|
844 |
-
random.shuffle(meta_pairs)
|
845 |
-
meta_data = ". ".join(meta_pairs)
|
846 |
-
desc = desc if not random.uniform(0, 1) < drop_desc_p else None
|
847 |
-
|
848 |
-
if desc is None:
|
849 |
-
desc = meta_data if len(meta_data) > 1 else None
|
850 |
-
else:
|
851 |
-
desc = desc.rstrip('.') + ". " + meta_data
|
852 |
-
cond.text['description'] = desc.strip() if desc else None
|
853 |
-
|
854 |
-
if self.training and self.merge_text_conditions_p:
|
855 |
-
for sample in samples:
|
856 |
-
_merge_conds(sample, self.merge_text_conditions_p, self.drop_desc_p)
|
857 |
|
|
|
|
|
|
|
|
|
|
|
|
|
858 |
texts = [x.text for x in samples]
|
859 |
for text in texts:
|
860 |
for condition in self.text_conditions:
|
861 |
-
|
862 |
-
|
863 |
-
return batch_per_attribute
|
864 |
|
865 |
-
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]):
|
866 |
"""Generate a dict where the keys are attributes by which we fetch similar wavs,
|
867 |
-
and the values are Tensors of wavs according to said
|
868 |
|
869 |
*Note*: by the time the samples reach this function, each sample should have some waveform
|
870 |
inside the "wav" attribute. It should be either:
|
@@ -873,27 +1237,89 @@ class ConditioningProvider(nn.Module):
|
|
873 |
3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
|
874 |
|
875 |
Args:
|
876 |
-
samples (
|
877 |
Returns:
|
878 |
-
dict: A
|
879 |
"""
|
880 |
wavs = defaultdict(list)
|
881 |
-
|
|
|
882 |
paths = defaultdict(list)
|
883 |
-
|
|
|
884 |
|
885 |
for sample in samples:
|
886 |
for attribute in self.wav_conditions:
|
887 |
-
wav, length, path = sample.wav[attribute]
|
888 |
-
|
889 |
-
|
890 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
891 |
|
892 |
# stack all wavs to a single tensor
|
893 |
for attribute in self.wav_conditions:
|
894 |
stacked_wav, _ = collate(wavs[attribute], dim=0)
|
895 |
-
out[attribute] = WavCondition(
|
896 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
897 |
|
898 |
return out
|
899 |
|
@@ -920,7 +1346,7 @@ class ConditionFuser(StreamingModule):
|
|
920 |
super().__init__()
|
921 |
assert all(
|
922 |
[k in self.FUSING_METHODS for k in fuse2cond.keys()]
|
923 |
-
), f"
|
924 |
self.cross_attention_pos_emb = cross_attention_pos_emb
|
925 |
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
|
926 |
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
|
@@ -931,16 +1357,16 @@ class ConditionFuser(StreamingModule):
|
|
931 |
|
932 |
def forward(
|
933 |
self,
|
934 |
-
input: Tensor,
|
935 |
conditions: tp.Dict[str, ConditionType]
|
936 |
-
) -> tp.Tuple[Tensor, tp.Optional[Tensor]]:
|
937 |
"""Fuse the conditions to the provided model input.
|
938 |
|
939 |
Args:
|
940 |
-
input (Tensor): Transformer input.
|
941 |
-
conditions (
|
942 |
Returns:
|
943 |
-
|
944 |
after the conditions have been fused. The second output tensor is the tensor
|
945 |
used for cross-attention or None if no cross attention inputs exist.
|
946 |
"""
|
@@ -959,16 +1385,16 @@ class ConditionFuser(StreamingModule):
|
|
959 |
cross_attention_output = None
|
960 |
for cond_type, (cond, cond_mask) in conditions.items():
|
961 |
op = self.cond2fuse[cond_type]
|
962 |
-
if op ==
|
963 |
input += cond
|
964 |
-
elif op ==
|
965 |
-
cond = rearrange(cond, "b t d -> b d t")
|
966 |
cond = F.interpolate(cond, size=input.shape[1])
|
967 |
-
input += rearrange(cond, "b d t -> b t d")
|
968 |
-
elif op ==
|
969 |
if first_step:
|
970 |
input = torch.cat([cond, input], dim=1)
|
971 |
-
elif op ==
|
972 |
if cross_attention_output is not None:
|
973 |
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
|
974 |
else:
|
|
|
10 |
from itertools import chain
|
11 |
import logging
|
12 |
import math
|
13 |
+
from pathlib import Path
|
14 |
import random
|
15 |
import re
|
16 |
import typing as tp
|
17 |
import warnings
|
18 |
|
19 |
+
import einops
|
20 |
from num2words import num2words
|
21 |
import spacy
|
22 |
+
from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
|
|
|
23 |
import torch
|
24 |
from torch import nn
|
|
|
25 |
import torch.nn.functional as F
|
26 |
from torch.nn.utils.rnn import pad_sequence
|
27 |
|
28 |
+
from .chroma import ChromaExtractor
|
29 |
from .streaming import StreamingModule
|
30 |
from .transformer import create_sin_embedding
|
31 |
+
from ..data.audio import audio_read
|
32 |
from ..data.audio_dataset import SegmentInfo
|
33 |
+
from ..data.audio_utils import convert_audio
|
34 |
+
from ..environment import AudioCraftEnvironment
|
35 |
+
from ..quantization import ResidualVectorQuantizer
|
36 |
from ..utils.autocast import TorchAutocast
|
37 |
+
from ..utils.cache import EmbeddingCache
|
38 |
+
from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
|
39 |
|
40 |
|
41 |
logger = logging.getLogger(__name__)
|
42 |
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
43 |
+
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
44 |
|
45 |
|
46 |
class WavCondition(tp.NamedTuple):
|
47 |
+
wav: torch.Tensor
|
48 |
+
length: torch.Tensor
|
49 |
+
sample_rate: tp.List[int]
|
50 |
path: tp.List[tp.Optional[str]] = []
|
51 |
+
seek_time: tp.List[tp.Optional[float]] = []
|
52 |
|
53 |
|
54 |
+
class JointEmbedCondition(tp.NamedTuple):
|
55 |
+
wav: torch.Tensor
|
56 |
+
text: tp.List[tp.Optional[str]]
|
57 |
+
length: torch.Tensor
|
58 |
+
sample_rate: tp.List[int]
|
59 |
+
path: tp.List[tp.Optional[str]] = []
|
60 |
+
seek_time: tp.List[tp.Optional[float]] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
@dataclass
|
64 |
class ConditioningAttributes:
|
65 |
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
66 |
wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
|
67 |
+
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
|
68 |
|
69 |
def __getitem__(self, item):
|
70 |
return getattr(self, item)
|
|
|
77 |
def wav_attributes(self):
|
78 |
return self.wav.keys()
|
79 |
|
80 |
+
@property
|
81 |
+
def joint_embed_attributes(self):
|
82 |
+
return self.joint_embed.keys()
|
83 |
+
|
84 |
@property
|
85 |
def attributes(self):
|
86 |
+
return {
|
87 |
+
"text": self.text_attributes,
|
88 |
+
"wav": self.wav_attributes,
|
89 |
+
"joint_embed": self.joint_embed_attributes,
|
90 |
+
}
|
91 |
|
92 |
def to_flat_dict(self):
|
93 |
return {
|
94 |
**{f"text.{k}": v for k, v in self.text.items()},
|
95 |
**{f"wav.{k}": v for k, v in self.wav.items()},
|
96 |
+
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
|
97 |
}
|
98 |
|
99 |
@classmethod
|
|
|
114 |
raise NotImplementedError()
|
115 |
|
116 |
|
117 |
+
def nullify_condition(condition: ConditionType, dim: int = 1):
|
118 |
+
"""Transform an input condition to a null condition.
|
119 |
+
The way it is done by converting it to a single zero vector similarly
|
120 |
+
to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
|
124 |
+
dim (int): The dimension that will be truncated (should be the time dimension)
|
125 |
+
WARNING!: dim should not be the batch dimension!
|
126 |
+
Returns:
|
127 |
+
ConditionType: A tuple of null condition and mask
|
128 |
+
"""
|
129 |
+
assert dim != 0, "dim cannot be the batch dimension!"
|
130 |
+
assert isinstance(condition, tuple) and \
|
131 |
+
isinstance(condition[0], torch.Tensor) and \
|
132 |
+
isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
|
133 |
+
cond, mask = condition
|
134 |
+
B = cond.shape[0]
|
135 |
+
last_dim = cond.dim() - 1
|
136 |
+
out = cond.transpose(dim, last_dim)
|
137 |
+
out = 0. * out[..., :1]
|
138 |
+
out = out.transpose(dim, last_dim)
|
139 |
+
mask = torch.zeros((B, 1), device=out.device).int()
|
140 |
+
assert cond.dim() == out.dim()
|
141 |
+
return out, mask
|
142 |
+
|
143 |
+
|
144 |
+
def nullify_wav(cond: WavCondition) -> WavCondition:
|
145 |
+
"""Transform a WavCondition to a nullified WavCondition.
|
146 |
+
It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
|
150 |
+
Returns:
|
151 |
+
WavCondition: Nullified wav condition.
|
152 |
+
"""
|
153 |
+
null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
|
154 |
+
return WavCondition(
|
155 |
+
wav=null_wav,
|
156 |
+
length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
|
157 |
+
sample_rate=cond.sample_rate,
|
158 |
+
path=[None] * cond.wav.shape[0],
|
159 |
+
seek_time=[None] * cond.wav.shape[0],
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
|
164 |
+
"""Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
|
165 |
+
and replacing metadata by dummy attributes.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
|
169 |
+
"""
|
170 |
+
null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
|
171 |
+
return JointEmbedCondition(
|
172 |
+
wav=null_wav, text=[None] * len(embed.text),
|
173 |
+
length=torch.LongTensor([0]).to(embed.wav.device),
|
174 |
+
sample_rate=embed.sample_rate,
|
175 |
+
path=[None] * embed.wav.shape[0],
|
176 |
+
seek_time=[0] * embed.wav.shape[0],
|
177 |
+
)
|
178 |
+
|
179 |
+
|
180 |
class Tokenizer:
|
181 |
+
"""Base tokenizer implementation
|
182 |
(in case we want to introduce more advances tokenizers in the future).
|
183 |
"""
|
184 |
+
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
185 |
raise NotImplementedError()
|
186 |
|
187 |
|
|
|
192 |
[[78, 62, 31, 4, 78, 25, 19, 34],
|
193 |
[59, 77, 0, 0, 0, 0, 0, 0]]
|
194 |
"""
|
195 |
+
PUNCTUATION = "?:!.,;"
|
196 |
|
197 |
def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
|
198 |
lemma: bool = True, stopwords: bool = True) -> None:
|
|
|
207 |
self.nlp = spacy.load(language)
|
208 |
|
209 |
@tp.no_type_check
|
210 |
+
def __call__(self, texts: tp.List[tp.Optional[str]],
|
211 |
+
return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
212 |
"""Take a list of strings and convert them to a tensor of indices.
|
213 |
|
214 |
Args:
|
215 |
+
texts (list[str]): List of strings.
|
216 |
return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
|
217 |
Returns:
|
218 |
+
tuple[torch.Tensor, torch.Tensor]:
|
219 |
- Indices of words in the LUT.
|
220 |
- And a mask indicating where the padding tokens are
|
221 |
"""
|
|
|
224 |
for i, text in enumerate(texts):
|
225 |
# if current sample doesn't have a certain attribute, replace with pad token
|
226 |
if text is None:
|
227 |
+
output.append(torch.Tensor([self.pad_idx]))
|
228 |
lengths.append(0)
|
229 |
continue
|
230 |
|
|
|
235 |
# remove stopwords
|
236 |
if self.stopwords:
|
237 |
text = [w for w in text if not w.is_stop] # type: ignore
|
238 |
+
# remove punctuation
|
239 |
+
text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
|
240 |
# lemmatize if needed
|
241 |
text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
|
242 |
|
243 |
texts[i] = " ".join(text)
|
244 |
lengths.append(len(text))
|
245 |
# convert to tensor
|
246 |
+
tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
|
247 |
output.append(tokens)
|
248 |
|
249 |
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
|
|
267 |
self.n_bins = n_bins
|
268 |
self.pad_idx = pad_idx
|
269 |
|
270 |
+
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
271 |
output, lengths = [], []
|
272 |
for text in texts:
|
273 |
# if current sample doesn't have a certain attribute, replace with pad token
|
|
|
284 |
|
285 |
|
286 |
class BaseConditioner(nn.Module):
|
287 |
+
"""Base model for all conditioner modules.
|
288 |
+
We allow the output dim to be different than the hidden dim for two reasons:
|
289 |
+
1) keep our LUTs small when the vocab is large;
|
290 |
2) make all condition dims consistent.
|
291 |
|
292 |
Args:
|
293 |
+
dim (int): Hidden dim of the model.
|
294 |
output_dim (int): Output dim of the conditioner.
|
295 |
"""
|
296 |
+
def __init__(self, dim: int, output_dim: int):
|
297 |
super().__init__()
|
298 |
self.dim = dim
|
299 |
self.output_dim = output_dim
|
|
|
338 |
super().__init__(dim, output_dim)
|
339 |
self.embed = nn.Embedding(n_bins, dim)
|
340 |
self.tokenizer: Tokenizer
|
341 |
+
if tokenizer == 'whitespace':
|
342 |
self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
|
343 |
+
elif tokenizer == 'noop':
|
344 |
self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
|
345 |
else:
|
346 |
raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
|
|
|
390 |
def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
|
391 |
autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
|
392 |
normalize_text: bool = False):
|
393 |
+
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
394 |
super().__init__(self.MODELS_DIMS[name], output_dim)
|
395 |
self.device = device
|
396 |
self.name = name
|
397 |
self.finetune = finetune
|
398 |
self.word_dropout = word_dropout
|
|
|
399 |
if autocast_dtype is None or self.device == 'cpu':
|
400 |
self.autocast = TorchAutocast(enabled=False)
|
401 |
if self.device != 'cpu':
|
|
|
421 |
else:
|
422 |
# this makes sure that the t5 models is not part
|
423 |
# of the saved checkpoint
|
424 |
+
self.__dict__['t5'] = t5.to(device)
|
425 |
|
426 |
self.normalize_text = normalize_text
|
427 |
if normalize_text:
|
|
|
441 |
|
442 |
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
|
443 |
|
444 |
+
inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device)
|
445 |
+
mask = inputs['attention_mask']
|
446 |
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
447 |
return inputs
|
448 |
|
449 |
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
|
450 |
+
mask = inputs['attention_mask']
|
451 |
with torch.set_grad_enabled(self.finetune), self.autocast:
|
452 |
embeds = self.t5(**inputs).last_hidden_state
|
453 |
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
|
|
469 |
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
|
470 |
super().__init__(dim, output_dim)
|
471 |
self.device = device
|
472 |
+
# if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample.
|
473 |
+
self._use_masking = True
|
474 |
|
475 |
+
def tokenize(self, x: WavCondition) -> WavCondition:
|
476 |
+
wav, length, sample_rate, path, seek_time = x
|
477 |
assert length is not None
|
478 |
+
return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
|
479 |
|
480 |
+
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
481 |
+
"""Gets as input a WavCondition and returns a dense embedding."""
|
482 |
raise NotImplementedError()
|
483 |
|
484 |
def _downsampling_factor(self):
|
485 |
"""Returns the downsampling factor of the embedding model."""
|
486 |
raise NotImplementedError()
|
487 |
|
488 |
+
def forward(self, x: WavCondition) -> ConditionType:
|
489 |
+
"""Extract condition embedding and mask from a waveform and its metadata.
|
490 |
Args:
|
491 |
+
x (WavCondition): Waveform condition containing raw waveform and metadata.
|
492 |
Returns:
|
493 |
+
ConditionType: a dense vector representing the conditioning along with its mask
|
494 |
"""
|
495 |
+
wav, lengths, *_ = x
|
496 |
with torch.no_grad():
|
497 |
+
embeds = self._get_wav_embedding(x)
|
498 |
embeds = embeds.to(self.output_proj.weight)
|
499 |
embeds = self.output_proj(embeds)
|
500 |
|
501 |
+
if lengths is not None and self._use_masking:
|
502 |
lengths = lengths / self._downsampling_factor()
|
503 |
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
|
504 |
else:
|
505 |
+
mask = torch.ones_like(embeds[..., 0])
|
506 |
+
embeds = (embeds * mask.unsqueeze(-1))
|
|
|
507 |
return embeds, mask
|
508 |
|
509 |
|
510 |
class ChromaStemConditioner(WaveformConditioner):
|
511 |
+
"""Chroma conditioner based on stems.
|
512 |
+
The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
|
513 |
+
the drums and bass often dominate the chroma leading to the chroma features
|
514 |
+
not containing information about the melody.
|
515 |
|
516 |
Args:
|
517 |
output_dim (int): Output dimension for the conditioner.
|
518 |
sample_rate (int): Sample rate for the chroma extractor.
|
519 |
+
n_chroma (int): Number of chroma bins for the chroma extractor.
|
520 |
+
radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
|
521 |
+
duration (int): duration used during training. This is later used for correct padding
|
522 |
in case we are using chroma as prefix.
|
523 |
+
match_len_on_eval (bool, optional): if True then all chromas are padded to the training
|
524 |
duration. Defaults to False.
|
525 |
+
eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
|
526 |
conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
|
527 |
Defaults to None.
|
528 |
+
n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
|
529 |
device (tp.Union[torch.device, str], optional): Device for the conditioner.
|
530 |
**kwargs: Additional parameters for the chroma extractor.
|
531 |
"""
|
532 |
def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
|
533 |
duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
|
534 |
+
n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
|
535 |
+
device: tp.Union[torch.device, str] = 'cpu', **kwargs):
|
536 |
from demucs import pretrained
|
537 |
super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
|
538 |
+
self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
|
539 |
self.sample_rate = sample_rate
|
540 |
self.match_len_on_eval = match_len_on_eval
|
541 |
+
if match_len_on_eval:
|
542 |
+
self._use_masking = False
|
543 |
self.duration = duration
|
544 |
+
self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
|
545 |
+
stem_sources: list = self.demucs.sources # type: ignore
|
546 |
+
self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
|
547 |
+
self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
|
548 |
+
radix2_exp=radix2_exp, **kwargs).to(device)
|
549 |
self.chroma_len = self._get_chroma_len()
|
550 |
+
self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
|
551 |
+
self.cache = None
|
552 |
+
if cache_path is not None:
|
553 |
+
self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
554 |
+
compute_embed_fn=self._get_full_chroma_for_cache,
|
555 |
+
extract_embed_fn=self._extract_chroma_chunk)
|
556 |
+
|
557 |
+
def _downsampling_factor(self) -> int:
|
558 |
return self.chroma.winhop
|
559 |
|
560 |
+
def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
|
561 |
+
"""Load pre-defined waveforms from a json.
|
562 |
+
These waveforms will be used for chroma extraction during evaluation.
|
563 |
+
This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
|
564 |
+
"""
|
565 |
+
if path is None:
|
566 |
+
return None
|
567 |
+
|
568 |
+
logger.info(f"Loading evaluation wavs from {path}")
|
569 |
+
from audiocraft.data.audio_dataset import AudioDataset
|
570 |
+
dataset: AudioDataset = AudioDataset.from_meta(
|
571 |
+
path, segment_duration=self.duration, min_audio_duration=self.duration,
|
572 |
+
sample_rate=self.sample_rate, channels=1)
|
573 |
+
|
574 |
+
if len(dataset) > 0:
|
575 |
+
eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
|
576 |
+
logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
|
577 |
+
return eval_wavs
|
578 |
+
else:
|
579 |
+
raise ValueError("Could not find evaluation wavs, check lengths of wavs")
|
580 |
+
|
581 |
+
def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
|
582 |
+
self.eval_wavs = eval_wavs
|
583 |
+
|
584 |
+
def has_eval_wavs(self) -> bool:
|
585 |
+
return self.eval_wavs is not None
|
586 |
+
|
587 |
+
def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
|
588 |
+
"""Sample wavs from a predefined list."""
|
589 |
+
assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
|
590 |
+
total_eval_wavs = len(self.eval_wavs)
|
591 |
+
out = self.eval_wavs
|
592 |
+
if num_samples > total_eval_wavs:
|
593 |
+
out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
|
594 |
+
return out[torch.randperm(len(out))][:num_samples]
|
595 |
+
|
596 |
+
def _get_chroma_len(self) -> int:
|
597 |
+
"""Get length of chroma during training."""
|
598 |
+
dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
|
599 |
dummy_chr = self.chroma(dummy_wav)
|
600 |
return dummy_chr.shape[1]
|
601 |
|
602 |
@torch.no_grad()
|
603 |
+
def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
604 |
+
"""Get parts of the wav that holds the melody, extracting the main stems from the wav."""
|
605 |
from demucs.apply import apply_model
|
606 |
from demucs.audio import convert_audio
|
607 |
with self.autocast:
|
608 |
+
wav = convert_audio(
|
609 |
+
wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
|
610 |
stems = apply_model(self.demucs, wav, device=self.device)
|
611 |
+
stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
|
612 |
+
mix_wav = stems.sum(1) # merge extracted stems to single waveform
|
613 |
+
mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
|
614 |
+
return mix_wav
|
|
|
615 |
|
616 |
@torch.no_grad()
|
617 |
+
def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
|
618 |
+
"""Extract chroma features from the waveform."""
|
619 |
+
with self.autocast:
|
620 |
+
return self.chroma(wav)
|
621 |
+
|
622 |
+
@torch.no_grad()
|
623 |
+
def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
624 |
+
"""Compute wav embedding, applying stem and chroma extraction."""
|
625 |
# avoid 0-size tensors when we are working with null conds
|
626 |
if wav.shape[-1] == 1:
|
627 |
+
return self._extract_chroma(wav)
|
628 |
+
stems = self._get_stemmed_wav(wav, sample_rate)
|
629 |
+
chroma = self._extract_chroma(stems)
|
630 |
+
return chroma
|
631 |
+
|
632 |
+
@torch.no_grad()
|
633 |
+
def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
|
634 |
+
"""Extract chroma from the whole audio waveform at the given path."""
|
635 |
+
wav, sr = audio_read(path)
|
636 |
+
wav = wav[None].to(self.device)
|
637 |
+
wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
|
638 |
+
chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
|
639 |
+
return chroma
|
640 |
+
|
641 |
+
def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
|
642 |
+
"""Extract a chunk of chroma from the full chroma derived from the full waveform."""
|
643 |
+
wav_length = x.wav.shape[-1]
|
644 |
+
seek_time = x.seek_time[idx]
|
645 |
+
assert seek_time is not None, (
|
646 |
+
"WavCondition seek_time is required "
|
647 |
+
"when extracting chroma chunks from pre-computed chroma.")
|
648 |
+
full_chroma = full_chroma.float()
|
649 |
+
frame_rate = self.sample_rate / self._downsampling_factor()
|
650 |
+
target_length = int(frame_rate * wav_length / self.sample_rate)
|
651 |
+
index = int(frame_rate * seek_time)
|
652 |
+
out = full_chroma[index: index + target_length]
|
653 |
+
out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
|
654 |
+
return out.to(self.device)
|
655 |
+
|
656 |
+
@torch.no_grad()
|
657 |
+
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
|
658 |
+
"""Get the wav embedding from the WavCondition.
|
659 |
+
The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
|
660 |
+
or will rely on the embedding cache to load the pre-computed embedding if relevant.
|
661 |
+
"""
|
662 |
+
sampled_wav: tp.Optional[torch.Tensor] = None
|
663 |
+
if not self.training and self.eval_wavs is not None:
|
664 |
+
warn_once(logger, "Using precomputed evaluation wavs!")
|
665 |
+
sampled_wav = self._sample_eval_wavs(len(x.wav))
|
666 |
+
|
667 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
668 |
+
no_nullified_cond = x.wav.shape[-1] > 1
|
669 |
+
if sampled_wav is not None:
|
670 |
+
chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
|
671 |
+
elif self.cache is not None and no_undefined_paths and no_nullified_cond:
|
672 |
+
paths = [Path(p) for p in x.path if p is not None]
|
673 |
+
chroma = self.cache.get_embed_from_cache(paths, x)
|
674 |
+
else:
|
675 |
+
assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
|
676 |
+
chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
|
677 |
|
678 |
if self.match_len_on_eval:
|
679 |
+
B, T, C = chroma.shape
|
680 |
+
if T > self.chroma_len:
|
681 |
chroma = chroma[:, :self.chroma_len]
|
682 |
+
logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
|
683 |
+
elif T < self.chroma_len:
|
684 |
+
n_repeat = int(math.ceil(self.chroma_len / T))
|
|
|
685 |
chroma = chroma.repeat(1, n_repeat, 1)
|
686 |
chroma = chroma[:, :self.chroma_len]
|
687 |
+
logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
|
688 |
+
|
689 |
return chroma
|
690 |
|
691 |
+
def tokenize(self, x: WavCondition) -> WavCondition:
|
692 |
+
"""Apply WavConditioner tokenization and populate cache if needed."""
|
693 |
+
x = super().tokenize(x)
|
694 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
695 |
+
if self.cache is not None and no_undefined_paths:
|
696 |
+
paths = [Path(p) for p in x.path if p is not None]
|
697 |
+
self.cache.populate_embed_cache(paths, x)
|
698 |
+
return x
|
699 |
|
700 |
+
|
701 |
+
class JointEmbeddingConditioner(BaseConditioner):
|
702 |
+
"""Joint embedding conditioning supporting both audio or text conditioning.
|
703 |
|
704 |
Args:
|
705 |
+
dim (int): Dimension.
|
706 |
+
output_dim (int): Output dimension.
|
707 |
+
device (str): Device.
|
708 |
+
attribute (str): Attribute used by the conditioner.
|
709 |
+
autocast_dtype (str): Autocast for the conditioner.
|
710 |
+
quantize (bool): Whether to quantize the CLAP embedding.
|
711 |
+
n_q (int): Number of residual quantizers (used if quantize is true).
|
712 |
+
bins (int): Quantizers' codebooks size (used if quantize is true).
|
713 |
+
kwargs: Additional parameters for residual vector quantizer.
|
714 |
"""
|
715 |
+
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
|
716 |
+
autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
|
717 |
+
n_q: int = 12, bins: int = 1024, **kwargs):
|
718 |
+
super().__init__(dim=dim, output_dim=output_dim)
|
|
|
719 |
self.device = device
|
720 |
+
self.attribute = attribute
|
721 |
+
if autocast_dtype is None or device == 'cpu':
|
722 |
+
self.autocast = TorchAutocast(enabled=False)
|
723 |
+
logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
|
724 |
+
else:
|
725 |
+
dtype = getattr(torch, autocast_dtype)
|
726 |
+
assert isinstance(dtype, torch.dtype)
|
727 |
+
logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
|
728 |
+
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
|
729 |
+
# residual vector quantizer to discretize the conditioned embedding
|
730 |
+
self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
|
731 |
+
if quantize:
|
732 |
+
self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
|
733 |
+
|
734 |
+
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
735 |
+
"""Get joint embedding in latent space from the inputs.
|
736 |
+
|
737 |
+
Returns:
|
738 |
+
tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
|
739 |
+
and corresponding empty indexes.
|
740 |
+
"""
|
741 |
+
raise NotImplementedError()
|
742 |
+
|
743 |
+
def forward(self, x: JointEmbedCondition) -> ConditionType:
|
744 |
with self.autocast:
|
745 |
+
embed, empty_idx = self._get_embed(x)
|
746 |
+
if self.quantizer is not None:
|
747 |
+
embed = embed.view(-1, self.dim, 1)
|
748 |
+
q_res = self.quantizer(embed, frame_rate=1)
|
749 |
+
out_embed = q_res.x.view(-1, self.dim)
|
750 |
+
else:
|
751 |
+
out_embed = embed
|
752 |
+
out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
|
753 |
+
mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
|
754 |
+
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
755 |
+
out_embed = (out_embed * mask.unsqueeze(-1))
|
756 |
+
return out_embed, mask
|
757 |
+
|
758 |
+
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
759 |
+
return x
|
760 |
+
|
761 |
+
|
762 |
+
class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
|
763 |
+
"""Joint Embedding conditioner based on pre-trained CLAP model.
|
764 |
+
|
765 |
+
This CLAP-based conditioner supports a caching mechanism
|
766 |
+
over the computed embeddings for faster training.
|
767 |
+
|
768 |
+
Args:
|
769 |
+
dim (int): Dimension.
|
770 |
+
output_dim (int): Output dimension.
|
771 |
+
device (str): Device.
|
772 |
+
attribute (str): Attribute used by the conditioner.
|
773 |
+
quantize (bool): Whether to quantize the CLAP embedding.
|
774 |
+
n_q (int): Number of residual quantizers (used if quantize is true).
|
775 |
+
bins (int): Quantizers' codebooks size (used if quantize is true).
|
776 |
+
checkpoint (str): Path to CLAP checkpoint.
|
777 |
+
model_arch (str): CLAP model architecture.
|
778 |
+
enable_fusion (bool): Enable fusion for CLAP model.
|
779 |
+
sample_rate (int): Sample rate used by CLAP model.
|
780 |
+
max_audio_length (float): Maximum audio length for CLAP model.
|
781 |
+
audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
|
782 |
+
normalize (bool): Whether to normalize the CLAP embedding.
|
783 |
+
text_p (float): Probability of using text representation instead of audio at train time.
|
784 |
+
batch_size (Optional[int]): Batch size for CLAP embedding computation.
|
785 |
+
autocast_dtype (str): Autocast for the conditioner.
|
786 |
+
cache_path (Optional[str]): Path for pre-computed embeddings caching.
|
787 |
+
kwargs: Additional parameters for residual vector quantizer.
|
788 |
+
"""
|
789 |
+
def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
|
790 |
+
quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
|
791 |
+
enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
|
792 |
+
normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
|
793 |
+
autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
|
794 |
+
try:
|
795 |
+
import laion_clap # type: ignore
|
796 |
+
except ImportError:
|
797 |
+
raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
|
798 |
+
warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
|
799 |
+
"Please retrain all models.")
|
800 |
+
checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
|
801 |
+
clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
|
802 |
+
clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
|
803 |
+
load_clap_state_dict(clap_model, checkpoint)
|
804 |
+
clap_model.eval()
|
805 |
+
clap_model.to(device)
|
806 |
+
super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
|
807 |
+
autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
|
808 |
+
**kwargs)
|
809 |
+
self.checkpoint = checkpoint
|
810 |
+
self.enable_fusion = enable_fusion
|
811 |
+
self.model_arch = model_arch
|
812 |
+
self.clap: laion_clap.CLAP_Module
|
813 |
+
self.clap_tokenize: RobertaTokenizer
|
814 |
+
self.clap_sample_rate = sample_rate
|
815 |
+
self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
|
816 |
+
self.clap_stride = int(self.clap_sample_rate * audio_stride)
|
817 |
+
self.batch_size = batch_size or 1
|
818 |
+
self.normalize = normalize
|
819 |
+
self.text_p = text_p
|
820 |
+
self.__dict__['clap_tokenize'] = clap_tokenize
|
821 |
+
self.__dict__['clap'] = clap_model
|
822 |
+
self.wav_cache, self.text_cache = None, None
|
823 |
+
if cache_path is not None:
|
824 |
+
self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
|
825 |
+
compute_embed_fn=self._get_wav_embedding_for_cache,
|
826 |
+
extract_embed_fn=self._extract_wav_embedding_chunk)
|
827 |
+
self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
|
828 |
+
compute_embed_fn=self._get_text_embedding_for_cache)
|
829 |
+
|
830 |
+
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
|
831 |
+
# we use the default params from CLAP module here as well
|
832 |
+
return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
|
833 |
+
|
834 |
+
def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
|
835 |
+
"""Compute text embedding from CLAP model on a given a batch of text.
|
836 |
+
|
837 |
+
Args:
|
838 |
+
text (list[str]): List of text for the batch, with B items.
|
839 |
+
Returns:
|
840 |
+
torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
|
841 |
+
"""
|
842 |
+
with torch.no_grad():
|
843 |
+
embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
|
844 |
+
return embed.view(embed.size(0), 1, embed.size(-1))
|
845 |
+
|
846 |
+
def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
|
847 |
+
x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
848 |
+
"""Get text embedding function for the cache."""
|
849 |
+
text = x.text[idx]
|
850 |
+
text = text if text is not None else ""
|
851 |
+
return self._compute_text_embedding([text])[0]
|
852 |
+
|
853 |
+
def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
|
854 |
+
"""Preprocess wav to expected format by CLAP model.
|
855 |
+
|
856 |
+
Args:
|
857 |
+
wav (torch.Tensor): Audio wav, of shape [B, C, T].
|
858 |
+
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
|
859 |
+
sample_rates (list[int]): Sample rates for each sample in the batch
|
860 |
+
Returns:
|
861 |
+
torch.Tensor: Audio wav of shape [B, T].
|
862 |
+
"""
|
863 |
+
assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
|
864 |
+
if sample_rates is not None:
|
865 |
+
_wav = []
|
866 |
+
for i, audio in enumerate(wav):
|
867 |
+
sr = sample_rates[i]
|
868 |
+
audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
|
869 |
+
_wav.append(audio)
|
870 |
+
wav = torch.stack(_wav, dim=0)
|
871 |
+
wav = wav.mean(dim=1)
|
872 |
+
return wav
|
873 |
+
|
874 |
+
def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
|
875 |
+
sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
|
876 |
+
"""Compute audio wave embedding from CLAP model.
|
877 |
+
|
878 |
+
Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
|
879 |
+
we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
|
880 |
+
average the resulting embeddings.
|
881 |
+
|
882 |
+
Args:
|
883 |
+
wav (torch.Tensor): Audio wav, of shape [B, C, T].
|
884 |
+
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
|
885 |
+
sample_rates (list[int]): Sample rates for each sample in the batch.
|
886 |
+
reduce_mean (bool): Whether to get the average tensor.
|
887 |
+
Returns:
|
888 |
+
torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
|
889 |
+
"""
|
890 |
+
with torch.no_grad():
|
891 |
+
wav = self._preprocess_wav(wav, length, sample_rates)
|
892 |
+
B, T = wav.shape
|
893 |
+
if T >= self.clap_max_frames:
|
894 |
+
wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
|
895 |
+
else:
|
896 |
+
wav = wav.view(-1, 1, T) # [B, F, T] with F=1
|
897 |
+
wav = einops.rearrange(wav, 'b f t -> (b f) t')
|
898 |
+
embed_list = []
|
899 |
+
for i in range(0, wav.size(0), self.batch_size):
|
900 |
+
_wav = wav[i:i+self.batch_size, ...]
|
901 |
+
_embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
|
902 |
+
embed_list.append(_embed)
|
903 |
+
embed = torch.cat(embed_list, dim=0)
|
904 |
+
embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
|
905 |
+
if reduce_mean:
|
906 |
+
embed = embed.mean(dim=1, keepdim=True)
|
907 |
+
return embed # [B, F, D] with F=1 if reduce_mean is True
|
908 |
+
|
909 |
+
def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
|
910 |
+
x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
911 |
+
"""Compute audio wave embedding for the cache.
|
912 |
+
The embedding is computed on a given audio read from file.
|
913 |
+
|
914 |
+
Args:
|
915 |
+
path (str or Path): Path to the full audio file.
|
916 |
+
Returns:
|
917 |
+
torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
|
918 |
+
"""
|
919 |
+
wav, sr = audio_read(path) # [C, T]
|
920 |
+
wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
|
921 |
+
wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
|
922 |
+
embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
|
923 |
+
return embed.squeeze(0) # [F, D]
|
924 |
+
|
925 |
+
def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
|
926 |
+
"""Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
|
927 |
+
|
928 |
+
Args:
|
929 |
+
full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
|
930 |
+
x (JointEmbedCondition): Joint embedding condition for the full batch.
|
931 |
+
idx (int): Index considered for the given embedding to extract.
|
932 |
+
Returns:
|
933 |
+
torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
|
934 |
+
"""
|
935 |
+
sample_rate = x.sample_rate[idx]
|
936 |
+
seek_time = x.seek_time[idx]
|
937 |
+
seek_time = 0. if seek_time is None else seek_time
|
938 |
+
clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
|
939 |
+
end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
|
940 |
+
start_offset = int(seek_time * sample_rate // clap_stride)
|
941 |
+
end_offset = int(end_seek_time * sample_rate // clap_stride)
|
942 |
+
wav_embed = full_embed[start_offset:end_offset, ...]
|
943 |
+
wav_embed = wav_embed.mean(dim=0, keepdim=True)
|
944 |
+
return wav_embed.to(self.device) # [F, D]
|
945 |
+
|
946 |
+
def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
|
947 |
+
"""Get CLAP embedding from a batch of text descriptions."""
|
948 |
+
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
|
949 |
+
if self.text_cache is not None and no_nullified_cond:
|
950 |
+
assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
|
951 |
+
paths = [Path(p) for p in x.path if p is not None]
|
952 |
+
embed = self.text_cache.get_embed_from_cache(paths, x)
|
953 |
+
else:
|
954 |
+
text = [xi if xi is not None else "" for xi in x.text]
|
955 |
+
embed = self._compute_text_embedding(text)
|
956 |
+
if self.normalize:
|
957 |
+
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
|
958 |
+
return embed
|
959 |
+
|
960 |
+
def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
|
961 |
+
"""Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
|
962 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
963 |
+
no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
|
964 |
+
if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
|
965 |
+
paths = [Path(p) for p in x.path if p is not None]
|
966 |
+
embed = self.wav_cache.get_embed_from_cache(paths, x)
|
967 |
+
else:
|
968 |
+
embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
|
969 |
+
if self.normalize:
|
970 |
+
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
|
971 |
+
return embed
|
972 |
+
|
973 |
+
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
|
974 |
+
# Trying to limit as much as possible sync points when the cache is warm.
|
975 |
+
no_undefined_paths = all(p is not None for p in x.path)
|
976 |
+
if self.wav_cache is not None and no_undefined_paths:
|
977 |
+
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
|
978 |
+
paths = [Path(p) for p in x.path if p is not None]
|
979 |
+
self.wav_cache.populate_embed_cache(paths, x)
|
980 |
+
if self.text_cache is not None and no_undefined_paths:
|
981 |
+
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
|
982 |
+
paths = [Path(p) for p in x.path if p is not None]
|
983 |
+
self.text_cache.populate_embed_cache(paths, x)
|
984 |
+
return x
|
985 |
+
|
986 |
+
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
987 |
+
"""Extract shared latent representation from either the wav or the text using CLAP."""
|
988 |
+
# decide whether to use text embedding at train time or not
|
989 |
+
use_text_embed = random.random() < self.text_p
|
990 |
+
if self.training and not use_text_embed:
|
991 |
+
embed = self._get_wav_embedding(x)
|
992 |
+
empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
|
993 |
+
else:
|
994 |
+
embed = self._get_text_embedding(x)
|
995 |
+
empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
|
996 |
+
return embed, empty_idx
|
997 |
+
|
998 |
+
|
999 |
+
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
|
1000 |
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
1001 |
+
If the condition is of type "wav", then nullify it using `nullify_condition` function.
|
1002 |
+
If the condition is of any other type, set its value to None.
|
1003 |
Works in-place.
|
1004 |
"""
|
1005 |
+
if condition_type not in ['text', 'wav', 'joint_embed']:
|
1006 |
raise ValueError(
|
1007 |
"dropout_condition got an unexpected condition type!"
|
1008 |
+
f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
|
1009 |
)
|
1010 |
|
1011 |
if condition not in getattr(sample, condition_type):
|
1012 |
raise ValueError(
|
1013 |
"dropout_condition received an unexpected condition!"
|
1014 |
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
|
1015 |
+
f" but got '{condition}' of type '{condition_type}'!"
|
1016 |
)
|
1017 |
|
1018 |
+
if condition_type == 'wav':
|
1019 |
+
wav_cond = sample.wav[condition]
|
1020 |
+
sample.wav[condition] = nullify_wav(wav_cond)
|
1021 |
+
elif condition_type == 'joint_embed':
|
1022 |
+
embed = sample.joint_embed[condition]
|
1023 |
+
sample.joint_embed[condition] = nullify_joint_embed(embed)
|
1024 |
else:
|
1025 |
sample.text[condition] = None
|
1026 |
|
|
|
1028 |
|
1029 |
|
1030 |
class DropoutModule(nn.Module):
|
1031 |
+
"""Base module for all dropout modules."""
|
1032 |
def __init__(self, seed: int = 1234):
|
1033 |
super().__init__()
|
1034 |
self.rng = torch.Generator()
|
|
|
1036 |
|
1037 |
|
1038 |
class AttributeDropout(DropoutModule):
|
1039 |
+
"""Dropout with a given probability per attribute.
|
1040 |
+
This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
|
1041 |
+
to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
|
1042 |
+
This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
|
1043 |
+
must also be dropped.
|
1044 |
|
1045 |
Args:
|
1046 |
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
|
|
|
1063 |
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
1064 |
"""
|
1065 |
Args:
|
1066 |
+
samples (list[ConditioningAttributes]): List of conditions.
|
1067 |
Returns:
|
1068 |
+
list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
|
1069 |
"""
|
1070 |
if not self.training and not self.active_on_eval:
|
1071 |
return samples
|
1072 |
|
1073 |
samples = deepcopy(samples)
|
|
|
1074 |
for condition_type, ps in self.p.items(): # for condition types [text, wav]
|
1075 |
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
|
1076 |
if torch.rand(1, generator=self.rng).item() < p:
|
1077 |
for sample in samples:
|
1078 |
dropout_condition(sample, condition_type, condition)
|
|
|
1079 |
return samples
|
1080 |
|
1081 |
def __repr__(self):
|
|
|
1083 |
|
1084 |
|
1085 |
class ClassifierFreeGuidanceDropout(DropoutModule):
|
1086 |
+
"""Classifier Free Guidance dropout.
|
1087 |
+
All attributes are dropped with the same probability.
|
1088 |
|
1089 |
Args:
|
1090 |
p (float): Probability to apply condition dropout during training.
|
|
|
1097 |
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
1098 |
"""
|
1099 |
Args:
|
1100 |
+
samples (list[ConditioningAttributes]): List of conditions.
|
1101 |
Returns:
|
1102 |
+
list[ConditioningAttributes]: List of conditions after all attributes were set to None.
|
1103 |
"""
|
1104 |
if not self.training:
|
1105 |
return samples
|
|
|
1111 |
|
1112 |
# nullify conditions of all attributes
|
1113 |
samples = deepcopy(samples)
|
|
|
1114 |
for condition_type in ["wav", "text"]:
|
1115 |
for sample in samples:
|
1116 |
for condition in sample.attributes[condition_type]:
|
1117 |
dropout_condition(sample, condition_type, condition)
|
|
|
1118 |
return samples
|
1119 |
|
1120 |
def __repr__(self):
|
|
|
1122 |
|
1123 |
|
1124 |
class ConditioningProvider(nn.Module):
|
1125 |
+
"""Prepare and provide conditions given all the supported conditioners.
|
1126 |
|
1127 |
Args:
|
1128 |
conditioners (dict): Dictionary of conditioners.
|
1129 |
+
device (torch.device or str, optional): Device for conditioners and output condition types.
|
|
|
|
|
|
|
|
|
1130 |
"""
|
1131 |
+
def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
|
|
|
|
|
|
|
|
|
|
|
|
|
1132 |
super().__init__()
|
1133 |
self.device = device
|
|
|
|
|
1134 |
self.conditioners = nn.ModuleDict(conditioners)
|
1135 |
|
1136 |
+
@property
|
1137 |
+
def joint_embed_conditions(self):
|
1138 |
+
return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
|
1139 |
+
|
1140 |
+
@property
|
1141 |
+
def has_joint_embed_conditions(self):
|
1142 |
+
return len(self.joint_embed_conditions) > 0
|
1143 |
+
|
1144 |
@property
|
1145 |
def text_conditions(self):
|
1146 |
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
|
|
|
1159 |
This will return a dict matching conditioner names to their arbitrary tokenized representations.
|
1160 |
|
1161 |
Args:
|
1162 |
+
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
|
1163 |
text and wav conditions.
|
1164 |
"""
|
1165 |
+
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
|
1166 |
+
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
|
1167 |
f" but types were {set([type(x) for x in inputs])}"
|
1168 |
+
)
|
1169 |
|
1170 |
output = {}
|
1171 |
text = self._collate_text(inputs)
|
1172 |
wavs = self._collate_wavs(inputs)
|
1173 |
+
joint_embeds = self._collate_joint_embeds(inputs)
|
1174 |
|
1175 |
+
assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
|
1176 |
+
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
|
1177 |
+
f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
|
1178 |
+
)
|
1179 |
|
1180 |
+
for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
|
1181 |
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
1182 |
return output
|
1183 |
|
1184 |
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
1185 |
+
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
|
1186 |
+
The output is for example:
|
1187 |
+
{
|
1188 |
+
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
|
1189 |
+
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
|
1190 |
+
...
|
1191 |
+
}
|
|
|
1192 |
|
1193 |
Args:
|
1194 |
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
|
|
|
1213 |
"genre": ["Rock", "Hip-hop"],
|
1214 |
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
|
1215 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1216 |
|
1217 |
+
Args:
|
1218 |
+
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
1219 |
+
Returns:
|
1220 |
+
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
|
1221 |
+
"""
|
1222 |
+
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
|
1223 |
texts = [x.text for x in samples]
|
1224 |
for text in texts:
|
1225 |
for condition in self.text_conditions:
|
1226 |
+
out[condition].append(text[condition])
|
1227 |
+
return out
|
|
|
1228 |
|
1229 |
+
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
|
1230 |
"""Generate a dict where the keys are attributes by which we fetch similar wavs,
|
1231 |
+
and the values are Tensors of wavs according to said attributes.
|
1232 |
|
1233 |
*Note*: by the time the samples reach this function, each sample should have some waveform
|
1234 |
inside the "wav" attribute. It should be either:
|
|
|
1237 |
3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
|
1238 |
|
1239 |
Args:
|
1240 |
+
samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
|
1241 |
Returns:
|
1242 |
+
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
|
1243 |
"""
|
1244 |
wavs = defaultdict(list)
|
1245 |
+
lengths = defaultdict(list)
|
1246 |
+
sample_rates = defaultdict(list)
|
1247 |
paths = defaultdict(list)
|
1248 |
+
seek_times = defaultdict(list)
|
1249 |
+
out: tp.Dict[str, WavCondition] = {}
|
1250 |
|
1251 |
for sample in samples:
|
1252 |
for attribute in self.wav_conditions:
|
1253 |
+
wav, length, sample_rate, path, seek_time = sample.wav[attribute]
|
1254 |
+
assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
|
1255 |
+
assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
|
1256 |
+
# mono-channel conditioning
|
1257 |
+
wav = wav.mean(1, keepdim=True) # [1, 1, T]
|
1258 |
+
wavs[attribute].append(wav.flatten()) # [T]
|
1259 |
+
lengths[attribute].append(length)
|
1260 |
+
sample_rates[attribute].extend(sample_rate)
|
1261 |
+
paths[attribute].extend(path)
|
1262 |
+
seek_times[attribute].extend(seek_time)
|
1263 |
|
1264 |
# stack all wavs to a single tensor
|
1265 |
for attribute in self.wav_conditions:
|
1266 |
stacked_wav, _ = collate(wavs[attribute], dim=0)
|
1267 |
+
out[attribute] = WavCondition(
|
1268 |
+
stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
|
1269 |
+
paths[attribute], seek_times[attribute])
|
1270 |
+
|
1271 |
+
return out
|
1272 |
+
|
1273 |
+
def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
|
1274 |
+
"""Generate a dict where the keys are attributes by which we compute joint embeddings,
|
1275 |
+
and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
|
1276 |
+
|
1277 |
+
Args:
|
1278 |
+
samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
|
1279 |
+
Returns:
|
1280 |
+
A dictionary mapping an attribute name to joint embeddings.
|
1281 |
+
"""
|
1282 |
+
texts = defaultdict(list)
|
1283 |
+
wavs = defaultdict(list)
|
1284 |
+
lengths = defaultdict(list)
|
1285 |
+
sample_rates = defaultdict(list)
|
1286 |
+
paths = defaultdict(list)
|
1287 |
+
seek_times = defaultdict(list)
|
1288 |
+
channels: int = 0
|
1289 |
+
|
1290 |
+
out = {}
|
1291 |
+
for sample in samples:
|
1292 |
+
for attribute in self.joint_embed_conditions:
|
1293 |
+
wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
|
1294 |
+
assert wav.dim() == 3
|
1295 |
+
if channels == 0:
|
1296 |
+
channels = wav.size(1)
|
1297 |
+
else:
|
1298 |
+
assert channels == wav.size(1), "not all audio has same number of channels in batch"
|
1299 |
+
assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
|
1300 |
+
wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
|
1301 |
+
wavs[attribute].append(wav)
|
1302 |
+
texts[attribute].extend(text)
|
1303 |
+
lengths[attribute].append(length)
|
1304 |
+
sample_rates[attribute].extend(sample_rate)
|
1305 |
+
paths[attribute].extend(path)
|
1306 |
+
seek_times[attribute].extend(seek_time)
|
1307 |
+
|
1308 |
+
for attribute in self.joint_embed_conditions:
|
1309 |
+
stacked_texts = texts[attribute]
|
1310 |
+
stacked_paths = paths[attribute]
|
1311 |
+
stacked_seek_times = seek_times[attribute]
|
1312 |
+
stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
|
1313 |
+
stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
|
1314 |
+
stacked_sample_rates = sample_rates[attribute]
|
1315 |
+
stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
|
1316 |
+
assert stacked_lengths.size(0) == stacked_wavs.size(0)
|
1317 |
+
assert len(stacked_sample_rates) == stacked_wavs.size(0)
|
1318 |
+
assert len(stacked_texts) == stacked_wavs.size(0)
|
1319 |
+
out[attribute] = JointEmbedCondition(
|
1320 |
+
text=stacked_texts, wav=stacked_wavs,
|
1321 |
+
length=stacked_lengths, sample_rate=stacked_sample_rates,
|
1322 |
+
path=stacked_paths, seek_time=stacked_seek_times)
|
1323 |
|
1324 |
return out
|
1325 |
|
|
|
1346 |
super().__init__()
|
1347 |
assert all(
|
1348 |
[k in self.FUSING_METHODS for k in fuse2cond.keys()]
|
1349 |
+
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
|
1350 |
self.cross_attention_pos_emb = cross_attention_pos_emb
|
1351 |
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
|
1352 |
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
|
|
|
1357 |
|
1358 |
def forward(
|
1359 |
self,
|
1360 |
+
input: torch.Tensor,
|
1361 |
conditions: tp.Dict[str, ConditionType]
|
1362 |
+
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
1363 |
"""Fuse the conditions to the provided model input.
|
1364 |
|
1365 |
Args:
|
1366 |
+
input (torch.Tensor): Transformer input.
|
1367 |
+
conditions (dict[str, ConditionType]): Dict of conditions.
|
1368 |
Returns:
|
1369 |
+
tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
|
1370 |
after the conditions have been fused. The second output tensor is the tensor
|
1371 |
used for cross-attention or None if no cross attention inputs exist.
|
1372 |
"""
|
|
|
1385 |
cross_attention_output = None
|
1386 |
for cond_type, (cond, cond_mask) in conditions.items():
|
1387 |
op = self.cond2fuse[cond_type]
|
1388 |
+
if op == 'sum':
|
1389 |
input += cond
|
1390 |
+
elif op == 'input_interpolate':
|
1391 |
+
cond = einops.rearrange(cond, "b t d -> b d t")
|
1392 |
cond = F.interpolate(cond, size=input.shape[1])
|
1393 |
+
input += einops.rearrange(cond, "b d t -> b t d")
|
1394 |
+
elif op == 'prepend':
|
1395 |
if first_step:
|
1396 |
input = torch.cat([cond, input], dim=1)
|
1397 |
+
elif op == 'cross':
|
1398 |
if cross_attention_output is not None:
|
1399 |
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
|
1400 |
else:
|
audiocraft/modules/conv.py
CHANGED
@@ -11,7 +11,7 @@ import warnings
|
|
11 |
import torch
|
12 |
from torch import nn
|
13 |
from torch.nn import functional as F
|
14 |
-
from torch.nn.utils import spectral_norm, weight_norm
|
15 |
|
16 |
|
17 |
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
|
|
11 |
import torch
|
12 |
from torch import nn
|
13 |
from torch.nn import functional as F
|
14 |
+
from torch.nn.utils.parametrizations import spectral_norm, weight_norm
|
15 |
|
16 |
|
17 |
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
audiocraft/modules/diffusion_schedule.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from collections import namedtuple
|
12 |
+
import random
|
13 |
+
import typing as tp
|
14 |
+
import julius
|
15 |
+
import torch
|
16 |
+
|
17 |
+
TrainingItem = namedtuple("TrainingItem", "noisy noise step")
|
18 |
+
|
19 |
+
|
20 |
+
def betas_from_alpha_bar(alpha_bar):
|
21 |
+
alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
|
22 |
+
return 1 - alphas
|
23 |
+
|
24 |
+
|
25 |
+
class SampleProcessor(torch.nn.Module):
|
26 |
+
def project_sample(self, x: torch.Tensor):
|
27 |
+
"""Project the original sample to the 'space' where the diffusion will happen."""
|
28 |
+
return x
|
29 |
+
|
30 |
+
def return_sample(self, z: torch.Tensor):
|
31 |
+
"""Project back from diffusion space to the actual sample space."""
|
32 |
+
return z
|
33 |
+
|
34 |
+
|
35 |
+
class MultiBandProcessor(SampleProcessor):
|
36 |
+
"""
|
37 |
+
MultiBand sample processor. The input audio is splitted across
|
38 |
+
frequency bands evenly distributed in mel-scale.
|
39 |
+
|
40 |
+
Each band will be rescaled to match the power distribution
|
41 |
+
of Gaussian noise in that band, using online metrics
|
42 |
+
computed on the first few samples.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
n_bands (int): Number of mel-bands to split the signal over.
|
46 |
+
sample_rate (int): Sample rate of the audio.
|
47 |
+
num_samples (int): Number of samples to use to fit the rescaling
|
48 |
+
for each band. The processor won't be stable
|
49 |
+
until it has seen that many samples.
|
50 |
+
power_std (float or list/tensor): The rescaling factor computed to match the
|
51 |
+
power of Gaussian noise in each band is taken to
|
52 |
+
that power, i.e. `1.` means full correction of the energy
|
53 |
+
in each band, and values less than `1` means only partial
|
54 |
+
correction. Can be used to balance the relative importance
|
55 |
+
of low vs. high freq in typical audio signals.
|
56 |
+
"""
|
57 |
+
def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
|
58 |
+
num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
|
59 |
+
super().__init__()
|
60 |
+
self.n_bands = n_bands
|
61 |
+
self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
|
62 |
+
self.num_samples = num_samples
|
63 |
+
self.power_std = power_std
|
64 |
+
if isinstance(power_std, list):
|
65 |
+
assert len(power_std) == n_bands
|
66 |
+
power_std = torch.tensor(power_std)
|
67 |
+
self.register_buffer('counts', torch.zeros(1))
|
68 |
+
self.register_buffer('sum_x', torch.zeros(n_bands))
|
69 |
+
self.register_buffer('sum_x2', torch.zeros(n_bands))
|
70 |
+
self.register_buffer('sum_target_x2', torch.zeros(n_bands))
|
71 |
+
self.counts: torch.Tensor
|
72 |
+
self.sum_x: torch.Tensor
|
73 |
+
self.sum_x2: torch.Tensor
|
74 |
+
self.sum_target_x2: torch.Tensor
|
75 |
+
|
76 |
+
@property
|
77 |
+
def mean(self):
|
78 |
+
mean = self.sum_x / self.counts
|
79 |
+
return mean
|
80 |
+
|
81 |
+
@property
|
82 |
+
def std(self):
|
83 |
+
std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
|
84 |
+
return std
|
85 |
+
|
86 |
+
@property
|
87 |
+
def target_std(self):
|
88 |
+
target_std = self.sum_target_x2 / self.counts
|
89 |
+
return target_std
|
90 |
+
|
91 |
+
def project_sample(self, x: torch.Tensor):
|
92 |
+
assert x.dim() == 3
|
93 |
+
bands = self.split_bands(x)
|
94 |
+
if self.counts.item() < self.num_samples:
|
95 |
+
ref_bands = self.split_bands(torch.randn_like(x))
|
96 |
+
self.counts += len(x)
|
97 |
+
self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
|
98 |
+
self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
|
99 |
+
self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
|
100 |
+
rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
|
101 |
+
bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
|
102 |
+
return bands.sum(dim=0)
|
103 |
+
|
104 |
+
def return_sample(self, x: torch.Tensor):
|
105 |
+
assert x.dim() == 3
|
106 |
+
bands = self.split_bands(x)
|
107 |
+
rescale = (self.std / self.target_std) ** self.power_std
|
108 |
+
bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
|
109 |
+
return bands.sum(dim=0)
|
110 |
+
|
111 |
+
|
112 |
+
class NoiseSchedule:
|
113 |
+
"""Noise schedule for diffusion.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
beta_t0 (float): Variance of the first diffusion step.
|
117 |
+
beta_t1 (float): Variance of the last diffusion step.
|
118 |
+
beta_exp (float): Power schedule exponent
|
119 |
+
num_steps (int): Number of diffusion step.
|
120 |
+
variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
|
121 |
+
clip (float): clipping value for the denoising steps
|
122 |
+
rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
|
123 |
+
repartition (str): shape of the schedule only power schedule is supported
|
124 |
+
sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
|
125 |
+
noise_scale (float): Scaling factor for the noise
|
126 |
+
"""
|
127 |
+
def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
|
128 |
+
clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
|
129 |
+
repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
|
130 |
+
sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
|
131 |
+
|
132 |
+
self.beta_t0 = beta_t0
|
133 |
+
self.beta_t1 = beta_t1
|
134 |
+
self.variance = variance
|
135 |
+
self.num_steps = num_steps
|
136 |
+
self.clip = clip
|
137 |
+
self.sample_processor = sample_processor
|
138 |
+
self.rescale = rescale
|
139 |
+
self.n_bands = n_bands
|
140 |
+
self.noise_scale = noise_scale
|
141 |
+
assert n_bands is None
|
142 |
+
if repartition == "power":
|
143 |
+
self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
|
144 |
+
device=device, dtype=torch.float) ** beta_exp
|
145 |
+
else:
|
146 |
+
raise RuntimeError('Not implemented')
|
147 |
+
self.rng = random.Random(1234)
|
148 |
+
|
149 |
+
def get_beta(self, step: tp.Union[int, torch.Tensor]):
|
150 |
+
if self.n_bands is None:
|
151 |
+
return self.betas[step]
|
152 |
+
else:
|
153 |
+
return self.betas[:, step] # [n_bands, len(step)]
|
154 |
+
|
155 |
+
def get_initial_noise(self, x: torch.Tensor):
|
156 |
+
if self.n_bands is None:
|
157 |
+
return torch.randn_like(x)
|
158 |
+
return torch.randn((x.size(0), self.n_bands, x.size(2)))
|
159 |
+
|
160 |
+
def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
|
161 |
+
"""Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
|
162 |
+
if step is None:
|
163 |
+
return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
|
164 |
+
if type(step) is int:
|
165 |
+
return (1 - self.betas[:step + 1]).prod()
|
166 |
+
else:
|
167 |
+
return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
|
168 |
+
|
169 |
+
def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
|
170 |
+
"""Create a noisy data item for diffusion model training:
|
171 |
+
|
172 |
+
Args:
|
173 |
+
x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
|
174 |
+
tensor_step (bool): If tensor_step = false, only one step t is sample,
|
175 |
+
the whole batch is diffused to the same step and t is int.
|
176 |
+
If tensor_step = true, t is a tensor of size (x.size(0),)
|
177 |
+
every element of the batch is diffused to a independently sampled.
|
178 |
+
"""
|
179 |
+
step: tp.Union[int, torch.Tensor]
|
180 |
+
if tensor_step:
|
181 |
+
bs = x.size(0)
|
182 |
+
step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
|
183 |
+
else:
|
184 |
+
step = self.rng.randrange(self.num_steps)
|
185 |
+
alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
|
186 |
+
|
187 |
+
x = self.sample_processor.project_sample(x)
|
188 |
+
noise = torch.randn_like(x)
|
189 |
+
noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
|
190 |
+
return TrainingItem(noisy, noise, step)
|
191 |
+
|
192 |
+
def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
|
193 |
+
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
|
194 |
+
"""Full ddpm reverse process.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
model (nn.Module): Diffusion model.
|
198 |
+
initial (tensor): Initial Noise.
|
199 |
+
condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
|
200 |
+
return_list (bool): Whether to return the whole process or only the sampled point.
|
201 |
+
"""
|
202 |
+
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
|
203 |
+
current = initial
|
204 |
+
iterates = [initial]
|
205 |
+
for step in range(self.num_steps)[::-1]:
|
206 |
+
with torch.no_grad():
|
207 |
+
estimate = model(current, step, condition=condition).sample
|
208 |
+
alpha = 1 - self.betas[step]
|
209 |
+
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
|
210 |
+
previous_alpha_bar = self.get_alpha_bar(step=step - 1)
|
211 |
+
if step == 0:
|
212 |
+
sigma2 = 0
|
213 |
+
elif self.variance == 'beta':
|
214 |
+
sigma2 = 1 - alpha
|
215 |
+
elif self.variance == 'beta_tilde':
|
216 |
+
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
|
217 |
+
elif self.variance == 'none':
|
218 |
+
sigma2 = 0
|
219 |
+
else:
|
220 |
+
raise ValueError(f'Invalid variance type {self.variance}')
|
221 |
+
|
222 |
+
if sigma2 > 0:
|
223 |
+
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
|
224 |
+
if self.clip:
|
225 |
+
previous = previous.clamp(-self.clip, self.clip)
|
226 |
+
current = previous
|
227 |
+
alpha_bar = previous_alpha_bar
|
228 |
+
if step == 0:
|
229 |
+
previous *= self.rescale
|
230 |
+
if return_list:
|
231 |
+
iterates.append(previous.cpu())
|
232 |
+
|
233 |
+
if return_list:
|
234 |
+
return iterates
|
235 |
+
else:
|
236 |
+
return self.sample_processor.return_sample(previous)
|
237 |
+
|
238 |
+
def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
|
239 |
+
condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
|
240 |
+
"""Reverse process that only goes through Markov chain states in step_list."""
|
241 |
+
if step_list is None:
|
242 |
+
step_list = list(range(1000))[::-50] + [0]
|
243 |
+
alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
|
244 |
+
alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
|
245 |
+
betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
|
246 |
+
current = initial * self.noise_scale
|
247 |
+
iterates = [current]
|
248 |
+
for idx, step in enumerate(step_list[:-1]):
|
249 |
+
with torch.no_grad():
|
250 |
+
estimate = model(current, step, condition=condition).sample * self.noise_scale
|
251 |
+
alpha = 1 - betas_subsampled[-1 - idx]
|
252 |
+
previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
|
253 |
+
previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
|
254 |
+
if step == step_list[-2]:
|
255 |
+
sigma2 = 0
|
256 |
+
previous_alpha_bar = torch.tensor(1.0)
|
257 |
+
else:
|
258 |
+
sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
|
259 |
+
if sigma2 > 0:
|
260 |
+
previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
|
261 |
+
if self.clip:
|
262 |
+
previous = previous.clamp(-self.clip, self.clip)
|
263 |
+
current = previous
|
264 |
+
alpha_bar = previous_alpha_bar
|
265 |
+
if step == 0:
|
266 |
+
previous *= self.rescale
|
267 |
+
if return_list:
|
268 |
+
iterates.append(previous.cpu())
|
269 |
+
if return_list:
|
270 |
+
return iterates
|
271 |
+
else:
|
272 |
+
return self.sample_processor.return_sample(previous)
|
audiocraft/modules/rope.py
CHANGED
@@ -18,7 +18,7 @@ class XPos(nn.Module):
|
|
18 |
dim (int): Embedding dimension.
|
19 |
smoothing (float): Smoothing factor applied to the decay rates.
|
20 |
base_scale (int): Base decay rate, given in terms of scaling time.
|
21 |
-
device (torch.device
|
22 |
dtype (torch.dtype): dtype to use to generate the embedding.
|
23 |
"""
|
24 |
def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
|
@@ -36,8 +36,7 @@ class XPos(nn.Module):
|
|
36 |
self.decay: tp.Optional[torch.Tensor] = None
|
37 |
|
38 |
def get_decay(self, start: int, end: int):
|
39 |
-
"""Create complex decay tensor, cache values for fast computation.
|
40 |
-
"""
|
41 |
if self.decay is None or end > self.decay.shape[0]:
|
42 |
assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
|
43 |
idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
|
@@ -55,7 +54,7 @@ class RotaryEmbedding(nn.Module):
|
|
55 |
max_period (float): Maximum period of the rotation frequencies.
|
56 |
xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
|
57 |
scale (float): Scale of positional embedding, set to 0 to deactivate.
|
58 |
-
device (torch.device
|
59 |
dtype (torch.dtype): dtype to use to generate the embedding.
|
60 |
"""
|
61 |
def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
|
@@ -74,8 +73,7 @@ class RotaryEmbedding(nn.Module):
|
|
74 |
self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
|
75 |
|
76 |
def get_rotation(self, start: int, end: int):
|
77 |
-
"""Create complex rotation tensor, cache values for fast computation.
|
78 |
-
"""
|
79 |
if self.rotation is None or end > self.rotation.shape[0]:
|
80 |
assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
|
81 |
idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
|
@@ -83,14 +81,16 @@ class RotaryEmbedding(nn.Module):
|
|
83 |
self.rotation = torch.polar(torch.ones_like(angles), angles)
|
84 |
return self.rotation[start:end]
|
85 |
|
86 |
-
def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
|
87 |
-
"""Apply rope rotation to query or key tensor.
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
91 |
|
92 |
if self.xpos:
|
93 |
-
decay = self.xpos.get_decay(start, start + T).
|
94 |
else:
|
95 |
decay = 1.0
|
96 |
|
@@ -99,26 +99,27 @@ class RotaryEmbedding(nn.Module):
|
|
99 |
|
100 |
x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
|
101 |
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
|
102 |
-
x_out = torch.view_as_real(x_complex * scaled_rotation).
|
103 |
|
104 |
return x_out.type_as(x)
|
105 |
|
106 |
-
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
|
107 |
""" Apply rope rotation to both query and key tensors.
|
108 |
Supports streaming mode, in which query and key are not expected to have the same shape.
|
109 |
-
In streaming mode, key will be of
|
110 |
query will be [C] (typically C == 1).
|
111 |
|
112 |
Args:
|
113 |
query (torch.Tensor): Query to rotate.
|
114 |
key (torch.Tensor): Key to rotate.
|
115 |
start (int): Start index of the sequence for time offset.
|
|
|
116 |
"""
|
117 |
-
query_timesteps = query.shape[
|
118 |
-
key_timesteps = key.shape[
|
119 |
streaming_offset = key_timesteps - query_timesteps
|
120 |
|
121 |
-
query_out = self.rotate(query, start + streaming_offset)
|
122 |
-
key_out = self.rotate(key, start, invert_decay=True)
|
123 |
|
124 |
return query_out, key_out
|
|
|
18 |
dim (int): Embedding dimension.
|
19 |
smoothing (float): Smoothing factor applied to the decay rates.
|
20 |
base_scale (int): Base decay rate, given in terms of scaling time.
|
21 |
+
device (torch.device, optional): Device on which to initialize the module.
|
22 |
dtype (torch.dtype): dtype to use to generate the embedding.
|
23 |
"""
|
24 |
def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
|
|
|
36 |
self.decay: tp.Optional[torch.Tensor] = None
|
37 |
|
38 |
def get_decay(self, start: int, end: int):
|
39 |
+
"""Create complex decay tensor, cache values for fast computation."""
|
|
|
40 |
if self.decay is None or end > self.decay.shape[0]:
|
41 |
assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
|
42 |
idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
|
|
|
54 |
max_period (float): Maximum period of the rotation frequencies.
|
55 |
xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
|
56 |
scale (float): Scale of positional embedding, set to 0 to deactivate.
|
57 |
+
device (torch.device, optional): Device on which to initialize the module.
|
58 |
dtype (torch.dtype): dtype to use to generate the embedding.
|
59 |
"""
|
60 |
def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
|
|
|
73 |
self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
|
74 |
|
75 |
def get_rotation(self, start: int, end: int):
|
76 |
+
"""Create complex rotation tensor, cache values for fast computation."""
|
|
|
77 |
if self.rotation is None or end > self.rotation.shape[0]:
|
78 |
assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
|
79 |
idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
|
|
|
81 |
self.rotation = torch.polar(torch.ones_like(angles), angles)
|
82 |
return self.rotation[start:end]
|
83 |
|
84 |
+
def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
|
85 |
+
"""Apply rope rotation to query or key tensor."""
|
86 |
+
T = x.shape[time_dim]
|
87 |
+
target_shape = [1] * x.dim()
|
88 |
+
target_shape[time_dim] = T
|
89 |
+
target_shape[-1] = -1
|
90 |
+
rotation = self.get_rotation(start, start + T).view(target_shape)
|
91 |
|
92 |
if self.xpos:
|
93 |
+
decay = self.xpos.get_decay(start, start + T).view(target_shape)
|
94 |
else:
|
95 |
decay = 1.0
|
96 |
|
|
|
99 |
|
100 |
x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
|
101 |
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
|
102 |
+
x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)
|
103 |
|
104 |
return x_out.type_as(x)
|
105 |
|
106 |
+
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
|
107 |
""" Apply rope rotation to both query and key tensors.
|
108 |
Supports streaming mode, in which query and key are not expected to have the same shape.
|
109 |
+
In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
|
110 |
query will be [C] (typically C == 1).
|
111 |
|
112 |
Args:
|
113 |
query (torch.Tensor): Query to rotate.
|
114 |
key (torch.Tensor): Key to rotate.
|
115 |
start (int): Start index of the sequence for time offset.
|
116 |
+
time_dim (int): which dimension represent the time steps.
|
117 |
"""
|
118 |
+
query_timesteps = query.shape[time_dim]
|
119 |
+
key_timesteps = key.shape[time_dim]
|
120 |
streaming_offset = key_timesteps - query_timesteps
|
121 |
|
122 |
+
query_out = self.rotate(query, start + streaming_offset, time_dim)
|
123 |
+
key_out = self.rotate(key, start, time_dim, invert_decay=True)
|
124 |
|
125 |
return query_out, key_out
|
audiocraft/modules/transformer.py
CHANGED
@@ -35,8 +35,8 @@ def set_efficient_attention_backend(backend: str = 'torch'):
|
|
35 |
_efficient_attention_backend = backend
|
36 |
|
37 |
|
38 |
-
def _get_attention_time_dimension() -> int:
|
39 |
-
if _efficient_attention_backend == 'torch':
|
40 |
return 2
|
41 |
else:
|
42 |
return 1
|
@@ -89,11 +89,11 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
|
|
89 |
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
90 |
|
91 |
|
92 |
-
def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
93 |
-
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
|
94 |
if n_rep == 1:
|
95 |
return x
|
96 |
-
if _efficient_attention_backend == 'torch':
|
97 |
bs, n_kv_heads, slen, head_dim = x.shape
|
98 |
return (
|
99 |
x[:, :, None, :, :]
|
@@ -111,14 +111,14 @@ def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
111 |
|
112 |
class LayerScale(nn.Module):
|
113 |
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
114 |
-
This rescales
|
115 |
|
116 |
Args:
|
117 |
channels (int): Number of channels.
|
118 |
init (float): Initial scale.
|
119 |
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
|
120 |
-
device (torch.device or
|
121 |
-
dtype (torch.dtype
|
122 |
"""
|
123 |
def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
|
124 |
device=None, dtype=None):
|
@@ -144,22 +144,22 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
144 |
dropout (float): Dropout level.
|
145 |
bias (bool): Use bias in projections.
|
146 |
causal (bool): Causal mask applied automatically.
|
147 |
-
past_context (int
|
148 |
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
149 |
memory_efficient (bool): Use xformers based memory efficient attention.
|
150 |
attention_as_float32 (bool): Perform the attention as float32
|
151 |
(especially important with memory_efficient as autocast won't do this automatically).
|
152 |
-
rope (`RotaryEmbedding
|
153 |
cross_attention: Should be true when used as a cross attention.
|
154 |
All keys and values must be available at once, streaming is only for the queries.
|
155 |
Cannot be used with `causal` or `rope` (as it wouldn't make sens to
|
156 |
-
|
157 |
safe_streaming (bool): Bug fix, will go away with xformers update.
|
158 |
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
|
159 |
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
|
160 |
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
|
161 |
-
device (torch.device
|
162 |
-
dtype (torch.dtype
|
163 |
"""
|
164 |
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
|
165 |
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
|
@@ -234,14 +234,14 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
234 |
# Return a causal mask, accounting for potentially stored past keys/values
|
235 |
# We actually return a bias for the attention score, as this has the same
|
236 |
# convention both in the builtin MHA in Pytorch, and Xformers functions.
|
237 |
-
time_dim = _get_attention_time_dimension()
|
238 |
if self.memory_efficient:
|
239 |
from xformers.ops import LowerTriangularMask
|
240 |
if current_steps == 1:
|
241 |
# If we only have one step, then we do not need a mask.
|
242 |
return None
|
243 |
elif 'past_keys' in self._streaming_state:
|
244 |
-
raise RuntimeError(
|
245 |
else:
|
246 |
# Then we can safely use a lower triangular mask
|
247 |
return LowerTriangularMask()
|
@@ -264,7 +264,7 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
264 |
torch.full([], float('-inf'), device=device, dtype=dtype))
|
265 |
|
266 |
def _complete_kv(self, k, v):
|
267 |
-
time_dim = _get_attention_time_dimension()
|
268 |
if self.cross_attention:
|
269 |
# With cross attention we assume all keys and values
|
270 |
# are already available, and streaming is with respect
|
@@ -298,8 +298,7 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
298 |
return nk, nv
|
299 |
|
300 |
def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
|
301 |
-
|
302 |
-
assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.'
|
303 |
# Apply rope embeddings to query and key tensors.
|
304 |
assert self.rope is not None
|
305 |
if 'past_keys' in self._streaming_state:
|
@@ -311,16 +310,16 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
311 |
else:
|
312 |
past_context_offset = 0
|
313 |
streaming_offset = past_context_offset + past_keys_offset
|
314 |
-
return self.rope.rotate_qk(query, key, start=streaming_offset)
|
315 |
|
316 |
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
317 |
key_padding_mask=None, need_weights=False, attn_mask=None,
|
318 |
average_attn_weights=True, is_causal=False):
|
319 |
assert attn_mask is None
|
320 |
-
assert not is_causal, ("
|
321 |
"use the causal args in the constructor.")
|
322 |
|
323 |
-
time_dim = _get_attention_time_dimension()
|
324 |
if time_dim == 2:
|
325 |
layout = "b h t d"
|
326 |
else:
|
@@ -394,8 +393,8 @@ class StreamingMultiheadAttention(StreamingModule):
|
|
394 |
q, k = self._apply_rope(q, k)
|
395 |
k, v = self._complete_kv(k, v)
|
396 |
if self.kv_repeat > 1:
|
397 |
-
k = expand_repeated_kv(k, self.kv_repeat)
|
398 |
-
v = expand_repeated_kv(v, self.kv_repeat)
|
399 |
if self.attention_as_float32:
|
400 |
q, k, v = [x.float() for x in [q, k, v]]
|
401 |
if self.memory_efficient:
|
@@ -455,7 +454,7 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
|
455 |
bias_ff (bool): Use bias for FF.
|
456 |
bias_attn (bool): Use bias for MHA.
|
457 |
causal (bool): Causal mask applied automatically.
|
458 |
-
past_context (int
|
459 |
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
460 |
memory_efficient (bool): Use xformers based memory efficient attention.
|
461 |
attention_as_float32 (bool): Perform the attention as float32
|
@@ -465,15 +464,15 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
|
465 |
cross_attention (bool): If True, expect to get secondary input for cross-attention.
|
466 |
Cross attention will use the default MHA, as it typically won't require
|
467 |
special treatment.
|
468 |
-
layer_scale (float
|
469 |
the given value as initial scale.
|
470 |
-
rope (`RotaryEmbedding
|
471 |
-
attention_dropout (float
|
472 |
in FFN and of the attention dropout.
|
473 |
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
|
474 |
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
|
475 |
-
device (torch.device
|
476 |
-
dtype (torch.dtype
|
477 |
**kwargs: See `nn.TransformerEncoderLayer`.
|
478 |
"""
|
479 |
def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
@@ -576,30 +575,30 @@ class StreamingTransformer(StreamingModule):
|
|
576 |
bias_ff (bool): Use bias for FF.
|
577 |
bias_attn (bool): Use bias for MHA.
|
578 |
causal (bool): Causal mask applied automatically.
|
579 |
-
past_context (int
|
580 |
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
581 |
memory_efficient (bool): Use xformers based memory efficient attention.
|
582 |
attention_as_float32 (bool): Perform the attention as float32
|
583 |
(especially important with memory_efficient as autocast won't do this automatically).
|
584 |
cross_attention (bool): If True, expect to get secondary input for cross-attention.
|
585 |
-
layer_scale (float
|
586 |
with the given value as initial scale.
|
587 |
positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
|
588 |
max_period (float): Maximum period of the time embedding.
|
589 |
positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
|
590 |
xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
|
591 |
-
lr (float
|
592 |
-
weight_decay (float
|
593 |
layer_class: (subclass of `StreamingTransformerLayer): class to use
|
594 |
-
to initialize the layers, allowing further customization outside of
|
595 |
checkpointing (str): Checkpointing strategy to reduce memory usage.
|
596 |
No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
|
597 |
if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
|
598 |
minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
|
599 |
a policy for opting-out some operations of the checkpointing like
|
600 |
linear layers and attention, providing a middle ground between speed and memory.
|
601 |
-
device (torch.device
|
602 |
-
dtype (torch.dtype
|
603 |
**kwargs: See `nn.TransformerEncoderLayer`.
|
604 |
"""
|
605 |
def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
|
@@ -649,7 +648,6 @@ class StreamingTransformer(StreamingModule):
|
|
649 |
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
|
650 |
# backward hook inside of FSDP...
|
651 |
layer._magma_checkpointed = True # type: ignore
|
652 |
-
assert layer.layer_drop == 0., "Need further checking" # type: ignore
|
653 |
|
654 |
def _apply_layer(self, layer, *args, **kwargs):
|
655 |
method = self.checkpointing
|
@@ -713,7 +711,7 @@ class StreamingTransformer(StreamingModule):
|
|
713 |
return group
|
714 |
|
715 |
|
716 |
-
# special attention
|
717 |
|
718 |
def _verify_xformers_memory_efficient_compat():
|
719 |
try:
|
|
|
35 |
_efficient_attention_backend = backend
|
36 |
|
37 |
|
38 |
+
def _get_attention_time_dimension(memory_efficient: bool) -> int:
|
39 |
+
if _efficient_attention_backend == 'torch' and memory_efficient:
|
40 |
return 2
|
41 |
else:
|
42 |
return 1
|
|
|
89 |
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
90 |
|
91 |
|
92 |
+
def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
|
93 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
|
94 |
if n_rep == 1:
|
95 |
return x
|
96 |
+
if _efficient_attention_backend == 'torch' and memory_efficient:
|
97 |
bs, n_kv_heads, slen, head_dim = x.shape
|
98 |
return (
|
99 |
x[:, :, None, :, :]
|
|
|
111 |
|
112 |
class LayerScale(nn.Module):
|
113 |
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
114 |
+
This rescales diagonally the residual outputs close to 0, with a learnt scale.
|
115 |
|
116 |
Args:
|
117 |
channels (int): Number of channels.
|
118 |
init (float): Initial scale.
|
119 |
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
|
120 |
+
device (torch.device or str, optional): Device on which to initialize the module.
|
121 |
+
dtype (torch.dtype, optional): dtype to use to initialize the module.
|
122 |
"""
|
123 |
def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
|
124 |
device=None, dtype=None):
|
|
|
144 |
dropout (float): Dropout level.
|
145 |
bias (bool): Use bias in projections.
|
146 |
causal (bool): Causal mask applied automatically.
|
147 |
+
past_context (int, optional): Receptive field for the causal mask, infinite if None.
|
148 |
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
149 |
memory_efficient (bool): Use xformers based memory efficient attention.
|
150 |
attention_as_float32 (bool): Perform the attention as float32
|
151 |
(especially important with memory_efficient as autocast won't do this automatically).
|
152 |
+
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
153 |
cross_attention: Should be true when used as a cross attention.
|
154 |
All keys and values must be available at once, streaming is only for the queries.
|
155 |
Cannot be used with `causal` or `rope` (as it wouldn't make sens to
|
156 |
+
interpret the time steps in the keys relative to those in the queries).
|
157 |
safe_streaming (bool): Bug fix, will go away with xformers update.
|
158 |
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
|
159 |
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
|
160 |
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
|
161 |
+
device (torch.device, optional): Device on which to initialize.
|
162 |
+
dtype (torch.dtype, optional): dtype to use.
|
163 |
"""
|
164 |
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
|
165 |
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
|
|
|
234 |
# Return a causal mask, accounting for potentially stored past keys/values
|
235 |
# We actually return a bias for the attention score, as this has the same
|
236 |
# convention both in the builtin MHA in Pytorch, and Xformers functions.
|
237 |
+
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
238 |
if self.memory_efficient:
|
239 |
from xformers.ops import LowerTriangularMask
|
240 |
if current_steps == 1:
|
241 |
# If we only have one step, then we do not need a mask.
|
242 |
return None
|
243 |
elif 'past_keys' in self._streaming_state:
|
244 |
+
raise RuntimeError("Not supported at the moment")
|
245 |
else:
|
246 |
# Then we can safely use a lower triangular mask
|
247 |
return LowerTriangularMask()
|
|
|
264 |
torch.full([], float('-inf'), device=device, dtype=dtype))
|
265 |
|
266 |
def _complete_kv(self, k, v):
|
267 |
+
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
268 |
if self.cross_attention:
|
269 |
# With cross attention we assume all keys and values
|
270 |
# are already available, and streaming is with respect
|
|
|
298 |
return nk, nv
|
299 |
|
300 |
def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
|
301 |
+
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
|
|
302 |
# Apply rope embeddings to query and key tensors.
|
303 |
assert self.rope is not None
|
304 |
if 'past_keys' in self._streaming_state:
|
|
|
310 |
else:
|
311 |
past_context_offset = 0
|
312 |
streaming_offset = past_context_offset + past_keys_offset
|
313 |
+
return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)
|
314 |
|
315 |
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
316 |
key_padding_mask=None, need_weights=False, attn_mask=None,
|
317 |
average_attn_weights=True, is_causal=False):
|
318 |
assert attn_mask is None
|
319 |
+
assert not is_causal, ("New param added in torch 2.0.1 not supported, "
|
320 |
"use the causal args in the constructor.")
|
321 |
|
322 |
+
time_dim = _get_attention_time_dimension(self.memory_efficient)
|
323 |
if time_dim == 2:
|
324 |
layout = "b h t d"
|
325 |
else:
|
|
|
393 |
q, k = self._apply_rope(q, k)
|
394 |
k, v = self._complete_kv(k, v)
|
395 |
if self.kv_repeat > 1:
|
396 |
+
k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient)
|
397 |
+
v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient)
|
398 |
if self.attention_as_float32:
|
399 |
q, k, v = [x.float() for x in [q, k, v]]
|
400 |
if self.memory_efficient:
|
|
|
454 |
bias_ff (bool): Use bias for FF.
|
455 |
bias_attn (bool): Use bias for MHA.
|
456 |
causal (bool): Causal mask applied automatically.
|
457 |
+
past_context (int, optional): Receptive field for the causal mask, infinite if None.
|
458 |
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
459 |
memory_efficient (bool): Use xformers based memory efficient attention.
|
460 |
attention_as_float32 (bool): Perform the attention as float32
|
|
|
464 |
cross_attention (bool): If True, expect to get secondary input for cross-attention.
|
465 |
Cross attention will use the default MHA, as it typically won't require
|
466 |
special treatment.
|
467 |
+
layer_scale (float, optional): If not None, LayerScale will be used with
|
468 |
the given value as initial scale.
|
469 |
+
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
470 |
+
attention_dropout (float, optional): If not None, separate the value of the dimension dropout
|
471 |
in FFN and of the attention dropout.
|
472 |
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
|
473 |
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
|
474 |
+
device (torch.device, optional): Device on which to initialize.
|
475 |
+
dtype (torch.dtype, optional): dtype to use.
|
476 |
**kwargs: See `nn.TransformerEncoderLayer`.
|
477 |
"""
|
478 |
def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
|
|
575 |
bias_ff (bool): Use bias for FF.
|
576 |
bias_attn (bool): Use bias for MHA.
|
577 |
causal (bool): Causal mask applied automatically.
|
578 |
+
past_context (int, optional): Receptive field for the causal mask, infinite if None.
|
579 |
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
580 |
memory_efficient (bool): Use xformers based memory efficient attention.
|
581 |
attention_as_float32 (bool): Perform the attention as float32
|
582 |
(especially important with memory_efficient as autocast won't do this automatically).
|
583 |
cross_attention (bool): If True, expect to get secondary input for cross-attention.
|
584 |
+
layer_scale (float, optional): If not None, LayerScale will be used
|
585 |
with the given value as initial scale.
|
586 |
positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
|
587 |
max_period (float): Maximum period of the time embedding.
|
588 |
positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
|
589 |
xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
|
590 |
+
lr (float, optional): learning rate override through the `make_optim_group` API.
|
591 |
+
weight_decay (float, optional): Weight_decay override through the `make_optim_group` API.
|
592 |
layer_class: (subclass of `StreamingTransformerLayer): class to use
|
593 |
+
to initialize the layers, allowing further customization outside of AudioCraft.
|
594 |
checkpointing (str): Checkpointing strategy to reduce memory usage.
|
595 |
No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
|
596 |
if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
|
597 |
minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
|
598 |
a policy for opting-out some operations of the checkpointing like
|
599 |
linear layers and attention, providing a middle ground between speed and memory.
|
600 |
+
device (torch.device, optional): Device on which to initialize.
|
601 |
+
dtype (torch.dtype, optional): dtype to use.
|
602 |
**kwargs: See `nn.TransformerEncoderLayer`.
|
603 |
"""
|
604 |
def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
|
|
|
648 |
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
|
649 |
# backward hook inside of FSDP...
|
650 |
layer._magma_checkpointed = True # type: ignore
|
|
|
651 |
|
652 |
def _apply_layer(self, layer, *args, **kwargs):
|
653 |
method = self.checkpointing
|
|
|
711 |
return group
|
712 |
|
713 |
|
714 |
+
# special attention related function
|
715 |
|
716 |
def _verify_xformers_memory_efficient_compat():
|
717 |
try:
|
audiocraft/quantization/core_vq.py
CHANGED
@@ -75,7 +75,7 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
|
75 |
return means, bins
|
76 |
|
77 |
|
78 |
-
def
|
79 |
# eq (2) from https://arxiv.org/abs/2112.00384
|
80 |
n = t.shape[0]
|
81 |
normed_codes = l2norm(t)
|
@@ -237,7 +237,7 @@ class VectorQuantization(nn.Module):
|
|
237 |
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
238 |
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
239 |
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
|
240 |
-
for orthogonal
|
241 |
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
242 |
that have an exponential moving average cluster size less than the specified threshold with
|
243 |
randomly selected vector from the current batch.
|
@@ -340,7 +340,7 @@ class VectorQuantization(nn.Module):
|
|
340 |
rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
|
341 |
codebook = codebook[rand_ids]
|
342 |
|
343 |
-
orthogonal_reg_loss =
|
344 |
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
|
345 |
|
346 |
quantize = self.project_out(quantize)
|
@@ -371,11 +371,16 @@ class ResidualVectorQuantization(nn.Module):
|
|
371 |
|
372 |
for i, layer in enumerate(self.layers[:n_q]):
|
373 |
quantized, indices, loss = layer(residual)
|
|
|
374 |
residual = residual - quantized
|
375 |
quantized_out = quantized_out + quantized
|
376 |
all_indices.append(indices)
|
377 |
all_losses.append(loss)
|
378 |
|
|
|
|
|
|
|
|
|
379 |
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
380 |
return quantized_out, out_indices, out_losses
|
381 |
|
|
|
75 |
return means, bins
|
76 |
|
77 |
|
78 |
+
def orthogonal_loss_fn(t):
|
79 |
# eq (2) from https://arxiv.org/abs/2112.00384
|
80 |
n = t.shape[0]
|
81 |
normed_codes = l2norm(t)
|
|
|
237 |
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
238 |
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
239 |
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
|
240 |
+
for orthogonal regularization.
|
241 |
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
242 |
that have an exponential moving average cluster size less than the specified threshold with
|
243 |
randomly selected vector from the current batch.
|
|
|
340 |
rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
|
341 |
codebook = codebook[rand_ids]
|
342 |
|
343 |
+
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
|
344 |
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
|
345 |
|
346 |
quantize = self.project_out(quantize)
|
|
|
371 |
|
372 |
for i, layer in enumerate(self.layers[:n_q]):
|
373 |
quantized, indices, loss = layer(residual)
|
374 |
+
quantized = quantized.detach()
|
375 |
residual = residual - quantized
|
376 |
quantized_out = quantized_out + quantized
|
377 |
all_indices.append(indices)
|
378 |
all_losses.append(loss)
|
379 |
|
380 |
+
if self.training:
|
381 |
+
# Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
|
382 |
+
quantized_out = x + (quantized_out - x).detach()
|
383 |
+
|
384 |
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
385 |
return quantized_out, out_indices, out_losses
|
386 |
|
audiocraft/utils/cache.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
from collections import deque
|
9 |
+
from functools import partial
|
10 |
+
from hashlib import sha1
|
11 |
+
import logging
|
12 |
+
from pathlib import Path
|
13 |
+
import sys
|
14 |
+
import typing as tp
|
15 |
+
import zipfile
|
16 |
+
|
17 |
+
import flashy
|
18 |
+
import torch
|
19 |
+
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor:
|
25 |
+
"""Utility function for the EmbeddingCache, returning the full embedding without any chunking.
|
26 |
+
This method can be used in case there is no need in extracting a chunk of the full embedding
|
27 |
+
read from the cache.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
full_embed (torch.Tensor): The full embedding.
|
31 |
+
x (any): Batch object from which the full embedding is derived.
|
32 |
+
idx (torch.Tensor): Index of object to consider in the batch object.
|
33 |
+
Returns:
|
34 |
+
full_embed (torch.Tensor): The full embedding
|
35 |
+
"""
|
36 |
+
return full_embed.to(device)
|
37 |
+
|
38 |
+
|
39 |
+
class EmbeddingCache:
|
40 |
+
"""Cache around embeddings computation for faster execution.
|
41 |
+
The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API
|
42 |
+
to retrieve the pre-computed embeddings on full inputs and extract only a given chunk
|
43 |
+
using a user-provided function. When the cache is warm (all embeddings are pre-computed),
|
44 |
+
the EmbeddingCache allows for faster training as it removes the need of computing the embeddings.
|
45 |
+
Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint
|
46 |
+
and synchronization points in the forward calls.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk.
|
50 |
+
device (str or torch.device): Device on which the embedding is returned.
|
51 |
+
compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute
|
52 |
+
the embedding from a given object and path. This user provided function can compute the
|
53 |
+
embedding from the provided object or using the provided path as entry point. The last parameter
|
54 |
+
specify the index corresponding to the current embedding in the object that can represent batch metadata.
|
55 |
+
extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract
|
56 |
+
the desired embedding chunk from the full embedding loaded from the cache. The last parameter
|
57 |
+
specify the index corresponding to the current embedding in the object that can represent batch metadata.
|
58 |
+
If not specified, will return the full embedding unmodified.
|
59 |
+
"""
|
60 |
+
def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device],
|
61 |
+
compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
|
62 |
+
extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
|
63 |
+
self.cache_path = Path(cache_path)
|
64 |
+
self.device = device
|
65 |
+
self._compute_embed_fn = compute_embed_fn
|
66 |
+
self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]
|
67 |
+
if extract_embed_fn is not None:
|
68 |
+
self._extract_embed_fn = extract_embed_fn
|
69 |
+
else:
|
70 |
+
self._extract_embed_fn = partial(get_full_embed, device=device)
|
71 |
+
if self.cache_path is not None:
|
72 |
+
self.cache_path.mkdir(exist_ok=True, parents=True)
|
73 |
+
logger.info(f"Cache instantiated at: {self.cache_path}")
|
74 |
+
self.pool = ThreadPoolExecutor(8)
|
75 |
+
self.pool.__enter__()
|
76 |
+
self._current_batch_cache: dict = {}
|
77 |
+
self._memory_cache: dict = {}
|
78 |
+
|
79 |
+
def _get_cache_path(self, path: tp.Union[Path, str]):
|
80 |
+
"""Get cache path for the given file path."""
|
81 |
+
sig = sha1(str(path).encode()).hexdigest()
|
82 |
+
return self.cache_path / sig
|
83 |
+
|
84 |
+
@staticmethod
|
85 |
+
def _get_full_embed_from_cache(cache: Path):
|
86 |
+
"""Loads full pre-computed embedding from the cache."""
|
87 |
+
try:
|
88 |
+
embed = torch.load(cache, 'cpu')
|
89 |
+
except Exception as exc:
|
90 |
+
logger.error("Error loading %s: %r", cache, exc)
|
91 |
+
embed = None
|
92 |
+
return embed
|
93 |
+
|
94 |
+
def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor:
|
95 |
+
"""Get embedding from cache, computing and storing it to cache if not already cached.
|
96 |
+
The EmbeddingCache first tries to load the embedding from the in-memory cache
|
97 |
+
containing the pre-computed chunks populated through `populate_embed_cache`.
|
98 |
+
If not found, the full embedding is computed and stored on disk to be later accessed
|
99 |
+
to populate the in-memory cache, and the desired embedding chunk is extracted and returned.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
paths (list[Path or str]): List of paths from where the embeddings can be loaded.
|
103 |
+
x (any): Object from which the embedding is extracted.
|
104 |
+
"""
|
105 |
+
embeds = []
|
106 |
+
for idx, path in enumerate(paths):
|
107 |
+
cache = self._get_cache_path(path)
|
108 |
+
if cache in self._current_batch_cache:
|
109 |
+
embed = self._current_batch_cache[cache]
|
110 |
+
else:
|
111 |
+
full_embed = self._compute_embed_fn(path, x, idx)
|
112 |
+
try:
|
113 |
+
with flashy.utils.write_and_rename(cache, pid=True) as f:
|
114 |
+
torch.save(full_embed.cpu(), f)
|
115 |
+
except Exception as exc:
|
116 |
+
logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc)
|
117 |
+
else:
|
118 |
+
logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape)
|
119 |
+
embed = self._extract_embed_fn(full_embed, x, idx)
|
120 |
+
embeds.append(embed)
|
121 |
+
embed = torch.stack(embeds, dim=0)
|
122 |
+
return embed
|
123 |
+
|
124 |
+
def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
|
125 |
+
"""Populate in-memory caches for embeddings reading from the embeddings stored on disk.
|
126 |
+
The in-memory caches consist in a cache for the full embedding and another cache for the
|
127 |
+
final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings
|
128 |
+
and reduce the IO footprint and synchronization points during forward passes.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
paths (list[Path]): List of paths from where the embeddings can be loaded.
|
132 |
+
x (any): Object from which the embedding is extracted.
|
133 |
+
"""
|
134 |
+
self._current_batch_cache.clear()
|
135 |
+
if self.cache_path is not None:
|
136 |
+
futures: list = []
|
137 |
+
for path in paths:
|
138 |
+
assert path is not None, "Path is required for computation from cache"
|
139 |
+
cache = self._get_cache_path(path)
|
140 |
+
if cache in self._memory_cache or not cache.exists():
|
141 |
+
futures.append(None)
|
142 |
+
else:
|
143 |
+
futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache))
|
144 |
+
for idx, (path, future) in enumerate(zip(paths, futures)):
|
145 |
+
assert path is not None
|
146 |
+
cache = self._get_cache_path(path)
|
147 |
+
full_embed = None
|
148 |
+
if future is None:
|
149 |
+
if cache in self._memory_cache:
|
150 |
+
full_embed = self._memory_cache[cache]
|
151 |
+
else:
|
152 |
+
full_embed = future.result()
|
153 |
+
if full_embed is not None:
|
154 |
+
self._memory_cache[cache] = full_embed
|
155 |
+
full_embed = full_embed.to(self.device)
|
156 |
+
if full_embed is not None:
|
157 |
+
embed = self._extract_embed_fn(full_embed, x, idx)
|
158 |
+
self._current_batch_cache[cache] = embed
|
159 |
+
|
160 |
+
|
161 |
+
class CachedBatchWriter:
|
162 |
+
"""Write pre computed caches for mini batches. This can
|
163 |
+
make loading a lot more efficient depending on your filesystem.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
cache_folder (Path): folder in which the cached minibatches
|
167 |
+
will be stored.
|
168 |
+
|
169 |
+
Inside cache folder, the structure is the following:
|
170 |
+
`epoch_number / update_number.zip`
|
171 |
+
And the zip file contains one entry per batch item.
|
172 |
+
|
173 |
+
It is possible to use the cache with a batch size smaller than
|
174 |
+
created with but obviously not larger. Make sure to call the
|
175 |
+
`start_epoch(epoch)` method for indicating changes of epochs.
|
176 |
+
|
177 |
+
See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py`
|
178 |
+
for an example of how to warmup the cache.
|
179 |
+
"""
|
180 |
+
def __init__(self, cache_folder: Path):
|
181 |
+
self.cache_folder = cache_folder
|
182 |
+
self._current_epoch: tp.Optional[int] = None
|
183 |
+
self._current_index = 0
|
184 |
+
|
185 |
+
def start_epoch(self, epoch: int):
|
186 |
+
"""Call at the beginning of each epoch.
|
187 |
+
"""
|
188 |
+
self._current_epoch = epoch
|
189 |
+
self._current_index = 0
|
190 |
+
self._zip_path.parent.mkdir(exist_ok=True, parents=True)
|
191 |
+
|
192 |
+
@staticmethod
|
193 |
+
def _get_zip_path(cache_folder: Path, epoch: int, index: int):
|
194 |
+
return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip"
|
195 |
+
|
196 |
+
@property
|
197 |
+
def _zip_path(self):
|
198 |
+
assert self._current_epoch is not None
|
199 |
+
return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index)
|
200 |
+
|
201 |
+
def save(self, *content):
|
202 |
+
"""Save one mini batch. This function is distributed-aware
|
203 |
+
and will automatically merge all the items from the different
|
204 |
+
workers.
|
205 |
+
"""
|
206 |
+
all_contents = []
|
207 |
+
for rank in range(flashy.distrib.world_size()):
|
208 |
+
their_content = flashy.distrib.broadcast_object(content, src=rank)
|
209 |
+
all_contents.append(their_content)
|
210 |
+
|
211 |
+
if flashy.distrib.is_rank_zero():
|
212 |
+
idx = 0
|
213 |
+
with flashy.utils.write_and_rename(self._zip_path) as tmp:
|
214 |
+
with zipfile.ZipFile(tmp, 'w') as zf:
|
215 |
+
for content in all_contents:
|
216 |
+
for vals in zip(*content):
|
217 |
+
with zf.open(f'{idx}', 'w') as f: # type: ignore
|
218 |
+
torch.save(vals, f)
|
219 |
+
idx += 1
|
220 |
+
flashy.distrib.barrier()
|
221 |
+
self._current_index += 1
|
222 |
+
|
223 |
+
|
224 |
+
class CachedBatchLoader:
|
225 |
+
"""Loader for cached mini-batches dumped with `CachedBatchWriter`.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
cache_folder (Path): folder in which the cached minibatches are stored.
|
229 |
+
batch_size (int): batch size (per GPU) expected.
|
230 |
+
num_workers (int): number of workers to use for loading.
|
231 |
+
min_length (int): minimum expected length for each epoch. If some
|
232 |
+
mini-batches are missing, and error is raised.
|
233 |
+
|
234 |
+
This is iterable just like a regular DataLoader.
|
235 |
+
"""
|
236 |
+
|
237 |
+
def __init__(self, cache_folder: Path, batch_size: int,
|
238 |
+
num_workers: int = 10, min_length: int = 1):
|
239 |
+
self.cache_folder = cache_folder
|
240 |
+
self.batch_size = batch_size
|
241 |
+
self.num_workers = num_workers
|
242 |
+
self.min_length = min_length
|
243 |
+
self._current_epoch: tp.Optional[int] = None
|
244 |
+
self.sampler = None # for compatibility with the regular DataLoader
|
245 |
+
|
246 |
+
def __len__(self):
|
247 |
+
path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent
|
248 |
+
return len([p for p in path.iterdir() if p.suffix == ".zip"])
|
249 |
+
|
250 |
+
def start_epoch(self, epoch: int):
|
251 |
+
"""Call at the beginning of each epoch.
|
252 |
+
"""
|
253 |
+
self._current_epoch = epoch
|
254 |
+
|
255 |
+
def _zip_path(self, index: int):
|
256 |
+
assert self._current_epoch is not None
|
257 |
+
return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index)
|
258 |
+
|
259 |
+
def _load_one(self, index: int):
|
260 |
+
zip_path = self._zip_path(index)
|
261 |
+
if not zip_path.exists():
|
262 |
+
if index < self.min_length:
|
263 |
+
raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist")
|
264 |
+
|
265 |
+
return None
|
266 |
+
mode = "rb" if sys.version_info >= (3, 9) else "r"
|
267 |
+
try:
|
268 |
+
with zipfile.ZipFile(zip_path, 'r') as zf:
|
269 |
+
rank = flashy.distrib.rank()
|
270 |
+
world_size = flashy.distrib.world_size()
|
271 |
+
root = zipfile.Path(zf)
|
272 |
+
items = list(root.iterdir())
|
273 |
+
total_batch_size = self.batch_size * world_size
|
274 |
+
if len(items) < total_batch_size:
|
275 |
+
raise RuntimeError(
|
276 |
+
f"The cache can handle a max batch size of {len(items)}, "
|
277 |
+
f"but {total_batch_size} is needed.")
|
278 |
+
start = rank * self.batch_size
|
279 |
+
items = items[start: start + self.batch_size]
|
280 |
+
assert len(items) == self.batch_size
|
281 |
+
entries = []
|
282 |
+
entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore
|
283 |
+
transposed = zip(*entries)
|
284 |
+
out = []
|
285 |
+
for part in transposed:
|
286 |
+
assert len(part) > 0
|
287 |
+
if isinstance(part[0], torch.Tensor):
|
288 |
+
out.append(torch.stack(part))
|
289 |
+
else:
|
290 |
+
assert isinstance(part, torch.Tensor)
|
291 |
+
out.append(part)
|
292 |
+
return out
|
293 |
+
except Exception:
|
294 |
+
logger.error("Error when reading zip path %s", zip_path)
|
295 |
+
raise
|
296 |
+
|
297 |
+
def __iter__(self):
|
298 |
+
"""This will yields tuples, exactly as provided to the
|
299 |
+
`CachedBatchWriter.save` method.
|
300 |
+
"""
|
301 |
+
pool = ThreadPoolExecutor(self.num_workers)
|
302 |
+
next_index = 0
|
303 |
+
queue = deque()
|
304 |
+
|
305 |
+
def _get_next():
|
306 |
+
nonlocal next_index
|
307 |
+
r = queue.popleft().result()
|
308 |
+
if r is None:
|
309 |
+
return None
|
310 |
+
else:
|
311 |
+
queue.append(pool.submit(self._load_one, next_index))
|
312 |
+
next_index += 1
|
313 |
+
return r
|
314 |
+
|
315 |
+
with pool:
|
316 |
+
# fill the buffer of fetching jobs.
|
317 |
+
for _ in range(2 * self.num_workers):
|
318 |
+
queue.append(pool.submit(self._load_one, next_index))
|
319 |
+
next_index += 1
|
320 |
+
while True:
|
321 |
+
batch = _get_next()
|
322 |
+
if batch is None:
|
323 |
+
return
|
324 |
+
yield batch
|
audiocraft/utils/cluster.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Utility functions for SLURM configuration and cluster settings.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from enum import Enum
|
12 |
+
import os
|
13 |
+
import socket
|
14 |
+
import typing as tp
|
15 |
+
|
16 |
+
import omegaconf
|
17 |
+
|
18 |
+
|
19 |
+
class ClusterType(Enum):
|
20 |
+
AWS = "aws"
|
21 |
+
FAIR = "fair"
|
22 |
+
RSC = "rsc"
|
23 |
+
LOCAL_DARWIN = "darwin"
|
24 |
+
DEFAULT = "default" # used for any other cluster.
|
25 |
+
|
26 |
+
|
27 |
+
def _guess_cluster_type() -> ClusterType:
|
28 |
+
uname = os.uname()
|
29 |
+
fqdn = socket.getfqdn()
|
30 |
+
if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn):
|
31 |
+
return ClusterType.AWS
|
32 |
+
|
33 |
+
if fqdn.endswith(".fair"):
|
34 |
+
return ClusterType.FAIR
|
35 |
+
|
36 |
+
if fqdn.endswith(".facebook.com"):
|
37 |
+
return ClusterType.RSC
|
38 |
+
|
39 |
+
if uname.sysname == "Darwin":
|
40 |
+
return ClusterType.LOCAL_DARWIN
|
41 |
+
|
42 |
+
return ClusterType.DEFAULT
|
43 |
+
|
44 |
+
|
45 |
+
def get_cluster_type(
|
46 |
+
cluster_type: tp.Optional[ClusterType] = None,
|
47 |
+
) -> tp.Optional[ClusterType]:
|
48 |
+
if cluster_type is None:
|
49 |
+
return _guess_cluster_type()
|
50 |
+
|
51 |
+
return cluster_type
|
52 |
+
|
53 |
+
|
54 |
+
def get_slurm_parameters(
|
55 |
+
cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
|
56 |
+
) -> omegaconf.DictConfig:
|
57 |
+
"""Update SLURM parameters in configuration based on cluster type.
|
58 |
+
If the cluster type is not specify, it infers it automatically.
|
59 |
+
"""
|
60 |
+
from ..environment import AudioCraftEnvironment
|
61 |
+
cluster_type = get_cluster_type(cluster_type)
|
62 |
+
# apply cluster-specific adjustments
|
63 |
+
if cluster_type == ClusterType.AWS:
|
64 |
+
cfg["mem_per_gpu"] = None
|
65 |
+
cfg["constraint"] = None
|
66 |
+
cfg["setup"] = []
|
67 |
+
elif cluster_type == ClusterType.RSC:
|
68 |
+
cfg["mem_per_gpu"] = None
|
69 |
+
cfg["setup"] = []
|
70 |
+
cfg["constraint"] = None
|
71 |
+
cfg["partition"] = "learn"
|
72 |
+
slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
|
73 |
+
if slurm_exclude is not None:
|
74 |
+
cfg["exclude"] = slurm_exclude
|
75 |
+
return cfg
|
audiocraft/utils/export.py
CHANGED
@@ -11,46 +11,69 @@ Utility to export a training checkpoint to a lightweight release checkpoint.
|
|
11 |
from pathlib import Path
|
12 |
import typing as tp
|
13 |
|
14 |
-
from omegaconf import OmegaConf
|
15 |
import torch
|
16 |
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
cfg['transformer_lm']['n_q'] = 4
|
24 |
-
# Experimental params no longer supported.
|
25 |
-
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
|
26 |
-
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
|
27 |
-
for name in bad_params:
|
28 |
-
del cfg['transformer_lm'][name]
|
29 |
-
OmegaConf.set_struct(cfg, True)
|
30 |
-
return cfg
|
31 |
-
|
32 |
-
|
33 |
-
def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
|
34 |
-
sig = Path(checkpoint_path).parent.name
|
35 |
-
assert len(sig) == 8, "Not a valid Dora signature"
|
36 |
pkg = torch.load(checkpoint_path, 'cpu')
|
37 |
new_pkg = {
|
38 |
-
'best_state': pkg['
|
39 |
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
|
|
|
|
40 |
}
|
41 |
-
out_file =
|
42 |
torch.save(new_pkg, out_file)
|
43 |
return out_file
|
44 |
|
45 |
|
46 |
-
def
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
pkg = torch.load(checkpoint_path, 'cpu')
|
|
|
|
|
|
|
|
|
|
|
50 |
new_pkg = {
|
51 |
-
'best_state':
|
52 |
-
'xp.cfg': OmegaConf.to_yaml(
|
|
|
|
|
53 |
}
|
54 |
-
|
|
|
55 |
torch.save(new_pkg, out_file)
|
56 |
return out_file
|
|
|
11 |
from pathlib import Path
|
12 |
import typing as tp
|
13 |
|
14 |
+
from omegaconf import OmegaConf
|
15 |
import torch
|
16 |
|
17 |
+
from audiocraft import __version__
|
18 |
|
19 |
+
|
20 |
+
def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
|
21 |
+
"""Export only the best state from the given EnCodec checkpoint. This
|
22 |
+
should be used if you trained your own EnCodec model.
|
23 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
pkg = torch.load(checkpoint_path, 'cpu')
|
25 |
new_pkg = {
|
26 |
+
'best_state': pkg['best_state']['model'],
|
27 |
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
28 |
+
'version': __version__,
|
29 |
+
'exported': True,
|
30 |
}
|
31 |
+
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
32 |
torch.save(new_pkg, out_file)
|
33 |
return out_file
|
34 |
|
35 |
|
36 |
+
def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
|
37 |
+
"""Export a compression model (potentially EnCodec) from a pretrained model.
|
38 |
+
This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
|
39 |
+
Do not include the //pretrained/ prefix. For instance if you trained a model
|
40 |
+
with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.
|
41 |
+
|
42 |
+
In that case, this will not actually include a copy of the model, simply the reference
|
43 |
+
to the model used.
|
44 |
+
"""
|
45 |
+
if Path(pretrained_encodec).exists():
|
46 |
+
pkg = torch.load(pretrained_encodec)
|
47 |
+
assert 'best_state' in pkg
|
48 |
+
assert 'xp.cfg' in pkg
|
49 |
+
assert 'version' in pkg
|
50 |
+
assert 'exported' in pkg
|
51 |
+
else:
|
52 |
+
pkg = {
|
53 |
+
'pretrained': pretrained_encodec,
|
54 |
+
'exported': True,
|
55 |
+
'version': __version__,
|
56 |
+
}
|
57 |
+
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
58 |
+
torch.save(pkg, out_file)
|
59 |
+
|
60 |
+
|
61 |
+
def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
|
62 |
+
"""Export only the best state from the given MusicGen or AudioGen checkpoint.
|
63 |
+
"""
|
64 |
pkg = torch.load(checkpoint_path, 'cpu')
|
65 |
+
if pkg['fsdp_best_state']:
|
66 |
+
best_state = pkg['fsdp_best_state']['model']
|
67 |
+
else:
|
68 |
+
assert pkg['best_state']
|
69 |
+
best_state = pkg['best_state']['model']
|
70 |
new_pkg = {
|
71 |
+
'best_state': best_state,
|
72 |
+
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
73 |
+
'version': __version__,
|
74 |
+
'exported': True,
|
75 |
}
|
76 |
+
|
77 |
+
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
78 |
torch.save(new_pkg, out_file)
|
79 |
return out_file
|
audiocraft/utils/export_legacy.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Utility to export a training checkpoint to a lightweight release checkpoint.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from pathlib import Path
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
from omegaconf import OmegaConf, DictConfig
|
15 |
+
import torch
|
16 |
+
|
17 |
+
|
18 |
+
def _clean_lm_cfg(cfg: DictConfig):
|
19 |
+
OmegaConf.set_struct(cfg, False)
|
20 |
+
# This used to be set automatically in the LM solver, need a more robust solution
|
21 |
+
# for the future.
|
22 |
+
cfg['transformer_lm']['card'] = 2048
|
23 |
+
cfg['transformer_lm']['n_q'] = 4
|
24 |
+
# Experimental params no longer supported.
|
25 |
+
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
|
26 |
+
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
|
27 |
+
for name in bad_params:
|
28 |
+
del cfg['transformer_lm'][name]
|
29 |
+
OmegaConf.set_struct(cfg, True)
|
30 |
+
return cfg
|
31 |
+
|
32 |
+
|
33 |
+
def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
|
34 |
+
sig = Path(checkpoint_path).parent.name
|
35 |
+
assert len(sig) == 8, "Not a valid Dora signature"
|
36 |
+
pkg = torch.load(checkpoint_path, 'cpu')
|
37 |
+
new_pkg = {
|
38 |
+
'best_state': pkg['ema']['state']['model'],
|
39 |
+
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
40 |
+
}
|
41 |
+
out_file = Path(out_folder) / f'{sig}.th'
|
42 |
+
torch.save(new_pkg, out_file)
|
43 |
+
return out_file
|
44 |
+
|
45 |
+
|
46 |
+
def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
|
47 |
+
sig = Path(checkpoint_path).parent.name
|
48 |
+
assert len(sig) == 8, "Not a valid Dora signature"
|
49 |
+
pkg = torch.load(checkpoint_path, 'cpu')
|
50 |
+
new_pkg = {
|
51 |
+
'best_state': pkg['fsdp_best_state']['model'],
|
52 |
+
'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
|
53 |
+
}
|
54 |
+
out_file = Path(out_folder) / f'{sig}.th'
|
55 |
+
torch.save(new_pkg, out_file)
|
56 |
+
return out_file
|
audiocraft/utils/extend.py
CHANGED
@@ -179,7 +179,7 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
|
|
179 |
descriptions=[text],
|
180 |
melody_wavs=verse,
|
181 |
sample_rate=sr,
|
182 |
-
progress=
|
183 |
prompt=prompt_segment,
|
184 |
)
|
185 |
# If user selects a prompt segment, use the prompt segment for all segments
|
@@ -280,9 +280,10 @@ def load_font(font_name, font_size=16):
|
|
280 |
if font is None:
|
281 |
try:
|
282 |
req = requests.get(font_name)
|
283 |
-
font = ImageFont.truetype(BytesIO(req.content), font_size)
|
284 |
-
except (FileNotFoundError, OSError):
|
285 |
-
|
|
|
286 |
if font:
|
287 |
print(f"Font loaded {font.getname()}")
|
288 |
else:
|
|
|
179 |
descriptions=[text],
|
180 |
melody_wavs=verse,
|
181 |
sample_rate=sr,
|
182 |
+
progress=True,
|
183 |
prompt=prompt_segment,
|
184 |
)
|
185 |
# If user selects a prompt segment, use the prompt segment for all segments
|
|
|
280 |
if font is None:
|
281 |
try:
|
282 |
req = requests.get(font_name)
|
283 |
+
font = ImageFont.truetype(BytesIO(req.content), font_size)
|
284 |
+
except (FileNotFoundError, OSError):
|
285 |
+
print(f"Font not found: {font_name} Using default font\n")
|
286 |
+
|
287 |
if font:
|
288 |
print(f"Font loaded {font.getname()}")
|
289 |
else:
|
audiocraft/utils/utils.py
CHANGED
@@ -5,9 +5,12 @@
|
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
from concurrent.futures import ProcessPoolExecutor
|
8 |
-
from
|
|
|
9 |
import hashlib
|
|
|
10 |
import logging
|
|
|
11 |
import typing as tp
|
12 |
|
13 |
import flashy
|
@@ -20,6 +23,18 @@ from torch.nn.utils.rnn import pad_sequence
|
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
|
24 |
"""Convenience function to map an omegaconf configuration to a dictionary.
|
25 |
|
@@ -172,7 +187,7 @@ def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> t
|
|
172 |
assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
|
173 |
final_length = lengths.max().item() if not max_len else max_len
|
174 |
final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
|
175 |
-
return torch.arange(final_length)[None, :]
|
176 |
|
177 |
|
178 |
def hash_trick(word: str, vocab_size: int) -> int:
|
@@ -232,3 +247,54 @@ def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tens
|
|
232 |
padded_tensors = padded_tensors.transpose(0, 1)
|
233 |
padded_tensors = padded_tensors.transpose(1, dim + 1)
|
234 |
return padded_tensors, lens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
# LICENSE file in the root directory of this source tree.
|
6 |
|
7 |
from concurrent.futures import ProcessPoolExecutor
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from functools import wraps, lru_cache
|
10 |
import hashlib
|
11 |
+
import json
|
12 |
import logging
|
13 |
+
from pathlib import Path
|
14 |
import typing as tp
|
15 |
|
16 |
import flashy
|
|
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
25 |
|
26 |
+
def model_hash(model: torch.nn.Module) -> str:
|
27 |
+
"""Return a model hash. This should allow us to track regressions in model init
|
28 |
+
from the logs of past experiments.
|
29 |
+
"""
|
30 |
+
hasher = hashlib.sha1()
|
31 |
+
for p in model.parameters():
|
32 |
+
hasher.update(p.data.cpu().numpy().tobytes())
|
33 |
+
return hasher.hexdigest()
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
|
39 |
"""Convenience function to map an omegaconf configuration to a dictionary.
|
40 |
|
|
|
187 |
assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
|
188 |
final_length = lengths.max().item() if not max_len else max_len
|
189 |
final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
|
190 |
+
return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None]
|
191 |
|
192 |
|
193 |
def hash_trick(word: str, vocab_size: int) -> int:
|
|
|
247 |
padded_tensors = padded_tensors.transpose(0, 1)
|
248 |
padded_tensors = padded_tensors.transpose(1, dim + 1)
|
249 |
return padded_tensors, lens
|
250 |
+
|
251 |
+
|
252 |
+
# TODO: Move to flashy?
|
253 |
+
def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
|
254 |
+
dtype: tp.Optional[torch.dtype] = None) -> tp.Any:
|
255 |
+
if isinstance(state, torch.Tensor):
|
256 |
+
if dtype is None or not state.is_floating_point():
|
257 |
+
dtype = state.dtype
|
258 |
+
return state.detach().to(device=device, dtype=dtype, copy=True)
|
259 |
+
elif isinstance(state, dict):
|
260 |
+
return {k: copy_state(v, device, dtype) for k, v in state.items()}
|
261 |
+
elif isinstance(state, list):
|
262 |
+
return [copy_state(v, device, dtype) for v in state]
|
263 |
+
|
264 |
+
|
265 |
+
# TODO: Move to flashy?
|
266 |
+
@contextmanager
|
267 |
+
def swap_state(model, state, **kwargs):
|
268 |
+
old_state = copy_state(model.state_dict())
|
269 |
+
model.load_state_dict(state, **kwargs)
|
270 |
+
try:
|
271 |
+
yield
|
272 |
+
finally:
|
273 |
+
model.load_state_dict(old_state)
|
274 |
+
|
275 |
+
|
276 |
+
@lru_cache(None)
|
277 |
+
def warn_once(logger, msg):
|
278 |
+
"""Warn about a given message only once."""
|
279 |
+
logger.warning(msg)
|
280 |
+
|
281 |
+
|
282 |
+
def is_jsonable(x: tp.Any):
|
283 |
+
"""Check if an object can be serialized into a json:"""
|
284 |
+
try:
|
285 |
+
json.dumps(x)
|
286 |
+
return True
|
287 |
+
except (TypeError, OverflowError):
|
288 |
+
return False
|
289 |
+
|
290 |
+
|
291 |
+
def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
|
292 |
+
"""Wrapper around state dict loading of CLAP model
|
293 |
+
addressing compatibility issues between CLAP and AudioCraft
|
294 |
+
HuggingFace transformer version.
|
295 |
+
See: https://github.com/LAION-AI/CLAP/issues/118
|
296 |
+
"""
|
297 |
+
from clap_module.factory import load_state_dict # type: ignore
|
298 |
+
pkg = load_state_dict(path)
|
299 |
+
pkg.pop('text_branch.embeddings.position_ids', None)
|
300 |
+
clap_model.model.load_state_dict(pkg)
|
modules/file_utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# file_utils
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
def get_file_parts(file_path: str):
|
7 |
+
# Split the path into directory and filename
|
8 |
+
directory, filename = os.path.split(file_path)
|
9 |
+
|
10 |
+
# Split the filename into name and extension
|
11 |
+
name, ext = os.path.splitext(filename)
|
12 |
+
|
13 |
+
# Convert the extension to lowercase
|
14 |
+
new_ext = ext.lower()
|
15 |
+
return directory, filename, name, ext, new_ext
|
16 |
+
|
17 |
+
def rename_file_to_lowercase_extension(file_path: str) -> str:
|
18 |
+
"""
|
19 |
+
Renames a file's extension to lowercase in place.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
file_path (str): The original file path.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
str: The new file path with the lowercase extension.
|
26 |
+
|
27 |
+
Raises:
|
28 |
+
OSError: If there is an error renaming the file (e.g., file not found, permissions issue).
|
29 |
+
"""
|
30 |
+
directory, filename, name, ext, new_ext = get_file_parts(file_path)
|
31 |
+
# If the extension changes, rename the file
|
32 |
+
if ext != new_ext:
|
33 |
+
new_filename = name + new_ext
|
34 |
+
new_file_path = os.path.join(directory, new_filename)
|
35 |
+
try:
|
36 |
+
os.rename(file_path, new_file_path)
|
37 |
+
print(f"Rename {file_path} to {new_file_path}\n")
|
38 |
+
except Exception as e:
|
39 |
+
print(f"os.rename failed: {e}. Falling back to binary copy operation.")
|
40 |
+
try:
|
41 |
+
# Read the file in binary mode and write it to new_file_path
|
42 |
+
with open(file_path, 'rb') as f:
|
43 |
+
data = f.read()
|
44 |
+
with open(new_file_path, 'wb') as f:
|
45 |
+
f.write(data)
|
46 |
+
print(f"Copied {file_path} to {new_file_path}\n")
|
47 |
+
# Optionally, remove the original file after copying
|
48 |
+
#os.remove(file_path)
|
49 |
+
except Exception as inner_e:
|
50 |
+
print(f"Failed to copy file from {file_path} to {new_file_path}: {inner_e}")
|
51 |
+
raise inner_e
|
52 |
+
return new_file_path
|
53 |
+
else:
|
54 |
+
return file_path
|
55 |
+
|
56 |
+
def get_filename(file):
|
57 |
+
# extract filename from file object
|
58 |
+
filename = None
|
59 |
+
if file is not None:
|
60 |
+
filename = file.name
|
61 |
+
return filename
|
62 |
+
|
63 |
+
def convert_title_to_filename(title):
|
64 |
+
# convert title to filename
|
65 |
+
filename = title.lower().replace(" ", "_").replace("/", "_")
|
66 |
+
return filename
|
67 |
+
|
68 |
+
def get_filename_from_filepath(filepath):
|
69 |
+
file_name = os.path.basename(filepath)
|
70 |
+
file_base, file_extension = os.path.splitext(file_name)
|
71 |
+
return file_base, file_extension
|
72 |
+
|
73 |
+
def delete_file(file_path: str) -> None:
|
74 |
+
"""
|
75 |
+
Deletes the specified file.
|
76 |
+
|
77 |
+
Parameters:
|
78 |
+
file_path (str): The path to thefile to delete.
|
79 |
+
|
80 |
+
Raises:
|
81 |
+
FileNotFoundError: If the file does not exist.
|
82 |
+
Exception: If there is an error deleting the file.
|
83 |
+
"""
|
84 |
+
try:
|
85 |
+
path = Path(file_path)
|
86 |
+
path.unlink()
|
87 |
+
print(f"Deleted original file: {file_path}")
|
88 |
+
except FileNotFoundError:
|
89 |
+
print(f"File not found: {file_path}")
|
90 |
+
except Exception as e:
|
91 |
+
print(f"Error deleting file: {e}")
|
modules/gradio.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules.gradio
|
2 |
+
# holds updates and lost code from gradio changes
|
3 |
+
import os
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import PIL
|
7 |
+
import PIL.Image
|
8 |
+
import shutil
|
9 |
+
import subprocess
|
10 |
+
from tempfile import NamedTemporaryFile
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
|
14 |
+
class MatplotlibBackendMananger:
|
15 |
+
def __enter__(self):
|
16 |
+
try:
|
17 |
+
import matplotlib
|
18 |
+
|
19 |
+
self._original_backend = matplotlib.get_backend()
|
20 |
+
matplotlib.use("agg")
|
21 |
+
except ImportError:
|
22 |
+
pass
|
23 |
+
|
24 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
25 |
+
try:
|
26 |
+
import matplotlib
|
27 |
+
|
28 |
+
matplotlib.use(self._original_backend)
|
29 |
+
except ImportError:
|
30 |
+
pass
|
31 |
+
|
32 |
+
gr.utils.MatplotlibBackendMananger = MatplotlibBackendMananger
|
33 |
+
|
34 |
+
def make_waveform(
|
35 |
+
audio: str | tuple[int, np.ndarray],
|
36 |
+
*,
|
37 |
+
bg_color: str = "#f3f4f6",
|
38 |
+
bg_image: str | None = None,
|
39 |
+
fg_alpha: float = 0.75,
|
40 |
+
bars_color: str | tuple[str, str] = ("#fbbf24", "#ea580c"),
|
41 |
+
bar_count: int = 50,
|
42 |
+
bar_width: float = 0.6,
|
43 |
+
animate: bool = False,
|
44 |
+
name: str = "",
|
45 |
+
) -> str:
|
46 |
+
"""
|
47 |
+
Generates a waveform video from an audio file. Useful for creating an easy to share audio visualization. The output should be passed into a `gr.Video` component.
|
48 |
+
Parameters:
|
49 |
+
audio: Audio file path or tuple of (sample_rate, audio_data)
|
50 |
+
bg_color: Background color of waveform (ignored if bg_image is provided)
|
51 |
+
bg_image: Background image of waveform
|
52 |
+
fg_alpha: Opacity of foreground waveform
|
53 |
+
bars_color: Color of waveform bars. Can be a single color or a tuple of (start_color, end_color) of gradient
|
54 |
+
bar_count: Number of bars in waveform
|
55 |
+
bar_width: Width of bars in waveform. 1 represents full width, 0.5 represents half width, etc.
|
56 |
+
animate: If true, the audio waveform overlay will be animated, if false, it will be static.
|
57 |
+
Returns:
|
58 |
+
A filepath to the output video in mp4 format.
|
59 |
+
"""
|
60 |
+
import matplotlib.pyplot as plt
|
61 |
+
from matplotlib.animation import FuncAnimation
|
62 |
+
|
63 |
+
if isinstance(audio, str):
|
64 |
+
audio_file = audio
|
65 |
+
audio = gr.processing_utils.audio_from_file(audio)
|
66 |
+
else:
|
67 |
+
tmp_wav = NamedTemporaryFile(suffix=".wav", delete=False, prefix = name)
|
68 |
+
gr.processing_utils.audio_to_file(audio[0], audio[1], tmp_wav.name, format="wav")
|
69 |
+
audio_file = tmp_wav.name
|
70 |
+
|
71 |
+
if not os.path.isfile(audio_file):
|
72 |
+
raise ValueError("Audio file not found.")
|
73 |
+
|
74 |
+
ffmpeg = shutil.which("ffmpeg")
|
75 |
+
if not ffmpeg:
|
76 |
+
raise RuntimeError("ffmpeg not found.")
|
77 |
+
|
78 |
+
duration = round(len(audio[1]) / audio[0], 4)
|
79 |
+
|
80 |
+
# Helper methods to create waveform
|
81 |
+
def hex_to_rgb(hex_str):
|
82 |
+
return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)]
|
83 |
+
|
84 |
+
def get_color_gradient(c1, c2, n):
|
85 |
+
if n < 1:
|
86 |
+
raise ValueError("Must have at least one stop in gradient")
|
87 |
+
c1_rgb = np.array(hex_to_rgb(c1)) / 255
|
88 |
+
c2_rgb = np.array(hex_to_rgb(c2)) / 255
|
89 |
+
mix_pcts = [x / (n - 1) for x in range(n)]
|
90 |
+
rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts]
|
91 |
+
return [
|
92 |
+
"#" + "".join(f"{int(round(val * 255)):02x}" for val in item)
|
93 |
+
for item in rgb_colors
|
94 |
+
]
|
95 |
+
|
96 |
+
# Reshape audio to have a fixed number of bars
|
97 |
+
samples = audio[1]
|
98 |
+
if len(samples.shape) > 1:
|
99 |
+
samples = np.mean(samples, 1)
|
100 |
+
bins_to_pad = bar_count - (len(samples) % bar_count)
|
101 |
+
samples = np.pad(samples, [(0, bins_to_pad)])
|
102 |
+
samples = np.reshape(samples, (bar_count, -1))
|
103 |
+
samples = np.abs(samples)
|
104 |
+
samples = np.max(samples, 1)
|
105 |
+
|
106 |
+
with MatplotlibBackendMananger():
|
107 |
+
plt.clf()
|
108 |
+
# Plot waveform
|
109 |
+
color = (
|
110 |
+
bars_color
|
111 |
+
if isinstance(bars_color, str)
|
112 |
+
else get_color_gradient(bars_color[0], bars_color[1], bar_count)
|
113 |
+
)
|
114 |
+
|
115 |
+
if animate:
|
116 |
+
fig = plt.figure(figsize=(5, 1), dpi=200, frameon=False)
|
117 |
+
fig.subplots_adjust(left=0, bottom=0, right=1, top=1)
|
118 |
+
plt.axis("off")
|
119 |
+
plt.margins(x=0)
|
120 |
+
|
121 |
+
bar_alpha = fg_alpha if animate else 1.0
|
122 |
+
barcollection = plt.bar(
|
123 |
+
np.arange(0, bar_count),
|
124 |
+
samples * 2,
|
125 |
+
bottom=(-1 * samples),
|
126 |
+
width=bar_width,
|
127 |
+
color=color,
|
128 |
+
alpha=bar_alpha,
|
129 |
+
)
|
130 |
+
|
131 |
+
tmp_img = NamedTemporaryFile(suffix=".png", delete=False, prefix = name)
|
132 |
+
|
133 |
+
savefig_kwargs: dict[str, Any] = {"bbox_inches": "tight"}
|
134 |
+
if bg_image is not None:
|
135 |
+
savefig_kwargs["transparent"] = True
|
136 |
+
if animate:
|
137 |
+
savefig_kwargs["facecolor"] = "none"
|
138 |
+
else:
|
139 |
+
savefig_kwargs["facecolor"] = bg_color
|
140 |
+
plt.savefig(tmp_img.name, **savefig_kwargs)
|
141 |
+
|
142 |
+
if not animate:
|
143 |
+
waveform_img = PIL.Image.open(tmp_img.name)
|
144 |
+
waveform_img = waveform_img.resize((1000, 400))
|
145 |
+
|
146 |
+
# Composite waveform with background image
|
147 |
+
if bg_image is not None:
|
148 |
+
waveform_array = np.array(waveform_img)
|
149 |
+
waveform_array[:, :, 3] = waveform_array[:, :, 3] * fg_alpha
|
150 |
+
waveform_img = PIL.Image.fromarray(waveform_array)
|
151 |
+
|
152 |
+
bg_img = PIL.Image.open(bg_image)
|
153 |
+
waveform_width, waveform_height = waveform_img.size
|
154 |
+
bg_width, bg_height = bg_img.size
|
155 |
+
if waveform_width != bg_width:
|
156 |
+
bg_img = bg_img.resize(
|
157 |
+
(
|
158 |
+
waveform_width,
|
159 |
+
2 * int(bg_height * waveform_width / bg_width / 2),
|
160 |
+
)
|
161 |
+
)
|
162 |
+
bg_width, bg_height = bg_img.size
|
163 |
+
composite_height = max(bg_height, waveform_height)
|
164 |
+
composite = PIL.Image.new(
|
165 |
+
"RGBA", (waveform_width, composite_height), "#FFFFFF"
|
166 |
+
)
|
167 |
+
composite.paste(bg_img, (0, composite_height - bg_height))
|
168 |
+
composite.paste(
|
169 |
+
waveform_img, (0, composite_height - waveform_height), waveform_img
|
170 |
+
)
|
171 |
+
composite.save(tmp_img.name)
|
172 |
+
img_width, img_height = composite.size
|
173 |
+
else:
|
174 |
+
img_width, img_height = waveform_img.size
|
175 |
+
waveform_img.save(tmp_img.name)
|
176 |
+
else:
|
177 |
+
|
178 |
+
def _animate(_):
|
179 |
+
for idx, b in enumerate(barcollection):
|
180 |
+
rand_height = np.random.uniform(0.8, 1.2)
|
181 |
+
b.set_height(samples[idx] * rand_height * 2)
|
182 |
+
b.set_y((-rand_height * samples)[idx])
|
183 |
+
|
184 |
+
frames = int(duration * 10)
|
185 |
+
anim = FuncAnimation(
|
186 |
+
fig, # type: ignore
|
187 |
+
_animate, # type: ignore
|
188 |
+
repeat=False,
|
189 |
+
blit=False,
|
190 |
+
frames=frames,
|
191 |
+
interval=100,
|
192 |
+
)
|
193 |
+
anim.save(
|
194 |
+
tmp_img.name,
|
195 |
+
writer="pillow",
|
196 |
+
fps=10,
|
197 |
+
codec="png",
|
198 |
+
savefig_kwargs=savefig_kwargs,
|
199 |
+
)
|
200 |
+
|
201 |
+
# Convert waveform to video with ffmpeg
|
202 |
+
output_mp4 = NamedTemporaryFile(suffix=".mp4", delete=False, prefix = name)
|
203 |
+
|
204 |
+
if animate and bg_image is not None:
|
205 |
+
ffmpeg_cmd = [
|
206 |
+
ffmpeg,
|
207 |
+
"-loop",
|
208 |
+
"1",
|
209 |
+
"-i",
|
210 |
+
bg_image,
|
211 |
+
"-i",
|
212 |
+
tmp_img.name,
|
213 |
+
"-i",
|
214 |
+
audio_file,
|
215 |
+
"-filter_complex",
|
216 |
+
"[0:v]scale=w=trunc(iw/2)*2:h=trunc(ih/2)*2[bg];[1:v]format=rgba,colorchannelmixer=aa=1.0[ov];[bg][ov]overlay=(main_w-overlay_w*0.9)/2:main_h-overlay_h*0.9/2[output]",
|
217 |
+
"-t",
|
218 |
+
str(duration),
|
219 |
+
"-map",
|
220 |
+
"[output]",
|
221 |
+
"-map",
|
222 |
+
"2:a",
|
223 |
+
"-c:v",
|
224 |
+
"libx264",
|
225 |
+
"-c:a",
|
226 |
+
"aac",
|
227 |
+
"-shortest",
|
228 |
+
"-y",
|
229 |
+
output_mp4.name,
|
230 |
+
]
|
231 |
+
elif animate and bg_image is None:
|
232 |
+
ffmpeg_cmd = [
|
233 |
+
ffmpeg,
|
234 |
+
"-i",
|
235 |
+
tmp_img.name,
|
236 |
+
"-i",
|
237 |
+
audio_file,
|
238 |
+
"-filter_complex",
|
239 |
+
"[0:v][1:a]concat=n=1:v=1:a=1[v];[v]scale=1000:400,format=yuv420p[v_scaled]",
|
240 |
+
"-map",
|
241 |
+
"[v_scaled]",
|
242 |
+
"-map",
|
243 |
+
"1:a",
|
244 |
+
"-c:v",
|
245 |
+
"libx264",
|
246 |
+
"-c:a",
|
247 |
+
"aac",
|
248 |
+
"-shortest",
|
249 |
+
"-y",
|
250 |
+
output_mp4.name,
|
251 |
+
]
|
252 |
+
else:
|
253 |
+
ffmpeg_cmd = [
|
254 |
+
ffmpeg,
|
255 |
+
"-loop",
|
256 |
+
"1",
|
257 |
+
"-i",
|
258 |
+
tmp_img.name,
|
259 |
+
"-i",
|
260 |
+
audio_file,
|
261 |
+
"-vf",
|
262 |
+
f"color=c=#FFFFFF77:s={img_width}x{img_height}[bar];[0][bar]overlay=-w+(w/{duration})*t:H-h:shortest=1", # type: ignore
|
263 |
+
"-t",
|
264 |
+
str(duration),
|
265 |
+
"-y",
|
266 |
+
output_mp4.name,
|
267 |
+
]
|
268 |
+
|
269 |
+
subprocess.check_call(ffmpeg_cmd)
|
270 |
+
return output_mp4.name
|
271 |
+
|
272 |
+
gr.make_waveform = make_waveform
|
modules/user_history.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
User History is a plugin that you can add to your Spaces to cache generated images for your users.
|
3 |
+
|
4 |
+
Key features:
|
5 |
+
- 🤗 Sign in with Hugging Face
|
6 |
+
- Save generated image, video, audio and document files with their metadata: prompts, timestamp, hyper-parameters, etc.
|
7 |
+
- Export your history as zip.
|
8 |
+
- Delete your history to respect privacy.
|
9 |
+
- Compatible with Persistent Storage for long-term storage.
|
10 |
+
- Admin panel to check configuration and disk usage .
|
11 |
+
|
12 |
+
Useful links:
|
13 |
+
- Demo: https://huggingface.co/spaces/Wauplin/gradio-user-history
|
14 |
+
- README: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/README.md
|
15 |
+
- Source file: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/user_history.py
|
16 |
+
- Discussions: https://huggingface.co/spaces/Wauplin/gradio-user-history/discussions
|
17 |
+
"""
|
18 |
+
|
19 |
+
__version__ = "0.2.0"
|
20 |
+
|
21 |
+
import json
|
22 |
+
import os
|
23 |
+
import shutil
|
24 |
+
import warnings
|
25 |
+
from datetime import datetime
|
26 |
+
from functools import cache
|
27 |
+
from pathlib import Path
|
28 |
+
from typing import Callable, Dict, List, Tuple, Any
|
29 |
+
from uuid import uuid4
|
30 |
+
|
31 |
+
import gradio as gr
|
32 |
+
import numpy as np
|
33 |
+
import requests
|
34 |
+
from filelock import FileLock
|
35 |
+
from PIL.Image import Image
|
36 |
+
import filetype
|
37 |
+
import wave
|
38 |
+
from mutagen.mp3 import MP3, EasyMP3
|
39 |
+
import torchaudio
|
40 |
+
import subprocess
|
41 |
+
|
42 |
+
|
43 |
+
def setup(folder_path: str | Path | None = None) -> None:
|
44 |
+
user_history = _UserHistory()
|
45 |
+
user_history.folder_path = _resolve_folder_path(folder_path)
|
46 |
+
user_history.initialized = True
|
47 |
+
|
48 |
+
|
49 |
+
def render() -> None:
|
50 |
+
user_history = _UserHistory()
|
51 |
+
|
52 |
+
# initialize with default config
|
53 |
+
if not user_history.initialized:
|
54 |
+
print("Initializing user history with default config. Use `user_history.setup(...)` to customize folder_path.")
|
55 |
+
setup()
|
56 |
+
|
57 |
+
# Render user history tab
|
58 |
+
gr.Markdown(
|
59 |
+
"## Your past generations\n\nLog in to keep a gallery of your previous generations. Your history will be saved"
|
60 |
+
" and available on your next visit. Make sure to export your images from time to time as this gallery may be"
|
61 |
+
" deleted in the future."
|
62 |
+
)
|
63 |
+
|
64 |
+
if os.getenv("SYSTEM") == "spaces" and not os.path.exists("/data"):
|
65 |
+
gr.Markdown(
|
66 |
+
"**⚠️ Persistent storage is disabled, meaning your history will be lost if the Space gets restarted."
|
67 |
+
" Only the Space owner can setup a Persistent Storage. If you are not the Space owner, consider"
|
68 |
+
" duplicating this Space to set your own storage.⚠️**"
|
69 |
+
)
|
70 |
+
|
71 |
+
with gr.Row():
|
72 |
+
gr.LoginButton(min_width=250)
|
73 |
+
#gr.LogoutButton(min_width=250)
|
74 |
+
refresh_button = gr.Button(
|
75 |
+
"Refresh",
|
76 |
+
icon="./assets/icon_refresh.png",
|
77 |
+
)
|
78 |
+
export_button = gr.Button(
|
79 |
+
"Export",
|
80 |
+
icon="./assets/icon_download.png",
|
81 |
+
)
|
82 |
+
delete_button = gr.Button(
|
83 |
+
"Delete history",
|
84 |
+
icon="./assets/icon_delete.png",
|
85 |
+
)
|
86 |
+
|
87 |
+
# "Export zip" row (hidden by default)
|
88 |
+
with gr.Row():
|
89 |
+
export_file = gr.File(file_count="single", file_types=[".zip"], label="Exported history", visible=False)
|
90 |
+
|
91 |
+
# "Config deletion" row (hidden by default)
|
92 |
+
with gr.Row():
|
93 |
+
confirm_button = gr.Button("Confirm delete all history", variant="stop", visible=False)
|
94 |
+
cancel_button = gr.Button("Cancel", visible=False)
|
95 |
+
|
96 |
+
# Gallery
|
97 |
+
gallery = gr.Gallery(
|
98 |
+
label="Past images",
|
99 |
+
show_label=True,
|
100 |
+
elem_id="gradio_user_history_gallery",
|
101 |
+
object_fit="cover",
|
102 |
+
columns=5,
|
103 |
+
height=600,
|
104 |
+
preview=False,
|
105 |
+
show_share_button=False,
|
106 |
+
show_download_button=True,
|
107 |
+
)
|
108 |
+
gr.Markdown(
|
109 |
+
"User history is powered by"
|
110 |
+
" [Wauplin/gradio-user-history](https://huggingface.co/spaces/Wauplin/gradio-user-history). Integrate it to"
|
111 |
+
" your own Space in just a few lines of code!"
|
112 |
+
)
|
113 |
+
gallery.attach_load_event(_fetch_user_history, every=None)
|
114 |
+
|
115 |
+
# Interactions
|
116 |
+
refresh_button.click(fn=_fetch_user_history, inputs=[], outputs=[gallery], queue=False)
|
117 |
+
export_button.click(fn=_export_user_history, inputs=[], outputs=[export_file], queue=False)
|
118 |
+
|
119 |
+
# Taken from https://github.com/gradio-app/gradio/issues/3324#issuecomment-1446382045
|
120 |
+
delete_button.click(
|
121 |
+
lambda: [gr.update(visible=True), gr.update(visible=True)],
|
122 |
+
outputs=[confirm_button, cancel_button],
|
123 |
+
queue=False,
|
124 |
+
)
|
125 |
+
cancel_button.click(
|
126 |
+
lambda: [gr.update(visible=False), gr.update(visible=False)],
|
127 |
+
outputs=[confirm_button, cancel_button],
|
128 |
+
queue=False,
|
129 |
+
)
|
130 |
+
confirm_button.click(_delete_user_history).then(
|
131 |
+
lambda: [gr.update(visible=False), gr.update(visible=False)],
|
132 |
+
outputs=[confirm_button, cancel_button],
|
133 |
+
queue=False,
|
134 |
+
)
|
135 |
+
|
136 |
+
# Admin section (only shown locally or when logged in as Space owner)
|
137 |
+
_admin_section()
|
138 |
+
|
139 |
+
|
140 |
+
def save_image(
|
141 |
+
profile: gr.OAuthProfile | None,
|
142 |
+
image: Image | np.ndarray | str | Path,
|
143 |
+
label: str | None = None,
|
144 |
+
metadata: Dict | None = None,
|
145 |
+
):
|
146 |
+
# Ignore images from logged out users
|
147 |
+
if profile is None:
|
148 |
+
return
|
149 |
+
username = profile["preferred_username"]
|
150 |
+
|
151 |
+
# Ignore images if user history not used
|
152 |
+
user_history = _UserHistory()
|
153 |
+
if not user_history.initialized:
|
154 |
+
warnings.warn(
|
155 |
+
"User history is not set in Gradio demo. Saving image is ignored. You must use `user_history.render(...)`"
|
156 |
+
" first."
|
157 |
+
)
|
158 |
+
return
|
159 |
+
|
160 |
+
# Copy image to storage
|
161 |
+
image_path = _copy_image(image, dst_folder=user_history._user_images_path(username))
|
162 |
+
|
163 |
+
# Save new image + metadata
|
164 |
+
if metadata is None:
|
165 |
+
metadata = {}
|
166 |
+
if "datetime" not in metadata:
|
167 |
+
metadata["datetime"] = str(datetime.now())
|
168 |
+
data = {"path": str(image_path), "label": label, "metadata": metadata}
|
169 |
+
with user_history._user_lock(username):
|
170 |
+
with user_history._user_jsonl_path(username).open("a") as f:
|
171 |
+
f.write(json.dumps(data) + "\n")
|
172 |
+
|
173 |
+
def save_file(
|
174 |
+
profile: gr.OAuthProfile | None,
|
175 |
+
image: Image | np.ndarray | str | Path | None = None,
|
176 |
+
video: str | Path | None = None,
|
177 |
+
audio: str | Path | None = None,
|
178 |
+
document: str | Path | None = None,
|
179 |
+
label: str | None = None,
|
180 |
+
metadata: Dict | None = None,
|
181 |
+
):
|
182 |
+
# Ignore files from logged out users
|
183 |
+
if profile is None:
|
184 |
+
return
|
185 |
+
username = profile["preferred_username"]
|
186 |
+
|
187 |
+
# Ignore files if user history not used
|
188 |
+
user_history = _UserHistory()
|
189 |
+
if not user_history.initialized:
|
190 |
+
warnings.warn(
|
191 |
+
"User history is not set in Gradio demo. Saving files is ignored. You must use `user_history.render(...)`"
|
192 |
+
" first."
|
193 |
+
)
|
194 |
+
return
|
195 |
+
|
196 |
+
# Save new files + metadata
|
197 |
+
if metadata is None:
|
198 |
+
metadata = {}
|
199 |
+
if "datetime" not in metadata:
|
200 |
+
metadata["datetime"] = str(datetime.now())
|
201 |
+
|
202 |
+
# Copy image to storage
|
203 |
+
image_path = None
|
204 |
+
if image is not None:
|
205 |
+
image_path = _copy_image(image, dst_folder=user_history._user_images_path(username))
|
206 |
+
image_path = _add_metadata(image_path, metadata)
|
207 |
+
|
208 |
+
# Copy video to storage
|
209 |
+
if video is not None:
|
210 |
+
video_path = _copy_file(video, dst_folder=user_history._user_file_path(username, "videos"))
|
211 |
+
video_path = _add_metadata(video_path, metadata)
|
212 |
+
|
213 |
+
# Copy audio to storage
|
214 |
+
if audio is not None:
|
215 |
+
audio_path = _copy_file(audio, dst_folder=user_history._user_file_path(username, "audios"))
|
216 |
+
audio_path = _add_metadata(audio_path, metadata)
|
217 |
+
|
218 |
+
# Copy document to storage
|
219 |
+
if document is not None:
|
220 |
+
document_path = _copy_file(document, dst_folder=user_history._user_file_path(username, "documents"))
|
221 |
+
document_path = _add_metadata(document_path, metadata)
|
222 |
+
|
223 |
+
# Save Json file
|
224 |
+
data = {"image_path": str(image_path), "video_path": str(video_path), "audio_path": str(audio_path), "document_path": str(document_path), "label": label, "metadata": metadata}
|
225 |
+
with user_history._user_lock(username):
|
226 |
+
with user_history._user_jsonl_path(username).open("a") as f:
|
227 |
+
f.write(json.dumps(data) + "\n")
|
228 |
+
|
229 |
+
|
230 |
+
#############
|
231 |
+
# Internals #
|
232 |
+
#############
|
233 |
+
|
234 |
+
|
235 |
+
class _UserHistory(object):
|
236 |
+
_instance = None
|
237 |
+
initialized: bool = False
|
238 |
+
folder_path: Path
|
239 |
+
|
240 |
+
def __new__(cls):
|
241 |
+
# Using singleton pattern => we don't want to expose an object (more complex to use) but still want to keep
|
242 |
+
# state between `render` and `save_image` calls.
|
243 |
+
if cls._instance is None:
|
244 |
+
cls._instance = super(_UserHistory, cls).__new__(cls)
|
245 |
+
return cls._instance
|
246 |
+
|
247 |
+
def _user_path(self, username: str) -> Path:
|
248 |
+
path = self.folder_path / username
|
249 |
+
path.mkdir(parents=True, exist_ok=True)
|
250 |
+
return path
|
251 |
+
|
252 |
+
def _user_lock(self, username: str) -> FileLock:
|
253 |
+
"""Ensure history is not corrupted if concurrent calls."""
|
254 |
+
return FileLock(self.folder_path / f"{username}.lock") # lock outside of folder => better when exporting ZIP
|
255 |
+
|
256 |
+
def _user_jsonl_path(self, username: str) -> Path:
|
257 |
+
return self._user_path(username) / "history.jsonl"
|
258 |
+
|
259 |
+
def _user_images_path(self, username: str) -> Path:
|
260 |
+
path = self._user_path(username) / "images"
|
261 |
+
path.mkdir(parents=True, exist_ok=True)
|
262 |
+
return path
|
263 |
+
|
264 |
+
def _user_file_path(self, username: str, filetype: str = "images") -> Path:
|
265 |
+
path = self._user_path(username) / filetype
|
266 |
+
path.mkdir(parents=True, exist_ok=True)
|
267 |
+
return path
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
def _fetch_user_history(profile: gr.OAuthProfile | None) -> List[Tuple[str, str]]:
|
272 |
+
"""Return saved history for that user, if it exists."""
|
273 |
+
# Cannot load history for logged out users
|
274 |
+
if profile is None:
|
275 |
+
return []
|
276 |
+
username = profile["preferred_username"]
|
277 |
+
|
278 |
+
user_history = _UserHistory()
|
279 |
+
if not user_history.initialized:
|
280 |
+
warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.")
|
281 |
+
return []
|
282 |
+
|
283 |
+
with user_history._user_lock(username):
|
284 |
+
# No file => no history saved yet
|
285 |
+
jsonl_path = user_history._user_jsonl_path(username)
|
286 |
+
if not jsonl_path.is_file():
|
287 |
+
return []
|
288 |
+
|
289 |
+
# Read history
|
290 |
+
images = []
|
291 |
+
for line in jsonl_path.read_text().splitlines():
|
292 |
+
data = json.loads(line)
|
293 |
+
images.append((data["path"], data["label"] or ""))
|
294 |
+
return list(reversed(images))
|
295 |
+
|
296 |
+
|
297 |
+
def _export_user_history(profile: gr.OAuthProfile | None) -> Dict | None:
|
298 |
+
"""Zip all history for that user, if it exists and return it as a downloadable file."""
|
299 |
+
# Cannot load history for logged out users
|
300 |
+
if profile is None:
|
301 |
+
return None
|
302 |
+
username = profile["preferred_username"]
|
303 |
+
|
304 |
+
user_history = _UserHistory()
|
305 |
+
if not user_history.initialized:
|
306 |
+
warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.")
|
307 |
+
return None
|
308 |
+
|
309 |
+
# Zip history
|
310 |
+
with user_history._user_lock(username):
|
311 |
+
path = shutil.make_archive(
|
312 |
+
str(_archives_path() / f"history_{username}"), "zip", user_history._user_path(username)
|
313 |
+
)
|
314 |
+
|
315 |
+
return gr.update(visible=True, value=path)
|
316 |
+
|
317 |
+
|
318 |
+
def _delete_user_history(profile: gr.OAuthProfile | None) -> None:
|
319 |
+
"""Delete all history for that user."""
|
320 |
+
# Cannot load history for logged out users
|
321 |
+
if profile is None:
|
322 |
+
return
|
323 |
+
username = profile["preferred_username"]
|
324 |
+
|
325 |
+
user_history = _UserHistory()
|
326 |
+
if not user_history.initialized:
|
327 |
+
warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.")
|
328 |
+
return
|
329 |
+
|
330 |
+
with user_history._user_lock(username):
|
331 |
+
shutil.rmtree(user_history._user_path(username))
|
332 |
+
|
333 |
+
|
334 |
+
####################
|
335 |
+
# Internal helpers #
|
336 |
+
####################
|
337 |
+
|
338 |
+
|
339 |
+
def _copy_image(image: Image | np.ndarray | str | Path, dst_folder: Path) -> Path:
|
340 |
+
try:
|
341 |
+
"""Copy image to the images folder."""
|
342 |
+
# Already a path => copy it
|
343 |
+
if isinstance(image, str):
|
344 |
+
image = Path(image)
|
345 |
+
if isinstance(image, Path):
|
346 |
+
dst = dst_folder / f"{uuid4().hex}_{Path(image).name}" # keep file ext
|
347 |
+
shutil.copyfile(image, dst)
|
348 |
+
return dst
|
349 |
+
|
350 |
+
# Still a Python object => serialize it
|
351 |
+
if isinstance(image, np.ndarray):
|
352 |
+
image = Image.fromarray(image)
|
353 |
+
if isinstance(image, Image):
|
354 |
+
dst = dst_folder / f"{Path(file).name}_{uuid4().hex}.png"
|
355 |
+
image.save(dst)
|
356 |
+
return dst
|
357 |
+
|
358 |
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
359 |
+
|
360 |
+
except Exception as e:
|
361 |
+
print(f"An error occurred: {e}")
|
362 |
+
if not isinstance(dst, Path):
|
363 |
+
dst = Path(image)
|
364 |
+
return dst # Return the original file_location if an error occurs
|
365 |
+
|
366 |
+
def _copy_file(file: Any | np.ndarray | str | Path, dst_folder: Path) -> Path:
|
367 |
+
try:
|
368 |
+
"""Copy file to the appropriate folder."""
|
369 |
+
# Already a path => copy it
|
370 |
+
if isinstance(file, str):
|
371 |
+
file = Path(file)
|
372 |
+
if isinstance(file, Path):
|
373 |
+
dst = dst_folder / f"{file.stem}_{uuid4().hex}{file.suffix}" # keep file ext
|
374 |
+
shutil.copyfile(file, dst)
|
375 |
+
return dst
|
376 |
+
|
377 |
+
# Still a Python object => serialize it
|
378 |
+
if isinstance(file, np.ndarray):
|
379 |
+
file = Image.fromarray(file)
|
380 |
+
dst = dst_folder / f"{file.filename}_{uuid4().hex}{file.suffix}"
|
381 |
+
file.save(dst)
|
382 |
+
return dst
|
383 |
+
|
384 |
+
# try other file types
|
385 |
+
kind = filetype.guess(file)
|
386 |
+
if kind is not None:
|
387 |
+
dst = dst_folder / f"{Path(file).stem}_{uuid4().hex}.{kind.extension}"
|
388 |
+
shutil.copyfile(file, dst)
|
389 |
+
return dst
|
390 |
+
raise ValueError(f"Unsupported file type: {type(file)}")
|
391 |
+
|
392 |
+
except Exception as e:
|
393 |
+
print(f"An error occurred: {e}")
|
394 |
+
if not isinstance(dst, Path):
|
395 |
+
dst = Path(file)
|
396 |
+
return dst # Return the original file_location if an error occurs
|
397 |
+
|
398 |
+
|
399 |
+
def _add_metadata(file_location: Path, metadata: Dict[str, Any]) -> Path:
|
400 |
+
try:
|
401 |
+
file_type = file_location.suffix
|
402 |
+
valid_file_types = [".wav", ".mp3", ".mp4", ".png"]
|
403 |
+
if file_type not in valid_file_types:
|
404 |
+
raise ValueError("Invalid file type. Valid file types are .wav, .mp3, .mp4, .png")
|
405 |
+
|
406 |
+
if file_type == ".wav":
|
407 |
+
# Open and process .wav file
|
408 |
+
with wave.open(file_location, 'rb') as wav_file:
|
409 |
+
# Get the current metadata
|
410 |
+
current_metadata = {key: value for key, value in wav_file.getparams()._asdict().items() if isinstance(value, (int, float))}
|
411 |
+
|
412 |
+
# Update metadata
|
413 |
+
current_metadata.update(metadata)
|
414 |
+
|
415 |
+
# Reopen the WAV file in write mode
|
416 |
+
with wave.open(file_location, 'wb') as wav_output_file:
|
417 |
+
# Set the new metadata
|
418 |
+
wav_output_file.setparams(wav_file.getparams())
|
419 |
+
|
420 |
+
# Save the WAV file (overwriting the previous version)
|
421 |
+
wav_output_file.close()
|
422 |
+
elif file_type == ".mp3":
|
423 |
+
# Open and process .mp3 file
|
424 |
+
audio = EasyMP3(file_location)
|
425 |
+
|
426 |
+
# Add metadata to the file
|
427 |
+
for key, value in metadata.items():
|
428 |
+
audio[key] = value
|
429 |
+
|
430 |
+
# Save the MP3 file (overwriting the previous version)
|
431 |
+
audio.save()
|
432 |
+
elif file_type == ".mp4":
|
433 |
+
# Open and process .mp4 file
|
434 |
+
# Add metadata to the file
|
435 |
+
wav_file_location = file_location.with_suffix(".wav")
|
436 |
+
wave_exists = wav_file_location.exists()
|
437 |
+
if not wave_exists:
|
438 |
+
# Use torchaudio to create the WAV file if it doesn't exist
|
439 |
+
audio, sample_rate = torchaudio.load(file_location, normalize=True)
|
440 |
+
torchaudio.save(wav_file_location, audio, sample_rate, format='wav')
|
441 |
+
|
442 |
+
# Use ffmpeg to add metadata to the video file
|
443 |
+
metadata_args = [f"{key}={value}" for key, value in metadata.items()]
|
444 |
+
ffmpeg_metadata = ":".join(metadata_args)
|
445 |
+
ffmpeg_cmd = f'ffmpeg -i "{file_location}" -i "{wav_file_location}" -map 0:v:0 -map 1:a:0 -c:v copy -c:a aac -metadata "{ffmpeg_metadata}" "{file_location}"'
|
446 |
+
subprocess.run(ffmpeg_cmd, shell=True, check=True)
|
447 |
+
|
448 |
+
# Remove temporary WAV file
|
449 |
+
if not wave_exists:
|
450 |
+
wav_file_location.unlink()
|
451 |
+
elif file_type == ".png":
|
452 |
+
# Open and process .png file
|
453 |
+
image = Image.open(file_location)
|
454 |
+
exif_data = image.info.get("exif", {})
|
455 |
+
exif_data.update(metadata)
|
456 |
+
# Add metadata to the file
|
457 |
+
image.save(file_location, exif=exif_data)
|
458 |
+
|
459 |
+
return file_location # Return the path to the modified file
|
460 |
+
|
461 |
+
except Exception as e:
|
462 |
+
print(f"An error occurred: {e}")
|
463 |
+
return file_location # Return the original file_location if an error occurs
|
464 |
+
|
465 |
+
def _resolve_folder_path(folder_path: str | Path | None) -> Path:
|
466 |
+
if folder_path is not None:
|
467 |
+
return Path(folder_path).expanduser().resolve()
|
468 |
+
|
469 |
+
if os.getenv("SYSTEM") == "spaces" and os.path.exists("/data"): # Persistent storage is enabled!
|
470 |
+
return Path("/data") / "_user_history"
|
471 |
+
|
472 |
+
# Not in a Space or Persistent storage not enabled => local folder
|
473 |
+
return Path("_user_history").resolve()
|
474 |
+
|
475 |
+
|
476 |
+
def _archives_path() -> Path:
|
477 |
+
# Doesn't have to be on persistent storage as it's only used for download
|
478 |
+
path = Path(__file__).parent / "_user_history_exports"
|
479 |
+
path.mkdir(parents=True, exist_ok=True)
|
480 |
+
return path
|
481 |
+
|
482 |
+
|
483 |
+
#################
|
484 |
+
# Admin section #
|
485 |
+
#################
|
486 |
+
|
487 |
+
|
488 |
+
def _admin_section() -> None:
|
489 |
+
title = gr.Markdown()
|
490 |
+
title.attach_load_event(_display_if_admin(), every=None)
|
491 |
+
|
492 |
+
|
493 |
+
def _display_if_admin() -> Callable:
|
494 |
+
def _inner(profile: gr.OAuthProfile | None) -> str:
|
495 |
+
if profile is None:
|
496 |
+
return ""
|
497 |
+
if profile["preferred_username"] in _fetch_admins():
|
498 |
+
return _admin_content()
|
499 |
+
return ""
|
500 |
+
|
501 |
+
return _inner
|
502 |
+
|
503 |
+
|
504 |
+
def _admin_content() -> str:
|
505 |
+
return f"""
|
506 |
+
## Admin section
|
507 |
+
|
508 |
+
Running on **{os.getenv("SYSTEM", "local")}** (id: {os.getenv("SPACE_ID")}). {_get_msg_is_persistent_storage_enabled()}
|
509 |
+
|
510 |
+
Admins: {', '.join(_fetch_admins())}
|
511 |
+
|
512 |
+
{_get_nb_users()} user(s), {_get_nb_images()} image(s)
|
513 |
+
|
514 |
+
### Configuration
|
515 |
+
|
516 |
+
History folder: *{_UserHistory().folder_path}*
|
517 |
+
|
518 |
+
Exports folder: *{_archives_path()}*
|
519 |
+
|
520 |
+
### Disk usage
|
521 |
+
|
522 |
+
{_disk_space_warning_message()}
|
523 |
+
"""
|
524 |
+
|
525 |
+
|
526 |
+
def _get_nb_users() -> int:
|
527 |
+
user_history = _UserHistory()
|
528 |
+
if not user_history.initialized:
|
529 |
+
return 0
|
530 |
+
if user_history.folder_path is not None and user_history.folder_path.exists():
|
531 |
+
return len([path for path in user_history.folder_path.iterdir() if path.is_dir()])
|
532 |
+
return 0
|
533 |
+
|
534 |
+
|
535 |
+
def _get_nb_images() -> int:
|
536 |
+
user_history = _UserHistory()
|
537 |
+
if not user_history.initialized:
|
538 |
+
return 0
|
539 |
+
if user_history.folder_path is not None and user_history.folder_path.exists():
|
540 |
+
return len([path for path in user_history.folder_path.glob("*/images/*")])
|
541 |
+
return 0
|
542 |
+
|
543 |
+
|
544 |
+
def _get_msg_is_persistent_storage_enabled() -> str:
|
545 |
+
if os.getenv("SYSTEM") == "spaces":
|
546 |
+
if os.path.exists("/data"):
|
547 |
+
return "Persistent storage is enabled."
|
548 |
+
else:
|
549 |
+
return (
|
550 |
+
"Persistent storage is not enabled. This means that user histories will be deleted when the Space is"
|
551 |
+
" restarted. Consider adding a Persistent Storage in your Space settings."
|
552 |
+
)
|
553 |
+
return ""
|
554 |
+
|
555 |
+
|
556 |
+
def _disk_space_warning_message() -> str:
|
557 |
+
user_history = _UserHistory()
|
558 |
+
if not user_history.initialized:
|
559 |
+
return ""
|
560 |
+
|
561 |
+
message = ""
|
562 |
+
if user_history.folder_path is not None:
|
563 |
+
total, used, _ = _get_disk_usage(user_history.folder_path)
|
564 |
+
message += f"History folder: **{used / 1e9 :.0f}/{total / 1e9 :.0f}GB** used ({100*used/total :.0f}%)."
|
565 |
+
|
566 |
+
total, used, _ = _get_disk_usage(_archives_path())
|
567 |
+
message += f"\n\nExports folder: **{used / 1e9 :.0f}/{total / 1e9 :.0f}GB** used ({100*used/total :.0f}%)."
|
568 |
+
|
569 |
+
return f"{message.strip()}"
|
570 |
+
|
571 |
+
|
572 |
+
def _get_disk_usage(path: Path) -> Tuple[int, int, int]:
|
573 |
+
for path in [path] + list(path.parents): # first check target_dir, then each parents one by one
|
574 |
+
try:
|
575 |
+
return shutil.disk_usage(path)
|
576 |
+
except OSError: # if doesn't exist or can't read => fail silently and try parent one
|
577 |
+
pass
|
578 |
+
return 0, 0, 0
|
579 |
+
|
580 |
+
|
581 |
+
@cache
|
582 |
+
def _fetch_admins() -> List[str]:
|
583 |
+
# Running locally => fake user is admin
|
584 |
+
if os.getenv("SYSTEM") != "spaces":
|
585 |
+
return ["FakeGradioUser"]
|
586 |
+
|
587 |
+
# Running in Space but no space_id => ???
|
588 |
+
space_id = os.getenv("SPACE_ID")
|
589 |
+
if space_id is None:
|
590 |
+
return ["Unknown"]
|
591 |
+
|
592 |
+
# Running in Space => try to fetch organization members
|
593 |
+
# Otherwise, it's not an organization => namespace is the user
|
594 |
+
namespace = space_id.split("/")[0]
|
595 |
+
response = requests.get(f"https://huggingface.co/api/organizations/{namespace}/members")
|
596 |
+
if response.status_code == 200:
|
597 |
+
return sorted((member["user"] for member in response.json()), key=lambda x: x.lower())
|
598 |
+
return [namespace]
|
modules/version_info.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modules/version_info.py
|
2 |
+
|
3 |
+
from audiocraft import __version__ as audiocraft_version
|
4 |
+
import subprocess
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import gc
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
git = os.environ.get('GIT', "git")
|
11 |
+
|
12 |
+
def commit_hash():
|
13 |
+
try:
|
14 |
+
return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
|
15 |
+
except Exception:
|
16 |
+
return "<none>"
|
17 |
+
|
18 |
+
def get_xformers_version():
|
19 |
+
try:
|
20 |
+
import xformers
|
21 |
+
return xformers.__version__
|
22 |
+
except Exception:
|
23 |
+
return "<none>"
|
24 |
+
def get_transformers_version():
|
25 |
+
try:
|
26 |
+
import transformers
|
27 |
+
return transformers.__version__
|
28 |
+
except Exception:
|
29 |
+
return "<none>"
|
30 |
+
|
31 |
+
def get_accelerate_version():
|
32 |
+
try:
|
33 |
+
import accelerate
|
34 |
+
return accelerate.__version__
|
35 |
+
except Exception:
|
36 |
+
return "<none>"
|
37 |
+
def get_safetensors_version():
|
38 |
+
try:
|
39 |
+
import safetensors
|
40 |
+
return safetensors.__version__
|
41 |
+
except Exception:
|
42 |
+
return "<none>"
|
43 |
+
def get_diffusers_version():
|
44 |
+
try:
|
45 |
+
import diffusers
|
46 |
+
return diffusers.__version__
|
47 |
+
except Exception:
|
48 |
+
return "<none>"
|
49 |
+
|
50 |
+
def get_torch_info():
|
51 |
+
from torch import __version__ as torch_version_, version, cuda, backends
|
52 |
+
device_type = initialize_cuda()
|
53 |
+
if device_type == "cuda":
|
54 |
+
try:
|
55 |
+
info = [torch_version_, f"CUDA Version:{version.cuda}", f"Available:{cuda.is_available()}", f"flash attention enabled: {backends.cuda.flash_sdp_enabled()}", f"Capabilities: {cuda.get_device_capability(0)}", f"Device Name: {cuda.get_device_name(0)}", f"Device Count: {cuda.device_count()}"]
|
56 |
+
del torch_version_, version, cuda, backends
|
57 |
+
return info
|
58 |
+
except Exception:
|
59 |
+
del torch_version_, version, cuda, backends
|
60 |
+
return "<none>"
|
61 |
+
else:
|
62 |
+
return "Not Recognized"
|
63 |
+
|
64 |
+
def release_torch_resources():
|
65 |
+
from torch import cuda
|
66 |
+
# Clear the CUDA cache
|
67 |
+
cuda.empty_cache()
|
68 |
+
cuda.ipc_collect()
|
69 |
+
# Delete any objects that are using GPU memory
|
70 |
+
#for obj in gc.get_objects():
|
71 |
+
# if is_tensor(obj) or (hasattr(obj, 'data') and is_tensor(obj.data)):
|
72 |
+
# del obj
|
73 |
+
# Run garbage collection
|
74 |
+
del cuda
|
75 |
+
gc.collect()
|
76 |
+
|
77 |
+
|
78 |
+
def initialize_cuda():
|
79 |
+
from torch import cuda, version
|
80 |
+
if cuda.is_available():
|
81 |
+
device = cuda.device("cuda")
|
82 |
+
print(f"CUDA is available. Using device: {cuda.get_device_name(0)} with CUDA version: {version.cuda}")
|
83 |
+
result = "cuda"
|
84 |
+
else:
|
85 |
+
print("CUDA is not available. Using CPU.")
|
86 |
+
result = "cpu"
|
87 |
+
return result
|
88 |
+
|
89 |
+
def versions_html():
|
90 |
+
from torch import __version__ as torch_version_
|
91 |
+
python_version = ".".join([str(x) for x in sys.version_info[0:3]])
|
92 |
+
commit = commit_hash()
|
93 |
+
|
94 |
+
# Define the Toggle Dark Mode link with JavaScript
|
95 |
+
toggle_dark_link = '''
|
96 |
+
<a href="#" onclick="document.body.classList.toggle('dark'); return false;" style="cursor: pointer; text-decoration: underline;">
|
97 |
+
Toggle Dark Mode
|
98 |
+
</a>
|
99 |
+
'''
|
100 |
+
|
101 |
+
v_html = f"""
|
102 |
+
version: <a href="https://huggingface.co/spaces/Surn/UnlimitedMusicGen/commit/{"huggingface" if commit == "<none>" else commit}" target="_blank">{"huggingface" if commit == "<none>" else commit}</a>
|
103 |
+
 • 
|
104 |
+
autocraft: {audiocraft_version}
|
105 |
+
 • 
|
106 |
+
python: <span title="{sys.version}">{python_version}</span>
|
107 |
+
 • 
|
108 |
+
torch: {torch_version_}
|
109 |
+
 • 
|
110 |
+
xformers: {get_xformers_version()}
|
111 |
+
 • 
|
112 |
+
transformers: {get_transformers_version()}
|
113 |
+
 • 
|
114 |
+
safetensors: {get_safetensors_version()}
|
115 |
+
 • 
|
116 |
+
gradio: {gr.__version__}
|
117 |
+
 • 
|
118 |
+
{toggle_dark_link}
|
119 |
+
<br>
|
120 |
+
Full GPU Info:{get_torch_info()}
|
121 |
+
"""
|
122 |
+
del torch_version_
|
123 |
+
return v_html
|
pre-requirements.txt
CHANGED
@@ -1 +1 @@
|
|
1 |
-
pip>=
|
|
|
1 |
+
pip>=24.0
|
requirements.txt
CHANGED
@@ -1,22 +1,38 @@
|
|
1 |
# please make sure you have already a pytorch install that is cuda enabled!
|
2 |
-
av
|
3 |
einops
|
4 |
flashy>=0.0.1
|
5 |
hydra-core>=1.1
|
6 |
hydra_colorlog
|
7 |
-
|
8 |
-
|
9 |
-
numpy
|
10 |
-
sentencepiece
|
11 |
-
spacy==3.5.2
|
12 |
-
torch==2.0.1
|
13 |
-
torchaudio==2.0.2
|
14 |
soundfile
|
15 |
huggingface_hub
|
16 |
tqdm
|
17 |
-
transformers>=4.
|
18 |
-
xformers>=0.0.
|
19 |
demucs
|
20 |
librosa
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# please make sure you have already a pytorch install that is cuda enabled!
|
2 |
+
av==11.0.0
|
3 |
einops
|
4 |
flashy>=0.0.1
|
5 |
hydra-core>=1.1
|
6 |
hydra_colorlog
|
7 |
+
torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124
|
8 |
+
torchaudio>=2.0.0,<2.6.2 --extra-index-url https://download.pytorch.org/whl/cu124
|
|
|
|
|
|
|
|
|
|
|
9 |
soundfile
|
10 |
huggingface_hub
|
11 |
tqdm
|
12 |
+
transformers>=4.48.0 # need Encodec there.
|
13 |
+
xformers>=0.0.23 --index-url https://download.pytorch.org/whl/cu124
|
14 |
demucs
|
15 |
librosa
|
16 |
+
soundfile
|
17 |
+
gradio==5.23.3
|
18 |
+
gradio[oauth]
|
19 |
+
pillow
|
20 |
+
torchmetrics
|
21 |
+
encodec
|
22 |
+
protobuf>=3.20.1
|
23 |
+
filetype
|
24 |
+
wave
|
25 |
+
mutagen
|
26 |
+
fastapi>=0.88.0
|
27 |
+
pydantic
|
28 |
+
typer
|
29 |
+
torchvision>=0.21.0 --extra-index-url https://download.pytorch.org/whl/cu124
|
30 |
+
#torchtext
|
31 |
+
pesq
|
32 |
+
pystoi
|
33 |
+
julius
|
34 |
+
spacy==3.7.6
|
35 |
+
sentencepiece
|
36 |
+
num2words
|
37 |
+
numpy<1.26.4
|
38 |
+
matplotlib
|
style_20250331.css
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.interface-wrapper {
|
2 |
+
max-width: 1024px;
|
3 |
+
margin: 0 auto;
|
4 |
+
}
|
5 |
+
|
6 |
+
.centered {
|
7 |
+
margin: 0 auto;
|
8 |
+
display: block;
|
9 |
+
text-align:center;
|
10 |
+
}
|
11 |
+
|
12 |
+
.solid {
|
13 |
+
opacity: 1.0 !important;
|
14 |
+
height: auto !important;
|
15 |
+
}
|
16 |
+
|
17 |
+
.intro {
|
18 |
+
font-size: 1.2em !important;
|
19 |
+
font-weight: bold;
|
20 |
+
text-align: center;
|
21 |
+
background-color: rgba(242, 218, 163, 0.62);
|
22 |
+
}
|
23 |
+
|
24 |
+
.dark .gradio-container.gradio-container-5-23-3 .contain .intro .prose {
|
25 |
+
background-color: rgba(41, 18, 5, 0.38) !important;
|
26 |
+
}
|
27 |
+
.toast-body.info {
|
28 |
+
background-color: rgba(242, 218, 163, 0.75);
|
29 |
+
}
|
30 |
+
.dark .toast-body.info {
|
31 |
+
background-color: rgba(128, 128, 128, 0.75);
|
32 |
+
}
|
33 |
+
|
34 |
+
.small {
|
35 |
+
font-size: smaller !important;
|
36 |
+
text-align: center;
|
37 |
+
}
|
38 |
+
|
39 |
+
.imgcontainer img {
|
40 |
+
object-fit: contain !important;
|
41 |
+
}
|
42 |
+
|
43 |
+
#examples {
|
44 |
+
font-weight: bolder;
|
45 |
+
}
|
46 |
+
|
47 |
+
--background-fill-primary: #FBCE50 !important;
|
48 |
+
#col-container {
|
49 |
+
max-width: 1024px;
|
50 |
+
margin-left: auto;
|
51 |
+
margin-right: auto;
|
52 |
+
}
|
53 |
+
|
54 |
+
a {
|
55 |
+
text-decoration-line: underline;
|
56 |
+
font-weight: 600;
|
57 |
+
}
|
58 |
+
|
59 |
+
#btn-generate {
|
60 |
+
background-image: linear-gradient(to right bottom, rgb(157, 255, 157), rgb(229, 255, 235));
|
61 |
+
color: var(--primary-800);
|
62 |
+
}
|
63 |
+
|
64 |
+
#btn-generate:hover {
|
65 |
+
background-image: linear-gradient(to right bottom, rgb(229, 255, 229), rgb(255, 255, 255));
|
66 |
+
}
|
67 |
+
|
68 |
+
#btn-generate:active {
|
69 |
+
background-image: linear-gradient(to right bottom, rgb(229, 255, 235), rgb(157, 255, 157));
|
70 |
+
}
|
71 |
+
|
72 |
+
#versions {
|
73 |
+
margin-top: 1em;
|
74 |
+
width: 100%;
|
75 |
+
text-align: center;
|
76 |
+
}
|
77 |
+
|
78 |
+
.small-btn {
|
79 |
+
max-width: 75px;
|
80 |
+
}
|
81 |
+
|
82 |
+
#gallery .thumbnails, #lora_gallery .thumbnails {
|
83 |
+
flex-direction: column !important;
|
84 |
+
display: inline-flex !important;
|
85 |
+
flex-wrap: wrap !important;
|
86 |
+
position: relative !important;
|
87 |
+
}
|
88 |
+
|
89 |
+
#gallery caption.caption, #lora_gallery caption.caption {
|
90 |
+
flex-direction: row !important;
|
91 |
+
display: inline-flex !important;
|
92 |
+
flex-wrap: wrap;
|
93 |
+
white-space: unset !important;
|
94 |
+
}
|
95 |
+
|
96 |
+
#gallery .image-button img.with-caption, #lora_gallery .image-button img.with-caption {
|
97 |
+
object-fit: cover !important;
|
98 |
+
object-position: center !important;
|
99 |
+
}
|
100 |
+
|
101 |
+
#gallery button.preview, #lora_gallery button.preview {
|
102 |
+
position: relative !important;
|
103 |
+
}
|
104 |
+
|
105 |
+
.gradio-container::before {
|
106 |
+
content: ' ';
|
107 |
+
display: block;
|
108 |
+
position: absolute;
|
109 |
+
left: 0;
|
110 |
+
top: 0;
|
111 |
+
width: 100%;
|
112 |
+
height: 100%;
|
113 |
+
opacity: 0.5;
|
114 |
+
background-image: url('gradio_api/file=./assets/Vermilion-Musical-Notes-Typography-No-Background.svg');
|
115 |
+
background-repeat: no-repeat;
|
116 |
+
background-position: 50% 25%;
|
117 |
+
/*background-color: rgba(0,0,0,0.5);*/
|
118 |
+
background-size: 45vh;
|
119 |
+
overflow: hidden;
|
120 |
+
}
|
121 |
+
|
122 |
+
.gradio-container::after {
|
123 |
+
content: '';
|
124 |
+
position: absolute;
|
125 |
+
top: 0;
|
126 |
+
left: -60%; /* Start off-screen */
|
127 |
+
width: 30%;
|
128 |
+
height: 100%;
|
129 |
+
background: linear-gradient( 120deg, rgba(255, 255, 255, 0) 10%, rgba(255, 255, 255, 0.60) 50%, rgba(255, 255, 255, 0) 90% );
|
130 |
+
animation: shine 30s infinite;
|
131 |
+
}
|
132 |
+
|
133 |
+
#component-0, #component-1 {
|
134 |
+
opacity: 0.9;
|
135 |
+
}
|
136 |
+
|
137 |
+
#excluded_colors {
|
138 |
+
width: 95%;
|
139 |
+
margin: 0 auto;
|
140 |
+
font-size: smaller;
|
141 |
+
}
|
142 |
+
|
143 |
+
@media only screen and (min-width: 1920px) {
|
144 |
+
.gradio-container, .gradio-container::before {
|
145 |
+
max-width: 1920px !important;
|
146 |
+
}
|
147 |
+
}
|
148 |
+
|
149 |
+
.sidebar .toggle-button::before {
|
150 |
+
content: 'Sketch Pad';
|
151 |
+
font-weight: bold;
|
152 |
+
transform: rotate(180deg);
|
153 |
+
margin-right: -120px;
|
154 |
+
width: 120px;
|
155 |
+
background-color: rgba(242, 218, 163, 0.62);
|
156 |
+
}
|
157 |
+
.dark .sidebar .toggle-button::before {
|
158 |
+
background-color: rgba(41, 18, 5, 0.38) !important;
|
159 |
+
}
|
160 |
+
.sidebar.open .toggle-button::before {
|
161 |
+
content: '';
|
162 |
+
}
|
163 |
+
|
164 |
+
#sketchpd, #filters, #image_gen, #accordian_3d {
|
165 |
+
outline-color: #bbf7d0;
|
166 |
+
outline-style:solid;
|
167 |
+
outline-width: 1px;
|
168 |
+
outline-offset: 1px;
|
169 |
+
padding: 2px;
|
170 |
+
border-radius:6px;
|
171 |
+
}
|
172 |
+
.outline-important {
|
173 |
+
outline-color: var(--accordion-text-color);
|
174 |
+
outline-style: solid;
|
175 |
+
outline-width: 2px;
|
176 |
+
outline-offset: 2px;
|
177 |
+
padding: 2px;
|
178 |
+
border-radius: 6px;
|
179 |
+
}
|
180 |
+
.selected.svelte-1tcem6n.svelte-1tcem6n {
|
181 |
+
font-size: large;
|
182 |
+
font-weight: bold;
|
183 |
+
color: var(--body-text-color);
|
184 |
+
}
|
185 |
+
.tab-wrapper.svelte-1tcem6n.svelte-1tcem6n {
|
186 |
+
height: var(--size-12);
|
187 |
+
padding-bottom: var(--size-1);
|
188 |
+
text-align: center;
|
189 |
+
background-blend-mode: multiply;
|
190 |
+
border-radius: var(--block-radius);
|
191 |
+
background-color: var(--block-background-fill);
|
192 |
+
|
193 |
+
outline-color: var(--accordion-text-color);
|
194 |
+
outline-style: solid;
|
195 |
+
outline-width: 2px;
|
196 |
+
outline-offset: 2px;
|
197 |
+
padding: 2px;
|
198 |
+
border-radius: 6px;
|
199 |
+
}
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
@keyframes shine {
|
204 |
+
0% {
|
205 |
+
left: -100%;
|
206 |
+
}
|
207 |
+
|
208 |
+
20% {
|
209 |
+
left: 100%;
|
210 |
+
}
|
211 |
+
|
212 |
+
100% {
|
213 |
+
left: 125%;
|
214 |
+
}
|
215 |
+
}
|