Spaces:
Runtime error
Runtime error
File size: 4,047 Bytes
e0c1514 b4432a0 e0c1514 bce68e6 0f49a1d 57cdeb5 0f49a1d e0c1514 b52703a bce68e6 dfe4c30 e0c1514 3b9debd e0c1514 3b9debd e0c1514 8f641db f0c1424 e0c1514 7a5d65f e0c1514 3b9debd 2cae516 7159bbe 3b9debd 7159bbe e0c1514 7c39bf5 e0c1514 3b9debd e0c1514 0f49a1d e0c1514 7c39bf5 bce68e6 0c3e3fb 0f49a1d 0c3e3fb ecc5750 e0c1514 0f49a1d e0c1514 7c39bf5 a888933 a22f91a a888933 57cdeb5 bce68e6 a888933 3b9debd a888933 5662bb5 a888933 5f5c021 4363c0a 3b9debd e0c1514 4a79b1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import streamlit as st
import torch
import torchaudio
from audiocraft.models import MusicGen
import os
import numpy as np
import base64
genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical",
"Lofi", "Chillpop","Country","R&G", "Folk","Heavy Metal",
"EDM", "Soil", "Funk","Reggae", "Disco", "Punk Rock", "House",
"Techno","Indie Rock", "Grunge", "Ambient","Gospel", "Latin Music","Grime" ,"Trap", "Psychedelic Rock" ]
@st.cache_resource()
def load_model():
model = MusicGen.get_pretrained('facebook/musicgen-small')
return model
def generate_music_tensors(descriptions, duration: int):
model = load_model()
# Manually set the device to CPU
device = torch.device('cpu')
model = model.to(device)
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=duration
)
with st.spinner("Generating Music..."):
output = model.generate(
descriptions=descriptions,
progress=True,
return_tokens=True
)
st.success("Music Generation Complete!")
return output
def save_audio(samples: torch.Tensor):
sample_rate = 30000
save_path = "audio_output"
assert samples.dim() == 2 or samples.dim() == 3
samples = samples.detach().cpu()
if samples.dim() == 2:
samples = samples[None, ...]
for idx, audio in enumerate(samples):
audio_path = os.path.join(save_path, f"audio_{idx}.wav")
torchaudio.save(audio_path, audio, sample_rate)
def get_binary_file_downloader_html(bin_file, file_label='File'):
with open(bin_file, 'rb') as f:
data = f.read()
bin_str = base64.b64encode(data).decode()
href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
return href
st.set_page_config(
page_icon= "musical_note",
page_title= "Music Gen"
)
def main():
with st.sidebar:
st.header("""⚙️Generate Music ⚙️""", divider="rainbow")
st.text("")
st.subheader("1. Enter your music description.......")
bpm = st.number_input("Enter Speed in BPM", min_value=60)
text_area = st.text_area('Ex : 80s rock song with guitar and drums')
st.text('')
# Dropdown for genres
selected_genre = st.selectbox("Select Genre", genres)
st.subheader("2. Select time duration (In Seconds)")
# time_slider = st.slider("Select time duration (In Seconds)", 0, 60, 10)
time_slider = st.slider("Select time duration (In Minutes)", 0, 300, 10, step=1)
st.title("""🎵 Song Lab AI 🎵""")
st.text('')
left_co, right_co = st.columns(2)
left_co.write("""Music Generation through a prompt""")
left_co.write(("""PS : First generation may take some time ......."""))
if st.sidebar.button('Generate !'):
with left_co:
st.text('')
st.text('')
st.text('')
st.text('')
st.text('')
st.text('')
st.text('\n\n')
st.subheader("Generated Music")
# Generate audio
# descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(5)]
descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(1)] # Change the batch size to 1
music_tensors = generate_music_tensors(descriptions, time_slider)
# Only play the full audio for index 0
idx = 0
music_tensor = music_tensors[idx]
save_music_file = save_audio(music_tensor)
audio_filepath = f'audio_output/audio_{idx}.wav'
audio_file = open(audio_filepath, 'rb')
audio_bytes = audio_file.read()
# Play the full audio
st.audio(audio_bytes, format='audio/wav')
st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
if __name__ == "__main__":
main()
|