awacke1 commited on
Commit
784b974
·
verified ·
1 Parent(s): dbc6fc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -24
app.py CHANGED
@@ -5,6 +5,7 @@ import rtmidi
5
  import MIDI
6
  import base64
7
  import io
 
8
  from huggingface_hub import hf_hub_download
9
  from midi_synthesizer import MidiSynthesizer
10
 
@@ -14,25 +15,44 @@ class MIDIManager:
14
  def __init__(self):
15
  self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
16
  self.synthesizer = MidiSynthesizer(self.soundfont_path)
17
- self.loaded_midi = {} # Store uploaded MIDI files
18
- self.modified_files = [] # Track generated files
19
  self.is_playing = False
20
  self.midi_in = rtmidi.MidiIn()
21
  self.midi_in.open_port(0) if self.midi_in.get_ports() else None
22
  self.midi_in.set_callback(self.midi_callback)
23
  self.live_notes = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def midi_callback(self, event, data=None):
26
  message, _ = event
27
  if len(message) >= 3 and message[0] & 0xF0 == 0x90: # Note On
28
  note, velocity = message[1], message[2]
29
  if velocity > 0:
30
- self.live_notes.append((note, velocity, 0)) # Time placeholder
31
 
32
  def load_midi(self, file_path):
33
  midi = MIDI.load(file_path)
34
- midi_id = f"midi_{len(self.loaded_midi)}"
35
- self.loaded_midi[midi_id] = midi
36
  return midi_id
37
 
38
  def extract_notes(self, midi):
@@ -46,7 +66,8 @@ class MIDIManager:
46
  def generate_variation(self, midi_id, length_factor=2, variation=0.3):
47
  if midi_id not in self.loaded_midi:
48
  return None
49
- notes = self.extract_notes(self.loaded_midi[midi_id])
 
50
  new_notes = []
51
  for _ in range(int(length_factor)):
52
  for note, vel, time in notes:
@@ -100,12 +121,12 @@ class MIDIManager:
100
  time_cum = 0
101
  for note, vel, _ in self.live_notes:
102
  midi.addNote(0, 0, note, time_cum, 100, vel)
103
- time_cum += 100 # Simple timing
104
  output = io.BytesIO()
105
  midi.writeFile(output)
106
  midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
107
  self.modified_files.append(midi_data)
108
- self.live_notes = [] # Reset after saving
109
  return midi_data
110
 
111
  midi_manager = MIDIManager()
@@ -117,6 +138,9 @@ def create_download_list():
117
  html += "</ul>"
118
  return html
119
 
 
 
 
120
  with gr.Blocks(theme=gr.themes.Soft()) as app:
121
  gr.Markdown("<h1>🎵 MIDI Composer 🎵</h1>")
122
 
@@ -126,34 +150,39 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
126
  midi_files = gr.File(label="Upload MIDI Files", file_count="multiple")
127
  midi_list = gr.State({})
128
  file_display = gr.HTML(value="No files loaded")
 
129
 
130
- def load_files(files):
131
- midi_list_val = {}
132
  html = "<h3>Loaded Files</h3>"
 
133
  for file in files or []:
134
  midi_id = midi_manager.load_midi(file.name)
135
- midi_list_val[midi_id] = file.name
136
  html += f"<div>{file.name}</div>"
137
- return midi_list_val, html
 
 
 
138
 
139
- midi_files.change(load_files, inputs=[midi_files], outputs=[midi_list, file_display])
 
140
 
141
  # Tab 2: Generate & Perform
142
  with gr.Tab("Generate & Perform"):
143
- midi_select = gr.Dropdown(label="Select MIDI", choices=[])
144
  length_factor = gr.Slider(1, 5, value=2, step=1, label="Length Factor")
145
  variation = gr.Slider(0, 1, value=0.3, label="Variation")
146
  generate_btn = gr.Button("Generate")
147
  effect = gr.Radio(["tempo"], label="Effect", value="tempo")
148
  intensity = gr.Slider(0, 1, value=0.5, label="Intensity")
149
  apply_btn = gr.Button("Apply Effect")
150
- play_btn = gr.Button("Play Loop")
151
- stop_btn = gr.Button("Stop")
152
- output = gr.Audio(label="Preview", type="bytes")
153
  status = gr.Textbox(label="Status", value="Ready")
154
 
155
  def update_dropdown(midi_list):
156
- return gr.update(choices=list(midi_list.keys()))
157
 
158
  midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_select])
159
 
@@ -161,30 +190,34 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
161
  if not midi_id:
162
  return None, "Select a MIDI file"
163
  midi_data = midi_manager.generate_variation(midi_id, length, var)
164
- return io.BytesIO(base64.b64decode(midi_data)), "Generated"
 
165
 
166
  def apply_effect(midi_data, fx, inten):
167
  if not midi_data:
168
  return None, "Generate a MIDI first"
169
  new_data = midi_manager.apply_synth_effect(midi_data.decode('utf-8'), fx, inten)
170
- return io.BytesIO(base64.b64decode(new_data)), "Effect Applied"
 
171
 
172
  generate_btn.click(generate, inputs=[midi_select, length_factor, variation],
173
  outputs=[output, status])
174
  apply_btn.click(apply_effect, inputs=[output, effect, intensity],
175
  outputs=[output, status])
176
- play_btn.click(midi_manager.play_with_loop, inputs=[output], outputs=[status])
177
  stop_btn.click(midi_manager.stop_playback, inputs=None, outputs=[status])
178
 
179
  # Tab 3: MIDI Input
180
  with gr.Tab("MIDI Input"):
181
  gr.Markdown("Play your MIDI keyboard to record notes")
182
  save_btn = gr.Button("Save Live MIDI")
183
- live_output = gr.Audio(label="Live MIDI", type="bytes")
184
 
185
  def save_live():
186
  midi_data = midi_manager.save_live_midi()
187
- return io.BytesIO(base64.b64decode(midi_data)) if midi_data else None
 
 
 
188
 
189
  save_btn.click(save_live, inputs=None, outputs=[live_output])
190
 
@@ -193,7 +226,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
193
  downloads = gr.HTML(value="No files yet")
194
  def update_downloads(*args):
195
  return create_download_list()
196
- gr.on(triggers=[generate_btn.click, apply_btn.click, save_btn.click],
197
  fn=update_downloads, inputs=None, outputs=[downloads])
198
 
199
  gr.Markdown("""
 
5
  import MIDI
6
  import base64
7
  import io
8
+ import os
9
  from huggingface_hub import hf_hub_download
10
  from midi_synthesizer import MidiSynthesizer
11
 
 
15
  def __init__(self):
16
  self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
17
  self.synthesizer = MidiSynthesizer(self.soundfont_path)
18
+ self.loaded_midi = {} # Key: midi_id, Value: (file_path, midi_data)
19
+ self.modified_files = []
20
  self.is_playing = False
21
  self.midi_in = rtmidi.MidiIn()
22
  self.midi_in.open_port(0) if self.midi_in.get_ports() else None
23
  self.midi_in.set_callback(self.midi_callback)
24
  self.live_notes = []
25
+ self.example_files = self.load_example_midis()
26
+
27
+ def load_example_midis(self):
28
+ # Check for MIDI files in a local 'examples' directory or predefined paths
29
+ example_dir = "examples" # Adjust this path as needed
30
+ examples = {}
31
+ if os.path.exists(example_dir):
32
+ for file in os.listdir(example_dir):
33
+ if file.endswith(".mid") or file.endswith(".midi"):
34
+ midi_id = f"example_{len(examples)}"
35
+ file_path = os.path.join(example_dir, file)
36
+ examples[midi_id] = (file_path, MIDI.load(file_path))
37
+ # Add a default example if none found
38
+ if not examples:
39
+ midi = MIDI.MIDIFile(1)
40
+ midi.addTrack()
41
+ midi.addNote(0, 0, 60, 0, 100, 100) # C4 note
42
+ examples["example_0"] = ("Simple C4.mid", midi)
43
+ return examples
44
 
45
  def midi_callback(self, event, data=None):
46
  message, _ = event
47
  if len(message) >= 3 and message[0] & 0xF0 == 0x90: # Note On
48
  note, velocity = message[1], message[2]
49
  if velocity > 0:
50
+ self.live_notes.append((note, velocity, 0))
51
 
52
  def load_midi(self, file_path):
53
  midi = MIDI.load(file_path)
54
+ midi_id = f"midi_{len(self.loaded_midi) - len(self.example_files)}"
55
+ self.loaded_midi[midi_id] = (file_path, midi)
56
  return midi_id
57
 
58
  def extract_notes(self, midi):
 
66
  def generate_variation(self, midi_id, length_factor=2, variation=0.3):
67
  if midi_id not in self.loaded_midi:
68
  return None
69
+ _, midi = self.loaded_midi[midi_id]
70
+ notes = self.extract_notes(midi)
71
  new_notes = []
72
  for _ in range(int(length_factor)):
73
  for note, vel, time in notes:
 
121
  time_cum = 0
122
  for note, vel, _ in self.live_notes:
123
  midi.addNote(0, 0, note, time_cum, 100, vel)
124
+ time_cum += 100
125
  output = io.BytesIO()
126
  midi.writeFile(output)
127
  midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
128
  self.modified_files.append(midi_data)
129
+ self.live_notes = []
130
  return midi_data
131
 
132
  midi_manager = MIDIManager()
 
138
  html += "</ul>"
139
  return html
140
 
141
+ def get_midi_choices():
142
+ return [(os.path.basename(path), midi_id) for midi_id, (path, _) in midi_manager.loaded_midi.items()]
143
+
144
  with gr.Blocks(theme=gr.themes.Soft()) as app:
145
  gr.Markdown("<h1>🎵 MIDI Composer 🎵</h1>")
146
 
 
150
  midi_files = gr.File(label="Upload MIDI Files", file_count="multiple")
151
  midi_list = gr.State({})
152
  file_display = gr.HTML(value="No files loaded")
153
+ output = gr.Audio(label="Generated Preview", type="bytes", autoplay=True)
154
 
155
+ def load_and_generate(files):
156
+ midi_list_val = midi_manager.loaded_midi.copy()
157
  html = "<h3>Loaded Files</h3>"
158
+ midi_data = None
159
  for file in files or []:
160
  midi_id = midi_manager.load_midi(file.name)
161
+ midi_list_val[midi_id] = (file.name, midi_manager.loaded_midi[midi_id][1])
162
  html += f"<div>{file.name}</div>"
163
+ midi_data = midi_manager.generate_variation(midi_id)
164
+ return (midi_list_val, html,
165
+ io.BytesIO(base64.b64decode(midi_data)) if midi_data else None,
166
+ get_midi_choices())
167
 
168
+ midi_files.change(load_and_generate, inputs=[midi_files],
169
+ outputs=[midi_list, file_display, output, gr.State(get_midi_choices())])
170
 
171
  # Tab 2: Generate & Perform
172
  with gr.Tab("Generate & Perform"):
173
+ midi_select = gr.Dropdown(label="Select MIDI", choices=get_midi_choices(), value=None)
174
  length_factor = gr.Slider(1, 5, value=2, step=1, label="Length Factor")
175
  variation = gr.Slider(0, 1, value=0.3, label="Variation")
176
  generate_btn = gr.Button("Generate")
177
  effect = gr.Radio(["tempo"], label="Effect", value="tempo")
178
  intensity = gr.Slider(0, 1, value=0.5, label="Intensity")
179
  apply_btn = gr.Button("Apply Effect")
180
+ stop_btn = gr.Button("Stop Playback")
181
+ output = gr.Audio(label="Preview", type="bytes", autoplay=True)
 
182
  status = gr.Textbox(label="Status", value="Ready")
183
 
184
  def update_dropdown(midi_list):
185
+ return gr.update(choices=get_midi_choices())
186
 
187
  midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_select])
188
 
 
190
  if not midi_id:
191
  return None, "Select a MIDI file"
192
  midi_data = midi_manager.generate_variation(midi_id, length, var)
193
+ midi_manager.play_with_loop(midi_data)
194
+ return io.BytesIO(base64.b64decode(midi_data)), "Playing"
195
 
196
  def apply_effect(midi_data, fx, inten):
197
  if not midi_data:
198
  return None, "Generate a MIDI first"
199
  new_data = midi_manager.apply_synth_effect(midi_data.decode('utf-8'), fx, inten)
200
+ midi_manager.play_with_loop(new_data)
201
+ return io.BytesIO(base64.b64decode(new_data)), "Playing"
202
 
203
  generate_btn.click(generate, inputs=[midi_select, length_factor, variation],
204
  outputs=[output, status])
205
  apply_btn.click(apply_effect, inputs=[output, effect, intensity],
206
  outputs=[output, status])
 
207
  stop_btn.click(midi_manager.stop_playback, inputs=None, outputs=[status])
208
 
209
  # Tab 3: MIDI Input
210
  with gr.Tab("MIDI Input"):
211
  gr.Markdown("Play your MIDI keyboard to record notes")
212
  save_btn = gr.Button("Save Live MIDI")
213
+ live_output = gr.Audio(label="Live MIDI", type="bytes", autoplay=True)
214
 
215
  def save_live():
216
  midi_data = midi_manager.save_live_midi()
217
+ if midi_data:
218
+ midi_manager.play_with_loop(midi_data)
219
+ return io.BytesIO(base64.b64decode(midi_data))
220
+ return None
221
 
222
  save_btn.click(save_live, inputs=None, outputs=[live_output])
223
 
 
226
  downloads = gr.HTML(value="No files yet")
227
  def update_downloads(*args):
228
  return create_download_list()
229
+ gr.on(triggers=[midi_files.change, generate_btn.click, apply_btn.click, save_btn.click],
230
  fn=update_downloads, inputs=None, outputs=[downloads])
231
 
232
  gr.Markdown("""