File size: 7,049 Bytes
246d69a
0b75620
246d69a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d878645
 
 
 
 
246d69a
 
 
 
d878645
246d69a
 
 
 
 
 
 
 
 
 
 
 
d878645
 
 
 
 
 
 
 
 
 
246d69a
d878645
246d69a
d878645
 
246d69a
 
 
 
d878645
 
1e5f245
 
 
 
 
 
1e517e1
d878645
 
 
246d69a
d878645
 
246d69a
0b75620
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gradio as gr
from transformers import pipeline
import tempfile, os
from midi2audio import FluidSynth

# --- Music Generation Logic (API-like Function) ---

def generate_music_api(midi_data=None, chord_progression=None, tempo=120, temperature=0.95, nb_tokens=512, bar_range="0-4"):
    try:
        # Load the MusicLang Predict model (replace with actual loading code)
        ml = ...  # Example: ml = pipeline("music-generation", model="your-musiclang-predict-model")

        # Handle different generation scenarios based on inputs
        if midi_data is not None and chord_progression.strip() != "":
            # Continue sequence with chord progression
            generated_score = ml.continue_sequence(
                midi_data,
                chord_progression=chord_progression,
                nb_tokens=int(nb_tokens),
                temperature=float(temperature),
                # ... other parameters
            )
        elif midi_data is not None and chord_progression.strip() == "":
            # Generate using the uploaded MIDI file as a prompt
            generated_score = ml.predict(
                midi_data,  # Use the uploaded MIDI as the prompt
                nb_tokens=int(nb_tokens),
                temperature=float(temperature),
                # ... other parameters
            )
        else:
            # Generate with specific chord progression
            generated_score = ml.predict_chords(
                chord_progression,
                # ... other parameters
            )

        # Save generated files to temporary locations
        temp_midi_file = tempfile.NamedTemporaryFile(suffix=".mid", delete=False)
        midi_path = temp_midi_file.name
        generated_score.to_midi(midi_path, tempo=tempo, time_signature=time_signature)  # Assuming time_signature is defined

        temp_mp3_file = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False)
        mp3_path = temp_mp3_file.name
        # ... (convert MIDI to MP3 using FluidSynth and FFmpeg)

        # Read binary data from the temporary files
        with open(mp3_path, 'rb') as f_mp3:
            mp3_binary = f_mp3.read()

        with open(midi_path, 'rb') as f_midi:
            midi_binary = f_midi.read()

        # Remove temporary files
        os.remove(mp3_path)
        os.remove(midi_path)

        return {
            "mp3": mp3_binary,
            "midi": midi_binary,
            "chord_repr": chord_repr,  # Assuming chord_repr is still needed
            "tempo_message": tempo_message  # Assuming tempo_message is still needed
        }

    except Exception as e:
        return {"error": str(e)}

# --- Gradio Interface ---

def musiclang_gradio(midi_file, chord_progression, tempo, temperature, nb_tokens, bar_range):
    midi_data = None
    if midi_file:
        with open(midi_file.name, "rb") as f:
            midi_data = f.read()

    api_response = generate_music_api(midi_data=midi_data, chord_progression=chord_progression, tempo=tempo, temperature=temperature, nb_tokens=nb_tokens, bar_range=bar_range)

    if "error" in api_response:
        return None, None, api_response["error"]

    # Create temporary files for Gradio
    mp3_path = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
    midi_path = tempfile.NamedTemporaryFile(suffix=".mid", delete=False).name

    # Write binary data to temporary files
    with open(mp3_path, "wb") as f:
        f.write(api_response["mp3"])

    with open(midi_path, "wb") as f:
        f.write(api_response["midi"])

    return mp3_path, midi_path, None

with gr.Blocks() as demo:
    # Introductory text
    gr.Markdown("""
    # Controllable Symbolic Music Generation with MusicLang Predict
    [MusicLang Predict](https://github.com/musiclang/musiclang_predict) offers advanced controllability features and high-quality music generation by manipulating symbolic music.
    You can for example use it to continue your composition with a specific chord progression.
    """)

    with gr.Row():
        with gr.Column():
            with gr.Row():
                midi_file = gr.File(label="Prompt MIDI File (Optional)", type="filepath", file_types=[".mid", ".midi"],
                                    elem_id='midi_file_input')
            with gr.Column():
                bar_range = gr.Textbox(label="Bar Range of input file (eg: 0-4 for first four bars)", placeholder="0-4",
                                       value="0-4", elem_id='bar_range_input')
                nb_tokens = gr.Number(label="Nb Tokens",
                                      value=512, minimum=256, maximum=2048, step=256, elem_id='nb_tokens_input')
                temperature = gr.Slider(
                    label="Temperature",
                    value=0.95,
                    visible=False,
                    minimum=0.1, maximum=1.0, step=0.1, elem_id='temperature_input')
                tempo = gr.Slider(label="Tempo", value=120, minimum=60, maximum=240, step=1, elem_id='tempo_input')
            with gr.Row():
                chord_progression = gr.Textbox(
                    label="Chord Progression (Optional)",
                    placeholder="Am CM Dm7/F E7 Asus4", lines=2, value="", elem_id='chord_progression_input')
            with gr.Row():
                generate_btn = gr.Button("Generate", elem_id='generate_button')
        with gr.Column():
            info_message = gr.Textbox(label="Info Message", elem_id='info_message_output')
            generated_music = gr.Audio(label="Preview generated Music", elem_id='generated_music_output')
            generated_midi = gr.File(label="Download MIDI", elem_id='generated_midi_output')

            generate_btn.click(
                fn=musiclang_gradio,
                inputs=[midi_file, chord_progression, tempo, temperature, nb_tokens, bar_range],
                outputs=[generated_music, generated_midi, info_message]
            )

    with gr.Row():
        with gr.Column():
            gr.Markdown("## Examples")
            gr.Examples(
                examples=[["/home/user/app/bach_847.mid", "", 120, 0.95, 512, "0-4"],
                          ["/home/user/app/bach_847.mid", "Cm C7/E Fm F#dim G7", 120, 0.95, 512, "0-4"],
                          ["/home/user/app/boney_m_ma_baker.mid", "", 120, 0.95, 512, "0-4"],
                          ["/home/user/app/eminem_slim_shady.mid", "Cm AbM BbM G7 Cm", 120, 0.95, 512, "0-4"],
                          ["/home/user/app/mozart_alla_turca.mid", "", 120, 0.95, 512, "0-4"],
                          ["/home/user/app/mozart_alla_turca.mid", "Am Em CM G7 E7 Am Am E7 Am", 120, 0.95, 512, "0-4"],
                          ["/home/user/app/daft_punk_around_the_world.mid", "", 120, 0.95, 512, "0-4"],
                          ],
                inputs=[midi_file, chord_progression, tempo, temperature, nb_tokens, bar_range],
                outputs=[generated_music, generated_midi, info_message],
                fn=musiclang_gradio,
                cache_examples=True,
            )

demo.launch()