Gpagejr12 commited on
Commit
6f1d0eb
·
verified ·
1 Parent(s): f71571a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -6
app.py CHANGED
@@ -25,13 +25,12 @@ def generate_music_tensors(descriptions, duration: int, device):
25
  duration=duration
26
  )
27
 
28
- # Set the device during the forward pass
29
- with torch.no_grad():
30
  output = model.generate(
31
  descriptions=descriptions,
32
  progress=True,
33
  return_tokens=True,
34
- device=device
35
  )
36
 
37
  st.success("Music Generation Complete!")
@@ -98,15 +97,12 @@ def main():
98
  st.subheader("Generated Music")
99
 
100
  # Generate audio
101
- # descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(5)]
102
  descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(1)] # Change the batch size to 1
103
- music_tensors = generate_music_tensors(descriptions, time_slider)
104
 
105
  # Pass the device parameter when calling generate_music_tensors
106
  device = torch.device('cpu')
107
  music_tensors = generate_music_tensors(descriptions, time_slider, device)
108
 
109
- # Only play the full audio for index 0
110
  idx = 0
111
  music_tensor = music_tensors[idx]
112
  save_music_file = save_audio(music_tensor)
 
25
  duration=duration
26
  )
27
 
28
+ with st.spinner("Generating Music..."):
 
29
  output = model.generate(
30
  descriptions=descriptions,
31
  progress=True,
32
  return_tokens=True,
33
+ device=device # Pass the device to the generate method
34
  )
35
 
36
  st.success("Music Generation Complete!")
 
97
  st.subheader("Generated Music")
98
 
99
  # Generate audio
 
100
  descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(1)] # Change the batch size to 1
 
101
 
102
  # Pass the device parameter when calling generate_music_tensors
103
  device = torch.device('cpu')
104
  music_tensors = generate_music_tensors(descriptions, time_slider, device)
105
 
 
106
  idx = 0
107
  music_tensor = music_tensors[idx]
108
  save_music_file = save_audio(music_tensor)