Surn commited on
Commit
e7edd0b
·
1 Parent(s): fef074d

Major Update

Browse files

-metadata
-new models
-background and style
-readme
-requirements
-gradio extension

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. CHANGELOG.md +61 -12
  3. README.md +48 -7
  4. app.py +270 -202
  5. assets/KuritaSurnLogox64.png +0 -0
  6. assets/Vermilion-Musical-Notes-Typography-No-Background.svg +0 -0
  7. assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 +0 -0
  8. assets/icon_delete.png +0 -0
  9. assets/icon_download.png +0 -0
  10. assets/icon_refresh.png +0 -0
  11. assets/logo_animation_256.gif +3 -0
  12. assets/screenshot.png +3 -0
  13. assets/sirens_and_a_humming_engine_approach_and_pass.mp3 +0 -0
  14. audiocraft/__init__.py +1 -1
  15. audiocraft/data/__init__.py +1 -1
  16. audiocraft/data/audio.py +86 -1
  17. audiocraft/data/audio_dataset.py +93 -31
  18. audiocraft/data/audio_utils.py +9 -6
  19. audiocraft/data/info_audio_dataset.py +110 -0
  20. audiocraft/data/zip.py +8 -6
  21. audiocraft/environment.py +176 -0
  22. audiocraft/models/__init__.py +7 -0
  23. audiocraft/models/builders.py +89 -49
  24. audiocraft/models/encodec.py +267 -63
  25. audiocraft/models/lm.py +35 -29
  26. audiocraft/models/loaders.py +146 -5
  27. audiocraft/models/musicgen.py +108 -34
  28. audiocraft/models/unet.py +214 -0
  29. audiocraft/modules/__init__.py +1 -0
  30. audiocraft/modules/chroma.py +66 -0
  31. audiocraft/modules/codebooks_patterns.py +17 -12
  32. audiocraft/modules/conditioners.py +724 -298
  33. audiocraft/modules/conv.py +1 -1
  34. audiocraft/modules/diffusion_schedule.py +272 -0
  35. audiocraft/modules/rope.py +20 -19
  36. audiocraft/modules/transformer.py +36 -38
  37. audiocraft/quantization/core_vq.py +8 -3
  38. audiocraft/utils/cache.py +324 -0
  39. audiocraft/utils/cluster.py +75 -0
  40. audiocraft/utils/export.py +50 -27
  41. audiocraft/utils/export_legacy.py +56 -0
  42. audiocraft/utils/extend.py +5 -4
  43. audiocraft/utils/utils.py +68 -2
  44. modules/file_utils.py +91 -0
  45. modules/gradio.py +272 -0
  46. modules/user_history.py +598 -0
  47. modules/version_info.py +123 -0
  48. pre-requirements.txt +1 -1
  49. requirements.txt +28 -12
  50. style_20250331.css +215 -0
.gitattributes CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ assets/logo_animation_256.gif filter=lfs diff=lfs merge=lfs -text
36
+ assets/screenshot.png filter=lfs diff=lfs merge=lfs -text
CHANGELOG.md CHANGED
@@ -1,13 +1,69 @@
1
- ## [0.0.2a2] - 2023-07-20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- Music Generation set to a max of 720 seconds (12 minutes) to avoid memory issues.
4
 
5
- Video editing options (thanks @Surn and @oncorporation).
6
 
7
- Music Conditioning segment options
 
8
 
9
 
10
- ## [0.0.2a] - TBD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  Improved demo, fixed top p (thanks @jnordberg).
13
 
@@ -24,10 +80,3 @@ Note that other implementations exist: https://github.com/camenduru/MusicGen-col
24
  ## [0.0.1] - 2023-06-09
25
 
26
  Initial release, with model evaluation only.
27
-
28
-
29
- # Changelog
30
-
31
- All notable changes to this project will be documented in this file.
32
-
33
- The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
 
1
+ # Changelog
2
+
3
+ All notable changes to this project will be documented in this file.
4
+
5
+ ## [1.2.Surn] - 2025-04-02
6
+
7
+ Implemented Unlimited Music Generation (UMG) with the [hf checkpoints](https://huggingface.co/facebook/unlimited-music-generation).
8
+
9
+ ## [1.4.0a2] - 2025-01-14
10
+
11
+ Add training and inference code for JASCO (https://arxiv.org/abs/2406.10970) along with the [hf checkpoints](https://huggingface.co/facebook/jasco-chords-drums-melody-1B).
12
+
13
+ ## [1.4.0a1] - 2024-06-03
14
+
15
+ Adding new metric PesqMetric ([Perceptual Evaluation of Speech Quality](https://doi.org/10.5281/zenodo.6549559))
16
+
17
+ Adding multiple audio augmentation functions: generating pink noises, up-/downsampling, low-/highpass filtering, banpass filtering, smoothing, duck masking, boosting. All are wrapped in the `audiocraft.utils.audio_effects.AudioEffects` and can be called with the API `audiocraft.utils.audio_effects.select_audio_effects`.
18
+
19
+ Add training code for AudioSeal (https://arxiv.org/abs/2401.17264) along with the [hf checkpoints]( https://huggingface.co/facebook/audioseal).
20
+
21
+ ## [1.3.0] - 2024-05-02
22
+
23
+ Adding the MAGNeT model (https://arxiv.org/abs/2401.04577) along with hf checkpoints and a gradio demo app.
24
+
25
+ Typo fixes.
26
+
27
+ Fixing setup.py to install only audiocraft, not the unit tests and scripts.
28
+
29
+ Fix FSDP support with PyTorch 2.1.0.
30
+
31
+ ## [1.2.0] - 2024-01-11
32
 
33
+ Adding stereo models.
34
 
35
+ Fixed the commitment loss, which was until now only applied to the first RVQ layer.
36
 
37
+ Removed compression model state from the LM checkpoints, for consistency, it
38
+ should always be loaded from the original `compression_model_checkpoint`.
39
 
40
 
41
+ ## [1.1.0] - 2023-11-06
42
+
43
+ Not using torchaudio anymore when writing audio files, relying instead directly on the commandline ffmpeg. Also not using it anymore for reading audio files, for similar reasons.
44
+
45
+ Fixed DAC support with non default number of codebooks.
46
+
47
+ Fixed bug when `two_step_cfg` was overriden when calling `generate()`.
48
+
49
+ Fixed samples being always prompted with audio, rather than having both prompted and unprompted.
50
+
51
+ **Backward incompatible change:** A `torch.no_grad` around the computation of the conditioning made its way in the public release.
52
+ The released models were trained without this. Those impact linear layers applied to the output of the T5 or melody conditioners.
53
+ We removed it, so you might need to retrain models.
54
+
55
+ **Backward incompatible change:** Fixing wrong sample rate in CLAP (WARNING if you trained model with CLAP before).
56
+
57
+ **Backward incompatible change:** Renamed VALLEPattern to CoarseFirstPattern, as it was wrongly named. Probably no one
58
+ retrained a model with this pattern, so hopefully this won't impact you!
59
+
60
+
61
+ ## [1.0.0] - 2023-09-07
62
+
63
+ Major revision, added training code for EnCodec, AudioGen, MusicGen, and MultiBandDiffusion.
64
+ Added pretrained model for AudioGen and MultiBandDiffusion.
65
+
66
+ ## [0.0.2] - 2023-08-01
67
 
68
  Improved demo, fixed top p (thanks @jnordberg).
69
 
 
80
  ## [0.0.1] - 2023-06-09
81
 
82
  Initial release, with model evaluation only.
 
 
 
 
 
 
 
README.md CHANGED
@@ -4,13 +4,21 @@ emoji: 🎼
4
  colorFrom: gray
5
  colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.38.0
 
8
  app_file: app.py
9
- pinned: false
10
  license: creativeml-openrail-m
11
  tags:
12
  - musicgen
13
  - unlimited
 
 
 
 
 
 
 
14
  ---
15
 
16
  [arxiv]: https://arxiv.org/abs/2306.05284
@@ -18,7 +26,18 @@ tags:
18
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
19
 
20
  # UnlimitedMusicGen
21
- This is my modification of the Audiocraft project to enable unlimited Audio generation. I have added a few features to the original project to enable this. I have also added a few features to the gradio interface to make it easier to use.
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Audiocraft
24
  ![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
@@ -46,12 +65,12 @@ Check out our [sample page][musicgen_samples] or test the available demo!
46
  We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data.
47
 
48
  ## Installation
49
- Audiocraft requires Python 3.9, PyTorch 2.0.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following:
50
 
51
  ```shell
52
  # Best to make sure you have torch installed first, in particular before installing xformers.
53
  # Don't run this if you already have PyTorch installed.
54
- pip install 'torch>=2.0'
55
  # Then proceed to one of the following
56
  pip install -U audiocraft # stable release
57
  pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
@@ -60,7 +79,7 @@ pip install -e . # or if you cloned the repo locally
60
 
61
  ## Usage
62
  We offer a number of way to interact with MusicGen:
63
- 1. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/facebook/MusicGen) (huge thanks to all the HF team for their support).
64
  2. You can run the Gradio demo in Colab: [colab notebook](https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing).
65
  3. You can use the gradio demo locally by running `python app.py`.
66
  4. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally (if you have a GPU).
@@ -178,6 +197,25 @@ For more details on using the MusicGen model for inference using the 🤗 Transf
178
  [MusicGen docs](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen) or the hands-on
179
  [Google Colab](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/MusicGen.ipynb).
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  ## Model Card
183
 
@@ -212,4 +250,7 @@ Check [@camenduru tutorial on Youtube](https://www.youtube.com/watch?v=EGfxuTy9E
212
  * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
213
  * The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
214
  [arxiv]: https://arxiv.org/abs/2306.05284
215
- [musicgen_samples]: https://ai.honu.io/papers/musicgen/
 
 
 
 
4
  colorFrom: gray
5
  colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.23.3
8
+ python_version: 3.12.8
9
  app_file: app.py
10
+ pinned: true
11
  license: creativeml-openrail-m
12
  tags:
13
  - musicgen
14
  - unlimited
15
+ - user history
16
+ - metadata
17
+ hf_oauth: true
18
+ disable_embedding: true
19
+ short_description: 'unlimited Audio generation with a few added features '
20
+ thumbnail: >-
21
+ https://cdn-uploads.huggingface.co/production/uploads/6346595c9e5f0fe83fc60444/Z8E8OaKV84zuVAvvGpMDJ.png
22
  ---
23
 
24
  [arxiv]: https://arxiv.org/abs/2306.05284
 
26
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
27
 
28
  # UnlimitedMusicGen
29
+ Charles Fettinger's modification of the Audiocraft project to enable unlimited Audio generation. I have added a few features to the original project to enable this. I have also added a few features to the gradio interface to make it easier to use.
30
+
31
+ Please review my other AI relalated spaces at https://huggingface.co/Surn
32
+
33
+ Check your video's generative metadata with https://mediaarea.net/en/MediaInfo
34
+
35
+ Also note that I wrote an extension to Gradio for the waveform in the video after v4.48.0 removed it.
36
+
37
+ The key update here is in the extend utility. We segment melody input and then condition the next segment with current tensors and tensors from the current time in the conditioning melody file.
38
+ This allows us to follow the same arraingement of the original melody.
39
+
40
+ **Thank you Huggingface for the community grant to run this project**!!
41
 
42
  # Audiocraft
43
  ![docs badge](https://github.com/facebookresearch/audiocraft/workflows/audiocraft_docs/badge.svg)
 
65
  We use 20K hours of licensed music to train MusicGen. Specifically, we rely on an internal dataset of 10K high-quality music tracks, and on the ShutterStock and Pond5 music data.
66
 
67
  ## Installation
68
+ Audiocraft requires Python 3.9, PyTorch 2.1.0, and a GPU with at least 16 GB of memory (for the medium-sized model). To install Audiocraft, you can run the following:
69
 
70
  ```shell
71
  # Best to make sure you have torch installed first, in particular before installing xformers.
72
  # Don't run this if you already have PyTorch installed.
73
+ pip install 'torch>=2.1'
74
  # Then proceed to one of the following
75
  pip install -U audiocraft # stable release
76
  pip install -U git+https://[email protected]/facebookresearch/audiocraft#egg=audiocraft # bleeding edge
 
79
 
80
  ## Usage
81
  We offer a number of way to interact with MusicGen:
82
+ 1. A demo is also available on the [`facebook/MusicGen` HuggingFace Space](https://huggingface.co/spaces/Surn/UnlimitedMusicGen) (huge thanks to all the HF team for their support).
83
  2. You can run the Gradio demo in Colab: [colab notebook](https://colab.research.google.com/drive/1-Xe9NCdIs2sCUbiSmwHXozK6AAhMm7_i?usp=sharing).
84
  3. You can use the gradio demo locally by running `python app.py`.
85
  4. You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./demo.ipynb) locally (if you have a GPU).
 
197
  [MusicGen docs](https://huggingface.co/docs/transformers/main/en/model_doc/musicgen) or the hands-on
198
  [Google Colab](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/MusicGen.ipynb).
199
 
200
+ ## User History
201
+
202
+ User History is a plugin that you can add to your Spaces to cache generated images for your users.
203
+
204
+ Key features:
205
+ - 🤗 Sign in with Hugging Face
206
+ - Save generated image, video, audio and document files with their metadata: prompts, timestamp, hyper-parameters, etc.
207
+ - Export your history as zip.
208
+ - Delete your history to respect privacy.
209
+ - Compatible with Persistent Storage for long-term storage.
210
+ - Admin panel to check configuration and disk usage .
211
+
212
+ Useful links:
213
+ - Demo: https://huggingface.co/spaces/Wauplin/gradio-user-history
214
+ - README: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/README.md
215
+ - Source file: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/user_history.py
216
+ - Discussions: https://huggingface.co/spaces/Wauplin/gradio-user-history/discussions
217
+
218
+ ![Image preview](./assets/screenshot.png)
219
 
220
  ## Model Card
221
 
 
250
  * The code in this repository is released under the MIT license as found in the [LICENSE file](LICENSE).
251
  * The weights in this repository are released under the CC-BY-NC 4.0 license as found in the [LICENSE_weights file](LICENSE_weights).
252
  [arxiv]: https://arxiv.org/abs/2306.05284
253
+
254
+ [arxiv]: https://arxiv.org/abs/2306.05284
255
+ [musicgen_samples]: https://ai.honu.io/papers/musicgen/
256
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -21,11 +21,17 @@ from audiocraft.models import MusicGen
21
  from audiocraft.data.audio import audio_write
22
  from audiocraft.data.audio_utils import apply_fade, apply_tafade, apply_splice_effect
23
  from audiocraft.utils.extend import generate_music_segments, add_settings_to_image, INTERRUPTING
 
24
  import numpy as np
25
  import random
26
- #from pathlib import Path
 
27
  #from typing import List, Union
28
  import librosa
 
 
 
 
29
 
30
  MODEL = None
31
  MODELS = None
@@ -35,7 +41,12 @@ UNLOAD_MODEL = False
35
  MOVE_TO_CPU = False
36
  MAX_PROMPT_INDEX = 0
37
  git = os.environ.get('GIT', "git")
38
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
 
 
 
 
 
39
 
40
  def interrupt_callback():
41
  return INTERRUPTED
@@ -72,7 +83,7 @@ def toggle_audio_src(choice):
72
  else:
73
  return gr.update(source="upload", value=None, label="File")
74
 
75
- def make_waveform(*args, **kwargs):
76
  # Further remove some warnings.
77
  be = time.time()
78
  with warnings.catch_warnings():
@@ -80,6 +91,7 @@ def make_waveform(*args, **kwargs):
80
  out = gr.make_waveform(*args, **kwargs)
81
  print("Make a video took", time.time() - be)
82
  return out
 
83
 
84
  def load_model(version):
85
  global MODEL, MODELS, UNLOAD_MODEL
@@ -102,32 +114,12 @@ def load_model(version):
102
  print("Cached model loaded in %.2fs" % (time.monotonic() - t1))
103
  return result
104
 
105
- def get_filename(file):
106
- # extract filename from file object
107
- filename = None
108
- if file is not None:
109
- filename = file.name
110
- return filename
111
-
112
- def get_filename_from_filepath(filepath):
113
- file_name = os.path.basename(filepath)
114
- file_base, file_extension = os.path.splitext(file_name)
115
- return file_base, file_extension
116
-
117
  def get_melody(melody_filepath):
118
  audio_data= list(librosa.load(melody_filepath, sr=None))
119
  audio_data[0], audio_data[1] = audio_data[1], audio_data[0]
120
  melody = tuple(audio_data)
121
  return melody
122
 
123
-
124
- def commit_hash():
125
- try:
126
- return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
127
- except Exception:
128
- return "<none>"
129
-
130
-
131
  def git_tag():
132
  try:
133
  return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
@@ -140,28 +132,6 @@ def git_tag():
140
  except Exception:
141
  return "<none>"
142
 
143
- def versions_html():
144
- import torch
145
-
146
- python_version = ".".join([str(x) for x in sys.version_info[0:3]])
147
- commit = commit_hash()
148
- #tag = git_tag()
149
-
150
- import xformers
151
- xformers_version = xformers.__version__
152
-
153
- return f"""
154
- version: <a href="https://github.com/Oncorporation/audiocraft/commit/{"huggingface" if commit == "<none>" else commit}" target="_blank">{"huggingface" if commit == "<none>" else commit}</a>
155
- &#x2000;•&#x2000;
156
- python: <span title="{sys.version}">{python_version}</span>
157
- &#x2000;•&#x2000;
158
- torch: {getattr(torch, '__long_version__',torch.__version__)}
159
- &#x2000;•&#x2000;
160
- xformers: {xformers_version}
161
- &#x2000;•&#x2000;
162
- gradio: {gr.__version__}
163
- """
164
-
165
  def load_melody_filepath(melody_filepath, title):
166
  # get melody filename
167
  #$Union[str, os.PathLike]
@@ -187,12 +157,13 @@ def load_melody_filepath(melody_filepath, title):
187
  print(f"Melody length: {len(melody_data)}, Melody segments: {total_melodys}\n")
188
  MAX_PROMPT_INDEX = total_melodys
189
 
190
- return gr.Textbox.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=0), gr.update(value="melody-large", interactive=True)
191
 
192
  def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False):
193
  global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU
194
  output_segments = None
195
  melody_name = "Not Used"
 
196
  melody = None
197
  if melody_filepath:
198
  melody_name, melody_extension = get_filename_from_filepath(melody_filepath)
@@ -201,17 +172,23 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
201
  INTERRUPTED = False
202
  INTERRUPTING = False
203
  if temperature < 0:
 
204
  raise gr.Error("Temperature must be >= 0.")
205
  if topk < 0:
 
206
  raise gr.Error("Topk must be non-negative.")
207
  if topp < 0:
 
208
  raise gr.Error("Topp must be non-negative.")
209
 
210
- if MODEL is None or MODEL.name != model:
211
- MODEL = load_model(model)
212
- else:
213
- if MOVE_TO_CPU:
214
- MODEL.to('cuda')
 
 
 
215
 
216
  # prevent hacking
217
  duration = min(duration, 720)
@@ -251,35 +228,41 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
251
  rep_penalty=0.5
252
  )
253
 
254
- if melody:
255
- # todo return excess duration, load next model and continue in loop structure building up output_segments
256
- if duration > MODEL.lm.cfg.dataset.segment_duration:
257
- output_segments, duration = generate_music_segments(text, melody, seed, MODEL, duration, overlap, MODEL.lm.cfg.dataset.segment_duration, prompt_index, harmony_only=False)
258
- else:
259
- # pure original code
260
- sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
261
- print(melody.shape)
262
- if melody.dim() == 2:
263
- melody = melody[None]
264
- melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
265
- output = MODEL.generate_with_chroma(
266
- descriptions=[text],
267
- melody_wavs=melody,
268
- melody_sample_rate=sr,
269
- progress=False
270
- )
271
- # All output_segments are populated, so we can break the loop or set duration to 0
272
- break
273
- else:
274
- #output = MODEL.generate(descriptions=[text], progress=False)
275
- if not output_segments:
276
- next_segment = MODEL.generate(descriptions=[text], progress=False)
277
- duration -= segment_duration
278
  else:
279
- last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:]
280
- next_segment = MODEL.generate_continuation(last_chunk, MODEL.sample_rate, descriptions=[text], progress=False)
281
- duration -= segment_duration - overlap
282
- output_segments.append(next_segment)
 
 
 
 
 
 
 
 
 
 
283
 
284
  if INTERRUPTING:
285
  INTERRUPTED = True
@@ -287,6 +270,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
287
  print("Function execution interrupted!")
288
  raise gr.Error("Interrupted.")
289
 
 
290
  if output_segments:
291
  try:
292
  # Combine the output segments into one long audio file or stack tracks
@@ -312,7 +296,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
312
  ##overlapping_output = torch.cat([output[:, :, -overlap_samples:], output_segments[i][:, :, :overlap_samples]], dim=1) #stack tracks
313
  ##print(f" overlap size stack:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}")
314
  #overlapping_output = torch.cat([output[:, :, -overlap_samples:], output_segments[i][:, :, :overlap_samples]], dim=2) #stack tracks
315
- #print(f" overlap size cat:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}")
316
  output = torch.cat([output[:, :, :-overlap_samples], overlapping_output, output_segments[i][:, :, overlap_samples:]], dim=dimension)
317
  else:
318
  output = torch.cat([output, output_segments[i]], dim=dimension)
@@ -321,143 +305,227 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
321
  print(f"Error combining segments: {e}. Using the first segment only.")
322
  output = output_segments[0].detach().cpu().float()[0]
323
  else:
324
- output = output.detach().cpu().float()[0]
325
-
326
- with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
 
 
 
 
327
  video_description = f"{text}\n Duration: {str(initial_duration)} Dimension: {dimension}\n Top-k:{topk} Top-p:{topp}\n Randomness:{temperature}\n cfg:{cfg_coef} overlap: {overlap}\n Seed: {seed}\n Model: {model}\n Melody Condition:{melody_name}\n Sample Segment: {prompt_index}"
328
  if include_settings or include_title:
329
  background = add_settings_to_image(title if include_title else "", video_description if include_settings else "", background_path=background, font=settings_font, font_color=settings_font_color)
330
  audio_write(
331
  file.name, output, MODEL.sample_rate, strategy="loudness",
332
- loudness_headroom_db=18, loudness_compressor=True, add_suffix=False, channels=2)
333
- waveform_video = make_waveform(file.name,bg_image=background, bar_count=45)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  if MOVE_TO_CPU:
335
  MODEL.to('cpu')
336
  if UNLOAD_MODEL:
337
  MODEL = None
338
  torch.cuda.empty_cache()
339
  torch.cuda.ipc_collect()
340
- return waveform_video, file.name, seed
341
 
 
342
  def ui(**kwargs):
343
- css="""
344
- #col-container {max-width: 910px; margin-left: auto; margin-right: auto;}
345
- a {text-decoration-line: underline; font-weight: 600;}
346
- #btn-generate {background-image:linear-gradient(to right bottom, rgb(157, 255, 157), rgb(229, 255, 235));}
347
- #btn-generate:hover {background-image:linear-gradient(to right bottom, rgb(229, 255, 229), rgb(255, 255, 255));}
348
- #btn-generate:active {background-image:linear-gradient(to right bottom, rgb(229, 255, 235), rgb(157, 255, 157));}
349
- #versions {margin-top: 1em; width:100%; text-align:center;}
350
- .small-btn {max-width:75px;}
351
- """
352
- with gr.Blocks(title="UnlimitedMusicGen", css=css) as demo:
353
- gr.Markdown(
354
- """
355
  # UnlimitedMusicGen
356
  This is your private demo for [UnlimitedMusicGen](https://github.com/Oncorporation/audiocraft), a simple and controllable model for music generation
357
  presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
358
 
359
  Disclaimer: This won't run on CPU only. Clone this App and run on GPU instance!
360
 
361
- Todo: Working on improved Interrupt and new Models.
362
- """
363
- )
364
- if IS_SHARED_SPACE and not torch.cuda.is_available():
365
- gr.Markdown("""
366
- This Space doesn't work in this shared UI ⚠
367
-
368
- <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
369
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
370
- to use it privately, or use the <a href="https://huggingface.co/spaces/facebook/MusicGen">public demo</a>
371
- """)
372
- with gr.Row():
373
- with gr.Column():
374
- with gr.Row():
375
- text = gr.Text(label="Describe your music", interactive=True, value="4/4 100bpm 320kbps 48khz, Industrial/Electronic Soundtrack, Dark, Intense, Sci-Fi")
376
- with gr.Column():
377
- duration = gr.Slider(minimum=1, maximum=720, value=10, label="Duration (s)", interactive=True)
378
- model = gr.Radio(["melody", "medium", "small", "large", "melody-large", "stereo-melody", "stereo-medium", "stereo-small", "stereo-large", "stereo-melody-large"], label="AI Model", value="melody-large", interactive=True)
379
- with gr.Row():
380
- submit = gr.Button("Generate", elem_id="btn-generate")
381
- # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
382
- _ = gr.Button("Interrupt", elem_id="btn-interrupt").click(fn=interrupt, queue=False)
383
- with gr.Row():
384
- with gr.Column():
385
- radio = gr.Radio(["file", "mic"], value="file", label="Condition on a melody (optional) File or Mic")
386
- melody_filepath = gr.Audio(source="upload", type="filepath", label="Melody Condition (optional)", interactive=True, elem_id="melody-input")
387
  with gr.Column():
388
- harmony_only = gr.Radio(label="Use Harmony Only",choices=["No", "Yes"], value="No", interactive=True, info="Remove Drums?")
389
- prompt_index = gr.Slider(label="Melody Condition Sample Segment", minimum=-1, maximum=MAX_PROMPT_INDEX, step=1, value=0, interactive=True, info="Which 30 second segment to condition with, - 1 condition each segment independantly")
390
- with gr.Accordion("Video", open=False):
391
- with gr.Row():
392
- background= gr.Image(value="./assets/background.png", source="upload", label="Background", shape=(768,512), type="filepath", interactive=True)
393
- with gr.Column():
394
- include_title = gr.Checkbox(label="Add Title", value=True, interactive=True)
395
- include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True)
396
- with gr.Row():
397
- title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True)
398
- settings_font = gr.Text(label="Settings Font", value="./assets/arial.ttf", interactive=True)
399
- settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#c87f05", interactive=True)
400
- with gr.Accordion("Expert", open=False):
401
- with gr.Row():
402
- overlap = gr.Slider(minimum=0, maximum=15, value=2, step=1, label="Verse Overlap", interactive=True)
403
- dimension = gr.Slider(minimum=-2, maximum=2, value=2, step=1, label="Dimension", info="determines which direction to add new segements of audio. (1 = stack tracks, 2 = lengthen, -2..0 = ?)", interactive=True)
404
- with gr.Row():
405
- topk = gr.Number(label="Top-k", value=280, precision=0, interactive=True)
406
- topp = gr.Number(label="Top-p", value=1150, precision=0, interactive=True)
407
- temperature = gr.Number(label="Randomness Temperature", value=0.7, precision=None, interactive=True)
408
- cfg_coef = gr.Number(label="Classifier Free Guidance", value=8.5, precision=None, interactive=True)
409
- with gr.Row():
410
- seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
411
- gr.Button('\U0001f3b2\ufe0f', elem_classes="small-btn").click(fn=lambda: -1, outputs=[seed], queue=False)
412
- reuse_seed = gr.Button('\u267b\ufe0f', elem_classes="small-btn")
413
- with gr.Column() as c:
414
- output = gr.Video(label="Generated Music")
415
- wave_file = gr.File(label=".wav file", elem_id="output_wavefile", interactive=True)
416
- seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
417
-
418
- radio.change(toggle_audio_src, radio, [melody_filepath], queue=False, show_progress=False)
419
- melody_filepath.change(load_melody_filepath, inputs=[melody_filepath, title], outputs=[title, prompt_index , model], api_name="melody_filepath_change", queue=False)
420
- reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False, api_name="reuse_seed")
421
- submit.click(predict, inputs=[model, text,melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap, prompt_index, include_title, include_settings, harmony_only], outputs=[output, wave_file, seed_used], api_name="submit")
422
- gr.Examples(
423
- fn=predict,
424
- examples=[
425
- [
426
- "4/4 120bpm 320kbps 48khz, An 80s driving pop song with heavy drums and synth pads in the background",
427
- "./assets/bach.mp3",
428
- "melody",
429
- "80s Pop Synth"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
  ],
431
- [
432
- "4/4 120bpm 320kbps 48khz, A cheerful country song with acoustic guitars",
433
- "./assets/bolero_ravel.mp3",
434
- "melody",
435
- "Country Guitar"
436
- ],
437
- [
438
- "4/4 120bpm 320kbps 48khz, 90s rock song with electric guitar and heavy drums",
439
- None,
440
- "medium",
441
- "90s Rock Guitar"
442
- ],
443
- [
444
- "4/4 120bpm 320kbps 48khz, a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
445
- "./assets/bach.mp3",
446
- "melody",
447
- "EDM my Bach"
448
- ],
449
- [
450
- "4/4 320kbps 48khz, lofi slow bpm electro chill with organic samples",
451
- None,
452
- "medium",
453
- "LoFi Chill"
454
- ],
455
- ],
456
- inputs=[text, melody_filepath, model, title],
457
- outputs=[output]
458
- )
459
- gr.HTML(value=versions_html(), visible=True, elem_id="versions")
460
-
461
  # Show the interface
462
  launch_kwargs = {}
463
  share = kwargs.get('share', False)
@@ -471,10 +539,10 @@ def ui(**kwargs):
471
  if share:
472
  launch_kwargs['share'] = share
473
  launch_kwargs['favicon_path']= "./assets/favicon.ico"
 
474
 
475
 
476
-
477
- demo.queue(max_size=10, concurrency_count=1, api_open=False).launch(**launch_kwargs)
478
 
479
  if __name__ == "__main__":
480
  parser = argparse.ArgumentParser()
@@ -518,7 +586,7 @@ if __name__ == "__main__":
518
  args = parser.parse_args()
519
 
520
  launch_kwargs = {}
521
- launch_kwargs['server_name'] = args.listen
522
 
523
  if args.username and args.password:
524
  launch_kwargs['auth'] = (args.username, args.password)
@@ -528,7 +596,7 @@ if __name__ == "__main__":
528
  launch_kwargs['inbrowser'] = args.inbrowser
529
  if args.share:
530
  launch_kwargs['share'] = args.share
531
- launch_kwargs['favicon_path']= "./assets/favicon.ico"
532
 
533
 
534
  UNLOAD_MODEL = args.unload_model
@@ -538,6 +606,6 @@ if __name__ == "__main__":
538
 
539
  ui(
540
  unload_to_cpu = MOVE_TO_CPU,
541
- share=args.share
542
-
543
  )
 
21
  from audiocraft.data.audio import audio_write
22
  from audiocraft.data.audio_utils import apply_fade, apply_tafade, apply_splice_effect
23
  from audiocraft.utils.extend import generate_music_segments, add_settings_to_image, INTERRUPTING
24
+ from audiocraft.utils import utils
25
  import numpy as np
26
  import random
27
+ import shutil
28
+ from mutagen.mp4 import MP4
29
  #from typing import List, Union
30
  import librosa
31
+ import modules.user_history
32
+ from modules.version_info import versions_html, commit_hash, get_xformers_version
33
+ from modules.gradio import *
34
+ from modules.file_utils import get_file_parts, get_filename_from_filepath, convert_title_to_filename, get_filename, delete_file
35
 
36
  MODEL = None
37
  MODELS = None
 
41
  MOVE_TO_CPU = False
42
  MAX_PROMPT_INDEX = 0
43
  git = os.environ.get('GIT', "git")
44
+ #s.environ["CUDA_LAUNCH_BLOCKING"] = "1"
45
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
46
+ os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
47
+ os.environ['CUDA_MODULE_LOADING']='LAZY'
48
+ os.environ['USE_FLASH_ATTENTION'] = '1'
49
+ os.environ['XFORMERS_FORCE_DISABLE_TRITON']= '1'
50
 
51
  def interrupt_callback():
52
  return INTERRUPTED
 
83
  else:
84
  return gr.update(source="upload", value=None, label="File")
85
 
86
+ def get_waveform(*args, **kwargs):
87
  # Further remove some warnings.
88
  be = time.time()
89
  with warnings.catch_warnings():
 
91
  out = gr.make_waveform(*args, **kwargs)
92
  print("Make a video took", time.time() - be)
93
  return out
94
+
95
 
96
  def load_model(version):
97
  global MODEL, MODELS, UNLOAD_MODEL
 
114
  print("Cached model loaded in %.2fs" % (time.monotonic() - t1))
115
  return result
116
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def get_melody(melody_filepath):
118
  audio_data= list(librosa.load(melody_filepath, sr=None))
119
  audio_data[0], audio_data[1] = audio_data[1], audio_data[0]
120
  melody = tuple(audio_data)
121
  return melody
122
 
 
 
 
 
 
 
 
 
123
  def git_tag():
124
  try:
125
  return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip()
 
132
  except Exception:
133
  return "<none>"
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def load_melody_filepath(melody_filepath, title):
136
  # get melody filename
137
  #$Union[str, os.PathLike]
 
157
  print(f"Melody length: {len(melody_data)}, Melody segments: {total_melodys}\n")
158
  MAX_PROMPT_INDEX = total_melodys
159
 
160
+ return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=0), gr.update(value="melody", interactive=True)
161
 
162
  def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False):
163
  global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU
164
  output_segments = None
165
  melody_name = "Not Used"
166
+ melody_extension = "Not Used"
167
  melody = None
168
  if melody_filepath:
169
  melody_name, melody_extension = get_filename_from_filepath(melody_filepath)
 
172
  INTERRUPTED = False
173
  INTERRUPTING = False
174
  if temperature < 0:
175
+ temperature -0
176
  raise gr.Error("Temperature must be >= 0.")
177
  if topk < 0:
178
+ topk = 1
179
  raise gr.Error("Topk must be non-negative.")
180
  if topp < 0:
181
+ topp =1
182
  raise gr.Error("Topp must be non-negative.")
183
 
184
+ try:
185
+ if MODEL is None or MODEL.name != model:
186
+ MODEL = load_model(model)
187
+ else:
188
+ if MOVE_TO_CPU:
189
+ MODEL.to('cuda')
190
+ except Exception as e:
191
+ raise gr.Error(f"Error loading model '{model}': {str(e)}. Try a different model.")
192
 
193
  # prevent hacking
194
  duration = min(duration, 720)
 
228
  rep_penalty=0.5
229
  )
230
 
231
+ try:
232
+ if melody:
233
+ # return excess duration, load next model and continue in loop structure building up output_segments
234
+ if duration > MODEL.lm.cfg.dataset.segment_duration:
235
+ output_segments, duration = generate_music_segments(text, melody, seed, MODEL, duration, overlap, MODEL.lm.cfg.dataset.segment_duration, prompt_index, harmony_only=False)
236
+ else:
237
+ # pure original code
238
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
239
+ print(melody.shape)
240
+ if melody.dim() == 2:
241
+ melody = melody[None]
242
+ melody = melody[..., :int(sr * MODEL.lm.cfg.dataset.segment_duration)]
243
+ output = MODEL.generate_with_chroma(
244
+ descriptions=[text],
245
+ melody_wavs=melody,
246
+ melody_sample_rate=sr,
247
+ progress=True
248
+ )
249
+ # All output_segments are populated, so we can break the loop or set duration to 0
250
+ break
 
 
 
 
251
  else:
252
+ #output = MODEL.generate(descriptions=[text], progress=False)
253
+ if not output_segments:
254
+ next_segment = MODEL.generate(descriptions=[text], progress=True)
255
+ duration -= segment_duration
256
+ else:
257
+ last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:]
258
+ next_segment = MODEL.generate_continuation(last_chunk, MODEL.sample_rate, descriptions=[text], progress=True)
259
+ duration -= segment_duration - overlap
260
+ if next_segment != None:
261
+ output_segments.append(next_segment)
262
+ except Exception as e:
263
+ print(f"Error generating audio: {e}")
264
+ gr.Error(f"Error generating audio: {e}")
265
+ return None, None, seed
266
 
267
  if INTERRUPTING:
268
  INTERRUPTED = True
 
270
  print("Function execution interrupted!")
271
  raise gr.Error("Interrupted.")
272
 
273
+ print(f"\nOutput segments: {len(output_segments)}\n")
274
  if output_segments:
275
  try:
276
  # Combine the output segments into one long audio file or stack tracks
 
296
  ##overlapping_output = torch.cat([output[:, :, -overlap_samples:], output_segments[i][:, :, :overlap_samples]], dim=1) #stack tracks
297
  ##print(f" overlap size stack:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}")
298
  #overlapping_output = torch.cat([output[:, :, -overlap_samples:], output_segments[i][:, :, :overlap_samples]], dim=2) #stack tracks
299
+ #print(f" overlap size cat:{overlapping_output.size()}\n output: {output.size()}\n segment: {output_segments[i].size()}")
300
  output = torch.cat([output[:, :, :-overlap_samples], overlapping_output, output_segments[i][:, :, overlap_samples:]], dim=dimension)
301
  else:
302
  output = torch.cat([output, output_segments[i]], dim=dimension)
 
305
  print(f"Error combining segments: {e}. Using the first segment only.")
306
  output = output_segments[0].detach().cpu().float()[0]
307
  else:
308
+ if (output is None) or (output.dim() == 0):
309
+ return None, None, seed
310
+ else:
311
+ output = output.detach().cpu().float()[0]
312
+ profile: gr.OAuthProfile | None = None
313
+ title_file_name = convert_title_to_filename(title)
314
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False, prefix = title_file_name) as file:
315
  video_description = f"{text}\n Duration: {str(initial_duration)} Dimension: {dimension}\n Top-k:{topk} Top-p:{topp}\n Randomness:{temperature}\n cfg:{cfg_coef} overlap: {overlap}\n Seed: {seed}\n Model: {model}\n Melody Condition:{melody_name}\n Sample Segment: {prompt_index}"
316
  if include_settings or include_title:
317
  background = add_settings_to_image(title if include_title else "", video_description if include_settings else "", background_path=background, font=settings_font, font_color=settings_font_color)
318
  audio_write(
319
  file.name, output, MODEL.sample_rate, strategy="loudness",
320
+ loudness_headroom_db=18, loudness_compressor=True, add_suffix=False, channels=2)
321
+ waveform_video_path = get_waveform(file.name,bg_image=background, bar_count=45, name = title_file_name)
322
+ # Remove the extension from file.name
323
+ file_name_without_extension = os.path.splitext(file.name)[0]
324
+ # Get the directory, filename, name, extension, and new extension of the waveform video path
325
+ video_dir, video_name, video_name, video_ext, video_new_ext = get_file_parts(waveform_video_path)
326
+
327
+ new_video_path = os.path.join(video_dir, title_file_name + video_new_ext)
328
+
329
+ mp4 = MP4(waveform_video_path)
330
+ mp4["©nam"] = title_file_name # Title tag
331
+ mp4["desc"] = f"{text}\n Duration: {str(initial_duration)}" # Description tag
332
+
333
+ commit = commit_hash()
334
+ metadata={
335
+ "prompt": text,
336
+ "negative_prompt": "",
337
+ "Seed": seed,
338
+ "steps": 1,
339
+ "width": "768px",
340
+ "height":"512px",
341
+ "Dimension": dimension,
342
+ "Top-k": topk,
343
+ "Top-p":topp,
344
+ "Randomness": temperature,
345
+ "cfg":cfg_coef,
346
+ "overlap": overlap,
347
+ "Melody Condition": melody_name,
348
+ "Sample Segment": prompt_index,
349
+ "Duration": initial_duration,
350
+ "Audio": file.name,
351
+ "font": settings_font,
352
+ "font_color": settings_font_color,
353
+ "harmony_only": harmony_only,
354
+ "background": background,
355
+ "include_title": include_title,
356
+ "include_settings": include_settings,
357
+ "profile": profile,
358
+ "commit": commit_hash(),
359
+ "tag": git_tag(),
360
+ "version": gr.__version__,
361
+ "model_version": MODEL.version,
362
+ "model_name": MODEL.name,
363
+ "model_description": f"{MODEL.audio_channels} channels, {MODEL.sample_rate} Hz",
364
+ "melody_name" : melody_name if melody_name else "",
365
+ "melody_extension" : melody_extension if melody_extension else "",
366
+ "hostname": "https://huggingface.co/spaces/Surn/UnlimitedMusicGen",
367
+ "version" : f"""https://huggingface.co/spaces/Surn/UnlimitedMusicGen/commit/{"huggingface" if commit == "<none>" else commit}""",
368
+ "python" : sys.version,
369
+ "torch" : getattr(torch, '__long_version__',torch.__version__),
370
+ "xformers": get_xformers_version(),
371
+ "gradio": gr.__version__,
372
+ "huggingface_space": os.environ.get('SPACE_ID', ''),
373
+ "CUDA": f"""{"CUDA is available. device: " + torch.cuda.get_device_name(0) + " version: " + torch.version.cuda if torch.cuda.is_available() else "CUDA is not available."}""",
374
+ }
375
+ # Add additional metadata from the metadata dictionary (if it exists)
376
+ for key, value in metadata.items():
377
+ mp4[key] = str(value) # Convert values to strings as required by mutagen
378
+
379
+ # Save the metadata changes to the file
380
+ mp4.save()
381
+
382
+ try:
383
+ if os.path.exists(new_video_path):
384
+ delete_file(new_video_path)
385
+ # Open the original MP4 file in binary read mode and the new file in binary write mode
386
+ with open(waveform_video_path, "rb") as src, open(new_video_path, "wb") as dst:
387
+ if os.path.exists(waveform_video_path):
388
+ # Copy the contents from the source file to the destination file
389
+ shutil.copyfileobj(src, dst)
390
+ waveform_video_path = new_video_path
391
+ except Exception as e:
392
+ print(f"Error copying file: {e}")
393
+
394
+ if waveform_video_path:
395
+ modules.user_history.save_file(
396
+ profile=profile,
397
+ image=background,
398
+ audio=file,
399
+ video=waveform_video_path,
400
+ label=text,
401
+ metadata=metadata,
402
+ )
403
+
404
+
405
  if MOVE_TO_CPU:
406
  MODEL.to('cpu')
407
  if UNLOAD_MODEL:
408
  MODEL = None
409
  torch.cuda.empty_cache()
410
  torch.cuda.ipc_collect()
411
+ return waveform_video_path, file.name, seed
412
 
413
+ gr.set_static_paths(paths=["fonts/","assets/"])
414
  def ui(**kwargs):
415
+ with gr.Blocks(title="UnlimitedMusicGen",css_paths="style_20250331.css", theme='Surn/beeuty') as interface:
416
+ with gr.Tab("UnlimitedMusicGen"):
417
+ gr.Markdown(
418
+ """
 
 
 
 
 
 
 
 
419
  # UnlimitedMusicGen
420
  This is your private demo for [UnlimitedMusicGen](https://github.com/Oncorporation/audiocraft), a simple and controllable model for music generation
421
  presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
422
 
423
  Disclaimer: This won't run on CPU only. Clone this App and run on GPU instance!
424
 
425
+ Todo: Working on improved Interrupt.
426
+ Theme Available at ["Surn/Beeuty"](https://huggingface.co/spaces/Surn/Beeuty)
427
+
428
+ """
429
+ )
430
+ if IS_SHARED_SPACE and not torch.cuda.is_available():
431
+ gr.Markdown("""
432
+ This Space doesn't work in this shared UI ⚠
433
+
434
+ <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
435
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
436
+ to use it privately, or use the <a href="https://huggingface.co/spaces/facebook/MusicGen">public demo</a>
437
+ """)
438
+ with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
 
439
  with gr.Column():
440
+ with gr.Row():
441
+ text = gr.Text(label="Describe your music", interactive=True, value="4/4 100bpm 320kbps 48khz, Industrial/Electronic Soundtrack, Dark, Intense, Sci-Fi")
442
+ with gr.Column():
443
+ duration = gr.Slider(minimum=1, maximum=720, value=10, label="Duration (s)", interactive=True)
444
+ model = gr.Radio(["melody", "medium", "small", "large", "melody-large", "stereo-small", "stereo-medium", "stereo-large", "stereo-melody", "stereo-melody-large"], label="AI Model", value="melody", interactive=True)
445
+ with gr.Row():
446
+ submit = gr.Button("Generate", elem_id="btn-generate")
447
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
448
+ _ = gr.Button("Interrupt", elem_id="btn-interrupt").click(fn=interrupt, queue=False)
449
+ with gr.Row():
450
+ with gr.Column():
451
+ radio = gr.Radio(["file", "mic"], value="file", label="Condition on a melody (optional) File or Mic")
452
+ melody_filepath = gr.Audio(sources=["upload"], type="filepath", label="Melody Condition (optional)", interactive=True, elem_id="melody-input")
453
+ with gr.Column():
454
+ harmony_only = gr.Radio(label="Use Harmony Only",choices=["No", "Yes"], value="No", interactive=True, info="Remove Drums?")
455
+ prompt_index = gr.Slider(label="Melody Condition Sample Segment", minimum=-1, maximum=MAX_PROMPT_INDEX, step=1, value=0, interactive=True, info="Which 30 second segment to condition with, - 1 condition each segment independantly")
456
+ with gr.Accordion("Video", open=False):
457
+ with gr.Row():
458
+ background= gr.Image(value="./assets/background.png", sources=["upload"], label="Background", width=768, height=512, type="filepath", interactive=True)
459
+ with gr.Column():
460
+ include_title = gr.Checkbox(label="Add Title", value=True, interactive=True)
461
+ include_settings = gr.Checkbox(label="Add Settings to background", value=True, interactive=True)
462
+ with gr.Row():
463
+ title = gr.Textbox(label="Title", value="UnlimitedMusicGen", interactive=True)
464
+ settings_font = gr.Text(label="Settings Font", value="./assets/arial.ttf", interactive=True)
465
+ settings_font_color = gr.ColorPicker(label="Settings Font Color", value="#c87f05", interactive=True)
466
+ with gr.Accordion("Expert", open=False):
467
+ with gr.Row():
468
+ overlap = gr.Slider(minimum=0, maximum=15, value=2, step=1, label="Verse Overlap", interactive=True)
469
+ dimension = gr.Slider(minimum=-2, maximum=2, value=2, step=1, label="Dimension", info="determines which direction to add new segements of audio. (1 = stack tracks, 2 = lengthen, -2..0 = ?)", interactive=True)
470
+ with gr.Row():
471
+ topk = gr.Number(label="Top-k", value=280, precision=0, interactive=True)
472
+ topp = gr.Number(label="Top-p", value=1150, precision=0, interactive=True)
473
+ temperature = gr.Number(label="Randomness Temperature", value=0.7, precision=None, interactive=True)
474
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=8.5, precision=None, interactive=True)
475
+ with gr.Row():
476
+ seed = gr.Number(label="Seed", value=-1, precision=0, interactive=True)
477
+ gr.Button('\U0001f3b2\ufe0f', elem_classes="small-btn").click(fn=lambda: -1, outputs=[seed], queue=False)
478
+ reuse_seed = gr.Button('\u267b\ufe0f', elem_classes="small-btn")
479
+ with gr.Column() as c:
480
+ output = gr.Video(label="Generated Music")
481
+ wave_file = gr.File(label=".wav file", elem_id="output_wavefile", interactive=True)
482
+ seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
483
+
484
+ radio.change(toggle_audio_src, radio, [melody_filepath], queue=False, show_progress=False)
485
+ melody_filepath.change(load_melody_filepath, inputs=[melody_filepath, title], outputs=[title, prompt_index , model], api_name="melody_filepath_change", queue=False)
486
+ reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False, api_name="reuse_seed")
487
+ submit.click(predict, inputs=[model, text,melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap, prompt_index, include_title, include_settings, harmony_only], outputs=[output, wave_file, seed_used], api_name="submit")
488
+ gr.Examples(
489
+ fn=predict,
490
+ examples=[
491
+ [
492
+ "4/4 120bpm 320kbps 48khz, An 80s driving pop song with heavy drums and synth pads in the background",
493
+ "./assets/bach.mp3",
494
+ "stereo-melody-large",
495
+ "80s Pop Synth"
496
+ ],
497
+ [
498
+ "4/4 120bpm 320kbps 48khz, A cheerful country song with acoustic guitars",
499
+ "./assets/bolero_ravel.mp3",
500
+ "melody",
501
+ "Country Guitar"
502
+ ],
503
+ [
504
+ "4/4 120bpm 320kbps 48khz, 90s rock song with electric guitar and heavy drums",
505
+ None,
506
+ "stereo-medium",
507
+ "90s Rock Guitar"
508
+ ],
509
+ [
510
+ "4/4 120bpm 320kbps 48khz, a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
511
+ "./assets/bach.mp3",
512
+ "melody-large",
513
+ "EDM my Bach"
514
+ ],
515
+ [
516
+ "4/4 320kbps 48khz, lofi slow bpm electro chill with organic samples",
517
+ None,
518
+ "medium",
519
+ "LoFi Chill"
520
+ ],
521
  ],
522
+ inputs=[text, melody_filepath, model, title],
523
+ outputs=[output]
524
+ )
525
+ gr.HTML(value=versions_html(), visible=True, elem_id="versions")
526
+ with gr.Tab("User History") as history_tab:
527
+ modules.user_history.render()
528
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  # Show the interface
530
  launch_kwargs = {}
531
  share = kwargs.get('share', False)
 
539
  if share:
540
  launch_kwargs['share'] = share
541
  launch_kwargs['favicon_path']= "./assets/favicon.ico"
542
+
543
 
544
 
545
+ interface.queue(max_size=10, api_open=False).launch(**launch_kwargs)
 
546
 
547
  if __name__ == "__main__":
548
  parser = argparse.ArgumentParser()
 
586
  args = parser.parse_args()
587
 
588
  launch_kwargs = {}
589
+ launch_kwargs['listen'] = args.listen
590
 
591
  if args.username and args.password:
592
  launch_kwargs['auth'] = (args.username, args.password)
 
596
  launch_kwargs['inbrowser'] = args.inbrowser
597
  if args.share:
598
  launch_kwargs['share'] = args.share
599
+ launch_kwargs['favicon_path']= "./assets/favicon.ico"
600
 
601
 
602
  UNLOAD_MODEL = args.unload_model
 
606
 
607
  ui(
608
  unload_to_cpu = MOVE_TO_CPU,
609
+ share=args.share,
610
+ **launch_kwargs,
611
  )
assets/KuritaSurnLogox64.png ADDED
assets/Vermilion-Musical-Notes-Typography-No-Background.svg ADDED
assets/a_duck_quacking_as_birds_chirp_and_a_pigeon_cooing.mp3 ADDED
Binary file (15.2 kB). View file
 
assets/icon_delete.png ADDED
assets/icon_download.png ADDED
assets/icon_refresh.png ADDED
assets/logo_animation_256.gif ADDED

Git LFS Details

  • SHA256: 84aa8c95f88f4c9d110dc87c344ec92786e8b5c464ac8141a9c3b12bedf2ed71
  • Pointer size: 131 Bytes
  • Size of remote file: 140 kB
assets/screenshot.png ADDED

Git LFS Details

  • SHA256: 89abfaffefc18124ffe8f0775eb4dcdbc589bb97befc76dd3f3fc48992991e2e
  • Pointer size: 131 Bytes
  • Size of remote file: 388 kB
assets/sirens_and_a_humming_engine_approach_and_pass.mp3 ADDED
Binary file (15.2 kB). View file
 
audiocraft/__init__.py CHANGED
@@ -7,4 +7,4 @@
7
  # flake8: noqa
8
  from . import data, modules, models
9
 
10
- __version__ = '0.0.2a2'
 
7
  # flake8: noqa
8
  from . import data, modules, models
9
 
10
+ __version__ = '1.4.Surn'
audiocraft/data/__init__.py CHANGED
@@ -5,4 +5,4 @@
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  # flake8: noqa
8
- from . import audio, audio_dataset
 
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  # flake8: noqa
8
+ from . import audio, audio_dataset, info_audio_dataset
audiocraft/data/audio.py CHANGED
@@ -21,6 +21,7 @@ from torch.nn import functional as F
21
  import torchaudio as ta
22
 
23
  import av
 
24
 
25
  from .audio_utils import f32_pcm, i16_pcm, normalize_audio, convert_audio
26
 
@@ -149,7 +150,17 @@ def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
149
  wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
150
  return wav, sr
151
 
152
-
 
 
 
 
 
 
 
 
 
 
153
  def audio_write(stem_name: tp.Union[str, Path],
154
  wav: torch.Tensor, sample_rate: int,
155
  format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
@@ -215,3 +226,77 @@ def audio_write(stem_name: tp.Union[str, Path],
215
  path.unlink()
216
  raise
217
  return path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  import torchaudio as ta
22
 
23
  import av
24
+ import subprocess as sp
25
 
26
  from .audio_utils import f32_pcm, i16_pcm, normalize_audio, convert_audio
27
 
 
150
  wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
151
  return wav, sr
152
 
153
+ def _piping_to_ffmpeg(out_path: tp.Union[str, Path], wav: torch.Tensor, sample_rate: int, flags: tp.List[str]):
154
+ # ffmpeg is always installed and torchaudio is a bit unstable lately, so let's bypass it entirely.
155
+ assert wav.dim() == 2, wav.shape
156
+ command = [
157
+ 'ffmpeg',
158
+ '-loglevel', 'error',
159
+ '-y', '-f', 'f32le', '-ar', str(sample_rate), '-ac', str(wav.shape[0]),
160
+ '-i', '-'] + flags + [str(out_path)]
161
+ input_ = f32_pcm(wav).t().detach().cpu().numpy().tobytes()
162
+ sp.run(command, input=input_, check=True)
163
+
164
  def audio_write(stem_name: tp.Union[str, Path],
165
  wav: torch.Tensor, sample_rate: int,
166
  format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
 
226
  path.unlink()
227
  raise
228
  return path
229
+
230
+ def audio_write2(stem_name: tp.Union[str, Path],
231
+ wav: torch.Tensor, sample_rate: int,
232
+ format: str = 'wav', mp3_rate: int = 320, ogg_rate: tp.Optional[int] = None,
233
+ normalize: bool = True, strategy: str = 'peak', peak_clip_headroom_db: float = 1,
234
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
235
+ loudness_compressor: bool = False,
236
+ log_clipping: bool = True, make_parent_dir: bool = True,
237
+ add_suffix: bool = True) -> Path:
238
+ """Convenience function for saving audio to disk. Returns the filename the audio was written to.
239
+
240
+ Args:
241
+ stem_name (str or Path): Filename without extension which will be added automatically.
242
+ wav (torch.Tensor): Audio data to save.
243
+ sample_rate (int): Sample rate of audio data.
244
+ format (str): Either "wav", "mp3", "ogg", or "flac".
245
+ mp3_rate (int): kbps when using mp3s.
246
+ ogg_rate (int): kbps when using ogg/vorbis. If not provided, let ffmpeg decide for itself.
247
+ normalize (bool): if `True` (default), normalizes according to the prescribed
248
+ strategy (see after). If `False`, the strategy is only used in case clipping
249
+ would happen.
250
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
251
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
252
+ with extra headroom to avoid clipping. 'clip' just clips.
253
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
254
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
255
+ than the `peak_clip` one to avoid further clipping.
256
+ loudness_headroom_db (float): Target loudness for loudness normalization.
257
+ loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
258
+ when strategy is 'loudness' log_clipping (bool): If True, basic logging on stderr when clipping still
259
+ occurs despite strategy (only for 'rms').
260
+ make_parent_dir (bool): Make parent directory if it doesn't exist.
261
+ Returns:
262
+ Path: Path of the saved audio.
263
+ """
264
+ assert wav.dtype.is_floating_point, "wav is not floating point"
265
+ if wav.dim() == 1:
266
+ wav = wav[None]
267
+ elif wav.dim() > 2:
268
+ raise ValueError("Input wav should be at most 2 dimension.")
269
+ assert wav.isfinite().all()
270
+ wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
271
+ rms_headroom_db, loudness_headroom_db, loudness_compressor,
272
+ log_clipping=log_clipping, sample_rate=sample_rate,
273
+ stem_name=str(stem_name))
274
+ if format == 'mp3':
275
+ suffix = '.mp3'
276
+ flags = ['-f', 'mp3', '-c:a', 'libmp3lame', '-b:a', f'{mp3_rate}k']
277
+ elif format == 'wav':
278
+ suffix = '.wav'
279
+ flags = ['-f', 'wav', '-c:a', 'pcm_s16le']
280
+ elif format == 'ogg':
281
+ suffix = '.ogg'
282
+ flags = ['-f', 'ogg', '-c:a', 'libvorbis']
283
+ if ogg_rate is not None:
284
+ flags += ['-b:a', f'{ogg_rate}k']
285
+ elif format == 'flac':
286
+ suffix = '.flac'
287
+ flags = ['-f', 'flac']
288
+ else:
289
+ raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
290
+ if not add_suffix:
291
+ suffix = ''
292
+ path = Path(str(stem_name) + suffix)
293
+ if make_parent_dir:
294
+ path.parent.mkdir(exist_ok=True, parents=True)
295
+ try:
296
+ _piping_to_ffmpeg(path, wav, sample_rate, flags)
297
+ except Exception:
298
+ if path.exists():
299
+ # we do not want to leave half written files around.
300
+ path.unlink()
301
+ raise
302
+ return path
audiocraft/data/audio_dataset.py CHANGED
@@ -3,12 +3,16 @@
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
-
 
 
 
7
  import argparse
8
  import copy
9
  from concurrent.futures import ThreadPoolExecutor, Future
10
  from dataclasses import dataclass, fields
11
  from contextlib import ExitStack
 
12
  import gzip
13
  import json
14
  import logging
@@ -81,9 +85,12 @@ class AudioMeta(BaseInfo):
81
  class SegmentInfo(BaseInfo):
82
  meta: AudioMeta
83
  seek_time: float
84
- n_frames: int # actual number of frames without padding
 
 
85
  total_frames: int # total number of frames, padding included
86
- sample_rate: int # actual sample rate
 
87
 
88
 
89
  DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
@@ -114,8 +121,8 @@ def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
114
 
115
  Args:
116
  m (AudioMeta): Audio meta to resolve.
117
- fast (bool): If True, uses a really fast check for determining if a file is already absolute or not.
118
- Only valid on Linux/Mac.
119
  Returns:
120
  AudioMeta: Audio meta with resolved path.
121
  """
@@ -151,7 +158,7 @@ def find_audio_files(path: tp.Union[Path, str],
151
  progress (bool): Whether to log progress on audio files collection.
152
  workers (int): number of parallel workers, if 0, use only the current thread.
153
  Returns:
154
- List[AudioMeta]: List of audio file path and its metadata.
155
  """
156
  audio_files = []
157
  futures: tp.List[Future] = []
@@ -203,7 +210,7 @@ def load_audio_meta(path: tp.Union[str, Path],
203
  resolve (bool): Whether to resolve the path from AudioMeta (default=True).
204
  fast (bool): activates some tricks to make things faster.
205
  Returns:
206
- List[AudioMeta]: List of audio file path and its total duration.
207
  """
208
  open_fn = gzip.open if str(path).lower().endswith('.gz') else open
209
  with open_fn(path, 'rb') as fp: # type: ignore
@@ -250,9 +257,14 @@ class AudioDataset:
250
  allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
251
  original audio meta.
252
 
 
 
 
 
 
253
  Args:
254
- meta (tp.List[AudioMeta]): List of audio files metadata.
255
- segment_duration (float): Optional segment duration of audio to load.
256
  If not specified, the dataset will load the full audio segment from the file.
257
  shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
258
  sample_rate (int): Target sample rate of the loaded audio samples.
@@ -266,10 +278,19 @@ class AudioDataset:
266
  is shorter than the desired segment.
267
  max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
268
  return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
269
- min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided
270
  audio shorter than this will be filtered out.
271
- max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided
272
  audio longer than this will be filtered out.
 
 
 
 
 
 
 
 
 
273
  """
274
  def __init__(self,
275
  meta: tp.List[AudioMeta],
@@ -285,16 +306,14 @@ class AudioDataset:
285
  max_read_retry: int = 10,
286
  return_info: bool = False,
287
  min_audio_duration: tp.Optional[float] = None,
288
- max_audio_duration: tp.Optional[float] = None
 
 
 
289
  ):
290
- assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.'
291
  assert segment_duration is None or segment_duration > 0
292
  assert segment_duration is None or min_segment_ratio >= 0
293
- logging.debug(f'sample_on_duration: {sample_on_duration}')
294
- logging.debug(f'sample_on_weight: {sample_on_weight}')
295
- logging.debug(f'pad: {pad}')
296
- logging.debug(f'min_segment_ratio: {min_segment_ratio}')
297
-
298
  self.segment_duration = segment_duration
299
  self.min_segment_ratio = min_segment_ratio
300
  self.max_audio_duration = max_audio_duration
@@ -317,13 +336,25 @@ class AudioDataset:
317
  self.sampling_probabilities = self._get_sampling_probabilities()
318
  self.max_read_retry = max_read_retry
319
  self.return_info = return_info
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  def __len__(self):
322
  return self.num_samples
323
 
324
  def _get_sampling_probabilities(self, normalized: bool = True):
325
- """Return the sampling probabilities for each file inside `self.meta`.
326
- """
327
  scores: tp.List[float] = []
328
  for file_meta in self.meta:
329
  score = 1.
@@ -337,12 +368,32 @@ class AudioDataset:
337
  probabilities /= probabilities.sum()
338
  return probabilities
339
 
340
- def sample_file(self, rng: torch.Generator) -> AudioMeta:
341
- """Sample a given file from `self.meta`. Can be overriden in subclasses.
 
 
 
 
 
 
 
 
 
342
  This is only called if `segment_duration` is not None.
343
 
344
  You must use the provided random number generator `rng` for reproducibility.
 
345
  """
 
 
 
 
 
 
 
 
 
 
346
  if not self.sample_on_weight and not self.sample_on_duration:
347
  file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
348
  else:
@@ -350,6 +401,15 @@ class AudioDataset:
350
 
351
  return self.meta[file_index]
352
 
 
 
 
 
 
 
 
 
 
353
  def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
354
  if self.segment_duration is None:
355
  file_meta = self.meta[index]
@@ -357,18 +417,22 @@ class AudioDataset:
357
  out = convert_audio(out, sr, self.sample_rate, self.channels)
358
  n_frames = out.shape[-1]
359
  segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
360
- sample_rate=self.sample_rate)
361
  else:
362
  rng = torch.Generator()
363
  if self.shuffle:
364
- # We use index, plus extra randomness
365
- rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
 
 
 
 
366
  else:
367
  # We only use index
368
  rng.manual_seed(index)
369
 
370
  for retry in range(self.max_read_retry):
371
- file_meta = self.sample_file(rng)
372
  # We add some variance in the file position even if audio file is smaller than segment
373
  # without ending up with empty segments
374
  max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
@@ -381,7 +445,7 @@ class AudioDataset:
381
  if self.pad:
382
  out = F.pad(out, (0, target_frames - n_frames))
383
  segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
384
- sample_rate=self.sample_rate)
385
  except Exception as exc:
386
  logger.warning("Error opening file %s: %r", file_meta.path, exc)
387
  if retry == self.max_read_retry - 1:
@@ -423,7 +487,7 @@ class AudioDataset:
423
  if to_pad:
424
  # Each wav could be of a different duration as they are not segmented.
425
  for i in range(len(samples)):
426
- # Determines the total legth of the signal with padding, so we update here as we pad.
427
  segment_infos[i].total_frames = max_len
428
  wavs[i] = _pad_wav(wavs[i])
429
 
@@ -436,9 +500,7 @@ class AudioDataset:
436
  return torch.stack(samples)
437
 
438
  def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
439
- """Filters out audio files with short durations.
440
- Removes from meta files that have durations that will not allow to samples examples from them.
441
- """
442
  orig_len = len(meta)
443
 
444
  # Filter data that is too short.
 
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
+ """AudioDataset support. In order to handle a larger number of files
7
+ without having to scan again the folders, we precompute some metadata
8
+ (filename, sample rate, duration), and use that to efficiently sample audio segments.
9
+ """
10
  import argparse
11
  import copy
12
  from concurrent.futures import ThreadPoolExecutor, Future
13
  from dataclasses import dataclass, fields
14
  from contextlib import ExitStack
15
+ from functools import lru_cache
16
  import gzip
17
  import json
18
  import logging
 
85
  class SegmentInfo(BaseInfo):
86
  meta: AudioMeta
87
  seek_time: float
88
+ # The following values are given once the audio is processed, e.g.
89
+ # at the target sample rate and target number of channels.
90
+ n_frames: int # actual number of frames without padding
91
  total_frames: int # total number of frames, padding included
92
+ sample_rate: int # actual sample rate
93
+ channels: int # number of audio channels.
94
 
95
 
96
  DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
 
121
 
122
  Args:
123
  m (AudioMeta): Audio meta to resolve.
124
+ fast (bool): If True, uses a really fast check for determining if a file
125
+ is already absolute or not. Only valid on Linux/Mac.
126
  Returns:
127
  AudioMeta: Audio meta with resolved path.
128
  """
 
158
  progress (bool): Whether to log progress on audio files collection.
159
  workers (int): number of parallel workers, if 0, use only the current thread.
160
  Returns:
161
+ list of AudioMeta: List of audio file path and its metadata.
162
  """
163
  audio_files = []
164
  futures: tp.List[Future] = []
 
210
  resolve (bool): Whether to resolve the path from AudioMeta (default=True).
211
  fast (bool): activates some tricks to make things faster.
212
  Returns:
213
+ list of AudioMeta: List of audio file path and its total duration.
214
  """
215
  open_fn = gzip.open if str(path).lower().endswith('.gz') else open
216
  with open_fn(path, 'rb') as fp: # type: ignore
 
257
  allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
258
  original audio meta.
259
 
260
+ Note that you can call `start_epoch(epoch)` in order to get
261
+ a deterministic "randomization" for `shuffle=True`.
262
+ For a given epoch and dataset index, this will always return the same extract.
263
+ You can get back some diversity by setting the `shuffle_seed` param.
264
+
265
  Args:
266
+ meta (list of AudioMeta): List of audio files metadata.
267
+ segment_duration (float, optional): Optional segment duration of audio to load.
268
  If not specified, the dataset will load the full audio segment from the file.
269
  shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
270
  sample_rate (int): Target sample rate of the loaded audio samples.
 
278
  is shorter than the desired segment.
279
  max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
280
  return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
281
+ min_audio_duration (float, optional): Minimum audio file duration, in seconds, if provided
282
  audio shorter than this will be filtered out.
283
+ max_audio_duration (float, optional): Maximal audio file duration in seconds, if provided
284
  audio longer than this will be filtered out.
285
+ shuffle_seed (int): can be used to further randomize
286
+ load_wav (bool): if False, skip loading the wav but returns a tensor of 0
287
+ with the expected segment_duration (which must be provided if load_wav is False).
288
+ permutation_on_files (bool): only if `sample_on_weight` and `sample_on_duration`
289
+ are False. Will ensure a permutation on files when going through the dataset.
290
+ In that case the epoch number must be provided in order for the model
291
+ to continue the permutation across epochs. In that case, it is assumed
292
+ that `num_samples = total_batch_size * num_updates_per_epoch`, with
293
+ `total_batch_size` the overall batch size accounting for all gpus.
294
  """
295
  def __init__(self,
296
  meta: tp.List[AudioMeta],
 
306
  max_read_retry: int = 10,
307
  return_info: bool = False,
308
  min_audio_duration: tp.Optional[float] = None,
309
+ max_audio_duration: tp.Optional[float] = None,
310
+ shuffle_seed: int = 0,
311
+ load_wav: bool = True,
312
+ permutation_on_files: bool = False,
313
  ):
314
+ assert len(meta) > 0, "No audio meta provided to AudioDataset. Please check loading of audio meta."
315
  assert segment_duration is None or segment_duration > 0
316
  assert segment_duration is None or min_segment_ratio >= 0
 
 
 
 
 
317
  self.segment_duration = segment_duration
318
  self.min_segment_ratio = min_segment_ratio
319
  self.max_audio_duration = max_audio_duration
 
336
  self.sampling_probabilities = self._get_sampling_probabilities()
337
  self.max_read_retry = max_read_retry
338
  self.return_info = return_info
339
+ self.shuffle_seed = shuffle_seed
340
+ self.current_epoch: tp.Optional[int] = None
341
+ self.load_wav = load_wav
342
+ if not load_wav:
343
+ assert segment_duration is not None
344
+ self.permutation_on_files = permutation_on_files
345
+ if permutation_on_files:
346
+ assert not self.sample_on_duration
347
+ assert not self.sample_on_weight
348
+ assert self.shuffle
349
+
350
+ def start_epoch(self, epoch: int):
351
+ self.current_epoch = epoch
352
 
353
  def __len__(self):
354
  return self.num_samples
355
 
356
  def _get_sampling_probabilities(self, normalized: bool = True):
357
+ """Return the sampling probabilities for each file inside `self.meta`."""
 
358
  scores: tp.List[float] = []
359
  for file_meta in self.meta:
360
  score = 1.
 
368
  probabilities /= probabilities.sum()
369
  return probabilities
370
 
371
+ @staticmethod
372
+ @lru_cache(16)
373
+ def _get_file_permutation(num_files: int, permutation_index: int, base_seed: int):
374
+ # Used to keep the most recent files permutation in memory implicitely.
375
+ # will work unless someone is using a lot of Datasets in parallel.
376
+ rng = torch.Generator()
377
+ rng.manual_seed(base_seed + permutation_index)
378
+ return torch.randperm(num_files, generator=rng)
379
+
380
+ def sample_file(self, index: int, rng: torch.Generator) -> AudioMeta:
381
+ """Sample a given file from `self.meta`. Can be overridden in subclasses.
382
  This is only called if `segment_duration` is not None.
383
 
384
  You must use the provided random number generator `rng` for reproducibility.
385
+ You can further make use of the index accessed.
386
  """
387
+ if self.permutation_on_files:
388
+ assert self.current_epoch is not None
389
+ total_index = self.current_epoch * len(self) + index
390
+ permutation_index = total_index // len(self.meta)
391
+ relative_index = total_index % len(self.meta)
392
+ permutation = AudioDataset._get_file_permutation(
393
+ len(self.meta), permutation_index, self.shuffle_seed)
394
+ file_index = permutation[relative_index]
395
+ return self.meta[file_index]
396
+
397
  if not self.sample_on_weight and not self.sample_on_duration:
398
  file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
399
  else:
 
401
 
402
  return self.meta[file_index]
403
 
404
+ def _audio_read(self, path: str, seek_time: float = 0, duration: float = -1):
405
+ # Override this method in subclass if needed.
406
+ if self.load_wav:
407
+ return audio_read(path, seek_time, duration, pad=False)
408
+ else:
409
+ assert self.segment_duration is not None
410
+ n_frames = int(self.sample_rate * self.segment_duration)
411
+ return torch.zeros(self.channels, n_frames), self.sample_rate
412
+
413
  def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
414
  if self.segment_duration is None:
415
  file_meta = self.meta[index]
 
417
  out = convert_audio(out, sr, self.sample_rate, self.channels)
418
  n_frames = out.shape[-1]
419
  segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
420
+ sample_rate=self.sample_rate, channels=out.shape[0])
421
  else:
422
  rng = torch.Generator()
423
  if self.shuffle:
424
+ # We use index, plus extra randomness, either totally random if we don't know the epoch.
425
+ # otherwise we make use of the epoch number and optional shuffle_seed.
426
+ if self.current_epoch is None:
427
+ rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
428
+ else:
429
+ rng.manual_seed(index + self.num_samples * (self.current_epoch + self.shuffle_seed))
430
  else:
431
  # We only use index
432
  rng.manual_seed(index)
433
 
434
  for retry in range(self.max_read_retry):
435
+ file_meta = self.sample_file(index, rng)
436
  # We add some variance in the file position even if audio file is smaller than segment
437
  # without ending up with empty segments
438
  max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
 
445
  if self.pad:
446
  out = F.pad(out, (0, target_frames - n_frames))
447
  segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
448
+ sample_rate=self.sample_rate, channels=out.shape[0])
449
  except Exception as exc:
450
  logger.warning("Error opening file %s: %r", file_meta.path, exc)
451
  if retry == self.max_read_retry - 1:
 
487
  if to_pad:
488
  # Each wav could be of a different duration as they are not segmented.
489
  for i in range(len(samples)):
490
+ # Determines the total length of the signal with padding, so we update here as we pad.
491
  segment_infos[i].total_frames = max_len
492
  wavs[i] = _pad_wav(wavs[i])
493
 
 
500
  return torch.stack(samples)
501
 
502
  def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
503
+ """Filters out audio files with audio durations that will not allow to sample examples from them."""
 
 
504
  orig_len = len(meta)
505
 
506
  # Filter data that is too short.
audiocraft/data/audio_utils.py CHANGED
@@ -3,7 +3,8 @@
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
-
 
7
  import sys
8
  import typing as tp
9
 
@@ -150,17 +151,19 @@ def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
150
  """
151
  if wav.dtype.is_floating_point:
152
  return wav
153
- else:
154
- assert wav.dtype == torch.int16
155
  return wav.float() / 2**15
 
 
 
156
 
157
 
158
  def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
159
  """Convert audio to int 16 bits PCM format.
160
 
161
- ..Warning:: There exist many formula for doing this convertion. None are perfect
162
- due to the asymetry of the int16 range. One either have possible clipping, DC offset,
163
- or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom,
164
  it is possible that `i16_pcm(f32_pcm)) != Identity`.
165
  """
166
  if wav.dtype.is_floating_point:
 
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
+ """Various utilities for audio convertion (pcm format, sample rate and channels),
7
+ and volume normalization."""
8
  import sys
9
  import typing as tp
10
 
 
151
  """
152
  if wav.dtype.is_floating_point:
153
  return wav
154
+ elif wav.dtype == torch.int16:
 
155
  return wav.float() / 2**15
156
+ elif wav.dtype == torch.int32:
157
+ return wav.float() / 2**31
158
+ raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
159
 
160
 
161
  def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
162
  """Convert audio to int 16 bits PCM format.
163
 
164
+ ..Warning:: There exist many formula for doing this conversion. None are perfect
165
+ due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
166
+ or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
167
  it is possible that `i16_pcm(f32_pcm)) != Identity`.
168
  """
169
  if wav.dtype.is_floating_point:
audiocraft/data/info_audio_dataset.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Base classes for the datasets that also provide non-audio metadata,
7
+ e.g. description, text transcription etc.
8
+ """
9
+ from dataclasses import dataclass
10
+ import logging
11
+ import math
12
+ import re
13
+ import typing as tp
14
+
15
+ import torch
16
+
17
+ from .audio_dataset import AudioDataset, AudioMeta
18
+ from ..environment import AudioCraftEnvironment
19
+ from ..modules.conditioners import SegmentWithAttributes, ConditioningAttributes
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
26
+ """Monkey-patch meta to match cluster specificities."""
27
+ meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
28
+ if meta.info_path is not None:
29
+ meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
30
+ return meta
31
+
32
+
33
+ def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
34
+ """Monkey-patch all meta to match cluster specificities."""
35
+ return [_clusterify_meta(m) for m in meta]
36
+
37
+
38
+ @dataclass
39
+ class AudioInfo(SegmentWithAttributes):
40
+ """Dummy SegmentInfo with empty attributes.
41
+
42
+ The InfoAudioDataset is expected to return metadata that inherits
43
+ from SegmentWithAttributes class and can return conditioning attributes.
44
+
45
+ This basically guarantees all datasets will be compatible with current
46
+ solver that contain conditioners requiring this.
47
+ """
48
+ audio_tokens: tp.Optional[torch.Tensor] = None # populated when using cached batch for training a LM.
49
+
50
+ def to_condition_attributes(self) -> ConditioningAttributes:
51
+ return ConditioningAttributes()
52
+
53
+
54
+ class InfoAudioDataset(AudioDataset):
55
+ """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.
56
+
57
+ See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
58
+ """
59
+ def __init__(self, meta: tp.List[AudioMeta], **kwargs):
60
+ super().__init__(clusterify_all_meta(meta), **kwargs)
61
+
62
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
63
+ if not self.return_info:
64
+ wav = super().__getitem__(index)
65
+ assert isinstance(wav, torch.Tensor)
66
+ return wav
67
+ wav, meta = super().__getitem__(index)
68
+ return wav, AudioInfo(**meta.to_dict())
69
+
70
+
71
+ def get_keyword_or_keyword_list(value: tp.Optional[str]) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
72
+ """Preprocess a single keyword or possible a list of keywords."""
73
+ if isinstance(value, list):
74
+ return get_keyword_list(value)
75
+ else:
76
+ return get_keyword(value)
77
+
78
+
79
+ def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
80
+ """Preprocess a single keyword."""
81
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
82
+ return None
83
+ else:
84
+ return value.strip()
85
+
86
+
87
+ def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
88
+ """Preprocess a single keyword."""
89
+ if value is None or (not isinstance(value, str)) or len(value) == 0 or value == 'None':
90
+ return None
91
+ else:
92
+ return value.strip().lower()
93
+
94
+
95
+ def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
96
+ """Preprocess a list of keywords."""
97
+ if isinstance(values, str):
98
+ values = [v.strip() for v in re.split(r'[,\s]', values)]
99
+ elif isinstance(values, float) and math.isnan(values):
100
+ values = []
101
+ if not isinstance(values, list):
102
+ logger.debug(f"Unexpected keyword list {values}")
103
+ values = [str(values)]
104
+
105
+ kws = [get_keyword(v) for v in values]
106
+ kw_list = [k for k in kws if k is not None]
107
+ if len(kw_list) == 0:
108
+ return None
109
+ else:
110
+ return kw_list
audiocraft/data/zip.py CHANGED
@@ -3,6 +3,8 @@
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
 
 
6
 
7
  import typing
8
  import zipfile
@@ -18,13 +20,13 @@ MODE = Literal['r', 'w', 'x', 'a']
18
 
19
  @dataclass(order=True)
20
  class PathInZip:
21
- """Class for holding a path of file within a zip file.
22
 
23
  Args:
24
- path: The convention is <path_to_zip>:<relative_path_inside_zip>
25
  Let's assume there is a zip file /some/location/foo.zip
26
  and inside of it is a json file located at /data/file1.json,
27
- Then we expect path = "/some/location/foo.zip:/data/file1.json"
28
  """
29
 
30
  INFO_PATH_SEP = ':'
@@ -55,7 +57,7 @@ def set_zip_cache_size(max_size: int):
55
  """Sets the maximal LRU caching for zip file opening.
56
 
57
  Args:
58
- max_size: the maximal LRU cache.
59
  """
60
  global _cached_open_zip
61
  _cached_open_zip = lru_cache(max_size)(_open_zip)
@@ -65,8 +67,8 @@ def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
65
  """Opens a file stored inside a zip and returns a file-like object.
66
 
67
  Args:
68
- path_in_zip: A PathInZip object representing the file to return a file-like object of.
69
- mode: The mode in which to open the file with.
70
  Returns:
71
  A file-like object for PathInZip.
72
  """
 
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
+ """Utility for reading some info from inside a zip file.
7
+ """
8
 
9
  import typing
10
  import zipfile
 
20
 
21
  @dataclass(order=True)
22
  class PathInZip:
23
+ """Hold a path of file within a zip file.
24
 
25
  Args:
26
+ path (str): The convention is <path_to_zip>:<relative_path_inside_zip>.
27
  Let's assume there is a zip file /some/location/foo.zip
28
  and inside of it is a json file located at /data/file1.json,
29
+ Then we expect path = "/some/location/foo.zip:/data/file1.json".
30
  """
31
 
32
  INFO_PATH_SEP = ':'
 
57
  """Sets the maximal LRU caching for zip file opening.
58
 
59
  Args:
60
+ max_size (int): the maximal LRU cache.
61
  """
62
  global _cached_open_zip
63
  _cached_open_zip = lru_cache(max_size)(_open_zip)
 
67
  """Opens a file stored inside a zip and returns a file-like object.
68
 
69
  Args:
70
+ path_in_zip (PathInZip): A PathInZip object representing the file to return a file-like object of.
71
+ mode (str): The mode in which to open the file with.
72
  Returns:
73
  A file-like object for PathInZip.
74
  """
audiocraft/environment.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Provides cluster and tools configuration across clusters (slurm, dora, utilities).
9
+ """
10
+
11
+ import logging
12
+ import os
13
+ from pathlib import Path
14
+ import re
15
+ import typing as tp
16
+
17
+ import omegaconf
18
+
19
+ from .utils.cluster import _guess_cluster_type
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class AudioCraftEnvironment:
26
+ """Environment configuration for teams and clusters.
27
+
28
+ AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment
29
+ or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment
30
+ provides pointers to a reference folder resolved automatically across clusters that is shared across team members,
31
+ allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically
32
+ map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters.
33
+
34
+ The cluster type is identified automatically and base configuration file is read from config/teams.yaml.
35
+ Use the following environment variables to specify the cluster, team or configuration:
36
+
37
+ AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type
38
+ cannot be inferred automatically.
39
+ AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration.
40
+ If not set, configuration is read from config/teams.yaml.
41
+ AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team.
42
+ Cluster configuration are shared across teams to match compute allocation,
43
+ specify your cluster configuration in the configuration file under a key mapping
44
+ your team name.
45
+ """
46
+ _instance = None
47
+ DEFAULT_TEAM = "default"
48
+
49
+ def __init__(self) -> None:
50
+ """Loads configuration."""
51
+ self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM)
52
+ cluster_type = _guess_cluster_type()
53
+ cluster = os.getenv(
54
+ "AUDIOCRAFT_CLUSTER", cluster_type.value
55
+ )
56
+ logger.info("Detecting cluster type %s", cluster_type)
57
+
58
+ self.cluster: str = cluster
59
+
60
+ config_path = os.getenv(
61
+ "AUDIOCRAFT_CONFIG",
62
+ Path(__file__)
63
+ .parent.parent.joinpath("config/teams", self.team)
64
+ .with_suffix(".yaml"),
65
+ )
66
+ self.config = omegaconf.OmegaConf.load(config_path)
67
+ self._dataset_mappers = []
68
+ cluster_config = self._get_cluster_config()
69
+ if "dataset_mappers" in cluster_config:
70
+ for pattern, repl in cluster_config["dataset_mappers"].items():
71
+ regex = re.compile(pattern)
72
+ self._dataset_mappers.append((regex, repl))
73
+
74
+ def _get_cluster_config(self) -> omegaconf.DictConfig:
75
+ assert isinstance(self.config, omegaconf.DictConfig)
76
+ return self.config[self.cluster]
77
+
78
+ @classmethod
79
+ def instance(cls):
80
+ if cls._instance is None:
81
+ cls._instance = cls()
82
+ return cls._instance
83
+
84
+ @classmethod
85
+ def reset(cls):
86
+ """Clears the environment and forces a reload on next invocation."""
87
+ cls._instance = None
88
+
89
+ @classmethod
90
+ def get_team(cls) -> str:
91
+ """Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var.
92
+ If not defined, defaults to "labs".
93
+ """
94
+ return cls.instance().team
95
+
96
+ @classmethod
97
+ def get_cluster(cls) -> str:
98
+ """Gets the detected cluster.
99
+ This value can be overridden by the AUDIOCRAFT_CLUSTER env var.
100
+ """
101
+ return cls.instance().cluster
102
+
103
+ @classmethod
104
+ def get_dora_dir(cls) -> Path:
105
+ """Gets the path to the dora directory for the current team and cluster.
106
+ Value is overridden by the AUDIOCRAFT_DORA_DIR env var.
107
+ """
108
+ cluster_config = cls.instance()._get_cluster_config()
109
+ dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"])
110
+ logger.warning(f"Dora directory: {dora_dir}")
111
+ return Path(dora_dir)
112
+
113
+ @classmethod
114
+ def get_reference_dir(cls) -> Path:
115
+ """Gets the path to the reference directory for the current team and cluster.
116
+ Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var.
117
+ """
118
+ cluster_config = cls.instance()._get_cluster_config()
119
+ return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"]))
120
+
121
+ @classmethod
122
+ def get_slurm_exclude(cls) -> tp.Optional[str]:
123
+ """Get the list of nodes to exclude for that cluster."""
124
+ cluster_config = cls.instance()._get_cluster_config()
125
+ return cluster_config.get("slurm_exclude")
126
+
127
+ @classmethod
128
+ def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str:
129
+ """Gets the requested partitions for the current team and cluster as a comma-separated string.
130
+
131
+ Args:
132
+ partition_types (list[str], optional): partition types to retrieve. Values must be
133
+ from ['global', 'team']. If not provided, the global partition is returned.
134
+ """
135
+ if not partition_types:
136
+ partition_types = ["global"]
137
+
138
+ cluster_config = cls.instance()._get_cluster_config()
139
+ partitions = [
140
+ cluster_config["partitions"][partition_type]
141
+ for partition_type in partition_types
142
+ ]
143
+ return ",".join(partitions)
144
+
145
+ @classmethod
146
+ def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path:
147
+ """Converts reference placeholder in path with configured reference dir to resolve paths.
148
+
149
+ Args:
150
+ path (str or Path): Path to resolve.
151
+ Returns:
152
+ Path: Resolved path.
153
+ """
154
+ path = str(path)
155
+
156
+ if path.startswith("//reference"):
157
+ reference_dir = cls.get_reference_dir()
158
+ logger.warn(f"Reference directory: {reference_dir}")
159
+ assert (
160
+ reference_dir.exists() and reference_dir.is_dir()
161
+ ), f"Reference directory does not exist: {reference_dir}."
162
+ path = re.sub("^//reference", str(reference_dir), path)
163
+
164
+ return Path(path)
165
+
166
+ @classmethod
167
+ def apply_dataset_mappers(cls, path: str) -> str:
168
+ """Applies dataset mapping regex rules as defined in the configuration.
169
+ If no rules are defined, the path is returned as-is.
170
+ """
171
+ instance = cls.instance()
172
+
173
+ for pattern, repl in instance._dataset_mappers:
174
+ path = pattern.sub(repl, path)
175
+
176
+ return path
audiocraft/models/__init__.py CHANGED
@@ -4,7 +4,14 @@
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
 
 
 
7
  # flake8: noqa
 
 
 
 
8
  from .musicgen import MusicGen
9
  from .lm import LMModel
10
  from .encodec import CompressionModel, EncodecModel
 
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
 
7
+ """
8
+ Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
9
+ """
10
  # flake8: noqa
11
+ from . import builders, loaders
12
+ from .encodec import (
13
+ CompressionModel, EncodecModel, DAC,
14
+ HFEncodecModel, HFEncodecCompressionModel)
15
  from .musicgen import MusicGen
16
  from .lm import LMModel
17
  from .encodec import CompressionModel, EncodecModel
audiocraft/models/builders.py CHANGED
@@ -10,32 +10,34 @@ from the Hydra config.
10
  """
11
 
12
  import typing as tp
13
- import warnings
14
 
15
  import audiocraft
16
  import omegaconf
17
  import torch
18
 
19
- from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel # noqa
20
  from .lm import LMModel
21
  from ..modules.codebooks_patterns import (
22
  CodebooksPatternProvider,
23
  DelayedPatternProvider,
 
24
  ParallelPatternProvider,
25
  UnrolledPatternProvider,
26
- VALLEPattern,
27
- MusicLMPattern,
28
  )
29
  from ..modules.conditioners import (
30
  BaseConditioner,
 
 
 
31
  ConditioningProvider,
32
  LUTConditioner,
33
  T5Conditioner,
34
- ConditionFuser,
35
- ChromaStemConditioner,
36
  )
 
37
  from .. import quantization as qt
38
  from ..utils.utils import dict_from_config
 
39
 
40
 
41
  def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
@@ -60,12 +62,11 @@ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
60
  decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
61
  return encoder, decoder
62
  else:
63
- raise KeyError(f'Unexpected compression model {cfg.compression_model}')
64
 
65
 
66
  def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
67
- """Instantiate a compression model.
68
- """
69
  if cfg.compression_model == 'encodec':
70
  kwargs = dict_from_config(getattr(cfg, 'encodec'))
71
  encoder_name = kwargs.pop('autoencoder')
@@ -73,20 +74,17 @@ def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
73
  encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
74
  quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
75
  frame_rate = kwargs['sample_rate'] // encoder.hop_length
76
- renormalize = kwargs.pop('renormalize', None)
77
- renorm = kwargs.pop('renorm')
78
- if renormalize is None:
79
- renormalize = renorm is not None
80
- warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.")
81
  return EncodecModel(encoder, decoder, quantizer,
82
  frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
83
  else:
84
- raise KeyError(f'Unexpected compression model {cfg.compression_model}')
85
 
86
 
87
  def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
88
- """Instantiate a transformer LM.
89
- """
90
  if cfg.lm_model == 'transformer_lm':
91
  kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
92
  n_q = kwargs['n_q']
@@ -94,14 +92,14 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
94
  codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
95
  attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
96
  cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
97
- cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"]
98
  fuser = get_condition_fuser(cfg)
99
  condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
100
- if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programatically
101
  kwargs['cross_attention'] = True
102
  if codebooks_pattern_cfg.modeling is None:
103
  assert q_modeling is not None, \
104
- 'LM model should either have a codebook pattern defined or transformer_lm.q_modeling'
105
  codebooks_pattern_cfg = omegaconf.OmegaConf.create(
106
  {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
107
  )
@@ -118,45 +116,50 @@ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
118
  **kwargs
119
  ).to(cfg.device)
120
  else:
121
- raise KeyError(f'Unexpected LM model {cfg.lm_model}')
122
 
123
 
124
  def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
125
- """Instantiate a conditioning model.
126
- """
127
  device = cfg.device
128
  duration = cfg.dataset.segment_duration
129
- cfg = getattr(cfg, "conditioners")
130
- cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg
131
  conditioners: tp.Dict[str, BaseConditioner] = {}
132
- with omegaconf.open_dict(cfg):
133
- condition_provider_args = cfg.pop('args', {})
134
- for cond, cond_cfg in cfg.items():
135
- model_type = cond_cfg["model"]
 
 
136
  model_args = cond_cfg[model_type]
137
- if model_type == "t5":
138
  conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
139
- elif model_type == "lut":
140
  conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
141
- elif model_type == "chroma_stem":
142
- model_args.pop('cache_path', None)
143
  conditioners[str(cond)] = ChromaStemConditioner(
144
  output_dim=output_dim,
145
  duration=duration,
146
  device=device,
147
  **model_args
148
  )
 
 
 
 
 
 
149
  else:
150
- raise ValueError(f"unrecognized conditioning model: {model_type}")
151
  conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
152
  return conditioner
153
 
154
 
155
  def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
156
- """Instantiate a condition fuser object.
157
- """
158
- fuser_cfg = getattr(cfg, "fuser")
159
- fuser_methods = ["sum", "cross", "prepend", "input_interpolate"]
160
  fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
161
  kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
162
  fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
@@ -164,13 +167,12 @@ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
164
 
165
 
166
  def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
167
- """Instantiate a codebooks pattern provider object.
168
- """
169
  pattern_providers = {
170
  'parallel': ParallelPatternProvider,
171
  'delay': DelayedPatternProvider,
172
  'unroll': UnrolledPatternProvider,
173
- 'valle': VALLEPattern,
174
  'musiclm': MusicLMPattern,
175
  }
176
  name = cfg.modeling
@@ -179,14 +181,20 @@ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> Codeb
179
  return klass(n_q, **kwargs)
180
 
181
 
182
- def get_debug_compression_model(device='cpu'):
183
- """Instantiate a debug compression model to be used for unit tests.
184
- """
185
- seanet_kwargs = {
 
 
 
 
 
 
186
  'n_filters': 4,
187
  'n_residual_layers': 1,
188
  'dimension': 32,
189
- 'ratios': [10, 8, 16] # 25 Hz at 32kHz
190
  }
191
  encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
192
  decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
@@ -195,13 +203,31 @@ def get_debug_compression_model(device='cpu'):
195
  quantizer(init_x, 1) # initialize kmeans etc.
196
  compression_model = EncodecModel(
197
  encoder, decoder, quantizer,
198
- frame_rate=25, sample_rate=32000, channels=1).to(device)
199
  return compression_model.eval()
200
 
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  def get_debug_lm_model(device='cpu'):
203
- """Instantiate a debug LM to be used for unit tests.
204
- """
205
  pattern = DelayedPatternProvider(n_q=4)
206
  dim = 16
207
  providers = {
@@ -216,3 +242,17 @@ def get_debug_lm_model(device='cpu'):
216
  n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
217
  cross_attention=True, causal=True)
218
  return lm.to(device).eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  """
11
 
12
  import typing as tp
 
13
 
14
  import audiocraft
15
  import omegaconf
16
  import torch
17
 
18
+ from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel
19
  from .lm import LMModel
20
  from ..modules.codebooks_patterns import (
21
  CodebooksPatternProvider,
22
  DelayedPatternProvider,
23
+ MusicLMPattern,
24
  ParallelPatternProvider,
25
  UnrolledPatternProvider,
26
+ CoarseFirstPattern,
 
27
  )
28
  from ..modules.conditioners import (
29
  BaseConditioner,
30
+ ChromaStemConditioner,
31
+ CLAPEmbeddingConditioner,
32
+ ConditionFuser,
33
  ConditioningProvider,
34
  LUTConditioner,
35
  T5Conditioner,
 
 
36
  )
37
+ from .unet import DiffusionUnet
38
  from .. import quantization as qt
39
  from ..utils.utils import dict_from_config
40
+ from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
41
 
42
 
43
  def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
 
62
  decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
63
  return encoder, decoder
64
  else:
65
+ raise KeyError(f"Unexpected compression model {cfg.compression_model}")
66
 
67
 
68
  def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
69
+ """Instantiate a compression model."""
 
70
  if cfg.compression_model == 'encodec':
71
  kwargs = dict_from_config(getattr(cfg, 'encodec'))
72
  encoder_name = kwargs.pop('autoencoder')
 
74
  encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
75
  quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
76
  frame_rate = kwargs['sample_rate'] // encoder.hop_length
77
+ renormalize = kwargs.pop('renormalize', False)
78
+ # deprecated params
79
+ kwargs.pop('renorm', None)
 
 
80
  return EncodecModel(encoder, decoder, quantizer,
81
  frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
82
  else:
83
+ raise KeyError(f"Unexpected compression model {cfg.compression_model}")
84
 
85
 
86
  def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
87
+ """Instantiate a transformer LM."""
 
88
  if cfg.lm_model == 'transformer_lm':
89
  kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
90
  n_q = kwargs['n_q']
 
92
  codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
93
  attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
94
  cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
95
+ cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
96
  fuser = get_condition_fuser(cfg)
97
  condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
98
+ if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
99
  kwargs['cross_attention'] = True
100
  if codebooks_pattern_cfg.modeling is None:
101
  assert q_modeling is not None, \
102
+ "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
103
  codebooks_pattern_cfg = omegaconf.OmegaConf.create(
104
  {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
105
  )
 
116
  **kwargs
117
  ).to(cfg.device)
118
  else:
119
+ raise KeyError(f"Unexpected LM model {cfg.lm_model}")
120
 
121
 
122
  def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
123
+ """Instantiate a conditioning model."""
 
124
  device = cfg.device
125
  duration = cfg.dataset.segment_duration
126
+ cfg = getattr(cfg, 'conditioners')
127
+ dict_cfg = {} if cfg is None else dict_from_config(cfg)
128
  conditioners: tp.Dict[str, BaseConditioner] = {}
129
+ condition_provider_args = dict_cfg.pop('args', {})
130
+ condition_provider_args.pop('merge_text_conditions_p', None)
131
+ condition_provider_args.pop('drop_desc_p', None)
132
+
133
+ for cond, cond_cfg in dict_cfg.items():
134
+ model_type = cond_cfg['model']
135
  model_args = cond_cfg[model_type]
136
+ if model_type == 't5':
137
  conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
138
+ elif model_type == 'lut':
139
  conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
140
+ elif model_type == 'chroma_stem':
 
141
  conditioners[str(cond)] = ChromaStemConditioner(
142
  output_dim=output_dim,
143
  duration=duration,
144
  device=device,
145
  **model_args
146
  )
147
+ elif model_type == 'clap':
148
+ conditioners[str(cond)] = CLAPEmbeddingConditioner(
149
+ output_dim=output_dim,
150
+ device=device,
151
+ **model_args
152
+ )
153
  else:
154
+ raise ValueError(f"Unrecognized conditioning model: {model_type}")
155
  conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
156
  return conditioner
157
 
158
 
159
  def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
160
+ """Instantiate a condition fuser object."""
161
+ fuser_cfg = getattr(cfg, 'fuser')
162
+ fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
 
163
  fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
164
  kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
165
  fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
 
167
 
168
 
169
  def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
170
+ """Instantiate a codebooks pattern provider object."""
 
171
  pattern_providers = {
172
  'parallel': ParallelPatternProvider,
173
  'delay': DelayedPatternProvider,
174
  'unroll': UnrolledPatternProvider,
175
+ 'coarse_first': CoarseFirstPattern,
176
  'musiclm': MusicLMPattern,
177
  }
178
  name = cfg.modeling
 
181
  return klass(n_q, **kwargs)
182
 
183
 
184
+ def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
185
+ """Instantiate a debug compression model to be used for unit tests."""
186
+ assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
187
+ model_ratios = {
188
+ 16000: [10, 8, 8], # 25 Hz at 16kHz
189
+ 32000: [10, 8, 16] # 25 Hz at 32kHz
190
+ }
191
+ ratios: tp.List[int] = model_ratios[sample_rate]
192
+ frame_rate = 25
193
+ seanet_kwargs: dict = {
194
  'n_filters': 4,
195
  'n_residual_layers': 1,
196
  'dimension': 32,
197
+ 'ratios': ratios,
198
  }
199
  encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
200
  decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
 
203
  quantizer(init_x, 1) # initialize kmeans etc.
204
  compression_model = EncodecModel(
205
  encoder, decoder, quantizer,
206
+ frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
207
  return compression_model.eval()
208
 
209
 
210
+ def get_diffusion_model(cfg: omegaconf.DictConfig):
211
+ # TODO Find a way to infer the channels from dset
212
+ channels = cfg.channels
213
+ num_steps = cfg.schedule.num_steps
214
+ return DiffusionUnet(
215
+ chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
216
+
217
+
218
+ def get_processor(cfg, sample_rate: int = 24000):
219
+ sample_processor = SampleProcessor()
220
+ if cfg.use:
221
+ kw = dict(cfg)
222
+ kw.pop('use')
223
+ kw.pop('name')
224
+ if cfg.name == "multi_band_processor":
225
+ sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
226
+ return sample_processor
227
+
228
+
229
  def get_debug_lm_model(device='cpu'):
230
+ """Instantiate a debug LM to be used for unit tests."""
 
231
  pattern = DelayedPatternProvider(n_q=4)
232
  dim = 16
233
  providers = {
 
242
  n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
243
  cross_attention=True, causal=True)
244
  return lm.to(device).eval()
245
+
246
+
247
+ def get_wrapped_compression_model(
248
+ compression_model: CompressionModel,
249
+ cfg: omegaconf.DictConfig) -> CompressionModel:
250
+ if hasattr(cfg, 'interleave_stereo_codebooks'):
251
+ if cfg.interleave_stereo_codebooks.use:
252
+ kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
253
+ kwargs.pop('use')
254
+ compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
255
+ if hasattr(cfg, 'compression_model_n_q'):
256
+ if cfg.compression_model_n_q is not None:
257
+ compression_model.set_num_codebooks(cfg.compression_model_n_q)
258
+ return compression_model
audiocraft/models/encodec.py CHANGED
@@ -3,18 +3,32 @@
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
 
 
 
6
 
7
  from abc import ABC, abstractmethod
 
 
 
8
  import typing as tp
9
 
10
  from einops import rearrange
 
11
  import torch
12
  from torch import nn
 
13
 
14
  from .. import quantization as qt
15
 
16
 
 
 
 
17
  class CompressionModel(ABC, nn.Module):
 
 
 
18
 
19
  @abstractmethod
20
  def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
@@ -22,12 +36,17 @@ class CompressionModel(ABC, nn.Module):
22
 
23
  @abstractmethod
24
  def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
25
- """See `EncodecModel.encode`"""
26
  ...
27
 
28
  @abstractmethod
29
  def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
30
- """See `EncodecModel.decode`"""
 
 
 
 
 
31
  ...
32
 
33
  @property
@@ -37,7 +56,7 @@ class CompressionModel(ABC, nn.Module):
37
 
38
  @property
39
  @abstractmethod
40
- def frame_rate(self) -> int:
41
  ...
42
 
43
  @property
@@ -62,10 +81,46 @@ class CompressionModel(ABC, nn.Module):
62
 
63
  @abstractmethod
64
  def set_num_codebooks(self, n: int):
65
- """Set the active number of codebooks used by the quantizer.
66
- """
67
  ...
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  class EncodecModel(CompressionModel):
71
  """Encodec model operating on the raw waveform.
@@ -80,9 +135,9 @@ class EncodecModel(CompressionModel):
80
  causal (bool): Whether to use a causal version of the model.
81
  renormalize (bool): Whether to renormalize the audio before running the model.
82
  """
83
- # we need assignement to override the property in the abstract class,
84
  # I couldn't find a better way...
85
- frame_rate: int = 0
86
  sample_rate: int = 0
87
  channels: int = 0
88
 
@@ -111,25 +166,21 @@ class EncodecModel(CompressionModel):
111
 
112
  @property
113
  def total_codebooks(self):
114
- """Total number of quantizer codebooks available.
115
- """
116
  return self.quantizer.total_codebooks
117
 
118
  @property
119
  def num_codebooks(self):
120
- """Active number of codebooks used by the quantizer.
121
- """
122
  return self.quantizer.num_codebooks
123
 
124
  def set_num_codebooks(self, n: int):
125
- """Set the active number of codebooks used by the quantizer.
126
- """
127
  self.quantizer.set_num_codebooks(n)
128
 
129
  @property
130
  def cardinality(self):
131
- """Cardinality of each codebook.
132
- """
133
  return self.quantizer.bins
134
 
135
  def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
@@ -176,7 +227,7 @@ class EncodecModel(CompressionModel):
176
  x (torch.Tensor): Float tensor of shape [B, C, T]
177
 
178
  Returns:
179
- codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of:
180
  codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
181
  scale a float tensor containing the scale for audio renormalizealization.
182
  """
@@ -192,41 +243,174 @@ class EncodecModel(CompressionModel):
192
 
193
  Args:
194
  codes (torch.Tensor): Int tensor of shape [B, K, T]
195
- scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value.
196
 
197
  Returns:
198
  out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
199
  """
200
- emb = self.quantizer.decode(codes)
201
  out = self.decoder(emb)
202
  out = self.postprocess(out, scale)
203
  # out contains extra padding added by the encoder and decoder
204
  return out
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- class FlattenedCompressionModel(CompressionModel):
208
- """Wraps a CompressionModel and flatten its codebooks, e.g.
209
- instead of returning [B, K, T], return [B, S, T * (K // S)] with
210
- S the number of codebooks per step, and `K // S` the number of 'virtual steps'
211
- for each real time step.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  Args:
214
- model (CompressionModel): compression model to wrap.
215
- codebooks_per_step (int): number of codebooks to keep per step,
216
- this must divide the number of codebooks provided by the wrapped model.
217
- extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1,
218
- if each codebook has a cardinality N, then the first codebook will
219
- use the range [0, N - 1], and the second [N, 2 N - 1] etc.
220
- On decoding, this can lead to potentially invalid sequences.
221
- Any invalid entry will be silently remapped to the proper range
222
- with a modulo.
223
  """
224
- def __init__(self, model: CompressionModel, codebooks_per_step: int = 1,
225
- extend_cardinality: bool = True):
226
  super().__init__()
227
  self.model = model
228
- self.codebooks_per_step = codebooks_per_step
229
- self.extend_cardinality = extend_cardinality
230
 
231
  @property
232
  def total_codebooks(self):
@@ -236,30 +420,27 @@ class FlattenedCompressionModel(CompressionModel):
236
  def num_codebooks(self):
237
  """Active number of codebooks used by the quantizer.
238
 
239
- ..Warning:: this reports the number of codebooks after the flattening
240
  of the codebooks!
241
  """
242
- assert self.model.num_codebooks % self.codebooks_per_step == 0
243
- return self.codebooks_per_step
244
 
245
  def set_num_codebooks(self, n: int):
246
  """Set the active number of codebooks used by the quantizer.
247
 
248
- ..Warning:: this sets the number of codebooks **before** the flattening
249
- of the codebooks.
250
  """
251
- assert n % self.codebooks_per_step == 0
252
  self.model.set_num_codebooks(n)
253
 
254
  @property
255
- def num_virtual_steps(self) -> int:
256
  """Return the number of virtual steps, e.g. one real step
257
  will be split into that many steps.
258
  """
259
- return self.model.num_codebooks // self.codebooks_per_step
260
 
261
  @property
262
- def frame_rate(self) -> int:
263
  return self.model.frame_rate * self.num_virtual_steps
264
 
265
  @property
@@ -268,35 +449,58 @@ class FlattenedCompressionModel(CompressionModel):
268
 
269
  @property
270
  def channels(self) -> int:
271
- return self.model.channels
272
 
273
  @property
274
  def cardinality(self):
275
  """Cardinality of each codebook.
276
  """
277
- if self.extend_cardinality:
278
- return self.model.cardinality * self.num_virtual_steps
279
- else:
280
- return self.model.cardinality
281
 
282
  def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
283
  raise NotImplementedError("Not supported, use encode and decode.")
284
 
285
  def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
286
- indices, scales = self.model.encode(x)
287
- B, K, T = indices.shape
288
- indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step)
289
- if self.extend_cardinality:
290
- for virtual_step in range(1, self.num_virtual_steps):
291
- indices[..., virtual_step] += self.model.cardinality * virtual_step
292
- indices = rearrange(indices, 'b k t v -> b k (t v)')
 
 
 
 
 
 
 
 
293
  return (indices, scales)
294
 
 
 
 
 
 
 
 
295
  def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
296
  B, K, T = codes.shape
297
- assert T % self.num_virtual_steps == 0
298
- codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps)
299
- # We silently ignore potential errors from the LM when
300
- # using extend_cardinality.
301
- codes = codes % self.model.cardinality
302
- return self.model.decode(codes, scale)
 
 
 
 
 
 
 
 
 
 
 
 
3
  #
4
  # This source code is licensed under the license found in the
5
  # LICENSE file in the root directory of this source tree.
6
+ """Compression models or wrapper around existing models.
7
+ Also defines the main interface that a model must follow to be usable as an audio tokenizer.
8
+ """
9
 
10
  from abc import ABC, abstractmethod
11
+ import logging
12
+ import math
13
+ from pathlib import Path
14
  import typing as tp
15
 
16
  from einops import rearrange
17
+ import numpy as np
18
  import torch
19
  from torch import nn
20
+ from transformers import EncodecModel as HFEncodecModel
21
 
22
  from .. import quantization as qt
23
 
24
 
25
+ logger = logging.getLogger()
26
+
27
+
28
  class CompressionModel(ABC, nn.Module):
29
+ """Base API for all compression model that aim at being used as audio tokenizers
30
+ with a language model.
31
+ """
32
 
33
  @abstractmethod
34
  def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
 
36
 
37
  @abstractmethod
38
  def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
39
+ """See `EncodecModel.encode`."""
40
  ...
41
 
42
  @abstractmethod
43
  def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
44
+ """See `EncodecModel.decode`."""
45
+ ...
46
+
47
+ @abstractmethod
48
+ def decode_latent(self, codes: torch.Tensor):
49
+ """Decode from the discrete codes to continuous latent space."""
50
  ...
51
 
52
  @property
 
56
 
57
  @property
58
  @abstractmethod
59
+ def frame_rate(self) -> float:
60
  ...
61
 
62
  @property
 
81
 
82
  @abstractmethod
83
  def set_num_codebooks(self, n: int):
84
+ """Set the active number of codebooks used by the quantizer."""
 
85
  ...
86
 
87
+ @staticmethod
88
+ def get_pretrained(
89
+ name: str, device: tp.Union[torch.device, str] = 'cpu'
90
+ ) -> 'CompressionModel':
91
+ """Instantiate a CompressionModel from a given pretrained model.
92
+
93
+ Args:
94
+ name (Path or str): name of the pretrained model. See after.
95
+ device (torch.device or str): Device on which the model is loaded.
96
+
97
+ Pretrained models:
98
+ - dac_44khz (https://github.com/descriptinc/descript-audio-codec)
99
+ - dac_24khz (same)
100
+ - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
101
+ - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
102
+ - your own model on HugginFace. Export instructions to come...
103
+ """
104
+
105
+ from . import builders, loaders
106
+ model: CompressionModel
107
+ if name in ['dac_44khz', 'dac_24khz']:
108
+ model_type = name.split('_')[1]
109
+ logger.info("Getting pretrained compression model from DAC %s", model_type)
110
+ model = DAC(model_type)
111
+ elif name in ['debug_compression_model']:
112
+ logger.info("Getting pretrained compression model for debug")
113
+ model = builders.get_debug_compression_model()
114
+ elif Path(name).exists():
115
+ # We assume here if the paths exist that it is in fact an AC checkpoint
116
+ # that was exported using `audiocraft.utils.export` functions.
117
+ model = loaders.load_compression_model(name, device=device)
118
+ else:
119
+ logger.info("Getting pretrained compression model from HF %s", name)
120
+ hf_model = HFEncodecModel.from_pretrained(name)
121
+ model = HFEncodecCompressionModel(hf_model).to(device)
122
+ return model.to(device).eval()
123
+
124
 
125
  class EncodecModel(CompressionModel):
126
  """Encodec model operating on the raw waveform.
 
135
  causal (bool): Whether to use a causal version of the model.
136
  renormalize (bool): Whether to renormalize the audio before running the model.
137
  """
138
+ # we need assignment to override the property in the abstract class,
139
  # I couldn't find a better way...
140
+ frame_rate: float = 0
141
  sample_rate: int = 0
142
  channels: int = 0
143
 
 
166
 
167
  @property
168
  def total_codebooks(self):
169
+ """Total number of quantizer codebooks available."""
 
170
  return self.quantizer.total_codebooks
171
 
172
  @property
173
  def num_codebooks(self):
174
+ """Active number of codebooks used by the quantizer."""
 
175
  return self.quantizer.num_codebooks
176
 
177
  def set_num_codebooks(self, n: int):
178
+ """Set the active number of codebooks used by the quantizer."""
 
179
  self.quantizer.set_num_codebooks(n)
180
 
181
  @property
182
  def cardinality(self):
183
+ """Cardinality of each codebook."""
 
184
  return self.quantizer.bins
185
 
186
  def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
 
227
  x (torch.Tensor): Float tensor of shape [B, C, T]
228
 
229
  Returns:
230
+ codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
231
  codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
232
  scale a float tensor containing the scale for audio renormalizealization.
233
  """
 
243
 
244
  Args:
245
  codes (torch.Tensor): Int tensor of shape [B, K, T]
246
+ scale (torch.Tensor, optional): Float tensor containing the scale value.
247
 
248
  Returns:
249
  out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
250
  """
251
+ emb = self.decode_latent(codes)
252
  out = self.decoder(emb)
253
  out = self.postprocess(out, scale)
254
  # out contains extra padding added by the encoder and decoder
255
  return out
256
 
257
+ def decode_latent(self, codes: torch.Tensor):
258
+ """Decode from the discrete codes to continuous latent space."""
259
+ return self.quantizer.decode(codes)
260
+
261
+
262
+ class DAC(CompressionModel):
263
+ def __init__(self, model_type: str = "44khz"):
264
+ super().__init__()
265
+ try:
266
+ import dac.utils
267
+ except ImportError:
268
+ raise RuntimeError("Could not import dac, make sure it is installed, "
269
+ "please run `pip install descript-audio-codec`")
270
+ self.model = dac.utils.load_model(model_type=model_type)
271
+ self.n_quantizers = self.total_codebooks
272
+ self.model.eval()
273
+
274
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
275
+ # We don't support training with this.
276
+ raise NotImplementedError("Forward and training with DAC not supported.")
277
+
278
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
279
+ codes = self.model.encode(x, self.n_quantizers)[1]
280
+ return codes[:, :self.n_quantizers], None
281
+
282
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
283
+ assert scale is None
284
+ z_q = self.decode_latent(codes)
285
+ return self.model.decode(z_q)
286
+
287
+ def decode_latent(self, codes: torch.Tensor):
288
+ """Decode from the discrete codes to continuous latent space."""
289
+ return self.model.quantizer.from_codes(codes)[0]
290
+
291
+ @property
292
+ def channels(self) -> int:
293
+ return 1
294
+
295
+ @property
296
+ def frame_rate(self) -> float:
297
+ return self.model.sample_rate / self.model.hop_length
298
+
299
+ @property
300
+ def sample_rate(self) -> int:
301
+ return self.model.sample_rate
302
+
303
+ @property
304
+ def cardinality(self) -> int:
305
+ return self.model.codebook_size
306
+
307
+ @property
308
+ def num_codebooks(self) -> int:
309
+ return self.n_quantizers
310
 
311
+ @property
312
+ def total_codebooks(self) -> int:
313
+ return self.model.n_codebooks
314
+
315
+ def set_num_codebooks(self, n: int):
316
+ """Set the active number of codebooks used by the quantizer.
317
+ """
318
+ assert n >= 1
319
+ assert n <= self.total_codebooks
320
+ self.n_quantizers = n
321
+
322
+
323
+ class HFEncodecCompressionModel(CompressionModel):
324
+ """Wrapper around HuggingFace Encodec.
325
+ """
326
+ def __init__(self, model: HFEncodecModel):
327
+ super().__init__()
328
+ self.model = model
329
+ bws = self.model.config.target_bandwidths
330
+ num_codebooks = [
331
+ bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
332
+ for bw in bws
333
+ ]
334
+ deltas = [nc - int(nc) for nc in num_codebooks]
335
+ # Checking we didn't do some bad maths and we indeed have integers!
336
+ assert all(deltas) <= 1e-3, deltas
337
+ self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
338
+ self.set_num_codebooks(max(self.possible_num_codebooks))
339
+
340
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
341
+ # We don't support training with this.
342
+ raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
343
+
344
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
345
+ bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
346
+ bandwidth = self.model.config.target_bandwidths[bandwidth_index]
347
+ res = self.model.encode(x, None, bandwidth)
348
+ assert len(res[0]) == 1
349
+ assert len(res[1]) == 1
350
+ return res[0][0], res[1][0]
351
+
352
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
353
+ if scale is None:
354
+ scales = [None] # type: ignore
355
+ else:
356
+ scales = scale # type: ignore
357
+ res = self.model.decode(codes[None], scales)
358
+ return res[0]
359
+
360
+ def decode_latent(self, codes: torch.Tensor):
361
+ """Decode from the discrete codes to continuous latent space."""
362
+ return self.model.quantizer.decode(codes.transpose(0, 1))
363
+
364
+ @property
365
+ def channels(self) -> int:
366
+ return self.model.config.audio_channels
367
+
368
+ @property
369
+ def frame_rate(self) -> float:
370
+ hop_length = int(np.prod(self.model.config.upsampling_ratios))
371
+ return self.sample_rate / hop_length
372
+
373
+ @property
374
+ def sample_rate(self) -> int:
375
+ return self.model.config.sampling_rate
376
+
377
+ @property
378
+ def cardinality(self) -> int:
379
+ return self.model.config.codebook_size
380
+
381
+ @property
382
+ def num_codebooks(self) -> int:
383
+ return self._num_codebooks
384
+
385
+ @property
386
+ def total_codebooks(self) -> int:
387
+ return max(self.possible_num_codebooks)
388
+
389
+ def set_num_codebooks(self, n: int):
390
+ """Set the active number of codebooks used by the quantizer.
391
+ """
392
+ if n not in self.possible_num_codebooks:
393
+ raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
394
+ self._num_codebooks = n
395
+
396
+
397
+ class InterleaveStereoCompressionModel(CompressionModel):
398
+ """Wraps a CompressionModel to support stereo inputs. The wrapped model
399
+ will be applied independently to the left and right channels, and both codebooks
400
+ will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per
401
+ channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on
402
+ `per_timestep`.
403
 
404
  Args:
405
+ model (CompressionModel): Compression model to wrap.
406
+ per_timestep (bool): Whether to interleave on the timestep dimension
407
+ or on the codebooks dimension.
 
 
 
 
 
 
408
  """
409
+ def __init__(self, model: CompressionModel, per_timestep: bool = False):
 
410
  super().__init__()
411
  self.model = model
412
+ self.per_timestep = per_timestep
413
+ assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio"
414
 
415
  @property
416
  def total_codebooks(self):
 
420
  def num_codebooks(self):
421
  """Active number of codebooks used by the quantizer.
422
 
423
+ ..Warning:: this reports the number of codebooks after the interleaving
424
  of the codebooks!
425
  """
426
+ return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2
 
427
 
428
  def set_num_codebooks(self, n: int):
429
  """Set the active number of codebooks used by the quantizer.
430
 
431
+ ..Warning:: this sets the number of codebooks before the interleaving!
 
432
  """
 
433
  self.model.set_num_codebooks(n)
434
 
435
  @property
436
+ def num_virtual_steps(self) -> float:
437
  """Return the number of virtual steps, e.g. one real step
438
  will be split into that many steps.
439
  """
440
+ return 2 if self.per_timestep else 1
441
 
442
  @property
443
+ def frame_rate(self) -> float:
444
  return self.model.frame_rate * self.num_virtual_steps
445
 
446
  @property
 
449
 
450
  @property
451
  def channels(self) -> int:
452
+ return 2
453
 
454
  @property
455
  def cardinality(self):
456
  """Cardinality of each codebook.
457
  """
458
+ return self.model.cardinality
 
 
 
459
 
460
  def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
461
  raise NotImplementedError("Not supported, use encode and decode.")
462
 
463
  def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
464
+ B, C, T = x.shape
465
+ assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}"
466
+
467
+ indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1))
468
+ indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1))
469
+ indices = torch.stack([indices_c0, indices_c1], dim=0)
470
+ scales: tp.Optional[torch.Tensor] = None
471
+ if scales_c0 is not None and scales_c1 is not None:
472
+ scales = torch.stack([scales_c0, scales_c1], dim=1)
473
+
474
+ if self.per_timestep:
475
+ indices = rearrange(indices, 'c b k t -> b k (t c)', c=2)
476
+ else:
477
+ indices = rearrange(indices, 'c b k t -> b (k c) t', c=2)
478
+
479
  return (indices, scales)
480
 
481
+ def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
482
+ if self.per_timestep:
483
+ codes = rearrange(codes, 'b k (t c) -> c b k t', c=2)
484
+ else:
485
+ codes = rearrange(codes, 'b (k c) t -> c b k t', c=2)
486
+ return codes[0], codes[1]
487
+
488
  def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
489
  B, K, T = codes.shape
490
+ assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match"
491
+ assert K == self.num_codebooks, "Provided codes' number of codebooks does not match"
492
+
493
+ scale_c0, scale_c1 = None, None
494
+ if scale is not None:
495
+ assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}"
496
+ scale_c0 = scale[0, ...]
497
+ scale_c1 = scale[1, ...]
498
+
499
+ codes_c0, codes_c1 = self.get_left_right_codes(codes)
500
+ audio_c0 = self.model.decode(codes_c0, scale_c0)
501
+ audio_c1 = self.model.decode(codes_c1, scale_c1)
502
+ return torch.cat([audio_c0, audio_c1], dim=1)
503
+
504
+ def decode_latent(self, codes: torch.Tensor):
505
+ """Decode from the discrete codes to continuous latent space."""
506
+ raise NotImplementedError("Not supported by interleaved stereo wrapped models.")
audiocraft/models/lm.py CHANGED
@@ -41,7 +41,7 @@ def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None
41
  method (str): Method name for init function. Valid options are:
42
  'gaussian', 'uniform'.
43
  input_dim (int): Input dimension of the initialized module.
44
- init_depth (Optional[int]): Optional init depth value used to rescale
45
  the standard deviation if defined.
46
  """
47
  # Compute std
@@ -70,7 +70,7 @@ def init_layer(m: nn.Module,
70
  Args:
71
  m (nn.Module): Module to initialize.
72
  method (str): Method name for the init function.
73
- init_depth (Optional[int]): Optional init depth value used to rescale
74
  the standard deviation if defined.
75
  zero_bias_init (bool): Whether to initialize the bias to 0 or not.
76
  """
@@ -130,10 +130,10 @@ class LMModel(StreamingModule):
130
  hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
131
  norm (str): Normalization method.
132
  norm_first (bool): Use pre-norm instead of post-norm.
133
- emb_lr (Optional[float]): Embedding-specific learning rate.
134
  bias_proj (bool): Use bias for output projections.
135
- weight_init (Optional[str]): Method for weight initialization.
136
- depthwise_init (Optional[str]): Method for depthwise weight initialization.
137
  zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
138
  cfg_dropout (float): Classifier-free guidance dropout.
139
  cfg_coef (float): Classifier-free guidance coefficient.
@@ -179,11 +179,11 @@ class LMModel(StreamingModule):
179
  """Initialization of the transformer module weights.
180
 
181
  Args:
182
- weight_init (Optional[str]): Weight initialization strategy. See ``get_init_fn`` for valid options.
183
- depthwise_init (Optional[str]): Depwthwise initialization strategy. The following options are valid:
184
  'current' where the depth corresponds to the current layer index or 'global' where the total number
185
  of layer is used as depth. If not set, no depthwise initialization strategy is used.
186
- zero_bias_init (bool): Whether to initalize bias to zero or not.
187
  """
188
  assert depthwise_init is None or depthwise_init in ['current', 'global']
189
  assert depthwise_init is None or weight_init is not None, \
@@ -225,17 +225,17 @@ class LMModel(StreamingModule):
225
  S the sequence steps, return the logits with shape [B, card, K, S].
226
 
227
  Args:
228
- indices (torch.Tensor): indices of the codes to model.
229
- conditions (list[ConditioningAttributes]): conditionings to use when modeling
230
  the given codes. Note that when evaluating multiple time with the same conditioning
231
  you should pre-compute those and pass them as `condition_tensors`.
232
- condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
233
  tensors, see `conditions`.
234
  Returns:
235
  torch.Tensor: Logits.
236
  """
237
  B, K, S = sequence.shape
238
- assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks'
239
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
240
  if condition_tensors is None:
241
  assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
@@ -271,10 +271,10 @@ class LMModel(StreamingModule):
271
  Args:
272
  codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
273
  K the number of codebooks and T the number of timesteps.
274
- conditions (list[ConditioningAttributes]): conditionings to use when modeling
275
  the given codes. Note that when evaluating multiple time with the same conditioning
276
  you should pre-compute those and pass them as `condition_tensors`.
277
- condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
278
  tensors, see `conditions`.
279
  Returns:
280
  LMOutput: Language model outputs
@@ -314,7 +314,8 @@ class LMModel(StreamingModule):
314
  temp: float = 1.0,
315
  top_k: int = 0,
316
  top_p: float = 0.0,
317
- cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
 
318
  """Sample next token from the model given a sequence and a set of conditions. The model supports
319
  multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
320
 
@@ -322,21 +323,22 @@ class LMModel(StreamingModule):
322
  sequence (torch.Tensor): Current sequence of shape [B, K, S]
323
  with K corresponding to the number of codebooks and S the number of sequence steps.
324
  S = 1 in streaming mode, except for the first step that contains a bigger prompt.
325
- condition_tensors (Dict[str, ConditionType): Set of conditions. If CFG is used,
326
  should be twice the batch size, being the concatenation of the conditions + null conditions.
327
  use_sampling (bool): Whether to use a sampling strategy or not.
328
  temp (float): Sampling temperature.
329
  top_k (int): K for "top-k" sampling.
330
  top_p (float): P for "top-p" sampling.
331
- cfg_coef (float): classifier free guidance coefficient
332
  Returns:
333
  next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
334
  """
335
  B = sequence.shape[0]
336
  cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
337
  model = self if self._fsdp is None else self._fsdp
338
- if self.two_step_cfg and cfg_conditions != {}:
339
- assert isinstance(cfg_conditions, tuple)
 
340
  condition_tensors, null_condition_tensors = cfg_conditions
341
  cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
342
  state = self.get_streaming_state()
@@ -388,7 +390,7 @@ class LMModel(StreamingModule):
388
  top_k: int = 250,
389
  top_p: float = 0.0,
390
  cfg_coef: tp.Optional[float] = None,
391
- two_step_cfg: bool = False,
392
  remove_prompts: bool = False,
393
  check: bool = False,
394
  callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
@@ -396,15 +398,19 @@ class LMModel(StreamingModule):
396
  be perform in a greedy fashion or using sampling with top K and top P strategies.
397
 
398
  Args:
399
- prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T].
400
- conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None.
401
- num_samples (int or None): Number of samples to generate when no prompt and no conditions are given.
402
  max_gen_len (int): Maximum generation length.
403
  use_sampling (bool): Whether to use a sampling strategy or not.
404
  temp (float): Sampling temperature.
405
  top_k (int): K for "top-k" sampling.
406
  top_p (float): P for "top-p" sampling.
 
 
407
  remove_prompts (bool): Whether to remove prompts from generation or not.
 
 
408
  Returns:
409
  torch.Tensor: Generated tokens.
410
  """
@@ -412,7 +418,7 @@ class LMModel(StreamingModule):
412
  first_param = next(iter(self.parameters()))
413
  device = first_param.device
414
 
415
- # Checking all input shapes are consistents.
416
  possible_num_samples = []
417
  if num_samples is not None:
418
  possible_num_samples.append(num_samples)
@@ -422,7 +428,7 @@ class LMModel(StreamingModule):
422
  possible_num_samples.append(len(conditions))
423
  else:
424
  possible_num_samples.append(1)
425
- assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes"
426
  num_samples = possible_num_samples[0]
427
 
428
  # below we create set of conditions: one conditional and one unconditional
@@ -432,7 +438,7 @@ class LMModel(StreamingModule):
432
  # 1. it is about x2 faster than doing 2 forward passes
433
  # 2. avoid the streaming API treating the 2 passes as part of different time steps
434
  # We also support doing two different passes, in particular to ensure that
435
- # the padding structure is exactly the same between train anf test.
436
  # With a batch size of 1, this can be slower though.
437
  cfg_conditions: CFGConditions
438
  two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
@@ -457,8 +463,8 @@ class LMModel(StreamingModule):
457
  B, K, T = prompt.shape
458
  start_offset = T
459
  print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
460
- assert start_offset <= max_gen_len
461
-
462
  pattern = self.pattern_provider.get_pattern(max_gen_len)
463
  # this token is used as default value for codes that are not generated yet
464
  unknown_token = -1
@@ -490,7 +496,7 @@ class LMModel(StreamingModule):
490
  # sample next token from the model, next token shape is [B, K, 1]
491
  next_token = self._sample_next_token(
492
  curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
493
- cfg_coef=cfg_coef)
494
  # ensure the tokens that should be masked are properly set to special_token_id
495
  # as the model never output special_token_id
496
  valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
 
41
  method (str): Method name for init function. Valid options are:
42
  'gaussian', 'uniform'.
43
  input_dim (int): Input dimension of the initialized module.
44
+ init_depth (int, optional): Optional init depth value used to rescale
45
  the standard deviation if defined.
46
  """
47
  # Compute std
 
70
  Args:
71
  m (nn.Module): Module to initialize.
72
  method (str): Method name for the init function.
73
+ init_depth (int, optional): Optional init depth value used to rescale
74
  the standard deviation if defined.
75
  zero_bias_init (bool): Whether to initialize the bias to 0 or not.
76
  """
 
130
  hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
131
  norm (str): Normalization method.
132
  norm_first (bool): Use pre-norm instead of post-norm.
133
+ emb_lr (float, optional): Embedding-specific learning rate.
134
  bias_proj (bool): Use bias for output projections.
135
+ weight_init (str, optional): Method for weight initialization.
136
+ depthwise_init (str, optional): Method for depthwise weight initialization.
137
  zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
138
  cfg_dropout (float): Classifier-free guidance dropout.
139
  cfg_coef (float): Classifier-free guidance coefficient.
 
179
  """Initialization of the transformer module weights.
180
 
181
  Args:
182
+ weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
183
+ depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
184
  'current' where the depth corresponds to the current layer index or 'global' where the total number
185
  of layer is used as depth. If not set, no depthwise initialization strategy is used.
186
+ zero_bias_init (bool): Whether to initialize bias to zero or not.
187
  """
188
  assert depthwise_init is None or depthwise_init in ['current', 'global']
189
  assert depthwise_init is None or weight_init is not None, \
 
225
  S the sequence steps, return the logits with shape [B, card, K, S].
226
 
227
  Args:
228
+ indices (torch.Tensor): Indices of the codes to model.
229
+ conditions (list of ConditioningAttributes): Conditions to use when modeling
230
  the given codes. Note that when evaluating multiple time with the same conditioning
231
  you should pre-compute those and pass them as `condition_tensors`.
232
+ condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
233
  tensors, see `conditions`.
234
  Returns:
235
  torch.Tensor: Logits.
236
  """
237
  B, K, S = sequence.shape
238
+ assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
239
  input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
240
  if condition_tensors is None:
241
  assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
 
271
  Args:
272
  codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
273
  K the number of codebooks and T the number of timesteps.
274
+ conditions (list of ConditioningAttributes): conditionings to use when modeling
275
  the given codes. Note that when evaluating multiple time with the same conditioning
276
  you should pre-compute those and pass them as `condition_tensors`.
277
+ condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
278
  tensors, see `conditions`.
279
  Returns:
280
  LMOutput: Language model outputs
 
314
  temp: float = 1.0,
315
  top_k: int = 0,
316
  top_p: float = 0.0,
317
+ cfg_coef: tp.Optional[float] = None,
318
+ two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
319
  """Sample next token from the model given a sequence and a set of conditions. The model supports
320
  multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
321
 
 
323
  sequence (torch.Tensor): Current sequence of shape [B, K, S]
324
  with K corresponding to the number of codebooks and S the number of sequence steps.
325
  S = 1 in streaming mode, except for the first step that contains a bigger prompt.
326
+ condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
327
  should be twice the batch size, being the concatenation of the conditions + null conditions.
328
  use_sampling (bool): Whether to use a sampling strategy or not.
329
  temp (float): Sampling temperature.
330
  top_k (int): K for "top-k" sampling.
331
  top_p (float): P for "top-p" sampling.
332
+ cfg_coef (float, optional): classifier free guidance coefficient
333
  Returns:
334
  next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
335
  """
336
  B = sequence.shape[0]
337
  cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
338
  model = self if self._fsdp is None else self._fsdp
339
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
340
+ if two_step_cfg and cfg_conditions != {}:
341
+ assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
342
  condition_tensors, null_condition_tensors = cfg_conditions
343
  cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
344
  state = self.get_streaming_state()
 
390
  top_k: int = 250,
391
  top_p: float = 0.0,
392
  cfg_coef: tp.Optional[float] = None,
393
+ two_step_cfg: tp.Optional[bool] = None,
394
  remove_prompts: bool = False,
395
  check: bool = False,
396
  callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
 
398
  be perform in a greedy fashion or using sampling with top K and top P strategies.
399
 
400
  Args:
401
+ prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
402
+ conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
403
+ num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
404
  max_gen_len (int): Maximum generation length.
405
  use_sampling (bool): Whether to use a sampling strategy or not.
406
  temp (float): Sampling temperature.
407
  top_k (int): K for "top-k" sampling.
408
  top_p (float): P for "top-p" sampling.
409
+ cfg_coeff (float, optional): Classifier-free guidance coefficient.
410
+ two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
411
  remove_prompts (bool): Whether to remove prompts from generation or not.
412
+ check (bool): Whether to apply further checks on generated sequence.
413
+ callback (Callback, optional): Callback function to report generation progress.
414
  Returns:
415
  torch.Tensor: Generated tokens.
416
  """
 
418
  first_param = next(iter(self.parameters()))
419
  device = first_param.device
420
 
421
+ # Checking all input shapes are consistent.
422
  possible_num_samples = []
423
  if num_samples is not None:
424
  possible_num_samples.append(num_samples)
 
428
  possible_num_samples.append(len(conditions))
429
  else:
430
  possible_num_samples.append(1)
431
+ assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
432
  num_samples = possible_num_samples[0]
433
 
434
  # below we create set of conditions: one conditional and one unconditional
 
438
  # 1. it is about x2 faster than doing 2 forward passes
439
  # 2. avoid the streaming API treating the 2 passes as part of different time steps
440
  # We also support doing two different passes, in particular to ensure that
441
+ # the padding structure is exactly the same between train and test.
442
  # With a batch size of 1, this can be slower though.
443
  cfg_conditions: CFGConditions
444
  two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
 
463
  B, K, T = prompt.shape
464
  start_offset = T
465
  print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
466
+ assert start_offset < max_gen_len
467
+
468
  pattern = self.pattern_provider.get_pattern(max_gen_len)
469
  # this token is used as default value for codes that are not generated yet
470
  unknown_token = -1
 
496
  # sample next token from the model, next token shape is [B, K, 1]
497
  next_token = self._sample_next_token(
498
  curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
499
+ cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
500
  # ensure the tokens that should be masked are properly set to special_token_id
501
  # as the model never output special_token_id
502
  valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
audiocraft/models/loaders.py CHANGED
@@ -24,10 +24,16 @@ from huggingface_hub import hf_hub_download
24
  import typing as tp
25
  import os
26
 
27
- from omegaconf import OmegaConf
28
  import torch
29
 
 
30
  from . import builders
 
 
 
 
 
31
 
32
 
33
  HF_MODEL_CHECKPOINTS_MAP = {
@@ -50,6 +56,8 @@ def _get_state_dict(
50
  device='cpu',
51
  cache_dir: tp.Optional[str] = None,
52
  ):
 
 
53
  # Return the state dict either from a file or url
54
  file_or_url_or_id = str(file_or_url_or_id)
55
  assert isinstance(file_or_url_or_id, str)
@@ -72,21 +80,120 @@ def _get_state_dict(
72
  return torch.load(file, map_location=device)
73
 
74
  else:
75
- raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
79
- pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
80
- cfg = OmegaConf.create(pkg['xp.cfg'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  cfg.device = str(device)
82
  model = builders.get_compression_model(cfg)
83
  model.load_state_dict(pkg['best_state'])
84
  model.eval()
85
  return model
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
89
- pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
90
  cfg = OmegaConf.create(pkg['xp.cfg'])
91
  cfg.device = str(device)
92
  if cfg.device == 'cpu':
@@ -95,8 +202,42 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_di
95
  cfg.dtype = 'float32'
96
  else:
97
  cfg.dtype = 'float16'
 
 
 
98
  model = builders.get_lm_model(cfg)
99
  model.load_state_dict(pkg['best_state'])
100
  model.eval()
101
  model.cfg = cfg
102
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  import typing as tp
25
  import os
26
 
27
+ from omegaconf import OmegaConf, DictConfig
28
  import torch
29
 
30
+ import audiocraft
31
  from . import builders
32
+ from .encodec import CompressionModel
33
+
34
+
35
+ def get_audiocraft_cache_dir() -> tp.Optional[str]:
36
+ return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
37
 
38
 
39
  HF_MODEL_CHECKPOINTS_MAP = {
 
56
  device='cpu',
57
  cache_dir: tp.Optional[str] = None,
58
  ):
59
+ if cache_dir is None:
60
+ cache_dir = get_audiocraft_cache_dir()
61
  # Return the state dict either from a file or url
62
  file_or_url_or_id = str(file_or_url_or_id)
63
  assert isinstance(file_or_url_or_id, str)
 
80
  return torch.load(file, map_location=device)
81
 
82
  else:
83
+ assert filename is not None, "filename needs to be defined if using HF checkpoints"
84
+
85
+ file = hf_hub_download(
86
+ repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir,
87
+ library_name="audiocraft", library_version=audiocraft.__version__)
88
+ return torch.load(file, map_location=device)
89
+
90
+ def create_melody_config(model_id: str, device: str) -> DictConfig:
91
+ """Create a fallback configuration for melody models.
92
+
93
+ Args:
94
+ model_id: The model identifier
95
+ device: The device to use
96
+
97
+ Returns:
98
+ A compatible OmegaConf DictConfig
99
+ """
100
+ base_cfg = {
101
+ "device": str(device),
102
+ "channels": 2 if "stereo" in model_id else 1,
103
+ "sample_rate": 32000,
104
+ "audio_channels": 2 if "stereo" in model_id else 1,
105
+ "frame_rate": 50,
106
+ "codec_name": "encodec",
107
+ "codec": {
108
+ "dim": 128,
109
+ "hidden_dim": 1024,
110
+ "stride": 320,
111
+ "n_q": 4,
112
+ "codebook_size": 2048,
113
+ "normalize": True,
114
+ }
115
+ }
116
+ return OmegaConf.create(base_cfg)
117
+
118
+ def create_default_config(model_id: str, device: str) -> DictConfig:
119
+ """Create a fallback configuration for standard models.
120
+
121
+ Args:
122
+ model_id: The model identifier
123
+ device: The device to use
124
+
125
+ Returns:
126
+ A compatible OmegaConf DictConfig
127
+ """
128
+ base_cfg = {
129
+ "device": str(device),
130
+ "channels": 2 if "stereo" in model_id else 1,
131
+ "sample_rate": 32000,
132
+ "audio_channels": 2 if "stereo" in model_id else 1,
133
+ "frame_rate": 50,
134
+ "codec_name": "encodec",
135
+ "codec": {
136
+ "dim": 128,
137
+ "hidden_dim": 1024,
138
+ "stride": 320,
139
+ "n_q": 4,
140
+ "codebook_size": 1024,
141
+ "normalize": True,
142
+ }
143
+ }
144
+ return OmegaConf.create(base_cfg)
145
+
146
+
147
+ def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
148
+ return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
149
 
150
 
151
  def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
152
+ pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
153
+ if 'pretrained' in pkg:
154
+ return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
155
+
156
+ # Handle newer model formats that might not have xp.cfg
157
+ if 'xp.cfg' not in pkg:
158
+ if file_or_url_or_id in ['melody-large', 'stereo-melody', 'stereo-medium',
159
+ 'stereo-small', 'stereo-large', 'stereo-melody-large']:
160
+ print(f"Using fallback configuration for {file_or_url_or_id}")
161
+ # Create a default configuration based on the model type
162
+ # This is where you'd need to add model-specific configurations
163
+ if 'melody' in file_or_url_or_id:
164
+ cfg = create_melody_config(file_or_url_or_id, device)
165
+ else:
166
+ cfg = create_default_config(file_or_url_or_id, device)
167
+ else:
168
+ raise KeyError(f"Missing configuration for model {file_or_url_or_id}")
169
+ else:
170
+ cfg = OmegaConf.create(pkg['xp.cfg'])
171
+
172
  cfg.device = str(device)
173
  model = builders.get_compression_model(cfg)
174
  model.load_state_dict(pkg['best_state'])
175
  model.eval()
176
  return model
177
 
178
+ def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
179
+ return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
180
+
181
+
182
+ def _delete_param(cfg: DictConfig, full_name: str):
183
+ parts = full_name.split('.')
184
+ for part in parts[:-1]:
185
+ if part in cfg:
186
+ cfg = cfg[part]
187
+ else:
188
+ return
189
+ OmegaConf.set_struct(cfg, False)
190
+ if parts[-1] in cfg:
191
+ del cfg[parts[-1]]
192
+ OmegaConf.set_struct(cfg, True)
193
+
194
 
195
  def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
196
+ pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
197
  cfg = OmegaConf.create(pkg['xp.cfg'])
198
  cfg.device = str(device)
199
  if cfg.device == 'cpu':
 
202
  cfg.dtype = 'float32'
203
  else:
204
  cfg.dtype = 'float16'
205
+ _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
206
+ _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
207
+ _delete_param(cfg, 'conditioners.args.drop_desc_p')
208
  model = builders.get_lm_model(cfg)
209
  model.load_state_dict(pkg['best_state'])
210
  model.eval()
211
  model.cfg = cfg
212
  return model
213
+
214
+
215
+ def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
216
+ filename: tp.Optional[str] = None,
217
+ cache_dir: tp.Optional[str] = None):
218
+ return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
219
+
220
+
221
+ def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
222
+ device='cpu',
223
+ filename: tp.Optional[str] = None,
224
+ cache_dir: tp.Optional[str] = None):
225
+ pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
226
+ models = []
227
+ processors = []
228
+ cfgs = []
229
+ sample_rate = pkg['sample_rate']
230
+ for i in range(pkg['n_bands']):
231
+ cfg = pkg[i]['cfg']
232
+ model = builders.get_diffusion_model(cfg)
233
+ model_dict = pkg[i]['model_state']
234
+ model.load_state_dict(model_dict)
235
+ model.to(device)
236
+ processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
237
+ processor_dict = pkg[i]['processor_state']
238
+ processor.load_state_dict(processor_dict)
239
+ processor.to(device)
240
+ models.append(model)
241
+ processors.append(processor)
242
+ cfgs.append(cfg)
243
+ return models, processors, cfgs
audiocraft/models/musicgen.py CHANGED
@@ -11,18 +11,19 @@ and provide easy access to the generation API.
11
 
12
  import os
13
  import typing as tp
 
14
 
 
15
  import torch
16
 
17
  from .encodec import CompressionModel
18
  from .lm import LMModel
19
- from .builders import get_debug_compression_model, get_debug_lm_model
20
  from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
21
  from ..data.audio_utils import convert_audio
22
  from ..modules.conditioners import ConditioningAttributes, WavCondition
23
  from ..utils.autocast import TorchAutocast
24
 
25
-
26
  MelodyList = tp.List[tp.Optional[torch.Tensor]]
27
  MelodyType = tp.Union[torch.Tensor, MelodyList]
28
 
@@ -35,11 +36,32 @@ class MusicGen:
35
  compression_model (CompressionModel): Compression model
36
  used to map audio to invertible discrete representations.
37
  lm (LMModel): Language model over discrete representations.
 
 
38
  """
39
- def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: float = 30):
40
  self.name = name
41
  self.compression_model = compression_model
42
  self.lm = lm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  self.max_duration = max_duration
44
  self.duration = 15.0 # default duration
45
  self.device = next(iter(lm.parameters())).device
@@ -53,7 +75,12 @@ class MusicGen:
53
  enabled=True, device_type=self.device.type, dtype=torch.float16)
54
 
55
  @property
56
- def frame_rate(self) -> int:
 
 
 
 
 
57
  """Roughly the number of AR steps per seconds."""
58
  return self.compression_model.frame_rate
59
 
@@ -100,12 +127,15 @@ class MusicGen:
100
  f"{name} is not a valid checkpoint name. "
101
  f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
102
  )
 
 
103
 
104
  cache_dir = os.environ.get('MUSICGEN_ROOT', None)
105
  compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
106
  lm = load_lm_model(name, device=device, cache_dir=cache_dir)
107
- if name == 'melody':
108
  lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
 
109
 
110
  return MusicGen(name, compression_model, lm)
111
 
@@ -125,6 +155,9 @@ class MusicGen:
125
  two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
126
  instead of batching together the two. This has some impact on how things
127
  are padded but seems to have little impact in practice.
 
 
 
128
  rep_penalty (float, optional): If set, use repetition penalty during generation. Not Implemented.
129
  """
130
  assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
@@ -137,47 +170,61 @@ class MusicGen:
137
  'top_k': top_k,
138
  'top_p': top_p,
139
  'cfg_coef': cfg_coef,
140
- 'two_step_cfg': two_step_cfg,
141
  }
142
 
143
  def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
144
  """Override the default progress callback."""
145
  self._progress_callback = progress_callback
146
 
147
- def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
 
 
148
  """Generate samples in an unconditional manner.
149
 
150
  Args:
151
  num_samples (int): Number of samples to be generated.
152
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
 
153
  """
154
  descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
155
  attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
156
- return self._generate_tokens(attributes, prompt_tokens, progress)
 
 
 
157
 
158
- def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
 
159
  """Generate samples conditioned on text.
160
 
161
  Args:
162
- descriptions (tp.List[str]): A list of strings used as text conditioning.
163
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
 
164
  """
165
  attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
166
  assert prompt_tokens is None
167
- return self._generate_tokens(attributes, prompt_tokens, progress)
 
 
 
168
 
169
  def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
170
- melody_sample_rate: int, progress: bool = False) -> torch.Tensor:
 
 
171
  """Generate samples conditioned on text and melody.
172
 
173
  Args:
174
- descriptions (tp.List[str]): A list of strings used as text conditioning.
175
  melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
176
  melody conditioning. Should have shape [B, C, T] with B matching the description length,
177
  C=1 or 2. It can be [C, T] if there is a single description. It can also be
178
  a list of [C, T] tensors.
179
  melody_sample_rate: (int): Sample rate of the melody waveforms.
180
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
 
181
  """
182
  if isinstance(melody_wavs, torch.Tensor):
183
  if melody_wavs.dim() == 2:
@@ -197,10 +244,14 @@ class MusicGen:
197
  attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
198
  melody_wavs=melody_wavs)
199
  assert prompt_tokens is None
200
- return self._generate_tokens(attributes, prompt_tokens, progress)
 
 
 
201
 
202
  def generate_with_all(self, descriptions: tp.List[str], melody_wavs: MelodyType,
203
- sample_rate: int, progress: bool = False, prompt: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
 
204
  """Generate samples conditioned on text and melody and audio prompts.
205
  Args:
206
  descriptions (tp.List[str]): A list of strings used as text conditioning.
@@ -249,19 +300,24 @@ class MusicGen:
249
  assert prompt_tokens is not None
250
  else:
251
  assert prompt_tokens is None
252
- return self._generate_tokens(attributes, prompt_tokens, progress)
 
 
 
253
 
254
  def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
255
  descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
256
- progress: bool = False) -> torch.Tensor:
 
257
  """Generate samples conditioned on audio prompts.
258
 
259
  Args:
260
  prompt (torch.Tensor): A batch of waveforms used for continuation.
261
  Prompt should be [B, C, T], or [C, T] if only one sample is generated.
262
  prompt_sample_rate (int): Sampling rate of the given audio waveforms.
263
- descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None.
264
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
 
265
  """
266
  if prompt.dim() == 2:
267
  prompt = prompt[None]
@@ -272,7 +328,10 @@ class MusicGen:
272
  descriptions = [None] * len(prompt)
273
  attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
274
  assert prompt_tokens is not None
275
- return self._generate_tokens(attributes, prompt_tokens, progress)
 
 
 
276
 
277
  @torch.no_grad()
278
  def _prepare_tokens_and_attributes(
@@ -284,9 +343,9 @@ class MusicGen:
284
  """Prepare model inputs.
285
 
286
  Args:
287
- descriptions (tp.List[str]): A list of strings used as text conditioning.
288
  prompt (torch.Tensor): A batch of waveforms used for continuation.
289
- melody_wavs (tp.Optional[torch.Tensor], optional): A batch of waveforms
290
  used as melody conditioning. Defaults to None.
291
  """
292
  attributes = [
@@ -296,11 +355,12 @@ class MusicGen:
296
  if melody_wavs is None:
297
  for attr in attributes:
298
  attr.wav['self_wav'] = WavCondition(
299
- torch.zeros((1, 1), device=self.device),
300
  torch.tensor([0], device=self.device),
301
- path='null_wav') # type: ignore
 
302
  else:
303
- if self.name != "melody":
304
  raise RuntimeError("This model doesn't support melody conditioning. "
305
  "Use the `melody` model.")
306
  assert len(melody_wavs) == len(descriptions), \
@@ -309,13 +369,17 @@ class MusicGen:
309
  for attr, melody in zip(attributes, melody_wavs):
310
  if melody is None:
311
  attr.wav['self_wav'] = WavCondition(
312
- torch.zeros((1, 1), device=self.device),
313
  torch.tensor([0], device=self.device),
314
- path='null_wav') # type: ignore
 
315
  else:
316
  attr.wav['self_wav'] = WavCondition(
317
- melody.to(device=self.device),
318
- torch.tensor([melody.shape[-1]], device=self.device))
 
 
 
319
 
320
  if prompt is not None:
321
  if descriptions is not None:
@@ -396,8 +460,10 @@ class MusicGen:
396
  positions = torch.arange(initial_position,
397
  initial_position + wav_target_length, device=self.device)
398
  attr.wav['self_wav'] = WavCondition(
399
- ref_wav[0][:, positions % wav_length],
400
- torch.full_like(ref_wav[1], wav_target_length))
 
 
401
  with self.autocast:
402
  gen_tokens = self.lm.generate(
403
  prompt_tokens, attributes,
@@ -411,13 +477,21 @@ class MusicGen:
411
  current_gen_offset += stride_tokens
412
 
413
  gen_tokens = torch.cat(all_tokens, dim=-1)
 
414
 
415
  # generate audio
416
- assert gen_tokens.dim() == 3
417
- with torch.no_grad():
418
- gen_audio = self.compression_model.decode(gen_tokens, None)
419
- return gen_audio
420
 
 
 
 
 
 
 
 
 
 
 
 
421
  #def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
422
  # prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
423
  # """Generate discrete audio tokens given audio prompt and/or conditions.
 
11
 
12
  import os
13
  import typing as tp
14
+ import warnings
15
 
16
+ import omegaconf
17
  import torch
18
 
19
  from .encodec import CompressionModel
20
  from .lm import LMModel
21
+ from .builders import get_debug_compression_model, get_debug_lm_model, get_wrapped_compression_model
22
  from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
23
  from ..data.audio_utils import convert_audio
24
  from ..modules.conditioners import ConditioningAttributes, WavCondition
25
  from ..utils.autocast import TorchAutocast
26
 
 
27
  MelodyList = tp.List[tp.Optional[torch.Tensor]]
28
  MelodyType = tp.Union[torch.Tensor, MelodyList]
29
 
 
36
  compression_model (CompressionModel): Compression model
37
  used to map audio to invertible discrete representations.
38
  lm (LMModel): Language model over discrete representations.
39
+ max_duration (float, optional): maximum duration the model can produce,
40
+ otherwise, inferred from the training params.
41
  """
42
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel, max_duration: tp.Optional[float] = 30):
43
  self.name = name
44
  self.compression_model = compression_model
45
  self.lm = lm
46
+ self.cfg: tp.Optional[omegaconf.DictConfig] = None
47
+ # Just to be safe, let's put everything in eval mode.
48
+ self.compression_model.eval()
49
+ self.lm.eval()
50
+
51
+ if hasattr(lm, 'cfg'):
52
+ cfg = lm.cfg
53
+ assert isinstance(cfg, omegaconf.DictConfig)
54
+ self.cfg = cfg
55
+
56
+ if self.cfg is not None:
57
+ self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
58
+
59
+ if max_duration is None:
60
+ if self.cfg is not None:
61
+ max_duration = lm.cfg.dataset.segment_duration # type: ignore
62
+ else:
63
+ raise ValueError("You must provide max_duration when building directly MusicGen")
64
+ assert max_duration is not None
65
  self.max_duration = max_duration
66
  self.duration = 15.0 # default duration
67
  self.device = next(iter(lm.parameters())).device
 
75
  enabled=True, device_type=self.device.type, dtype=torch.float16)
76
 
77
  @property
78
+ def version(self) -> str:
79
+ from audiocraft import __version__ as audiocraft_version
80
+ return audiocraft_version
81
+
82
+ @property
83
+ def frame_rate(self) -> float:
84
  """Roughly the number of AR steps per seconds."""
85
  return self.compression_model.frame_rate
86
 
 
127
  f"{name} is not a valid checkpoint name. "
128
  f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
129
  )
130
+ else:
131
+ name = HF_MODEL_CHECKPOINTS_MAP[name]
132
 
133
  cache_dir = os.environ.get('MUSICGEN_ROOT', None)
134
  compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
135
  lm = load_lm_model(name, device=device, cache_dir=cache_dir)
136
+ if name.__contains__('melody') or 'self_wav' in lm.condition_provider.conditioners:
137
  lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
138
+ lm.condition_provider.conditioners['self_wav']._use_masking = False
139
 
140
  return MusicGen(name, compression_model, lm)
141
 
 
155
  two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
156
  instead of batching together the two. This has some impact on how things
157
  are padded but seems to have little impact in practice.
158
+ extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
159
+ should we extend the audio each time. Larger values will mean less context is
160
+ preserved, and shorter value will require extra computations.
161
  rep_penalty (float, optional): If set, use repetition penalty during generation. Not Implemented.
162
  """
163
  assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
 
170
  'top_k': top_k,
171
  'top_p': top_p,
172
  'cfg_coef': cfg_coef,
173
+ 'two_step_cfg': two_step_cfg,
174
  }
175
 
176
  def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
177
  """Override the default progress callback."""
178
  self._progress_callback = progress_callback
179
 
180
+ def generate_unconditional(self, num_samples: int, progress: bool = False,
181
+ return_tokens: bool = False) -> tp.Union[torch.Tensor,
182
+ tp.Tuple[torch.Tensor, torch.Tensor]]:
183
  """Generate samples in an unconditional manner.
184
 
185
  Args:
186
  num_samples (int): Number of samples to be generated.
187
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
188
+ return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.
189
  """
190
  descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
191
  attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
192
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
193
+ if return_tokens:
194
+ return self.generate_audio(tokens), tokens
195
+ return self.generate_audio(tokens)
196
 
197
+ def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
198
+ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
199
  """Generate samples conditioned on text.
200
 
201
  Args:
202
+ descriptions (list of str): A list of strings used as text conditioning.
203
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
204
+ return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.
205
  """
206
  attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
207
  assert prompt_tokens is None
208
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
209
+ if return_tokens:
210
+ return self.generate_audio(tokens), tokens
211
+ return self.generate_audio(tokens)
212
 
213
  def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
214
+ melody_sample_rate: int, progress: bool = False,
215
+ return_tokens: bool = False) -> tp.Union[torch.Tensor,
216
+ tp.Tuple[torch.Tensor, torch.Tensor]]:
217
  """Generate samples conditioned on text and melody.
218
 
219
  Args:
220
+ descriptions (list of str): A list of strings used as text conditioning.
221
  melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
222
  melody conditioning. Should have shape [B, C, T] with B matching the description length,
223
  C=1 or 2. It can be [C, T] if there is a single description. It can also be
224
  a list of [C, T] tensors.
225
  melody_sample_rate: (int): Sample rate of the melody waveforms.
226
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
227
+ return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.
228
  """
229
  if isinstance(melody_wavs, torch.Tensor):
230
  if melody_wavs.dim() == 2:
 
244
  attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
245
  melody_wavs=melody_wavs)
246
  assert prompt_tokens is None
247
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
248
+ if return_tokens:
249
+ return self.generate_audio(tokens), tokens
250
+ return self.generate_audio(tokens)
251
 
252
  def generate_with_all(self, descriptions: tp.List[str], melody_wavs: MelodyType,
253
+ sample_rate: int, progress: bool = False, prompt: tp.Optional[torch.Tensor] = None, return_tokens: bool = False) \
254
+ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
255
  """Generate samples conditioned on text and melody and audio prompts.
256
  Args:
257
  descriptions (tp.List[str]): A list of strings used as text conditioning.
 
300
  assert prompt_tokens is not None
301
  else:
302
  assert prompt_tokens is None
303
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
304
+ if return_tokens:
305
+ return self.generate_audio(tokens), tokens
306
+ return self.generate_audio(tokens)
307
 
308
  def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
309
  descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
310
+ progress: bool = False, return_tokens: bool = False) \
311
+ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
312
  """Generate samples conditioned on audio prompts.
313
 
314
  Args:
315
  prompt (torch.Tensor): A batch of waveforms used for continuation.
316
  Prompt should be [B, C, T], or [C, T] if only one sample is generated.
317
  prompt_sample_rate (int): Sampling rate of the given audio waveforms.
318
+ descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
319
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
320
+ return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.
321
  """
322
  if prompt.dim() == 2:
323
  prompt = prompt[None]
 
328
  descriptions = [None] * len(prompt)
329
  attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
330
  assert prompt_tokens is not None
331
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
332
+ if return_tokens:
333
+ return self.generate_audio(tokens), tokens
334
+ return self.generate_audio(tokens)
335
 
336
  @torch.no_grad()
337
  def _prepare_tokens_and_attributes(
 
343
  """Prepare model inputs.
344
 
345
  Args:
346
+ descriptions (list of str): A list of strings used as text conditioning.
347
  prompt (torch.Tensor): A batch of waveforms used for continuation.
348
+ melody_wavs (torch.Tensor, optional): A batch of waveforms
349
  used as melody conditioning. Defaults to None.
350
  """
351
  attributes = [
 
355
  if melody_wavs is None:
356
  for attr in attributes:
357
  attr.wav['self_wav'] = WavCondition(
358
+ torch.zeros((1, 1, 1), device=self.device),
359
  torch.tensor([0], device=self.device),
360
+ sample_rate=[self.sample_rate],
361
+ path=[None]) # type: ignore
362
  else:
363
+ if 'self_wav' not in self.lm.condition_provider.conditioners:
364
  raise RuntimeError("This model doesn't support melody conditioning. "
365
  "Use the `melody` model.")
366
  assert len(melody_wavs) == len(descriptions), \
 
369
  for attr, melody in zip(attributes, melody_wavs):
370
  if melody is None:
371
  attr.wav['self_wav'] = WavCondition(
372
+ torch.zeros((1, 1, 1), device=self.device),
373
  torch.tensor([0], device=self.device),
374
+ sample_rate=[self.sample_rate],
375
+ path=[None]) # type: ignore
376
  else:
377
  attr.wav['self_wav'] = WavCondition(
378
+ melody[None].to(device=self.device),
379
+ torch.tensor([melody.shape[-1]], device=self.device),
380
+ sample_rate=[self.sample_rate],
381
+ path=[None],
382
+ )
383
 
384
  if prompt is not None:
385
  if descriptions is not None:
 
460
  positions = torch.arange(initial_position,
461
  initial_position + wav_target_length, device=self.device)
462
  attr.wav['self_wav'] = WavCondition(
463
+ ref_wav[0][..., positions % wav_length],
464
+ torch.full_like(ref_wav[1], wav_target_length),
465
+ [self.sample_rate] * ref_wav[0].size(0),
466
+ [None], [0.])
467
  with self.autocast:
468
  gen_tokens = self.lm.generate(
469
  prompt_tokens, attributes,
 
477
  current_gen_offset += stride_tokens
478
 
479
  gen_tokens = torch.cat(all_tokens, dim=-1)
480
+ return gen_tokens
481
 
482
  # generate audio
 
 
 
 
483
 
484
+ def generate_audio(self, gen_tokens: torch.Tensor):
485
+ try:
486
+ """Generate Audio from tokens"""
487
+ assert gen_tokens.dim() == 3
488
+ with torch.no_grad():
489
+ gen_audio = self.compression_model.decode(gen_tokens, None)
490
+ return gen_audio
491
+ except Exception as e:
492
+ print(f"Error generating audio: {e}")
493
+ return None
494
+
495
  #def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
496
  # prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
497
  # """Generate discrete audio tokens given audio prompt and/or conditions.
audiocraft/models/unet.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Pytorch Unet Module used for diffusion.
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ import typing as tp
13
+
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+ from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding
18
+
19
+
20
+ @dataclass
21
+ class Output:
22
+ sample: torch.Tensor
23
+
24
+
25
+ def get_model(cfg, channels: int, side: int, num_steps: int):
26
+ if cfg.model == 'unet':
27
+ return DiffusionUnet(
28
+ chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
29
+ else:
30
+ raise RuntimeError('Not Implemented')
31
+
32
+
33
+ class ResBlock(nn.Module):
34
+ def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
35
+ dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
36
+ dropout: float = 0.):
37
+ super().__init__()
38
+ stride = 1
39
+ padding = dilation * (kernel - stride) // 2
40
+ Conv = nn.Conv1d
41
+ Drop = nn.Dropout1d
42
+ self.norm1 = nn.GroupNorm(norm_groups, channels)
43
+ self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
44
+ self.activation1 = activation()
45
+ self.dropout1 = Drop(dropout)
46
+
47
+ self.norm2 = nn.GroupNorm(norm_groups, channels)
48
+ self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
49
+ self.activation2 = activation()
50
+ self.dropout2 = Drop(dropout)
51
+
52
+ def forward(self, x):
53
+ h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
54
+ h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
55
+ return x + h
56
+
57
+
58
+ class DecoderLayer(nn.Module):
59
+ def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
60
+ norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
61
+ dropout: float = 0.):
62
+ super().__init__()
63
+ padding = (kernel - stride) // 2
64
+ self.res_blocks = nn.Sequential(
65
+ *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
66
+ for idx in range(res_blocks)])
67
+ self.norm = nn.GroupNorm(norm_groups, chin)
68
+ ConvTr = nn.ConvTranspose1d
69
+ self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
70
+ self.activation = activation()
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ x = self.res_blocks(x)
74
+ x = self.norm(x)
75
+ x = self.activation(x)
76
+ x = self.convtr(x)
77
+ return x
78
+
79
+
80
+ class EncoderLayer(nn.Module):
81
+ def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
82
+ norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
83
+ dropout: float = 0.):
84
+ super().__init__()
85
+ padding = (kernel - stride) // 2
86
+ Conv = nn.Conv1d
87
+ self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
88
+ self.norm = nn.GroupNorm(norm_groups, chout)
89
+ self.activation = activation()
90
+ self.res_blocks = nn.Sequential(
91
+ *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
92
+ for idx in range(res_blocks)])
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ B, C, T = x.shape
96
+ stride, = self.conv.stride
97
+ pad = (stride - (T % stride)) % stride
98
+ x = F.pad(x, (0, pad))
99
+
100
+ x = self.conv(x)
101
+ x = self.norm(x)
102
+ x = self.activation(x)
103
+ x = self.res_blocks(x)
104
+ return x
105
+
106
+
107
+ class BLSTM(nn.Module):
108
+ """BiLSTM with same hidden units as input dim.
109
+ """
110
+ def __init__(self, dim, layers=2):
111
+ super().__init__()
112
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
113
+ self.linear = nn.Linear(2 * dim, dim)
114
+
115
+ def forward(self, x):
116
+ x = x.permute(2, 0, 1)
117
+ x = self.lstm(x)[0]
118
+ x = self.linear(x)
119
+ x = x.permute(1, 2, 0)
120
+ return x
121
+
122
+
123
+ class DiffusionUnet(nn.Module):
124
+ def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
125
+ max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
126
+ bilstm: bool = False, transformer: bool = False,
127
+ codec_dim: tp.Optional[int] = None, **kwargs):
128
+ super().__init__()
129
+ self.encoders = nn.ModuleList()
130
+ self.decoders = nn.ModuleList()
131
+ self.embeddings: tp.Optional[nn.ModuleList] = None
132
+ self.embedding = nn.Embedding(num_steps, hidden)
133
+ if emb_all_layers:
134
+ self.embeddings = nn.ModuleList()
135
+ self.condition_embedding: tp.Optional[nn.Module] = None
136
+ for d in range(depth):
137
+ encoder = EncoderLayer(chin, hidden, **kwargs)
138
+ decoder = DecoderLayer(hidden, chin, **kwargs)
139
+ self.encoders.append(encoder)
140
+ self.decoders.insert(0, decoder)
141
+ if emb_all_layers and d > 0:
142
+ assert self.embeddings is not None
143
+ self.embeddings.append(nn.Embedding(num_steps, hidden))
144
+ chin = hidden
145
+ hidden = min(int(chin * growth), max_channels)
146
+ self.bilstm: tp.Optional[nn.Module]
147
+ if bilstm:
148
+ self.bilstm = BLSTM(chin)
149
+ else:
150
+ self.bilstm = None
151
+ self.use_transformer = transformer
152
+ self.cross_attention = False
153
+ if transformer:
154
+ self.cross_attention = cross_attention
155
+ self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
156
+ cross_attention=cross_attention)
157
+
158
+ self.use_codec = False
159
+ if codec_dim is not None:
160
+ self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
161
+ self.use_codec = True
162
+
163
+ def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
164
+ skips = []
165
+ bs = x.size(0)
166
+ z = x
167
+ view_args = [1]
168
+ if type(step) is torch.Tensor:
169
+ step_tensor = step
170
+ else:
171
+ step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
172
+
173
+ for idx, encoder in enumerate(self.encoders):
174
+ z = encoder(z)
175
+ if idx == 0:
176
+ z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
177
+ elif self.embeddings is not None:
178
+ z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
179
+
180
+ skips.append(z)
181
+
182
+ if self.use_codec: # insert condition in the bottleneck
183
+ assert condition is not None, "Model defined for conditionnal generation"
184
+ condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
185
+ assert condition_emb.size(-1) <= 2 * z.size(-1), \
186
+ f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
187
+ if not self.cross_attention:
188
+
189
+ condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
190
+ assert z.size() == condition_emb.size()
191
+ z += condition_emb
192
+ cross_attention_src = None
193
+ else:
194
+ cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
195
+ B, T, C = cross_attention_src.shape
196
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
197
+ pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
198
+ cross_attention_src = cross_attention_src + pos_emb
199
+ if self.use_transformer:
200
+ z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
201
+ else:
202
+ if self.bilstm is None:
203
+ z = torch.zeros_like(z)
204
+ else:
205
+ z = self.bilstm(z)
206
+
207
+ for decoder in self.decoders:
208
+ s = skips.pop(-1)
209
+ z = z[:, :, :s.shape[2]]
210
+ z = z + s
211
+ z = decoder(z)
212
+
213
+ z = z[:, :, :x.shape[2]]
214
+ return Output(z)
audiocraft/modules/__init__.py CHANGED
@@ -18,3 +18,4 @@ from .conv import (
18
  )
19
  from .lstm import StreamableLSTM
20
  from .seanet import SEANetEncoder, SEANetDecoder
 
 
18
  )
19
  from .lstm import StreamableLSTM
20
  from .seanet import SEANetEncoder, SEANetDecoder
21
+ from .transformer import StreamingTransformer
audiocraft/modules/chroma.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import typing as tp
7
+
8
+ from einops import rearrange
9
+ from librosa import filters
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+ import torchaudio
14
+
15
+
16
+ class ChromaExtractor(nn.Module):
17
+ """Chroma extraction and quantization.
18
+
19
+ Args:
20
+ sample_rate (int): Sample rate for the chroma extraction.
21
+ n_chroma (int): Number of chroma bins for the chroma extraction.
22
+ radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
23
+ nfft (int, optional): Number of FFT.
24
+ winlen (int, optional): Window length.
25
+ winhop (int, optional): Window hop size.
26
+ argmax (bool, optional): Whether to use argmax. Defaults to False.
27
+ norm (float, optional): Norm for chroma normalization. Defaults to inf.
28
+ """
29
+ def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
30
+ winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
31
+ norm: float = torch.inf):
32
+ super().__init__()
33
+ self.winlen = winlen or 2 ** radix2_exp
34
+ self.nfft = nfft or self.winlen
35
+ self.winhop = winhop or (self.winlen // 4)
36
+ self.sample_rate = sample_rate
37
+ self.n_chroma = n_chroma
38
+ self.norm = norm
39
+ self.argmax = argmax
40
+ self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
41
+ n_chroma=self.n_chroma)), persistent=False)
42
+ self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
43
+ hop_length=self.winhop, power=2, center=True,
44
+ pad=0, normalized=True)
45
+
46
+ def forward(self, wav: torch.Tensor) -> torch.Tensor:
47
+ T = wav.shape[-1]
48
+ # in case we are getting a wav that was dropped out (nullified)
49
+ # from the conditioner, make sure wav length is no less that nfft
50
+ if T < self.nfft:
51
+ pad = self.nfft - T
52
+ r = 0 if pad % 2 == 0 else 1
53
+ wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
54
+ assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"
55
+
56
+ spec = self.spec(wav).squeeze(1)
57
+ raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
58
+ norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
59
+ norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')
60
+
61
+ if self.argmax:
62
+ idx = norm_chroma.argmax(-1, keepdim=True)
63
+ norm_chroma[:] = 0
64
+ norm_chroma.scatter_(dim=-1, index=idx, value=1)
65
+
66
+ return norm_chroma
audiocraft/modules/codebooks_patterns.py CHANGED
@@ -122,7 +122,7 @@ class Pattern:
122
  Args:
123
  timesteps (int): Maximum number of timesteps steps to consider.
124
  keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
125
- device (Union[torch.device, str]): Device for created tensors.
126
  Returns:
127
  indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
128
  mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
@@ -189,9 +189,9 @@ class Pattern:
189
  keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
190
  Steps that are beyond valid steps will be replaced by the special_token in that case.
191
  is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
192
- device (Union[torch.device, str]): Device for created tensors.
193
  Returns:
194
- torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
195
  mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
196
  """
197
  ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
@@ -295,7 +295,7 @@ class CodebooksPatternProvider(ABC):
295
  """Builds pattern with specific interleaving between codebooks.
296
 
297
  Args:
298
- timesteps (int): Total numer of timesteps.
299
  """
300
  raise NotImplementedError()
301
 
@@ -318,7 +318,7 @@ class DelayedPatternProvider(CodebooksPatternProvider):
318
 
319
  Args:
320
  n_q (int): Number of codebooks.
321
- delays (Optional[List[int]]): Delay for each of the codebooks.
322
  If delays not defined, each codebook is delayed by 1 compared to the previous one.
323
  flatten_first (int): Flatten the first N timesteps.
324
  empty_initial (int): Prepend with N empty list of coordinates.
@@ -406,10 +406,10 @@ class UnrolledPatternProvider(CodebooksPatternProvider):
406
 
407
  Args:
408
  n_q (int): Number of codebooks.
409
- flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
410
  the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
411
  have n_q extra steps for each timestep.
412
- delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
413
  no delay is added and therefore will default to [0] * ``n_q``.
414
  Note that two codebooks that will be flattened to the same inner step
415
  should have the same delay, otherwise the pattern is considered as invalid.
@@ -462,7 +462,7 @@ class UnrolledPatternProvider(CodebooksPatternProvider):
462
  """Builds pattern for delay across codebooks.
463
 
464
  Args:
465
- timesteps (int): Total numer of timesteps.
466
  """
467
  # the PatternLayout is built as a tuple of sequence position and list of coordinates
468
  # so that it can be reordered properly given the required delay between codebooks of given timesteps
@@ -486,13 +486,18 @@ class UnrolledPatternProvider(CodebooksPatternProvider):
486
  return Pattern(out, n_q=self.n_q, timesteps=timesteps)
487
 
488
 
489
- class VALLEPattern(CodebooksPatternProvider):
490
- """Almost VALL-E style pattern. We futher allow some delays for the
491
- codebooks other than the first one.
 
 
 
 
 
492
 
493
  Args:
494
  n_q (int): Number of codebooks.
495
- delays (Optional[List[int]]): Delay for each of the codebooks.
496
  If delays not defined, each codebook is delayed by 1 compared to the previous one.
497
  """
498
  def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
 
122
  Args:
123
  timesteps (int): Maximum number of timesteps steps to consider.
124
  keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
125
+ device (torch.device or str): Device for created tensors.
126
  Returns:
127
  indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
128
  mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
 
189
  keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
190
  Steps that are beyond valid steps will be replaced by the special_token in that case.
191
  is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
192
+ device (torch.device or str): Device for created tensors.
193
  Returns:
194
+ indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
195
  mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
196
  """
197
  ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
 
295
  """Builds pattern with specific interleaving between codebooks.
296
 
297
  Args:
298
+ timesteps (int): Total number of timesteps.
299
  """
300
  raise NotImplementedError()
301
 
 
318
 
319
  Args:
320
  n_q (int): Number of codebooks.
321
+ delays (list of int, optional): Delay for each of the codebooks.
322
  If delays not defined, each codebook is delayed by 1 compared to the previous one.
323
  flatten_first (int): Flatten the first N timesteps.
324
  empty_initial (int): Prepend with N empty list of coordinates.
 
406
 
407
  Args:
408
  n_q (int): Number of codebooks.
409
+ flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
410
  the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
411
  have n_q extra steps for each timestep.
412
+ delays (list of int, optional): Delay for each of the codebooks. If not defined,
413
  no delay is added and therefore will default to [0] * ``n_q``.
414
  Note that two codebooks that will be flattened to the same inner step
415
  should have the same delay, otherwise the pattern is considered as invalid.
 
462
  """Builds pattern for delay across codebooks.
463
 
464
  Args:
465
+ timesteps (int): Total number of timesteps.
466
  """
467
  # the PatternLayout is built as a tuple of sequence position and list of coordinates
468
  # so that it can be reordered properly given the required delay between codebooks of given timesteps
 
486
  return Pattern(out, n_q=self.n_q, timesteps=timesteps)
487
 
488
 
489
+ class CoarseFirstPattern(CodebooksPatternProvider):
490
+ """First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
491
+ potentially with delays.
492
+
493
+ ..Warning:: You must always generate the full training duration at test time, for instance,
494
+ 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
495
+ location. This is due to the non causality of the remaining codebooks with respect to
496
+ the first ones.
497
 
498
  Args:
499
  n_q (int): Number of codebooks.
500
+ delays (list of int, optional): Delay for each of the codebooks.
501
  If delays not defined, each codebook is delayed by 1 compared to the previous one.
502
  """
503
  def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
audiocraft/modules/conditioners.py CHANGED
@@ -10,87 +10,61 @@ from dataclasses import dataclass, field
10
  from itertools import chain
11
  import logging
12
  import math
 
13
  import random
14
  import re
15
  import typing as tp
16
  import warnings
17
 
18
- from einops import rearrange
19
  from num2words import num2words
20
  import spacy
21
- from transformers import T5EncoderModel, T5Tokenizer # type: ignore
22
- import torchaudio
23
  import torch
24
  from torch import nn
25
- from torch import Tensor
26
  import torch.nn.functional as F
27
  from torch.nn.utils.rnn import pad_sequence
28
 
 
29
  from .streaming import StreamingModule
30
  from .transformer import create_sin_embedding
 
31
  from ..data.audio_dataset import SegmentInfo
 
 
 
32
  from ..utils.autocast import TorchAutocast
33
- from ..utils.utils import hash_trick, length_to_mask, collate
 
34
 
35
 
36
  logger = logging.getLogger(__name__)
37
  TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
38
- ConditionType = tp.Tuple[Tensor, Tensor] # condition, mask
39
 
40
 
41
  class WavCondition(tp.NamedTuple):
42
- wav: Tensor
43
- length: Tensor
 
44
  path: tp.List[tp.Optional[str]] = []
 
45
 
46
 
47
- def nullify_condition(condition: ConditionType, dim: int = 1):
48
- """This function transforms an input condition to a null condition.
49
- The way it is done by converting it to a single zero vector similarly
50
- to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
51
-
52
- Args:
53
- condition (ConditionType): a tuple of condition and mask (tp.Tuple[Tensor, Tensor])
54
- dim (int): the dimension that will be truncated (should be the time dimension)
55
- WARNING!: dim should not be the batch dimension!
56
- Returns:
57
- ConditionType: a tuple of null condition and mask
58
- """
59
- assert dim != 0, "dim cannot be the batch dimension!"
60
- assert type(condition) == tuple and \
61
- type(condition[0]) == Tensor and \
62
- type(condition[1]) == Tensor, "'nullify_condition' got an unexpected input type!"
63
- cond, mask = condition
64
- B = cond.shape[0]
65
- last_dim = cond.dim() - 1
66
- out = cond.transpose(dim, last_dim)
67
- out = 0. * out[..., :1]
68
- out = out.transpose(dim, last_dim)
69
- mask = torch.zeros((B, 1), device=out.device).int()
70
- assert cond.dim() == out.dim()
71
- return out, mask
72
-
73
-
74
- def nullify_wav(wav: Tensor) -> WavCondition:
75
- """Create a nullified WavCondition from a wav tensor with appropriate shape.
76
-
77
- Args:
78
- wav (Tensor): tensor of shape [B, T]
79
- Returns:
80
- WavCondition: wav condition with nullified wav.
81
- """
82
- null_wav, _ = nullify_condition((wav, torch.zeros_like(wav)), dim=wav.dim() - 1)
83
- return WavCondition(
84
- wav=null_wav,
85
- length=torch.tensor([0] * wav.shape[0], device=wav.device),
86
- path=['null_wav'] * wav.shape[0]
87
- )
88
 
89
 
90
  @dataclass
91
  class ConditioningAttributes:
92
  text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
93
  wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
 
94
 
95
  def __getitem__(self, item):
96
  return getattr(self, item)
@@ -103,14 +77,23 @@ class ConditioningAttributes:
103
  def wav_attributes(self):
104
  return self.wav.keys()
105
 
 
 
 
 
106
  @property
107
  def attributes(self):
108
- return {"text": self.text_attributes, "wav": self.wav_attributes}
 
 
 
 
109
 
110
  def to_flat_dict(self):
111
  return {
112
  **{f"text.{k}": v for k, v in self.text.items()},
113
  **{f"wav.{k}": v for k, v in self.wav.items()},
 
114
  }
115
 
116
  @classmethod
@@ -131,11 +114,74 @@ class SegmentWithAttributes(SegmentInfo):
131
  raise NotImplementedError()
132
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  class Tokenizer:
135
- """Base class for all tokenizers
136
  (in case we want to introduce more advances tokenizers in the future).
137
  """
138
- def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
139
  raise NotImplementedError()
140
 
141
 
@@ -146,7 +192,7 @@ class WhiteSpaceTokenizer(Tokenizer):
146
  [[78, 62, 31, 4, 78, 25, 19, 34],
147
  [59, 77, 0, 0, 0, 0, 0, 0]]
148
  """
149
- PUNCTUATIONS = "?:!.,;"
150
 
151
  def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
152
  lemma: bool = True, stopwords: bool = True) -> None:
@@ -161,18 +207,15 @@ class WhiteSpaceTokenizer(Tokenizer):
161
  self.nlp = spacy.load(language)
162
 
163
  @tp.no_type_check
164
- def __call__(
165
- self,
166
- texts: tp.List[tp.Optional[str]],
167
- return_text: bool = False
168
- ) -> tp.Tuple[Tensor, Tensor]:
169
  """Take a list of strings and convert them to a tensor of indices.
170
 
171
  Args:
172
- texts (tp.List[str]): List of strings.
173
  return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
174
  Returns:
175
- tp.Tuple[Tensor, Tensor]:
176
  - Indices of words in the LUT.
177
  - And a mask indicating where the padding tokens are
178
  """
@@ -181,7 +224,7 @@ class WhiteSpaceTokenizer(Tokenizer):
181
  for i, text in enumerate(texts):
182
  # if current sample doesn't have a certain attribute, replace with pad token
183
  if text is None:
184
- output.append(Tensor([self.pad_idx]))
185
  lengths.append(0)
186
  continue
187
 
@@ -192,15 +235,15 @@ class WhiteSpaceTokenizer(Tokenizer):
192
  # remove stopwords
193
  if self.stopwords:
194
  text = [w for w in text if not w.is_stop] # type: ignore
195
- # remove punctuations
196
- text = [w for w in text if w.text not in self.PUNCTUATIONS] # type: ignore
197
  # lemmatize if needed
198
  text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
199
 
200
  texts[i] = " ".join(text)
201
  lengths.append(len(text))
202
  # convert to tensor
203
- tokens = Tensor([hash_trick(w, self.n_bins) for w in text])
204
  output.append(tokens)
205
 
206
  mask = length_to_mask(torch.IntTensor(lengths)).int()
@@ -224,7 +267,7 @@ class NoopTokenizer(Tokenizer):
224
  self.n_bins = n_bins
225
  self.pad_idx = pad_idx
226
 
227
- def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
228
  output, lengths = [], []
229
  for text in texts:
230
  # if current sample doesn't have a certain attribute, replace with pad token
@@ -241,15 +284,16 @@ class NoopTokenizer(Tokenizer):
241
 
242
 
243
  class BaseConditioner(nn.Module):
244
- """Base model for all conditioner modules. We allow the output dim to be different
245
- than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large;
 
246
  2) make all condition dims consistent.
247
 
248
  Args:
249
- dim (int): Hidden dim of the model (text-encoder/LUT).
250
  output_dim (int): Output dim of the conditioner.
251
  """
252
- def __init__(self, dim, output_dim):
253
  super().__init__()
254
  self.dim = dim
255
  self.output_dim = output_dim
@@ -294,9 +338,9 @@ class LUTConditioner(TextConditioner):
294
  super().__init__(dim, output_dim)
295
  self.embed = nn.Embedding(n_bins, dim)
296
  self.tokenizer: Tokenizer
297
- if tokenizer == "whitespace":
298
  self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
299
- elif tokenizer == "noop":
300
  self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
301
  else:
302
  raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
@@ -346,13 +390,12 @@ class T5Conditioner(TextConditioner):
346
  def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
347
  autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
348
  normalize_text: bool = False):
349
- assert name in self.MODELS, f"unrecognized t5 model name (should in {self.MODELS})"
350
  super().__init__(self.MODELS_DIMS[name], output_dim)
351
  self.device = device
352
  self.name = name
353
  self.finetune = finetune
354
  self.word_dropout = word_dropout
355
-
356
  if autocast_dtype is None or self.device == 'cpu':
357
  self.autocast = TorchAutocast(enabled=False)
358
  if self.device != 'cpu':
@@ -378,7 +421,7 @@ class T5Conditioner(TextConditioner):
378
  else:
379
  # this makes sure that the t5 models is not part
380
  # of the saved checkpoint
381
- self.__dict__["t5"] = t5.to(device)
382
 
383
  self.normalize_text = normalize_text
384
  if normalize_text:
@@ -398,13 +441,13 @@ class T5Conditioner(TextConditioner):
398
 
399
  empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
400
 
401
- inputs = self.t5_tokenizer(entries, return_tensors="pt", padding=True).to(self.device)
402
- mask = inputs["attention_mask"]
403
  mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
404
  return inputs
405
 
406
  def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
407
- mask = inputs["attention_mask"]
408
  with torch.set_grad_enabled(self.finetune), self.autocast:
409
  embeds = self.t5(**inputs).last_hidden_state
410
  embeds = self.output_proj(embeds.to(self.output_proj.weight))
@@ -426,204 +469,558 @@ class WaveformConditioner(BaseConditioner):
426
  def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
427
  super().__init__(dim, output_dim)
428
  self.device = device
 
 
429
 
430
- def tokenize(self, wav_length: WavCondition) -> WavCondition:
431
- wav, length, path = wav_length
432
  assert length is not None
433
- return WavCondition(wav.to(self.device), length.to(self.device), path)
434
 
435
- def _get_wav_embedding(self, wav: Tensor) -> Tensor:
436
- """Gets as input a wav and returns a dense vector of conditions."""
437
  raise NotImplementedError()
438
 
439
  def _downsampling_factor(self):
440
  """Returns the downsampling factor of the embedding model."""
441
  raise NotImplementedError()
442
 
443
- def forward(self, inputs: WavCondition) -> ConditionType:
444
- """
445
  Args:
446
- input (WavCondition): Tuple of (waveform, lengths).
447
  Returns:
448
- ConditionType: Dense vector representing the conditioning along with its' mask.
449
  """
450
- wav, lengths, path = inputs
451
  with torch.no_grad():
452
- embeds = self._get_wav_embedding(wav)
453
  embeds = embeds.to(self.output_proj.weight)
454
  embeds = self.output_proj(embeds)
455
 
456
- if lengths is not None:
457
  lengths = lengths / self._downsampling_factor()
458
  mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
459
  else:
460
- mask = torch.ones_like(embeds)
461
- embeds = (embeds * mask.unsqueeze(2).to(self.device))
462
-
463
  return embeds, mask
464
 
465
 
466
  class ChromaStemConditioner(WaveformConditioner):
467
- """Chroma conditioner that uses DEMUCS to first filter out drums and bass. The is followed by
468
- the insight the drums and bass often dominate the chroma, leading to the chroma not containing the
469
- information about melody.
 
470
 
471
  Args:
472
  output_dim (int): Output dimension for the conditioner.
473
  sample_rate (int): Sample rate for the chroma extractor.
474
- n_chroma (int): Number of chroma for the chroma extractor.
475
- radix2_exp (int): Radix2 exponent for the chroma extractor.
476
- duration (float): Duration used during training. This is later used for correct padding
477
  in case we are using chroma as prefix.
478
- match_len_on_eval (bool, optional): If True then all chromas are padded to the training
479
  duration. Defaults to False.
480
- eval_wavs (str, optional): Path to a json egg with waveform, this waveforms are used as
481
  conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
482
  Defaults to None.
483
- n_eval_wavs (int, optional): Limits the number of waveforms used for conditioning. Defaults to 0.
484
  device (tp.Union[torch.device, str], optional): Device for the conditioner.
485
  **kwargs: Additional parameters for the chroma extractor.
486
  """
487
  def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
488
  duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
489
- n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs):
 
490
  from demucs import pretrained
491
  super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
492
- self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
493
  self.sample_rate = sample_rate
494
  self.match_len_on_eval = match_len_on_eval
 
 
495
  self.duration = duration
496
- self.__dict__["demucs"] = pretrained.get_model('htdemucs').to(device)
497
- self.stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3}
498
- self.stem_idx = torch.LongTensor([self.stem2idx['vocal'], self.stem2idx['other']]).to(device)
499
- self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, radix2_exp=radix2_exp,
500
- device=device, **kwargs)
501
  self.chroma_len = self._get_chroma_len()
502
-
503
- def _downsampling_factor(self):
 
 
 
 
 
 
504
  return self.chroma.winhop
505
 
506
- def _get_chroma_len(self):
507
- """Get length of chroma during training"""
508
- dummy_wav = torch.zeros((1, self.sample_rate * self.duration), device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
509
  dummy_chr = self.chroma(dummy_wav)
510
  return dummy_chr.shape[1]
511
 
512
  @torch.no_grad()
513
- def _get_filtered_wav(self, wav):
 
514
  from demucs.apply import apply_model
515
  from demucs.audio import convert_audio
516
  with self.autocast:
517
- wav = convert_audio(wav, self.sample_rate, self.demucs.samplerate, self.demucs.audio_channels)
 
518
  stems = apply_model(self.demucs, wav, device=self.device)
519
- stems = stems[:, self.stem_idx] # extract stem
520
- stems = stems.sum(1) # merge extracted stems
521
- stems = stems.mean(1, keepdim=True) # mono
522
- stems = convert_audio(stems, self.demucs.samplerate, self.sample_rate, 1)
523
- return stems
524
 
525
  @torch.no_grad()
526
- def _get_wav_embedding(self, wav):
 
 
 
 
 
 
 
527
  # avoid 0-size tensors when we are working with null conds
528
  if wav.shape[-1] == 1:
529
- return self.chroma(wav)
530
- stems = self._get_filtered_wav(wav)
531
- chroma = self.chroma(stems)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
532
 
533
  if self.match_len_on_eval:
534
- b, t, c = chroma.shape
535
- if t > self.chroma_len:
536
  chroma = chroma[:, :self.chroma_len]
537
- logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})')
538
- elif t < self.chroma_len:
539
- # chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t))
540
- n_repeat = int(math.ceil(self.chroma_len / t))
541
  chroma = chroma.repeat(1, n_repeat, 1)
542
  chroma = chroma[:, :self.chroma_len]
543
- logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})')
 
544
  return chroma
545
 
 
 
 
 
 
 
 
 
546
 
547
- class ChromaExtractor(nn.Module):
548
- """Chroma extraction class, handles chroma extraction and quantization.
 
549
 
550
  Args:
551
- sample_rate (int): Sample rate.
552
- n_chroma (int): Number of chroma to consider.
553
- radix2_exp (int): Radix2 exponent.
554
- nfft (tp.Optional[int], optional): Number of FFT.
555
- winlen (tp.Optional[int], optional): Window length.
556
- winhop (tp.Optional[int], optional): Window hop size.
557
- argmax (bool, optional): Whether to use argmax. Defaults to False.
558
- norm (float, optional): Norm for chroma normalization. Defaults to inf.
559
- device (tp.Union[torch.device, str], optional): Device to use. Defaults to cpu.
560
  """
561
- def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12,
562
- nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None,
563
- argmax: bool = False, norm: float = torch.inf, device: tp.Union[torch.device, str] = "cpu"):
564
- super().__init__()
565
- from librosa import filters
566
  self.device = device
567
- self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
568
- self.winlen = winlen or 2 ** radix2_exp
569
- self.nfft = nfft or self.winlen
570
- self.winhop = winhop or (self.winlen // 4)
571
- self.sr = sample_rate
572
- self.n_chroma = n_chroma
573
- self.norm = norm
574
- self.argmax = argmax
575
- self.window = torch.hann_window(self.winlen).to(device)
576
- self.fbanks = torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
577
- n_chroma=self.n_chroma)).to(device)
578
- self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
579
- hop_length=self.winhop, power=2, center=True,
580
- pad=0, normalized=True).to(device)
581
-
582
- def forward(self, wav):
 
 
 
 
 
 
 
 
583
  with self.autocast:
584
- T = wav.shape[-1]
585
- # in case we are getting a wav that was dropped out (nullified)
586
- # make sure wav length is no less that nfft
587
- if T < self.nfft:
588
- pad = self.nfft - T
589
- r = 0 if pad % 2 == 0 else 1
590
- wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
591
- assert wav.shape[-1] == self.nfft, f'expected len {self.nfft} but got {wav.shape[-1]}'
592
- spec = self.spec(wav).squeeze(1)
593
- raw_chroma = torch.einsum("cf,...ft->...ct", self.fbanks, spec)
594
- norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
595
- norm_chroma = rearrange(norm_chroma, "b d t -> b t d")
596
-
597
- if self.argmax:
598
- idx = norm_chroma.argmax(-1, keepdims=True)
599
- norm_chroma[:] = 0
600
- norm_chroma.scatter_(dim=-1, index=idx, value=1)
601
-
602
- return norm_chroma
603
-
604
-
605
- def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
606
  """Utility function for nullifying an attribute inside an ConditioningAttributes object.
607
- If the condition is of type "wav", then nullify it using "nullify_condition".
608
- If the condition is of any other type, set its' value to None.
609
  Works in-place.
610
  """
611
- if condition_type not in ["text", "wav"]:
612
  raise ValueError(
613
  "dropout_condition got an unexpected condition type!"
614
- f" expected 'wav' or 'text' but got '{condition_type}'"
615
  )
616
 
617
  if condition not in getattr(sample, condition_type):
618
  raise ValueError(
619
  "dropout_condition received an unexpected condition!"
620
  f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
621
- f"but got '{condition}' of type '{condition_type}'!"
622
  )
623
 
624
- if condition_type == "wav":
625
- wav, length, path = sample.wav[condition]
626
- sample.wav[condition] = nullify_wav(wav)
 
 
 
627
  else:
628
  sample.text[condition] = None
629
 
@@ -631,7 +1028,7 @@ def dropout_condition(sample: ConditioningAttributes, condition_type: str, condi
631
 
632
 
633
  class DropoutModule(nn.Module):
634
- """Base class for all dropout modules."""
635
  def __init__(self, seed: int = 1234):
636
  super().__init__()
637
  self.rng = torch.Generator()
@@ -639,10 +1036,11 @@ class DropoutModule(nn.Module):
639
 
640
 
641
  class AttributeDropout(DropoutModule):
642
- """Applies dropout with a given probability per attribute. This is different from the behavior of
643
- ClassifierFreeGuidanceDropout as this allows for attributes to be dropped out separately. For example,
644
- "artist" can be dropped while "genre" remains. This is in contrast to ClassifierFreeGuidanceDropout
645
- where if "artist" is dropped "genre" must also be dropped.
 
646
 
647
  Args:
648
  p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
@@ -665,21 +1063,19 @@ class AttributeDropout(DropoutModule):
665
  def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
666
  """
667
  Args:
668
- samples (tp.List[ConditioningAttributes]): List of conditions.
669
  Returns:
670
- tp.List[ConditioningAttributes]: List of conditions after certain attributes were set to None.
671
  """
672
  if not self.training and not self.active_on_eval:
673
  return samples
674
 
675
  samples = deepcopy(samples)
676
-
677
  for condition_type, ps in self.p.items(): # for condition types [text, wav]
678
  for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
679
  if torch.rand(1, generator=self.rng).item() < p:
680
  for sample in samples:
681
  dropout_condition(sample, condition_type, condition)
682
-
683
  return samples
684
 
685
  def __repr__(self):
@@ -687,8 +1083,8 @@ class AttributeDropout(DropoutModule):
687
 
688
 
689
  class ClassifierFreeGuidanceDropout(DropoutModule):
690
- """Applies Classifier Free Guidance dropout, meaning all attributes
691
- are dropped with the same probability.
692
 
693
  Args:
694
  p (float): Probability to apply condition dropout during training.
@@ -701,9 +1097,9 @@ class ClassifierFreeGuidanceDropout(DropoutModule):
701
  def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
702
  """
703
  Args:
704
- samples (tp.List[ConditioningAttributes]): List of conditions.
705
  Returns:
706
- tp.List[ConditioningAttributes]: List of conditions after all attributes were set to None.
707
  """
708
  if not self.training:
709
  return samples
@@ -715,12 +1111,10 @@ class ClassifierFreeGuidanceDropout(DropoutModule):
715
 
716
  # nullify conditions of all attributes
717
  samples = deepcopy(samples)
718
-
719
  for condition_type in ["wav", "text"]:
720
  for sample in samples:
721
  for condition in sample.attributes[condition_type]:
722
  dropout_condition(sample, condition_type, condition)
723
-
724
  return samples
725
 
726
  def __repr__(self):
@@ -728,29 +1122,25 @@ class ClassifierFreeGuidanceDropout(DropoutModule):
728
 
729
 
730
  class ConditioningProvider(nn.Module):
731
- """Main class to provide conditions given all the supported conditioners.
732
 
733
  Args:
734
  conditioners (dict): Dictionary of conditioners.
735
- merge_text_conditions_p (float, optional): Probability to merge all text sources
736
- into a single text condition. Defaults to 0.
737
- drop_desc_p (float, optional): Probability to drop the original description
738
- when merging all text sources into a single text condition. Defaults to 0.
739
- device (tp.Union[torch.device, str], optional): Device for conditioners and output condition types.
740
  """
741
- def __init__(
742
- self,
743
- conditioners: tp.Dict[str, BaseConditioner],
744
- merge_text_conditions_p: float = 0,
745
- drop_desc_p: float = 0,
746
- device: tp.Union[torch.device, str] = "cpu",
747
- ):
748
  super().__init__()
749
  self.device = device
750
- self.merge_text_conditions_p = merge_text_conditions_p
751
- self.drop_desc_p = drop_desc_p
752
  self.conditioners = nn.ModuleDict(conditioners)
753
 
 
 
 
 
 
 
 
 
754
  @property
755
  def text_conditions(self):
756
  return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
@@ -769,33 +1159,36 @@ class ConditioningProvider(nn.Module):
769
  This will return a dict matching conditioner names to their arbitrary tokenized representations.
770
 
771
  Args:
772
- inputs (list[ConditioningAttribres]): List of ConditioningAttributes objects containing
773
  text and wav conditions.
774
  """
775
- assert all([type(x) == ConditioningAttributes for x in inputs]), \
776
- "got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]" \
777
  f" but types were {set([type(x) for x in inputs])}"
 
778
 
779
  output = {}
780
  text = self._collate_text(inputs)
781
  wavs = self._collate_wavs(inputs)
 
782
 
783
- assert set(text.keys() | wavs.keys()).issubset(set(self.conditioners.keys())), \
784
- f"got an unexpected attribute! Expected {self.conditioners.keys()}, got {text.keys(), wavs.keys()}"
 
 
785
 
786
- for attribute, batch in chain(text.items(), wavs.items()):
787
  output[attribute] = self.conditioners[attribute].tokenize(batch)
788
  return output
789
 
790
  def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
791
- """Compute pairs of `(embedding, mask)` using the configured conditioners
792
- and the tokenized representations. The output is for example:
793
-
794
- {
795
- "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
796
- "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
797
- ...
798
- }
799
 
800
  Args:
801
  tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
@@ -820,51 +1213,22 @@ class ConditioningProvider(nn.Module):
820
  "genre": ["Rock", "Hip-hop"],
821
  "description": ["A rock song with a guitar solo", "A hip-hop verse"]
822
  }
823
- """
824
- batch_per_attribute: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
825
-
826
- def _merge_conds(cond, merge_text_conditions_p=0, drop_desc_p=0):
827
- def is_valid(k, v):
828
- k_valid = k in ['key', 'bpm', 'genre', 'moods', 'instrument']
829
- v_valid = v is not None and isinstance(v, (int, float, str, list))
830
- return k_valid and v_valid
831
-
832
- def process_value(v):
833
- if isinstance(v, (int, float, str)):
834
- return v
835
- if isinstance(v, list):
836
- return ", ".join(v)
837
- else:
838
- RuntimeError(f"unknown type for text value! ({type(v), v})")
839
-
840
- desc = cond.text['description']
841
- meta_data = ""
842
- if random.uniform(0, 1) < merge_text_conditions_p:
843
- meta_pairs = [f'{k}: {process_value(v)}' for k, v in cond.text.items() if is_valid(k, v)]
844
- random.shuffle(meta_pairs)
845
- meta_data = ". ".join(meta_pairs)
846
- desc = desc if not random.uniform(0, 1) < drop_desc_p else None
847
-
848
- if desc is None:
849
- desc = meta_data if len(meta_data) > 1 else None
850
- else:
851
- desc = desc.rstrip('.') + ". " + meta_data
852
- cond.text['description'] = desc.strip() if desc else None
853
-
854
- if self.training and self.merge_text_conditions_p:
855
- for sample in samples:
856
- _merge_conds(sample, self.merge_text_conditions_p, self.drop_desc_p)
857
 
 
 
 
 
 
 
858
  texts = [x.text for x in samples]
859
  for text in texts:
860
  for condition in self.text_conditions:
861
- batch_per_attribute[condition].append(text[condition])
862
-
863
- return batch_per_attribute
864
 
865
- def _collate_wavs(self, samples: tp.List[ConditioningAttributes]):
866
  """Generate a dict where the keys are attributes by which we fetch similar wavs,
867
- and the values are Tensors of wavs according to said attribtues.
868
 
869
  *Note*: by the time the samples reach this function, each sample should have some waveform
870
  inside the "wav" attribute. It should be either:
@@ -873,27 +1237,89 @@ class ConditioningProvider(nn.Module):
873
  3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
874
 
875
  Args:
876
- samples (tp.List[ConditioningAttributes]): List of ConditioningAttributes samples.
877
  Returns:
878
- dict: A dicionary mapping an attribute name to wavs.
879
  """
880
  wavs = defaultdict(list)
881
- lens = defaultdict(list)
 
882
  paths = defaultdict(list)
883
- out = {}
 
884
 
885
  for sample in samples:
886
  for attribute in self.wav_conditions:
887
- wav, length, path = sample.wav[attribute]
888
- wavs[attribute].append(wav.flatten())
889
- lens[attribute].append(length)
890
- paths[attribute].append(path)
 
 
 
 
 
 
891
 
892
  # stack all wavs to a single tensor
893
  for attribute in self.wav_conditions:
894
  stacked_wav, _ = collate(wavs[attribute], dim=0)
895
- out[attribute] = WavCondition(stacked_wav.unsqueeze(1),
896
- torch.cat(lens['self_wav']), paths[attribute]) # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897
 
898
  return out
899
 
@@ -920,7 +1346,7 @@ class ConditionFuser(StreamingModule):
920
  super().__init__()
921
  assert all(
922
  [k in self.FUSING_METHODS for k in fuse2cond.keys()]
923
- ), f"got invalid fuse method, allowed methods: {self.FUSING_MEHTODS}"
924
  self.cross_attention_pos_emb = cross_attention_pos_emb
925
  self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
926
  self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
@@ -931,16 +1357,16 @@ class ConditionFuser(StreamingModule):
931
 
932
  def forward(
933
  self,
934
- input: Tensor,
935
  conditions: tp.Dict[str, ConditionType]
936
- ) -> tp.Tuple[Tensor, tp.Optional[Tensor]]:
937
  """Fuse the conditions to the provided model input.
938
 
939
  Args:
940
- input (Tensor): Transformer input.
941
- conditions (tp.Dict[str, ConditionType]): Dict of conditions.
942
  Returns:
943
- tp.Tuple[Tensor, Tensor]: The first tensor is the transformer input
944
  after the conditions have been fused. The second output tensor is the tensor
945
  used for cross-attention or None if no cross attention inputs exist.
946
  """
@@ -959,16 +1385,16 @@ class ConditionFuser(StreamingModule):
959
  cross_attention_output = None
960
  for cond_type, (cond, cond_mask) in conditions.items():
961
  op = self.cond2fuse[cond_type]
962
- if op == "sum":
963
  input += cond
964
- elif op == "input_interpolate":
965
- cond = rearrange(cond, "b t d -> b d t")
966
  cond = F.interpolate(cond, size=input.shape[1])
967
- input += rearrange(cond, "b d t -> b t d")
968
- elif op == "prepend":
969
  if first_step:
970
  input = torch.cat([cond, input], dim=1)
971
- elif op == "cross":
972
  if cross_attention_output is not None:
973
  cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
974
  else:
 
10
  from itertools import chain
11
  import logging
12
  import math
13
+ from pathlib import Path
14
  import random
15
  import re
16
  import typing as tp
17
  import warnings
18
 
19
+ import einops
20
  from num2words import num2words
21
  import spacy
22
+ from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer # type: ignore
 
23
  import torch
24
  from torch import nn
 
25
  import torch.nn.functional as F
26
  from torch.nn.utils.rnn import pad_sequence
27
 
28
+ from .chroma import ChromaExtractor
29
  from .streaming import StreamingModule
30
  from .transformer import create_sin_embedding
31
+ from ..data.audio import audio_read
32
  from ..data.audio_dataset import SegmentInfo
33
+ from ..data.audio_utils import convert_audio
34
+ from ..environment import AudioCraftEnvironment
35
+ from ..quantization import ResidualVectorQuantizer
36
  from ..utils.autocast import TorchAutocast
37
+ from ..utils.cache import EmbeddingCache
38
+ from ..utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once
39
 
40
 
41
  logger = logging.getLogger(__name__)
42
  TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
43
+ ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
44
 
45
 
46
  class WavCondition(tp.NamedTuple):
47
+ wav: torch.Tensor
48
+ length: torch.Tensor
49
+ sample_rate: tp.List[int]
50
  path: tp.List[tp.Optional[str]] = []
51
+ seek_time: tp.List[tp.Optional[float]] = []
52
 
53
 
54
+ class JointEmbedCondition(tp.NamedTuple):
55
+ wav: torch.Tensor
56
+ text: tp.List[tp.Optional[str]]
57
+ length: torch.Tensor
58
+ sample_rate: tp.List[int]
59
+ path: tp.List[tp.Optional[str]] = []
60
+ seek_time: tp.List[tp.Optional[float]] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  @dataclass
64
  class ConditioningAttributes:
65
  text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
66
  wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
67
+ joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict)
68
 
69
  def __getitem__(self, item):
70
  return getattr(self, item)
 
77
  def wav_attributes(self):
78
  return self.wav.keys()
79
 
80
+ @property
81
+ def joint_embed_attributes(self):
82
+ return self.joint_embed.keys()
83
+
84
  @property
85
  def attributes(self):
86
+ return {
87
+ "text": self.text_attributes,
88
+ "wav": self.wav_attributes,
89
+ "joint_embed": self.joint_embed_attributes,
90
+ }
91
 
92
  def to_flat_dict(self):
93
  return {
94
  **{f"text.{k}": v for k, v in self.text.items()},
95
  **{f"wav.{k}": v for k, v in self.wav.items()},
96
+ **{f"joint_embed.{k}": v for k, v in self.joint_embed.items()}
97
  }
98
 
99
  @classmethod
 
114
  raise NotImplementedError()
115
 
116
 
117
+ def nullify_condition(condition: ConditionType, dim: int = 1):
118
+ """Transform an input condition to a null condition.
119
+ The way it is done by converting it to a single zero vector similarly
120
+ to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
121
+
122
+ Args:
123
+ condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor])
124
+ dim (int): The dimension that will be truncated (should be the time dimension)
125
+ WARNING!: dim should not be the batch dimension!
126
+ Returns:
127
+ ConditionType: A tuple of null condition and mask
128
+ """
129
+ assert dim != 0, "dim cannot be the batch dimension!"
130
+ assert isinstance(condition, tuple) and \
131
+ isinstance(condition[0], torch.Tensor) and \
132
+ isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!"
133
+ cond, mask = condition
134
+ B = cond.shape[0]
135
+ last_dim = cond.dim() - 1
136
+ out = cond.transpose(dim, last_dim)
137
+ out = 0. * out[..., :1]
138
+ out = out.transpose(dim, last_dim)
139
+ mask = torch.zeros((B, 1), device=out.device).int()
140
+ assert cond.dim() == out.dim()
141
+ return out, mask
142
+
143
+
144
+ def nullify_wav(cond: WavCondition) -> WavCondition:
145
+ """Transform a WavCondition to a nullified WavCondition.
146
+ It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes.
147
+
148
+ Args:
149
+ cond (WavCondition): Wav condition with wav, tensor of shape [B, T].
150
+ Returns:
151
+ WavCondition: Nullified wav condition.
152
+ """
153
+ null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1)
154
+ return WavCondition(
155
+ wav=null_wav,
156
+ length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device),
157
+ sample_rate=cond.sample_rate,
158
+ path=[None] * cond.wav.shape[0],
159
+ seek_time=[None] * cond.wav.shape[0],
160
+ )
161
+
162
+
163
+ def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition:
164
+ """Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0,
165
+ and replacing metadata by dummy attributes.
166
+
167
+ Args:
168
+ cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T].
169
+ """
170
+ null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1)
171
+ return JointEmbedCondition(
172
+ wav=null_wav, text=[None] * len(embed.text),
173
+ length=torch.LongTensor([0]).to(embed.wav.device),
174
+ sample_rate=embed.sample_rate,
175
+ path=[None] * embed.wav.shape[0],
176
+ seek_time=[0] * embed.wav.shape[0],
177
+ )
178
+
179
+
180
  class Tokenizer:
181
+ """Base tokenizer implementation
182
  (in case we want to introduce more advances tokenizers in the future).
183
  """
184
+ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
185
  raise NotImplementedError()
186
 
187
 
 
192
  [[78, 62, 31, 4, 78, 25, 19, 34],
193
  [59, 77, 0, 0, 0, 0, 0, 0]]
194
  """
195
+ PUNCTUATION = "?:!.,;"
196
 
197
  def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
198
  lemma: bool = True, stopwords: bool = True) -> None:
 
207
  self.nlp = spacy.load(language)
208
 
209
  @tp.no_type_check
210
+ def __call__(self, texts: tp.List[tp.Optional[str]],
211
+ return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]:
 
 
 
212
  """Take a list of strings and convert them to a tensor of indices.
213
 
214
  Args:
215
+ texts (list[str]): List of strings.
216
  return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
217
  Returns:
218
+ tuple[torch.Tensor, torch.Tensor]:
219
  - Indices of words in the LUT.
220
  - And a mask indicating where the padding tokens are
221
  """
 
224
  for i, text in enumerate(texts):
225
  # if current sample doesn't have a certain attribute, replace with pad token
226
  if text is None:
227
+ output.append(torch.Tensor([self.pad_idx]))
228
  lengths.append(0)
229
  continue
230
 
 
235
  # remove stopwords
236
  if self.stopwords:
237
  text = [w for w in text if not w.is_stop] # type: ignore
238
+ # remove punctuation
239
+ text = [w for w in text if w.text not in self.PUNCTUATION] # type: ignore
240
  # lemmatize if needed
241
  text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
242
 
243
  texts[i] = " ".join(text)
244
  lengths.append(len(text))
245
  # convert to tensor
246
+ tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text])
247
  output.append(tokens)
248
 
249
  mask = length_to_mask(torch.IntTensor(lengths)).int()
 
267
  self.n_bins = n_bins
268
  self.pad_idx = pad_idx
269
 
270
+ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
271
  output, lengths = [], []
272
  for text in texts:
273
  # if current sample doesn't have a certain attribute, replace with pad token
 
284
 
285
 
286
  class BaseConditioner(nn.Module):
287
+ """Base model for all conditioner modules.
288
+ We allow the output dim to be different than the hidden dim for two reasons:
289
+ 1) keep our LUTs small when the vocab is large;
290
  2) make all condition dims consistent.
291
 
292
  Args:
293
+ dim (int): Hidden dim of the model.
294
  output_dim (int): Output dim of the conditioner.
295
  """
296
+ def __init__(self, dim: int, output_dim: int):
297
  super().__init__()
298
  self.dim = dim
299
  self.output_dim = output_dim
 
338
  super().__init__(dim, output_dim)
339
  self.embed = nn.Embedding(n_bins, dim)
340
  self.tokenizer: Tokenizer
341
+ if tokenizer == 'whitespace':
342
  self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
343
+ elif tokenizer == 'noop':
344
  self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
345
  else:
346
  raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
 
390
  def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
391
  autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
392
  normalize_text: bool = False):
393
+ assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
394
  super().__init__(self.MODELS_DIMS[name], output_dim)
395
  self.device = device
396
  self.name = name
397
  self.finetune = finetune
398
  self.word_dropout = word_dropout
 
399
  if autocast_dtype is None or self.device == 'cpu':
400
  self.autocast = TorchAutocast(enabled=False)
401
  if self.device != 'cpu':
 
421
  else:
422
  # this makes sure that the t5 models is not part
423
  # of the saved checkpoint
424
+ self.__dict__['t5'] = t5.to(device)
425
 
426
  self.normalize_text = normalize_text
427
  if normalize_text:
 
441
 
442
  empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
443
 
444
+ inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device)
445
+ mask = inputs['attention_mask']
446
  mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
447
  return inputs
448
 
449
  def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
450
+ mask = inputs['attention_mask']
451
  with torch.set_grad_enabled(self.finetune), self.autocast:
452
  embeds = self.t5(**inputs).last_hidden_state
453
  embeds = self.output_proj(embeds.to(self.output_proj.weight))
 
469
  def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
470
  super().__init__(dim, output_dim)
471
  self.device = device
472
+ # if False no masking is done, used in ChromaStemConditioner when completing by periodicity a sample.
473
+ self._use_masking = True
474
 
475
+ def tokenize(self, x: WavCondition) -> WavCondition:
476
+ wav, length, sample_rate, path, seek_time = x
477
  assert length is not None
478
+ return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time)
479
 
480
+ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
481
+ """Gets as input a WavCondition and returns a dense embedding."""
482
  raise NotImplementedError()
483
 
484
  def _downsampling_factor(self):
485
  """Returns the downsampling factor of the embedding model."""
486
  raise NotImplementedError()
487
 
488
+ def forward(self, x: WavCondition) -> ConditionType:
489
+ """Extract condition embedding and mask from a waveform and its metadata.
490
  Args:
491
+ x (WavCondition): Waveform condition containing raw waveform and metadata.
492
  Returns:
493
+ ConditionType: a dense vector representing the conditioning along with its mask
494
  """
495
+ wav, lengths, *_ = x
496
  with torch.no_grad():
497
+ embeds = self._get_wav_embedding(x)
498
  embeds = embeds.to(self.output_proj.weight)
499
  embeds = self.output_proj(embeds)
500
 
501
+ if lengths is not None and self._use_masking:
502
  lengths = lengths / self._downsampling_factor()
503
  mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
504
  else:
505
+ mask = torch.ones_like(embeds[..., 0])
506
+ embeds = (embeds * mask.unsqueeze(-1))
 
507
  return embeds, mask
508
 
509
 
510
  class ChromaStemConditioner(WaveformConditioner):
511
+ """Chroma conditioner based on stems.
512
+ The ChromaStemConditioner uses DEMUCS to first filter out drums and bass, as
513
+ the drums and bass often dominate the chroma leading to the chroma features
514
+ not containing information about the melody.
515
 
516
  Args:
517
  output_dim (int): Output dimension for the conditioner.
518
  sample_rate (int): Sample rate for the chroma extractor.
519
+ n_chroma (int): Number of chroma bins for the chroma extractor.
520
+ radix2_exp (int): Size of stft window for the chroma extractor (power of 2, e.g. 12 -> 2^12).
521
+ duration (int): duration used during training. This is later used for correct padding
522
  in case we are using chroma as prefix.
523
+ match_len_on_eval (bool, optional): if True then all chromas are padded to the training
524
  duration. Defaults to False.
525
+ eval_wavs (str, optional): path to a dataset manifest with waveform, this waveforms are used as
526
  conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
527
  Defaults to None.
528
+ n_eval_wavs (int, optional): limits the number of waveforms used for conditioning. Defaults to 0.
529
  device (tp.Union[torch.device, str], optional): Device for the conditioner.
530
  **kwargs: Additional parameters for the chroma extractor.
531
  """
532
  def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
533
  duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
534
+ n_eval_wavs: int = 0, cache_path: tp.Optional[tp.Union[str, Path]] = None,
535
+ device: tp.Union[torch.device, str] = 'cpu', **kwargs):
536
  from demucs import pretrained
537
  super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
538
+ self.autocast = TorchAutocast(enabled=device != 'cpu', device_type=self.device, dtype=torch.float32)
539
  self.sample_rate = sample_rate
540
  self.match_len_on_eval = match_len_on_eval
541
+ if match_len_on_eval:
542
+ self._use_masking = False
543
  self.duration = duration
544
+ self.__dict__['demucs'] = pretrained.get_model('htdemucs').to(device)
545
+ stem_sources: list = self.demucs.sources # type: ignore
546
+ self.stem_indices = torch.LongTensor([stem_sources.index('vocals'), stem_sources.index('other')]).to(device)
547
+ self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma,
548
+ radix2_exp=radix2_exp, **kwargs).to(device)
549
  self.chroma_len = self._get_chroma_len()
550
+ self.eval_wavs: tp.Optional[torch.Tensor] = self._load_eval_wavs(eval_wavs, n_eval_wavs)
551
+ self.cache = None
552
+ if cache_path is not None:
553
+ self.cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
554
+ compute_embed_fn=self._get_full_chroma_for_cache,
555
+ extract_embed_fn=self._extract_chroma_chunk)
556
+
557
+ def _downsampling_factor(self) -> int:
558
  return self.chroma.winhop
559
 
560
+ def _load_eval_wavs(self, path: tp.Optional[str], num_samples: int) -> tp.Optional[torch.Tensor]:
561
+ """Load pre-defined waveforms from a json.
562
+ These waveforms will be used for chroma extraction during evaluation.
563
+ This is done to make the evaluation on MusicCaps fair (we shouldn't see the chromas of MusicCaps).
564
+ """
565
+ if path is None:
566
+ return None
567
+
568
+ logger.info(f"Loading evaluation wavs from {path}")
569
+ from audiocraft.data.audio_dataset import AudioDataset
570
+ dataset: AudioDataset = AudioDataset.from_meta(
571
+ path, segment_duration=self.duration, min_audio_duration=self.duration,
572
+ sample_rate=self.sample_rate, channels=1)
573
+
574
+ if len(dataset) > 0:
575
+ eval_wavs = dataset.collater([dataset[i] for i in range(num_samples)]).to(self.device)
576
+ logger.info(f"Using {len(eval_wavs)} evaluation wavs for chroma-stem conditioner")
577
+ return eval_wavs
578
+ else:
579
+ raise ValueError("Could not find evaluation wavs, check lengths of wavs")
580
+
581
+ def reset_eval_wavs(self, eval_wavs: tp.Optional[torch.Tensor]) -> None:
582
+ self.eval_wavs = eval_wavs
583
+
584
+ def has_eval_wavs(self) -> bool:
585
+ return self.eval_wavs is not None
586
+
587
+ def _sample_eval_wavs(self, num_samples: int) -> torch.Tensor:
588
+ """Sample wavs from a predefined list."""
589
+ assert self.eval_wavs is not None, "Cannot sample eval wavs as no eval wavs provided."
590
+ total_eval_wavs = len(self.eval_wavs)
591
+ out = self.eval_wavs
592
+ if num_samples > total_eval_wavs:
593
+ out = self.eval_wavs.repeat(num_samples // total_eval_wavs + 1, 1, 1)
594
+ return out[torch.randperm(len(out))][:num_samples]
595
+
596
+ def _get_chroma_len(self) -> int:
597
+ """Get length of chroma during training."""
598
+ dummy_wav = torch.zeros((1, int(self.sample_rate * self.duration)), device=self.device)
599
  dummy_chr = self.chroma(dummy_wav)
600
  return dummy_chr.shape[1]
601
 
602
  @torch.no_grad()
603
+ def _get_stemmed_wav(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
604
+ """Get parts of the wav that holds the melody, extracting the main stems from the wav."""
605
  from demucs.apply import apply_model
606
  from demucs.audio import convert_audio
607
  with self.autocast:
608
+ wav = convert_audio(
609
+ wav, sample_rate, self.demucs.samplerate, self.demucs.audio_channels) # type: ignore
610
  stems = apply_model(self.demucs, wav, device=self.device)
611
+ stems = stems[:, self.stem_indices] # extract relevant stems for melody conditioning
612
+ mix_wav = stems.sum(1) # merge extracted stems to single waveform
613
+ mix_wav = convert_audio(mix_wav, self.demucs.samplerate, self.sample_rate, 1) # type: ignore
614
+ return mix_wav
 
615
 
616
  @torch.no_grad()
617
+ def _extract_chroma(self, wav: torch.Tensor) -> torch.Tensor:
618
+ """Extract chroma features from the waveform."""
619
+ with self.autocast:
620
+ return self.chroma(wav)
621
+
622
+ @torch.no_grad()
623
+ def _compute_wav_embedding(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
624
+ """Compute wav embedding, applying stem and chroma extraction."""
625
  # avoid 0-size tensors when we are working with null conds
626
  if wav.shape[-1] == 1:
627
+ return self._extract_chroma(wav)
628
+ stems = self._get_stemmed_wav(wav, sample_rate)
629
+ chroma = self._extract_chroma(stems)
630
+ return chroma
631
+
632
+ @torch.no_grad()
633
+ def _get_full_chroma_for_cache(self, path: tp.Union[str, Path], x: WavCondition, idx: int) -> torch.Tensor:
634
+ """Extract chroma from the whole audio waveform at the given path."""
635
+ wav, sr = audio_read(path)
636
+ wav = wav[None].to(self.device)
637
+ wav = convert_audio(wav, sr, self.sample_rate, to_channels=1)
638
+ chroma = self._compute_wav_embedding(wav, self.sample_rate)[0]
639
+ return chroma
640
+
641
+ def _extract_chroma_chunk(self, full_chroma: torch.Tensor, x: WavCondition, idx: int) -> torch.Tensor:
642
+ """Extract a chunk of chroma from the full chroma derived from the full waveform."""
643
+ wav_length = x.wav.shape[-1]
644
+ seek_time = x.seek_time[idx]
645
+ assert seek_time is not None, (
646
+ "WavCondition seek_time is required "
647
+ "when extracting chroma chunks from pre-computed chroma.")
648
+ full_chroma = full_chroma.float()
649
+ frame_rate = self.sample_rate / self._downsampling_factor()
650
+ target_length = int(frame_rate * wav_length / self.sample_rate)
651
+ index = int(frame_rate * seek_time)
652
+ out = full_chroma[index: index + target_length]
653
+ out = F.pad(out[None], (0, 0, 0, target_length - out.shape[0]))[0]
654
+ return out.to(self.device)
655
+
656
+ @torch.no_grad()
657
+ def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor:
658
+ """Get the wav embedding from the WavCondition.
659
+ The conditioner will either extract the embedding on-the-fly computing it from the condition wav directly
660
+ or will rely on the embedding cache to load the pre-computed embedding if relevant.
661
+ """
662
+ sampled_wav: tp.Optional[torch.Tensor] = None
663
+ if not self.training and self.eval_wavs is not None:
664
+ warn_once(logger, "Using precomputed evaluation wavs!")
665
+ sampled_wav = self._sample_eval_wavs(len(x.wav))
666
+
667
+ no_undefined_paths = all(p is not None for p in x.path)
668
+ no_nullified_cond = x.wav.shape[-1] > 1
669
+ if sampled_wav is not None:
670
+ chroma = self._compute_wav_embedding(sampled_wav, self.sample_rate)
671
+ elif self.cache is not None and no_undefined_paths and no_nullified_cond:
672
+ paths = [Path(p) for p in x.path if p is not None]
673
+ chroma = self.cache.get_embed_from_cache(paths, x)
674
+ else:
675
+ assert all(sr == x.sample_rate[0] for sr in x.sample_rate), "All sample rates in batch should be equal."
676
+ chroma = self._compute_wav_embedding(x.wav, x.sample_rate[0])
677
 
678
  if self.match_len_on_eval:
679
+ B, T, C = chroma.shape
680
+ if T > self.chroma_len:
681
  chroma = chroma[:, :self.chroma_len]
682
+ logger.debug(f"Chroma was truncated to match length! ({T} -> {chroma.shape[1]})")
683
+ elif T < self.chroma_len:
684
+ n_repeat = int(math.ceil(self.chroma_len / T))
 
685
  chroma = chroma.repeat(1, n_repeat, 1)
686
  chroma = chroma[:, :self.chroma_len]
687
+ logger.debug(f"Chroma was repeated to match length! ({T} -> {chroma.shape[1]})")
688
+
689
  return chroma
690
 
691
+ def tokenize(self, x: WavCondition) -> WavCondition:
692
+ """Apply WavConditioner tokenization and populate cache if needed."""
693
+ x = super().tokenize(x)
694
+ no_undefined_paths = all(p is not None for p in x.path)
695
+ if self.cache is not None and no_undefined_paths:
696
+ paths = [Path(p) for p in x.path if p is not None]
697
+ self.cache.populate_embed_cache(paths, x)
698
+ return x
699
 
700
+
701
+ class JointEmbeddingConditioner(BaseConditioner):
702
+ """Joint embedding conditioning supporting both audio or text conditioning.
703
 
704
  Args:
705
+ dim (int): Dimension.
706
+ output_dim (int): Output dimension.
707
+ device (str): Device.
708
+ attribute (str): Attribute used by the conditioner.
709
+ autocast_dtype (str): Autocast for the conditioner.
710
+ quantize (bool): Whether to quantize the CLAP embedding.
711
+ n_q (int): Number of residual quantizers (used if quantize is true).
712
+ bins (int): Quantizers' codebooks size (used if quantize is true).
713
+ kwargs: Additional parameters for residual vector quantizer.
714
  """
715
+ def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
716
+ autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True,
717
+ n_q: int = 12, bins: int = 1024, **kwargs):
718
+ super().__init__(dim=dim, output_dim=output_dim)
 
719
  self.device = device
720
+ self.attribute = attribute
721
+ if autocast_dtype is None or device == 'cpu':
722
+ self.autocast = TorchAutocast(enabled=False)
723
+ logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.")
724
+ else:
725
+ dtype = getattr(torch, autocast_dtype)
726
+ assert isinstance(dtype, torch.dtype)
727
+ logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.")
728
+ self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
729
+ # residual vector quantizer to discretize the conditioned embedding
730
+ self.quantizer: tp.Optional[ResidualVectorQuantizer] = None
731
+ if quantize:
732
+ self.quantizer = ResidualVectorQuantizer(dim, n_q=n_q, bins=bins, **kwargs)
733
+
734
+ def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
735
+ """Get joint embedding in latent space from the inputs.
736
+
737
+ Returns:
738
+ tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding
739
+ and corresponding empty indexes.
740
+ """
741
+ raise NotImplementedError()
742
+
743
+ def forward(self, x: JointEmbedCondition) -> ConditionType:
744
  with self.autocast:
745
+ embed, empty_idx = self._get_embed(x)
746
+ if self.quantizer is not None:
747
+ embed = embed.view(-1, self.dim, 1)
748
+ q_res = self.quantizer(embed, frame_rate=1)
749
+ out_embed = q_res.x.view(-1, self.dim)
750
+ else:
751
+ out_embed = embed
752
+ out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim)
753
+ mask = torch.ones(*out_embed.shape[:2], device=out_embed.device)
754
+ mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
755
+ out_embed = (out_embed * mask.unsqueeze(-1))
756
+ return out_embed, mask
757
+
758
+ def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
759
+ return x
760
+
761
+
762
+ class CLAPEmbeddingConditioner(JointEmbeddingConditioner):
763
+ """Joint Embedding conditioner based on pre-trained CLAP model.
764
+
765
+ This CLAP-based conditioner supports a caching mechanism
766
+ over the computed embeddings for faster training.
767
+
768
+ Args:
769
+ dim (int): Dimension.
770
+ output_dim (int): Output dimension.
771
+ device (str): Device.
772
+ attribute (str): Attribute used by the conditioner.
773
+ quantize (bool): Whether to quantize the CLAP embedding.
774
+ n_q (int): Number of residual quantizers (used if quantize is true).
775
+ bins (int): Quantizers' codebooks size (used if quantize is true).
776
+ checkpoint (str): Path to CLAP checkpoint.
777
+ model_arch (str): CLAP model architecture.
778
+ enable_fusion (bool): Enable fusion for CLAP model.
779
+ sample_rate (int): Sample rate used by CLAP model.
780
+ max_audio_length (float): Maximum audio length for CLAP model.
781
+ audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence.
782
+ normalize (bool): Whether to normalize the CLAP embedding.
783
+ text_p (float): Probability of using text representation instead of audio at train time.
784
+ batch_size (Optional[int]): Batch size for CLAP embedding computation.
785
+ autocast_dtype (str): Autocast for the conditioner.
786
+ cache_path (Optional[str]): Path for pre-computed embeddings caching.
787
+ kwargs: Additional parameters for residual vector quantizer.
788
+ """
789
+ def __init__(self, dim: int, output_dim: int, device: str, attribute: str,
790
+ quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str,
791
+ enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int,
792
+ normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None,
793
+ autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs):
794
+ try:
795
+ import laion_clap # type: ignore
796
+ except ImportError:
797
+ raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'")
798
+ warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). "
799
+ "Please retrain all models.")
800
+ checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint)
801
+ clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base')
802
+ clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch)
803
+ load_clap_state_dict(clap_model, checkpoint)
804
+ clap_model.eval()
805
+ clap_model.to(device)
806
+ super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute,
807
+ autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins,
808
+ **kwargs)
809
+ self.checkpoint = checkpoint
810
+ self.enable_fusion = enable_fusion
811
+ self.model_arch = model_arch
812
+ self.clap: laion_clap.CLAP_Module
813
+ self.clap_tokenize: RobertaTokenizer
814
+ self.clap_sample_rate = sample_rate
815
+ self.clap_max_frames = int(self.clap_sample_rate * max_audio_length)
816
+ self.clap_stride = int(self.clap_sample_rate * audio_stride)
817
+ self.batch_size = batch_size or 1
818
+ self.normalize = normalize
819
+ self.text_p = text_p
820
+ self.__dict__['clap_tokenize'] = clap_tokenize
821
+ self.__dict__['clap'] = clap_model
822
+ self.wav_cache, self.text_cache = None, None
823
+ if cache_path is not None:
824
+ self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device,
825
+ compute_embed_fn=self._get_wav_embedding_for_cache,
826
+ extract_embed_fn=self._extract_wav_embedding_chunk)
827
+ self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device,
828
+ compute_embed_fn=self._get_text_embedding_for_cache)
829
+
830
+ def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict:
831
+ # we use the default params from CLAP module here as well
832
+ return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
833
+
834
+ def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor:
835
+ """Compute text embedding from CLAP model on a given a batch of text.
836
+
837
+ Args:
838
+ text (list[str]): List of text for the batch, with B items.
839
+ Returns:
840
+ torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension.
841
+ """
842
+ with torch.no_grad():
843
+ embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True)
844
+ return embed.view(embed.size(0), 1, embed.size(-1))
845
+
846
+ def _get_text_embedding_for_cache(self, path: tp.Union[Path, str],
847
+ x: JointEmbedCondition, idx: int) -> torch.Tensor:
848
+ """Get text embedding function for the cache."""
849
+ text = x.text[idx]
850
+ text = text if text is not None else ""
851
+ return self._compute_text_embedding([text])[0]
852
+
853
+ def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor:
854
+ """Preprocess wav to expected format by CLAP model.
855
+
856
+ Args:
857
+ wav (torch.Tensor): Audio wav, of shape [B, C, T].
858
+ length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
859
+ sample_rates (list[int]): Sample rates for each sample in the batch
860
+ Returns:
861
+ torch.Tensor: Audio wav of shape [B, T].
862
+ """
863
+ assert wav.dim() == 3, "Expecting wav to be [B, C, T]"
864
+ if sample_rates is not None:
865
+ _wav = []
866
+ for i, audio in enumerate(wav):
867
+ sr = sample_rates[i]
868
+ audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1)
869
+ _wav.append(audio)
870
+ wav = torch.stack(_wav, dim=0)
871
+ wav = wav.mean(dim=1)
872
+ return wav
873
+
874
+ def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor,
875
+ sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor:
876
+ """Compute audio wave embedding from CLAP model.
877
+
878
+ Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences,
879
+ we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and
880
+ average the resulting embeddings.
881
+
882
+ Args:
883
+ wav (torch.Tensor): Audio wav, of shape [B, C, T].
884
+ length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B].
885
+ sample_rates (list[int]): Sample rates for each sample in the batch.
886
+ reduce_mean (bool): Whether to get the average tensor.
887
+ Returns:
888
+ torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension.
889
+ """
890
+ with torch.no_grad():
891
+ wav = self._preprocess_wav(wav, length, sample_rates)
892
+ B, T = wav.shape
893
+ if T >= self.clap_max_frames:
894
+ wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) # [B, F, T]
895
+ else:
896
+ wav = wav.view(-1, 1, T) # [B, F, T] with F=1
897
+ wav = einops.rearrange(wav, 'b f t -> (b f) t')
898
+ embed_list = []
899
+ for i in range(0, wav.size(0), self.batch_size):
900
+ _wav = wav[i:i+self.batch_size, ...]
901
+ _embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True)
902
+ embed_list.append(_embed)
903
+ embed = torch.cat(embed_list, dim=0)
904
+ embed = einops.rearrange(embed, '(b f) d -> b f d', b=B)
905
+ if reduce_mean:
906
+ embed = embed.mean(dim=1, keepdim=True)
907
+ return embed # [B, F, D] with F=1 if reduce_mean is True
908
+
909
+ def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path],
910
+ x: JointEmbedCondition, idx: int) -> torch.Tensor:
911
+ """Compute audio wave embedding for the cache.
912
+ The embedding is computed on a given audio read from file.
913
+
914
+ Args:
915
+ path (str or Path): Path to the full audio file.
916
+ Returns:
917
+ torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension.
918
+ """
919
+ wav, sr = audio_read(path) # [C, T]
920
+ wav = wav.unsqueeze(0).to(self.device) # [1, C, T]
921
+ wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device)
922
+ embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) # [B, F, D]
923
+ return embed.squeeze(0) # [F, D]
924
+
925
+ def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor:
926
+ """Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding.
927
+
928
+ Args:
929
+ full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D].
930
+ x (JointEmbedCondition): Joint embedding condition for the full batch.
931
+ idx (int): Index considered for the given embedding to extract.
932
+ Returns:
933
+ torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D].
934
+ """
935
+ sample_rate = x.sample_rate[idx]
936
+ seek_time = x.seek_time[idx]
937
+ seek_time = 0. if seek_time is None else seek_time
938
+ clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate
939
+ end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate
940
+ start_offset = int(seek_time * sample_rate // clap_stride)
941
+ end_offset = int(end_seek_time * sample_rate // clap_stride)
942
+ wav_embed = full_embed[start_offset:end_offset, ...]
943
+ wav_embed = wav_embed.mean(dim=0, keepdim=True)
944
+ return wav_embed.to(self.device) # [F, D]
945
+
946
+ def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
947
+ """Get CLAP embedding from a batch of text descriptions."""
948
+ no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
949
+ if self.text_cache is not None and no_nullified_cond:
950
+ assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided"
951
+ paths = [Path(p) for p in x.path if p is not None]
952
+ embed = self.text_cache.get_embed_from_cache(paths, x)
953
+ else:
954
+ text = [xi if xi is not None else "" for xi in x.text]
955
+ embed = self._compute_text_embedding(text)
956
+ if self.normalize:
957
+ embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
958
+ return embed
959
+
960
+ def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor:
961
+ """Get CLAP embedding from a batch of audio tensors (and corresponding sample rates)."""
962
+ no_undefined_paths = all(p is not None for p in x.path)
963
+ no_nullified_cond = x.wav.shape[-1] > 1 # we don't want to read from cache when condition dropout
964
+ if self.wav_cache is not None and no_undefined_paths and no_nullified_cond:
965
+ paths = [Path(p) for p in x.path if p is not None]
966
+ embed = self.wav_cache.get_embed_from_cache(paths, x)
967
+ else:
968
+ embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True)
969
+ if self.normalize:
970
+ embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1)
971
+ return embed
972
+
973
+ def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition:
974
+ # Trying to limit as much as possible sync points when the cache is warm.
975
+ no_undefined_paths = all(p is not None for p in x.path)
976
+ if self.wav_cache is not None and no_undefined_paths:
977
+ assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
978
+ paths = [Path(p) for p in x.path if p is not None]
979
+ self.wav_cache.populate_embed_cache(paths, x)
980
+ if self.text_cache is not None and no_undefined_paths:
981
+ assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided"
982
+ paths = [Path(p) for p in x.path if p is not None]
983
+ self.text_cache.populate_embed_cache(paths, x)
984
+ return x
985
+
986
+ def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]:
987
+ """Extract shared latent representation from either the wav or the text using CLAP."""
988
+ # decide whether to use text embedding at train time or not
989
+ use_text_embed = random.random() < self.text_p
990
+ if self.training and not use_text_embed:
991
+ embed = self._get_wav_embedding(x)
992
+ empty_idx = torch.LongTensor([]) # we assume we always have the audio wav
993
+ else:
994
+ embed = self._get_text_embedding(x)
995
+ empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""])
996
+ return embed, empty_idx
997
+
998
+
999
+ def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes:
1000
  """Utility function for nullifying an attribute inside an ConditioningAttributes object.
1001
+ If the condition is of type "wav", then nullify it using `nullify_condition` function.
1002
+ If the condition is of any other type, set its value to None.
1003
  Works in-place.
1004
  """
1005
+ if condition_type not in ['text', 'wav', 'joint_embed']:
1006
  raise ValueError(
1007
  "dropout_condition got an unexpected condition type!"
1008
+ f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'"
1009
  )
1010
 
1011
  if condition not in getattr(sample, condition_type):
1012
  raise ValueError(
1013
  "dropout_condition received an unexpected condition!"
1014
  f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
1015
+ f" but got '{condition}' of type '{condition_type}'!"
1016
  )
1017
 
1018
+ if condition_type == 'wav':
1019
+ wav_cond = sample.wav[condition]
1020
+ sample.wav[condition] = nullify_wav(wav_cond)
1021
+ elif condition_type == 'joint_embed':
1022
+ embed = sample.joint_embed[condition]
1023
+ sample.joint_embed[condition] = nullify_joint_embed(embed)
1024
  else:
1025
  sample.text[condition] = None
1026
 
 
1028
 
1029
 
1030
  class DropoutModule(nn.Module):
1031
+ """Base module for all dropout modules."""
1032
  def __init__(self, seed: int = 1234):
1033
  super().__init__()
1034
  self.rng = torch.Generator()
 
1036
 
1037
 
1038
  class AttributeDropout(DropoutModule):
1039
+ """Dropout with a given probability per attribute.
1040
+ This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes
1041
+ to be dropped out separately. For example, "artist" can be dropped while "genre" remains.
1042
+ This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre"
1043
+ must also be dropped.
1044
 
1045
  Args:
1046
  p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
 
1063
  def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
1064
  """
1065
  Args:
1066
+ samples (list[ConditioningAttributes]): List of conditions.
1067
  Returns:
1068
+ list[ConditioningAttributes]: List of conditions after certain attributes were set to None.
1069
  """
1070
  if not self.training and not self.active_on_eval:
1071
  return samples
1072
 
1073
  samples = deepcopy(samples)
 
1074
  for condition_type, ps in self.p.items(): # for condition types [text, wav]
1075
  for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
1076
  if torch.rand(1, generator=self.rng).item() < p:
1077
  for sample in samples:
1078
  dropout_condition(sample, condition_type, condition)
 
1079
  return samples
1080
 
1081
  def __repr__(self):
 
1083
 
1084
 
1085
  class ClassifierFreeGuidanceDropout(DropoutModule):
1086
+ """Classifier Free Guidance dropout.
1087
+ All attributes are dropped with the same probability.
1088
 
1089
  Args:
1090
  p (float): Probability to apply condition dropout during training.
 
1097
  def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
1098
  """
1099
  Args:
1100
+ samples (list[ConditioningAttributes]): List of conditions.
1101
  Returns:
1102
+ list[ConditioningAttributes]: List of conditions after all attributes were set to None.
1103
  """
1104
  if not self.training:
1105
  return samples
 
1111
 
1112
  # nullify conditions of all attributes
1113
  samples = deepcopy(samples)
 
1114
  for condition_type in ["wav", "text"]:
1115
  for sample in samples:
1116
  for condition in sample.attributes[condition_type]:
1117
  dropout_condition(sample, condition_type, condition)
 
1118
  return samples
1119
 
1120
  def __repr__(self):
 
1122
 
1123
 
1124
  class ConditioningProvider(nn.Module):
1125
+ """Prepare and provide conditions given all the supported conditioners.
1126
 
1127
  Args:
1128
  conditioners (dict): Dictionary of conditioners.
1129
+ device (torch.device or str, optional): Device for conditioners and output condition types.
 
 
 
 
1130
  """
1131
+ def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
 
 
 
 
 
 
1132
  super().__init__()
1133
  self.device = device
 
 
1134
  self.conditioners = nn.ModuleDict(conditioners)
1135
 
1136
+ @property
1137
+ def joint_embed_conditions(self):
1138
+ return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
1139
+
1140
+ @property
1141
+ def has_joint_embed_conditions(self):
1142
+ return len(self.joint_embed_conditions) > 0
1143
+
1144
  @property
1145
  def text_conditions(self):
1146
  return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
 
1159
  This will return a dict matching conditioner names to their arbitrary tokenized representations.
1160
 
1161
  Args:
1162
+ inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
1163
  text and wav conditions.
1164
  """
1165
+ assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
1166
+ "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
1167
  f" but types were {set([type(x) for x in inputs])}"
1168
+ )
1169
 
1170
  output = {}
1171
  text = self._collate_text(inputs)
1172
  wavs = self._collate_wavs(inputs)
1173
+ joint_embeds = self._collate_joint_embeds(inputs)
1174
 
1175
+ assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), (
1176
+ f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
1177
+ f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
1178
+ )
1179
 
1180
+ for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()):
1181
  output[attribute] = self.conditioners[attribute].tokenize(batch)
1182
  return output
1183
 
1184
  def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
1185
+ """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations.
1186
+ The output is for example:
1187
+ {
1188
+ "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
1189
+ "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
1190
+ ...
1191
+ }
 
1192
 
1193
  Args:
1194
  tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
 
1213
  "genre": ["Rock", "Hip-hop"],
1214
  "description": ["A rock song with a guitar solo", "A hip-hop verse"]
1215
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1216
 
1217
+ Args:
1218
+ samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
1219
+ Returns:
1220
+ dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch.
1221
+ """
1222
+ out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
1223
  texts = [x.text for x in samples]
1224
  for text in texts:
1225
  for condition in self.text_conditions:
1226
+ out[condition].append(text[condition])
1227
+ return out
 
1228
 
1229
+ def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]:
1230
  """Generate a dict where the keys are attributes by which we fetch similar wavs,
1231
+ and the values are Tensors of wavs according to said attributes.
1232
 
1233
  *Note*: by the time the samples reach this function, each sample should have some waveform
1234
  inside the "wav" attribute. It should be either:
 
1237
  3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
1238
 
1239
  Args:
1240
+ samples (list of ConditioningAttributes): List of ConditioningAttributes samples.
1241
  Returns:
1242
+ dict[str, WavCondition]: A dictionary mapping an attribute name to wavs.
1243
  """
1244
  wavs = defaultdict(list)
1245
+ lengths = defaultdict(list)
1246
+ sample_rates = defaultdict(list)
1247
  paths = defaultdict(list)
1248
+ seek_times = defaultdict(list)
1249
+ out: tp.Dict[str, WavCondition] = {}
1250
 
1251
  for sample in samples:
1252
  for attribute in self.wav_conditions:
1253
+ wav, length, sample_rate, path, seek_time = sample.wav[attribute]
1254
+ assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]"
1255
+ assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1"
1256
+ # mono-channel conditioning
1257
+ wav = wav.mean(1, keepdim=True) # [1, 1, T]
1258
+ wavs[attribute].append(wav.flatten()) # [T]
1259
+ lengths[attribute].append(length)
1260
+ sample_rates[attribute].extend(sample_rate)
1261
+ paths[attribute].extend(path)
1262
+ seek_times[attribute].extend(seek_time)
1263
 
1264
  # stack all wavs to a single tensor
1265
  for attribute in self.wav_conditions:
1266
  stacked_wav, _ = collate(wavs[attribute], dim=0)
1267
+ out[attribute] = WavCondition(
1268
+ stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute],
1269
+ paths[attribute], seek_times[attribute])
1270
+
1271
+ return out
1272
+
1273
+ def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]:
1274
+ """Generate a dict where the keys are attributes by which we compute joint embeddings,
1275
+ and the values are Tensors of pre-computed embeddings and the corresponding text attributes.
1276
+
1277
+ Args:
1278
+ samples (list[ConditioningAttributes]): List of ConditioningAttributes samples.
1279
+ Returns:
1280
+ A dictionary mapping an attribute name to joint embeddings.
1281
+ """
1282
+ texts = defaultdict(list)
1283
+ wavs = defaultdict(list)
1284
+ lengths = defaultdict(list)
1285
+ sample_rates = defaultdict(list)
1286
+ paths = defaultdict(list)
1287
+ seek_times = defaultdict(list)
1288
+ channels: int = 0
1289
+
1290
+ out = {}
1291
+ for sample in samples:
1292
+ for attribute in self.joint_embed_conditions:
1293
+ wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute]
1294
+ assert wav.dim() == 3
1295
+ if channels == 0:
1296
+ channels = wav.size(1)
1297
+ else:
1298
+ assert channels == wav.size(1), "not all audio has same number of channels in batch"
1299
+ assert wav.size(0) == 1, "Expecting single-wav batch in the collate method"
1300
+ wav = einops.rearrange(wav, "b c t -> (b c t)") # [1, C, T] => [C * T]
1301
+ wavs[attribute].append(wav)
1302
+ texts[attribute].extend(text)
1303
+ lengths[attribute].append(length)
1304
+ sample_rates[attribute].extend(sample_rate)
1305
+ paths[attribute].extend(path)
1306
+ seek_times[attribute].extend(seek_time)
1307
+
1308
+ for attribute in self.joint_embed_conditions:
1309
+ stacked_texts = texts[attribute]
1310
+ stacked_paths = paths[attribute]
1311
+ stacked_seek_times = seek_times[attribute]
1312
+ stacked_wavs = pad_sequence(wavs[attribute]).to(self.device)
1313
+ stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels)
1314
+ stacked_sample_rates = sample_rates[attribute]
1315
+ stacked_lengths = torch.cat(lengths[attribute]).to(self.device)
1316
+ assert stacked_lengths.size(0) == stacked_wavs.size(0)
1317
+ assert len(stacked_sample_rates) == stacked_wavs.size(0)
1318
+ assert len(stacked_texts) == stacked_wavs.size(0)
1319
+ out[attribute] = JointEmbedCondition(
1320
+ text=stacked_texts, wav=stacked_wavs,
1321
+ length=stacked_lengths, sample_rate=stacked_sample_rates,
1322
+ path=stacked_paths, seek_time=stacked_seek_times)
1323
 
1324
  return out
1325
 
 
1346
  super().__init__()
1347
  assert all(
1348
  [k in self.FUSING_METHODS for k in fuse2cond.keys()]
1349
+ ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}"
1350
  self.cross_attention_pos_emb = cross_attention_pos_emb
1351
  self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
1352
  self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
 
1357
 
1358
  def forward(
1359
  self,
1360
+ input: torch.Tensor,
1361
  conditions: tp.Dict[str, ConditionType]
1362
+ ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
1363
  """Fuse the conditions to the provided model input.
1364
 
1365
  Args:
1366
+ input (torch.Tensor): Transformer input.
1367
+ conditions (dict[str, ConditionType]): Dict of conditions.
1368
  Returns:
1369
+ tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input
1370
  after the conditions have been fused. The second output tensor is the tensor
1371
  used for cross-attention or None if no cross attention inputs exist.
1372
  """
 
1385
  cross_attention_output = None
1386
  for cond_type, (cond, cond_mask) in conditions.items():
1387
  op = self.cond2fuse[cond_type]
1388
+ if op == 'sum':
1389
  input += cond
1390
+ elif op == 'input_interpolate':
1391
+ cond = einops.rearrange(cond, "b t d -> b d t")
1392
  cond = F.interpolate(cond, size=input.shape[1])
1393
+ input += einops.rearrange(cond, "b d t -> b t d")
1394
+ elif op == 'prepend':
1395
  if first_step:
1396
  input = torch.cat([cond, input], dim=1)
1397
+ elif op == 'cross':
1398
  if cross_attention_output is not None:
1399
  cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
1400
  else:
audiocraft/modules/conv.py CHANGED
@@ -11,7 +11,7 @@ import warnings
11
  import torch
12
  from torch import nn
13
  from torch.nn import functional as F
14
- from torch.nn.utils import spectral_norm, weight_norm
15
 
16
 
17
  CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
 
11
  import torch
12
  from torch import nn
13
  from torch.nn import functional as F
14
+ from torch.nn.utils.parametrizations import spectral_norm, weight_norm
15
 
16
 
17
  CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
audiocraft/modules/diffusion_schedule.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
9
+ """
10
+
11
+ from collections import namedtuple
12
+ import random
13
+ import typing as tp
14
+ import julius
15
+ import torch
16
+
17
+ TrainingItem = namedtuple("TrainingItem", "noisy noise step")
18
+
19
+
20
+ def betas_from_alpha_bar(alpha_bar):
21
+ alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
22
+ return 1 - alphas
23
+
24
+
25
+ class SampleProcessor(torch.nn.Module):
26
+ def project_sample(self, x: torch.Tensor):
27
+ """Project the original sample to the 'space' where the diffusion will happen."""
28
+ return x
29
+
30
+ def return_sample(self, z: torch.Tensor):
31
+ """Project back from diffusion space to the actual sample space."""
32
+ return z
33
+
34
+
35
+ class MultiBandProcessor(SampleProcessor):
36
+ """
37
+ MultiBand sample processor. The input audio is splitted across
38
+ frequency bands evenly distributed in mel-scale.
39
+
40
+ Each band will be rescaled to match the power distribution
41
+ of Gaussian noise in that band, using online metrics
42
+ computed on the first few samples.
43
+
44
+ Args:
45
+ n_bands (int): Number of mel-bands to split the signal over.
46
+ sample_rate (int): Sample rate of the audio.
47
+ num_samples (int): Number of samples to use to fit the rescaling
48
+ for each band. The processor won't be stable
49
+ until it has seen that many samples.
50
+ power_std (float or list/tensor): The rescaling factor computed to match the
51
+ power of Gaussian noise in each band is taken to
52
+ that power, i.e. `1.` means full correction of the energy
53
+ in each band, and values less than `1` means only partial
54
+ correction. Can be used to balance the relative importance
55
+ of low vs. high freq in typical audio signals.
56
+ """
57
+ def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
58
+ num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
59
+ super().__init__()
60
+ self.n_bands = n_bands
61
+ self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
62
+ self.num_samples = num_samples
63
+ self.power_std = power_std
64
+ if isinstance(power_std, list):
65
+ assert len(power_std) == n_bands
66
+ power_std = torch.tensor(power_std)
67
+ self.register_buffer('counts', torch.zeros(1))
68
+ self.register_buffer('sum_x', torch.zeros(n_bands))
69
+ self.register_buffer('sum_x2', torch.zeros(n_bands))
70
+ self.register_buffer('sum_target_x2', torch.zeros(n_bands))
71
+ self.counts: torch.Tensor
72
+ self.sum_x: torch.Tensor
73
+ self.sum_x2: torch.Tensor
74
+ self.sum_target_x2: torch.Tensor
75
+
76
+ @property
77
+ def mean(self):
78
+ mean = self.sum_x / self.counts
79
+ return mean
80
+
81
+ @property
82
+ def std(self):
83
+ std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
84
+ return std
85
+
86
+ @property
87
+ def target_std(self):
88
+ target_std = self.sum_target_x2 / self.counts
89
+ return target_std
90
+
91
+ def project_sample(self, x: torch.Tensor):
92
+ assert x.dim() == 3
93
+ bands = self.split_bands(x)
94
+ if self.counts.item() < self.num_samples:
95
+ ref_bands = self.split_bands(torch.randn_like(x))
96
+ self.counts += len(x)
97
+ self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
98
+ self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
99
+ self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
100
+ rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
101
+ bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
102
+ return bands.sum(dim=0)
103
+
104
+ def return_sample(self, x: torch.Tensor):
105
+ assert x.dim() == 3
106
+ bands = self.split_bands(x)
107
+ rescale = (self.std / self.target_std) ** self.power_std
108
+ bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
109
+ return bands.sum(dim=0)
110
+
111
+
112
+ class NoiseSchedule:
113
+ """Noise schedule for diffusion.
114
+
115
+ Args:
116
+ beta_t0 (float): Variance of the first diffusion step.
117
+ beta_t1 (float): Variance of the last diffusion step.
118
+ beta_exp (float): Power schedule exponent
119
+ num_steps (int): Number of diffusion step.
120
+ variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
121
+ clip (float): clipping value for the denoising steps
122
+ rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
123
+ repartition (str): shape of the schedule only power schedule is supported
124
+ sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
125
+ noise_scale (float): Scaling factor for the noise
126
+ """
127
+ def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
128
+ clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
129
+ repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
130
+ sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
131
+
132
+ self.beta_t0 = beta_t0
133
+ self.beta_t1 = beta_t1
134
+ self.variance = variance
135
+ self.num_steps = num_steps
136
+ self.clip = clip
137
+ self.sample_processor = sample_processor
138
+ self.rescale = rescale
139
+ self.n_bands = n_bands
140
+ self.noise_scale = noise_scale
141
+ assert n_bands is None
142
+ if repartition == "power":
143
+ self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
144
+ device=device, dtype=torch.float) ** beta_exp
145
+ else:
146
+ raise RuntimeError('Not implemented')
147
+ self.rng = random.Random(1234)
148
+
149
+ def get_beta(self, step: tp.Union[int, torch.Tensor]):
150
+ if self.n_bands is None:
151
+ return self.betas[step]
152
+ else:
153
+ return self.betas[:, step] # [n_bands, len(step)]
154
+
155
+ def get_initial_noise(self, x: torch.Tensor):
156
+ if self.n_bands is None:
157
+ return torch.randn_like(x)
158
+ return torch.randn((x.size(0), self.n_bands, x.size(2)))
159
+
160
+ def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
161
+ """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
162
+ if step is None:
163
+ return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
164
+ if type(step) is int:
165
+ return (1 - self.betas[:step + 1]).prod()
166
+ else:
167
+ return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
168
+
169
+ def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
170
+ """Create a noisy data item for diffusion model training:
171
+
172
+ Args:
173
+ x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
174
+ tensor_step (bool): If tensor_step = false, only one step t is sample,
175
+ the whole batch is diffused to the same step and t is int.
176
+ If tensor_step = true, t is a tensor of size (x.size(0),)
177
+ every element of the batch is diffused to a independently sampled.
178
+ """
179
+ step: tp.Union[int, torch.Tensor]
180
+ if tensor_step:
181
+ bs = x.size(0)
182
+ step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
183
+ else:
184
+ step = self.rng.randrange(self.num_steps)
185
+ alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
186
+
187
+ x = self.sample_processor.project_sample(x)
188
+ noise = torch.randn_like(x)
189
+ noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
190
+ return TrainingItem(noisy, noise, step)
191
+
192
+ def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
193
+ condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
194
+ """Full ddpm reverse process.
195
+
196
+ Args:
197
+ model (nn.Module): Diffusion model.
198
+ initial (tensor): Initial Noise.
199
+ condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
200
+ return_list (bool): Whether to return the whole process or only the sampled point.
201
+ """
202
+ alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
203
+ current = initial
204
+ iterates = [initial]
205
+ for step in range(self.num_steps)[::-1]:
206
+ with torch.no_grad():
207
+ estimate = model(current, step, condition=condition).sample
208
+ alpha = 1 - self.betas[step]
209
+ previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
210
+ previous_alpha_bar = self.get_alpha_bar(step=step - 1)
211
+ if step == 0:
212
+ sigma2 = 0
213
+ elif self.variance == 'beta':
214
+ sigma2 = 1 - alpha
215
+ elif self.variance == 'beta_tilde':
216
+ sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
217
+ elif self.variance == 'none':
218
+ sigma2 = 0
219
+ else:
220
+ raise ValueError(f'Invalid variance type {self.variance}')
221
+
222
+ if sigma2 > 0:
223
+ previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
224
+ if self.clip:
225
+ previous = previous.clamp(-self.clip, self.clip)
226
+ current = previous
227
+ alpha_bar = previous_alpha_bar
228
+ if step == 0:
229
+ previous *= self.rescale
230
+ if return_list:
231
+ iterates.append(previous.cpu())
232
+
233
+ if return_list:
234
+ return iterates
235
+ else:
236
+ return self.sample_processor.return_sample(previous)
237
+
238
+ def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
239
+ condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
240
+ """Reverse process that only goes through Markov chain states in step_list."""
241
+ if step_list is None:
242
+ step_list = list(range(1000))[::-50] + [0]
243
+ alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
244
+ alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
245
+ betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
246
+ current = initial * self.noise_scale
247
+ iterates = [current]
248
+ for idx, step in enumerate(step_list[:-1]):
249
+ with torch.no_grad():
250
+ estimate = model(current, step, condition=condition).sample * self.noise_scale
251
+ alpha = 1 - betas_subsampled[-1 - idx]
252
+ previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
253
+ previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
254
+ if step == step_list[-2]:
255
+ sigma2 = 0
256
+ previous_alpha_bar = torch.tensor(1.0)
257
+ else:
258
+ sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
259
+ if sigma2 > 0:
260
+ previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
261
+ if self.clip:
262
+ previous = previous.clamp(-self.clip, self.clip)
263
+ current = previous
264
+ alpha_bar = previous_alpha_bar
265
+ if step == 0:
266
+ previous *= self.rescale
267
+ if return_list:
268
+ iterates.append(previous.cpu())
269
+ if return_list:
270
+ return iterates
271
+ else:
272
+ return self.sample_processor.return_sample(previous)
audiocraft/modules/rope.py CHANGED
@@ -18,7 +18,7 @@ class XPos(nn.Module):
18
  dim (int): Embedding dimension.
19
  smoothing (float): Smoothing factor applied to the decay rates.
20
  base_scale (int): Base decay rate, given in terms of scaling time.
21
- device (torch.device or None): Device on which to initialize the module.
22
  dtype (torch.dtype): dtype to use to generate the embedding.
23
  """
24
  def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
@@ -36,8 +36,7 @@ class XPos(nn.Module):
36
  self.decay: tp.Optional[torch.Tensor] = None
37
 
38
  def get_decay(self, start: int, end: int):
39
- """Create complex decay tensor, cache values for fast computation.
40
- """
41
  if self.decay is None or end > self.decay.shape[0]:
42
  assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
43
  idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
@@ -55,7 +54,7 @@ class RotaryEmbedding(nn.Module):
55
  max_period (float): Maximum period of the rotation frequencies.
56
  xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
57
  scale (float): Scale of positional embedding, set to 0 to deactivate.
58
- device (torch.device or None): Device on which to initialize the module.
59
  dtype (torch.dtype): dtype to use to generate the embedding.
60
  """
61
  def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
@@ -74,8 +73,7 @@ class RotaryEmbedding(nn.Module):
74
  self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
75
 
76
  def get_rotation(self, start: int, end: int):
77
- """Create complex rotation tensor, cache values for fast computation.
78
- """
79
  if self.rotation is None or end > self.rotation.shape[0]:
80
  assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
81
  idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
@@ -83,14 +81,16 @@ class RotaryEmbedding(nn.Module):
83
  self.rotation = torch.polar(torch.ones_like(angles), angles)
84
  return self.rotation[start:end]
85
 
86
- def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
87
- """Apply rope rotation to query or key tensor.
88
- """
89
- T = x.shape[1]
90
- rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
 
 
91
 
92
  if self.xpos:
93
- decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
94
  else:
95
  decay = 1.0
96
 
@@ -99,26 +99,27 @@ class RotaryEmbedding(nn.Module):
99
 
100
  x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
101
  scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
102
- x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
103
 
104
  return x_out.type_as(x)
105
 
106
- def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
107
  """ Apply rope rotation to both query and key tensors.
108
  Supports streaming mode, in which query and key are not expected to have the same shape.
109
- In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but
110
  query will be [C] (typically C == 1).
111
 
112
  Args:
113
  query (torch.Tensor): Query to rotate.
114
  key (torch.Tensor): Key to rotate.
115
  start (int): Start index of the sequence for time offset.
 
116
  """
117
- query_timesteps = query.shape[1]
118
- key_timesteps = key.shape[1]
119
  streaming_offset = key_timesteps - query_timesteps
120
 
121
- query_out = self.rotate(query, start + streaming_offset)
122
- key_out = self.rotate(key, start, invert_decay=True)
123
 
124
  return query_out, key_out
 
18
  dim (int): Embedding dimension.
19
  smoothing (float): Smoothing factor applied to the decay rates.
20
  base_scale (int): Base decay rate, given in terms of scaling time.
21
+ device (torch.device, optional): Device on which to initialize the module.
22
  dtype (torch.dtype): dtype to use to generate the embedding.
23
  """
24
  def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
 
36
  self.decay: tp.Optional[torch.Tensor] = None
37
 
38
  def get_decay(self, start: int, end: int):
39
+ """Create complex decay tensor, cache values for fast computation."""
 
40
  if self.decay is None or end > self.decay.shape[0]:
41
  assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
42
  idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
 
54
  max_period (float): Maximum period of the rotation frequencies.
55
  xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
56
  scale (float): Scale of positional embedding, set to 0 to deactivate.
57
+ device (torch.device, optional): Device on which to initialize the module.
58
  dtype (torch.dtype): dtype to use to generate the embedding.
59
  """
60
  def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
 
73
  self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
74
 
75
  def get_rotation(self, start: int, end: int):
76
+ """Create complex rotation tensor, cache values for fast computation."""
 
77
  if self.rotation is None or end > self.rotation.shape[0]:
78
  assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
79
  idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
 
81
  self.rotation = torch.polar(torch.ones_like(angles), angles)
82
  return self.rotation[start:end]
83
 
84
+ def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
85
+ """Apply rope rotation to query or key tensor."""
86
+ T = x.shape[time_dim]
87
+ target_shape = [1] * x.dim()
88
+ target_shape[time_dim] = T
89
+ target_shape[-1] = -1
90
+ rotation = self.get_rotation(start, start + T).view(target_shape)
91
 
92
  if self.xpos:
93
+ decay = self.xpos.get_decay(start, start + T).view(target_shape)
94
  else:
95
  decay = 1.0
96
 
 
99
 
100
  x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
101
  scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
102
+ x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)
103
 
104
  return x_out.type_as(x)
105
 
106
+ def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
107
  """ Apply rope rotation to both query and key tensors.
108
  Supports streaming mode, in which query and key are not expected to have the same shape.
109
+ In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
110
  query will be [C] (typically C == 1).
111
 
112
  Args:
113
  query (torch.Tensor): Query to rotate.
114
  key (torch.Tensor): Key to rotate.
115
  start (int): Start index of the sequence for time offset.
116
+ time_dim (int): which dimension represent the time steps.
117
  """
118
+ query_timesteps = query.shape[time_dim]
119
+ key_timesteps = key.shape[time_dim]
120
  streaming_offset = key_timesteps - query_timesteps
121
 
122
+ query_out = self.rotate(query, start + streaming_offset, time_dim)
123
+ key_out = self.rotate(key, start, time_dim, invert_decay=True)
124
 
125
  return query_out, key_out
audiocraft/modules/transformer.py CHANGED
@@ -35,8 +35,8 @@ def set_efficient_attention_backend(backend: str = 'torch'):
35
  _efficient_attention_backend = backend
36
 
37
 
38
- def _get_attention_time_dimension() -> int:
39
- if _efficient_attention_backend == 'torch':
40
  return 2
41
  else:
42
  return 1
@@ -89,11 +89,11 @@ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float =
89
  return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
90
 
91
 
92
- def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
93
- """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
94
  if n_rep == 1:
95
  return x
96
- if _efficient_attention_backend == 'torch':
97
  bs, n_kv_heads, slen, head_dim = x.shape
98
  return (
99
  x[:, :, None, :, :]
@@ -111,14 +111,14 @@ def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
111
 
112
  class LayerScale(nn.Module):
113
  """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
114
- This rescales diagonaly the residual outputs close to 0, with a learnt scale.
115
 
116
  Args:
117
  channels (int): Number of channels.
118
  init (float): Initial scale.
119
  channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
120
- device (torch.device or None): Device on which to initialize the module.
121
- dtype (torch.dtype or None): dtype to use to initialize the module.
122
  """
123
  def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
124
  device=None, dtype=None):
@@ -144,22 +144,22 @@ class StreamingMultiheadAttention(StreamingModule):
144
  dropout (float): Dropout level.
145
  bias (bool): Use bias in projections.
146
  causal (bool): Causal mask applied automatically.
147
- past_context (int or None): Receptive field for the causal mask, infinite if None.
148
  custom (bool): Use custom MHA implementation, for testing / benchmarking.
149
  memory_efficient (bool): Use xformers based memory efficient attention.
150
  attention_as_float32 (bool): Perform the attention as float32
151
  (especially important with memory_efficient as autocast won't do this automatically).
152
- rope (`RotaryEmbedding` or None): Rope embedding to use.
153
  cross_attention: Should be true when used as a cross attention.
154
  All keys and values must be available at once, streaming is only for the queries.
155
  Cannot be used with `causal` or `rope` (as it wouldn't make sens to
156
- intepret the time steps in the keys relative to those in the queries).
157
  safe_streaming (bool): Bug fix, will go away with xformers update.
158
  qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
159
  kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
160
  This will lead to faster decoding time on A100 or other GPUs with tensorcore.
161
- device (torch.device or None): Sevice on which to initialize.
162
- dtype (torch.dtype or None): dtype to use.
163
  """
164
  def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
165
  causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
@@ -234,14 +234,14 @@ class StreamingMultiheadAttention(StreamingModule):
234
  # Return a causal mask, accounting for potentially stored past keys/values
235
  # We actually return a bias for the attention score, as this has the same
236
  # convention both in the builtin MHA in Pytorch, and Xformers functions.
237
- time_dim = _get_attention_time_dimension()
238
  if self.memory_efficient:
239
  from xformers.ops import LowerTriangularMask
240
  if current_steps == 1:
241
  # If we only have one step, then we do not need a mask.
242
  return None
243
  elif 'past_keys' in self._streaming_state:
244
- raise RuntimeError('Not supported at the moment')
245
  else:
246
  # Then we can safely use a lower triangular mask
247
  return LowerTriangularMask()
@@ -264,7 +264,7 @@ class StreamingMultiheadAttention(StreamingModule):
264
  torch.full([], float('-inf'), device=device, dtype=dtype))
265
 
266
  def _complete_kv(self, k, v):
267
- time_dim = _get_attention_time_dimension()
268
  if self.cross_attention:
269
  # With cross attention we assume all keys and values
270
  # are already available, and streaming is with respect
@@ -298,8 +298,7 @@ class StreamingMultiheadAttention(StreamingModule):
298
  return nk, nv
299
 
300
  def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
301
- # TODO: fix and verify layout.
302
- assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.'
303
  # Apply rope embeddings to query and key tensors.
304
  assert self.rope is not None
305
  if 'past_keys' in self._streaming_state:
@@ -311,16 +310,16 @@ class StreamingMultiheadAttention(StreamingModule):
311
  else:
312
  past_context_offset = 0
313
  streaming_offset = past_context_offset + past_keys_offset
314
- return self.rope.rotate_qk(query, key, start=streaming_offset)
315
 
316
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
317
  key_padding_mask=None, need_weights=False, attn_mask=None,
318
  average_attn_weights=True, is_causal=False):
319
  assert attn_mask is None
320
- assert not is_causal, ("new param added in torch 2.0.1 not supported, "
321
  "use the causal args in the constructor.")
322
 
323
- time_dim = _get_attention_time_dimension()
324
  if time_dim == 2:
325
  layout = "b h t d"
326
  else:
@@ -394,8 +393,8 @@ class StreamingMultiheadAttention(StreamingModule):
394
  q, k = self._apply_rope(q, k)
395
  k, v = self._complete_kv(k, v)
396
  if self.kv_repeat > 1:
397
- k = expand_repeated_kv(k, self.kv_repeat)
398
- v = expand_repeated_kv(v, self.kv_repeat)
399
  if self.attention_as_float32:
400
  q, k, v = [x.float() for x in [q, k, v]]
401
  if self.memory_efficient:
@@ -455,7 +454,7 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
455
  bias_ff (bool): Use bias for FF.
456
  bias_attn (bool): Use bias for MHA.
457
  causal (bool): Causal mask applied automatically.
458
- past_context (int or None): Receptive field for the causal mask, infinite if None.
459
  custom (bool): Use custom MHA implementation, for testing / benchmarking.
460
  memory_efficient (bool): Use xformers based memory efficient attention.
461
  attention_as_float32 (bool): Perform the attention as float32
@@ -465,15 +464,15 @@ class StreamingTransformerLayer(nn.TransformerEncoderLayer):
465
  cross_attention (bool): If True, expect to get secondary input for cross-attention.
466
  Cross attention will use the default MHA, as it typically won't require
467
  special treatment.
468
- layer_scale (float or None): If not None, LayerScale will be used with
469
  the given value as initial scale.
470
- rope (`RotaryEmbedding` or None): Rope embedding to use.
471
- attention_dropout (float or None): If not None, separate the value of the dimension dropout
472
  in FFN and of the attention dropout.
473
  kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
474
  This will lead to faster decoding time on A100 or other GPUs with tensorcore.
475
- device (torch.device or None): Device on which to initialize.
476
- dtype (torch.dtype or None): dtype to use.
477
  **kwargs: See `nn.TransformerEncoderLayer`.
478
  """
479
  def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
@@ -576,30 +575,30 @@ class StreamingTransformer(StreamingModule):
576
  bias_ff (bool): Use bias for FF.
577
  bias_attn (bool): Use bias for MHA.
578
  causal (bool): Causal mask applied automatically.
579
- past_context (int or None): Receptive field for the causal mask, infinite if None.
580
  custom (bool): Use custom MHA implementation, for testing / benchmarking.
581
  memory_efficient (bool): Use xformers based memory efficient attention.
582
  attention_as_float32 (bool): Perform the attention as float32
583
  (especially important with memory_efficient as autocast won't do this automatically).
584
  cross_attention (bool): If True, expect to get secondary input for cross-attention.
585
- layer_scale (float or None): If not None, LayerScale will be used
586
  with the given value as initial scale.
587
  positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
588
  max_period (float): Maximum period of the time embedding.
589
  positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
590
  xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
591
- lr (float or None): learning rate override through the `make_optim_group` API.
592
- weight_decay (float or None): Weight_decay override through the `make_optim_group` API.
593
  layer_class: (subclass of `StreamingTransformerLayer): class to use
594
- to initialize the layers, allowing further customization outside of Audiocraft.
595
  checkpointing (str): Checkpointing strategy to reduce memory usage.
596
  No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
597
  if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
598
  minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
599
  a policy for opting-out some operations of the checkpointing like
600
  linear layers and attention, providing a middle ground between speed and memory.
601
- device (torch.device or None): Device on which to initialize.
602
- dtype (torch.dtype or None): dtype to use.
603
  **kwargs: See `nn.TransformerEncoderLayer`.
604
  """
605
  def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
@@ -649,7 +648,6 @@ class StreamingTransformer(StreamingModule):
649
  # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
650
  # backward hook inside of FSDP...
651
  layer._magma_checkpointed = True # type: ignore
652
- assert layer.layer_drop == 0., "Need further checking" # type: ignore
653
 
654
  def _apply_layer(self, layer, *args, **kwargs):
655
  method = self.checkpointing
@@ -713,7 +711,7 @@ class StreamingTransformer(StreamingModule):
713
  return group
714
 
715
 
716
- # special attention attention related function
717
 
718
  def _verify_xformers_memory_efficient_compat():
719
  try:
 
35
  _efficient_attention_backend = backend
36
 
37
 
38
+ def _get_attention_time_dimension(memory_efficient: bool) -> int:
39
+ if _efficient_attention_backend == 'torch' and memory_efficient:
40
  return 2
41
  else:
42
  return 1
 
89
  return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
90
 
91
 
92
+ def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor:
93
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers."""
94
  if n_rep == 1:
95
  return x
96
+ if _efficient_attention_backend == 'torch' and memory_efficient:
97
  bs, n_kv_heads, slen, head_dim = x.shape
98
  return (
99
  x[:, :, None, :, :]
 
111
 
112
  class LayerScale(nn.Module):
113
  """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
114
+ This rescales diagonally the residual outputs close to 0, with a learnt scale.
115
 
116
  Args:
117
  channels (int): Number of channels.
118
  init (float): Initial scale.
119
  channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
120
+ device (torch.device or str, optional): Device on which to initialize the module.
121
+ dtype (torch.dtype, optional): dtype to use to initialize the module.
122
  """
123
  def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
124
  device=None, dtype=None):
 
144
  dropout (float): Dropout level.
145
  bias (bool): Use bias in projections.
146
  causal (bool): Causal mask applied automatically.
147
+ past_context (int, optional): Receptive field for the causal mask, infinite if None.
148
  custom (bool): Use custom MHA implementation, for testing / benchmarking.
149
  memory_efficient (bool): Use xformers based memory efficient attention.
150
  attention_as_float32 (bool): Perform the attention as float32
151
  (especially important with memory_efficient as autocast won't do this automatically).
152
+ rope (`RotaryEmbedding`, optional): Rope embedding to use.
153
  cross_attention: Should be true when used as a cross attention.
154
  All keys and values must be available at once, streaming is only for the queries.
155
  Cannot be used with `causal` or `rope` (as it wouldn't make sens to
156
+ interpret the time steps in the keys relative to those in the queries).
157
  safe_streaming (bool): Bug fix, will go away with xformers update.
158
  qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
159
  kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
160
  This will lead to faster decoding time on A100 or other GPUs with tensorcore.
161
+ device (torch.device, optional): Device on which to initialize.
162
+ dtype (torch.dtype, optional): dtype to use.
163
  """
164
  def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
165
  causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
 
234
  # Return a causal mask, accounting for potentially stored past keys/values
235
  # We actually return a bias for the attention score, as this has the same
236
  # convention both in the builtin MHA in Pytorch, and Xformers functions.
237
+ time_dim = _get_attention_time_dimension(self.memory_efficient)
238
  if self.memory_efficient:
239
  from xformers.ops import LowerTriangularMask
240
  if current_steps == 1:
241
  # If we only have one step, then we do not need a mask.
242
  return None
243
  elif 'past_keys' in self._streaming_state:
244
+ raise RuntimeError("Not supported at the moment")
245
  else:
246
  # Then we can safely use a lower triangular mask
247
  return LowerTriangularMask()
 
264
  torch.full([], float('-inf'), device=device, dtype=dtype))
265
 
266
  def _complete_kv(self, k, v):
267
+ time_dim = _get_attention_time_dimension(self.memory_efficient)
268
  if self.cross_attention:
269
  # With cross attention we assume all keys and values
270
  # are already available, and streaming is with respect
 
298
  return nk, nv
299
 
300
  def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
301
+ time_dim = _get_attention_time_dimension(self.memory_efficient)
 
302
  # Apply rope embeddings to query and key tensors.
303
  assert self.rope is not None
304
  if 'past_keys' in self._streaming_state:
 
310
  else:
311
  past_context_offset = 0
312
  streaming_offset = past_context_offset + past_keys_offset
313
+ return self.rope.rotate_qk(query, key, start=streaming_offset, time_dim=time_dim)
314
 
315
  def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
316
  key_padding_mask=None, need_weights=False, attn_mask=None,
317
  average_attn_weights=True, is_causal=False):
318
  assert attn_mask is None
319
+ assert not is_causal, ("New param added in torch 2.0.1 not supported, "
320
  "use the causal args in the constructor.")
321
 
322
+ time_dim = _get_attention_time_dimension(self.memory_efficient)
323
  if time_dim == 2:
324
  layout = "b h t d"
325
  else:
 
393
  q, k = self._apply_rope(q, k)
394
  k, v = self._complete_kv(k, v)
395
  if self.kv_repeat > 1:
396
+ k = expand_repeated_kv(k, self.kv_repeat, self.memory_efficient)
397
+ v = expand_repeated_kv(v, self.kv_repeat, self.memory_efficient)
398
  if self.attention_as_float32:
399
  q, k, v = [x.float() for x in [q, k, v]]
400
  if self.memory_efficient:
 
454
  bias_ff (bool): Use bias for FF.
455
  bias_attn (bool): Use bias for MHA.
456
  causal (bool): Causal mask applied automatically.
457
+ past_context (int, optional): Receptive field for the causal mask, infinite if None.
458
  custom (bool): Use custom MHA implementation, for testing / benchmarking.
459
  memory_efficient (bool): Use xformers based memory efficient attention.
460
  attention_as_float32 (bool): Perform the attention as float32
 
464
  cross_attention (bool): If True, expect to get secondary input for cross-attention.
465
  Cross attention will use the default MHA, as it typically won't require
466
  special treatment.
467
+ layer_scale (float, optional): If not None, LayerScale will be used with
468
  the given value as initial scale.
469
+ rope (`RotaryEmbedding`, optional): Rope embedding to use.
470
+ attention_dropout (float, optional): If not None, separate the value of the dimension dropout
471
  in FFN and of the attention dropout.
472
  kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
473
  This will lead to faster decoding time on A100 or other GPUs with tensorcore.
474
+ device (torch.device, optional): Device on which to initialize.
475
+ dtype (torch.dtype, optional): dtype to use.
476
  **kwargs: See `nn.TransformerEncoderLayer`.
477
  """
478
  def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
 
575
  bias_ff (bool): Use bias for FF.
576
  bias_attn (bool): Use bias for MHA.
577
  causal (bool): Causal mask applied automatically.
578
+ past_context (int, optional): Receptive field for the causal mask, infinite if None.
579
  custom (bool): Use custom MHA implementation, for testing / benchmarking.
580
  memory_efficient (bool): Use xformers based memory efficient attention.
581
  attention_as_float32 (bool): Perform the attention as float32
582
  (especially important with memory_efficient as autocast won't do this automatically).
583
  cross_attention (bool): If True, expect to get secondary input for cross-attention.
584
+ layer_scale (float, optional): If not None, LayerScale will be used
585
  with the given value as initial scale.
586
  positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
587
  max_period (float): Maximum period of the time embedding.
588
  positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
589
  xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
590
+ lr (float, optional): learning rate override through the `make_optim_group` API.
591
+ weight_decay (float, optional): Weight_decay override through the `make_optim_group` API.
592
  layer_class: (subclass of `StreamingTransformerLayer): class to use
593
+ to initialize the layers, allowing further customization outside of AudioCraft.
594
  checkpointing (str): Checkpointing strategy to reduce memory usage.
595
  No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
596
  if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
597
  minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
598
  a policy for opting-out some operations of the checkpointing like
599
  linear layers and attention, providing a middle ground between speed and memory.
600
+ device (torch.device, optional): Device on which to initialize.
601
+ dtype (torch.dtype, optional): dtype to use.
602
  **kwargs: See `nn.TransformerEncoderLayer`.
603
  """
604
  def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
 
648
  # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
649
  # backward hook inside of FSDP...
650
  layer._magma_checkpointed = True # type: ignore
 
651
 
652
  def _apply_layer(self, layer, *args, **kwargs):
653
  method = self.checkpointing
 
711
  return group
712
 
713
 
714
+ # special attention related function
715
 
716
  def _verify_xformers_memory_efficient_compat():
717
  try:
audiocraft/quantization/core_vq.py CHANGED
@@ -75,7 +75,7 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10):
75
  return means, bins
76
 
77
 
78
- def orthgonal_loss_fn(t):
79
  # eq (2) from https://arxiv.org/abs/2112.00384
80
  n = t.shape[0]
81
  normed_codes = l2norm(t)
@@ -237,7 +237,7 @@ class VectorQuantization(nn.Module):
237
  orthogonal_reg_weight (float): Orthogonal regularization weights.
238
  orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
239
  orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
240
- for orthogonal regulariation.
241
  threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
242
  that have an exponential moving average cluster size less than the specified threshold with
243
  randomly selected vector from the current batch.
@@ -340,7 +340,7 @@ class VectorQuantization(nn.Module):
340
  rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
341
  codebook = codebook[rand_ids]
342
 
343
- orthogonal_reg_loss = orthgonal_loss_fn(codebook)
344
  loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
345
 
346
  quantize = self.project_out(quantize)
@@ -371,11 +371,16 @@ class ResidualVectorQuantization(nn.Module):
371
 
372
  for i, layer in enumerate(self.layers[:n_q]):
373
  quantized, indices, loss = layer(residual)
 
374
  residual = residual - quantized
375
  quantized_out = quantized_out + quantized
376
  all_indices.append(indices)
377
  all_losses.append(loss)
378
 
 
 
 
 
379
  out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
380
  return quantized_out, out_indices, out_losses
381
 
 
75
  return means, bins
76
 
77
 
78
+ def orthogonal_loss_fn(t):
79
  # eq (2) from https://arxiv.org/abs/2112.00384
80
  n = t.shape[0]
81
  normed_codes = l2norm(t)
 
237
  orthogonal_reg_weight (float): Orthogonal regularization weights.
238
  orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
239
  orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
240
+ for orthogonal regularization.
241
  threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
242
  that have an exponential moving average cluster size less than the specified threshold with
243
  randomly selected vector from the current batch.
 
340
  rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
341
  codebook = codebook[rand_ids]
342
 
343
+ orthogonal_reg_loss = orthogonal_loss_fn(codebook)
344
  loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
345
 
346
  quantize = self.project_out(quantize)
 
371
 
372
  for i, layer in enumerate(self.layers[:n_q]):
373
  quantized, indices, loss = layer(residual)
374
+ quantized = quantized.detach()
375
  residual = residual - quantized
376
  quantized_out = quantized_out + quantized
377
  all_indices.append(indices)
378
  all_losses.append(loss)
379
 
380
+ if self.training:
381
+ # Solving subtle bug with STE and RVQ: https://github.com/facebookresearch/encodec/issues/25
382
+ quantized_out = x + (quantized_out - x).detach()
383
+
384
  out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
385
  return quantized_out, out_indices, out_losses
386
 
audiocraft/utils/cache.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from collections import deque
9
+ from functools import partial
10
+ from hashlib import sha1
11
+ import logging
12
+ from pathlib import Path
13
+ import sys
14
+ import typing as tp
15
+ import zipfile
16
+
17
+ import flashy
18
+ import torch
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor:
25
+ """Utility function for the EmbeddingCache, returning the full embedding without any chunking.
26
+ This method can be used in case there is no need in extracting a chunk of the full embedding
27
+ read from the cache.
28
+
29
+ Args:
30
+ full_embed (torch.Tensor): The full embedding.
31
+ x (any): Batch object from which the full embedding is derived.
32
+ idx (torch.Tensor): Index of object to consider in the batch object.
33
+ Returns:
34
+ full_embed (torch.Tensor): The full embedding
35
+ """
36
+ return full_embed.to(device)
37
+
38
+
39
+ class EmbeddingCache:
40
+ """Cache around embeddings computation for faster execution.
41
+ The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API
42
+ to retrieve the pre-computed embeddings on full inputs and extract only a given chunk
43
+ using a user-provided function. When the cache is warm (all embeddings are pre-computed),
44
+ the EmbeddingCache allows for faster training as it removes the need of computing the embeddings.
45
+ Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint
46
+ and synchronization points in the forward calls.
47
+
48
+ Args:
49
+ cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk.
50
+ device (str or torch.device): Device on which the embedding is returned.
51
+ compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute
52
+ the embedding from a given object and path. This user provided function can compute the
53
+ embedding from the provided object or using the provided path as entry point. The last parameter
54
+ specify the index corresponding to the current embedding in the object that can represent batch metadata.
55
+ extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract
56
+ the desired embedding chunk from the full embedding loaded from the cache. The last parameter
57
+ specify the index corresponding to the current embedding in the object that can represent batch metadata.
58
+ If not specified, will return the full embedding unmodified.
59
+ """
60
+ def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device],
61
+ compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
62
+ extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
63
+ self.cache_path = Path(cache_path)
64
+ self.device = device
65
+ self._compute_embed_fn = compute_embed_fn
66
+ self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]
67
+ if extract_embed_fn is not None:
68
+ self._extract_embed_fn = extract_embed_fn
69
+ else:
70
+ self._extract_embed_fn = partial(get_full_embed, device=device)
71
+ if self.cache_path is not None:
72
+ self.cache_path.mkdir(exist_ok=True, parents=True)
73
+ logger.info(f"Cache instantiated at: {self.cache_path}")
74
+ self.pool = ThreadPoolExecutor(8)
75
+ self.pool.__enter__()
76
+ self._current_batch_cache: dict = {}
77
+ self._memory_cache: dict = {}
78
+
79
+ def _get_cache_path(self, path: tp.Union[Path, str]):
80
+ """Get cache path for the given file path."""
81
+ sig = sha1(str(path).encode()).hexdigest()
82
+ return self.cache_path / sig
83
+
84
+ @staticmethod
85
+ def _get_full_embed_from_cache(cache: Path):
86
+ """Loads full pre-computed embedding from the cache."""
87
+ try:
88
+ embed = torch.load(cache, 'cpu')
89
+ except Exception as exc:
90
+ logger.error("Error loading %s: %r", cache, exc)
91
+ embed = None
92
+ return embed
93
+
94
+ def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor:
95
+ """Get embedding from cache, computing and storing it to cache if not already cached.
96
+ The EmbeddingCache first tries to load the embedding from the in-memory cache
97
+ containing the pre-computed chunks populated through `populate_embed_cache`.
98
+ If not found, the full embedding is computed and stored on disk to be later accessed
99
+ to populate the in-memory cache, and the desired embedding chunk is extracted and returned.
100
+
101
+ Args:
102
+ paths (list[Path or str]): List of paths from where the embeddings can be loaded.
103
+ x (any): Object from which the embedding is extracted.
104
+ """
105
+ embeds = []
106
+ for idx, path in enumerate(paths):
107
+ cache = self._get_cache_path(path)
108
+ if cache in self._current_batch_cache:
109
+ embed = self._current_batch_cache[cache]
110
+ else:
111
+ full_embed = self._compute_embed_fn(path, x, idx)
112
+ try:
113
+ with flashy.utils.write_and_rename(cache, pid=True) as f:
114
+ torch.save(full_embed.cpu(), f)
115
+ except Exception as exc:
116
+ logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc)
117
+ else:
118
+ logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape)
119
+ embed = self._extract_embed_fn(full_embed, x, idx)
120
+ embeds.append(embed)
121
+ embed = torch.stack(embeds, dim=0)
122
+ return embed
123
+
124
+ def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
125
+ """Populate in-memory caches for embeddings reading from the embeddings stored on disk.
126
+ The in-memory caches consist in a cache for the full embedding and another cache for the
127
+ final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings
128
+ and reduce the IO footprint and synchronization points during forward passes.
129
+
130
+ Args:
131
+ paths (list[Path]): List of paths from where the embeddings can be loaded.
132
+ x (any): Object from which the embedding is extracted.
133
+ """
134
+ self._current_batch_cache.clear()
135
+ if self.cache_path is not None:
136
+ futures: list = []
137
+ for path in paths:
138
+ assert path is not None, "Path is required for computation from cache"
139
+ cache = self._get_cache_path(path)
140
+ if cache in self._memory_cache or not cache.exists():
141
+ futures.append(None)
142
+ else:
143
+ futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache))
144
+ for idx, (path, future) in enumerate(zip(paths, futures)):
145
+ assert path is not None
146
+ cache = self._get_cache_path(path)
147
+ full_embed = None
148
+ if future is None:
149
+ if cache in self._memory_cache:
150
+ full_embed = self._memory_cache[cache]
151
+ else:
152
+ full_embed = future.result()
153
+ if full_embed is not None:
154
+ self._memory_cache[cache] = full_embed
155
+ full_embed = full_embed.to(self.device)
156
+ if full_embed is not None:
157
+ embed = self._extract_embed_fn(full_embed, x, idx)
158
+ self._current_batch_cache[cache] = embed
159
+
160
+
161
+ class CachedBatchWriter:
162
+ """Write pre computed caches for mini batches. This can
163
+ make loading a lot more efficient depending on your filesystem.
164
+
165
+ Args:
166
+ cache_folder (Path): folder in which the cached minibatches
167
+ will be stored.
168
+
169
+ Inside cache folder, the structure is the following:
170
+ `epoch_number / update_number.zip`
171
+ And the zip file contains one entry per batch item.
172
+
173
+ It is possible to use the cache with a batch size smaller than
174
+ created with but obviously not larger. Make sure to call the
175
+ `start_epoch(epoch)` method for indicating changes of epochs.
176
+
177
+ See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py`
178
+ for an example of how to warmup the cache.
179
+ """
180
+ def __init__(self, cache_folder: Path):
181
+ self.cache_folder = cache_folder
182
+ self._current_epoch: tp.Optional[int] = None
183
+ self._current_index = 0
184
+
185
+ def start_epoch(self, epoch: int):
186
+ """Call at the beginning of each epoch.
187
+ """
188
+ self._current_epoch = epoch
189
+ self._current_index = 0
190
+ self._zip_path.parent.mkdir(exist_ok=True, parents=True)
191
+
192
+ @staticmethod
193
+ def _get_zip_path(cache_folder: Path, epoch: int, index: int):
194
+ return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip"
195
+
196
+ @property
197
+ def _zip_path(self):
198
+ assert self._current_epoch is not None
199
+ return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index)
200
+
201
+ def save(self, *content):
202
+ """Save one mini batch. This function is distributed-aware
203
+ and will automatically merge all the items from the different
204
+ workers.
205
+ """
206
+ all_contents = []
207
+ for rank in range(flashy.distrib.world_size()):
208
+ their_content = flashy.distrib.broadcast_object(content, src=rank)
209
+ all_contents.append(their_content)
210
+
211
+ if flashy.distrib.is_rank_zero():
212
+ idx = 0
213
+ with flashy.utils.write_and_rename(self._zip_path) as tmp:
214
+ with zipfile.ZipFile(tmp, 'w') as zf:
215
+ for content in all_contents:
216
+ for vals in zip(*content):
217
+ with zf.open(f'{idx}', 'w') as f: # type: ignore
218
+ torch.save(vals, f)
219
+ idx += 1
220
+ flashy.distrib.barrier()
221
+ self._current_index += 1
222
+
223
+
224
+ class CachedBatchLoader:
225
+ """Loader for cached mini-batches dumped with `CachedBatchWriter`.
226
+
227
+ Args:
228
+ cache_folder (Path): folder in which the cached minibatches are stored.
229
+ batch_size (int): batch size (per GPU) expected.
230
+ num_workers (int): number of workers to use for loading.
231
+ min_length (int): minimum expected length for each epoch. If some
232
+ mini-batches are missing, and error is raised.
233
+
234
+ This is iterable just like a regular DataLoader.
235
+ """
236
+
237
+ def __init__(self, cache_folder: Path, batch_size: int,
238
+ num_workers: int = 10, min_length: int = 1):
239
+ self.cache_folder = cache_folder
240
+ self.batch_size = batch_size
241
+ self.num_workers = num_workers
242
+ self.min_length = min_length
243
+ self._current_epoch: tp.Optional[int] = None
244
+ self.sampler = None # for compatibility with the regular DataLoader
245
+
246
+ def __len__(self):
247
+ path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent
248
+ return len([p for p in path.iterdir() if p.suffix == ".zip"])
249
+
250
+ def start_epoch(self, epoch: int):
251
+ """Call at the beginning of each epoch.
252
+ """
253
+ self._current_epoch = epoch
254
+
255
+ def _zip_path(self, index: int):
256
+ assert self._current_epoch is not None
257
+ return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index)
258
+
259
+ def _load_one(self, index: int):
260
+ zip_path = self._zip_path(index)
261
+ if not zip_path.exists():
262
+ if index < self.min_length:
263
+ raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist")
264
+
265
+ return None
266
+ mode = "rb" if sys.version_info >= (3, 9) else "r"
267
+ try:
268
+ with zipfile.ZipFile(zip_path, 'r') as zf:
269
+ rank = flashy.distrib.rank()
270
+ world_size = flashy.distrib.world_size()
271
+ root = zipfile.Path(zf)
272
+ items = list(root.iterdir())
273
+ total_batch_size = self.batch_size * world_size
274
+ if len(items) < total_batch_size:
275
+ raise RuntimeError(
276
+ f"The cache can handle a max batch size of {len(items)}, "
277
+ f"but {total_batch_size} is needed.")
278
+ start = rank * self.batch_size
279
+ items = items[start: start + self.batch_size]
280
+ assert len(items) == self.batch_size
281
+ entries = []
282
+ entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore
283
+ transposed = zip(*entries)
284
+ out = []
285
+ for part in transposed:
286
+ assert len(part) > 0
287
+ if isinstance(part[0], torch.Tensor):
288
+ out.append(torch.stack(part))
289
+ else:
290
+ assert isinstance(part, torch.Tensor)
291
+ out.append(part)
292
+ return out
293
+ except Exception:
294
+ logger.error("Error when reading zip path %s", zip_path)
295
+ raise
296
+
297
+ def __iter__(self):
298
+ """This will yields tuples, exactly as provided to the
299
+ `CachedBatchWriter.save` method.
300
+ """
301
+ pool = ThreadPoolExecutor(self.num_workers)
302
+ next_index = 0
303
+ queue = deque()
304
+
305
+ def _get_next():
306
+ nonlocal next_index
307
+ r = queue.popleft().result()
308
+ if r is None:
309
+ return None
310
+ else:
311
+ queue.append(pool.submit(self._load_one, next_index))
312
+ next_index += 1
313
+ return r
314
+
315
+ with pool:
316
+ # fill the buffer of fetching jobs.
317
+ for _ in range(2 * self.num_workers):
318
+ queue.append(pool.submit(self._load_one, next_index))
319
+ next_index += 1
320
+ while True:
321
+ batch = _get_next()
322
+ if batch is None:
323
+ return
324
+ yield batch
audiocraft/utils/cluster.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility functions for SLURM configuration and cluster settings.
9
+ """
10
+
11
+ from enum import Enum
12
+ import os
13
+ import socket
14
+ import typing as tp
15
+
16
+ import omegaconf
17
+
18
+
19
+ class ClusterType(Enum):
20
+ AWS = "aws"
21
+ FAIR = "fair"
22
+ RSC = "rsc"
23
+ LOCAL_DARWIN = "darwin"
24
+ DEFAULT = "default" # used for any other cluster.
25
+
26
+
27
+ def _guess_cluster_type() -> ClusterType:
28
+ uname = os.uname()
29
+ fqdn = socket.getfqdn()
30
+ if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn):
31
+ return ClusterType.AWS
32
+
33
+ if fqdn.endswith(".fair"):
34
+ return ClusterType.FAIR
35
+
36
+ if fqdn.endswith(".facebook.com"):
37
+ return ClusterType.RSC
38
+
39
+ if uname.sysname == "Darwin":
40
+ return ClusterType.LOCAL_DARWIN
41
+
42
+ return ClusterType.DEFAULT
43
+
44
+
45
+ def get_cluster_type(
46
+ cluster_type: tp.Optional[ClusterType] = None,
47
+ ) -> tp.Optional[ClusterType]:
48
+ if cluster_type is None:
49
+ return _guess_cluster_type()
50
+
51
+ return cluster_type
52
+
53
+
54
+ def get_slurm_parameters(
55
+ cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
56
+ ) -> omegaconf.DictConfig:
57
+ """Update SLURM parameters in configuration based on cluster type.
58
+ If the cluster type is not specify, it infers it automatically.
59
+ """
60
+ from ..environment import AudioCraftEnvironment
61
+ cluster_type = get_cluster_type(cluster_type)
62
+ # apply cluster-specific adjustments
63
+ if cluster_type == ClusterType.AWS:
64
+ cfg["mem_per_gpu"] = None
65
+ cfg["constraint"] = None
66
+ cfg["setup"] = []
67
+ elif cluster_type == ClusterType.RSC:
68
+ cfg["mem_per_gpu"] = None
69
+ cfg["setup"] = []
70
+ cfg["constraint"] = None
71
+ cfg["partition"] = "learn"
72
+ slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
73
+ if slurm_exclude is not None:
74
+ cfg["exclude"] = slurm_exclude
75
+ return cfg
audiocraft/utils/export.py CHANGED
@@ -11,46 +11,69 @@ Utility to export a training checkpoint to a lightweight release checkpoint.
11
  from pathlib import Path
12
  import typing as tp
13
 
14
- from omegaconf import OmegaConf, DictConfig
15
  import torch
16
 
 
17
 
18
- def _clean_lm_cfg(cfg: DictConfig):
19
- OmegaConf.set_struct(cfg, False)
20
- # This used to be set automatically in the LM solver, need a more robust solution
21
- # for the future.
22
- cfg['transformer_lm']['card'] = 2048
23
- cfg['transformer_lm']['n_q'] = 4
24
- # Experimental params no longer supported.
25
- bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
26
- 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
27
- for name in bad_params:
28
- del cfg['transformer_lm'][name]
29
- OmegaConf.set_struct(cfg, True)
30
- return cfg
31
-
32
-
33
- def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
34
- sig = Path(checkpoint_path).parent.name
35
- assert len(sig) == 8, "Not a valid Dora signature"
36
  pkg = torch.load(checkpoint_path, 'cpu')
37
  new_pkg = {
38
- 'best_state': pkg['ema']['state']['model'],
39
  'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
 
 
40
  }
41
- out_file = Path(out_folder) / f'{sig}.th'
42
  torch.save(new_pkg, out_file)
43
  return out_file
44
 
45
 
46
- def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
47
- sig = Path(checkpoint_path).parent.name
48
- assert len(sig) == 8, "Not a valid Dora signature"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  pkg = torch.load(checkpoint_path, 'cpu')
 
 
 
 
 
50
  new_pkg = {
51
- 'best_state': pkg['fsdp_best_state']['model'],
52
- 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
 
 
53
  }
54
- out_file = Path(out_folder) / f'{sig}.th'
 
55
  torch.save(new_pkg, out_file)
56
  return out_file
 
11
  from pathlib import Path
12
  import typing as tp
13
 
14
+ from omegaconf import OmegaConf
15
  import torch
16
 
17
+ from audiocraft import __version__
18
 
19
+
20
+ def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
21
+ """Export only the best state from the given EnCodec checkpoint. This
22
+ should be used if you trained your own EnCodec model.
23
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  pkg = torch.load(checkpoint_path, 'cpu')
25
  new_pkg = {
26
+ 'best_state': pkg['best_state']['model'],
27
  'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
28
+ 'version': __version__,
29
+ 'exported': True,
30
  }
31
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
32
  torch.save(new_pkg, out_file)
33
  return out_file
34
 
35
 
36
+ def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
37
+ """Export a compression model (potentially EnCodec) from a pretrained model.
38
+ This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
39
+ Do not include the //pretrained/ prefix. For instance if you trained a model
40
+ with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.
41
+
42
+ In that case, this will not actually include a copy of the model, simply the reference
43
+ to the model used.
44
+ """
45
+ if Path(pretrained_encodec).exists():
46
+ pkg = torch.load(pretrained_encodec)
47
+ assert 'best_state' in pkg
48
+ assert 'xp.cfg' in pkg
49
+ assert 'version' in pkg
50
+ assert 'exported' in pkg
51
+ else:
52
+ pkg = {
53
+ 'pretrained': pretrained_encodec,
54
+ 'exported': True,
55
+ 'version': __version__,
56
+ }
57
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
58
+ torch.save(pkg, out_file)
59
+
60
+
61
+ def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
62
+ """Export only the best state from the given MusicGen or AudioGen checkpoint.
63
+ """
64
  pkg = torch.load(checkpoint_path, 'cpu')
65
+ if pkg['fsdp_best_state']:
66
+ best_state = pkg['fsdp_best_state']['model']
67
+ else:
68
+ assert pkg['best_state']
69
+ best_state = pkg['best_state']['model']
70
  new_pkg = {
71
+ 'best_state': best_state,
72
+ 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
73
+ 'version': __version__,
74
+ 'exported': True,
75
  }
76
+
77
+ Path(out_file).parent.mkdir(exist_ok=True, parents=True)
78
  torch.save(new_pkg, out_file)
79
  return out_file
audiocraft/utils/export_legacy.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility to export a training checkpoint to a lightweight release checkpoint.
9
+ """
10
+
11
+ from pathlib import Path
12
+ import typing as tp
13
+
14
+ from omegaconf import OmegaConf, DictConfig
15
+ import torch
16
+
17
+
18
+ def _clean_lm_cfg(cfg: DictConfig):
19
+ OmegaConf.set_struct(cfg, False)
20
+ # This used to be set automatically in the LM solver, need a more robust solution
21
+ # for the future.
22
+ cfg['transformer_lm']['card'] = 2048
23
+ cfg['transformer_lm']['n_q'] = 4
24
+ # Experimental params no longer supported.
25
+ bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
26
+ 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
27
+ for name in bad_params:
28
+ del cfg['transformer_lm'][name]
29
+ OmegaConf.set_struct(cfg, True)
30
+ return cfg
31
+
32
+
33
+ def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
34
+ sig = Path(checkpoint_path).parent.name
35
+ assert len(sig) == 8, "Not a valid Dora signature"
36
+ pkg = torch.load(checkpoint_path, 'cpu')
37
+ new_pkg = {
38
+ 'best_state': pkg['ema']['state']['model'],
39
+ 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
40
+ }
41
+ out_file = Path(out_folder) / f'{sig}.th'
42
+ torch.save(new_pkg, out_file)
43
+ return out_file
44
+
45
+
46
+ def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
47
+ sig = Path(checkpoint_path).parent.name
48
+ assert len(sig) == 8, "Not a valid Dora signature"
49
+ pkg = torch.load(checkpoint_path, 'cpu')
50
+ new_pkg = {
51
+ 'best_state': pkg['fsdp_best_state']['model'],
52
+ 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
53
+ }
54
+ out_file = Path(out_folder) / f'{sig}.th'
55
+ torch.save(new_pkg, out_file)
56
+ return out_file
audiocraft/utils/extend.py CHANGED
@@ -179,7 +179,7 @@ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:
179
  descriptions=[text],
180
  melody_wavs=verse,
181
  sample_rate=sr,
182
- progress=False,
183
  prompt=prompt_segment,
184
  )
185
  # If user selects a prompt segment, use the prompt segment for all segments
@@ -280,9 +280,10 @@ def load_font(font_name, font_size=16):
280
  if font is None:
281
  try:
282
  req = requests.get(font_name)
283
- font = ImageFont.truetype(BytesIO(req.content), font_size)
284
- except (FileNotFoundError, OSError):
285
- print(f"Font not found: {font_name} Using default font\n")
 
286
  if font:
287
  print(f"Font loaded {font.getname()}")
288
  else:
 
179
  descriptions=[text],
180
  melody_wavs=verse,
181
  sample_rate=sr,
182
+ progress=True,
183
  prompt=prompt_segment,
184
  )
185
  # If user selects a prompt segment, use the prompt segment for all segments
 
280
  if font is None:
281
  try:
282
  req = requests.get(font_name)
283
+ font = ImageFont.truetype(BytesIO(req.content), font_size)
284
+ except (FileNotFoundError, OSError):
285
+ print(f"Font not found: {font_name} Using default font\n")
286
+
287
  if font:
288
  print(f"Font loaded {font.getname()}")
289
  else:
audiocraft/utils/utils.py CHANGED
@@ -5,9 +5,12 @@
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  from concurrent.futures import ProcessPoolExecutor
8
- from functools import wraps
 
9
  import hashlib
 
10
  import logging
 
11
  import typing as tp
12
 
13
  import flashy
@@ -20,6 +23,18 @@ from torch.nn.utils.rnn import pad_sequence
20
  logger = logging.getLogger(__name__)
21
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
24
  """Convenience function to map an omegaconf configuration to a dictionary.
25
 
@@ -172,7 +187,7 @@ def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> t
172
  assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
173
  final_length = lengths.max().item() if not max_len else max_len
174
  final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
175
- return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
176
 
177
 
178
  def hash_trick(word: str, vocab_size: int) -> int:
@@ -232,3 +247,54 @@ def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tens
232
  padded_tensors = padded_tensors.transpose(0, 1)
233
  padded_tensors = padded_tensors.transpose(1, dim + 1)
234
  return padded_tensors, lens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  from concurrent.futures import ProcessPoolExecutor
8
+ from contextlib import contextmanager
9
+ from functools import wraps, lru_cache
10
  import hashlib
11
+ import json
12
  import logging
13
+ from pathlib import Path
14
  import typing as tp
15
 
16
  import flashy
 
23
  logger = logging.getLogger(__name__)
24
 
25
 
26
+ def model_hash(model: torch.nn.Module) -> str:
27
+ """Return a model hash. This should allow us to track regressions in model init
28
+ from the logs of past experiments.
29
+ """
30
+ hasher = hashlib.sha1()
31
+ for p in model.parameters():
32
+ hasher.update(p.data.cpu().numpy().tobytes())
33
+ return hasher.hexdigest()
34
+
35
+
36
+
37
+
38
  def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
39
  """Convenience function to map an omegaconf configuration to a dictionary.
40
 
 
187
  assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
188
  final_length = lengths.max().item() if not max_len else max_len
189
  final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
190
+ return torch.arange(final_length, device=lengths.device)[None, :] < lengths[:, None]
191
 
192
 
193
  def hash_trick(word: str, vocab_size: int) -> int:
 
247
  padded_tensors = padded_tensors.transpose(0, 1)
248
  padded_tensors = padded_tensors.transpose(1, dim + 1)
249
  return padded_tensors, lens
250
+
251
+
252
+ # TODO: Move to flashy?
253
+ def copy_state(state: tp.Any, device: tp.Union[torch.device, str] = 'cpu',
254
+ dtype: tp.Optional[torch.dtype] = None) -> tp.Any:
255
+ if isinstance(state, torch.Tensor):
256
+ if dtype is None or not state.is_floating_point():
257
+ dtype = state.dtype
258
+ return state.detach().to(device=device, dtype=dtype, copy=True)
259
+ elif isinstance(state, dict):
260
+ return {k: copy_state(v, device, dtype) for k, v in state.items()}
261
+ elif isinstance(state, list):
262
+ return [copy_state(v, device, dtype) for v in state]
263
+
264
+
265
+ # TODO: Move to flashy?
266
+ @contextmanager
267
+ def swap_state(model, state, **kwargs):
268
+ old_state = copy_state(model.state_dict())
269
+ model.load_state_dict(state, **kwargs)
270
+ try:
271
+ yield
272
+ finally:
273
+ model.load_state_dict(old_state)
274
+
275
+
276
+ @lru_cache(None)
277
+ def warn_once(logger, msg):
278
+ """Warn about a given message only once."""
279
+ logger.warning(msg)
280
+
281
+
282
+ def is_jsonable(x: tp.Any):
283
+ """Check if an object can be serialized into a json:"""
284
+ try:
285
+ json.dumps(x)
286
+ return True
287
+ except (TypeError, OverflowError):
288
+ return False
289
+
290
+
291
+ def load_clap_state_dict(clap_model, path: tp.Union[str, Path]):
292
+ """Wrapper around state dict loading of CLAP model
293
+ addressing compatibility issues between CLAP and AudioCraft
294
+ HuggingFace transformer version.
295
+ See: https://github.com/LAION-AI/CLAP/issues/118
296
+ """
297
+ from clap_module.factory import load_state_dict # type: ignore
298
+ pkg = load_state_dict(path)
299
+ pkg.pop('text_branch.embeddings.position_ids', None)
300
+ clap_model.model.load_state_dict(pkg)
modules/file_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file_utils
2
+ import os
3
+ import shutil
4
+ from pathlib import Path
5
+
6
+ def get_file_parts(file_path: str):
7
+ # Split the path into directory and filename
8
+ directory, filename = os.path.split(file_path)
9
+
10
+ # Split the filename into name and extension
11
+ name, ext = os.path.splitext(filename)
12
+
13
+ # Convert the extension to lowercase
14
+ new_ext = ext.lower()
15
+ return directory, filename, name, ext, new_ext
16
+
17
+ def rename_file_to_lowercase_extension(file_path: str) -> str:
18
+ """
19
+ Renames a file's extension to lowercase in place.
20
+
21
+ Parameters:
22
+ file_path (str): The original file path.
23
+
24
+ Returns:
25
+ str: The new file path with the lowercase extension.
26
+
27
+ Raises:
28
+ OSError: If there is an error renaming the file (e.g., file not found, permissions issue).
29
+ """
30
+ directory, filename, name, ext, new_ext = get_file_parts(file_path)
31
+ # If the extension changes, rename the file
32
+ if ext != new_ext:
33
+ new_filename = name + new_ext
34
+ new_file_path = os.path.join(directory, new_filename)
35
+ try:
36
+ os.rename(file_path, new_file_path)
37
+ print(f"Rename {file_path} to {new_file_path}\n")
38
+ except Exception as e:
39
+ print(f"os.rename failed: {e}. Falling back to binary copy operation.")
40
+ try:
41
+ # Read the file in binary mode and write it to new_file_path
42
+ with open(file_path, 'rb') as f:
43
+ data = f.read()
44
+ with open(new_file_path, 'wb') as f:
45
+ f.write(data)
46
+ print(f"Copied {file_path} to {new_file_path}\n")
47
+ # Optionally, remove the original file after copying
48
+ #os.remove(file_path)
49
+ except Exception as inner_e:
50
+ print(f"Failed to copy file from {file_path} to {new_file_path}: {inner_e}")
51
+ raise inner_e
52
+ return new_file_path
53
+ else:
54
+ return file_path
55
+
56
+ def get_filename(file):
57
+ # extract filename from file object
58
+ filename = None
59
+ if file is not None:
60
+ filename = file.name
61
+ return filename
62
+
63
+ def convert_title_to_filename(title):
64
+ # convert title to filename
65
+ filename = title.lower().replace(" ", "_").replace("/", "_")
66
+ return filename
67
+
68
+ def get_filename_from_filepath(filepath):
69
+ file_name = os.path.basename(filepath)
70
+ file_base, file_extension = os.path.splitext(file_name)
71
+ return file_base, file_extension
72
+
73
+ def delete_file(file_path: str) -> None:
74
+ """
75
+ Deletes the specified file.
76
+
77
+ Parameters:
78
+ file_path (str): The path to thefile to delete.
79
+
80
+ Raises:
81
+ FileNotFoundError: If the file does not exist.
82
+ Exception: If there is an error deleting the file.
83
+ """
84
+ try:
85
+ path = Path(file_path)
86
+ path.unlink()
87
+ print(f"Deleted original file: {file_path}")
88
+ except FileNotFoundError:
89
+ print(f"File not found: {file_path}")
90
+ except Exception as e:
91
+ print(f"Error deleting file: {e}")
modules/gradio.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules.gradio
2
+ # holds updates and lost code from gradio changes
3
+ import os
4
+ import gradio as gr
5
+ import numpy as np
6
+ import PIL
7
+ import PIL.Image
8
+ import shutil
9
+ import subprocess
10
+ from tempfile import NamedTemporaryFile
11
+ from pathlib import Path
12
+
13
+
14
+ class MatplotlibBackendMananger:
15
+ def __enter__(self):
16
+ try:
17
+ import matplotlib
18
+
19
+ self._original_backend = matplotlib.get_backend()
20
+ matplotlib.use("agg")
21
+ except ImportError:
22
+ pass
23
+
24
+ def __exit__(self, exc_type, exc_val, exc_tb):
25
+ try:
26
+ import matplotlib
27
+
28
+ matplotlib.use(self._original_backend)
29
+ except ImportError:
30
+ pass
31
+
32
+ gr.utils.MatplotlibBackendMananger = MatplotlibBackendMananger
33
+
34
+ def make_waveform(
35
+ audio: str | tuple[int, np.ndarray],
36
+ *,
37
+ bg_color: str = "#f3f4f6",
38
+ bg_image: str | None = None,
39
+ fg_alpha: float = 0.75,
40
+ bars_color: str | tuple[str, str] = ("#fbbf24", "#ea580c"),
41
+ bar_count: int = 50,
42
+ bar_width: float = 0.6,
43
+ animate: bool = False,
44
+ name: str = "",
45
+ ) -> str:
46
+ """
47
+ Generates a waveform video from an audio file. Useful for creating an easy to share audio visualization. The output should be passed into a `gr.Video` component.
48
+ Parameters:
49
+ audio: Audio file path or tuple of (sample_rate, audio_data)
50
+ bg_color: Background color of waveform (ignored if bg_image is provided)
51
+ bg_image: Background image of waveform
52
+ fg_alpha: Opacity of foreground waveform
53
+ bars_color: Color of waveform bars. Can be a single color or a tuple of (start_color, end_color) of gradient
54
+ bar_count: Number of bars in waveform
55
+ bar_width: Width of bars in waveform. 1 represents full width, 0.5 represents half width, etc.
56
+ animate: If true, the audio waveform overlay will be animated, if false, it will be static.
57
+ Returns:
58
+ A filepath to the output video in mp4 format.
59
+ """
60
+ import matplotlib.pyplot as plt
61
+ from matplotlib.animation import FuncAnimation
62
+
63
+ if isinstance(audio, str):
64
+ audio_file = audio
65
+ audio = gr.processing_utils.audio_from_file(audio)
66
+ else:
67
+ tmp_wav = NamedTemporaryFile(suffix=".wav", delete=False, prefix = name)
68
+ gr.processing_utils.audio_to_file(audio[0], audio[1], tmp_wav.name, format="wav")
69
+ audio_file = tmp_wav.name
70
+
71
+ if not os.path.isfile(audio_file):
72
+ raise ValueError("Audio file not found.")
73
+
74
+ ffmpeg = shutil.which("ffmpeg")
75
+ if not ffmpeg:
76
+ raise RuntimeError("ffmpeg not found.")
77
+
78
+ duration = round(len(audio[1]) / audio[0], 4)
79
+
80
+ # Helper methods to create waveform
81
+ def hex_to_rgb(hex_str):
82
+ return [int(hex_str[i : i + 2], 16) for i in range(1, 6, 2)]
83
+
84
+ def get_color_gradient(c1, c2, n):
85
+ if n < 1:
86
+ raise ValueError("Must have at least one stop in gradient")
87
+ c1_rgb = np.array(hex_to_rgb(c1)) / 255
88
+ c2_rgb = np.array(hex_to_rgb(c2)) / 255
89
+ mix_pcts = [x / (n - 1) for x in range(n)]
90
+ rgb_colors = [((1 - mix) * c1_rgb + (mix * c2_rgb)) for mix in mix_pcts]
91
+ return [
92
+ "#" + "".join(f"{int(round(val * 255)):02x}" for val in item)
93
+ for item in rgb_colors
94
+ ]
95
+
96
+ # Reshape audio to have a fixed number of bars
97
+ samples = audio[1]
98
+ if len(samples.shape) > 1:
99
+ samples = np.mean(samples, 1)
100
+ bins_to_pad = bar_count - (len(samples) % bar_count)
101
+ samples = np.pad(samples, [(0, bins_to_pad)])
102
+ samples = np.reshape(samples, (bar_count, -1))
103
+ samples = np.abs(samples)
104
+ samples = np.max(samples, 1)
105
+
106
+ with MatplotlibBackendMananger():
107
+ plt.clf()
108
+ # Plot waveform
109
+ color = (
110
+ bars_color
111
+ if isinstance(bars_color, str)
112
+ else get_color_gradient(bars_color[0], bars_color[1], bar_count)
113
+ )
114
+
115
+ if animate:
116
+ fig = plt.figure(figsize=(5, 1), dpi=200, frameon=False)
117
+ fig.subplots_adjust(left=0, bottom=0, right=1, top=1)
118
+ plt.axis("off")
119
+ plt.margins(x=0)
120
+
121
+ bar_alpha = fg_alpha if animate else 1.0
122
+ barcollection = plt.bar(
123
+ np.arange(0, bar_count),
124
+ samples * 2,
125
+ bottom=(-1 * samples),
126
+ width=bar_width,
127
+ color=color,
128
+ alpha=bar_alpha,
129
+ )
130
+
131
+ tmp_img = NamedTemporaryFile(suffix=".png", delete=False, prefix = name)
132
+
133
+ savefig_kwargs: dict[str, Any] = {"bbox_inches": "tight"}
134
+ if bg_image is not None:
135
+ savefig_kwargs["transparent"] = True
136
+ if animate:
137
+ savefig_kwargs["facecolor"] = "none"
138
+ else:
139
+ savefig_kwargs["facecolor"] = bg_color
140
+ plt.savefig(tmp_img.name, **savefig_kwargs)
141
+
142
+ if not animate:
143
+ waveform_img = PIL.Image.open(tmp_img.name)
144
+ waveform_img = waveform_img.resize((1000, 400))
145
+
146
+ # Composite waveform with background image
147
+ if bg_image is not None:
148
+ waveform_array = np.array(waveform_img)
149
+ waveform_array[:, :, 3] = waveform_array[:, :, 3] * fg_alpha
150
+ waveform_img = PIL.Image.fromarray(waveform_array)
151
+
152
+ bg_img = PIL.Image.open(bg_image)
153
+ waveform_width, waveform_height = waveform_img.size
154
+ bg_width, bg_height = bg_img.size
155
+ if waveform_width != bg_width:
156
+ bg_img = bg_img.resize(
157
+ (
158
+ waveform_width,
159
+ 2 * int(bg_height * waveform_width / bg_width / 2),
160
+ )
161
+ )
162
+ bg_width, bg_height = bg_img.size
163
+ composite_height = max(bg_height, waveform_height)
164
+ composite = PIL.Image.new(
165
+ "RGBA", (waveform_width, composite_height), "#FFFFFF"
166
+ )
167
+ composite.paste(bg_img, (0, composite_height - bg_height))
168
+ composite.paste(
169
+ waveform_img, (0, composite_height - waveform_height), waveform_img
170
+ )
171
+ composite.save(tmp_img.name)
172
+ img_width, img_height = composite.size
173
+ else:
174
+ img_width, img_height = waveform_img.size
175
+ waveform_img.save(tmp_img.name)
176
+ else:
177
+
178
+ def _animate(_):
179
+ for idx, b in enumerate(barcollection):
180
+ rand_height = np.random.uniform(0.8, 1.2)
181
+ b.set_height(samples[idx] * rand_height * 2)
182
+ b.set_y((-rand_height * samples)[idx])
183
+
184
+ frames = int(duration * 10)
185
+ anim = FuncAnimation(
186
+ fig, # type: ignore
187
+ _animate, # type: ignore
188
+ repeat=False,
189
+ blit=False,
190
+ frames=frames,
191
+ interval=100,
192
+ )
193
+ anim.save(
194
+ tmp_img.name,
195
+ writer="pillow",
196
+ fps=10,
197
+ codec="png",
198
+ savefig_kwargs=savefig_kwargs,
199
+ )
200
+
201
+ # Convert waveform to video with ffmpeg
202
+ output_mp4 = NamedTemporaryFile(suffix=".mp4", delete=False, prefix = name)
203
+
204
+ if animate and bg_image is not None:
205
+ ffmpeg_cmd = [
206
+ ffmpeg,
207
+ "-loop",
208
+ "1",
209
+ "-i",
210
+ bg_image,
211
+ "-i",
212
+ tmp_img.name,
213
+ "-i",
214
+ audio_file,
215
+ "-filter_complex",
216
+ "[0:v]scale=w=trunc(iw/2)*2:h=trunc(ih/2)*2[bg];[1:v]format=rgba,colorchannelmixer=aa=1.0[ov];[bg][ov]overlay=(main_w-overlay_w*0.9)/2:main_h-overlay_h*0.9/2[output]",
217
+ "-t",
218
+ str(duration),
219
+ "-map",
220
+ "[output]",
221
+ "-map",
222
+ "2:a",
223
+ "-c:v",
224
+ "libx264",
225
+ "-c:a",
226
+ "aac",
227
+ "-shortest",
228
+ "-y",
229
+ output_mp4.name,
230
+ ]
231
+ elif animate and bg_image is None:
232
+ ffmpeg_cmd = [
233
+ ffmpeg,
234
+ "-i",
235
+ tmp_img.name,
236
+ "-i",
237
+ audio_file,
238
+ "-filter_complex",
239
+ "[0:v][1:a]concat=n=1:v=1:a=1[v];[v]scale=1000:400,format=yuv420p[v_scaled]",
240
+ "-map",
241
+ "[v_scaled]",
242
+ "-map",
243
+ "1:a",
244
+ "-c:v",
245
+ "libx264",
246
+ "-c:a",
247
+ "aac",
248
+ "-shortest",
249
+ "-y",
250
+ output_mp4.name,
251
+ ]
252
+ else:
253
+ ffmpeg_cmd = [
254
+ ffmpeg,
255
+ "-loop",
256
+ "1",
257
+ "-i",
258
+ tmp_img.name,
259
+ "-i",
260
+ audio_file,
261
+ "-vf",
262
+ f"color=c=#FFFFFF77:s={img_width}x{img_height}[bar];[0][bar]overlay=-w+(w/{duration})*t:H-h:shortest=1", # type: ignore
263
+ "-t",
264
+ str(duration),
265
+ "-y",
266
+ output_mp4.name,
267
+ ]
268
+
269
+ subprocess.check_call(ffmpeg_cmd)
270
+ return output_mp4.name
271
+
272
+ gr.make_waveform = make_waveform
modules/user_history.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ User History is a plugin that you can add to your Spaces to cache generated images for your users.
3
+
4
+ Key features:
5
+ - 🤗 Sign in with Hugging Face
6
+ - Save generated image, video, audio and document files with their metadata: prompts, timestamp, hyper-parameters, etc.
7
+ - Export your history as zip.
8
+ - Delete your history to respect privacy.
9
+ - Compatible with Persistent Storage for long-term storage.
10
+ - Admin panel to check configuration and disk usage .
11
+
12
+ Useful links:
13
+ - Demo: https://huggingface.co/spaces/Wauplin/gradio-user-history
14
+ - README: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/README.md
15
+ - Source file: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/user_history.py
16
+ - Discussions: https://huggingface.co/spaces/Wauplin/gradio-user-history/discussions
17
+ """
18
+
19
+ __version__ = "0.2.0"
20
+
21
+ import json
22
+ import os
23
+ import shutil
24
+ import warnings
25
+ from datetime import datetime
26
+ from functools import cache
27
+ from pathlib import Path
28
+ from typing import Callable, Dict, List, Tuple, Any
29
+ from uuid import uuid4
30
+
31
+ import gradio as gr
32
+ import numpy as np
33
+ import requests
34
+ from filelock import FileLock
35
+ from PIL.Image import Image
36
+ import filetype
37
+ import wave
38
+ from mutagen.mp3 import MP3, EasyMP3
39
+ import torchaudio
40
+ import subprocess
41
+
42
+
43
+ def setup(folder_path: str | Path | None = None) -> None:
44
+ user_history = _UserHistory()
45
+ user_history.folder_path = _resolve_folder_path(folder_path)
46
+ user_history.initialized = True
47
+
48
+
49
+ def render() -> None:
50
+ user_history = _UserHistory()
51
+
52
+ # initialize with default config
53
+ if not user_history.initialized:
54
+ print("Initializing user history with default config. Use `user_history.setup(...)` to customize folder_path.")
55
+ setup()
56
+
57
+ # Render user history tab
58
+ gr.Markdown(
59
+ "## Your past generations\n\nLog in to keep a gallery of your previous generations. Your history will be saved"
60
+ " and available on your next visit. Make sure to export your images from time to time as this gallery may be"
61
+ " deleted in the future."
62
+ )
63
+
64
+ if os.getenv("SYSTEM") == "spaces" and not os.path.exists("/data"):
65
+ gr.Markdown(
66
+ "**⚠️ Persistent storage is disabled, meaning your history will be lost if the Space gets restarted."
67
+ " Only the Space owner can setup a Persistent Storage. If you are not the Space owner, consider"
68
+ " duplicating this Space to set your own storage.⚠️**"
69
+ )
70
+
71
+ with gr.Row():
72
+ gr.LoginButton(min_width=250)
73
+ #gr.LogoutButton(min_width=250)
74
+ refresh_button = gr.Button(
75
+ "Refresh",
76
+ icon="./assets/icon_refresh.png",
77
+ )
78
+ export_button = gr.Button(
79
+ "Export",
80
+ icon="./assets/icon_download.png",
81
+ )
82
+ delete_button = gr.Button(
83
+ "Delete history",
84
+ icon="./assets/icon_delete.png",
85
+ )
86
+
87
+ # "Export zip" row (hidden by default)
88
+ with gr.Row():
89
+ export_file = gr.File(file_count="single", file_types=[".zip"], label="Exported history", visible=False)
90
+
91
+ # "Config deletion" row (hidden by default)
92
+ with gr.Row():
93
+ confirm_button = gr.Button("Confirm delete all history", variant="stop", visible=False)
94
+ cancel_button = gr.Button("Cancel", visible=False)
95
+
96
+ # Gallery
97
+ gallery = gr.Gallery(
98
+ label="Past images",
99
+ show_label=True,
100
+ elem_id="gradio_user_history_gallery",
101
+ object_fit="cover",
102
+ columns=5,
103
+ height=600,
104
+ preview=False,
105
+ show_share_button=False,
106
+ show_download_button=True,
107
+ )
108
+ gr.Markdown(
109
+ "User history is powered by"
110
+ " [Wauplin/gradio-user-history](https://huggingface.co/spaces/Wauplin/gradio-user-history). Integrate it to"
111
+ " your own Space in just a few lines of code!"
112
+ )
113
+ gallery.attach_load_event(_fetch_user_history, every=None)
114
+
115
+ # Interactions
116
+ refresh_button.click(fn=_fetch_user_history, inputs=[], outputs=[gallery], queue=False)
117
+ export_button.click(fn=_export_user_history, inputs=[], outputs=[export_file], queue=False)
118
+
119
+ # Taken from https://github.com/gradio-app/gradio/issues/3324#issuecomment-1446382045
120
+ delete_button.click(
121
+ lambda: [gr.update(visible=True), gr.update(visible=True)],
122
+ outputs=[confirm_button, cancel_button],
123
+ queue=False,
124
+ )
125
+ cancel_button.click(
126
+ lambda: [gr.update(visible=False), gr.update(visible=False)],
127
+ outputs=[confirm_button, cancel_button],
128
+ queue=False,
129
+ )
130
+ confirm_button.click(_delete_user_history).then(
131
+ lambda: [gr.update(visible=False), gr.update(visible=False)],
132
+ outputs=[confirm_button, cancel_button],
133
+ queue=False,
134
+ )
135
+
136
+ # Admin section (only shown locally or when logged in as Space owner)
137
+ _admin_section()
138
+
139
+
140
+ def save_image(
141
+ profile: gr.OAuthProfile | None,
142
+ image: Image | np.ndarray | str | Path,
143
+ label: str | None = None,
144
+ metadata: Dict | None = None,
145
+ ):
146
+ # Ignore images from logged out users
147
+ if profile is None:
148
+ return
149
+ username = profile["preferred_username"]
150
+
151
+ # Ignore images if user history not used
152
+ user_history = _UserHistory()
153
+ if not user_history.initialized:
154
+ warnings.warn(
155
+ "User history is not set in Gradio demo. Saving image is ignored. You must use `user_history.render(...)`"
156
+ " first."
157
+ )
158
+ return
159
+
160
+ # Copy image to storage
161
+ image_path = _copy_image(image, dst_folder=user_history._user_images_path(username))
162
+
163
+ # Save new image + metadata
164
+ if metadata is None:
165
+ metadata = {}
166
+ if "datetime" not in metadata:
167
+ metadata["datetime"] = str(datetime.now())
168
+ data = {"path": str(image_path), "label": label, "metadata": metadata}
169
+ with user_history._user_lock(username):
170
+ with user_history._user_jsonl_path(username).open("a") as f:
171
+ f.write(json.dumps(data) + "\n")
172
+
173
+ def save_file(
174
+ profile: gr.OAuthProfile | None,
175
+ image: Image | np.ndarray | str | Path | None = None,
176
+ video: str | Path | None = None,
177
+ audio: str | Path | None = None,
178
+ document: str | Path | None = None,
179
+ label: str | None = None,
180
+ metadata: Dict | None = None,
181
+ ):
182
+ # Ignore files from logged out users
183
+ if profile is None:
184
+ return
185
+ username = profile["preferred_username"]
186
+
187
+ # Ignore files if user history not used
188
+ user_history = _UserHistory()
189
+ if not user_history.initialized:
190
+ warnings.warn(
191
+ "User history is not set in Gradio demo. Saving files is ignored. You must use `user_history.render(...)`"
192
+ " first."
193
+ )
194
+ return
195
+
196
+ # Save new files + metadata
197
+ if metadata is None:
198
+ metadata = {}
199
+ if "datetime" not in metadata:
200
+ metadata["datetime"] = str(datetime.now())
201
+
202
+ # Copy image to storage
203
+ image_path = None
204
+ if image is not None:
205
+ image_path = _copy_image(image, dst_folder=user_history._user_images_path(username))
206
+ image_path = _add_metadata(image_path, metadata)
207
+
208
+ # Copy video to storage
209
+ if video is not None:
210
+ video_path = _copy_file(video, dst_folder=user_history._user_file_path(username, "videos"))
211
+ video_path = _add_metadata(video_path, metadata)
212
+
213
+ # Copy audio to storage
214
+ if audio is not None:
215
+ audio_path = _copy_file(audio, dst_folder=user_history._user_file_path(username, "audios"))
216
+ audio_path = _add_metadata(audio_path, metadata)
217
+
218
+ # Copy document to storage
219
+ if document is not None:
220
+ document_path = _copy_file(document, dst_folder=user_history._user_file_path(username, "documents"))
221
+ document_path = _add_metadata(document_path, metadata)
222
+
223
+ # Save Json file
224
+ data = {"image_path": str(image_path), "video_path": str(video_path), "audio_path": str(audio_path), "document_path": str(document_path), "label": label, "metadata": metadata}
225
+ with user_history._user_lock(username):
226
+ with user_history._user_jsonl_path(username).open("a") as f:
227
+ f.write(json.dumps(data) + "\n")
228
+
229
+
230
+ #############
231
+ # Internals #
232
+ #############
233
+
234
+
235
+ class _UserHistory(object):
236
+ _instance = None
237
+ initialized: bool = False
238
+ folder_path: Path
239
+
240
+ def __new__(cls):
241
+ # Using singleton pattern => we don't want to expose an object (more complex to use) but still want to keep
242
+ # state between `render` and `save_image` calls.
243
+ if cls._instance is None:
244
+ cls._instance = super(_UserHistory, cls).__new__(cls)
245
+ return cls._instance
246
+
247
+ def _user_path(self, username: str) -> Path:
248
+ path = self.folder_path / username
249
+ path.mkdir(parents=True, exist_ok=True)
250
+ return path
251
+
252
+ def _user_lock(self, username: str) -> FileLock:
253
+ """Ensure history is not corrupted if concurrent calls."""
254
+ return FileLock(self.folder_path / f"{username}.lock") # lock outside of folder => better when exporting ZIP
255
+
256
+ def _user_jsonl_path(self, username: str) -> Path:
257
+ return self._user_path(username) / "history.jsonl"
258
+
259
+ def _user_images_path(self, username: str) -> Path:
260
+ path = self._user_path(username) / "images"
261
+ path.mkdir(parents=True, exist_ok=True)
262
+ return path
263
+
264
+ def _user_file_path(self, username: str, filetype: str = "images") -> Path:
265
+ path = self._user_path(username) / filetype
266
+ path.mkdir(parents=True, exist_ok=True)
267
+ return path
268
+
269
+
270
+
271
+ def _fetch_user_history(profile: gr.OAuthProfile | None) -> List[Tuple[str, str]]:
272
+ """Return saved history for that user, if it exists."""
273
+ # Cannot load history for logged out users
274
+ if profile is None:
275
+ return []
276
+ username = profile["preferred_username"]
277
+
278
+ user_history = _UserHistory()
279
+ if not user_history.initialized:
280
+ warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.")
281
+ return []
282
+
283
+ with user_history._user_lock(username):
284
+ # No file => no history saved yet
285
+ jsonl_path = user_history._user_jsonl_path(username)
286
+ if not jsonl_path.is_file():
287
+ return []
288
+
289
+ # Read history
290
+ images = []
291
+ for line in jsonl_path.read_text().splitlines():
292
+ data = json.loads(line)
293
+ images.append((data["path"], data["label"] or ""))
294
+ return list(reversed(images))
295
+
296
+
297
+ def _export_user_history(profile: gr.OAuthProfile | None) -> Dict | None:
298
+ """Zip all history for that user, if it exists and return it as a downloadable file."""
299
+ # Cannot load history for logged out users
300
+ if profile is None:
301
+ return None
302
+ username = profile["preferred_username"]
303
+
304
+ user_history = _UserHistory()
305
+ if not user_history.initialized:
306
+ warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.")
307
+ return None
308
+
309
+ # Zip history
310
+ with user_history._user_lock(username):
311
+ path = shutil.make_archive(
312
+ str(_archives_path() / f"history_{username}"), "zip", user_history._user_path(username)
313
+ )
314
+
315
+ return gr.update(visible=True, value=path)
316
+
317
+
318
+ def _delete_user_history(profile: gr.OAuthProfile | None) -> None:
319
+ """Delete all history for that user."""
320
+ # Cannot load history for logged out users
321
+ if profile is None:
322
+ return
323
+ username = profile["preferred_username"]
324
+
325
+ user_history = _UserHistory()
326
+ if not user_history.initialized:
327
+ warnings.warn("User history is not set in Gradio demo. You must use `user_history.render(...)` first.")
328
+ return
329
+
330
+ with user_history._user_lock(username):
331
+ shutil.rmtree(user_history._user_path(username))
332
+
333
+
334
+ ####################
335
+ # Internal helpers #
336
+ ####################
337
+
338
+
339
+ def _copy_image(image: Image | np.ndarray | str | Path, dst_folder: Path) -> Path:
340
+ try:
341
+ """Copy image to the images folder."""
342
+ # Already a path => copy it
343
+ if isinstance(image, str):
344
+ image = Path(image)
345
+ if isinstance(image, Path):
346
+ dst = dst_folder / f"{uuid4().hex}_{Path(image).name}" # keep file ext
347
+ shutil.copyfile(image, dst)
348
+ return dst
349
+
350
+ # Still a Python object => serialize it
351
+ if isinstance(image, np.ndarray):
352
+ image = Image.fromarray(image)
353
+ if isinstance(image, Image):
354
+ dst = dst_folder / f"{Path(file).name}_{uuid4().hex}.png"
355
+ image.save(dst)
356
+ return dst
357
+
358
+ raise ValueError(f"Unsupported image type: {type(image)}")
359
+
360
+ except Exception as e:
361
+ print(f"An error occurred: {e}")
362
+ if not isinstance(dst, Path):
363
+ dst = Path(image)
364
+ return dst # Return the original file_location if an error occurs
365
+
366
+ def _copy_file(file: Any | np.ndarray | str | Path, dst_folder: Path) -> Path:
367
+ try:
368
+ """Copy file to the appropriate folder."""
369
+ # Already a path => copy it
370
+ if isinstance(file, str):
371
+ file = Path(file)
372
+ if isinstance(file, Path):
373
+ dst = dst_folder / f"{file.stem}_{uuid4().hex}{file.suffix}" # keep file ext
374
+ shutil.copyfile(file, dst)
375
+ return dst
376
+
377
+ # Still a Python object => serialize it
378
+ if isinstance(file, np.ndarray):
379
+ file = Image.fromarray(file)
380
+ dst = dst_folder / f"{file.filename}_{uuid4().hex}{file.suffix}"
381
+ file.save(dst)
382
+ return dst
383
+
384
+ # try other file types
385
+ kind = filetype.guess(file)
386
+ if kind is not None:
387
+ dst = dst_folder / f"{Path(file).stem}_{uuid4().hex}.{kind.extension}"
388
+ shutil.copyfile(file, dst)
389
+ return dst
390
+ raise ValueError(f"Unsupported file type: {type(file)}")
391
+
392
+ except Exception as e:
393
+ print(f"An error occurred: {e}")
394
+ if not isinstance(dst, Path):
395
+ dst = Path(file)
396
+ return dst # Return the original file_location if an error occurs
397
+
398
+
399
+ def _add_metadata(file_location: Path, metadata: Dict[str, Any]) -> Path:
400
+ try:
401
+ file_type = file_location.suffix
402
+ valid_file_types = [".wav", ".mp3", ".mp4", ".png"]
403
+ if file_type not in valid_file_types:
404
+ raise ValueError("Invalid file type. Valid file types are .wav, .mp3, .mp4, .png")
405
+
406
+ if file_type == ".wav":
407
+ # Open and process .wav file
408
+ with wave.open(file_location, 'rb') as wav_file:
409
+ # Get the current metadata
410
+ current_metadata = {key: value for key, value in wav_file.getparams()._asdict().items() if isinstance(value, (int, float))}
411
+
412
+ # Update metadata
413
+ current_metadata.update(metadata)
414
+
415
+ # Reopen the WAV file in write mode
416
+ with wave.open(file_location, 'wb') as wav_output_file:
417
+ # Set the new metadata
418
+ wav_output_file.setparams(wav_file.getparams())
419
+
420
+ # Save the WAV file (overwriting the previous version)
421
+ wav_output_file.close()
422
+ elif file_type == ".mp3":
423
+ # Open and process .mp3 file
424
+ audio = EasyMP3(file_location)
425
+
426
+ # Add metadata to the file
427
+ for key, value in metadata.items():
428
+ audio[key] = value
429
+
430
+ # Save the MP3 file (overwriting the previous version)
431
+ audio.save()
432
+ elif file_type == ".mp4":
433
+ # Open and process .mp4 file
434
+ # Add metadata to the file
435
+ wav_file_location = file_location.with_suffix(".wav")
436
+ wave_exists = wav_file_location.exists()
437
+ if not wave_exists:
438
+ # Use torchaudio to create the WAV file if it doesn't exist
439
+ audio, sample_rate = torchaudio.load(file_location, normalize=True)
440
+ torchaudio.save(wav_file_location, audio, sample_rate, format='wav')
441
+
442
+ # Use ffmpeg to add metadata to the video file
443
+ metadata_args = [f"{key}={value}" for key, value in metadata.items()]
444
+ ffmpeg_metadata = ":".join(metadata_args)
445
+ ffmpeg_cmd = f'ffmpeg -i "{file_location}" -i "{wav_file_location}" -map 0:v:0 -map 1:a:0 -c:v copy -c:a aac -metadata "{ffmpeg_metadata}" "{file_location}"'
446
+ subprocess.run(ffmpeg_cmd, shell=True, check=True)
447
+
448
+ # Remove temporary WAV file
449
+ if not wave_exists:
450
+ wav_file_location.unlink()
451
+ elif file_type == ".png":
452
+ # Open and process .png file
453
+ image = Image.open(file_location)
454
+ exif_data = image.info.get("exif", {})
455
+ exif_data.update(metadata)
456
+ # Add metadata to the file
457
+ image.save(file_location, exif=exif_data)
458
+
459
+ return file_location # Return the path to the modified file
460
+
461
+ except Exception as e:
462
+ print(f"An error occurred: {e}")
463
+ return file_location # Return the original file_location if an error occurs
464
+
465
+ def _resolve_folder_path(folder_path: str | Path | None) -> Path:
466
+ if folder_path is not None:
467
+ return Path(folder_path).expanduser().resolve()
468
+
469
+ if os.getenv("SYSTEM") == "spaces" and os.path.exists("/data"): # Persistent storage is enabled!
470
+ return Path("/data") / "_user_history"
471
+
472
+ # Not in a Space or Persistent storage not enabled => local folder
473
+ return Path("_user_history").resolve()
474
+
475
+
476
+ def _archives_path() -> Path:
477
+ # Doesn't have to be on persistent storage as it's only used for download
478
+ path = Path(__file__).parent / "_user_history_exports"
479
+ path.mkdir(parents=True, exist_ok=True)
480
+ return path
481
+
482
+
483
+ #################
484
+ # Admin section #
485
+ #################
486
+
487
+
488
+ def _admin_section() -> None:
489
+ title = gr.Markdown()
490
+ title.attach_load_event(_display_if_admin(), every=None)
491
+
492
+
493
+ def _display_if_admin() -> Callable:
494
+ def _inner(profile: gr.OAuthProfile | None) -> str:
495
+ if profile is None:
496
+ return ""
497
+ if profile["preferred_username"] in _fetch_admins():
498
+ return _admin_content()
499
+ return ""
500
+
501
+ return _inner
502
+
503
+
504
+ def _admin_content() -> str:
505
+ return f"""
506
+ ## Admin section
507
+
508
+ Running on **{os.getenv("SYSTEM", "local")}** (id: {os.getenv("SPACE_ID")}). {_get_msg_is_persistent_storage_enabled()}
509
+
510
+ Admins: {', '.join(_fetch_admins())}
511
+
512
+ {_get_nb_users()} user(s), {_get_nb_images()} image(s)
513
+
514
+ ### Configuration
515
+
516
+ History folder: *{_UserHistory().folder_path}*
517
+
518
+ Exports folder: *{_archives_path()}*
519
+
520
+ ### Disk usage
521
+
522
+ {_disk_space_warning_message()}
523
+ """
524
+
525
+
526
+ def _get_nb_users() -> int:
527
+ user_history = _UserHistory()
528
+ if not user_history.initialized:
529
+ return 0
530
+ if user_history.folder_path is not None and user_history.folder_path.exists():
531
+ return len([path for path in user_history.folder_path.iterdir() if path.is_dir()])
532
+ return 0
533
+
534
+
535
+ def _get_nb_images() -> int:
536
+ user_history = _UserHistory()
537
+ if not user_history.initialized:
538
+ return 0
539
+ if user_history.folder_path is not None and user_history.folder_path.exists():
540
+ return len([path for path in user_history.folder_path.glob("*/images/*")])
541
+ return 0
542
+
543
+
544
+ def _get_msg_is_persistent_storage_enabled() -> str:
545
+ if os.getenv("SYSTEM") == "spaces":
546
+ if os.path.exists("/data"):
547
+ return "Persistent storage is enabled."
548
+ else:
549
+ return (
550
+ "Persistent storage is not enabled. This means that user histories will be deleted when the Space is"
551
+ " restarted. Consider adding a Persistent Storage in your Space settings."
552
+ )
553
+ return ""
554
+
555
+
556
+ def _disk_space_warning_message() -> str:
557
+ user_history = _UserHistory()
558
+ if not user_history.initialized:
559
+ return ""
560
+
561
+ message = ""
562
+ if user_history.folder_path is not None:
563
+ total, used, _ = _get_disk_usage(user_history.folder_path)
564
+ message += f"History folder: **{used / 1e9 :.0f}/{total / 1e9 :.0f}GB** used ({100*used/total :.0f}%)."
565
+
566
+ total, used, _ = _get_disk_usage(_archives_path())
567
+ message += f"\n\nExports folder: **{used / 1e9 :.0f}/{total / 1e9 :.0f}GB** used ({100*used/total :.0f}%)."
568
+
569
+ return f"{message.strip()}"
570
+
571
+
572
+ def _get_disk_usage(path: Path) -> Tuple[int, int, int]:
573
+ for path in [path] + list(path.parents): # first check target_dir, then each parents one by one
574
+ try:
575
+ return shutil.disk_usage(path)
576
+ except OSError: # if doesn't exist or can't read => fail silently and try parent one
577
+ pass
578
+ return 0, 0, 0
579
+
580
+
581
+ @cache
582
+ def _fetch_admins() -> List[str]:
583
+ # Running locally => fake user is admin
584
+ if os.getenv("SYSTEM") != "spaces":
585
+ return ["FakeGradioUser"]
586
+
587
+ # Running in Space but no space_id => ???
588
+ space_id = os.getenv("SPACE_ID")
589
+ if space_id is None:
590
+ return ["Unknown"]
591
+
592
+ # Running in Space => try to fetch organization members
593
+ # Otherwise, it's not an organization => namespace is the user
594
+ namespace = space_id.split("/")[0]
595
+ response = requests.get(f"https://huggingface.co/api/organizations/{namespace}/members")
596
+ if response.status_code == 200:
597
+ return sorted((member["user"] for member in response.json()), key=lambda x: x.lower())
598
+ return [namespace]
modules/version_info.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/version_info.py
2
+
3
+ from audiocraft import __version__ as audiocraft_version
4
+ import subprocess
5
+ import os
6
+ import sys
7
+ import gc
8
+ import gradio as gr
9
+
10
+ git = os.environ.get('GIT', "git")
11
+
12
+ def commit_hash():
13
+ try:
14
+ return subprocess.check_output([git, "rev-parse", "HEAD"], shell=False, encoding='utf8').strip()
15
+ except Exception:
16
+ return "<none>"
17
+
18
+ def get_xformers_version():
19
+ try:
20
+ import xformers
21
+ return xformers.__version__
22
+ except Exception:
23
+ return "<none>"
24
+ def get_transformers_version():
25
+ try:
26
+ import transformers
27
+ return transformers.__version__
28
+ except Exception:
29
+ return "<none>"
30
+
31
+ def get_accelerate_version():
32
+ try:
33
+ import accelerate
34
+ return accelerate.__version__
35
+ except Exception:
36
+ return "<none>"
37
+ def get_safetensors_version():
38
+ try:
39
+ import safetensors
40
+ return safetensors.__version__
41
+ except Exception:
42
+ return "<none>"
43
+ def get_diffusers_version():
44
+ try:
45
+ import diffusers
46
+ return diffusers.__version__
47
+ except Exception:
48
+ return "<none>"
49
+
50
+ def get_torch_info():
51
+ from torch import __version__ as torch_version_, version, cuda, backends
52
+ device_type = initialize_cuda()
53
+ if device_type == "cuda":
54
+ try:
55
+ info = [torch_version_, f"CUDA Version:{version.cuda}", f"Available:{cuda.is_available()}", f"flash attention enabled: {backends.cuda.flash_sdp_enabled()}", f"Capabilities: {cuda.get_device_capability(0)}", f"Device Name: {cuda.get_device_name(0)}", f"Device Count: {cuda.device_count()}"]
56
+ del torch_version_, version, cuda, backends
57
+ return info
58
+ except Exception:
59
+ del torch_version_, version, cuda, backends
60
+ return "<none>"
61
+ else:
62
+ return "Not Recognized"
63
+
64
+ def release_torch_resources():
65
+ from torch import cuda
66
+ # Clear the CUDA cache
67
+ cuda.empty_cache()
68
+ cuda.ipc_collect()
69
+ # Delete any objects that are using GPU memory
70
+ #for obj in gc.get_objects():
71
+ # if is_tensor(obj) or (hasattr(obj, 'data') and is_tensor(obj.data)):
72
+ # del obj
73
+ # Run garbage collection
74
+ del cuda
75
+ gc.collect()
76
+
77
+
78
+ def initialize_cuda():
79
+ from torch import cuda, version
80
+ if cuda.is_available():
81
+ device = cuda.device("cuda")
82
+ print(f"CUDA is available. Using device: {cuda.get_device_name(0)} with CUDA version: {version.cuda}")
83
+ result = "cuda"
84
+ else:
85
+ print("CUDA is not available. Using CPU.")
86
+ result = "cpu"
87
+ return result
88
+
89
+ def versions_html():
90
+ from torch import __version__ as torch_version_
91
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
92
+ commit = commit_hash()
93
+
94
+ # Define the Toggle Dark Mode link with JavaScript
95
+ toggle_dark_link = '''
96
+ <a href="#" onclick="document.body.classList.toggle('dark'); return false;" style="cursor: pointer; text-decoration: underline;">
97
+ Toggle Dark Mode
98
+ </a>
99
+ '''
100
+
101
+ v_html = f"""
102
+ version: <a href="https://huggingface.co/spaces/Surn/UnlimitedMusicGen/commit/{"huggingface" if commit == "<none>" else commit}" target="_blank">{"huggingface" if commit == "<none>" else commit}</a>
103
+ &#x2000;•&#x2000;
104
+ autocraft: {audiocraft_version}
105
+ &#x2000;•&#x2000;
106
+ python: <span title="{sys.version}">{python_version}</span>
107
+ &#x2000;•&#x2000;
108
+ torch: {torch_version_}
109
+ &#x2000;•&#x2000;
110
+ xformers: {get_xformers_version()}
111
+ &#x2000;•&#x2000;
112
+ transformers: {get_transformers_version()}
113
+ &#x2000;•&#x2000;
114
+ safetensors: {get_safetensors_version()}
115
+ &#x2000;•&#x2000;
116
+ gradio: {gr.__version__}
117
+ &#x2000;•&#x2000;
118
+ {toggle_dark_link}
119
+ <br>
120
+ Full GPU Info:{get_torch_info()}
121
+ """
122
+ del torch_version_
123
+ return v_html
pre-requirements.txt CHANGED
@@ -1 +1 @@
1
- pip>=23.3
 
1
+ pip>=24.0
requirements.txt CHANGED
@@ -1,22 +1,38 @@
1
  # please make sure you have already a pytorch install that is cuda enabled!
2
- av
3
  einops
4
  flashy>=0.0.1
5
  hydra-core>=1.1
6
  hydra_colorlog
7
- julius
8
- num2words
9
- numpy
10
- sentencepiece
11
- spacy==3.5.2
12
- torch==2.0.1
13
- torchaudio==2.0.2
14
  soundfile
15
  huggingface_hub
16
  tqdm
17
- transformers>=4.31.0
18
- xformers>=0.0.22
19
  demucs
20
  librosa
21
- gradio==3.38.00
22
- pillow
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # please make sure you have already a pytorch install that is cuda enabled!
2
+ av==11.0.0
3
  einops
4
  flashy>=0.0.1
5
  hydra-core>=1.1
6
  hydra_colorlog
7
+ torch==2.6.0 --extra-index-url https://download.pytorch.org/whl/cu124
8
+ torchaudio>=2.0.0,<2.6.2 --extra-index-url https://download.pytorch.org/whl/cu124
 
 
 
 
 
9
  soundfile
10
  huggingface_hub
11
  tqdm
12
+ transformers>=4.48.0 # need Encodec there.
13
+ xformers>=0.0.23 --index-url https://download.pytorch.org/whl/cu124
14
  demucs
15
  librosa
16
+ soundfile
17
+ gradio==5.23.3
18
+ gradio[oauth]
19
+ pillow
20
+ torchmetrics
21
+ encodec
22
+ protobuf>=3.20.1
23
+ filetype
24
+ wave
25
+ mutagen
26
+ fastapi>=0.88.0
27
+ pydantic
28
+ typer
29
+ torchvision>=0.21.0 --extra-index-url https://download.pytorch.org/whl/cu124
30
+ #torchtext
31
+ pesq
32
+ pystoi
33
+ julius
34
+ spacy==3.7.6
35
+ sentencepiece
36
+ num2words
37
+ numpy<1.26.4
38
+ matplotlib
style_20250331.css ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .interface-wrapper {
2
+ max-width: 1024px;
3
+ margin: 0 auto;
4
+ }
5
+
6
+ .centered {
7
+ margin: 0 auto;
8
+ display: block;
9
+ text-align:center;
10
+ }
11
+
12
+ .solid {
13
+ opacity: 1.0 !important;
14
+ height: auto !important;
15
+ }
16
+
17
+ .intro {
18
+ font-size: 1.2em !important;
19
+ font-weight: bold;
20
+ text-align: center;
21
+ background-color: rgba(242, 218, 163, 0.62);
22
+ }
23
+
24
+ .dark .gradio-container.gradio-container-5-23-3 .contain .intro .prose {
25
+ background-color: rgba(41, 18, 5, 0.38) !important;
26
+ }
27
+ .toast-body.info {
28
+ background-color: rgba(242, 218, 163, 0.75);
29
+ }
30
+ .dark .toast-body.info {
31
+ background-color: rgba(128, 128, 128, 0.75);
32
+ }
33
+
34
+ .small {
35
+ font-size: smaller !important;
36
+ text-align: center;
37
+ }
38
+
39
+ .imgcontainer img {
40
+ object-fit: contain !important;
41
+ }
42
+
43
+ #examples {
44
+ font-weight: bolder;
45
+ }
46
+
47
+ --background-fill-primary: #FBCE50 !important;
48
+ #col-container {
49
+ max-width: 1024px;
50
+ margin-left: auto;
51
+ margin-right: auto;
52
+ }
53
+
54
+ a {
55
+ text-decoration-line: underline;
56
+ font-weight: 600;
57
+ }
58
+
59
+ #btn-generate {
60
+ background-image: linear-gradient(to right bottom, rgb(157, 255, 157), rgb(229, 255, 235));
61
+ color: var(--primary-800);
62
+ }
63
+
64
+ #btn-generate:hover {
65
+ background-image: linear-gradient(to right bottom, rgb(229, 255, 229), rgb(255, 255, 255));
66
+ }
67
+
68
+ #btn-generate:active {
69
+ background-image: linear-gradient(to right bottom, rgb(229, 255, 235), rgb(157, 255, 157));
70
+ }
71
+
72
+ #versions {
73
+ margin-top: 1em;
74
+ width: 100%;
75
+ text-align: center;
76
+ }
77
+
78
+ .small-btn {
79
+ max-width: 75px;
80
+ }
81
+
82
+ #gallery .thumbnails, #lora_gallery .thumbnails {
83
+ flex-direction: column !important;
84
+ display: inline-flex !important;
85
+ flex-wrap: wrap !important;
86
+ position: relative !important;
87
+ }
88
+
89
+ #gallery caption.caption, #lora_gallery caption.caption {
90
+ flex-direction: row !important;
91
+ display: inline-flex !important;
92
+ flex-wrap: wrap;
93
+ white-space: unset !important;
94
+ }
95
+
96
+ #gallery .image-button img.with-caption, #lora_gallery .image-button img.with-caption {
97
+ object-fit: cover !important;
98
+ object-position: center !important;
99
+ }
100
+
101
+ #gallery button.preview, #lora_gallery button.preview {
102
+ position: relative !important;
103
+ }
104
+
105
+ .gradio-container::before {
106
+ content: ' ';
107
+ display: block;
108
+ position: absolute;
109
+ left: 0;
110
+ top: 0;
111
+ width: 100%;
112
+ height: 100%;
113
+ opacity: 0.5;
114
+ background-image: url('gradio_api/file=./assets/Vermilion-Musical-Notes-Typography-No-Background.svg');
115
+ background-repeat: no-repeat;
116
+ background-position: 50% 25%;
117
+ /*background-color: rgba(0,0,0,0.5);*/
118
+ background-size: 45vh;
119
+ overflow: hidden;
120
+ }
121
+
122
+ .gradio-container::after {
123
+ content: '';
124
+ position: absolute;
125
+ top: 0;
126
+ left: -60%; /* Start off-screen */
127
+ width: 30%;
128
+ height: 100%;
129
+ background: linear-gradient( 120deg, rgba(255, 255, 255, 0) 10%, rgba(255, 255, 255, 0.60) 50%, rgba(255, 255, 255, 0) 90% );
130
+ animation: shine 30s infinite;
131
+ }
132
+
133
+ #component-0, #component-1 {
134
+ opacity: 0.9;
135
+ }
136
+
137
+ #excluded_colors {
138
+ width: 95%;
139
+ margin: 0 auto;
140
+ font-size: smaller;
141
+ }
142
+
143
+ @media only screen and (min-width: 1920px) {
144
+ .gradio-container, .gradio-container::before {
145
+ max-width: 1920px !important;
146
+ }
147
+ }
148
+
149
+ .sidebar .toggle-button::before {
150
+ content: 'Sketch Pad';
151
+ font-weight: bold;
152
+ transform: rotate(180deg);
153
+ margin-right: -120px;
154
+ width: 120px;
155
+ background-color: rgba(242, 218, 163, 0.62);
156
+ }
157
+ .dark .sidebar .toggle-button::before {
158
+ background-color: rgba(41, 18, 5, 0.38) !important;
159
+ }
160
+ .sidebar.open .toggle-button::before {
161
+ content: '';
162
+ }
163
+
164
+ #sketchpd, #filters, #image_gen, #accordian_3d {
165
+ outline-color: #bbf7d0;
166
+ outline-style:solid;
167
+ outline-width: 1px;
168
+ outline-offset: 1px;
169
+ padding: 2px;
170
+ border-radius:6px;
171
+ }
172
+ .outline-important {
173
+ outline-color: var(--accordion-text-color);
174
+ outline-style: solid;
175
+ outline-width: 2px;
176
+ outline-offset: 2px;
177
+ padding: 2px;
178
+ border-radius: 6px;
179
+ }
180
+ .selected.svelte-1tcem6n.svelte-1tcem6n {
181
+ font-size: large;
182
+ font-weight: bold;
183
+ color: var(--body-text-color);
184
+ }
185
+ .tab-wrapper.svelte-1tcem6n.svelte-1tcem6n {
186
+ height: var(--size-12);
187
+ padding-bottom: var(--size-1);
188
+ text-align: center;
189
+ background-blend-mode: multiply;
190
+ border-radius: var(--block-radius);
191
+ background-color: var(--block-background-fill);
192
+
193
+ outline-color: var(--accordion-text-color);
194
+ outline-style: solid;
195
+ outline-width: 2px;
196
+ outline-offset: 2px;
197
+ padding: 2px;
198
+ border-radius: 6px;
199
+ }
200
+
201
+
202
+
203
+ @keyframes shine {
204
+ 0% {
205
+ left: -100%;
206
+ }
207
+
208
+ 20% {
209
+ left: 100%;
210
+ }
211
+
212
+ 100% {
213
+ left: 125%;
214
+ }
215
+ }