Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import base64 | |
import io | |
import os | |
import random | |
import numpy as np | |
import gradio as gr | |
import rtmidi | |
import onnxruntime as rt | |
from huggingface_hub import hf_hub_download | |
import MIDI | |
from midi_synthesizer import MidiSynthesizer | |
from midi_tokenizer import MIDITokenizer | |
# Constants | |
MAX_SEED = np.iinfo(np.int32).max | |
IN_SPACE = os.getenv("SYSTEM") == "spaces" | |
MAX_LENGTH = 1024 # Maximum tokens for generation | |
# MIDI Device Manager | |
class MIDIDeviceManager: | |
def __init__(self): | |
self.midiout = rtmidi.MidiOut() | |
self.midiin = rtmidi.MidiIn() | |
def get_device_info(self): | |
out_ports = self.midiout.get_ports() or ["No MIDI output devices"] | |
in_ports = self.midiin.get_ports() or ["No MIDI input devices"] | |
return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}" | |
def close(self): | |
if self.midiout.is_port_open(): | |
self.midiout.close_port() | |
if self.midiin.is_port_open(): | |
self.midiin.close_port() | |
del self.midiout, self.midiin | |
# MIDI Processor with ONNX Generation | |
class MIDIManager: | |
def __init__(self): | |
self.soundfont = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") | |
self.synthesizer = MidiSynthesizer(self.soundfont) | |
self.tokenizer = self._load_tokenizer("skytnt/midi-model") | |
self.model_base = rt.InferenceSession( | |
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"), | |
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
) | |
self.model_token = rt.InferenceSession( | |
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"), | |
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] | |
) | |
self.generated_files = [] # Store base64-encoded MIDI data | |
self.is_playing = False | |
def _load_tokenizer(self, repo_id): | |
config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
tokenizer = MIDITokenizer(config["tokenizer"]["version"]) | |
tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"]) | |
return tokenizer | |
def load_midi(self, file_path): | |
try: | |
return MIDI.load(file_path) | |
except Exception as e: | |
raise ValueError(f"Failed to load MIDI file: {e}") | |
def generate_variation(self, midi_data, temp=1.0, top_p=0.98, top_k=20): | |
# Tokenize input MIDI | |
mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data)) | |
input_tensor = np.array([mid_seq], dtype=np.int64) | |
cur_len = input_tensor.shape[1] | |
generator = np.random.RandomState(random.randint(0, MAX_SEED)) | |
# Generate up to MAX_LENGTH | |
while cur_len < MAX_LENGTH: | |
inputs = {"x": input_tensor[:, -1:]} # Last token | |
hidden = self.model_base.run(None, inputs)[0] # Base model output | |
logits = self.model_token.run(None, {"hidden": hidden})[0] # Token model output | |
probs = softmax(logits / temp, axis=-1) | |
next_token = sample_top_p_k(probs, top_p, top_k, generator) | |
input_tensor = np.concatenate([input_tensor, next_token], axis=1) | |
cur_len += 1 | |
# Detokenize and save as MIDI | |
new_seq = input_tensor[0].tolist() | |
new_midi = self.tokenizer.detokenize(new_seq) | |
midi_output = io.BytesIO() | |
MIDI.score2midi(new_midi, midi_output) | |
midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8') | |
self.generated_files.append(midi_data) | |
return midi_data | |
def play_midi(self, midi_data): | |
self.is_playing = True | |
midi_bytes = base64.b64decode(midi_data) | |
midi_file = MIDI.load(io.BytesIO(midi_bytes)) | |
audio = io.BytesIO() | |
self.synthesizer.render_midi(midi_file, audio) | |
audio.seek(0) | |
return audio | |
def stop(self): | |
self.is_playing = False | |
# Helper Functions | |
def softmax(x, axis): | |
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) | |
return exp_x / np.sum(exp_x, axis=axis, keepdims=True) | |
def sample_top_p_k(probs, p, k, generator): | |
probs_idx = np.argsort(-probs, axis=-1) | |
probs_sort = np.take_along_axis(probs, probs_idx, axis=-1) | |
probs_sum = np.cumsum(probs_sort, axis=-1) | |
mask = probs_sum - probs_sort > p | |
probs_sort[mask] = 0.0 | |
probs_sort[:, k:] = 0.0 # Top-k filtering | |
probs_sort /= probs_sort.sum(axis=-1, keepdims=True) | |
next_token = generator.choice(probs.shape[-1], p=probs_sort[0]) | |
return np.array([[next_token]]) | |
# UI Functions | |
def process_midi_upload(files): | |
if not files: | |
return None, "No file uploaded", "" | |
file = files[0] # Process first file | |
try: | |
midi_data = midi_processor.load_midi(file.name) | |
generated_midi = midi_processor.generate_variation(midi_data) | |
audio = midi_processor.play_midi(generated_midi) | |
download_html = create_download_list() | |
return audio, "Generated and playing", download_html | |
except Exception as e: | |
return None, f"Error: {e}", "" | |
def create_download_list(): | |
if not midi_processor.generated_files: | |
return "<p>No generated files yet.</p>" | |
html = "<h3>Generated MIDI Files</h3><ul>" | |
for i, midi_data in enumerate(midi_processor.generated_files): | |
html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="generated_{i}.mid">Download MIDI {i}</a></li>' | |
html += "</ul>" | |
return html | |
def refresh_devices(): | |
return device_manager.get_device_info() | |
def stop_playback(): | |
midi_processor.stop() | |
return "Playback stopped" | |
# Main Application | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="MIDI Composer with ONNX Generation") | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--share", action="store_true") | |
args = parser.parse_args() | |
device_manager = MIDIDeviceManager() | |
midi_processor = MIDIManager() | |
with gr.Blocks(title="MIDI Composer", theme=gr.themes.Soft()) as app: | |
gr.Markdown("# 🎵 MIDI Composer 🎵") | |
with gr.Tabs(): | |
# MIDI Prompt Tab | |
with gr.Tab("MIDI Prompt"): | |
midi_upload = gr.File(label="Upload MIDI File", file_types=[".mid", ".midi"]) | |
audio_output = gr.Audio(label="Generated MIDI", type="bytes", autoplay=True) | |
status = gr.Textbox(label="Status", value="Ready", interactive=False) | |
midi_upload.change( | |
process_midi_upload, | |
inputs=[midi_upload], | |
outputs=[audio_output, status, gr.HTML(elem_id="downloads")] | |
) | |
# Downloads Tab | |
with gr.Tab("Downloads", elem_id="downloads"): | |
gr.HTML(value=create_download_list()) | |
# Devices Tab | |
with gr.Tab("Devices"): | |
device_info = gr.Textbox(label="MIDI Devices", value=device_manager.get_device_info(), interactive=False) | |
refresh_btn = gr.Button("Refresh Devices") | |
stop_btn = gr.Button("Stop Playback") | |
refresh_btn.click(refresh_devices, outputs=[device_info]) | |
stop_btn.click(stop_playback, outputs=[status]) | |
app.launch(server_port=args.port, share=args.share, inbrowser=True) | |
device_manager.close() |