Spaces:
Runtime error
Runtime error
Merge branch 'main' into our_hf
Browse files- README.md +11 -5
- app.py +32 -23
- app_batched.py +24 -25
- audiocraft/models/loaders.py +37 -10
- audiocraft/models/musicgen.py +15 -20
- audiocraft/utils/utils.py +1 -1
- hf_loading.py +0 -61
- mypy.ini +1 -1
- requirements.txt +1 -0
README.md
CHANGED
@@ -56,15 +56,21 @@ You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./d
|
|
56 |
## API
|
57 |
|
58 |
We provide a simple API and 4 pre-trained models. The pre trained models are:
|
59 |
-
- `small`: 300M model, text to music only
|
60 |
-
- `medium`: 1.5B model, text to music only
|
61 |
-
- `melody`: 1.5B model, text to music and text+melody to music
|
62 |
-
- `large`: 3.3B model, text to music only.
|
63 |
|
64 |
We observe the best trade-off between quality and compute with the `medium` or `melody` model.
|
65 |
In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
|
66 |
GPUs will be able to generate short sequences, or longer sequences with the `small` model.
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
See after a quick example for using the API.
|
69 |
|
70 |
```python
|
@@ -84,7 +90,7 @@ wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), s
|
|
84 |
|
85 |
for idx, one_wav in enumerate(wav):
|
86 |
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
|
87 |
-
audio_write(f'{idx}', one_wav, model.sample_rate, strategy="loudness")
|
88 |
```
|
89 |
|
90 |
|
|
|
56 |
## API
|
57 |
|
58 |
We provide a simple API and 4 pre-trained models. The pre trained models are:
|
59 |
+
- `small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small)
|
60 |
+
- `medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium)
|
61 |
+
- `melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody)
|
62 |
+
- `large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large)
|
63 |
|
64 |
We observe the best trade-off between quality and compute with the `medium` or `melody` model.
|
65 |
In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
|
66 |
GPUs will be able to generate short sequences, or longer sequences with the `small` model.
|
67 |
|
68 |
+
**Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`.
|
69 |
+
You can install it with:
|
70 |
+
```
|
71 |
+
apt get install ffmpeg
|
72 |
+
```
|
73 |
+
|
74 |
See after a quick example for using the API.
|
75 |
|
76 |
```python
|
|
|
90 |
|
91 |
for idx, one_wav in enumerate(wav):
|
92 |
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
|
93 |
+
audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness")
|
94 |
```
|
95 |
|
96 |
|
app.py
CHANGED
@@ -6,9 +6,12 @@ This source code is licensed under the license found in the
|
|
6 |
LICENSE file in the root directory of this source tree.
|
7 |
"""
|
8 |
|
|
|
9 |
import torch
|
10 |
import gradio as gr
|
11 |
-
from
|
|
|
|
|
12 |
|
13 |
|
14 |
MODEL = None
|
@@ -16,7 +19,7 @@ MODEL = None
|
|
16 |
|
17 |
def load_model(version):
|
18 |
print("Loading model", version)
|
19 |
-
return get_pretrained(version)
|
20 |
|
21 |
|
22 |
def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
|
@@ -51,8 +54,11 @@ def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
|
|
51 |
else:
|
52 |
output = MODEL.generate(descriptions=[text], progress=False)
|
53 |
|
54 |
-
output = output.detach().cpu().
|
55 |
-
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
with gr.Blocks() as demo:
|
@@ -60,25 +66,12 @@ with gr.Blocks() as demo:
|
|
60 |
"""
|
61 |
# MusicGen
|
62 |
|
63 |
-
This is the demo for MusicGen, a simple and controllable model for music generation
|
64 |
-
|
65 |
-
Below we present 3 model variations:
|
66 |
-
1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
|
67 |
-
2. Small -- a 300M transformer decoder conditioned on text only.
|
68 |
-
3. Medium -- a 1.5B transformer decoder conditioned on text only.
|
69 |
-
4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
|
70 |
-
|
71 |
-
When the optional melody conditioning wav is provided, the model will extract
|
72 |
-
a broad melody and try to follow it in the generated samples.
|
73 |
-
|
74 |
-
For skipping queue, you can duplicate this space, and upgrade to GPU in the settings.
|
75 |
<br/>
|
76 |
-
<a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true">
|
77 |
-
<img style="margin-
|
78 |
-
|
79 |
-
|
80 |
-
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
|
81 |
-
for more details.
|
82 |
"""
|
83 |
)
|
84 |
with gr.Row():
|
@@ -98,7 +91,7 @@ with gr.Blocks() as demo:
|
|
98 |
temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
|
99 |
cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
|
100 |
with gr.Column():
|
101 |
-
output = gr.
|
102 |
submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
|
103 |
gr.Examples(
|
104 |
fn=predict,
|
@@ -132,5 +125,21 @@ with gr.Blocks() as demo:
|
|
132 |
inputs=[text, melody, model],
|
133 |
outputs=[output]
|
134 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
demo.launch()
|
|
|
6 |
LICENSE file in the root directory of this source tree.
|
7 |
"""
|
8 |
|
9 |
+
from tempfile import NamedTemporaryFile
|
10 |
import torch
|
11 |
import gradio as gr
|
12 |
+
from audiocraft.models import MusicGen
|
13 |
+
|
14 |
+
from audiocraft.data.audio import audio_write
|
15 |
|
16 |
|
17 |
MODEL = None
|
|
|
19 |
|
20 |
def load_model(version):
|
21 |
print("Loading model", version)
|
22 |
+
return MusicGen.get_pretrained(version)
|
23 |
|
24 |
|
25 |
def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
|
|
|
54 |
else:
|
55 |
output = MODEL.generate(descriptions=[text], progress=False)
|
56 |
|
57 |
+
output = output.detach().cpu().float()[0]
|
58 |
+
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
59 |
+
audio_write(file.name, output, MODEL.sample_rate, strategy="loudness", add_suffix=False)
|
60 |
+
waveform_video = gr.make_waveform(file.name)
|
61 |
+
return waveform_video
|
62 |
|
63 |
|
64 |
with gr.Blocks() as demo:
|
|
|
66 |
"""
|
67 |
# MusicGen
|
68 |
|
69 |
+
This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
|
70 |
+
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
<br/>
|
72 |
+
<a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
|
73 |
+
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
74 |
+
for longer sequences, more control and no queue.</p>
|
|
|
|
|
|
|
75 |
"""
|
76 |
)
|
77 |
with gr.Row():
|
|
|
91 |
temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
|
92 |
cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
|
93 |
with gr.Column():
|
94 |
+
output = gr.Video(label="Generated Music")
|
95 |
submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
|
96 |
gr.Examples(
|
97 |
fn=predict,
|
|
|
125 |
inputs=[text, melody, model],
|
126 |
outputs=[output]
|
127 |
)
|
128 |
+
gr.Markdown(
|
129 |
+
"""
|
130 |
+
### More details
|
131 |
+
|
132 |
+
By typing a description of the music you want and an optional audio used for melody conditioning,
|
133 |
+
|
134 |
+
We present 4 model variations:
|
135 |
+
1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
|
136 |
+
2. Small -- a 300M transformer decoder conditioned on text only.
|
137 |
+
3. Medium -- a 1.5B transformer decoder conditioned on text only.
|
138 |
+
4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
|
139 |
+
|
140 |
+
When the optional melody conditioning wav is provided, the model will extract
|
141 |
+
a broad melody and try to follow it in the generated samples.
|
142 |
+
"""
|
143 |
+
)
|
144 |
|
145 |
demo.launch()
|
app_batched.py
CHANGED
@@ -11,7 +11,7 @@ import torch
|
|
11 |
import gradio as gr
|
12 |
from audiocraft.data.audio_utils import convert_audio
|
13 |
from audiocraft.data.audio import audio_write
|
14 |
-
from
|
15 |
|
16 |
|
17 |
MODEL = None
|
@@ -19,7 +19,7 @@ MODEL = None
|
|
19 |
|
20 |
def load_model():
|
21 |
print("Loading model")
|
22 |
-
return get_pretrained("melody")
|
23 |
|
24 |
|
25 |
def predict(texts, melodies):
|
@@ -58,8 +58,9 @@ def predict(texts, melodies):
|
|
58 |
for output in outputs:
|
59 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
60 |
audio_write(file.name, output, MODEL.sample_rate, strategy="loudness", add_suffix=False)
|
61 |
-
|
62 |
-
|
|
|
63 |
|
64 |
|
65 |
with gr.Blocks() as demo:
|
@@ -67,35 +68,23 @@ with gr.Blocks() as demo:
|
|
67 |
"""
|
68 |
# MusicGen
|
69 |
|
70 |
-
This is the demo for MusicGen, a simple and controllable model for music generation
|
71 |
-
presented at: "Simple and Controllable Music Generation".
|
72 |
-
|
73 |
-
Enter the description of the music you want and an optional audio used for melody conditioning.
|
74 |
-
The model will extract the broad melody from the uploaded wav if provided.
|
75 |
-
This will generate a 12s extract with the `melody` model.
|
76 |
-
|
77 |
-
For generating longer sequences (up to 30 seconds) and skipping queue, you can duplicate
|
78 |
-
to full demo space, which contains more control and upgrade to GPU in the settings.
|
79 |
<br/>
|
80 |
-
<a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true">
|
81 |
-
<img style="margin-
|
82 |
-
</p>
|
83 |
-
|
84 |
-
You can also use your own GPU or a Google Colab by following the instructions on our repo.
|
85 |
-
|
86 |
-
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
|
87 |
-
for more details.
|
88 |
"""
|
89 |
)
|
90 |
with gr.Row():
|
91 |
with gr.Column():
|
92 |
with gr.Row():
|
93 |
-
text = gr.Text(label="
|
94 |
-
melody = gr.Audio(source="upload", type="numpy", label="
|
95 |
with gr.Row():
|
96 |
-
submit = gr.Button("
|
97 |
with gr.Column():
|
98 |
-
output = gr.
|
99 |
submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=12)
|
100 |
gr.Examples(
|
101 |
fn=predict,
|
@@ -124,5 +113,15 @@ with gr.Blocks() as demo:
|
|
124 |
inputs=[text, melody],
|
125 |
outputs=[output]
|
126 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
demo.queue(max_size=15).launch()
|
|
|
11 |
import gradio as gr
|
12 |
from audiocraft.data.audio_utils import convert_audio
|
13 |
from audiocraft.data.audio import audio_write
|
14 |
+
from audiocraft.models import MusicGen
|
15 |
|
16 |
|
17 |
MODEL = None
|
|
|
19 |
|
20 |
def load_model():
|
21 |
print("Loading model")
|
22 |
+
return MusicGen.get_pretrained("melody")
|
23 |
|
24 |
|
25 |
def predict(texts, melodies):
|
|
|
58 |
for output in outputs:
|
59 |
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
60 |
audio_write(file.name, output, MODEL.sample_rate, strategy="loudness", add_suffix=False)
|
61 |
+
waveform_video = gr.make_waveform(file.name)
|
62 |
+
out_files.append(waveform_video)
|
63 |
+
return [out_files]
|
64 |
|
65 |
|
66 |
with gr.Blocks() as demo:
|
|
|
68 |
"""
|
69 |
# MusicGen
|
70 |
|
71 |
+
This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
|
72 |
+
presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
<br/>
|
74 |
+
<a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
|
75 |
+
<img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
76 |
+
for longer sequences, more control and no queue</p>
|
|
|
|
|
|
|
|
|
|
|
77 |
"""
|
78 |
)
|
79 |
with gr.Row():
|
80 |
with gr.Column():
|
81 |
with gr.Row():
|
82 |
+
text = gr.Text(label="Describe your music", lines=2, interactive=True)
|
83 |
+
melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
|
84 |
with gr.Row():
|
85 |
+
submit = gr.Button("Generate")
|
86 |
with gr.Column():
|
87 |
+
output = gr.Video(label="Generated Music")
|
88 |
submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=12)
|
89 |
gr.Examples(
|
90 |
fn=predict,
|
|
|
113 |
inputs=[text, melody],
|
114 |
outputs=[output]
|
115 |
)
|
116 |
+
gr.Markdown("""
|
117 |
+
### More details
|
118 |
+
By typing a description of the music you want and an optional audio used for melody conditioning,
|
119 |
+
the model will extract the broad melody from the uploaded wav if provided and generate a 12s extract with the `melody` model.
|
120 |
+
|
121 |
+
You can also use your own GPU or a Google Colab by following the instructions on our repo.
|
122 |
+
|
123 |
+
See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
|
124 |
+
for more details.
|
125 |
+
""")
|
126 |
|
127 |
demo.queue(max_size=15).launch()
|
audiocraft/models/loaders.py
CHANGED
@@ -20,7 +20,9 @@ of the returned model.
|
|
20 |
"""
|
21 |
|
22 |
from pathlib import Path
|
|
|
23 |
import typing as tp
|
|
|
24 |
|
25 |
from omegaconf import OmegaConf
|
26 |
import torch
|
@@ -28,18 +30,43 @@ import torch
|
|
28 |
from . import builders
|
29 |
|
30 |
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
# Return the state dict either from a file or url
|
33 |
-
|
34 |
-
assert isinstance(
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
else:
|
38 |
-
|
39 |
|
40 |
|
41 |
-
def load_compression_model(
|
42 |
-
pkg = _get_state_dict(
|
43 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
44 |
cfg.device = str(device)
|
45 |
model = builders.get_compression_model(cfg)
|
@@ -48,8 +75,8 @@ def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
|
|
48 |
return model
|
49 |
|
50 |
|
51 |
-
def load_lm_model(
|
52 |
-
pkg = _get_state_dict(
|
53 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
54 |
cfg.device = str(device)
|
55 |
if cfg.device == 'cpu':
|
|
|
20 |
"""
|
21 |
|
22 |
from pathlib import Path
|
23 |
+
from huggingface_hub import hf_hub_download
|
24 |
import typing as tp
|
25 |
+
import os
|
26 |
|
27 |
from omegaconf import OmegaConf
|
28 |
import torch
|
|
|
30 |
from . import builders
|
31 |
|
32 |
|
33 |
+
HF_MODEL_CHECKPOINTS_MAP = {
|
34 |
+
"small": "facebook/musicgen-small",
|
35 |
+
"medium": "facebook/musicgen-medium",
|
36 |
+
"large": "facebook/musicgen-large",
|
37 |
+
"melody": "facebook/musicgen-melody",
|
38 |
+
}
|
39 |
+
|
40 |
+
|
41 |
+
def _get_state_dict(
|
42 |
+
file_or_url_or_id: tp.Union[Path, str],
|
43 |
+
filename: tp.Optional[str] = None,
|
44 |
+
device='cpu',
|
45 |
+
cache_dir: tp.Optional[str] = None,
|
46 |
+
):
|
47 |
# Return the state dict either from a file or url
|
48 |
+
file_or_url_or_id = str(file_or_url_or_id)
|
49 |
+
assert isinstance(file_or_url_or_id, str)
|
50 |
+
|
51 |
+
if os.path.isfile(file_or_url_or_id):
|
52 |
+
return torch.load(file_or_url_or_id, map_location=device)
|
53 |
+
|
54 |
+
elif file_or_url_or_id.startswith('https://'):
|
55 |
+
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
|
56 |
+
|
57 |
+
elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
|
58 |
+
assert filename is not None, "filename needs to be defined if using HF checkpoints"
|
59 |
+
|
60 |
+
repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
|
61 |
+
file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
|
62 |
+
return torch.load(file, map_location=device)
|
63 |
+
|
64 |
else:
|
65 |
+
raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
|
66 |
|
67 |
|
68 |
+
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
69 |
+
pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
|
70 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
71 |
cfg.device = str(device)
|
72 |
model = builders.get_compression_model(cfg)
|
|
|
75 |
return model
|
76 |
|
77 |
|
78 |
+
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
79 |
+
pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
|
80 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
81 |
cfg.device = str(device)
|
82 |
if cfg.device == 'cpu':
|
audiocraft/models/musicgen.py
CHANGED
@@ -17,7 +17,7 @@ import torch
|
|
17 |
from .encodec import CompressionModel
|
18 |
from .lm import LMModel
|
19 |
from .builders import get_debug_compression_model, get_debug_lm_model
|
20 |
-
from .loaders import load_compression_model, load_lm_model
|
21 |
from ..data.audio_utils import convert_audio
|
22 |
from ..modules.conditioners import ConditioningAttributes, WavCondition
|
23 |
from ..utils.autocast import TorchAutocast
|
@@ -67,10 +67,10 @@ class MusicGen:
|
|
67 |
@staticmethod
|
68 |
def get_pretrained(name: str = 'melody', device='cuda'):
|
69 |
"""Return pretrained model, we provide four models:
|
70 |
-
- small (300M), text to music,
|
71 |
-
- medium (1.5B), text to music,
|
72 |
-
- melody (1.5B) text to music and text+melody to music,
|
73 |
-
- large (3.3B), text to music.
|
74 |
"""
|
75 |
|
76 |
if name == 'debug':
|
@@ -79,21 +79,16 @@ class MusicGen:
|
|
79 |
lm = get_debug_lm_model(device)
|
80 |
return MusicGen(name, compression_model, lm)
|
81 |
|
82 |
-
if
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
'large': '9b6e835c-1f0cf17b5e',
|
93 |
-
'melody': 'f79af192-61305ffc49',
|
94 |
-
}
|
95 |
-
sig = names[name]
|
96 |
-
lm = load_lm_model(ROOT + f'{sig}.th', device=device)
|
97 |
return MusicGen(name, compression_model, lm)
|
98 |
|
99 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
|
|
17 |
from .encodec import CompressionModel
|
18 |
from .lm import LMModel
|
19 |
from .builders import get_debug_compression_model, get_debug_lm_model
|
20 |
+
from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
|
21 |
from ..data.audio_utils import convert_audio
|
22 |
from ..modules.conditioners import ConditioningAttributes, WavCondition
|
23 |
from ..utils.autocast import TorchAutocast
|
|
|
67 |
@staticmethod
|
68 |
def get_pretrained(name: str = 'melody', device='cuda'):
|
69 |
"""Return pretrained model, we provide four models:
|
70 |
+
- small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
|
71 |
+
- medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
|
72 |
+
- melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
|
73 |
+
- large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
|
74 |
"""
|
75 |
|
76 |
if name == 'debug':
|
|
|
79 |
lm = get_debug_lm_model(device)
|
80 |
return MusicGen(name, compression_model, lm)
|
81 |
|
82 |
+
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
83 |
+
raise ValueError(
|
84 |
+
f"{name} is not a valid checkpoint name. "
|
85 |
+
f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
|
86 |
+
)
|
87 |
+
|
88 |
+
cache_dir = os.environ.get('MUSICGEN_ROOT', None)
|
89 |
+
compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
|
90 |
+
lm = load_lm_model(name, device=device, cache_dir=cache_dir)
|
91 |
+
|
|
|
|
|
|
|
|
|
|
|
92 |
return MusicGen(name, compression_model, lm)
|
93 |
|
94 |
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
audiocraft/utils/utils.py
CHANGED
@@ -122,7 +122,7 @@ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
|
122 |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
123 |
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
124 |
mask = probs_sum - probs_sort > p
|
125 |
-
probs_sort *= (~mask).float(
|
126 |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
127 |
next_token = multinomial(probs_sort, num_samples=1)
|
128 |
next_token = torch.gather(probs_idx, -1, next_token)
|
|
|
122 |
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
123 |
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
124 |
mask = probs_sum - probs_sort > p
|
125 |
+
probs_sort *= (~mask).float()
|
126 |
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
127 |
next_token = multinomial(probs_sort, num_samples=1)
|
128 |
next_token = torch.gather(probs_idx, -1, next_token)
|
hf_loading.py
DELETED
@@ -1,61 +0,0 @@
|
|
1 |
-
"""Utility for loading the models from HF."""
|
2 |
-
from pathlib import Path
|
3 |
-
import typing as tp
|
4 |
-
|
5 |
-
from omegaconf import OmegaConf
|
6 |
-
from huggingface_hub import hf_hub_download
|
7 |
-
import torch
|
8 |
-
|
9 |
-
from audiocraft.models import builders, MusicGen
|
10 |
-
|
11 |
-
MODEL_CHECKPOINTS_MAP = {
|
12 |
-
"small": "facebook/musicgen-small",
|
13 |
-
"medium": "facebook/musicgen-medium",
|
14 |
-
"large": "facebook/musicgen-large",
|
15 |
-
"melody": "facebook/musicgen-melody",
|
16 |
-
}
|
17 |
-
|
18 |
-
|
19 |
-
def _get_state_dict(file_or_url: tp.Union[Path, str],
|
20 |
-
filename="state_dict.bin", device='cpu'):
|
21 |
-
# Return the state dict either from a file or url
|
22 |
-
print("loading", file_or_url, filename)
|
23 |
-
file_or_url = str(file_or_url)
|
24 |
-
assert isinstance(file_or_url, str)
|
25 |
-
return torch.load(
|
26 |
-
hf_hub_download(repo_id=file_or_url, filename=filename), map_location=device)
|
27 |
-
|
28 |
-
|
29 |
-
def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
|
30 |
-
pkg = _get_state_dict(file_or_url, filename="compression_state_dict.bin")
|
31 |
-
cfg = OmegaConf.create(pkg['xp.cfg'])
|
32 |
-
cfg.device = str(device)
|
33 |
-
model = builders.get_compression_model(cfg)
|
34 |
-
model.load_state_dict(pkg['best_state'])
|
35 |
-
model.eval()
|
36 |
-
model.cfg = cfg
|
37 |
-
return model
|
38 |
-
|
39 |
-
|
40 |
-
def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'):
|
41 |
-
pkg = _get_state_dict(file_or_url)
|
42 |
-
cfg = OmegaConf.create(pkg['xp.cfg'])
|
43 |
-
cfg.device = str(device)
|
44 |
-
if cfg.device == 'cpu':
|
45 |
-
cfg.transformer_lm.memory_efficient = False
|
46 |
-
cfg.transformer_lm.custom = True
|
47 |
-
cfg.dtype = 'float32'
|
48 |
-
else:
|
49 |
-
cfg.dtype = 'float16'
|
50 |
-
model = builders.get_lm_model(cfg)
|
51 |
-
model.load_state_dict(pkg['best_state'])
|
52 |
-
model.eval()
|
53 |
-
model.cfg = cfg
|
54 |
-
return model
|
55 |
-
|
56 |
-
|
57 |
-
def get_pretrained(name: str = 'small', device='cuda'):
|
58 |
-
model_id = MODEL_CHECKPOINTS_MAP[name]
|
59 |
-
compression_model = load_compression_model(model_id, device=device)
|
60 |
-
lm = load_lm_model(model_id, device=device)
|
61 |
-
return MusicGen(name, compression_model, lm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mypy.ini
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
[mypy]
|
2 |
|
3 |
-
[mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy]
|
4 |
ignore_missing_imports = True
|
|
|
1 |
[mypy]
|
2 |
|
3 |
+
[mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub]
|
4 |
ignore_missing_imports = True
|
requirements.txt
CHANGED
@@ -11,6 +11,7 @@ sentencepiece
|
|
11 |
spacy==3.5.2
|
12 |
torch>=2.0.0
|
13 |
torchaudio>=2.0.0
|
|
|
14 |
tqdm
|
15 |
transformers
|
16 |
xformers
|
|
|
11 |
spacy==3.5.2
|
12 |
torch>=2.0.0
|
13 |
torchaudio>=2.0.0
|
14 |
+
huggingface_hub
|
15 |
tqdm
|
16 |
transformers
|
17 |
xformers
|