Surn commited on
Commit
9634e77
·
1 Parent(s): af50291

Updates Round3

Browse files
app.py CHANGED
@@ -17,6 +17,7 @@ from pathlib import Path
17
  import time
18
  import typing as tp
19
  import warnings
 
20
  from audiocraft.models import MusicGen
21
  from audiocraft.data.audio import audio_write
22
  from audiocraft.data.audio_utils import apply_fade, apply_tafade, apply_splice_effect
@@ -48,6 +49,7 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segmen
48
  # os.environ['USE_FLASH_ATTENTION'] = '1'
49
  # os.environ['XFORMERS_FORCE_DISABLE_TRITON']= '1'
50
 
 
51
  def interrupt_callback():
52
  return INTERRUPTED
53
 
@@ -162,7 +164,7 @@ def load_melody_filepath(melody_filepath, title, assigned_model):
162
 
163
  return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=0), gr.update(value=assigned_model, interactive=True)
164
 
165
- def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False):
166
  global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU
167
  output_segments = None
168
  melody_name = "Not Used"
@@ -228,14 +230,16 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
228
  cfg_coef=cfg_coef,
229
  duration=segment_duration,
230
  two_step_cfg=False,
 
231
  rep_penalty=0.5
232
  )
 
233
 
234
  try:
235
  if melody:
236
  # return excess duration, load next model and continue in loop structure building up output_segments
237
  if duration > MODEL.lm.cfg.dataset.segment_duration:
238
- output_segments, duration = generate_music_segments(text, melody, seed, MODEL, duration, overlap, MODEL.lm.cfg.dataset.segment_duration, prompt_index, harmony_only=False)
239
  else:
240
  # pure original code
241
  sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
@@ -247,20 +251,20 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
247
  descriptions=[text],
248
  melody_wavs=melody,
249
  melody_sample_rate=sr,
250
- progress=False
251
  )
252
  # All output_segments are populated, so we can break the loop or set duration to 0
253
  break
254
  else:
255
  #output = MODEL.generate(descriptions=[text], progress=False)
256
  if not output_segments:
257
- next_segment = MODEL.generate(descriptions=[text], progress=True)
258
  duration -= segment_duration
259
  else:
260
  last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:]
261
- next_segment = MODEL.generate_continuation(last_chunk, MODEL.sample_rate, descriptions=[text], progress=True)
262
  duration -= segment_duration - overlap
263
- if next_segment != None:
264
  output_segments.append(next_segment)
265
  except Exception as e:
266
  print(f"Error generating audio: {e}")
@@ -312,7 +316,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
312
  return None, None, seed
313
  else:
314
  output = output.detach().cpu().float()[0]
315
- profile: gr.OAuthProfile | None = None
316
  title_file_name = convert_title_to_filename(title)
317
  with NamedTemporaryFile("wb", suffix=".wav", delete=False, prefix = title_file_name) as file:
318
  video_description = f"{text}\n Duration: {str(initial_duration)} Dimension: {dimension}\n Top-k:{topk} Top-p:{topp}\n Randomness:{temperature}\n cfg:{cfg_coef} overlap: {overlap}\n Seed: {seed}\n Model: {model}\n Melody Condition:{melody_name}\n Sample Segment: {prompt_index}"
@@ -357,7 +361,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
357
  "background": background,
358
  "include_title": include_title,
359
  "include_settings": include_settings,
360
- "profile": profile,
361
  "commit": commit_hash(),
362
  "tag": git_tag(),
363
  "version": gr.__version__,
@@ -396,11 +400,11 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
396
 
397
  if waveform_video_path:
398
  modules.user_history.save_file(
399
- profile=profile,
400
  image=background,
401
- audio=file,
402
  video=waveform_video_path,
403
- label=text,
404
  metadata=metadata,
405
  )
406
 
@@ -423,7 +427,7 @@ def ui(**kwargs):
423
  This is your private demo for [UnlimitedMusicGen](https://github.com/Oncorporation/audiocraft), a simple and controllable model for music generation
424
  presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
425
 
426
- Disclaimer: This won't run on CPU only. Clone this App and run on GPU instance!
427
 
428
  Todo: Working on improved Interrupt.
429
  Theme Available at ["Surn/Beeuty"](https://huggingface.co/spaces/Surn/Beeuty)
@@ -482,12 +486,12 @@ def ui(**kwargs):
482
  with gr.Column() as c:
483
  output = gr.Video(label="Generated Music")
484
  wave_file = gr.File(label=".wav file", elem_id="output_wavefile", interactive=True)
485
- seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
486
 
487
  radio.change(toggle_audio_src, radio, [melody_filepath], queue=False, show_progress=False)
488
  melody_filepath.change(load_melody_filepath, inputs=[melody_filepath, title, model], outputs=[title, prompt_index , model], api_name="melody_filepath_change", queue=False)
489
  reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False, api_name="reuse_seed")
490
- submit.click(predict, inputs=[model, text,melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap, prompt_index, include_title, include_settings, harmony_only], outputs=[output, wave_file, seed_used], api_name="submit")
491
  gr.Examples(
492
  examples=[
493
  [
@@ -524,10 +528,25 @@ def ui(**kwargs):
524
  inputs=[text, melody_filepath, model, title],
525
  outputs=[output]
526
  )
527
- gr.HTML(value=versions_html(), visible=True, elem_id="versions")
528
  with gr.Tab("User History") as history_tab:
529
  modules.user_history.render()
 
530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531
  # Show the interface
532
  launch_kwargs = {}
533
  share = kwargs.get('share', False)
 
17
  import time
18
  import typing as tp
19
  import warnings
20
+ from tqdm import tqdm
21
  from audiocraft.models import MusicGen
22
  from audiocraft.data.audio import audio_write
23
  from audiocraft.data.audio_utils import apply_fade, apply_tafade, apply_splice_effect
 
49
  # os.environ['USE_FLASH_ATTENTION'] = '1'
50
  # os.environ['XFORMERS_FORCE_DISABLE_TRITON']= '1'
51
 
52
+
53
  def interrupt_callback():
54
  return INTERRUPTED
55
 
 
164
 
165
  return gr.update(value=melody_name), gr.update(maximum=MAX_PROMPT_INDEX, value=0), gr.update(value=assigned_model, interactive=True)
166
 
167
+ def predict(model, text, melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap=1, prompt_index = 0, include_title = True, include_settings = True, harmony_only = False, profile = gr.OAuthProfile, progress=gr.Progress(track_tqdm=True)):
168
  global MODEL, INTERRUPTED, INTERRUPTING, MOVE_TO_CPU
169
  output_segments = None
170
  melody_name = "Not Used"
 
230
  cfg_coef=cfg_coef,
231
  duration=segment_duration,
232
  two_step_cfg=False,
233
+ extend_stride=10,
234
  rep_penalty=0.5
235
  )
236
+ MODEL.set_custom_progress_callback(gr.Progress(track_tqdm=True))
237
 
238
  try:
239
  if melody:
240
  # return excess duration, load next model and continue in loop structure building up output_segments
241
  if duration > MODEL.lm.cfg.dataset.segment_duration:
242
+ output_segments, duration = generate_music_segments(text, melody, seed, MODEL, duration, overlap, MODEL.lm.cfg.dataset.segment_duration, prompt_index, harmony_only=False, progress=gr.Progress(track_tqdm=True))
243
  else:
244
  # pure original code
245
  sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t().unsqueeze(0)
 
251
  descriptions=[text],
252
  melody_wavs=melody,
253
  melody_sample_rate=sr,
254
+ progress=True, progress_callback=gr.Progress(track_tqdm=True)
255
  )
256
  # All output_segments are populated, so we can break the loop or set duration to 0
257
  break
258
  else:
259
  #output = MODEL.generate(descriptions=[text], progress=False)
260
  if not output_segments:
261
+ next_segment = MODEL.generate(descriptions=[text], progress=True, progress_callback=gr.Progress(track_tqdm=True))
262
  duration -= segment_duration
263
  else:
264
  last_chunk = output_segments[-1][:, :, -overlap*MODEL.sample_rate:]
265
+ next_segment = MODEL.generate_continuation(last_chunk, MODEL.sample_rate, descriptions=[text], progress=True, progress_callback=gr.Progress(track_tqdm=True))
266
  duration -= segment_duration - overlap
267
+ if next_segment != None:
268
  output_segments.append(next_segment)
269
  except Exception as e:
270
  print(f"Error generating audio: {e}")
 
316
  return None, None, seed
317
  else:
318
  output = output.detach().cpu().float()[0]
319
+
320
  title_file_name = convert_title_to_filename(title)
321
  with NamedTemporaryFile("wb", suffix=".wav", delete=False, prefix = title_file_name) as file:
322
  video_description = f"{text}\n Duration: {str(initial_duration)} Dimension: {dimension}\n Top-k:{topk} Top-p:{topp}\n Randomness:{temperature}\n cfg:{cfg_coef} overlap: {overlap}\n Seed: {seed}\n Model: {model}\n Melody Condition:{melody_name}\n Sample Segment: {prompt_index}"
 
361
  "background": background,
362
  "include_title": include_title,
363
  "include_settings": include_settings,
364
+ "profile": "Satoshi Nakamoto" if profile.value is None else profile.value.username,
365
  "commit": commit_hash(),
366
  "tag": git_tag(),
367
  "version": gr.__version__,
 
400
 
401
  if waveform_video_path:
402
  modules.user_history.save_file(
403
+ profile=profile.value,
404
  image=background,
405
+ audio=file.name,
406
  video=waveform_video_path,
407
+ label=title,
408
  metadata=metadata,
409
  )
410
 
 
427
  This is your private demo for [UnlimitedMusicGen](https://github.com/Oncorporation/audiocraft), a simple and controllable model for music generation
428
  presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284)
429
 
430
+ Disclaimer: This won't run on CPU only. Clone this App and run on GPU instance!
431
 
432
  Todo: Working on improved Interrupt.
433
  Theme Available at ["Surn/Beeuty"](https://huggingface.co/spaces/Surn/Beeuty)
 
486
  with gr.Column() as c:
487
  output = gr.Video(label="Generated Music")
488
  wave_file = gr.File(label=".wav file", elem_id="output_wavefile", interactive=True)
489
+ seed_used = gr.Number(label='Seed used', value=-1, interactive=False)
490
 
491
  radio.change(toggle_audio_src, radio, [melody_filepath], queue=False, show_progress=False)
492
  melody_filepath.change(load_melody_filepath, inputs=[melody_filepath, title, model], outputs=[title, prompt_index , model], api_name="melody_filepath_change", queue=False)
493
  reuse_seed.click(fn=lambda x: x, inputs=[seed_used], outputs=[seed], queue=False, api_name="reuse_seed")
494
+
495
  gr.Examples(
496
  examples=[
497
  [
 
528
  inputs=[text, melody_filepath, model, title],
529
  outputs=[output]
530
  )
531
+
532
  with gr.Tab("User History") as history_tab:
533
  modules.user_history.render()
534
+ user_profile = gr.State(None)
535
 
536
+ with gr.Row("Versions") as versions_row:
537
+ gr.HTML(value=versions_html(), visible=True, elem_id="versions")
538
+
539
+ submit.click(
540
+ modules.user_history.get_profile,
541
+ inputs=[],
542
+ outputs=[user_profile],
543
+ queue=True,
544
+ api_name="submit"
545
+ ).then(
546
+ predict,
547
+ inputs=[model, text,melody_filepath, duration, dimension, topk, topp, temperature, cfg_coef, background, title, settings_font, settings_font_color, seed, overlap, prompt_index, include_title, include_settings, harmony_only, user_profile],
548
+ outputs=[output, wave_file, seed_used])
549
+
550
  # Show the interface
551
  launch_kwargs = {}
552
  share = kwargs.get('share', False)
audiocraft/__init__.py CHANGED
@@ -7,4 +7,4 @@
7
  # flake8: noqa
8
  from . import data, modules, models
9
 
10
- __version__ = '1.4.Surn'
 
7
  # flake8: noqa
8
  from . import data, modules, models
9
 
10
+ __version__ = '1.2.Surn'
audiocraft/models/musicgen.py CHANGED
@@ -15,6 +15,7 @@ import warnings
15
 
16
  import omegaconf
17
  import torch
 
18
 
19
  from .encodec import CompressionModel
20
  from .lm import LMModel
@@ -67,7 +68,7 @@ class MusicGen:
67
  self.device = next(iter(lm.parameters())).device
68
  self.generation_params: dict = {}
69
  self.set_generation_params(duration=self.duration) # 15 seconds by default
70
- self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
71
  if self.device.type == 'cpu':
72
  self.autocast = TorchAutocast(enabled=False)
73
  else:
@@ -142,7 +143,7 @@ class MusicGen:
142
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
143
  top_p: float = 0.0, temperature: float = 1.0,
144
  duration: float = 30.0, cfg_coef: float = 3.0,
145
- two_step_cfg: bool = False, extend_stride: float = 18, rep_penalty: float = None):
146
  """Set the generation parameters for MusicGen.
147
 
148
  Args:
@@ -173,12 +174,12 @@ class MusicGen:
173
  'two_step_cfg': two_step_cfg,
174
  }
175
 
176
- def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
177
  """Override the default progress callback."""
178
  self._progress_callback = progress_callback
179
 
180
  def generate_unconditional(self, num_samples: int, progress: bool = False,
181
- return_tokens: bool = False) -> tp.Union[torch.Tensor,
182
  tp.Tuple[torch.Tensor, torch.Tensor]]:
183
  """Generate samples in an unconditional manner.
184
 
@@ -194,7 +195,7 @@ class MusicGen:
194
  return self.generate_audio(tokens), tokens
195
  return self.generate_audio(tokens)
196
 
197
- def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
198
  -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
199
  """Generate samples conditioned on text.
200
 
@@ -212,7 +213,7 @@ class MusicGen:
212
 
213
  def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
214
  melody_sample_rate: int, progress: bool = False,
215
- return_tokens: bool = False) -> tp.Union[torch.Tensor,
216
  tp.Tuple[torch.Tensor, torch.Tensor]]:
217
  """Generate samples conditioned on text and melody.
218
 
@@ -250,7 +251,7 @@ class MusicGen:
250
  return self.generate_audio(tokens)
251
 
252
  def generate_with_all(self, descriptions: tp.List[str], melody_wavs: MelodyType,
253
- sample_rate: int, progress: bool = False, prompt: tp.Optional[torch.Tensor] = None, return_tokens: bool = False) \
254
  -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
255
  """Generate samples conditioned on text and melody and audio prompts.
256
  Args:
@@ -307,7 +308,7 @@ class MusicGen:
307
 
308
  def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
309
  descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
310
- progress: bool = False, return_tokens: bool = False) \
311
  -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
312
  """Generate samples conditioned on audio prompts.
313
 
@@ -317,7 +318,8 @@ class MusicGen:
317
  prompt_sample_rate (int): Sampling rate of the given audio waveforms.
318
  descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
319
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
320
- return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.
 
321
  """
322
  if prompt.dim() == 2:
323
  prompt = prompt[None]
@@ -338,7 +340,8 @@ class MusicGen:
338
  self,
339
  descriptions: tp.Sequence[tp.Optional[str]],
340
  prompt: tp.Optional[torch.Tensor],
341
- melody_wavs: tp.Optional[MelodyList] = None,
 
342
  ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
343
  """Prepare model inputs.
344
 
@@ -392,7 +395,7 @@ class MusicGen:
392
  return attributes, prompt_tokens
393
 
394
  def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
395
- prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
396
  """Generate discrete audio tokens given audio prompt and/or conditions.
397
 
398
  Args:
@@ -411,17 +414,19 @@ class MusicGen:
411
  if self._progress_callback is not None:
412
  # Note that total_gen_len might be quite wrong depending on the
413
  # codebook pattern used, but with delay it is almost accurate.
414
- self._progress_callback(generated_tokens, total_gen_len)
415
- else:
 
 
 
416
  print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
417
 
418
  if prompt_tokens is not None:
419
  assert max_prompt_len >= prompt_tokens.shape[-1], \
420
  "Prompt is longer than audio to generate"
421
 
422
- callback = None
423
- if progress:
424
- callback = _progress_callback
425
 
426
  if self.duration <= self.max_duration:
427
  # generate by sampling from LM, simple case.
@@ -481,7 +486,7 @@ class MusicGen:
481
 
482
  # generate audio
483
 
484
- def generate_audio(self, gen_tokens: torch.Tensor):
485
  try:
486
  """Generate Audio from tokens"""
487
  assert gen_tokens.dim() == 3
 
15
 
16
  import omegaconf
17
  import torch
18
+ import gradio as gr
19
 
20
  from .encodec import CompressionModel
21
  from .lm import LMModel
 
68
  self.device = next(iter(lm.parameters())).device
69
  self.generation_params: dict = {}
70
  self.set_generation_params(duration=self.duration) # 15 seconds by default
71
+ self._progress_callback: tp.Union[tp.Callable[[int, int], None], gr.Progress] = None
72
  if self.device.type == 'cpu':
73
  self.autocast = TorchAutocast(enabled=False)
74
  else:
 
143
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
144
  top_p: float = 0.0, temperature: float = 1.0,
145
  duration: float = 30.0, cfg_coef: float = 3.0,
146
+ two_step_cfg: bool = False, extend_stride: float = 10, rep_penalty: float = None):
147
  """Set the generation parameters for MusicGen.
148
 
149
  Args:
 
174
  'two_step_cfg': two_step_cfg,
175
  }
176
 
177
+ def set_custom_progress_callback(self, progress_callback: tp.Union[tp.Callable[[int, int], None],gr.Progress] = None):
178
  """Override the default progress callback."""
179
  self._progress_callback = progress_callback
180
 
181
  def generate_unconditional(self, num_samples: int, progress: bool = False,
182
+ return_tokens: bool = False, progress_callback: gr.Progress = None) -> tp.Union[torch.Tensor,
183
  tp.Tuple[torch.Tensor, torch.Tensor]]:
184
  """Generate samples in an unconditional manner.
185
 
 
195
  return self.generate_audio(tokens), tokens
196
  return self.generate_audio(tokens)
197
 
198
+ def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False, progress_callback: gr.Progress = None) \
199
  -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
200
  """Generate samples conditioned on text.
201
 
 
213
 
214
  def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
215
  melody_sample_rate: int, progress: bool = False,
216
+ return_tokens: bool = False, progress_callback=gr.Progress(track_tqdm=True)) -> tp.Union[torch.Tensor,
217
  tp.Tuple[torch.Tensor, torch.Tensor]]:
218
  """Generate samples conditioned on text and melody.
219
 
 
251
  return self.generate_audio(tokens)
252
 
253
  def generate_with_all(self, descriptions: tp.List[str], melody_wavs: MelodyType,
254
+ sample_rate: int, progress: bool = False, prompt: tp.Optional[torch.Tensor] = None, return_tokens: bool = False, progress_callback: gr.Progress = None) \
255
  -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
256
  """Generate samples conditioned on text and melody and audio prompts.
257
  Args:
 
308
 
309
  def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
310
  descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
311
+ progress: bool = False, return_tokens: bool = False, progress_callback: gr.Progress = None) \
312
  -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
313
  """Generate samples conditioned on audio prompts.
314
 
 
318
  prompt_sample_rate (int): Sampling rate of the given audio waveforms.
319
  descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
320
  progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
321
+ return_tokens (bool, optional): If True, also return the generated tokens. Defaults to False.\
322
+ This is truly a hack and does not follow the progression of conditioning melody or previously generated audio.
323
  """
324
  if prompt.dim() == 2:
325
  prompt = prompt[None]
 
340
  self,
341
  descriptions: tp.Sequence[tp.Optional[str]],
342
  prompt: tp.Optional[torch.Tensor],
343
+ melody_wavs: tp.Optional[MelodyList] = None,
344
+ progress_callback: tp.Optional[gr.Progress] = None
345
  ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
346
  """Prepare model inputs.
347
 
 
395
  return attributes, prompt_tokens
396
 
397
  def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
398
+ prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False, progress_callback: gr.Progress = None) -> torch.Tensor:
399
  """Generate discrete audio tokens given audio prompt and/or conditions.
400
 
401
  Args:
 
414
  if self._progress_callback is not None:
415
  # Note that total_gen_len might be quite wrong depending on the
416
  # codebook pattern used, but with delay it is almost accurate.
417
+ self._progress_callback((generated_tokens / total_gen_len), f"Generated {generated_tokens}/{total_gen_len} tokens")
418
+ if progress_callback is not None:
419
+ # Update Gradio progress bar
420
+ progress_callback((generated_tokens / total_gen_len), f"Generated {generated_tokens}/{total_gen_len} tokens")
421
+ if progress:
422
  print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
423
 
424
  if prompt_tokens is not None:
425
  assert max_prompt_len >= prompt_tokens.shape[-1], \
426
  "Prompt is longer than audio to generate"
427
 
428
+ # callback = None
429
+ callback = _progress_callback
 
430
 
431
  if self.duration <= self.max_duration:
432
  # generate by sampling from LM, simple case.
 
486
 
487
  # generate audio
488
 
489
+ def generate_audio(self, gen_tokens: torch.Tensor):
490
  try:
491
  """Generate Audio from tokens"""
492
  assert gen_tokens.dim() == 3
audiocraft/utils/extend.py CHANGED
@@ -12,6 +12,7 @@ import requests
12
  from io import BytesIO
13
  from huggingface_hub import hf_hub_download
14
  import librosa
 
15
 
16
 
17
  INTERRUPTING = False
@@ -48,7 +49,7 @@ def separate_audio_segments(audio, segment_duration=30, overlap=1):
48
  print(f"separate_audio_segments: {len(segments)} segments of length {segment_samples // sr} seconds")
49
  return segments
50
 
51
- def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:int=1, segment_duration:int=30, prompt_index:int=0, harmony_only:bool= False):
52
  # generate audio segments
53
  melody_segments = separate_audio_segments(melody, segment_duration, 0)
54
 
 
12
  from io import BytesIO
13
  from huggingface_hub import hf_hub_download
14
  import librosa
15
+ import gradio as gr
16
 
17
 
18
  INTERRUPTING = False
 
49
  print(f"separate_audio_segments: {len(segments)} segments of length {segment_samples // sr} seconds")
50
  return segments
51
 
52
+ def generate_music_segments(text, melody, seed, MODEL, duration:int=10, overlap:int=1, segment_duration:int=30, prompt_index:int=0, harmony_only:bool= False, progress= gr.Progress(track_tqdm=True)):
53
  # generate audio segments
54
  melody_segments = separate_audio_segments(melody, segment_duration, 0)
55
 
modules/user_history.py CHANGED
@@ -14,9 +14,11 @@ Useful links:
14
  - README: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/README.md
15
  - Source file: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/user_history.py
16
  - Discussions: https://huggingface.co/spaces/Wauplin/gradio-user-history/discussions
 
 
17
  """
18
 
19
- __version__ = "0.2.0"
20
 
21
  import json
22
  import os
@@ -39,6 +41,13 @@ from mutagen.mp3 import MP3, EasyMP3
39
  import torchaudio
40
  import subprocess
41
 
 
 
 
 
 
 
 
42
 
43
  def setup(folder_path: str | Path | None = None) -> None:
44
  user_history = _UserHistory()
@@ -205,23 +214,30 @@ def save_file(
205
  image_path = _copy_image(image, dst_folder=user_history._user_images_path(username))
206
  image_path = _add_metadata(image_path, metadata)
207
 
 
208
  # Copy video to storage
209
  if video is not None:
210
  video_path = _copy_file(video, dst_folder=user_history._user_file_path(username, "videos"))
211
  video_path = _add_metadata(video_path, metadata)
212
 
 
213
  # Copy audio to storage
214
  if audio is not None:
215
  audio_path = _copy_file(audio, dst_folder=user_history._user_file_path(username, "audios"))
216
  audio_path = _add_metadata(audio_path, metadata)
217
 
 
218
  # Copy document to storage
219
  if document is not None:
220
  document_path = _copy_file(document, dst_folder=user_history._user_file_path(username, "documents"))
221
  document_path = _add_metadata(document_path, metadata)
222
 
 
 
 
 
223
  # Save Json file
224
- data = {"image_path": str(image_path), "video_path": str(video_path), "audio_path": str(audio_path), "document_path": str(document_path), "label": label, "metadata": metadata}
225
  with user_history._user_lock(username):
226
  with user_history._user_jsonl_path(username).open("a") as f:
227
  f.write(json.dumps(data) + "\n")
@@ -266,14 +282,34 @@ class _UserHistory(object):
266
  path.mkdir(parents=True, exist_ok=True)
267
  return path
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
 
271
  def _fetch_user_history(profile: gr.OAuthProfile | None) -> List[Tuple[str, str]]:
272
  """Return saved history for that user, if it exists."""
273
  # Cannot load history for logged out users
 
274
  if profile is None:
 
275
  return []
276
- username = profile["preferred_username"]
 
 
277
 
278
  user_history = _UserHistory()
279
  if not user_history.initialized:
@@ -290,7 +326,7 @@ def _fetch_user_history(profile: gr.OAuthProfile | None) -> List[Tuple[str, str]
290
  images = []
291
  for line in jsonl_path.read_text().splitlines():
292
  data = json.loads(line)
293
- images.append((data["path"], data["label"] or ""))
294
  return list(reversed(images))
295
 
296
 
 
14
  - README: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/README.md
15
  - Source file: https://huggingface.co/spaces/Wauplin/gradio-user-history/blob/main/user_history.py
16
  - Discussions: https://huggingface.co/spaces/Wauplin/gradio-user-history/discussions
17
+
18
+ Update by Surn (Charles Fettinger)
19
  """
20
 
21
+ __version__ = "0.2.1"
22
 
23
  import json
24
  import os
 
41
  import torchaudio
42
  import subprocess
43
 
44
+ user_profile = gr.State(None)
45
+
46
+ def get_profile() -> gr.OAuthProfile | None:
47
+ global user_profile
48
+ """Return the user profile if logged in, None otherwise."""
49
+
50
+ return user_profile
51
 
52
  def setup(folder_path: str | Path | None = None) -> None:
53
  user_history = _UserHistory()
 
214
  image_path = _copy_image(image, dst_folder=user_history._user_images_path(username))
215
  image_path = _add_metadata(image_path, metadata)
216
 
217
+ video_path = None
218
  # Copy video to storage
219
  if video is not None:
220
  video_path = _copy_file(video, dst_folder=user_history._user_file_path(username, "videos"))
221
  video_path = _add_metadata(video_path, metadata)
222
 
223
+ audio_path = None
224
  # Copy audio to storage
225
  if audio is not None:
226
  audio_path = _copy_file(audio, dst_folder=user_history._user_file_path(username, "audios"))
227
  audio_path = _add_metadata(audio_path, metadata)
228
 
229
+ document_path = None
230
  # Copy document to storage
231
  if document is not None:
232
  document_path = _copy_file(document, dst_folder=user_history._user_file_path(username, "documents"))
233
  document_path = _add_metadata(document_path, metadata)
234
 
235
+
236
+ # If no image, video, audio or document => nothing to save
237
+ if image_path is None and video_path is None and audio_path is None and document_path is None:
238
+ return
239
  # Save Json file
240
+ data = {"image_path": str(image_path), "video_path": str(video_path), "audio_path": str(audio_path), "document_path": str(document_path), "label": _UserHistory._sanitize_for_json(label), "metadata": _UserHistory._sanitize_for_json(metadata)}
241
  with user_history._user_lock(username):
242
  with user_history._user_jsonl_path(username).open("a") as f:
243
  f.write(json.dumps(data) + "\n")
 
282
  path.mkdir(parents=True, exist_ok=True)
283
  return path
284
 
285
+ @staticmethod
286
+ def _sanitize_for_json(obj: Any) -> Any:
287
+ """
288
+ Recursively convert non-serializable objects into their string representation.
289
+ """
290
+ if isinstance(obj, dict):
291
+ return {str(key): _UserHistory._sanitize_for_json(value) for key, value in obj.items()}
292
+ elif isinstance(obj, list):
293
+ return [_UserHistory._sanitize_for_json(item) for item in obj]
294
+ elif isinstance(obj, (str, int, float, bool)) or obj is None:
295
+ return obj
296
+ elif hasattr(obj, "isoformat"):
297
+ # For datetime objects and similar.
298
+ return obj.isoformat()
299
+ else:
300
+ return str(obj)
301
 
302
 
303
  def _fetch_user_history(profile: gr.OAuthProfile | None) -> List[Tuple[str, str]]:
304
  """Return saved history for that user, if it exists."""
305
  # Cannot load history for logged out users
306
+ global user_profile
307
  if profile is None:
308
+ user_profile = gr.State(None)
309
  return []
310
+ username = str(profile["preferred_username"])
311
+
312
+ user_profile = gr.State(profile)
313
 
314
  user_history = _UserHistory()
315
  if not user_history.initialized:
 
326
  images = []
327
  for line in jsonl_path.read_text().splitlines():
328
  data = json.loads(line)
329
+ images.append((data["image_path"], data["label"] or ""))
330
  return list(reversed(images))
331
 
332
 
style_20250331.css CHANGED
@@ -102,6 +102,10 @@ a {
102
  position: relative !important;
103
  }
104
 
 
 
 
 
105
  .gradio-container::before {
106
  content: ' ';
107
  display: block;
@@ -122,13 +126,15 @@ a {
122
  .gradio-container::after {
123
  content: '';
124
  position: absolute;
125
- top: 0;
126
  left: -60%; /* Start off-screen */
127
- width: 30%;
128
- height: 100%;
129
- background: linear-gradient( 92deg, rgba(255, 255, 255, 0) 25%, rgba(255, 255, 255, 0.60) 50%, rgba(255, 255, 255, 0) 75% );
130
- animation: shine 30s infinite;
131
- opacity:0.35;
 
 
132
  }
133
 
134
  #component-0, #component-1 {
@@ -213,4 +219,18 @@ a {
213
  100% {
214
  left: 125%;
215
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  }
 
102
  position: relative !important;
103
  }
104
 
105
+ .gradio-container {
106
+ position: relative;
107
+ }
108
+
109
  .gradio-container::before {
110
  content: ' ';
111
  display: block;
 
126
  .gradio-container::after {
127
  content: '';
128
  position: absolute;
129
+ top: -5%;
130
  left: -60%; /* Start off-screen */
131
+ width: 100%;
132
+ height: calc(100% + 150px);
133
+ background: -webkit-linear-gradient(to top right, rgba(255, 255, 255, 0) 0%, rgba(255, 255, 255, 0) 45%, rgba(255, 255, 255, 0.5) 48%, rgba(255, 255, 255, 0.8) 50%, rgba(255, 255, 255, 0.5) 52%, rgba(255, 255, 255, 0) 57%, rgba(255, 255, 255, 0) 100%);
134
+ animation: 15s infinite shine;
135
+ animation: shine 20s infinite;
136
+ opacity: 0.35;
137
+ z-index:2;
138
  }
139
 
140
  #component-0, #component-1 {
 
219
  100% {
220
  left: 125%;
221
  }
222
+ }
223
+
224
+ @keyframes shinebg {
225
+ 0% {
226
+ background-position: center, -100% 0;
227
+ }
228
+
229
+ 20% {
230
+ background-position: center, 100% 0;
231
+ }
232
+
233
+ 100% {
234
+ background-position: center, 125% 0;
235
+ }
236
  }