Spaces:
Running
Running
from typing import List, Tuple | |
import scipy | |
import numpy as np | |
# SPOTIFY | |
def get_inferred_onsets(onset_roll: np.array, note_roll: np.array, n_diff: int = 2) -> np.array: | |
""" | |
Infer onsets from large changes in note roll matrix amplitudes. | |
Modified from https://github.com/spotify/basic-pitch/blob/main/basic_pitch/note_creation.py | |
:param onset_roll: Onset activation matrix (n_times, n_freqs). | |
:param note_roll: Frame-level note activation matrix (n_times, n_freqs). | |
:param n_diff: Differences used to detect onsets. | |
:return: The maximum between the predicted onsets and its differences. | |
""" | |
diffs = [] | |
for n in range(1, n_diff + 1): | |
frames_appended = np.concatenate([np.zeros((n, note_roll.shape[1])), note_roll]) | |
diffs.append(frames_appended[n:, :] - frames_appended[:-n, :]) | |
frame_diff = np.min(diffs, axis=0) | |
frame_diff[frame_diff < 0] = 0 | |
frame_diff[:n_diff, :] = 0 | |
frame_diff = np.max(onset_roll) * frame_diff / np.max(frame_diff) # rescale to have the same max as onsets | |
max_onsets_diff = np.max([onset_roll, frame_diff], | |
axis=0) # use the max of the predicted onsets and the differences | |
return max_onsets_diff | |
def spotify_create_notes( | |
note_roll: np.array, | |
onset_roll: np.array, | |
onset_thresh: float, | |
frame_thresh: float, | |
min_note_len: int, | |
infer_onsets: bool, | |
note_low : int, #self.labeling.midi_centers[0] | |
note_high : int, #self.labeling.midi_centers[-1], | |
melodia_trick: bool = True, | |
energy_tol: int = 11, | |
) -> List[Tuple[int, int, int, float]]: | |
"""Decode raw model output to polyphonic note events | |
Modified from https://github.com/spotify/basic-pitch/blob/main/basic_pitch/note_creation.py | |
Args: | |
note_roll: Frame activation matrix (n_times, n_freqs). | |
onset_roll: Onset activation matrix (n_times, n_freqs). | |
onset_thresh: Minimum amplitude of an onset activation to be considered an onset. | |
frame_thresh: Minimum amplitude of a frame activation for a note to remain "on". | |
min_note_len: Minimum allowed note length in frames. | |
infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes. | |
melodia_trick : Whether to use the melodia trick to better detect notes. | |
energy_tol: Drop notes below this energy. | |
Returns: | |
list of tuples [(start_time_frames, end_time_frames, pitch_midi, amplitude)] | |
representing the note events, where amplitude is a number between 0 and 1 | |
""" | |
n_frames = note_roll.shape[0] | |
# use onsets inferred from frames in addition to the predicted onsets | |
if infer_onsets: | |
onset_roll = get_inferred_onsets(onset_roll, note_roll) | |
peak_thresh_mat = np.zeros(onset_roll.shape) | |
peaks = scipy.signal.argrelmax(onset_roll, axis=0) | |
peak_thresh_mat[peaks] = onset_roll[peaks] | |
onset_idx = np.where(peak_thresh_mat >= onset_thresh) | |
onset_time_idx = onset_idx[0][::-1] # sort to go backwards in time | |
onset_freq_idx = onset_idx[1][::-1] # sort to go backwards in time | |
remaining_energy = np.zeros(note_roll.shape) | |
remaining_energy[:, :] = note_roll[:, :] | |
# loop over onsets | |
note_events = [] | |
for note_start_idx, freq_idx in zip(onset_time_idx, onset_freq_idx): | |
# if we're too close to the end of the audio, continue | |
if note_start_idx >= n_frames - 1: | |
continue | |
# find time index at this frequency band where the frames drop below an energy threshold | |
i = note_start_idx + 1 | |
k = 0 # number of frames since energy dropped below threshold | |
while i < n_frames - 1 and k < energy_tol: | |
if remaining_energy[i, freq_idx] < frame_thresh: | |
k += 1 | |
else: | |
k = 0 | |
i += 1 | |
i -= k # go back to frame above threshold | |
# if the note is too short, skip it | |
if i - note_start_idx <= min_note_len: | |
continue | |
remaining_energy[note_start_idx:i, freq_idx] = 0 | |
if freq_idx < note_high: | |
remaining_energy[note_start_idx:i, freq_idx + 1] = 0 | |
if freq_idx > note_low: | |
remaining_energy[note_start_idx:i, freq_idx - 1] = 0 | |
# add the note | |
amplitude = np.mean(note_roll[note_start_idx:i, freq_idx]) | |
note_events.append( | |
( | |
note_start_idx, | |
i, | |
freq_idx + note_low, | |
amplitude, | |
) | |
) | |
if melodia_trick: | |
energy_shape = remaining_energy.shape | |
while np.max(remaining_energy) > frame_thresh: | |
i_mid, freq_idx = np.unravel_index(np.argmax(remaining_energy), energy_shape) | |
remaining_energy[i_mid, freq_idx] = 0 | |
# forward pass | |
i = i_mid + 1 | |
k = 0 | |
while i < n_frames - 1 and k < energy_tol: | |
if remaining_energy[i, freq_idx] < frame_thresh: | |
k += 1 | |
else: | |
k = 0 | |
remaining_energy[i, freq_idx] = 0 | |
if freq_idx < note_high: | |
remaining_energy[i, freq_idx + 1] = 0 | |
if freq_idx > note_low: | |
remaining_energy[i, freq_idx - 1] = 0 | |
i += 1 | |
i_end = i - 1 - k # go back to frame above threshold | |
# backward pass | |
i = i_mid - 1 | |
k = 0 | |
while i > 0 and k < energy_tol: | |
if remaining_energy[i, freq_idx] < frame_thresh: | |
k += 1 | |
else: | |
k = 0 | |
remaining_energy[i, freq_idx] = 0 | |
if freq_idx < note_high: | |
remaining_energy[i, freq_idx + 1] = 0 | |
if freq_idx > note_low: | |
remaining_energy[i, freq_idx - 1] = 0 | |
i -= 1 | |
i_start = i + 1 + k # go back to frame above threshold | |
assert i_start >= 0, "{}".format(i_start) | |
assert i_end < n_frames | |
if i_end - i_start <= min_note_len: | |
# note is too short, skip it | |
continue | |
# add the note | |
amplitude = np.mean(note_roll[i_start:i_end, freq_idx]) | |
note_events.append( | |
( | |
i_start, | |
i_end, | |
freq_idx + note_low, | |
amplitude, | |
) | |
) | |
return note_events | |
# TIKTOK | |
def note_detection_with_onset_offset_regress(frame_output, onset_output, | |
onset_shift_output, offset_output, offset_shift_output, velocity_output, | |
frame_threshold): | |
"""Process prediction matrices to note events information. | |
First, detect onsets with onset outputs. Then, detect offsets | |
with frame and offset outputs. | |
Args: | |
frame_output: (frames_num,) | |
onset_output: (frames_num,) | |
onset_shift_output: (frames_num,) | |
offset_output: (frames_num,) | |
offset_shift_output: (frames_num,) | |
velocity_output: (frames_num,) | |
frame_threshold: float | |
Returns: | |
output_tuples: list of [bgn, fin, onset_shift, offset_shift, normalized_velocity], | |
e.g., [ | |
[1821, 1909, 0.47498, 0.3048533, 0.72119445], | |
[1909, 1947, 0.30730522, -0.45764327, 0.64200014], | |
...] | |
""" | |
output_tuples = [] | |
bgn = None | |
frame_disappear = None | |
offset_occur = None | |
for i in range(onset_output.shape[0]): | |
if onset_output[i] == 1: | |
"""Onset detected""" | |
if bgn: | |
"""Consecutive onsets. E.g., pedal is not released, but two | |
consecutive notes being played.""" | |
fin = max(i - 1, 0) | |
output_tuples.append([bgn, fin, onset_shift_output[bgn], | |
0, velocity_output[bgn]]) | |
frame_disappear, offset_occur = None, None | |
bgn = i | |
if bgn and i > bgn: | |
"""If onset found, then search offset""" | |
if frame_output[i] <= frame_threshold and not frame_disappear: | |
"""Frame disappear detected""" | |
frame_disappear = i | |
if offset_output[i] == 1 and not offset_occur: | |
"""Offset detected""" | |
offset_occur = i | |
if frame_disappear: | |
if offset_occur and offset_occur - bgn > frame_disappear - offset_occur: | |
"""bgn --------- offset_occur --- frame_disappear""" | |
fin = offset_occur | |
else: | |
"""bgn --- offset_occur --------- frame_disappear""" | |
fin = frame_disappear | |
output_tuples.append([bgn, fin, onset_shift_output[bgn], | |
offset_shift_output[fin], velocity_output[bgn]]) | |
bgn, frame_disappear, offset_occur = None, None, None | |
if bgn and (i - bgn >= 600 or i == onset_output.shape[0] - 1): | |
"""Offset not detected""" | |
fin = i | |
output_tuples.append([bgn, fin, onset_shift_output[bgn], | |
offset_shift_output[fin], velocity_output[bgn]]) | |
bgn, frame_disappear, offset_occur = None, None, None | |
# Sort pairs by onsets | |
output_tuples.sort(key=lambda pair: pair[0]) | |
return output_tuples | |
class RegressionPostProcessor(object): | |
def __init__(self, frames_per_second, classes_num, onset_threshold, | |
offset_threshold, frame_threshold, pedal_offset_threshold, | |
begin_note): | |
"""Postprocess the output probabilities of a transription model to MIDI | |
events. | |
Args: | |
frames_per_second: float | |
classes_num: int | |
onset_threshold: float | |
offset_threshold: float | |
frame_threshold: float | |
pedal_offset_threshold: float | |
""" | |
self.frames_per_second = frames_per_second | |
self.classes_num = classes_num | |
self.onset_threshold = onset_threshold | |
self.offset_threshold = offset_threshold | |
self.frame_threshold = frame_threshold | |
self.pedal_offset_threshold = pedal_offset_threshold | |
self.begin_note = begin_note | |
self.velocity_scale = 128 | |
def output_dict_to_midi_events(self, output_dict): | |
"""Main function. Post process model outputs to MIDI events. | |
Args: | |
output_dict: { | |
'reg_onset_output': (segment_frames, classes_num), | |
'reg_offset_output': (segment_frames, classes_num), | |
'frame_output': (segment_frames, classes_num), | |
'velocity_output': (segment_frames, classes_num), | |
'reg_pedal_onset_output': (segment_frames, 1), | |
'reg_pedal_offset_output': (segment_frames, 1), | |
'pedal_frame_output': (segment_frames, 1)} | |
Outputs: | |
est_note_events: list of dict, e.g. [ | |
{'onset_time': 39.74, 'offset_time': 39.87, 'midi_note': 27, 'velocity': 83}, | |
{'onset_time': 11.98, 'offset_time': 12.11, 'midi_note': 33, 'velocity': 88}] | |
est_pedal_events: list of dict, e.g. [ | |
{'onset_time': 0.17, 'offset_time': 0.96}, | |
{'osnet_time': 1.17, 'offset_time': 2.65}] | |
""" | |
output_dict['frame_output'] = output_dict['note'] | |
output_dict['velocity_output'] = output_dict['note'] | |
output_dict['reg_onset_output'] = output_dict['onset'] | |
output_dict['reg_offset_output'] = output_dict['offset'] | |
# Post process piano note outputs to piano note and pedal events information | |
(est_on_off_note_vels, est_pedal_on_offs) = \ | |
self.output_dict_to_note_pedal_arrays(output_dict) | |
"""est_on_off_note_vels: (events_num, 4), the four columns are: [onset_time, offset_time, piano_note, velocity], | |
est_pedal_on_offs: (pedal_events_num, 2), the two columns are: [onset_time, offset_time]""" | |
# Reformat notes to MIDI events | |
est_note_events = self.detected_notes_to_events(est_on_off_note_vels) | |
if est_pedal_on_offs is None: | |
est_pedal_events = None | |
else: | |
est_pedal_events = self.detected_pedals_to_events(est_pedal_on_offs) | |
return est_note_events, est_pedal_events | |
def output_dict_to_note_pedal_arrays(self, output_dict): | |
"""Postprocess the output probabilities of a transription model to MIDI | |
events. | |
Args: | |
output_dict: dict, { | |
'reg_onset_output': (frames_num, classes_num), | |
'reg_offset_output': (frames_num, classes_num), | |
'frame_output': (frames_num, classes_num), | |
'velocity_output': (frames_num, classes_num), | |
...} | |
Returns: | |
est_on_off_note_vels: (events_num, 4), the 4 columns are onset_time, | |
offset_time, piano_note and velocity. E.g. [ | |
[39.74, 39.87, 27, 0.65], | |
[11.98, 12.11, 33, 0.69], | |
...] | |
est_pedal_on_offs: (pedal_events_num, 2), the 2 columns are onset_time | |
and offset_time. E.g. [ | |
[0.17, 0.96], | |
[1.17, 2.65], | |
...] | |
""" | |
# ------ 1. Process regression outputs to binarized outputs ------ | |
# For example, onset or offset of [0., 0., 0.15, 0.30, 0.40, 0.35, 0.20, 0.05, 0., 0.] | |
# will be processed to [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.] | |
# Calculate binarized onset output from regression output | |
(onset_output, onset_shift_output) = \ | |
self.get_binarized_output_from_regression( | |
reg_output=output_dict['reg_onset_output'], | |
threshold=self.onset_threshold, neighbour=2) | |
output_dict['onset_output'] = onset_output # Values are 0 or 1 | |
output_dict['onset_shift_output'] = onset_shift_output | |
# Calculate binarized offset output from regression output | |
(offset_output, offset_shift_output) = \ | |
self.get_binarized_output_from_regression( | |
reg_output=output_dict['reg_offset_output'], | |
threshold=self.offset_threshold, neighbour=4) | |
output_dict['offset_output'] = offset_output # Values are 0 or 1 | |
output_dict['offset_shift_output'] = offset_shift_output | |
if 'reg_pedal_onset_output' in output_dict.keys(): | |
"""Pedal onsets are not used in inference. Instead, frame-wise pedal | |
predictions are used to detect onsets. We empirically found this is | |
more accurate to detect pedal onsets.""" | |
pass | |
if 'reg_pedal_offset_output' in output_dict.keys(): | |
# Calculate binarized pedal offset output from regression output | |
(pedal_offset_output, pedal_offset_shift_output) = \ | |
self.get_binarized_output_from_regression( | |
reg_output=output_dict['reg_pedal_offset_output'], | |
threshold=self.pedal_offset_threshold, neighbour=4) | |
output_dict['pedal_offset_output'] = pedal_offset_output # Values are 0 or 1 | |
output_dict['pedal_offset_shift_output'] = pedal_offset_shift_output | |
# ------ 2. Process matrices results to event results ------ | |
# Detect piano notes from output_dict | |
est_on_off_note_vels = self.output_dict_to_detected_notes(output_dict) | |
est_pedal_on_offs = None | |
return est_on_off_note_vels, est_pedal_on_offs | |
def get_binarized_output_from_regression(self, reg_output, threshold, neighbour): | |
"""Calculate binarized output and shifts of onsets or offsets from the | |
regression results. | |
Args: | |
reg_output: (frames_num, classes_num) | |
threshold: float | |
neighbour: int | |
Returns: | |
binary_output: (frames_num, classes_num) | |
shift_output: (frames_num, classes_num) | |
""" | |
binary_output = np.zeros_like(reg_output) | |
shift_output = np.zeros_like(reg_output) | |
(frames_num, classes_num) = reg_output.shape | |
for k in range(classes_num): | |
x = reg_output[:, k] | |
for n in range(neighbour, frames_num - neighbour): | |
if x[n] > threshold and self.is_monotonic_neighbour(x, n, neighbour): | |
binary_output[n, k] = 1 | |
"""See Section III-D in [1] for deduction. | |
[1] Q. Kong, et al., High-resolution Piano Transcription | |
with Pedals by Regressing Onsets and Offsets Times, 2020.""" | |
if x[n - 1] > x[n + 1]: | |
shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n + 1]) / 2 | |
else: | |
shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n - 1]) / 2 | |
shift_output[n, k] = shift | |
return binary_output, shift_output | |
def is_monotonic_neighbour(self, x, n, neighbour): | |
"""Detect if values are monotonic in both side of x[n]. | |
Args: | |
x: (frames_num,) | |
n: int | |
neighbour: int | |
Returns: | |
monotonic: bool | |
""" | |
monotonic = True | |
for i in range(neighbour): | |
if x[n - i] < x[n - i - 1]: | |
monotonic = False | |
if x[n + i] < x[n + i + 1]: | |
monotonic = False | |
return monotonic | |
def output_dict_to_detected_notes(self, output_dict): | |
"""Postprocess output_dict to piano notes. | |
Args: | |
output_dict: dict, e.g. { | |
'onset_output': (frames_num, classes_num), | |
'onset_shift_output': (frames_num, classes_num), | |
'offset_output': (frames_num, classes_num), | |
'offset_shift_output': (frames_num, classes_num), | |
'frame_output': (frames_num, classes_num), | |
'onset_output': (frames_num, classes_num), | |
...} | |
Returns: | |
est_on_off_note_vels: (notes, 4), the four columns are onsets, offsets, | |
MIDI notes and velocities. E.g., | |
[[39.7375, 39.7500, 27., 0.6638], | |
[11.9824, 12.5000, 33., 0.6892], | |
...] | |
""" | |
est_tuples = [] | |
est_midi_notes = [] | |
classes_num = output_dict['frame_output'].shape[-1] | |
for piano_note in range(classes_num): | |
"""Detect piano notes""" | |
est_tuples_per_note = note_detection_with_onset_offset_regress( | |
frame_output=output_dict['frame_output'][:, piano_note], | |
onset_output=output_dict['onset_output'][:, piano_note], | |
onset_shift_output=output_dict['onset_shift_output'][:, piano_note], | |
offset_output=output_dict['offset_output'][:, piano_note], | |
offset_shift_output=output_dict['offset_shift_output'][:, piano_note], | |
velocity_output=output_dict['velocity_output'][:, piano_note], | |
frame_threshold=self.frame_threshold) | |
est_tuples += est_tuples_per_note | |
est_midi_notes += [piano_note + self.begin_note] * len(est_tuples_per_note) | |
est_tuples = np.array(est_tuples) # (notes, 5) | |
"""(notes, 5), the five columns are onset, offset, onset_shift, | |
offset_shift and normalized_velocity""" | |
est_midi_notes = np.array(est_midi_notes) # (notes,) | |
onset_times = (est_tuples[:, 0] + est_tuples[:, 2]) / self.frames_per_second | |
offset_times = (est_tuples[:, 1] + est_tuples[:, 3]) / self.frames_per_second | |
velocities = est_tuples[:, 4] | |
est_on_off_note_vels = np.stack((onset_times, offset_times, est_midi_notes, velocities), axis=-1) | |
"""(notes, 3), the three columns are onset_times, offset_times and velocity.""" | |
est_on_off_note_vels = est_on_off_note_vels.astype(np.float32) | |
return est_on_off_note_vels | |
def detected_notes_to_events(self, est_on_off_note_vels): | |
"""Reformat detected notes to midi events. | |
Args: | |
est_on_off_vels: (notes, 3), the three columns are onset_times, | |
offset_times and velocity. E.g. | |
[[32.8376, 35.7700, 0.7932], | |
[37.3712, 39.9300, 0.8058], | |
...] | |
Returns: | |
midi_events, list, e.g., | |
[{'onset_time': 39.7376, 'offset_time': 39.75, 'midi_note': 27, 'velocity': 84}, | |
{'onset_time': 11.9824, 'offset_time': 12.50, 'midi_note': 33, 'velocity': 88}, | |
...] | |
""" | |
midi_events = [] | |
for i in range(est_on_off_note_vels.shape[0]): | |
midi_events.append({ | |
'onset_time': est_on_off_note_vels[i][0], | |
'offset_time': est_on_off_note_vels[i][1], | |
'midi_note': int(est_on_off_note_vels[i][2]), | |
'velocity': int(est_on_off_note_vels[i][3] * self.velocity_scale)}) | |
return midi_events | |