Spaces:
Build error
Build error
from musicautobot.numpy_encode import * | |
from musicautobot.utils.file_processing import process_all, process_file | |
from musicautobot.config import * | |
from musicautobot.music_transformer import * | |
from musicautobot.utils.setup_musescore import setup_musescore | |
setup_musescore() | |
import gradio as gr | |
from midi2audio import FluidSynth | |
import tempfile | |
import os | |
# Bootloading model | |
data_path = Path('./') | |
data = MusicDataBunch.empty(data_path) | |
vocab = data.vocab | |
pretrained_path='./music_transformer.pth' | |
learn = music_model_learner(data, pretrained_path=pretrained_path, config=default_config()) | |
def predict(seed_midi, n_words=400, temperature1=1.1, temperature2=0.4, min_bars=12, top_k=24, top_p=0.7): | |
# Load input MIDI file as MusicItem | |
cutoff_beat = 10 | |
item = MusicItem.from_file(seed_midi.name, data.vocab) | |
seed_item = item.trim_to_beat(cutoff_beat) | |
# Generate prediction | |
pred, full = learn.predict(seed_item, n_words=n_words, temperatures=(temperature1, temperature2), min_bars=min_bars, top_k=top_k, top_p=top_p) | |
# Convert input MIDI to audio | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as seed_audio_temp: | |
FluidSynth("sound_font.sf2").midi_to_audio(seed_midi.name, seed_audio_temp.name) | |
# Save generated MIDI as temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.midi') as pred_midi_temp: | |
pred.stream.write('midi', fp=pred_midi_temp.name) | |
# Convert generated MIDI to audio | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as pred_audio_temp: | |
FluidSynth("sound_font.sf2").midi_to_audio(pred_midi_temp.name, pred_audio_temp.name) | |
# Cleanup temporary MIDI file | |
os.remove(pred_midi_temp.name) | |
return seed_audio_temp.name, pred_audio_temp.name | |
iface = gr.Interface(fn=predict, | |
inputs=[ | |
gr.inputs.File(label="Seed MIDI"), | |
gr.inputs.Slider(50, 1000, step=10, default=400, label="Number of Words"), | |
gr.inputs.Slider(0.0, 2.0, step=0.1, default=1.1, label="Temperature 1"), | |
gr.inputs.Slider(0.0, 2.0, step=0.1, default=0.4, label="Temperature 2"), | |
gr.inputs.Slider(1, 32, step=1, default=12, label="Min Bars"), | |
gr.inputs.Slider(1, 50, step=1, default=24, label="Top K"), | |
gr.inputs.Slider(0.0, 1.0, step=0.1, default=0.7, label="Top P") | |
], | |
outputs=[ | |
gr.outputs.Audio(type='filepath', label="Seed Audio"), | |
gr.outputs.Audio(type='filepath', label="Generated Audio") | |
],) | |
iface.launch() | |