kind of working
Browse files- app.py +12 -6
- audiocraft/models/musicgen.py +19 -6
app.py
CHANGED
@@ -59,6 +59,9 @@ def load_model(version='melody'):
|
|
59 |
|
60 |
|
61 |
def _do_predictions(texts, melodies, duration, **gen_kwargs):
|
|
|
|
|
|
|
62 |
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
63 |
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
64 |
be = time.time()
|
@@ -76,7 +79,7 @@ def _do_predictions(texts, melodies, duration, **gen_kwargs):
|
|
76 |
melody = convert_audio(melody, sr, target_sr, target_ac)
|
77 |
processed_melodies.append(melody)
|
78 |
|
79 |
-
if
|
80 |
outputs = MODEL.generate_with_chroma(
|
81 |
descriptions=texts,
|
82 |
melody_wavs=processed_melodies,
|
@@ -110,12 +113,10 @@ def predict_batched(texts, melodies):
|
|
110 |
def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef):
|
111 |
topk = int(topk)
|
112 |
load_model(model)
|
113 |
-
if duration > MODEL.lm.cfg.dataset.segment_duration:
|
114 |
-
raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
|
115 |
|
116 |
outs = _do_predictions(
|
117 |
[text], [melody], duration,
|
118 |
-
|
119 |
return outs[0]
|
120 |
|
121 |
|
@@ -138,7 +139,7 @@ def ui_full(launch_kwargs):
|
|
138 |
with gr.Row():
|
139 |
model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
|
140 |
with gr.Row():
|
141 |
-
duration = gr.Slider(minimum=1, maximum=
|
142 |
with gr.Row():
|
143 |
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
144 |
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
@@ -184,7 +185,12 @@ def ui_full(launch_kwargs):
|
|
184 |
### More details
|
185 |
|
186 |
The model will generate a short music extract based on the description you provided.
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
188 |
|
189 |
We present 4 model variations:
|
190 |
1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
|
|
|
59 |
|
60 |
|
61 |
def _do_predictions(texts, melodies, duration, **gen_kwargs):
|
62 |
+
if duration > MODEL.lm.cfg.dataset.segment_duration:
|
63 |
+
raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
|
64 |
+
|
65 |
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
66 |
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
67 |
be = time.time()
|
|
|
79 |
melody = convert_audio(melody, sr, target_sr, target_ac)
|
80 |
processed_melodies.append(melody)
|
81 |
|
82 |
+
if any(m is not None for m in processed_melodies):
|
83 |
outputs = MODEL.generate_with_chroma(
|
84 |
descriptions=texts,
|
85 |
melody_wavs=processed_melodies,
|
|
|
113 |
def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef):
|
114 |
topk = int(topk)
|
115 |
load_model(model)
|
|
|
|
|
116 |
|
117 |
outs = _do_predictions(
|
118 |
[text], [melody], duration,
|
119 |
+
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
|
120 |
return outs[0]
|
121 |
|
122 |
|
|
|
139 |
with gr.Row():
|
140 |
model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
|
141 |
with gr.Row():
|
142 |
+
duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
|
143 |
with gr.Row():
|
144 |
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
145 |
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
|
|
185 |
### More details
|
186 |
|
187 |
The model will generate a short music extract based on the description you provided.
|
188 |
+
The model can generate up to 30 seconds of audio in one pass. It is now possible
|
189 |
+
to extend the generation by feeding back the end of the previous chunk of audio.
|
190 |
+
This can take a long time, and the model might lose consistency. The model might also
|
191 |
+
decide at arbitrary positions that the song ends.
|
192 |
+
|
193 |
+
**WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
|
194 |
|
195 |
We present 4 model variations:
|
196 |
1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
|
audiocraft/models/musicgen.py
CHANGED
@@ -45,6 +45,7 @@ class MusicGen:
|
|
45 |
self.device = next(iter(lm.parameters())).device
|
46 |
self.generation_params: dict = {}
|
47 |
self.set_generation_params(duration=15) # 15 seconds by default
|
|
|
48 |
if self.device.type == 'cpu':
|
49 |
self.autocast = TorchAutocast(enabled=False)
|
50 |
else:
|
@@ -127,6 +128,9 @@ class MusicGen:
|
|
127 |
'two_step_cfg': two_step_cfg,
|
128 |
}
|
129 |
|
|
|
|
|
|
|
130 |
def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
|
131 |
"""Generate samples in an unconditional manner.
|
132 |
|
@@ -274,6 +278,10 @@ class MusicGen:
|
|
274 |
current_gen_offset: int = 0
|
275 |
|
276 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
|
|
|
|
|
|
|
|
277 |
print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
278 |
|
279 |
if prompt_tokens is not None:
|
@@ -296,11 +304,17 @@ class MusicGen:
|
|
296 |
# melody conditioning etc.
|
297 |
ref_wavs = [attr.wav['self_wav'] for attr in attributes]
|
298 |
all_tokens = []
|
299 |
-
if prompt_tokens is
|
|
|
|
|
300 |
all_tokens.append(prompt_tokens)
|
|
|
|
|
|
|
|
|
301 |
|
302 |
-
|
303 |
-
|
304 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
305 |
max_gen_len = int(chunk_duration * self.frame_rate)
|
306 |
for attr, ref_wav in zip(attributes, ref_wavs):
|
@@ -321,14 +335,13 @@ class MusicGen:
|
|
321 |
gen_tokens = self.lm.generate(
|
322 |
prompt_tokens, attributes,
|
323 |
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
324 |
-
stride_tokens = int(self.frame_rate * self.extend_stride)
|
325 |
if prompt_tokens is None:
|
326 |
all_tokens.append(gen_tokens)
|
327 |
else:
|
328 |
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
329 |
-
prompt_tokens = gen_tokens[:, :, stride_tokens]
|
|
|
330 |
current_gen_offset += stride_tokens
|
331 |
-
time_offset += self.extend_stride
|
332 |
|
333 |
gen_tokens = torch.cat(all_tokens, dim=-1)
|
334 |
|
|
|
45 |
self.device = next(iter(lm.parameters())).device
|
46 |
self.generation_params: dict = {}
|
47 |
self.set_generation_params(duration=15) # 15 seconds by default
|
48 |
+
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
49 |
if self.device.type == 'cpu':
|
50 |
self.autocast = TorchAutocast(enabled=False)
|
51 |
else:
|
|
|
128 |
'two_step_cfg': two_step_cfg,
|
129 |
}
|
130 |
|
131 |
+
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
132 |
+
self._progress_callback = progress_callback
|
133 |
+
|
134 |
def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
|
135 |
"""Generate samples in an unconditional manner.
|
136 |
|
|
|
278 |
current_gen_offset: int = 0
|
279 |
|
280 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
281 |
+
generated_tokens += current_gen_offset
|
282 |
+
if self._progress_callback is not None:
|
283 |
+
self._progress_callback(generated_tokens, total_gen_len)
|
284 |
+
else:
|
285 |
print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
286 |
|
287 |
if prompt_tokens is not None:
|
|
|
304 |
# melody conditioning etc.
|
305 |
ref_wavs = [attr.wav['self_wav'] for attr in attributes]
|
306 |
all_tokens = []
|
307 |
+
if prompt_tokens is None:
|
308 |
+
prompt_length = 0
|
309 |
+
else:
|
310 |
all_tokens.append(prompt_tokens)
|
311 |
+
prompt_length = prompt_tokens.shape[-1]
|
312 |
+
|
313 |
+
|
314 |
+
stride_tokens = int(self.frame_rate * self.extend_stride)
|
315 |
|
316 |
+
while current_gen_offset + prompt_length < total_gen_len:
|
317 |
+
time_offset = current_gen_offset / self.frame_rate
|
318 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
319 |
max_gen_len = int(chunk_duration * self.frame_rate)
|
320 |
for attr, ref_wav in zip(attributes, ref_wavs):
|
|
|
335 |
gen_tokens = self.lm.generate(
|
336 |
prompt_tokens, attributes,
|
337 |
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
|
|
338 |
if prompt_tokens is None:
|
339 |
all_tokens.append(gen_tokens)
|
340 |
else:
|
341 |
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
342 |
+
prompt_tokens = gen_tokens[:, :, stride_tokens:]
|
343 |
+
prompt_length = prompt_tokens.shape[-1]
|
344 |
current_gen_offset += stride_tokens
|
|
|
345 |
|
346 |
gen_tokens = torch.cat(all_tokens, dim=-1)
|
347 |
|