SethyYann98 commited on
Commit
b97d304
·
verified ·
1 Parent(s): ba5666f
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !nvidia-smi
2
+
3
+
4
+ import jukebox
5
+ import torch as t
6
+ import librosa
7
+ import os
8
+ from IPython.display import Audio
9
+ from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
10
+ from jukebox.hparams import Hyperparams, setup_hparams
11
+ from jukebox.sample import sample_single_window, _sample, \
12
+ sample_partial_window, upsample
13
+ from jukebox.utils.dist_utils import setup_dist_from_mpi
14
+ from jukebox.utils.torch_utils import empty_cache
15
+ rank, local_rank, device = setup_dist_from_mpi()
16
+
17
+
18
+ model = "5b_lyrics" # or "1b_lyrics"
19
+ hps = Hyperparams()
20
+ hps.sr = 44100
21
+ hps.n_samples = 3 if model=='5b_lyrics' else 8
22
+ hps.name = 'samples'
23
+ chunk_size = 16 if model=="5b_lyrics" else 32
24
+ max_batch_size = 3 if model=="5b_lyrics" else 16
25
+ hps.levels = 3
26
+ hps.hop_fraction = [.5,.5,.125]
27
+
28
+ vqvae, *priors = MODELS[model]
29
+ vqvae = make_vqvae(setup_hparams(vqvae, dict(sample_length = 1048576)), device)
30
+ top_prior = make_prior(setup_hparams(priors[-1], dict()), vqvae, device)
31
+
32
+
33
+ sample_length_in_seconds = 60 # Full length of musical sample to generate - we find songs in the 1 to 4 minute
34
+ # range work well, with generation time proportional to sample length.
35
+ # This total length affects how quickly the model
36
+ # progresses through lyrics (model also generates differently
37
+ # depending on if it thinks it's in the beginning, middle, or end of sample)
38
+
39
+ hps.sample_length = (int(sample_length_in_seconds*hps.sr)//top_prior.raw_to_tokens)*top_prior.raw_to_tokens
40
+ assert hps.sample_length >= top_prior.n_ctx*top_prior.raw_to_tokens, f'Please choose a larger sampling rate'
41
+ metas = [dict(artist = "Zac Brown Band",
42
+ genre = "Country",
43
+ total_length = hps.sample_length,
44
+ offset = 0,
45
+ lyrics = """I met a traveller from an antique land,
46
+ Who said—“Two vast and trunkless legs of stone
47
+ Stand in the desert. . . . Near them, on the sand,
48
+ Half sunk a shattered visage lies, whose frown,
49
+ And wrinkled lip, and sneer of cold command,
50
+ Tell that its sculptor well those passions read
51
+ Which yet survive, stamped on these lifeless things,
52
+ The hand that mocked them, and the heart that fed;
53
+ And on the pedestal, these words appear:
54
+ My name is Ozymandias, King of Kings;
55
+ Look on my Works, ye Mighty, and despair!
56
+ Nothing beside remains. Round the decay
57
+ Of that colossal Wreck, boundless and bare
58
+ The lone and level sands stretch far away
59
+ """,
60
+ ),
61
+ ] * hps.n_samples
62
+ labels = [None, None, top_prior.labeller.get_batch_labels(metas, 'cuda')]
63
+ # Set this False if you are on a local machine that has enough memory (this allows you to do the
64
+ # lyrics alignment visualization during the upsampling stage). For a hosted runtime,
65
+ # we'll need to go ahead and delete the top_prior if you are using the 5b_lyrics model.
66
+ if True:
67
+ del top_prior
68
+ empty_cache()
69
+ top_prior=None
70
+ upsamplers = [make_prior(setup_hparams(prior, dict()), vqvae, 'cpu') for prior in priors[:-1]]
71
+ labels[:2] = [prior.labeller.get_batch_labels(metas, 'cuda') for prior in upsamplers]
72
+ zs = upsample(zs, labels, sampling_kwargs, [*upsamplers, top_prior], hps)
73
+
74
+
75
+ Audio(f'{hps.name}/level_2/item_0.wav')
76
+ sampling_temperature = .98
77
+
78
+ lower_batch_size = 16
79
+ max_batch_size = 3 if model == "5b_lyrics" else 16
80
+ lower_level_chunk_size = 32
81
+ chunk_size = 16 if model == "5b_lyrics" else 32
82
+ sampling_kwargs = [dict(temp=.99, fp16=True, max_batch_size=lower_batch_size,
83
+ chunk_size=lower_level_chunk_size),
84
+ dict(temp=0.99, fp16=True, max_batch_size=lower_batch_size,
85
+ chunk_size=lower_level_chunk_size),
86
+ dict(temp=sampling_temperature, fp16=True,
87
+ max_batch_size=max_batch_size, chunk_size=chunk_size)]
88
+ del upsamplers
89
+ empty_cache()
90
+ Audio(f'{hps.name}/level_0/item_0.wav')
91
+