File size: 3,925 Bytes
e0c1514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a79b1a
 
e0c1514
 
 
 
 
 
4a79b1a
e0c1514
 
 
 
2cae516
 
f0c1424
 
2cae516
f0c1424
e0c1514
7a5d65f
e0c1514
 
 
 
 
2cae516
 
 
 
 
 
 
 
7159bbe
 
 
5d07865
 
e0c1514
7159bbe
 
 
 
 
 
 
 
 
e0c1514
7c39bf5
 
e0c1514
 
 
5d07865
e0c1514
 
 
 
 
 
 
 
7c39bf5
e0c1514
 
 
 
 
a888933
e0c1514
 
7c39bf5
 
a888933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e817b5
a888933
 
 
 
 
6f62f17
 
a888933
e0c1514
 
4a79b1a
503ccd3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st
import torch
import torchaudio
from audiocraft.models import MusicGen
import os
import numpy as np
import base64


@st.cache_resource()
def load_model():
    model = MusicGen.get_pretrained('facebook/musicgen-small')
    return model

def generate_music_tensors(descriptions, duration: int):
    model = load_model()

    model.set_generation_params(
        use_sampling=True,
        top_k=250,
        duration=duration
    )

    with st.spinner("Generating Music..."):
        st.markdown("### Generating Music... 🎵🎶🎹")

        output = model.generate(
            descriptions=descriptions,
            progress=True,
            return_tokens=True
        )

    st.success("Music Generation Complete! 🎉")
    return output

def save_audio(samples: torch.Tensor):
    sample_rate = 30000
    save_path = "/tmp/audio_output"  # Use /tmp directory

    if not os.path.exists(save_path):
        os.makedirs(save_path)

    assert samples.dim() == 2 or samples.dim() == 3

    samples = samples.detach().cpu()
    if samples.dim() == 2:
        samples = samples[None, ...]

    for idx, audio in enumerate(samples):
        audio_path = os.path.join(save_path, f"audio_{idx}.wav")
        try:
            torchaudio.save(audio_path, audio, sample_rate)
        except Exception as e:
            st.error(f"Error saving audio file: {e}")
            return None

    return save_path


    

# Define the genres list
genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", "Lofi", "Chillpop"]


# Add this function for downloading binary files
def get_binary_file_downloader_html(bin_file, file_label='File'):
    with open(bin_file, 'rb') as f:
        data = f.read()
        bin_str = base64.b64encode(data).decode()
        href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{file_label}">Download {file_label}</a>'
    return href

st.set_page_config(
    page_icon= "musical_note",
    page_title= "Music Gen"
)
def main():
    with st.sidebar:
        st.header("""⚙️Generate Music ⚙️""", divider="rainbow")
        st.text("")
        st.subheader("1. Enter your music description.......")
        bpm = st.number_input("Enter Speed in BPM", min_value=60)

        text_area = st.text_area('Ex : 80s rock song with guitar and drums')
        st.text('')
        # Dropdown for genres
        selected_genre = st.selectbox("Select Genre", genres)
        
        st.subheader("2. Select time duration (In Seconds)")
        time_slider = st.slider("Select time duration (In Seconds)", 0, 60, 10)

    st.title("""🎵 Song Lab AI 🎵""")
    st.text('')
    left_co, right_co = st.columns(2)
    left_co.write("""Music Generation through a prompt""")
    left_co.write(("""PS : First generation may take some time ......."""))
    
    if st.sidebar.button('Generate !'):
        with left_co:
            st.text('')
            st.text('')
            st.text('')
            st.text('')
            st.text('')
            st.text('')
            st.subheader("Generated Music")

            # Generate audio
            descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(5)]  # Adjust the batch size (5 in this case)
            music_tensors = generate_music_tensors(descriptions, time_slider)

            # Only play the full audio for index 0
            idx = 0
            music_tensor = music_tensors[idx]
            save_music_file = save_audio(music_tensor)
            audio_filepath = f'audio_output/audio_{idx}.wav'
            audio_file = open(audio_filepath, 'rb')
            audio_bytes = audio_file.read()

            # Play the full audio
            st.audio(audio_bytes, format='audio/wav')

            # Add download link
            st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)

if __name__ == "__main__":
    main()