import random import argparse import os import glob import json import rtmidi import gradio as gr import numpy as np import onnxruntime as rt from huggingface_hub import hf_hub_download import MIDI from midi_synthesizer import MidiSynthesizer from midi_tokenizer import MIDITokenizer MAX_SEED = np.iinfo(np.int32).max in_space = os.getenv("SYSTEM") == "spaces" class MIDIDeviceManager: def __init__(self): self.midiout = rtmidi.MidiOut() self.midiin = rtmidi.MidiIn() def get_output_devices(self): return self.midiout.get_ports() or ["No MIDI output devices"] def get_input_devices(self): return self.midiin.get_ports() or ["No MIDI input devices"] def get_device_info(self): out_devices = self.get_output_devices() in_devices = self.get_input_devices() out_info = "\n".join([f"Out Port {i}: {name}" for i, name in enumerate(out_devices)]) if out_devices else "No MIDI output devices detected" in_info = "\n".join([f"In Port {i}: {name}" for i, name in enumerate(in_devices)]) if in_devices else "No MIDI input devices detected" return f"Output Devices:\n{out_info}\n\nInput Devices:\n{in_info}" def close(self): if self.midiout.is_port_open(): self.midiout.close_port() if self.midiin.is_port_open(): self.midiin.close_port() del self.midiout del self.midiin class MIDIManager: def __init__(self): self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") self.synthesizer = MidiSynthesizer(self.soundfont_path) self.loaded_midi = {} self.modified_files = [] self.is_playing = False self.tokenizer = self.load_tokenizer("skytnt/midi-model") self.model_base = rt.InferenceSession(hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) self.model_token = rt.InferenceSession(hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"), providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) def load_tokenizer(self, repo_id): config_path = hf_hub_download(repo_id=repo_id, filename="config.json") with open(config_path, "r") as f: config = json.load(f) tokenizer = MIDITokenizer(config["tokenizer"]["version"]) tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"]) return tokenizer def load_midi(self, file_path): midi = MIDI.load(file_path) midi_id = f"midi_{len(self.loaded_midi)}" self.loaded_midi[midi_id] = (file_path, midi) return midi_id def extract_notes_and_instruments(self, midi): notes = [] instruments = set() for track in midi.tracks: for event in track.events: if event.type == 'note_on' and event.velocity > 0: notes.append((event.note, event.velocity, event.time)) if hasattr(event, 'program'): instruments.add(event.program) return notes, list(instruments) def generate_variation(self, midi_id, length_factor=10, variation=0.3): if midi_id not in self.loaded_midi: return None _, midi = self.loaded_midi[midi_id] notes, instruments = self.extract_notes_and_instruments(midi) new_notes = [] for _ in range(int(length_factor)): # Max length: 10x repetition for note, vel, time in notes: if random.random() < variation: new_note = min(127, max(0, note + random.randint(-2, 2))) new_vel = min(127, max(0, vel + random.randint(-10, 10))) new_notes.append((new_note, new_vel, time)) else: new_notes.append((note, vel, time)) new_midi = MIDI.MIDIFile(len(instruments) or 1) for i, inst in enumerate(instruments or [0]): new_midi.addTrack() new_midi.addProgramChange(i, 0, 0, inst) for note, vel, time in new_notes: new_midi.addNote(i, 0, note, time, 100, vel) midi_output = io.BytesIO() new_midi.writeFile(midi_output) midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8') self.modified_files.append(midi_data) return midi_data def generate_onnx(self, midi_id, max_len=1024, temp=1.0, top_p=0.98, top_k=20): if midi_id not in self.loaded_midi: return None _, mid = self.loaded_midi[midi_id] mid_seq = self.tokenizer.tokenize(MIDI.midi2score(mid)) mid = np.asarray([mid_seq], dtype=np.int64) generator = np.random.RandomState(random.randint(0, MAX_SEED)) # Simplified ONNX generation from app_onnx.py input_tensor = mid cur_len = input_tensor.shape[1] model = [self.model_base, self.model_token, self.tokenizer] while cur_len < max_len: inputs = {"x": rt.OrtValue.ortvalue_from_numpy(input_tensor[:, -1:], device_type="cuda")} outputs = {"hidden": rt.OrtValue.ortvalue_from_shape_and_type((1, 1, 1024), np.float32, device_type="cuda")} io_binding = model[0].io_binding() for name, val in inputs.items(): io_binding.bind_ortvalue_input(name, val) for name in outputs: io_binding.bind_ortvalue_output(name, outputs[name]) model[0].run_with_iobinding(io_binding) hidden = outputs["hidden"].numpy()[:, -1:] logits = model[1].run(None, {"hidden": hidden})[0] scores = softmax(logits / temp, -1) next_token = sample_top_p_k(scores, top_p, top_k, generator) input_tensor = np.concatenate([input_tensor, next_token], axis=1) cur_len += 1 mid_seq = input_tensor.tolist()[0] new_midi = self.tokenizer.detokenize(mid_seq) midi_output = io.BytesIO() MIDI.score2midi(new_midi, midi_output) midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8') self.modified_files.append(midi_data) return midi_data def play_with_loop(self, midi_data): self.is_playing = True midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data))) while self.is_playing: self.synthesizer.play_midi(midi_file) def stop_playback(self): self.is_playing = False return "Playback stopped" def softmax(x, axis): x_max = np.max(x, axis=axis, keepdims=True) exp_x_shifted = np.exp(x - x_max) return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True) def sample_top_p_k(probs, p, k, generator=None): if generator is None: generator = np.random probs_idx = np.argsort(-probs, axis=-1) probs_sort = np.take_along_axis(probs, probs_idx, -1) probs_sum = np.cumsum(probs_sort, axis=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 mask = np.zeros(probs_sort.shape[-1]) mask[:k] = 1 probs_sort *= mask probs_sort /= np.sum(probs_sort, axis=-1, keepdims=True) shape = probs_sort.shape probs_sort_flat = probs_sort.reshape(-1, shape[-1]) probs_idx_flat = probs_idx.reshape(-1, shape[-1]) next_token = np.stack([generator.choice(idxs, p=pvals) for pvals, idxs in zip(probs_sort_flat, probs_idx_flat)]) return next_token.reshape(*shape[:-1]) def create_download_list(): html = "