ylacombe commited on
Commit
b4777d5
·
verified ·
1 Parent(s): 2ab4d19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -13
app.py CHANGED
@@ -508,9 +508,6 @@ class MusicgenStreamer(BaseStreamer):
508
  self.to_yield = 0
509
 
510
  self.is_longform = is_longform
511
- if is_longform:
512
- self.longform_stride = model.stride_longform
513
- self.longform_stride_applied = True
514
 
515
  # varibles used in the thread process
516
  self.audio_queue = Queue()
@@ -564,15 +561,13 @@ class MusicgenStreamer(BaseStreamer):
564
  if self.token_cache is None:
565
  self.token_cache = value
566
  else:
567
- # if self.is_longform and not self.longform_stride_applied:
568
- # value = value[self.longform_stride:]
569
- # self.longform_stride_applied = True
570
  self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
571
 
572
  if self.token_cache.shape[-1] % self.play_steps == 0:
573
  audio_values = self.apply_delay_pattern_mask(self.token_cache)
574
- self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
575
- self.to_yield += len(audio_values) - self.to_yield - self.stride
 
576
 
577
  def end(self, stream_end=False):
578
  """Flushes any remaining cache and appends the stop symbol."""
@@ -582,8 +577,6 @@ class MusicgenStreamer(BaseStreamer):
582
  audio_values = np.zeros(self.to_yield)
583
 
584
  stream_end = (not self.is_longform) or stream_end
585
- if self.is_longform:
586
- self.longform_stride_applied = False
587
  self.on_finalized_audio(audio_values[self.to_yield :], stream_end=stream_end)
588
 
589
  def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
@@ -656,13 +649,13 @@ def generate_audio(text_prompt, audio, audio_length_in_s=10.0, play_steps_in_s=2
656
  return_tensors="pt",
657
  )
658
 
659
- streamer = MusicgenStreamer(model, device=device, play_steps=play_steps, is_longform=True, )
660
 
661
  generation_kwargs = dict(
662
  **inputs.to(device),
663
  temperature=1.2,
664
  streamer=streamer,
665
- max_new_tokens=min(max_new_tokens, 1500),
666
  max_longform_generation_length=max_new_tokens,
667
  )
668
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
@@ -685,7 +678,7 @@ demo = gr.Interface(
685
  inputs=[
686
  gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
687
  gr.Audio(type="filepath", label="Conditioning audio. Use this for melody-guided generation."),
688
- gr.Slider(35, 60, value=45, step=5, label="(Approximate) Audio length in seconds."),
689
  gr.Slider(0.5, 2.5, value=1.5, step=0.5, label="Streaming interval in seconds.", info="Lower = shorter chunks, lower latency, more codec steps."),
690
  gr.Number(value=5, precision=0, step=1, minimum=0, label="Seed for random generations."),
691
  ],
 
508
  self.to_yield = 0
509
 
510
  self.is_longform = is_longform
 
 
 
511
 
512
  # varibles used in the thread process
513
  self.audio_queue = Queue()
 
561
  if self.token_cache is None:
562
  self.token_cache = value
563
  else:
 
 
 
564
  self.token_cache = torch.concatenate([self.token_cache, value[:, None]], dim=-1)
565
 
566
  if self.token_cache.shape[-1] % self.play_steps == 0:
567
  audio_values = self.apply_delay_pattern_mask(self.token_cache)
568
+ if self.to_yield != len(audio_values) - self.stride:
569
+ self.on_finalized_audio(audio_values[self.to_yield : -self.stride])
570
+ self.to_yield += len(audio_values) - self.to_yield - self.stride
571
 
572
  def end(self, stream_end=False):
573
  """Flushes any remaining cache and appends the stop symbol."""
 
577
  audio_values = np.zeros(self.to_yield)
578
 
579
  stream_end = (not self.is_longform) or stream_end
 
 
580
  self.on_finalized_audio(audio_values[self.to_yield :], stream_end=stream_end)
581
 
582
  def on_finalized_audio(self, audio: np.ndarray, stream_end: bool = False):
 
649
  return_tensors="pt",
650
  )
651
 
652
+ streamer = MusicgenStreamer(model, device=device, play_steps=play_steps, is_longform=True, stride=1)
653
 
654
  generation_kwargs = dict(
655
  **inputs.to(device),
656
  temperature=1.2,
657
  streamer=streamer,
658
+ max_new_tokens=min(max_new_tokens, 1503),
659
  max_longform_generation_length=max_new_tokens,
660
  )
661
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
678
  inputs=[
679
  gr.Text(label="Prompt", value="80s pop track with synth and instrumentals"),
680
  gr.Audio(type="filepath", label="Conditioning audio. Use this for melody-guided generation."),
681
+ gr.Slider(30, 60, value=45, step=5, label="(Approximate) Audio length in seconds."),
682
  gr.Slider(0.5, 2.5, value=1.5, step=0.5, label="Streaming interval in seconds.", info="Lower = shorter chunks, lower latency, more codec steps."),
683
  gr.Number(value=5, precision=0, step=1, minimum=0, label="Seed for random generations."),
684
  ],