Spaces:
Runtime error
Runtime error
Jukebox
Browse files
app.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
!nvidia-smi
|
2 |
+
|
3 |
+
|
4 |
+
import jukebox
|
5 |
+
import torch as t
|
6 |
+
import librosa
|
7 |
+
import os
|
8 |
+
from IPython.display import Audio
|
9 |
+
from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
|
10 |
+
from jukebox.hparams import Hyperparams, setup_hparams
|
11 |
+
from jukebox.sample import sample_single_window, _sample, \
|
12 |
+
sample_partial_window, upsample
|
13 |
+
from jukebox.utils.dist_utils import setup_dist_from_mpi
|
14 |
+
from jukebox.utils.torch_utils import empty_cache
|
15 |
+
rank, local_rank, device = setup_dist_from_mpi()
|
16 |
+
|
17 |
+
|
18 |
+
model = "5b_lyrics" # or "1b_lyrics"
|
19 |
+
hps = Hyperparams()
|
20 |
+
hps.sr = 44100
|
21 |
+
hps.n_samples = 3 if model=='5b_lyrics' else 8
|
22 |
+
hps.name = 'samples'
|
23 |
+
chunk_size = 16 if model=="5b_lyrics" else 32
|
24 |
+
max_batch_size = 3 if model=="5b_lyrics" else 16
|
25 |
+
hps.levels = 3
|
26 |
+
hps.hop_fraction = [.5,.5,.125]
|
27 |
+
|
28 |
+
vqvae, *priors = MODELS[model]
|
29 |
+
vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)
|
30 |
+
top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)
|
31 |
+
|
32 |
+
|
33 |
+
sample_length_in_seconds = 60 # Full length of musical sample to generate - we find songs in the 1 to 4 minute
|
34 |
+
# range work well, with generation time proportional to sample length.
|
35 |
+
# This total length affects how quickly the model
|
36 |
+
# progresses through lyrics (model also generates differently
|
37 |
+
# depending on if it thinks it's in the beginning, middle, or end of sample)
|
38 |
+
|
39 |
+
hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens
|
40 |
+
assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'
|
41 |
+
metas = [dict(artist = "Zac Brown Band",
|
42 |
+
genre = "Country",
|
43 |
+
total_length = hps.sample_length,
|
44 |
+
offset = 0,
|
45 |
+
lyrics = """I met a traveller from an antique land,
|
46 |
+
Who said—“Two vast and trunkless legs of stone
|
47 |
+
Stand in the desert. . . . Near them, on the sand,
|
48 |
+
Half sunk a shattered visage lies, whose frown,
|
49 |
+
And wrinkled lip, and sneer of cold command,
|
50 |
+
Tell that its sculptor well those passions read
|
51 |
+
Which yet survive, stamped on these lifeless things,
|
52 |
+
The hand that mocked them, and the heart that fed;
|
53 |
+
And on the pedestal, these words appear:
|
54 |
+
My name is Ozymandias, King of Kings;
|
55 |
+
Look on my Works, ye Mighty, and despair!
|
56 |
+
Nothing beside remains. Round the decay
|
57 |
+
Of that colossal Wreck, boundless and bare
|
58 |
+
The lone and level sands stretch far away
|
59 |
+
""",
|
60 |
+
),
|
61 |
+
] * hps.n_samples
|
62 |
+
labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]
|
63 |
+
# Set this False if you are on a local machine that has enough memory (this allows you to do the
|
64 |
+
# lyrics alignment visualization during the upsampling stage). For a hosted runtime,
|
65 |
+
# we'll need to go ahead and delete the top_prior if you are using the 5b_lyrics model.
|
66 |
+
if True:
|
67 |
+
del top_prior
|
68 |
+
empty_cache()
|
69 |
+
top_prior=None
|
70 |
+
upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]
|
71 |
+
labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]
|
72 |
+
zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
|
73 |
+
|
74 |
+
|
75 |
+
Audio(f'{hps.name}/level_2/item_0.wav')
|
76 |
+
sampling_temperature = .98
|
77 |
+
|
78 |
+
lower_batch_size = 16
|
79 |
+
max_batch_size = 3 if model == "5b_lyrics" else 16
|
80 |
+
lower_level_chunk_size = 32
|
81 |
+
chunk_size = 16 if model == "5b_lyrics" else 32
|
82 |
+
sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,
|
83 |
+
chunk_size=lower_level_chunk_size),
|
84 |
+
dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,
|
85 |
+
chunk_size=lower_level_chunk_size),
|
86 |
+
dict(temp=sampling_temperature, fp16=True,
|
87 |
+
max_batch_size=max_batch_size, chunk_size=chunk_size)]
|
88 |
+
del upsamplers
|
89 |
+
empty_cache()
|
90 |
+
Audio(f'{hps.name}/level_0/item_0.wav')
|
91 |
+
|