Spaces:
Sleeping
Sleeping
"""preprocess_egmd.py""" | |
import os | |
import csv | |
import glob | |
import re | |
import json | |
from typing import Dict, List, Tuple | |
import numpy as np | |
from utils.audio import get_audio_file_info | |
from utils.midi import midi2note, note_event2midi | |
from utils.note2event import note2note_event, note_event2event | |
from utils.event2note import event2note_event | |
from utils.note_event_dataclasses import Note, NoteEvent | |
from utils.utils import note_event2token2note_event_sanity_check | |
# from utils.utils import assert_note_events_almost_equal | |
def create_note_event_and_note_from_midi(mid_file: str, id: str) -> Tuple[Dict, Dict]: | |
"""Extracts note or note_event and metadata from midi: | |
Returns: | |
notes (dict): note events and metadata. | |
note_events (dict): note events and metadata. | |
""" | |
notes, dur_sec = midi2note( | |
mid_file, | |
binary_velocity=True, | |
ch_9_as_drum=True, | |
force_all_drum=True, | |
trim_overlap=True, | |
fix_offset=True, | |
quantize=True, | |
verbose=0, | |
minimum_offset_sec=0.01, | |
drum_offset_sec=0.01, | |
ignore_pedal=True) | |
return { # notes | |
'egmd_id': id, | |
'program': [128], | |
'is_drum': [1], | |
'duration_sec': dur_sec, | |
'notes': notes, | |
}, { # note_events | |
'maps_id': id, | |
'program': [128], | |
'is_drum': [1], | |
'duration_sec': dur_sec, | |
'note_events': note2note_event(notes), | |
} | |
def preprocess_egmd16k(data_home: os.PathLike, dataset_name='egmd') -> None: | |
""" | |
Splits: | |
- train: 35217 files | |
- validation: 5031 files | |
- test: 5289 files | |
- test_reduced: 246 files that contain '_5.midi' or '_10.midi' in the filename | |
Writes: | |
- {dataset_name}_{split}_file_list.json: a dictionary with the following keys: | |
{ | |
index: | |
{ | |
'egmd_id': egmd_id, # filename wihout extension | |
'n_frames': (int), | |
'mix_audio_file': 'path/to/mix.wav', | |
'notes_file': 'path/to/notes.npy', | |
'note_events_file': 'path/to/note_events.npy', | |
'midi_file': 'path/to/midi.mid', | |
'program': List[int], | |
'is_drum': List[int], # 0 or 1 | |
} | |
} | |
""" | |
# Directory and file paths | |
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') | |
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') | |
os.makedirs(output_index_dir, exist_ok=True) | |
# Load csv file and create a dictionary | |
csv_file = os.path.join(base_dir, 'e-gmd-v1.0.0.csv') | |
with open(csv_file, 'r') as f: | |
csv_dict_reader = csv.DictReader(f) | |
egmd_dict_list_all = list(csv_dict_reader) | |
assert len(egmd_dict_list_all) == 45537 | |
# Process MIDI files | |
for d in egmd_dict_list_all: | |
emgd_id = d['midi_filename'].split('.')[0] | |
midi_file = os.path.join(base_dir, d['midi_filename']) | |
notes, note_events = create_note_event_and_note_from_midi(midi_file, emgd_id) | |
# Write notes and note_events | |
notes_file = midi_file.replace('.midi', '_notes.npy') | |
note_events_file = midi_file.replace('.midi', '_note_events.npy') | |
np.save(notes_file, notes, allow_pickle=True, fix_imports=False) | |
print(f"Created {notes_file}") | |
np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) | |
print(f"Created {note_events_file}") | |
# rewrite 120 bpm quantized midi file | |
quantized_midi_file = midi_file.replace('.midi', '_quantized_120bpm.mid') | |
note_event2midi(note_events['note_events'], quantized_midi_file) | |
print(f'Wrote {quantized_midi_file}') | |
# Process audio files | |
pass | |
# Create index files | |
for split in ['train', 'validation', 'test']: | |
file_list = {} | |
i = 0 | |
for d in egmd_dict_list_all: | |
if d['split'] == split: | |
egmd_id = d['midi_filename'].split('.')[0] | |
mix_audio_file = os.path.join(base_dir, d['audio_filename']) | |
n_frames = get_audio_file_info(mix_audio_file)[1] | |
midi_file = os.path.join(base_dir, d['midi_filename']) | |
notes_file = midi_file.replace('.midi', '_notes.npy') | |
note_events_file = midi_file.replace('.midi', '_note_events.npy') | |
# check file existence | |
assert os.path.exists(mix_audio_file) | |
assert os.path.exists(midi_file) | |
assert os.path.exists(notes_file) | |
assert os.path.exists(note_events_file) | |
# create file list | |
file_list[i] = { | |
'egmd_id': egmd_id, | |
'n_frames': n_frames, | |
'mix_audio_file': mix_audio_file, | |
'notes_file': notes_file, | |
'note_events_file': note_events_file, | |
'midi_file': midi_file, | |
'program': [128], | |
'is_drum': [1], | |
} | |
i += 1 | |
else: | |
pass | |
# Write file list | |
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') | |
with open(output_file, 'w') as f: | |
json.dump(file_list, f, indent=4) | |
print(f'Wrote {output_file}') | |
if split == 'train': | |
assert len(file_list) == 35217 | |
elif split == 'validation': | |
assert len(file_list) == 5031 | |
elif split == 'test': | |
assert len(file_list) == 5289 | |
# Create reduced test index file | |
split = 'test_reduced' | |
file_list = {} | |
i = 0 | |
for d in egmd_dict_list_all: | |
if d['split'] == 'test': | |
midi_file = os.path.join(base_dir, d['midi_filename']) | |
if '_5.midi' in midi_file or '_10.midi' in midi_file: | |
egmd_id = d['midi_filename'].split('.')[0] | |
mix_audio_file = os.path.join(base_dir, d['audio_filename']) | |
n_frames = get_audio_file_info(mix_audio_file)[1] | |
notes_file = midi_file.replace('.midi', '_notes.npy') | |
note_events_file = midi_file.replace('.midi', '_note_events.npy') | |
file_list[i] = { | |
'egmd_id': egmd_id, | |
'n_frames': n_frames, | |
'mix_audio_file': mix_audio_file, | |
'notes_file': notes_file, | |
'note_events_file': note_events_file, | |
'midi_file': midi_file, | |
'program': [128], | |
'is_drum': [1], | |
} | |
i += 1 | |
output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') | |
with open(output_file, 'w') as f: | |
json.dump(file_list, f, indent=4) | |
print(f'Wrote {output_file}') | |
assert len(file_list) == 246 | |