File size: 2,730 Bytes
ac4cbcf
 
 
 
 
 
4819bc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from musicautobot.numpy_encode import *
from musicautobot.utils.file_processing import process_all, process_file
from musicautobot.config import *
from musicautobot.music_transformer import *
from musicautobot.utils.setup_musescore import setup_musescore
setup_musescore()

import gradio as gr
from midi2audio import FluidSynth
import tempfile
import os

# Bootloading model
data_path = Path('./')
data = MusicDataBunch.empty(data_path)
vocab = data.vocab
pretrained_path='./music_transformer.pth'
learn = music_model_learner(data, pretrained_path=pretrained_path, config=default_config())


def predict(seed_midi, n_words=400, temperature1=1.1, temperature2=0.4, min_bars=12, top_k=24, top_p=0.7):
    # Load input MIDI file as MusicItem
    cutoff_beat = 10
    item = MusicItem.from_file(seed_midi.name, data.vocab)
    seed_item = item.trim_to_beat(cutoff_beat)

    # Generate prediction
    pred, full = learn.predict(seed_item, n_words=n_words, temperatures=(temperature1, temperature2), min_bars=min_bars, top_k=top_k, top_p=top_p)

    # Convert input MIDI to audio
    with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as seed_audio_temp:
        FluidSynth("sound_font.sf2").midi_to_audio(seed_midi.name, seed_audio_temp.name)

    # Save generated MIDI as temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix='.midi') as pred_midi_temp:
        pred.stream.write('midi', fp=pred_midi_temp.name)

    # Convert generated MIDI to audio
    with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as pred_audio_temp:
        FluidSynth("sound_font.sf2").midi_to_audio(pred_midi_temp.name, pred_audio_temp.name)

    # Cleanup temporary MIDI file
    os.remove(pred_midi_temp.name)

    return seed_audio_temp.name, pred_audio_temp.name



iface = gr.Interface(fn=predict,
                     inputs=[
                         gr.inputs.File(label="Seed MIDI"),
                         gr.inputs.Slider(50, 1000, step=10, default=400, label="Number of Words"),
                         gr.inputs.Slider(0.0, 2.0, step=0.1, default=1.1, label="Temperature 1"),
                         gr.inputs.Slider(0.0, 2.0, step=0.1, default=0.4, label="Temperature 2"),
                         gr.inputs.Slider(1, 32, step=1, default=12, label="Min Bars"),
                         gr.inputs.Slider(1, 50, step=1, default=24, label="Top K"),
                         gr.inputs.Slider(0.0, 1.0, step=0.1, default=0.7, label="Top P")
                     ],
                     outputs=[
                         gr.outputs.Audio(type='filepath', label="Seed Audio"),
                         gr.outputs.Audio(type='filepath', label="Generated Audio")
                     ],)

iface.launch()