midi-composer / app.py
awacke1's picture
Update app.py
784b974 verified
raw
history blame
10.6 kB
import random
import gradio as gr
import numpy as np
import rtmidi
import MIDI
import base64
import io
import os
from huggingface_hub import hf_hub_download
from midi_synthesizer import MidiSynthesizer
MAX_SEED = np.iinfo(np.int32).max
class MIDIManager:
def __init__(self):
self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
self.synthesizer = MidiSynthesizer(self.soundfont_path)
self.loaded_midi = {} # Key: midi_id, Value: (file_path, midi_data)
self.modified_files = []
self.is_playing = False
self.midi_in = rtmidi.MidiIn()
self.midi_in.open_port(0) if self.midi_in.get_ports() else None
self.midi_in.set_callback(self.midi_callback)
self.live_notes = []
self.example_files = self.load_example_midis()
def load_example_midis(self):
# Check for MIDI files in a local 'examples' directory or predefined paths
example_dir = "examples" # Adjust this path as needed
examples = {}
if os.path.exists(example_dir):
for file in os.listdir(example_dir):
if file.endswith(".mid") or file.endswith(".midi"):
midi_id = f"example_{len(examples)}"
file_path = os.path.join(example_dir, file)
examples[midi_id] = (file_path, MIDI.load(file_path))
# Add a default example if none found
if not examples:
midi = MIDI.MIDIFile(1)
midi.addTrack()
midi.addNote(0, 0, 60, 0, 100, 100) # C4 note
examples["example_0"] = ("Simple C4.mid", midi)
return examples
def midi_callback(self, event, data=None):
message, _ = event
if len(message) >= 3 and message[0] & 0xF0 == 0x90: # Note On
note, velocity = message[1], message[2]
if velocity > 0:
self.live_notes.append((note, velocity, 0))
def load_midi(self, file_path):
midi = MIDI.load(file_path)
midi_id = f"midi_{len(self.loaded_midi) - len(self.example_files)}"
self.loaded_midi[midi_id] = (file_path, midi)
return midi_id
def extract_notes(self, midi):
notes = []
for track in midi.tracks:
for event in track.events:
if event.type == 'note_on' and event.velocity > 0:
notes.append((event.note, event.velocity, event.time))
return notes
def generate_variation(self, midi_id, length_factor=2, variation=0.3):
if midi_id not in self.loaded_midi:
return None
_, midi = self.loaded_midi[midi_id]
notes = self.extract_notes(midi)
new_notes = []
for _ in range(int(length_factor)):
for note, vel, time in notes:
if random.random() < variation:
new_note = min(127, max(0, note + random.randint(-2, 2)))
new_vel = min(127, max(0, vel + random.randint(-10, 10)))
new_notes.append((new_note, new_vel, time))
else:
new_notes.append((note, vel, time))
new_midi = MIDI.MIDIFile(1)
new_midi.addTrack()
for note, vel, time in new_notes:
new_midi.addNote(0, 0, note, time, 100, vel)
output = io.BytesIO()
new_midi.writeFile(output)
midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
self.modified_files.append(midi_data)
return midi_data
def apply_synth_effect(self, midi_data, effect, intensity):
midi = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
if effect == "tempo":
factor = 1 + (intensity - 0.5) * 0.4
for track in midi.tracks:
for event in track.events:
event.time = int(event.time * factor)
output = io.BytesIO()
midi.writeFile(output)
midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
self.modified_files.append(midi_data)
return midi_data
def play_with_loop(self, midi_data):
self.is_playing = True
midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
while self.is_playing:
self.synthesizer.play_midi(midi_file)
return "Stopped"
def stop_playback(self):
self.is_playing = False
return "Stopping..."
def save_live_midi(self):
if not self.live_notes:
return None
midi = MIDI.MIDIFile(1)
midi.addTrack()
time_cum = 0
for note, vel, _ in self.live_notes:
midi.addNote(0, 0, note, time_cum, 100, vel)
time_cum += 100
output = io.BytesIO()
midi.writeFile(output)
midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
self.modified_files.append(midi_data)
self.live_notes = []
return midi_data
midi_manager = MIDIManager()
def create_download_list():
html = "<h3>Downloads</h3><ul>"
for i, data in enumerate(midi_manager.modified_files):
html += f'<li><a href="data:audio/midi;base64,{data}" download="midi_{i}.mid">MIDI {i}</a></li>'
html += "</ul>"
return html
def get_midi_choices():
return [(os.path.basename(path), midi_id) for midi_id, (path, _) in midi_manager.loaded_midi.items()]
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("<h1>🎵 MIDI Composer 🎵</h1>")
with gr.Tabs():
# Tab 1: Load MIDI Files
with gr.Tab("Load MIDI"):
midi_files = gr.File(label="Upload MIDI Files", file_count="multiple")
midi_list = gr.State({})
file_display = gr.HTML(value="No files loaded")
output = gr.Audio(label="Generated Preview", type="bytes", autoplay=True)
def load_and_generate(files):
midi_list_val = midi_manager.loaded_midi.copy()
html = "<h3>Loaded Files</h3>"
midi_data = None
for file in files or []:
midi_id = midi_manager.load_midi(file.name)
midi_list_val[midi_id] = (file.name, midi_manager.loaded_midi[midi_id][1])
html += f"<div>{file.name}</div>"
midi_data = midi_manager.generate_variation(midi_id)
return (midi_list_val, html,
io.BytesIO(base64.b64decode(midi_data)) if midi_data else None,
get_midi_choices())
midi_files.change(load_and_generate, inputs=[midi_files],
outputs=[midi_list, file_display, output, gr.State(get_midi_choices())])
# Tab 2: Generate & Perform
with gr.Tab("Generate & Perform"):
midi_select = gr.Dropdown(label="Select MIDI", choices=get_midi_choices(), value=None)
length_factor = gr.Slider(1, 5, value=2, step=1, label="Length Factor")
variation = gr.Slider(0, 1, value=0.3, label="Variation")
generate_btn = gr.Button("Generate")
effect = gr.Radio(["tempo"], label="Effect", value="tempo")
intensity = gr.Slider(0, 1, value=0.5, label="Intensity")
apply_btn = gr.Button("Apply Effect")
stop_btn = gr.Button("Stop Playback")
output = gr.Audio(label="Preview", type="bytes", autoplay=True)
status = gr.Textbox(label="Status", value="Ready")
def update_dropdown(midi_list):
return gr.update(choices=get_midi_choices())
midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_select])
def generate(midi_id, length, var):
if not midi_id:
return None, "Select a MIDI file"
midi_data = midi_manager.generate_variation(midi_id, length, var)
midi_manager.play_with_loop(midi_data)
return io.BytesIO(base64.b64decode(midi_data)), "Playing"
def apply_effect(midi_data, fx, inten):
if not midi_data:
return None, "Generate a MIDI first"
new_data = midi_manager.apply_synth_effect(midi_data.decode('utf-8'), fx, inten)
midi_manager.play_with_loop(new_data)
return io.BytesIO(base64.b64decode(new_data)), "Playing"
generate_btn.click(generate, inputs=[midi_select, length_factor, variation],
outputs=[output, status])
apply_btn.click(apply_effect, inputs=[output, effect, intensity],
outputs=[output, status])
stop_btn.click(midi_manager.stop_playback, inputs=None, outputs=[status])
# Tab 3: MIDI Input
with gr.Tab("MIDI Input"):
gr.Markdown("Play your MIDI keyboard to record notes")
save_btn = gr.Button("Save Live MIDI")
live_output = gr.Audio(label="Live MIDI", type="bytes", autoplay=True)
def save_live():
midi_data = midi_manager.save_live_midi()
if midi_data:
midi_manager.play_with_loop(midi_data)
return io.BytesIO(base64.b64decode(midi_data))
return None
save_btn.click(save_live, inputs=None, outputs=[live_output])
# Tab 4: Downloads
with gr.Tab("Downloads"):
downloads = gr.HTML(value="No files yet")
def update_downloads(*args):
return create_download_list()
gr.on(triggers=[midi_files.change, generate_btn.click, apply_btn.click, save_btn.click],
fn=update_downloads, inputs=None, outputs=[downloads])
gr.Markdown("""
<div style='text-align: center; margin-top: 20px;'>
<img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br>
<strong>Hugging Face</strong><br>
<a href='https://huggingface.co/models'>Models</a> |
<a href='https://huggingface.co/datasets'>Datasets</a> |
<a href='https://huggingface.co/spaces'>Spaces</a> |
<a href='https://huggingface.co/posts'>Posts</a> |
<a href='https://huggingface.co/docs'>Docs</a> |
<a href='https://huggingface.co/enterprise'>Enterprise</a> |
<a href='https://huggingface.co/pricing'>Pricing</a>
</div>
""")
app.queue().launch(inbrowser=True)