|
""" |
|
Compute rules from midi files. |
|
Could be slightly different from online compute from piano roll because saved midi is a cleaner version of piano rolls. |
|
""" |
|
import os |
|
import pretty_midi |
|
import pandas as pd |
|
import numpy as np |
|
from tqdm import tqdm |
|
import math |
|
import glob |
|
import torch as th |
|
import multiprocessing |
|
from argparse import ArgumentParser |
|
from guided_diffusion import midi_util |
|
import torch.nn.functional as F |
|
|
|
def main(): |
|
parser = ArgumentParser() |
|
parser.add_argument('--root_dir', type=str, default='loggings/cond_table/all/beam_50_1_2_cls_1', |
|
help='Path to the folder that contains generated samples for rule guidance') |
|
parser.add_argument('--rule_name', type=str, default='chord_progression', |
|
help='rule to compute from midis') |
|
args = parser.parse_args() |
|
|
|
gen_files = sorted(glob.glob(f'{args.root_dir}/*.midi') + glob.glob(f'{args.root_dir}/*.mid')) |
|
orig_dir = f'{args.root_dir}/gt' |
|
if args.rule_name is None: |
|
target_rules = ['pitch_hist', 'note_density', 'chord_progression'] |
|
else: |
|
target_rules = [args.rule_name] |
|
|
|
all_results = pd.DataFrame() |
|
for file in tqdm(gen_files): |
|
gen_midi = pretty_midi.PrettyMIDI(file) |
|
gen_piano_roll = gen_midi.get_piano_roll(fs=100, pedal_threshold=None, onset=False) |
|
gen_piano_roll = th.from_numpy(gen_piano_roll)[None, None] / 63.5 - 1 |
|
gen_piano_roll = F.pad(gen_piano_roll, (0, 1024 - gen_piano_roll.shape[3]), "constant", -1) |
|
basename = os.path.basename(file) |
|
orig_file = os.path.join(orig_dir, basename) |
|
try: |
|
orig_midi = pretty_midi.PrettyMIDI(orig_file) |
|
except: |
|
print(basename) |
|
continue |
|
orig_piano_roll = orig_midi.get_piano_roll(fs=100, pedal_threshold=None, onset=False) |
|
orig_piano_roll = th.from_numpy(orig_piano_roll)[None, None] / 63.5 - 1 |
|
orig_piano_roll = F.pad(orig_piano_roll, (0, 1024 - orig_piano_roll.shape[3]), "constant", -1) |
|
|
|
results = midi_util.compute_rule(gen_piano_roll, orig_piano_roll, target_rules) |
|
all_results = pd.concat([all_results, results], ignore_index=True) |
|
|
|
all_results.to_csv(f'{args.root_dir}/results_computed.csv', index=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
multiprocessing.set_start_method('spawn', force=True) |
|
main() |
|
|