Music_Generator / app.py
annapurnapadmaprema-ji's picture
Update app.py
0b3a9f2 verified
raw
history blame
2.41 kB
from audiocraft.models import MusicGen
import streamlit as st
import os
import torch
import torchaudio
from io import BytesIO
@st.cache_resource
def load_model():
model = MusicGen.get_pretrained("facebook/musicgen-small")
return model
def generate_music_tensors(description, duration: int):
print("Description:", description)
print("Duration:", duration)
model = load_model()
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=duration
)
output = model.generate(
descriptions=[description],
progress=True,
return_tokens=True
)
return output[0]
def save_audio_to_bytes(samples: torch.Tensor):
sample_rate = 32000
assert samples.dim() == 2 or samples.dim() == 3
samples = samples.detach().cpu()
if samples.dim() == 2:
samples = samples[None, ...] # Add batch dimension if missing
audio_buffer = BytesIO()
torchaudio.save(audio_buffer, samples[0], sample_rate=sample_rate, format="wav")
audio_buffer.seek(0) # Move to the start of the buffer
return audio_buffer
st.set_page_config(
page_icon=":musical_note:",
page_title="Music Gen"
)
def main():
st.title("Your Music")
with st.expander("See Explanation"):
st.write("This app uses Meta's Audiocraft Music Gen model to generate audio based on your description.")
text_area = st.text_area("Enter description")
time_slider = st.slider("Select time duration (seconds)", 2, 20, 5)
if text_area and time_slider:
st.json(
{
"Description": text_area,
"Selected duration": time_slider
}
st.write("We will back with your music....please enjoy doing the rest of your tasks while we come back in some time :)")
)
st.subheader("Generated Music")
music_tensors = generate_music_tensors(text_area, time_slider)
# Convert audio to bytes for playback and download
audio_buffer = save_audio_to_bytes(music_tensors)
# Play audio
st.audio(audio_buffer, format="audio/wav")
# Download button for audio
st.download_button(
label="Download Audio",
data=audio_buffer,
file_name="generated_music.wav",
mime="audio/wav"
)
if __name__ == "__main__":
main()