Spaces:
Build error
Build error
import streamlit as st | |
import torch | |
import torchaudio | |
from audiocraft.models import MusicGen | |
from audiocraft.data.audio import audio_write | |
import os | |
import base64 | |
def generate_music(model, description, duration): | |
model.set_generation_params(duration=duration) | |
wav = model.generate([description]) # Generate 1 sample | |
return wav[0] | |
def get_audio_player(audio_data): | |
audio_file = "temp_audio.wav" | |
torchaudio.save(audio_file, audio_data.cpu(), sample_rate=32000) | |
with open(audio_file, "rb") as f: | |
audio_bytes = f.read() | |
os.remove(audio_file) | |
b64 = base64.b64encode(audio_bytes).decode() | |
return f'<audio controls><source src="data:audio/wav;base64,{b64}" type="audio/wav"></audio>' | |
def main(): | |
st.set_page_config(page_title="SunnAI - Music Generation", page_icon="π΅") | |
st.title("SunnAI - AI Music Generation") | |
st.sidebar.header("Model Settings") | |
model_size = st.sidebar.selectbox("Select model size", ["small", "medium", "large"]) | |
def load_model(model_size): | |
return MusicGen.get_pretrained(f'facebook/musicgen-{model_size}') | |
model = load_model(model_size) | |
st.write("Welcome to SunnAI! Generate music using AI with just a text description.") | |
description = st.text_area("Enter a description for your music:", "A happy rock song with electric guitar and drums") | |
duration = st.slider("Select music duration (in seconds):", min_value=1, max_value=30, value=10) | |
if st.button("Generate Music"): | |
with st.spinner("Generating music... This may take a moment."): | |
generated_audio = generate_music(model, description, duration) | |
st.success("Music generated successfully!") | |
st.markdown(get_audio_player(generated_audio), unsafe_allow_html=True) | |
# Option to download the generated audio | |
audio_file = "generated_music.wav" | |
torchaudio.save(audio_file, generated_audio.cpu(), sample_rate=32000) | |
with open(audio_file, "rb") as f: | |
st.download_button( | |
label="Download generated music", | |
data=f, | |
file_name="sunnai_generated_music.wav", | |
mime="audio/wav" | |
) | |
os.remove(audio_file) | |
if __name__ == "__main__": | |
main() |