|
from audiocraft.models import MusicGen |
|
import streamlit as st |
|
import os |
|
import torch |
|
import torchaudio |
|
from io import BytesIO |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model = MusicGen.get_pretrained("facebook/musicgen-small") |
|
return model |
|
|
|
def generate_music_tensors(description, duration: int): |
|
print("Description:", description) |
|
print("Duration:", duration) |
|
model = load_model() |
|
|
|
model.set_generation_params( |
|
use_sampling=True, |
|
top_k=250, |
|
duration=duration |
|
) |
|
|
|
output = model.generate( |
|
descriptions=[description], |
|
progress=True, |
|
return_tokens=True |
|
) |
|
return output[0] |
|
|
|
def save_audio_to_bytes(samples: torch.Tensor): |
|
sample_rate = 32000 |
|
assert samples.dim() == 2 or samples.dim() == 3 |
|
samples = samples.detach().cpu() |
|
|
|
if samples.dim() == 2: |
|
samples = samples[None, ...] |
|
|
|
audio_buffer = BytesIO() |
|
torchaudio.save(audio_buffer, samples[0], sample_rate=sample_rate, format="wav") |
|
audio_buffer.seek(0) |
|
return audio_buffer |
|
|
|
st.set_page_config( |
|
page_icon=":musical_note:", |
|
page_title="Music Gen" |
|
) |
|
|
|
def main(): |
|
st.title("Your Music") |
|
|
|
with st.expander("See Explanation"): |
|
st.write("This app uses Meta's Audiocraft Music Gen model to generate audio based on your description.") |
|
|
|
text_area = st.text_area("Enter description") |
|
time_slider = st.slider("Select time duration (seconds)", 2, 20, 5) |
|
|
|
if text_area and time_slider: |
|
st.json( |
|
{ |
|
"Description": text_area, |
|
"Selected duration": time_slider |
|
} |
|
st.write("We will back with your music....please enjoy doing the rest of your tasks while we come back in some time :)") |
|
) |
|
st.subheader("Generated Music") |
|
music_tensors = generate_music_tensors(text_area, time_slider) |
|
|
|
|
|
audio_buffer = save_audio_to_bytes(music_tensors) |
|
|
|
|
|
st.audio(audio_buffer, format="audio/wav") |
|
|
|
|
|
st.download_button( |
|
label="Download Audio", |
|
data=audio_buffer, |
|
file_name="generated_music.wav", |
|
mime="audio/wav" |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|