File size: 4,304 Bytes
e0c1514
 
 
 
 
 
 
 
b4432a0
 
 
 
e0c1514
 
c9d1263
e0c1514
 
 
 
58fc0c2
c9d1263
58fc0c2
c9d1263
e0c1514
 
b52703a
bce68e6
c9d1263
e0c1514
 
6f1d0eb
c9d1263
e0c1514
 
 
58fc0c2
c9d1263
e0c1514
 
c9d1263
 
 
3b9debd
e0c1514
 
3b9debd
c9d1263
 
e0c1514
8f641db
f0c1424
e0c1514
c9d1263
e0c1514
 
 
 
 
c9d1263
 
2cae516
7159bbe
 
 
3b9debd
 
7159bbe
 
e0c1514
7c39bf5
 
e0c1514
3b9debd
e0c1514
 
0f49a1d
e0c1514
 
 
 
 
 
 
 
7c39bf5
bce68e6
0c3e3fb
0f49a1d
0c3e3fb
ecc5750
e0c1514
 
0f49a1d
e0c1514
 
7c39bf5
 
a888933
 
 
 
 
 
 
a22f91a
a888933
 
 
57cdeb5
a888933
f71571a
 
 
 
a888933
 
 
5662bb5
a888933
 
 
 
 
5f5c021
4363c0a
3b9debd
e0c1514
4a79b1a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import streamlit as st
import torch
import torchaudio
from audiocraft.models import MusicGen
import os
import numpy as np
import base64

genres = ["Pop", "Rock", "Jazz", "Electronic", "Hip-Hop", "Classical", 
          "Lofi", "Chillpop","Country","R&G", "Folk","Heavy Metal", 
          "EDM", "Soil", "Funk","Reggae", "Disco", "Punk Rock", "House",
          "Techno","Indie Rock", "Grunge", "Ambient","Gospel", "Latin Music","Grime" ,"Trap", "Psychedelic Rock"  ]

@st.cache_resource()

def load_model():
    model = MusicGen.get_pretrained('facebook/musicgen-small')
    return model

def generate_music_tensors(descriptions, duration: int, device):
    # Load the model and move it to the specified device
    model = load_model()
    model = model.to(device)

    model.set_generation_params(
        use_sampling=True,
        top_k=250,
        duration=duration * 60  # Multiply by 60 to convert minutes to seconds
    )

    with st.spinner("Generating Music..."):
        # Generate music using the model
        output = model.generate(
            descriptions=descriptions,
            progress=True,
            return_tokens=True,
            device=device
        )

    # Save the generated music audio
    save_audio(output, device)

    st.success("Music Generation Complete!")
    return output



def save_audio(samples: torch.Tensor, device):
    sample_rate = 30000
    save_path = "audio_output" 
    assert samples.dim() == 2 or samples.dim() == 3

    samples = samples.to(device)  # Move the samples to the device
    if samples.dim() == 2:
        samples = samples[None, ...]

    for idx, audio in enumerate(samples):
        audio_path = os.path.join(save_path, f"audio_{idx}.wav")
        torchaudio.save(audio_path, audio.cpu(), sample_rate)  # Move the audio to the CPU before saving


def get_binary_file_downloader_html(bin_file, file_label='File'):
    with open(bin_file, 'rb') as f:
        data = f.read()
    bin_str = base64.b64encode(data).decode()
    href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
    return href

st.set_page_config(
    page_icon= "musical_note",
    page_title= "Music Gen"
)

def main():
    with st.sidebar:
        st.header("""⚙️Generate Music ⚙️""", divider="rainbow")
        st.text("")
        st.subheader("1. Enter your music description.......")
        bpm = st.number_input("Enter Speed in BPM", min_value=60)

        text_area = st.text_area('Ex : 80s rock song with guitar and drums')
        st.text('')
        # Dropdown for genres
        selected_genre = st.selectbox("Select Genre", genres)
        
        st.subheader("2. Select time duration (In Seconds)")
        # time_slider = st.slider("Select time duration (In Seconds)", 0, 60, 10)
        time_slider = st.slider("Select time duration (In Minutes)", 0, 300, 10, step=1)


    st.title("""🎵 Song Lab AI 🎵""")
    st.text('')
    left_co, right_co = st.columns(2)
    left_co.write("""Music Generation through a prompt""")
    left_co.write(("""PS : First generation may take some time ......."""))
    
    if st.sidebar.button('Generate !'):
        with left_co:
            st.text('')
            st.text('')
            st.text('')
            st.text('')
            st.text('')
            st.text('')
            st.text('\n\n')
            st.subheader("Generated Music")

            # Generate audio
            descriptions = [f"{text_area} {selected_genre} {bpm} BPM" for _ in range(1)]  # Change the batch size to 1

            # Pass the device parameter when calling generate_music_tensors
            device = torch.device('cpu')
            music_tensors = generate_music_tensors(descriptions, time_slider, device)

            idx = 0
            music_tensor = music_tensors[idx]
            save_music_file = save_audio(music_tensor)
            audio_filepath = f'audio_output/audio_{idx}.wav'
            audio_file = open(audio_filepath, 'rb')
            audio_bytes = audio_file.read()

            # Play the full audio
            st.audio(audio_bytes, format='audio/wav')
            st.markdown(get_binary_file_downloader_html(audio_filepath, f'Audio_{idx}'), unsafe_allow_html=True)


if __name__ == "__main__":
    main()