Spaces:
Running
Running
from musc.dtw.mrmsdtw import sync_via_mrmsdtw_with_anchors | |
from musc.dtw.utils import make_path_strictly_monotonic | |
import numpy as np | |
from musc.transcriber import Transcriber | |
from typing import Dict | |
class Synchronizer(Transcriber): | |
def __init__(self, labeling, instrument='Violin', sr=16000, window_size=1024, hop_length=160): | |
super().__init__(labeling, instrument=instrument, sr=sr, window_size=window_size, hop_length=hop_length) | |
def synchronize(self, audio, midi, batch_size=128, include_pitch_bends=True, to_midi=True, debug=False, | |
include_velocity=False, alignment_padding=50, timing_refinement_range_with_f0s=0): | |
""" | |
Synchronize an audio file or mono waveform in numpy or torch with a MIDI file. | |
:param audio: str, pathlib.Path, np.ndarray, or torch.Tensor | |
:param midi: str, pathlib.Path, or pretty_midi.PrettyMIDI | |
:param batch_size: frames to process at once | |
:param include_pitch_bends: whether to include pitch bends in the MIDI file | |
:param to_midi: whether to return a MIDI file or a list of note events (as tuple) | |
:param debug: whether to plot the alignment path and compare the alignment with the predicted notes | |
:param include_velocity: whether to embed the note confidence in place of the velocity in the MIDI file | |
:param alignment_padding: how many frames to pad the audio and MIDI representations with | |
:param timing_refinement_range_with_f0s: how many frames to refine the alignment with the f0 confidence | |
:return: aligned MIDI file as a pretty_midi.PrettyMIDI object | |
Args: | |
debug: | |
to_midi: | |
include_pitch_bends: | |
""" | |
audio = self.predict(audio, batch_size) | |
notes_and_midi = self.out2sync(audio, midi, include_velocity=include_velocity, | |
alignment_padding=alignment_padding) | |
if notes_and_midi: # it might be none | |
notes, midi = notes_and_midi | |
if debug: | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
estimated_notes = self.out2note(audio, postprocessing='spotify', include_pitch_bends=True) | |
est_df = pd.DataFrame(estimated_notes).sort_values(by=0) | |
note_df = pd.DataFrame(notes).sort_values(by=0) | |
fig, ax = plt.subplots(figsize=(20, 10)) | |
for row in notes: | |
t_start = row[0] # sec | |
t_end = row[1] # sec | |
freq = row[2] # Hz | |
ax.hlines(freq, t_start, t_end, color='k', linewidth=3, zorder=2, alpha=0.5) | |
for row in estimated_notes: | |
t_start = row[0] # sec | |
t_end = row[1] # sec | |
freq = row[2] # Hz | |
ax.hlines(freq, t_start, t_end, color='r', linewidth=3, zorder=2, alpha=0.5) | |
fig.suptitle('alignment (black) vs. estimated (red)') | |
fig.show() | |
if not include_pitch_bends: | |
if to_midi: | |
return midi['midi'] | |
else: | |
return notes | |
else: | |
notes = [(np.argmin(np.abs(audio['time']-note[0])), | |
np.argmin(np.abs(audio['time']-note[1])), | |
note[2], note[3]) for note in notes] | |
notes = self.get_pitch_bends(audio["f0"], notes, timing_refinement_range_with_f0s) | |
notes = [ | |
(audio['time'][note[0]], audio['time'][note[1]], note[2], note[3], note[4]) for note in | |
notes | |
] | |
if to_midi: | |
return self.note2midi(notes, 120) #int(midi['midi'].estimate_tempo())) | |
else: | |
return notes | |
def out2sync_old(self, out: Dict[str, np.array], midi, include_velocity=False, alignment_padding=50, debug=False): | |
""" | |
Synchronizes the output of the model with the MIDI file. | |
Args: | |
out: Model output dictionary | |
midi: Path to the MIDI file or PrettyMIDI object | |
include_velocity: Whether to encode the note confidence in place of velocity | |
alignment_padding: Number of frames to pad the MIDI features with zeros | |
debug: Visualize the alignment | |
Returns: | |
note events and the aligned PrettyMIDI object | |
""" | |
midi = self.labeling.represent_midi(midi, self.sr/self.hop_length) | |
audio_midi_anchors = self.prepare_for_synchronization(out, midi, feature_rate=self.sr/self.hop_length, | |
pad_length=alignment_padding) | |
if isinstance(audio_midi_anchors, str): | |
print(audio_midi_anchors) | |
return None # the file is corrupted! no possible alignment at all | |
else: | |
audio, midi, anchor_pairs = audio_midi_anchors | |
ALPHA = 0.6 # This is the coefficient of onsets, 1 - ALPHA for offsets | |
wp = sync_via_mrmsdtw_with_anchors(f_chroma1=audio['note'].T, | |
f_onset1=np.hstack([ALPHA * audio['onset'], | |
(1 - ALPHA) * audio['offset']]).T, | |
f_chroma2=midi['note'].T, | |
f_onset2=np.hstack([ALPHA * midi['onset'], | |
(1 - ALPHA) * midi['offset']]).T, | |
input_feature_rate=self.sr/self.hop_length, | |
step_weights=np.array([1.5, 1.5, 2.0]), | |
threshold_rec=10 ** 6, | |
verbose=debug, normalize_chroma=False, | |
anchor_pairs=anchor_pairs) | |
wp = make_path_strictly_monotonic(wp).astype(int) | |
audio_time = np.take(audio['time'], wp[0]) | |
midi_time = np.take(midi['time'], wp[1]) | |
notes = [] | |
for instrument in midi['midi'].instruments: | |
for note in instrument.notes: | |
note.start = np.interp(note.start, midi_time, audio_time) | |
note.end = np.interp(note.end, midi_time, audio_time) | |
if note.end - note.start <= 0.012: # notes should be at least 12 ms (i.e. 2 frames) | |
note.start = note.start - 0.003 | |
note.end = note.start + 0.012 | |
if include_velocity: # encode the note confidence in place of velocity | |
velocity = np.median(audio['note'][np.argmin(np.abs(audio['time']-note.start)): | |
np.argmin(np.abs(audio['time']-note.end)), | |
note.pitch-self.labeling.midi_centers[0]]) | |
note.velocity = max(1, velocity*127) # velocity should be at least 1 otherwise midi removes the note | |
else: | |
velocity = note.velocity/127 | |
notes.append((note.start, note.end, note.pitch, velocity)) | |
return notes, midi | |
def out2sync(self, out: Dict[str, np.array], midi, include_velocity=False, alignment_padding=50, debug=False): | |
""" | |
Synchronizes the output of the model with the MIDI file. | |
Args: | |
out: Model output dictionary | |
midi: Path to the MIDI file or PrettyMIDI object | |
include_velocity: Whether to encode the note confidence in place of velocity | |
alignment_padding: Number of frames to pad the MIDI features with zeros | |
debug: Visualize the alignment | |
Returns: | |
note events and the aligned PrettyMIDI object | |
""" | |
midi = self.labeling.represent_midi(midi, self.sr/self.hop_length) | |
audio_midi_anchors = self.prepare_for_synchronization(out, midi, feature_rate=self.sr/self.hop_length, | |
pad_length=alignment_padding) | |
if isinstance(audio_midi_anchors, str): | |
print(audio_midi_anchors) | |
return None # the file is corrupted! no possible alignment at all | |
else: | |
audio, midi, anchor_pairs = audio_midi_anchors | |
ALPHA = 0.6 # This is the coefficient of onsets, 1 - ALPHA for offsets | |
starts = (np.array(anchor_pairs[0])*self.sr/self.hop_length).astype(int) | |
ends = (np.array(anchor_pairs[1])*self.sr/self.hop_length).astype(int) | |
wp = sync_via_mrmsdtw_with_anchors(f_chroma1=audio['note'].T[:, starts[0]:ends[0]], | |
f_onset1=np.hstack([ALPHA * audio['onset'], | |
(1 - ALPHA) * audio['offset']]).T[:, starts[0]:ends[0]], | |
f_chroma2=midi['note'].T[:, starts[1]:ends[1]], | |
f_onset2=np.hstack([ALPHA * midi['onset'], | |
(1 - ALPHA) * midi['offset']]).T[:, starts[1]:ends[1]], | |
input_feature_rate=self.sr/self.hop_length, | |
step_weights=np.array([1.5, 1.5, 2.0]), | |
threshold_rec=10 ** 6, | |
verbose=debug, normalize_chroma=False, | |
anchor_pairs=None) | |
wp = make_path_strictly_monotonic(wp).astype(int) | |
wp[0] += starts[0] | |
wp[1] += starts[1] | |
wp = np.hstack((wp, ends[:,np.newaxis])) | |
audio_time = np.take(audio['time'], wp[0]) | |
midi_time = np.take(midi['time'], wp[1]) | |
notes = [] | |
for instrument in midi['midi'].instruments: | |
for note in instrument.notes: | |
note.start = np.interp(note.start, midi_time, audio_time) | |
note.end = np.interp(note.end, midi_time, audio_time) | |
if note.end - note.start <= 0.012: # notes should be at least 12 ms (i.e. 2 frames) | |
note.start = note.start - 0.003 | |
note.end = note.start + 0.012 | |
if include_velocity: # encode the note confidence in place of velocity | |
velocity = np.median(audio['note'][np.argmin(np.abs(audio['time']-note.start)): | |
np.argmin(np.abs(audio['time']-note.end)), | |
note.pitch-self.labeling.midi_centers[0]]) | |
note.velocity = max(1, velocity*127) # velocity should be at least 1 otherwise midi removes the note | |
else: | |
velocity = note.velocity/127 | |
notes.append((note.start, note.end, note.pitch, velocity)) | |
return notes, midi | |
def pad_representations(dict_of_representations, pad_length=10): | |
""" | |
Pad the representations so that the DTW does not enforce them to encompass the entire duration. | |
Args: | |
dict_of_representations: audio or midi representations | |
pad_length: how many frames to pad | |
Returns: | |
padded representations | |
""" | |
for key, value in dict_of_representations.items(): | |
if key == 'time': | |
padded_time = dict_of_representations[key] | |
padded_time = np.concatenate([padded_time[:2*pad_length], padded_time+padded_time[2*pad_length]]) | |
dict_of_representations[key] = padded_time - padded_time[pad_length] # this is to ensure that the | |
# first frame times are negative until the real zero time | |
elif key in ['onset', 'offset', 'note']: | |
dict_of_representations[key] = np.pad(value, ((pad_length, pad_length), (0, 0))) | |
elif key in ['start_anchor', 'end_anchor']: | |
anchor_time = dict_of_representations[key][0][0] | |
anchor_time = np.argmin(np.abs(dict_of_representations['time'] - anchor_time)) | |
dict_of_representations[key][:,0] = anchor_time | |
dict_of_representations[key] = dict_of_representations[key].astype(np.int) | |
return dict_of_representations | |
def prepare_for_synchronization(self, audio, midi, feature_rate=44100/256, pad_length=100): | |
""" | |
MrMsDTW works better with start and end anchors. This function finds the start and end anchors for audio | |
based on the midi notes. It also pads the MIDI representations since MIDI files most often start with an active | |
note and end with an active note. Thus, the DTW will try to align the active notes to the entire duration of the | |
audio. This is not desirable. Therefore, we pad the MIDI representations with a few frames of silence at the | |
beginning and end of the audio. This way, the DTW will not try to align the active notes to the entire duration. | |
Args: | |
audio: | |
midi: | |
feature_rate: | |
pad_length: | |
Returns: | |
""" | |
# first pad the MIDI | |
midi = self.pad_representations(midi, pad_length) | |
# sometimes f0s are more reliable than the notes. So, we use both the f0s and the notes together to find the | |
# start and end anchors. f0 lookup bins is the number of bins to look around the f0 to assign a note to it. | |
f0_lookup_bins = int(100//(2*self.labeling.f0_granularity_c)) | |
# find the start anchor for the audio | |
# first decide on which notes to use for the start anchor (take the entire chord where the MIDI file starts) | |
anchor_notes = midi['start_anchor'][:, 1] - self.labeling.midi_centers[0] | |
# now find which f0 bins to look at for the start anchor | |
anchor_f0s = [self.midi_pitch_to_contour_bin(an+self.labeling.midi_centers[0]) for an in anchor_notes] | |
anchor_f0s = np.array([list(range(f0-f0_lookup_bins, f0+f0_lookup_bins+1)) for f0 in anchor_f0s]).reshape(-1) | |
# first start anchor proposals come from the notes | |
anchor_vals = np.any(audio['note'][:, anchor_notes]>0.5, axis=1) | |
# now the f0s | |
anchor_vals_f0 = np.any(audio['f0'][:, anchor_f0s]>0.5, axis=1) | |
# combine the two | |
anchor_vals = np.logical_or(anchor_vals, anchor_vals_f0) | |
if not any(anchor_vals): | |
return 'corrupted' # do not consider the file if we cannot find the start anchor | |
audio_start = np.argmax(anchor_vals) | |
# now the end anchor (most string instruments use chords in cadences: in general the end anchor is polyphonic) | |
anchor_notes = midi['end_anchor'][:, 1] - self.labeling.midi_centers[0] | |
anchor_f0s = [self.midi_pitch_to_contour_bin(an+self.labeling.midi_centers[0]) for an in anchor_notes] | |
anchor_f0s = np.array([list(range(f0-f0_lookup_bins, f0+f0_lookup_bins+1)) for f0 in anchor_f0s]).reshape(-1) | |
# the same procedure as above | |
anchor_vals = np.any(audio['note'][::-1, anchor_notes]>0.5, axis=1) | |
anchor_vals_f0 = np.any(audio['f0'][::-1, anchor_f0s]>0.5, axis=1) | |
anchor_vals = np.logical_or(anchor_vals, anchor_vals_f0) | |
if not any(anchor_vals): | |
return 'corrupted' # do not consider the file if we cannot find the end anchor | |
audio_end = audio['note'].shape[0] - np.argmax(anchor_vals) | |
if audio_end - audio_start < (midi['end_anchor'][0][0] - midi['start_anchor'][0][0])/10: # no one plays x10 faster | |
return 'corrupted' # do not consider the interval between anchors is too short | |
anchor_pairs = [(audio_start - 5, midi['start_anchor'][0][0] - 5), | |
(audio_end + 5, midi['end_anchor'][0][0] + 5)] | |
if anchor_pairs[0][0] < 1: | |
anchor_pairs[0] = (1, midi['start_anchor'][0][0]) | |
if anchor_pairs[1][0] > audio['note'].shape[0] - 1: | |
anchor_pairs[1] = (audio['note'].shape[0] - 1, midi['end_anchor'][0][0]) | |
return audio, midi, [(anchor_pairs[0][0]/feature_rate, anchor_pairs[0][1]/feature_rate), | |
(anchor_pairs[1][0]/feature_rate, anchor_pairs[1][1]/feature_rate)] | |