Gpagejr12 commited on
Commit
766f359
·
verified ·
1 Parent(s): 5015127

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -27
app.py CHANGED
@@ -12,53 +12,44 @@ genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical",
12
  "Techno","Indie Rock", "Grunge", "Ambient","Gospel", "Latin Music","Grime" ,"Trap", "Psychedelic Rock" ]
13
 
14
  @st.cache_resource()
15
-
16
  def load_model():
17
  model = MusicGen.get_pretrained('facebook/musicgen-small')
18
  return model
19
 
20
- def generate_music_tensors(descriptions, duration: int, device):
21
- # Load the model and move it to the specified device
22
  model = load_model()
23
- model = model
 
24
 
25
-
26
  model.set_generation_params(
27
  use_sampling=True,
28
  top_k=250,
29
- duration=duration * 60 # Multiply by 60 to convert minutes to seconds
30
  )
31
 
32
  with st.spinner("Generating Music..."):
33
- # Generate music using the model
34
  output = model.generate(
35
- descriptions=descriptions,
36
- progress=True,
37
- return_tokens=True,
38
  )
39
 
40
- # Save the generated music audio
41
- # Remove the device argument
42
- save_audio(output)
43
-
44
  st.success("Music Generation Complete!")
45
  return output
46
 
47
 
48
-
49
- def save_audio(samples: torch.Tensor, device):
50
  sample_rate = 30000
51
  save_path = "audio_output"
52
  assert samples.dim() == 2 or samples.dim() == 3
53
 
54
- samples = samples.to(device) # Move the samples to the device
55
  if samples.dim() == 2:
56
  samples = samples[None, ...]
57
 
58
  for idx, audio in enumerate(samples):
59
  audio_path = os.path.join(save_path, f"audio_{idx}.wav")
60
- torchaudio.save(audio_path, audio.cpu(), sample_rate) # Move the audio to the CPU before saving
61
-
62
 
63
  def get_binary_file_downloader_html(bin_file, file_label='File'):
64
  with open(bin_file, 'rb') as f:
@@ -74,7 +65,7 @@ st.set_page_config(
74
 
75
  def main():
76
  with st.sidebar:
77
- st.header("""⚙️Generate Music ⚙️""", divider="rainbow")
78
  st.text("")
79
  st.subheader("1. Enter your music description.......")
80
  bpm = st.number_input("Enter Speed in BPM", min_value=60)
@@ -85,13 +76,13 @@ def main():
85
  selected_genre = st.selectbox("Select Genre", genres)
86
 
87
  st.subheader("2. Select time duration (In Seconds)")
88
- # time_slider = st.slider("Select time duration (In Seconds)", 0, 60, 10)
89
- time_slider = st.slider("Select time duration (In Minutes)", 0, 300, 10, step=1)
90
 
91
 
92
  st.title("""🎵 Song Lab AI 🎵""")
93
  st.text('')
94
- left_co, right_co = st.columns(2)
95
  left_co.write("""Music Generation through a prompt""")
96
  left_co.write(("""PS : First generation may take some time ......."""))
97
 
@@ -107,12 +98,11 @@ def main():
107
  st.subheader("Generated Music")
108
 
109
  # Generate audio
 
110
  descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(1)] # Change the batch size to 1
 
111
 
112
- # Pass the device parameter when calling generate_music_tensors
113
- device = torch.device('cpu')
114
- music_tensors = generate_music_tensors(descriptions, time_slider, device)
115
-
116
  idx = 0
117
  music_tensor = music_tensors[idx]
118
  save_music_file = save_audio(music_tensor)
@@ -127,3 +117,4 @@ def main():
127
 
128
  if __name__ == "__main__":
129
  main()
 
 
12
  "Techno","Indie Rock", "Grunge", "Ambient","Gospel", "Latin Music","Grime" ,"Trap", "Psychedelic Rock" ]
13
 
14
  @st.cache_resource()
 
15
  def load_model():
16
  model = MusicGen.get_pretrained('facebook/musicgen-small')
17
  return model
18
 
19
+ def generate_music_tensors(descriptions, duration: int):
 
20
  model = load_model()
21
+ # model = load_model().to('cpu')
22
+
23
 
 
24
  model.set_generation_params(
25
  use_sampling=True,
26
  top_k=250,
27
+ duration=duration
28
  )
29
 
30
  with st.spinner("Generating Music..."):
 
31
  output = model.generate(
32
+ descriptions=descriptions,
33
+ progress=True,
34
+ return_tokens=True
35
  )
36
 
 
 
 
 
37
  st.success("Music Generation Complete!")
38
  return output
39
 
40
 
41
+ def save_audio(samples: torch.Tensor):
 
42
  sample_rate = 30000
43
  save_path = "audio_output"
44
  assert samples.dim() == 2 or samples.dim() == 3
45
 
46
+ samples = samples.detach().cpu()
47
  if samples.dim() == 2:
48
  samples = samples[None, ...]
49
 
50
  for idx, audio in enumerate(samples):
51
  audio_path = os.path.join(save_path, f"audio_{idx}.wav")
52
+ torchaudio.save(audio_path, audio, sample_rate)
 
53
 
54
  def get_binary_file_downloader_html(bin_file, file_label='File'):
55
  with open(bin_file, 'rb') as f:
 
65
 
66
  def main():
67
  with st.sidebar:
68
+ st.header("""⚙️Generate Music ⚙️""",divider="rainbow")
69
  st.text("")
70
  st.subheader("1. Enter your music description.......")
71
  bpm = st.number_input("Enter Speed in BPM", min_value=60)
 
76
  selected_genre = st.selectbox("Select Genre", genres)
77
 
78
  st.subheader("2. Select time duration (In Seconds)")
79
+ time_slider = st.slider("Select time duration (In Seconds)", 0, 60, 10)
80
+ # time_slider = st.slider("Select time duration (In Minutes)", 0,300,10, step=1)
81
 
82
 
83
  st.title("""🎵 Song Lab AI 🎵""")
84
  st.text('')
85
+ left_co,right_co = st.columns(2)
86
  left_co.write("""Music Generation through a prompt""")
87
  left_co.write(("""PS : First generation may take some time ......."""))
88
 
 
98
  st.subheader("Generated Music")
99
 
100
  # Generate audio
101
+ # descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(5)]
102
  descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(1)] # Change the batch size to 1
103
+ music_tensors = generate_music_tensors(descriptions, time_slider)
104
 
105
+ # Only play the full audio for index 0
 
 
 
106
  idx = 0
107
  music_tensor = music_tensors[idx]
108
  save_music_file = save_audio(music_tensor)
 
117
 
118
  if __name__ == "__main__":
119
  main()
120
+