Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import argparse | |
import os | |
import glob | |
import json | |
import rtmidi | |
import gradio as gr | |
import numpy as np | |
import onnxruntime as rt | |
from huggingface_hub import hf_hub_download | |
import MIDI | |
from midi_synthesizer import MidiSynthesizer | |
from midi_tokenizer import MIDITokenizer | |
MAX_SEED = np.iinfo(np.int32).max | |
in_space = os.getenv("SYSTEM") == "spaces" | |
class MIDIDeviceManager: | |
def __init__(self): | |
self.midiout = rtmidi.MidiOut() | |
self.midiin = rtmidi.MidiIn() | |
def get_output_devices(self): | |
return self.midiout.get_ports() or ["No MIDI output devices"] | |
def get_input_devices(self): | |
return self.midiin.get_ports() or ["No MIDI input devices"] | |
def get_device_info(self): | |
out_devices = self.get_output_devices() | |
in_devices = self.get_input_devices() | |
out_info = "\n".join([f"Out Port {i}: {name}" for i, name in enumerate(out_devices)]) if out_devices else "No MIDI output devices detected" | |
in_info = "\n".join([f"In Port {i}: {name}" for i, name in enumerate(in_devices)]) if in_devices else "No MIDI input devices detected" | |
return f"Output Devices:\n{out_info}\n\nInput Devices:\n{in_info}" | |
def close(self): | |
if self.midiout.is_port_open(): | |
self.midiout.close_port() | |
if self.midiin.is_port_open(): | |
self.midiin.close_port() | |
del self.midiout | |
del self.midiin | |
class MIDIManager: | |
def __init__(self): | |
self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") | |
self.synthesizer = MidiSynthesizer(self.soundfont_path) | |
self.loaded_midi = {} | |
self.modified_files = [] | |
self.is_playing = False | |
self.tokenizer = self.load_tokenizer("skytnt/midi-model") | |
self.model_base = rt.InferenceSession(hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
self.model_token = rt.InferenceSession(hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) | |
def load_tokenizer(self, repo_id): | |
config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
with open(config_path, "r") as f: | |
config = json.load(f) | |
tokenizer = MIDITokenizer(config["tokenizer"]["version"]) | |
tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"]) | |
return tokenizer | |
def load_midi(self, file_path): | |
midi = MIDI.load(file_path) | |
midi_id = f"midi_{len(self.loaded_midi)}" | |
self.loaded_midi[midi_id] = (file_path, midi) | |
return midi_id | |
def extract_notes_and_instruments(self, midi): | |
notes = [] | |
instruments = set() | |
for track in midi.tracks: | |
for event in track.events: | |
if event.type == 'note_on' and event.velocity > 0: | |
notes.append((event.note, event.velocity, event.time)) | |
if hasattr(event, 'program'): | |
instruments.add(event.program) | |
return notes, list(instruments) | |
def generate_variation(self, midi_id, length_factor=10, variation=0.3): | |
if midi_id not in self.loaded_midi: | |
return None | |
_, midi = self.loaded_midi[midi_id] | |
notes, instruments = self.extract_notes_and_instruments(midi) | |
new_notes = [] | |
for _ in range(int(length_factor)): # Max length: 10x repetition | |
for note, vel, time in notes: | |
if random.random() < variation: | |
new_note = min(127, max(0, note + random.randint(-2, 2))) | |
new_vel = min(127, max(0, vel + random.randint(-10, 10))) | |
new_notes.append((new_note, new_vel, time)) | |
else: | |
new_notes.append((note, vel, time)) | |
new_midi = MIDI.MIDIFile(len(instruments) or 1) | |
for i, inst in enumerate(instruments or [0]): | |
new_midi.addTrack() | |
new_midi.addProgramChange(i, 0, 0, inst) | |
for note, vel, time in new_notes: | |
new_midi.addNote(i, 0, note, time, 100, vel) | |
midi_output = io.BytesIO() | |
new_midi.writeFile(midi_output) | |
midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8') | |
self.modified_files.append(midi_data) | |
return midi_data | |
def generate_onnx(self, midi_id, max_len=1024, temp=1.0, top_p=0.98, top_k=20): | |
if midi_id not in self.loaded_midi: | |
return None | |
_, mid = self.loaded_midi[midi_id] | |
mid_seq = self.tokenizer.tokenize(MIDI.midi2score(mid)) | |
mid = np.asarray([mid_seq], dtype=np.int64) | |
generator = np.random.RandomState(random.randint(0, MAX_SEED)) | |
# Simplified ONNX generation from app_onnx.py | |
input_tensor = mid | |
cur_len = input_tensor.shape[1] | |
model = [self.model_base, self.model_token, self.tokenizer] | |
while cur_len < max_len: | |
inputs = {"x": rt.OrtValue.ortvalue_from_numpy(input_tensor[:, -1:], device_type="cuda")} | |
outputs = {"hidden": rt.OrtValue.ortvalue_from_shape_and_type((1, 1, 1024), np.float32, device_type="cuda")} | |
io_binding = model[0].io_binding() | |
for name, val in inputs.items(): | |
io_binding.bind_ortvalue_input(name, val) | |
for name in outputs: | |
io_binding.bind_ortvalue_output(name, outputs[name]) | |
model[0].run_with_iobinding(io_binding) | |
hidden = outputs["hidden"].numpy()[:, -1:] | |
logits = model[1].run(None, {"hidden": hidden})[0] | |
scores = softmax(logits / temp, -1) | |
next_token = sample_top_p_k(scores, top_p, top_k, generator) | |
input_tensor = np.concatenate([input_tensor, next_token], axis=1) | |
cur_len += 1 | |
mid_seq = input_tensor.tolist()[0] | |
new_midi = self.tokenizer.detokenize(mid_seq) | |
midi_output = io.BytesIO() | |
MIDI.score2midi(new_midi, midi_output) | |
midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8') | |
self.modified_files.append(midi_data) | |
return midi_data | |
def play_with_loop(self, midi_data): | |
self.is_playing = True | |
midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data))) | |
while self.is_playing: | |
self.synthesizer.play_midi(midi_file) | |
def stop_playback(self): | |
self.is_playing = False | |
return "Playback stopped" | |
def softmax(x, axis): | |
x_max = np.max(x, axis=axis, keepdims=True) | |
exp_x_shifted = np.exp(x - x_max) | |
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True) | |
def sample_top_p_k(probs, p, k, generator=None): | |
if generator is None: | |
generator = np.random | |
probs_idx = np.argsort(-probs, axis=-1) | |
probs_sort = np.take_along_axis(probs, probs_idx, -1) | |
probs_sum = np.cumsum(probs_sort, axis=-1) | |
mask = probs_sum - probs_sort > p | |
probs_sort[mask] = 0.0 | |
mask = np.zeros(probs_sort.shape[-1]) | |
mask[:k] = 1 | |
probs_sort *= mask | |
probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True) | |
shape = probs_sort.shape | |
probs_sort_flat = probs_sort.reshape(-1, shape[-1]) | |
probs_idx_flat = probs_idx.reshape(-1, shape[-1]) | |
next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)]) | |
return next_token.reshape(*shape[:-1]) | |
def create_download_list(): | |
html = "<h3>Downloads</h3><ul>" | |
for i, midi_data in enumerate(midi_processor.modified_files): | |
html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="midi_{i}.mid">MIDI {i}</a></li>' | |
html += "</ul>" | |
return html | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--share", action="store_true") | |
opt = parser.parse_args() | |
midi_manager = MIDIDeviceManager() | |
midi_processor = MIDIManager() | |
with gr.Blocks(theme=gr.themes.Soft()) as app: | |
gr.Markdown("<h1>🎵 MIDI Composer 🎵</h1>") | |
with gr.Tabs(): | |
# Tab 1: MIDI Prompt (Main Tab) | |
with gr.Tab("MIDI Prompt"): | |
midi_upload = gr.File(label="Upload MIDI File", file_count="multiple") | |
output = gr.Audio(label="Generated MIDI", type="bytes", autoplay=True) | |
status = gr.Textbox(label="Status", value="Ready", interactive=False) | |
def process_midi(files): | |
if not files: | |
return None, "No file uploaded" | |
midi_data = None | |
for file in files: | |
midi_id = midi_processor.load_midi(file.name) | |
# Use ONNX generation for advanced synthesis | |
midi_data = midi_processor.generate_onnx(midi_id, max_len=1024) | |
midi_processor.play_with_loop(midi_data) | |
return io.BytesIO(base64.b64decode(midi_data)), "Playing", create_download_list() | |
midi_upload.change(process_midi, inputs=[midi_upload], | |
outputs=[output, status, "downloads"]) | |
# Tab 2: Downloads | |
with gr.Tab("Downloads", elem_id="downloads"): | |
downloads = gr.HTML(value="No generated files yet") | |
# Tab 3: Devices | |
with gr.Tab("Devices"): | |
device_info = gr.Textbox(label="Connected MIDI Devices", value=midi_manager.get_device_info(), interactive=False) | |
refresh_btn = gr.Button("Refresh Devices") | |
stop_btn = gr.Button("Stop Playback") | |
def refresh_devices(): | |
return midi_manager.get_device_info() | |
refresh_btn.click(refresh_devices, inputs=None, outputs=[device_info]) | |
stop_btn.click(midi_processor.stop_playback, inputs=None, outputs=[status]) | |
gr.Markdown(""" | |
<div style='text-align: center; margin-top: 20px;'> | |
<img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br> | |
<strong>Hugging Face</strong><br> | |
<a href='https://huggingface.co/models'>Models</a> | | |
<a href='https://huggingface.co/datasets'>Datasets</a> | | |
<a href='https://huggingface.co/spaces'>Spaces</a> | | |
<a href='https://huggingface.co/posts'>Posts</a> | | |
<a href='https://huggingface.co/docs'>Docs</a> | | |
<a href='https://huggingface.co/enterprise'>Enterprise</a> | | |
<a href='https://huggingface.co/pricing'>Pricing</a> | |
</div> | |
""") | |
app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True) | |
midi_manager.close() |