awacke1 commited on
Commit
18caa10
·
1 Parent(s): dc3a2de

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +233 -0
app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from queue import Queue
2
+ from threading import Thread
3
+ from typing import Optional
4
+
5
+ import numpy as np
6
+ import torch
7
+
8
+ from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed
9
+ from transformers.generation.streamers import BaseStreamer
10
+
11
+ import gradio as gr
12
+ import spaces
13
+
14
+
15
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
16
+ processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")
17
+
18
+ title = "MusicGen Streaming"
19
+
20
+ description = """
21
+ Stream the outputs of the MusicGen text-to-music model by playing the generated audio as soon as the first chunk is ready.
22
+ Demo uses [MusicGen Small](https://huggingface.co/facebook/musicgen-small) in the 🤗 Transformers library. Note that the
23
+ demo works best on the Chrome browser. If there is no audio output, try switching browsers to Chrome.
24
+ """
25
+
26
+ article = """
27
+ ## How Does It Work?
28
+ MusicGen is an auto-regressive transformer-based model, meaning generates audio codes (tokens) in a causal fashion.
29
+ At each decoding step, the model generates a new set of audio codes, conditional on the text input and all previous audio codes. From the
30
+ frame rate of the [EnCodec model](https://huggingface.co/facebook/encodec_32khz) used to decode the generated codes to audio waveform,
31
+ each set of generated audio codes corresponds to 0.02 seconds. This means we require a total of 1000 decoding steps to generate
32
+ 20 seconds of audio.
33
+ Rather than waiting for the entire audio sequence to be generated, which would require the full 1000 decoding steps, we can start
34
+ playing the audio after a specified number of decoding steps have been reached, a techinque known as [*streaming*](https://huggingface.co/docs/transformers/main/en/generation_strategies#streaming).
35
+ For example, after 250 steps we have the first 5 seconds of audio ready, and so can play this without waiting for the remaining
36
+ 750 decoding steps to be complete. As we continue to generate with the MusicGen model, we append new chunks of generated audio
37
+ to our output waveform on-the-fly. After the full 1000 decoding steps, the generated audio is complete, and is composed of four
38
+ chunks of audio, each corresponding to 250 tokens.
39
+ This method of playing incremental generations reduces the latency of the MusicGen model from the total time to generate 1000 tokens,
40
+ to the time taken to play the first chunk of audio (250 tokens). This can result in significant improvements to perceived latency,
41
+ particularly when the chunk size is chosen to be small. In practice, the chunk size should be tuned to your device: using a
42
+ smaller chunk size will mean that the first chunk is ready faster, but should not be chosen so small that the model generates slower
43
+ than the time it takes to play the audio.
44
+ For details on how the streaming class works, check out the source code for the [MusicgenStreamer](https://huggingface.co/spaces/sanchit-gandhi/musicgen-streaming/blob/main/app.py#L52).
45
+ """
46
+
47
+
48
+ class MusicgenStreamer(BaseStreamer):
49
+ def __init__(
50
+ self,
51
+ model: MusicgenForConditionalGeneration,
52
+ device: Optional[str] = None,
53
+ play_steps: Optional[int] = 10,
54
+ stride: Optional[int] = None,
55
+ timeout: Optional[float] = None,
56
+ ):
57
+ """
58
+ Streamer that stores playback-ready audio in a queue, to be used by a downstream application as an iterator. This is
59
+ useful for applications that benefit from acessing the generated audio in a non-blocking way (e.g. in an interactive
60
+ Gradio demo).
61
+ Parameters:
62
+ model (`MusicgenForConditionalGeneration`):
63
+ The MusicGen model used to generate the audio waveform.
64
+ device (`str`, *optional*):
65
+ The torch device on which to run the computation. If `None`, will default to the device of the model.
66
+ play_steps (`int`, *optional*, defaults to 10):
67
+ The number of generation steps with which to return the generated audio array. Using fewer steps will
68
+ mean the first chunk is ready faster, but will require more codec decoding steps overall. This value
69
+ should be tuned to your device and latency requirements.
70
+ stride (`int`, *optional*):
71
+ The window (stride) between adjacent audio samples. Using a stride between adjacent audio samples reduces
72
+ the hard boundary between them, giving smoother playback. If `None`, will default to a value equivalent to
73
+ play_steps // 6 in the audio space.
74
+ timeout (`int`, *optional*):
75
+ The timeout for the audio queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
76
+ in `.generate()`, when it is called in a separate thread.
77
+ """
78
+ self.decoder = model.decoder
79
+ self.audio_encoder = model.audio_encoder
80
+ self.generation_config = model.generation_config
81
+ self.device = device if device is not None else model.device
82
+
83
+ # variables used in the streaming process
84
+ self.play_steps = play_steps
85
+ if stride is not None:
86
+ self.stride = stride
87
+ else:
88
+ hop_length = np.prod(self.audio_encoder.config.upsampling_ratios)
89
+ self.stride = hop_length * (play_steps - self.decoder.num_codebooks) // 6
90
+ self.token_cache = None
91
+ self.to_yield = 0
92
+
93
+ # varibles used in the thread process
94
+ self.audio_queue = Queue()
95
+ self.stop_signal = None
96
+ self.timeout = timeout
97
+
98
+ def apply_delay_pattern_mask(self, input_ids):
99
+ # build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to MusicGen)
100
+ _, decoder_delay_pattern_mask = self.decoder.build_delay_pattern_mask(
101
+ input_ids[:, :1],
102
+ pad_token_id=self.generation_config.decoder_start_token_id,
103
+ max_length=input_ids.shape[-1],
104
+ )
105
+ # apply the pattern mask to the input ids
106
+ input_ids = self.decoder.apply_delay_pattern_mask(input_ids, decoder_delay_pattern_mask)
107
+
108
+ # revert the pattern delay mask by filtering the pad token id
109
+ input_ids = input_ids[input_ids != self.generation_config.pad_token_id].reshape(
110
+ 1, self.decoder.num_codebooks, -1
111
+ )
112
+
113
+ # append the frame dimension back to the audio codes
114
+ input_ids = input_ids[None, ...]
115
+
116
+ # send the input_ids to the correct device
117
+ input_ids = input_ids.to(self.audio_encoder.device)
118
+
119
+ output_values = self.audio_encoder.decode(
120
+ input_ids,
121
+ audio_scales=[None],
122
+ )
123
+ audio_values = output_values.audio_values[0, 0]
124
+ return audio_values.cpu().float().numpy()
125
+
126
+ def put(self, value):
127
+ batch_size = value.shape[0] // self.decoder.num_codebooks
128
+ if batch_size > 1:
129
+ raise ValueError("MusicgenStreamer only supports batch size 1")
130
+
131
+ if self.token_cache is None:
132
+ self.token_cache = value
133
+ else:
134
+ self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
135
+
136
+ if self.token_cache.shape[-1] % self.play_steps == 0:
137
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
138
+ self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
139
+ self.to_yield += len(audio_values) - self.to_yield - self.stride
140
+
141
+ def end(self):
142
+ """Flushes any remaining cache and appends the stop symbol."""
143
+ if self.token_cache is not None:
144
+ audio_values = self.apply_delay_pattern_mask(self.token_cache)
145
+ else:
146
+ audio_values = np.zeros(self.to_yield)
147
+
148
+ self.on_finalized_audio(audio_values[self.to_yield :], stream_end=True)
149
+
150
+ def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
151
+ """Put the new audio in the queue. If the stream is ending, also put a stop signal in the queue."""
152
+ self.audio_queue.put(audio, timeout=self.timeout)
153
+ if stream_end:
154
+ self.audio_queue.put(self.stop_signal, timeout=self.timeout)
155
+
156
+ def __iter__(self):
157
+ return self
158
+
159
+ def __next__(self):
160
+ value = self.audio_queue.get(timeout=self.timeout)
161
+ if not isinstance(value, np.ndarray) and value == self.stop_signal:
162
+ raise StopIteration()
163
+ else:
164
+ return value
165
+
166
+
167
+ sampling_rate = model.audio_encoder.config.sampling_rate
168
+ frame_rate = model.audio_encoder.config.frame_rate
169
+
170
+ target_dtype = np.int16
171
+ max_range = np.iinfo(target_dtype).max
172
+
173
+
174
+ @spaces.GPU
175
+ def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0, seed=0):
176
+ max_new_tokens = int(frame_rate * audio_length_in_s)
177
+ play_steps = int(frame_rate * play_steps_in_s)
178
+
179
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
180
+ if device != model.device:
181
+ model.to(device)
182
+ if device == "cuda:0":
183
+ model.half()
184
+
185
+ inputs = processor(
186
+ text=text_prompt,
187
+ padding=True,
188
+ return_tensors="pt",
189
+ )
190
+
191
+ streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
192
+
193
+ generation_kwargs = dict(
194
+ **inputs.to(device),
195
+ streamer=streamer,
196
+ max_new_tokens=max_new_tokens,
197
+ )
198
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
199
+ thread.start()
200
+
201
+ set_seed(seed)
202
+ for new_audio in streamer:
203
+ print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
204
+ new_audio = (new_audio * max_range).astype(np.int16)
205
+ yield (sampling_rate, new_audio)
206
+
207
+
208
+ demo = gr.Interface(
209
+ fn=generate_audio,
210
+ inputs=[
211
+ gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
212
+ gr.Slider(10, 30, value=15, step=5, label="Audio length in seconds"),
213
+ gr.Slider(0.5, 2.5, value=0.5, step=0.5, label="Streaming interval in seconds", info="Lower = shorter chunks, lower latency, more codec steps"),
214
+ gr.Slider(0, 10, value=5, step=1, label="Seed for random generations"),
215
+ ],
216
+ outputs=[
217
+ gr.Audio(label="Generated Music", streaming=True, autoplay=True)
218
+ ],
219
+ examples=[
220
+ ["An 80s driving pop song with heavy drums and synth pads in the background", 20, 0.5, 5],
221
+ ["A cheerful country song with acoustic guitars", 15, 0.5, 5],
222
+ ["90s rock song with electric guitar and heavy drums", 15, 0.5, 5],
223
+ ["a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130", 30, 0.5, 5],
224
+ ["lofi slow bpm electro chill with organic samples", 30, 0.5, 5],
225
+ ],
226
+ title=title,
227
+ description=description,
228
+ article=article,
229
+ cache_examples=False,
230
+ )
231
+
232
+
233
+ demo.queue().launch()