Spaces:
Runtime error
Runtime error
File size: 4,304 Bytes
e0c1514 b4432a0 e0c1514 c9d1263 e0c1514 58fc0c2 c9d1263 58fc0c2 c9d1263 e0c1514 b52703a bce68e6 c9d1263 e0c1514 6f1d0eb c9d1263 e0c1514 58fc0c2 c9d1263 e0c1514 c9d1263 3b9debd e0c1514 3b9debd c9d1263 e0c1514 8f641db f0c1424 e0c1514 c9d1263 e0c1514 c9d1263 2cae516 7159bbe 3b9debd 7159bbe e0c1514 7c39bf5 e0c1514 3b9debd e0c1514 0f49a1d e0c1514 7c39bf5 bce68e6 0c3e3fb 0f49a1d 0c3e3fb ecc5750 e0c1514 0f49a1d e0c1514 7c39bf5 a888933 a22f91a a888933 57cdeb5 a888933 f71571a a888933 5662bb5 a888933 5f5c021 4363c0a 3b9debd e0c1514 4a79b1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import streamlit as st
import torch
import torchaudio
from audiocraft.models import MusicGen
import os
import numpy as np
import base64
genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical",
"Lofi", "Chillpop","Country","R&G", "Folk","Heavy Metal",
"EDM", "Soil", "Funk","Reggae", "Disco", "Punk Rock", "House",
"Techno","Indie Rock", "Grunge", "Ambient","Gospel", "Latin Music","Grime" ,"Trap", "Psychedelic Rock" ]
@st.cache_resource()
def load_model():
model = MusicGen.get_pretrained('facebook/musicgen-small')
return model
def generate_music_tensors(descriptions, duration: int, device):
# Load the model and move it to the specified device
model = load_model()
model = model.to(device)
model.set_generation_params(
use_sampling=True,
top_k=250,
duration=duration * 60 # Multiply by 60 to convert minutes to seconds
)
with st.spinner("Generating Music..."):
# Generate music using the model
output = model.generate(
descriptions=descriptions,
progress=True,
return_tokens=True,
device=device
)
# Save the generated music audio
save_audio(output, device)
st.success("Music Generation Complete!")
return output
def save_audio(samples: torch.Tensor, device):
sample_rate = 30000
save_path = "audio_output"
assert samples.dim() == 2 or samples.dim() == 3
samples = samples.to(device) # Move the samples to the device
if samples.dim() == 2:
samples = samples[None, ...]
for idx, audio in enumerate(samples):
audio_path = os.path.join(save_path, f"audio_{idx}.wav")
torchaudio.save(audio_path, audio.cpu(), sample_rate) # Move the audio to the CPU before saving
def get_binary_file_downloader_html(bin_file, file_label='File'):
with open(bin_file, 'rb') as f:
data = f.read()
bin_str = base64.b64encode(data).decode()
href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
return href
st.set_page_config(
page_icon= "musical_note",
page_title= "Music Gen"
)
def main():
with st.sidebar:
st.header("""⚙️Generate Music ⚙️""", divider="rainbow")
st.text("")
st.subheader("1. Enter your music description.......")
bpm = st.number_input("Enter Speed in BPM", min_value=60)
text_area = st.text_area('Ex : 80s rock song with guitar and drums')
st.text('')
# Dropdown for genres
selected_genre = st.selectbox("Select Genre", genres)
st.subheader("2. Select time duration (In Seconds)")
# time_slider = st.slider("Select time duration (In Seconds)", 0, 60, 10)
time_slider = st.slider("Select time duration (In Minutes)", 0, 300, 10, step=1)
st.title("""🎵 Song Lab AI 🎵""")
st.text('')
left_co, right_co = st.columns(2)
left_co.write("""Music Generation through a prompt""")
left_co.write(("""PS : First generation may take some time ......."""))
if st.sidebar.button('Generate !'):
with left_co:
st.text('')
st.text('')
st.text('')
st.text('')
st.text('')
st.text('')
st.text('\n\n')
st.subheader("Generated Music")
# Generate audio
descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(1)] # Change the batch size to 1
# Pass the device parameter when calling generate_music_tensors
device = torch.device('cpu')
music_tensors = generate_music_tensors(descriptions, time_slider, device)
idx = 0
music_tensor = music_tensors[idx]
save_music_file = save_audio(music_tensor)
audio_filepath = f'audio_output/audio_{idx}.wav'
audio_file = open(audio_filepath, 'rb')
audio_bytes = audio_file.read()
# Play the full audio
st.audio(audio_bytes, format='audio/wav')
st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)
if __name__ == "__main__":
main()
|