File size: 6,277 Bytes
9965bf6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import pretty_midi
import pandas as pd
import numpy as np
from tqdm import tqdm
import math
from music_rule_guidance.music_rules import MAX_PIANO, MIN_PIANO
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (6,3)
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
CC_SUSTAIN_PEDAL = 64
def split_csv(csv_path='merged_midi.csv'):
# separate training validation testing files
df = pd.read_csv(csv_path)
save_name = csv_path[:csv_path.rfind('.csv')]
for split in ['train', 'validation', 'test']:
path = os.path.join(save_name, split + '.csv')
df_sub = df[df.split == split]
df_sub.to_csv(path, index=False)
return
def quantize_pedal(value, num_bins=8):
"""Quantize an integer value from 0 to 127 into 8 bins and return the center value of the bin."""
if value < 0 or value > 127:
raise ValueError("Value should be between 0 and 127")
# Determine bin size
bin_size = 128 // num_bins # 16
# Quantize the value
bin_index = value // bin_size
bin_center = bin_size * bin_index + bin_size // 2
# Handle edge case for the last bin
if bin_center > 127:
bin_center = 127
return bin_center
def get_full_piano_roll(midi_data, fs, show=False):
# do not process sustain pedal
piano_roll, onset_roll = midi_data.get_piano_roll(fs=fs, pedal_threshold=None, onset=True)
# save pedal roll explicitly
pedal_roll = np.zeros_like(piano_roll)
# process pedal
for instru in midi_data.instruments:
pedal_changes = [_e for _e in instru.control_changes if _e.number == CC_SUSTAIN_PEDAL]
for cc in pedal_changes:
time_now = int(cc.time * fs)
if time_now < pedal_roll.shape[-1]:
# need to distinguish control_change 0 and background 0, with quantize 0-16 will be 8
# in muscore files, 0 immediately followed by 127, need to shift by one column
if pedal_roll[MIN_PIANO, time_now] != 0. and abs(pedal_roll[MIN_PIANO, time_now] - cc.value) > 64:
# use shift 2 here to prevent missing change when using interpolation augmentation
pedal_roll[MIN_PIANO:MAX_PIANO + 1, min(time_now + 2, pedal_roll.shape[-1] - 1)] = quantize_pedal(cc.value)
else:
pedal_roll[MIN_PIANO:MAX_PIANO + 1, time_now] = quantize_pedal(cc.value)
full_roll = np.concatenate((piano_roll[None], onset_roll[None], pedal_roll[None]), axis=0)
if show:
plt.imshow(piano_roll[::-1, :1024], vmin=0, vmax=127)
plt.show()
plt.imshow(pedal_roll[::-1, :1024], vmin=0, vmax=127)
plt.show()
return full_roll
def preprocess_midi(target='merged', csv_path='merged_midi.csv', fs=100., image_size=128, overlap=False, show=False):
# get piano roll from midi file
df = pd.read_csv(csv_path)
total_pieces = len(df)
if not os.path.exists(target):
os.makedirs(target)
for split in ['train', 'test']:
path = os.path.join(target, split)
if not os.path.exists(path):
os.makedirs(path)
for i in tqdm(range(total_pieces)):
midi_filename = df.midi_filename[i]
split = df.split[i]
dataset = df.dataset[i]
path = os.path.join(target, split)
midi_data = pretty_midi.PrettyMIDI(os.path.join(dataset, midi_filename))
full_roll = get_full_piano_roll(midi_data, fs=fs, show=show)
for j in range(0, full_roll.shape[-1], image_size):
if j + image_size <= full_roll.shape[-1]:
full_roll_excerpt = full_roll[:, :, j:j + image_size]
else:
full_roll_excerpt = np.zeros((3, full_roll.shape[1], image_size)) # 2x128ximage_size
full_roll_excerpt[:, :, : full_roll.shape[-1] - j] = full_roll[:, :, j:]
empty_roll = math.isclose(full_roll_excerpt.max(), 0.)
if not empty_roll:
# Find the last '/' in the string
last_slash_index = midi_filename.rfind('/')
# Find the '.npy' in the string
dot_mid_index = midi_filename.rfind('.mid')
# Extract the substring between last '/' and '.mid'
save_name = midi_filename[last_slash_index + 1:dot_mid_index]
full_roll_excerpt = full_roll_excerpt.astype(np.uint8)
np.save(os.path.join(path, save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
# save with dataset name for VAE duplicate file names
# np.save(os.path.join(path, dataset + '_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
if overlap:
for j in range(image_size//2, full_roll.shape[-1], image_size): # overlap with image_size//2
if j + image_size <= full_roll.shape[-1]:
full_roll_excerpt = full_roll[:, :, j:j + image_size]
else:
full_roll_excerpt = np.zeros((3, full_roll.shape[1], image_size))
full_roll_excerpt[:, :, : full_roll.shape[-1] - j] = full_roll[:, :, j:]
empty_roll = math.isclose(full_roll_excerpt.max(), 0.)
if not empty_roll:
last_slash_index = midi_filename.rfind('/')
dot_mid_index = midi_filename.rfind('.mid')
save_name = midi_filename[last_slash_index + 1:dot_mid_index]
full_roll_excerpt = full_roll_excerpt.astype(np.uint8)
np.save(os.path.join(path, 'shift_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
# save with dataset name for VAE duplicate file names
# np.save(os.path.join(path, dataset + '_' + 'shift_' + save_name + '_' + str(j // image_size) + '.npy'), full_roll_excerpt)
return
def main():
# create fs=100 1.28s datasets without overlap (can be rearranged)
preprocess_midi(target='all-128-fs100', csv_path='all_midi.csv', fs=100, image_size=128, overlap=False, show=False)
# create fs=100 2.56s datasets with overlap (used for vae training), when load in, need to select 1.28s from 2.56s
# preprocess_midi(target='all-256-overlap-fs100', csv_path='all_midi.csv', fs=100, image_size=256, overlap=True,
# show=False)
# create fs=12.5 (0.08s) for pixel space diffusion model, rearrangement with length 2
# preprocess_midi(target='all-128-fs12.5', csv_path='all_midi.csv', fs=12.5, image_size=128, overlap=False,
# show=False)
if __name__ == "__main__":
main() |