Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -508,9 +508,6 @@ class MusicgenStreamer(BaseStreamer):
|
|
508 |
self.to_yield = 0
|
509 |
|
510 |
self.is_longform = is_longform
|
511 |
-
if is_longform:
|
512 |
-
self.longform_stride = model.stride_longform
|
513 |
-
self.longform_stride_applied = True
|
514 |
|
515 |
# varibles used in the thread process
|
516 |
self.audio_queue = Queue()
|
@@ -564,15 +561,13 @@ class MusicgenStreamer(BaseStreamer):
|
|
564 |
if self.token_cache is None:
|
565 |
self.token_cache = value
|
566 |
else:
|
567 |
-
# if self.is_longform and not self.longform_stride_applied:
|
568 |
-
# value = value[self.longform_stride:]
|
569 |
-
# self.longform_stride_applied = True
|
570 |
self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
|
571 |
|
572 |
if self.token_cache.shape[-1] % self.play_steps == 0:
|
573 |
audio_values = self.apply_delay_pattern_mask(self.token_cache)
|
574 |
-
self.
|
575 |
-
|
|
|
576 |
|
577 |
def end(self, stream_end=False):
|
578 |
"""Flushes any remaining cache and appends the stop symbol."""
|
@@ -582,8 +577,6 @@ class MusicgenStreamer(BaseStreamer):
|
|
582 |
audio_values = np.zeros(self.to_yield)
|
583 |
|
584 |
stream_end = (not self.is_longform) or stream_end
|
585 |
-
if self.is_longform:
|
586 |
-
self.longform_stride_applied = False
|
587 |
self.on_finalized_audio(audio_values[self.to_yield :], stream_end=stream_end)
|
588 |
|
589 |
def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
|
@@ -656,13 +649,13 @@ def generate_audio(text_prompt, audio, audio_length_in_s=10.0, play_steps_in_s=2
|
|
656 |
return_tensors="pt",
|
657 |
)
|
658 |
|
659 |
-
streamer = MusicgenStreamer(model, device=device, play_steps=play_steps, is_longform=True, )
|
660 |
|
661 |
generation_kwargs = dict(
|
662 |
**inputs.to(device),
|
663 |
temperature=1.2,
|
664 |
streamer=streamer,
|
665 |
-
max_new_tokens=min(max_new_tokens,
|
666 |
max_longform_generation_length=max_new_tokens,
|
667 |
)
|
668 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
@@ -685,7 +678,7 @@ demo = gr.Interface(
|
|
685 |
inputs=[
|
686 |
gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
|
687 |
gr.Audio(type="filepath", label="Conditioning audio. Use this for melody-guided generation."),
|
688 |
-
gr.Slider(
|
689 |
gr.Slider(0.5, 2.5, value=1.5, step=0.5, label="Streaming interval in seconds.", info="Lower = shorter chunks, lower latency, more codec steps."),
|
690 |
gr.Number(value=5, precision=0, step=1, minimum=0, label="Seed for random generations."),
|
691 |
],
|
|
|
508 |
self.to_yield = 0
|
509 |
|
510 |
self.is_longform = is_longform
|
|
|
|
|
|
|
511 |
|
512 |
# varibles used in the thread process
|
513 |
self.audio_queue = Queue()
|
|
|
561 |
if self.token_cache is None:
|
562 |
self.token_cache = value
|
563 |
else:
|
|
|
|
|
|
|
564 |
self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
|
565 |
|
566 |
if self.token_cache.shape[-1] % self.play_steps == 0:
|
567 |
audio_values = self.apply_delay_pattern_mask(self.token_cache)
|
568 |
+
if self.to_yield != len(audio_values) - self.stride:
|
569 |
+
self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
|
570 |
+
self.to_yield += len(audio_values) - self.to_yield - self.stride
|
571 |
|
572 |
def end(self, stream_end=False):
|
573 |
"""Flushes any remaining cache and appends the stop symbol."""
|
|
|
577 |
audio_values = np.zeros(self.to_yield)
|
578 |
|
579 |
stream_end = (not self.is_longform) or stream_end
|
|
|
|
|
580 |
self.on_finalized_audio(audio_values[self.to_yield :], stream_end=stream_end)
|
581 |
|
582 |
def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
|
|
|
649 |
return_tensors="pt",
|
650 |
)
|
651 |
|
652 |
+
streamer = MusicgenStreamer(model, device=device, play_steps=play_steps, is_longform=True, stride=1)
|
653 |
|
654 |
generation_kwargs = dict(
|
655 |
**inputs.to(device),
|
656 |
temperature=1.2,
|
657 |
streamer=streamer,
|
658 |
+
max_new_tokens=min(max_new_tokens, 1503),
|
659 |
max_longform_generation_length=max_new_tokens,
|
660 |
)
|
661 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
|
|
678 |
inputs=[
|
679 |
gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
|
680 |
gr.Audio(type="filepath", label="Conditioning audio. Use this for melody-guided generation."),
|
681 |
+
gr.Slider(30, 60, value=45, step=5, label="(Approximate) Audio length in seconds."),
|
682 |
gr.Slider(0.5, 2.5, value=1.5, step=0.5, label="Streaming interval in seconds.", info="Lower = shorter chunks, lower latency, more codec steps."),
|
683 |
gr.Number(value=5, precision=0, step=1, minimum=0, label="Seed for random generations."),
|
684 |
],
|