import random import gradio as gr import numpy as np import rtmidi import MIDI import base64 import io import os from huggingface_hub import hf_hub_download from midi_synthesizer import MidiSynthesizer MAX_SEED = np.iinfo(np.int32).max class MIDIManager: def __init__(self): self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") self.synthesizer = MidiSynthesizer(self.soundfont_path) self.loaded_midi = {} # Key: midi_id, Value: (file_path, midi_data) self.modified_files = [] self.is_playing = False self.midi_in = rtmidi.MidiIn() self.midi_in.open_port(0) if self.midi_in.get_ports() else None self.midi_in.set_callback(self.midi_callback) self.live_notes = [] self.example_files = self.load_example_midis() def load_example_midis(self): # Check for MIDI files in a local 'examples' directory or predefined paths example_dir = "examples" # Adjust this path as needed examples = {} if os.path.exists(example_dir): for file in os.listdir(example_dir): if file.endswith(".mid") or file.endswith(".midi"): midi_id = f"example_{len(examples)}" file_path = os.path.join(example_dir, file) examples[midi_id] = (file_path, MIDI.load(file_path)) # Add a default example if none found if not examples: midi = MIDI.MIDIFile(1) midi.addTrack() midi.addNote(0, 0, 60, 0, 100, 100) # C4 note examples["example_0"] = ("Simple C4.mid", midi) return examples def midi_callback(self, event, data=None): message, _ = event if len(message) >= 3 and message[0] & 0xF0 == 0x90: # Note On note, velocity = message[1], message[2] if velocity > 0: self.live_notes.append((note, velocity, 0)) def load_midi(self, file_path): midi = MIDI.load(file_path) midi_id = f"midi_{len(self.loaded_midi) - len(self.example_files)}" self.loaded_midi[midi_id] = (file_path, midi) return midi_id def extract_notes(self, midi): notes = [] for track in midi.tracks: for event in track.events: if event.type == 'note_on' and event.velocity > 0: notes.append((event.note, event.velocity, event.time)) return notes def generate_variation(self, midi_id, length_factor=2, variation=0.3): if midi_id not in self.loaded_midi: return None _, midi = self.loaded_midi[midi_id] notes = self.extract_notes(midi) new_notes = [] for _ in range(int(length_factor)): for note, vel, time in notes: if random.random() < variation: new_note = min(127, max(0, note + random.randint(-2, 2))) new_vel = min(127, max(0, vel + random.randint(-10, 10))) new_notes.append((new_note, new_vel, time)) else: new_notes.append((note, vel, time)) new_midi = MIDI.MIDIFile(1) new_midi.addTrack() for note, vel, time in new_notes: new_midi.addNote(0, 0, note, time, 100, vel) output = io.BytesIO() new_midi.writeFile(output) midi_data = base64.b64encode(output.getvalue()).decode('utf-8') self.modified_files.append(midi_data) return midi_data def apply_synth_effect(self, midi_data, effect, intensity): midi = MIDI.load(io.BytesIO(base64.b64decode(midi_data))) if effect == "tempo": factor = 1 + (intensity - 0.5) * 0.4 for track in midi.tracks: for event in track.events: event.time = int(event.time * factor) output = io.BytesIO() midi.writeFile(output) midi_data = base64.b64encode(output.getvalue()).decode('utf-8') self.modified_files.append(midi_data) return midi_data def play_with_loop(self, midi_data): self.is_playing = True midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data))) while self.is_playing: self.synthesizer.play_midi(midi_file) return "Stopped" def stop_playback(self): self.is_playing = False return "Stopping..." def save_live_midi(self): if not self.live_notes: return None midi = MIDI.MIDIFile(1) midi.addTrack() time_cum = 0 for note, vel, _ in self.live_notes: midi.addNote(0, 0, note, time_cum, 100, vel) time_cum += 100 output = io.BytesIO() midi.writeFile(output) midi_data = base64.b64encode(output.getvalue()).decode('utf-8') self.modified_files.append(midi_data) self.live_notes = [] return midi_data midi_manager = MIDIManager() def create_download_list(): html = "

Downloads

" return html def get_midi_choices(): return [(os.path.basename(path), midi_id) for midi_id, (path, _) in midi_manager.loaded_midi.items()] with gr.Blocks(theme=gr.themes.Soft()) as app: gr.Markdown("

🎵 MIDI Composer 🎵

") with gr.Tabs(): # Tab 1: Load MIDI Files with gr.Tab("Load MIDI"): midi_files = gr.File(label="Upload MIDI Files", file_count="multiple") midi_list = gr.State({}) file_display = gr.HTML(value="No files loaded") output = gr.Audio(label="Generated Preview", type="bytes", autoplay=True) def load_and_generate(files): midi_list_val = midi_manager.loaded_midi.copy() html = "

Loaded Files

" midi_data = None for file in files or []: midi_id = midi_manager.load_midi(file.name) midi_list_val[midi_id] = (file.name, midi_manager.loaded_midi[midi_id][1]) html += f"
{file.name}
" midi_data = midi_manager.generate_variation(midi_id) return (midi_list_val, html, io.BytesIO(base64.b64decode(midi_data)) if midi_data else None, get_midi_choices()) midi_files.change(load_and_generate, inputs=[midi_files], outputs=[midi_list, file_display, output, gr.State(get_midi_choices())]) # Tab 2: Generate & Perform with gr.Tab("Generate & Perform"): midi_select = gr.Dropdown(label="Select MIDI", choices=get_midi_choices(), value=None) length_factor = gr.Slider(1, 5, value=2, step=1, label="Length Factor") variation = gr.Slider(0, 1, value=0.3, label="Variation") generate_btn = gr.Button("Generate") effect = gr.Radio(["tempo"], label="Effect", value="tempo") intensity = gr.Slider(0, 1, value=0.5, label="Intensity") apply_btn = gr.Button("Apply Effect") stop_btn = gr.Button("Stop Playback") output = gr.Audio(label="Preview", type="bytes", autoplay=True) status = gr.Textbox(label="Status", value="Ready") def update_dropdown(midi_list): return gr.update(choices=get_midi_choices()) midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_select]) def generate(midi_id, length, var): if not midi_id: return None, "Select a MIDI file" midi_data = midi_manager.generate_variation(midi_id, length, var) midi_manager.play_with_loop(midi_data) return io.BytesIO(base64.b64decode(midi_data)), "Playing" def apply_effect(midi_data, fx, inten): if not midi_data: return None, "Generate a MIDI first" new_data = midi_manager.apply_synth_effect(midi_data.decode('utf-8'), fx, inten) midi_manager.play_with_loop(new_data) return io.BytesIO(base64.b64decode(new_data)), "Playing" generate_btn.click(generate, inputs=[midi_select, length_factor, variation], outputs=[output, status]) apply_btn.click(apply_effect, inputs=[output, effect, intensity], outputs=[output, status]) stop_btn.click(midi_manager.stop_playback, inputs=None, outputs=[status]) # Tab 3: MIDI Input with gr.Tab("MIDI Input"): gr.Markdown("Play your MIDI keyboard to record notes") save_btn = gr.Button("Save Live MIDI") live_output = gr.Audio(label="Live MIDI", type="bytes", autoplay=True) def save_live(): midi_data = midi_manager.save_live_midi() if midi_data: midi_manager.play_with_loop(midi_data) return io.BytesIO(base64.b64decode(midi_data)) return None save_btn.click(save_live, inputs=None, outputs=[live_output]) # Tab 4: Downloads with gr.Tab("Downloads"): downloads = gr.HTML(value="No files yet") def update_downloads(*args): return create_download_list() gr.on(triggers=[midi_files.change, generate_btn.click, apply_btn.click, save_btn.click], fn=update_downloads, inputs=None, outputs=[downloads]) gr.Markdown("""
Hugging Face Logo
Hugging Face
Models | Datasets | Spaces | Posts | Docs | Enterprise | Pricing
""") app.queue().launch(inbrowser=True)