midi-composer / app.py
awacke1's picture
Update app.py
1151735 verified
raw
history blame
7.49 kB
import argparse
import base64
import io
import os
import random
import numpy as np
import gradio as gr
import rtmidi
import onnxruntime as rt
from huggingface_hub import hf_hub_download
import MIDI
from midi_synthesizer import MidiSynthesizer
from midi_tokenizer import MIDITokenizer
# Constants
MAX_SEED = np.iinfo(np.int32).max
IN_SPACE = os.getenv("SYSTEM") == "spaces"
MAX_LENGTH = 1024 # Maximum tokens for generation
# MIDI Device Manager
class MIDIDeviceManager:
def __init__(self):
self.midiout = rtmidi.MidiOut()
self.midiin = rtmidi.MidiIn()
def get_device_info(self):
out_ports = self.midiout.get_ports() or ["No MIDI output devices"]
in_ports = self.midiin.get_ports() or ["No MIDI input devices"]
return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}"
def close(self):
if self.midiout.is_port_open():
self.midiout.close_port()
if self.midiin.is_port_open():
self.midiin.close_port()
del self.midiout, self.midiin
# MIDI Processor with ONNX Generation
class MIDIManager:
def __init__(self):
self.soundfont = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
self.synthesizer = MidiSynthesizer(self.soundfont)
self.tokenizer = self._load_tokenizer("skytnt/midi-model")
self.model_base = rt.InferenceSession(
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"),
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.model_token = rt.InferenceSession(
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"),
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.generated_files = [] # Store base64-encoded MIDI data
self.is_playing = False
def _load_tokenizer(self, repo_id):
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
with open(config_path, "r") as f:
config = json.load(f)
tokenizer = MIDITokenizer(config["tokenizer"]["version"])
tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
return tokenizer
def load_midi(self, file_path):
try:
return MIDI.load(file_path)
except Exception as e:
raise ValueError(f"Failed to load MIDI file: {e}")
def generate_variation(self, midi_data, temp=1.0, top_p=0.98, top_k=20):
# Tokenize input MIDI
mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data))
input_tensor = np.array([mid_seq], dtype=np.int64)
cur_len = input_tensor.shape[1]
generator = np.random.RandomState(random.randint(0, MAX_SEED))
# Generate up to MAX_LENGTH
while cur_len < MAX_LENGTH:
inputs = {"x": input_tensor[:, -1:]} # Last token
hidden = self.model_base.run(None, inputs)[0] # Base model output
logits = self.model_token.run(None, {"hidden": hidden})[0] # Token model output
probs = softmax(logits / temp, axis=-1)
next_token = sample_top_p_k(probs, top_p, top_k, generator)
input_tensor = np.concatenate([input_tensor, next_token], axis=1)
cur_len += 1
# Detokenize and save as MIDI
new_seq = input_tensor[0].tolist()
new_midi = self.tokenizer.detokenize(new_seq)
midi_output = io.BytesIO()
MIDI.score2midi(new_midi, midi_output)
midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8')
self.generated_files.append(midi_data)
return midi_data
def play_midi(self, midi_data):
self.is_playing = True
midi_bytes = base64.b64decode(midi_data)
midi_file = MIDI.load(io.BytesIO(midi_bytes))
audio = io.BytesIO()
self.synthesizer.render_midi(midi_file, audio)
audio.seek(0)
return audio
def stop(self):
self.is_playing = False
# Helper Functions
def softmax(x, axis):
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def sample_top_p_k(probs, p, k, generator):
probs_idx = np.argsort(-probs, axis=-1)
probs_sort = np.take_along_axis(probs, probs_idx, axis=-1)
probs_sum = np.cumsum(probs_sort, axis=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort[:, k:] = 0.0 # Top-k filtering
probs_sort /= probs_sort.sum(axis=-1, keepdims=True)
next_token = generator.choice(probs.shape[-1], p=probs_sort[0])
return np.array([[next_token]])
# UI Functions
def process_midi_upload(files):
if not files:
return None, "No file uploaded", ""
file = files[0] # Process first file
try:
midi_data = midi_processor.load_midi(file.name)
generated_midi = midi_processor.generate_variation(midi_data)
audio = midi_processor.play_midi(generated_midi)
download_html = create_download_list()
return audio, "Generated and playing", download_html
except Exception as e:
return None, f"Error: {e}", ""
def create_download_list():
if not midi_processor.generated_files:
return "<p>No generated files yet.</p>"
html = "<h3>Generated MIDI Files</h3><ul>"
for i, midi_data in enumerate(midi_processor.generated_files):
html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="generated_{i}.mid">Download MIDI {i}</a></li>'
html += "</ul>"
return html
def refresh_devices():
return device_manager.get_device_info()
def stop_playback():
midi_processor.stop()
return "Playback stopped"
# Main Application
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MIDI Composer with ONNX Generation")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
device_manager = MIDIDeviceManager()
midi_processor = MIDIManager()
with gr.Blocks(title="MIDI Composer", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🎵 MIDI Composer 🎵")
with gr.Tabs():
# MIDI Prompt Tab
with gr.Tab("MIDI Prompt"):
midi_upload = gr.File(label="Upload MIDI File", file_types=[".mid", ".midi"])
audio_output = gr.Audio(label="Generated MIDI", type="bytes", autoplay=True)
status = gr.Textbox(label="Status", value="Ready", interactive=False)
midi_upload.change(
process_midi_upload,
inputs=[midi_upload],
outputs=[audio_output, status, gr.HTML(elem_id="downloads")]
)
# Downloads Tab
with gr.Tab("Downloads", elem_id="downloads"):
gr.HTML(value=create_download_list())
# Devices Tab
with gr.Tab("Devices"):
device_info = gr.Textbox(label="MIDI Devices", value=device_manager.get_device_info(), interactive=False)
refresh_btn = gr.Button("Refresh Devices")
stop_btn = gr.Button("Stop Playback")
refresh_btn.click(refresh_devices, outputs=[device_info])
stop_btn.click(stop_playback, outputs=[status])
app.launch(server_port=args.port, share=args.share, inbrowser=True)
device_manager.close()