Spaces:
Sleeping
Sleeping
""" preprocess_mtrack_slakh.py | |
""" | |
import os | |
import time | |
import json | |
from typing import Dict, List, Tuple | |
import numpy as np | |
from utils.audio import get_audio_file_info, load_audio_file | |
from utils.midi import midi2note | |
from utils.note2event import note2note_event, mix_notes | |
import mirdata | |
from utils.mirdata_dev.datasets import slakh16k | |
def create_audio_stem_from_mtrack(ds: mirdata.core.Dataset, | |
mtrack_id: str, | |
delete_source_files: bool = False) -> Dict: | |
"""Extracts audio stems and metadata from a multitrack.""" | |
mtrack = ds.multitrack(mtrack_id) | |
track_ids = mtrack.track_ids | |
max_length = 0 | |
program_numbers = [] | |
is_drum = [] | |
audio_tracks = [] # multi-channel audio array (C, T) | |
# collect all the audio tracks and their metadata | |
for track_id in track_ids: | |
track = ds.track(track_id) | |
audio_file = track.audio_path | |
program_numbers.append(track.program_number) | |
is_drum.append(1) if track.is_drum else is_drum.append(0) | |
fs, n_frames, n_channels = get_audio_file_info(audio_file) | |
assert (fs == 16000 and n_channels == 1) | |
max_length = n_frames if n_frames > max_length else max_length | |
audio = load_audio_file(audio_file, dtype=np.int16) # returns bytes | |
audio = audio / 2**15 | |
audio = audio.astype(np.float16) | |
audio_tracks.append(audio) | |
if delete_source_files: | |
print(f'๐๏ธ Deleting {audio_file} ...') | |
os.remove(audio_file) | |
# collate all the audio tracks into a single array | |
n_tracks = len(track_ids) | |
audio_array = np.zeros((n_tracks, max_length), dtype=np.float16) | |
for j, audio in enumerate(audio_tracks): | |
audio_array[j, :len(audio)] = audio | |
stem_content = { | |
'mtrack_id': mtrack_id, # str | |
'program': np.array(program_numbers, dtype=np.int64), | |
'is_drum': np.array(is_drum, dtype=np.int64), | |
'n_frames': max_length, # int | |
'audio_array': audio_array # (n_tracks, n_frames) | |
} | |
return stem_content | |
def create_note_event_and_note_from_mtrack_mirdata( | |
ds: mirdata.core.Dataset, | |
mtrack_id: str, | |
fix_bass_octave: bool = True) -> Tuple[Dict, Dict]: | |
"""Extracts note or note_event and metadata from a multitrack: | |
Args: | |
ds (mirdata.core.Dataset): Slakh dataset. | |
mtrack_id (str): multitrack id. | |
Returns: | |
notes (dict): note events and metadata. | |
note_events (dict): note events and metadata. | |
""" | |
mtrack = ds.multitrack(mtrack_id) | |
track_ids = mtrack.track_ids | |
program_numbers = [] | |
is_drum = [] | |
mixed_notes = [] | |
duration_sec = 0. | |
# mix notes from all stem midi files | |
for track_id in track_ids: | |
track = ds.track(track_id) | |
stem_midi_file = track.midi_path | |
notes, dur_sec = midi2note( | |
stem_midi_file, | |
binary_velocity=True, | |
ch_9_as_drum=False, # checked safe to set to False in Slakh | |
force_all_drum=True if track.is_drum else False, | |
force_all_program_to=None, # Slakh always has program number | |
trim_overlap=True, | |
fix_offset=True, | |
quantize=True, | |
verbose=0, | |
minimum_offset_sec=0.01, | |
drum_offset_sec=0.01) | |
if fix_bass_octave == True and track.program_number in np.arange(32, 40): | |
if track.plugin_name == 'scarbee_jay_bass_slap_both.nkm': | |
pass | |
else: | |
for note in notes: | |
note.pitch -= 12 | |
print("Fixed bass octave for track", track_id) | |
mixed_notes = mix_notes((mixed_notes, notes), True, True, True) | |
program_numbers.append(track.program_number) | |
is_drum.append(1) if track.is_drum else is_drum.append(0) | |
duration_sec = max(duration_sec, dur_sec) | |
# convert mixed notes to note events | |
mixed_note_events = note2note_event(mixed_notes, sort=True, return_activity=True) | |
return { # notes | |
'mtrack_id': mtrack_id, # str | |
'program': np.array(program_numbers, dtype=np.int64), # (n,) | |
'is_drum': np.array(is_drum, dtype=np.int64), # (n,) with 1 is drum | |
'duration_sec': duration_sec, # float | |
'notes': mixed_notes # list of Note instances | |
}, { # note_events | |
'mtrack_id': mtrack_id, # str | |
'program': np.array(program_numbers, dtype=np.int64), # (n,) | |
'is_drum': np.array(is_drum, dtype=np.int64), # (n,) with 1 is drum | |
'duration_sec': duration_sec, # float | |
'note_events': mixed_note_events # list of NoteEvent instances | |
} | |
def preprocess_slakh16k(data_home: str, | |
run_checksum: bool = False, | |
delete_source_files: bool = False, | |
fix_bass_octave: bool = True) -> None: | |
""" | |
Processes the Slakh dataset and extracts stems for each multitrack. | |
Args: | |
data_home (str): path to the Slakh data. | |
run_checksum (bool): if True, validates the dataset using its checksum. Default is False. | |
delete_source_files (bool): if True, deletes original audio files. Default is False. | |
fix_bass_octave (bool): if True, fixes the bass to be -1 octave. Slakh bass is annotated as +1 octave. Default is True. | |
Writes: | |
- {dataset_name}_{split}_file_list.json: a dictionary with the following keys: | |
{ | |
'mtrack_id': mtrack_id, | |
'n_frames': n of audio frames | |
'stem_file': Dict of stem audio file info | |
'mix_audio_file': mtrack.mix_path, | |
'notes_file': available only for 'validation' and 'test' | |
'note_events_file': available only for 'train' and 'validation' | |
'midi_file': mtrack.midi_path | |
} | |
""" | |
start_time = time.time() | |
ds = slakh16k.Dataset(data_home=data_home, version='2100-yourmt3-16k') | |
if run_checksum: | |
print('Checksum for slakh dataset...') | |
ds.validate() | |
print('Preprocessing slakh dataset...') | |
mtrack_split_dict = ds.get_mtrack_splits() | |
for split in ['train', 'validation', 'test']: | |
file_list = {} # write a file list for each split | |
mtrack_ids = mtrack_split_dict[split] | |
for i, mtrack_id in enumerate(mtrack_ids): | |
print(f'๐๐ปโโ๏ธ: processing {mtrack_id} ({i+1}/{len(mtrack_ids)} in {split})') | |
mtrack = ds.multitrack(mtrack_id) | |
output_dir = os.path.dirname(mtrack.mix_path) # same as mtrack | |
"""Audio: get stems (as array) and metadata from the multitrack""" | |
stem_content = create_audio_stem_from_mtrack(ds, mtrack_id, delete_source_files) | |
# save the audio array and metadata to disk | |
stem_file = os.path.join(output_dir, mtrack_id + '_stem.npy') | |
np.save(stem_file, stem_content) | |
print(f'๐ฟ Created {stem_file}') | |
# no preprocessing for mix audio | |
"""MIDI: pre-process and get metadata from the multitrack""" | |
notes, note_events = create_note_event_and_note_from_mtrack_mirdata( | |
ds, mtrack_id, fix_bass_octave=fix_bass_octave) | |
# save the note events and metadata to disk | |
notes_file = os.path.join(output_dir, mtrack_id + '_notes.npy') | |
np.save(notes_file, notes, allow_pickle=True, \ | |
fix_imports=False) | |
print(f'๐น Created {notes_file}') | |
note_events_file = os.path.join(output_dir, mtrack_id + '_note_events.npy') | |
np.save(note_events_file, note_events, allow_pickle=True, \ | |
fix_imports=False) | |
print(f'๐น Created {note_events_file}') | |
# add to the file list of the split | |
file_list[i] = { | |
'mtrack_id': mtrack_id, | |
'n_frames': stem_content['n_frames'], # n of audio frames | |
'stem_file': stem_file, | |
'mix_audio_file': mtrack.mix_path, | |
'notes_file': notes_file, | |
'note_events_file': note_events_file,\ | |
'midi_file': mtrack.midi_path | |
} | |
# By split, save a file list as json | |
summary_dir = os.path.join(data_home, 'yourmt3_indexes') | |
os.makedirs(summary_dir, exist_ok=True) | |
summary_file = os.path.join(summary_dir, f'slakh_{split}_file_list.json') | |
with open(summary_file, 'w') as f: | |
json.dump(file_list, f, indent=4) | |
print(f'๐พ Created {summary_file}') | |
elapsed_time = time.time() - start_time | |
print( | |
f"โฐ: {int(elapsed_time // 3600):02d}h {int(elapsed_time % 3600 // 60):02d}m {elapsed_time % 60:.2f}s" | |
) | |
""" end of preprocess_slakh16k """ | |
def add_program_and_is_drum_info_to_file_list(data_home: str): | |
for split in ['train', 'validation', 'test']: | |
file_list_dir = os.path.join(data_home, 'yourmt3_indexes') | |
file = os.path.join(file_list_dir, f'slakh_{split}_file_list.json') | |
with open(file, 'r') as f: | |
file_list = json.load(f) | |
for v in file_list.values(): | |
stem_file = v['stem_file'] | |
stem_content = np.load(stem_file, allow_pickle=True).item() | |
v['program'] = stem_content['program'].tolist() | |
v['is_drum'] = stem_content['is_drum'].tolist() | |
with open(file, 'w') as f: | |
json.dump(file_list, f, indent=4) | |
print(f'๐พ Added program and drum info to {file}') | |
if __name__ == '__main__': | |
from config.config import shared_cfg | |
data_home = shared_cfg['PATH']['data_home'] | |
preprocess_slakh16k(data_home=data_home, delete_source_files=False) |