Gpagejr12 commited on
Commit
c9d1263
·
verified ·
1 Parent(s): 6f1d0eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -12,43 +12,52 @@ genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical",
12
  "Techno","Indie Rock", "Grunge", "Ambient","Gospel", "Latin Music","Grime" ,"Trap", "Psychedelic Rock" ]
13
 
14
  @st.cache_resource()
 
15
  def load_model():
16
  model = MusicGen.get_pretrained('facebook/musicgen-small')
17
  return model
18
 
19
  def generate_music_tensors(descriptions, duration: int, device):
 
20
  model = load_model()
 
21
 
22
  model.set_generation_params(
23
  use_sampling=True,
24
  top_k=250,
25
- duration=duration
26
  )
27
 
28
  with st.spinner("Generating Music..."):
 
29
  output = model.generate(
30
  descriptions=descriptions,
31
  progress=True,
32
  return_tokens=True,
33
- device=device # Pass the device to the generate method
34
  )
35
 
 
 
 
36
  st.success("Music Generation Complete!")
37
  return output
38
 
39
 
40
- def save_audio(samples: torch.Tensor):
 
41
  sample_rate = 30000
42
  save_path = "audio_output"
43
  assert samples.dim() == 2 or samples.dim() == 3
44
 
45
- samples = samples.detach().cpu()
46
  if samples.dim() == 2:
47
  samples = samples[None, ...]
48
 
49
  for idx, audio in enumerate(samples):
50
  audio_path = os.path.join(save_path, f"audio_{idx}.wav")
51
- torchaudio.save(audio_path, audio, sample_rate)
 
52
 
53
  def get_binary_file_downloader_html(bin_file, file_label='File'):
54
  with open(bin_file, 'rb') as f:
 
12
  "Techno","Indie Rock", "Grunge", "Ambient","Gospel", "Latin Music","Grime" ,"Trap", "Psychedelic Rock" ]
13
 
14
  @st.cache_resource()
15
+
16
  def load_model():
17
  model = MusicGen.get_pretrained('facebook/musicgen-small')
18
  return model
19
 
20
  def generate_music_tensors(descriptions, duration: int, device):
21
+ # Load the model and move it to the specified device
22
  model = load_model()
23
+ model = model.to(device)
24
 
25
  model.set_generation_params(
26
  use_sampling=True,
27
  top_k=250,
28
+ duration=duration * 60 # Multiply by 60 to convert minutes to seconds
29
  )
30
 
31
  with st.spinner("Generating Music..."):
32
+ # Generate music using the model
33
  output = model.generate(
34
  descriptions=descriptions,
35
  progress=True,
36
  return_tokens=True,
37
+ device=device
38
  )
39
 
40
+ # Save the generated music audio
41
+ save_audio(output, device)
42
+
43
  st.success("Music Generation Complete!")
44
  return output
45
 
46
 
47
+
48
+ def save_audio(samples: torch.Tensor, device):
49
  sample_rate = 30000
50
  save_path = "audio_output"
51
  assert samples.dim() == 2 or samples.dim() == 3
52
 
53
+ samples = samples.to(device) # Move the samples to the device
54
  if samples.dim() == 2:
55
  samples = samples[None, ...]
56
 
57
  for idx, audio in enumerate(samples):
58
  audio_path = os.path.join(save_path, f"audio_{idx}.wav")
59
+ torchaudio.save(audio_path, audio.cpu(), sample_rate) # Move the audio to the CPU before saving
60
+
61
 
62
  def get_binary_file_downloader_html(bin_file, file_label='File'):
63
  with open(bin_file, 'rb') as f: