Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,493 Bytes
e8f8f3c 1151735 e8f8f3c 1151735 92acab8 1151735 75808a5 92acab8 75808a5 dbc6fc5 75808a5 92acab8 1151735 92acab8 1151735 92acab8 1151735 b50d4ec e8f8f3c b50d4ec 1151735 e8f8f3c 1151735 b50d4ec 1151735 92acab8 1151735 92acab8 75808a5 1151735 75808a5 dbc6fc5 92acab8 1151735 75808a5 1151735 75808a5 1151735 e8f8f3c 75808a5 e8f8f3c 1151735 75808a5 92acab8 1151735 92acab8 1151735 92acab8 75808a5 1151735 75808a5 1151735 75808a5 1151735 75808a5 1151735 75808a5 1151735 dbc6fc5 1151735 92acab8 1151735 e8f8f3c 1151735 e8f8f3c 1151735 e8f8f3c 1151735 e8f8f3c 1151735 e8f8f3c 1151735 e8f8f3c 1151735 75808a5 1151735 75808a5 1151735 e8f8f3c 1151735 75808a5 1151735 75808a5 e8f8f3c 1151735 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import argparse
import base64
import io
import os
import random
import numpy as np
import gradio as gr
import rtmidi
import onnxruntime as rt
from huggingface_hub import hf_hub_download
import MIDI
from midi_synthesizer import MidiSynthesizer
from midi_tokenizer import MIDITokenizer
# Constants
MAX_SEED = np.iinfo(np.int32).max
IN_SPACE = os.getenv("SYSTEM") == "spaces"
MAX_LENGTH = 1024 # Maximum tokens for generation
# MIDI Device Manager
class MIDIDeviceManager:
def __init__(self):
self.midiout = rtmidi.MidiOut()
self.midiin = rtmidi.MidiIn()
def get_device_info(self):
out_ports = self.midiout.get_ports() or ["No MIDI output devices"]
in_ports = self.midiin.get_ports() or ["No MIDI input devices"]
return f"Output Devices:\n{'\n'.join(out_ports)}\n\nInput Devices:\n{'\n'.join(in_ports)}"
def close(self):
if self.midiout.is_port_open():
self.midiout.close_port()
if self.midiin.is_port_open():
self.midiin.close_port()
del self.midiout, self.midiin
# MIDI Processor with ONNX Generation
class MIDIManager:
def __init__(self):
self.soundfont = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
self.synthesizer = MidiSynthesizer(self.soundfont)
self.tokenizer = self._load_tokenizer("skytnt/midi-model")
self.model_base = rt.InferenceSession(
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_base.onnx"),
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.model_token = rt.InferenceSession(
hf_hub_download(repo_id="skytnt/midi-model", filename="onnx/model_token.onnx"),
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
)
self.generated_files = [] # Store base64-encoded MIDI data
self.is_playing = False
def _load_tokenizer(self, repo_id):
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
with open(config_path, "r") as f:
config = json.load(f)
tokenizer = MIDITokenizer(config["tokenizer"]["version"])
tokenizer.set_optimise_midi(config["tokenizer"]["optimise_midi"])
return tokenizer
def load_midi(self, file_path):
try:
return MIDI.load(file_path)
except Exception as e:
raise ValueError(f"Failed to load MIDI file: {e}")
def generate_variation(self, midi_data, temp=1.0, top_p=0.98, top_k=20):
# Tokenize input MIDI
mid_seq = self.tokenizer.tokenize(MIDI.midi2score(midi_data))
input_tensor = np.array([mid_seq], dtype=np.int64)
cur_len = input_tensor.shape[1]
generator = np.random.RandomState(random.randint(0, MAX_SEED))
# Generate up to MAX_LENGTH
while cur_len < MAX_LENGTH:
inputs = {"x": input_tensor[:, -1:]} # Last token
hidden = self.model_base.run(None, inputs)[0] # Base model output
logits = self.model_token.run(None, {"hidden": hidden})[0] # Token model output
probs = softmax(logits / temp, axis=-1)
next_token = sample_top_p_k(probs, top_p, top_k, generator)
input_tensor = np.concatenate([input_tensor, next_token], axis=1)
cur_len += 1
# Detokenize and save as MIDI
new_seq = input_tensor[0].tolist()
new_midi = self.tokenizer.detokenize(new_seq)
midi_output = io.BytesIO()
MIDI.score2midi(new_midi, midi_output)
midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8')
self.generated_files.append(midi_data)
return midi_data
def play_midi(self, midi_data):
self.is_playing = True
midi_bytes = base64.b64decode(midi_data)
midi_file = MIDI.load(io.BytesIO(midi_bytes))
audio = io.BytesIO()
self.synthesizer.render_midi(midi_file, audio)
audio.seek(0)
return audio
def stop(self):
self.is_playing = False
# Helper Functions
def softmax(x, axis):
exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
return exp_x / np.sum(exp_x, axis=axis, keepdims=True)
def sample_top_p_k(probs, p, k, generator):
probs_idx = np.argsort(-probs, axis=-1)
probs_sort = np.take_along_axis(probs, probs_idx, axis=-1)
probs_sum = np.cumsum(probs_sort, axis=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort[:, k:] = 0.0 # Top-k filtering
probs_sort /= probs_sort.sum(axis=-1, keepdims=True)
next_token = generator.choice(probs.shape[-1], p=probs_sort[0])
return np.array([[next_token]])
# UI Functions
def process_midi_upload(files):
if not files:
return None, "No file uploaded", ""
file = files[0] # Process first file
try:
midi_data = midi_processor.load_midi(file.name)
generated_midi = midi_processor.generate_variation(midi_data)
audio = midi_processor.play_midi(generated_midi)
download_html = create_download_list()
return audio, "Generated and playing", download_html
except Exception as e:
return None, f"Error: {e}", ""
def create_download_list():
if not midi_processor.generated_files:
return "<p>No generated files yet.</p>"
html = "<h3>Generated MIDI Files</h3><ul>"
for i, midi_data in enumerate(midi_processor.generated_files):
html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="generated_{i}.mid">Download MIDI {i}</a></li>'
html += "</ul>"
return html
def refresh_devices():
return device_manager.get_device_info()
def stop_playback():
midi_processor.stop()
return "Playback stopped"
# Main Application
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MIDI Composer with ONNX Generation")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
device_manager = MIDIDeviceManager()
midi_processor = MIDIManager()
with gr.Blocks(title="MIDI Composer", theme=gr.themes.Soft()) as app:
gr.Markdown("# 🎵 MIDI Composer 🎵")
with gr.Tabs():
# MIDI Prompt Tab
with gr.Tab("MIDI Prompt"):
midi_upload = gr.File(label="Upload MIDI File", file_types=[".mid", ".midi"])
audio_output = gr.Audio(label="Generated MIDI", type="bytes", autoplay=True)
status = gr.Textbox(label="Status", value="Ready", interactive=False)
midi_upload.change(
process_midi_upload,
inputs=[midi_upload],
outputs=[audio_output, status, gr.HTML(elem_id="downloads")]
)
# Downloads Tab
with gr.Tab("Downloads", elem_id="downloads"):
gr.HTML(value=create_download_list())
# Devices Tab
with gr.Tab("Devices"):
device_info = gr.Textbox(label="MIDI Devices", value=device_manager.get_device_info(), interactive=False)
refresh_btn = gr.Button("Refresh Devices")
stop_btn = gr.Button("Stop Playback")
refresh_btn.click(refresh_devices, outputs=[device_info])
stop_btn.click(stop_playback, outputs=[status])
app.launch(server_port=args.port, share=args.share, inbrowser=True)
device_manager.close() |