Violin_midi_pro / musc /synchronizer.py
Hygee's picture
Upload 9 files
e72f2a9 verified
from musc.dtw.mrmsdtw import sync_via_mrmsdtw_with_anchors
from musc.dtw.utils import make_path_strictly_monotonic
import numpy as np
from musc.transcriber import Transcriber
from typing import Dict
class Synchronizer(Transcriber):
def __init__(self, labeling, instrument='Violin', sr=16000, window_size=1024, hop_length=160):
super().__init__(labeling, instrument=instrument, sr=sr, window_size=window_size, hop_length=hop_length)
def synchronize(self, audio, midi, batch_size=128, include_pitch_bends=True, to_midi=True, debug=False,
include_velocity=False, alignment_padding=50, timing_refinement_range_with_f0s=0):
"""
Synchronize an audio file or mono waveform in numpy or torch with a MIDI file.
:param audio: str, pathlib.Path, np.ndarray, or torch.Tensor
:param midi: str, pathlib.Path, or pretty_midi.PrettyMIDI
:param batch_size: frames to process at once
:param include_pitch_bends: whether to include pitch bends in the MIDI file
:param to_midi: whether to return a MIDI file or a list of note events (as tuple)
:param debug: whether to plot the alignment path and compare the alignment with the predicted notes
:param include_velocity: whether to embed the note confidence in place of the velocity in the MIDI file
:param alignment_padding: how many frames to pad the audio and MIDI representations with
:param timing_refinement_range_with_f0s: how many frames to refine the alignment with the f0 confidence
:return: aligned MIDI file as a pretty_midi.PrettyMIDI object
Args:
debug:
to_midi:
include_pitch_bends:
"""
audio = self.predict(audio, batch_size)
notes_and_midi = self.out2sync(audio, midi, include_velocity=include_velocity,
alignment_padding=alignment_padding)
if notes_and_midi: # it might be none
notes, midi = notes_and_midi
if debug:
import matplotlib.pyplot as plt
import pandas as pd
estimated_notes = self.out2note(audio, postprocessing='spotify', include_pitch_bends=True)
est_df = pd.DataFrame(estimated_notes).sort_values(by=0)
note_df = pd.DataFrame(notes).sort_values(by=0)
fig, ax = plt.subplots(figsize=(20, 10))
for row in notes:
t_start = row[0] # sec
t_end = row[1] # sec
freq = row[2] # Hz
ax.hlines(freq, t_start, t_end, color='k', linewidth=3, zorder=2, alpha=0.5)
for row in estimated_notes:
t_start = row[0] # sec
t_end = row[1] # sec
freq = row[2] # Hz
ax.hlines(freq, t_start, t_end, color='r', linewidth=3, zorder=2, alpha=0.5)
fig.suptitle('alignment (black) vs. estimated (red)')
fig.show()
if not include_pitch_bends:
if to_midi:
return midi['midi']
else:
return notes
else:
notes = [(np.argmin(np.abs(audio['time']-note[0])),
np.argmin(np.abs(audio['time']-note[1])),
note[2], note[3]) for note in notes]
notes = self.get_pitch_bends(audio["f0"], notes, timing_refinement_range_with_f0s)
notes = [
(audio['time'][note[0]], audio['time'][note[1]], note[2], note[3], note[4]) for note in
notes
]
if to_midi:
return self.note2midi(notes, 120) #int(midi['midi'].estimate_tempo()))
else:
return notes
def out2sync_old(self, out: Dict[str, np.array], midi, include_velocity=False, alignment_padding=50, debug=False):
"""
Synchronizes the output of the model with the MIDI file.
Args:
out: Model output dictionary
midi: Path to the MIDI file or PrettyMIDI object
include_velocity: Whether to encode the note confidence in place of velocity
alignment_padding: Number of frames to pad the MIDI features with zeros
debug: Visualize the alignment
Returns:
note events and the aligned PrettyMIDI object
"""
midi = self.labeling.represent_midi(midi, self.sr/self.hop_length)
audio_midi_anchors = self.prepare_for_synchronization(out, midi, feature_rate=self.sr/self.hop_length,
pad_length=alignment_padding)
if isinstance(audio_midi_anchors, str):
print(audio_midi_anchors)
return None # the file is corrupted! no possible alignment at all
else:
audio, midi, anchor_pairs = audio_midi_anchors
ALPHA = 0.6 # This is the coefficient of onsets, 1 - ALPHA for offsets
wp = sync_via_mrmsdtw_with_anchors(f_chroma1=audio['note'].T,
f_onset1=np.hstack([ALPHA * audio['onset'],
(1 - ALPHA) * audio['offset']]).T,
f_chroma2=midi['note'].T,
f_onset2=np.hstack([ALPHA * midi['onset'],
(1 - ALPHA) * midi['offset']]).T,
input_feature_rate=self.sr/self.hop_length,
step_weights=np.array([1.5, 1.5, 2.0]),
threshold_rec=10 ** 6,
verbose=debug, normalize_chroma=False,
anchor_pairs=anchor_pairs)
wp = make_path_strictly_monotonic(wp).astype(int)
audio_time = np.take(audio['time'], wp[0])
midi_time = np.take(midi['time'], wp[1])
notes = []
for instrument in midi['midi'].instruments:
for note in instrument.notes:
note.start = np.interp(note.start, midi_time, audio_time)
note.end = np.interp(note.end, midi_time, audio_time)
if note.end - note.start <= 0.012: # notes should be at least 12 ms (i.e. 2 frames)
note.start = note.start - 0.003
note.end = note.start + 0.012
if include_velocity: # encode the note confidence in place of velocity
velocity = np.median(audio['note'][np.argmin(np.abs(audio['time']-note.start)):
np.argmin(np.abs(audio['time']-note.end)),
note.pitch-self.labeling.midi_centers[0]])
note.velocity = max(1, velocity*127) # velocity should be at least 1 otherwise midi removes the note
else:
velocity = note.velocity/127
notes.append((note.start, note.end, note.pitch, velocity))
return notes, midi
def out2sync(self, out: Dict[str, np.array], midi, include_velocity=False, alignment_padding=50, debug=False):
"""
Synchronizes the output of the model with the MIDI file.
Args:
out: Model output dictionary
midi: Path to the MIDI file or PrettyMIDI object
include_velocity: Whether to encode the note confidence in place of velocity
alignment_padding: Number of frames to pad the MIDI features with zeros
debug: Visualize the alignment
Returns:
note events and the aligned PrettyMIDI object
"""
midi = self.labeling.represent_midi(midi, self.sr/self.hop_length)
audio_midi_anchors = self.prepare_for_synchronization(out, midi, feature_rate=self.sr/self.hop_length,
pad_length=alignment_padding)
if isinstance(audio_midi_anchors, str):
print(audio_midi_anchors)
return None # the file is corrupted! no possible alignment at all
else:
audio, midi, anchor_pairs = audio_midi_anchors
ALPHA = 0.6 # This is the coefficient of onsets, 1 - ALPHA for offsets
starts = (np.array(anchor_pairs[0])*self.sr/self.hop_length).astype(int)
ends = (np.array(anchor_pairs[1])*self.sr/self.hop_length).astype(int)
wp = sync_via_mrmsdtw_with_anchors(f_chroma1=audio['note'].T[:, starts[0]:ends[0]],
f_onset1=np.hstack([ALPHA * audio['onset'],
(1 - ALPHA) * audio['offset']]).T[:, starts[0]:ends[0]],
f_chroma2=midi['note'].T[:, starts[1]:ends[1]],
f_onset2=np.hstack([ALPHA * midi['onset'],
(1 - ALPHA) * midi['offset']]).T[:, starts[1]:ends[1]],
input_feature_rate=self.sr/self.hop_length,
step_weights=np.array([1.5, 1.5, 2.0]),
threshold_rec=10 ** 6,
verbose=debug, normalize_chroma=False,
anchor_pairs=None)
wp = make_path_strictly_monotonic(wp).astype(int)
wp[0] += starts[0]
wp[1] += starts[1]
wp = np.hstack((wp, ends[:,np.newaxis]))
audio_time = np.take(audio['time'], wp[0])
midi_time = np.take(midi['time'], wp[1])
notes = []
for instrument in midi['midi'].instruments:
for note in instrument.notes:
note.start = np.interp(note.start, midi_time, audio_time)
note.end = np.interp(note.end, midi_time, audio_time)
if note.end - note.start <= 0.012: # notes should be at least 12 ms (i.e. 2 frames)
note.start = note.start - 0.003
note.end = note.start + 0.012
if include_velocity: # encode the note confidence in place of velocity
velocity = np.median(audio['note'][np.argmin(np.abs(audio['time']-note.start)):
np.argmin(np.abs(audio['time']-note.end)),
note.pitch-self.labeling.midi_centers[0]])
note.velocity = max(1, velocity*127) # velocity should be at least 1 otherwise midi removes the note
else:
velocity = note.velocity/127
notes.append((note.start, note.end, note.pitch, velocity))
return notes, midi
@staticmethod
def pad_representations(dict_of_representations, pad_length=10):
"""
Pad the representations so that the DTW does not enforce them to encompass the entire duration.
Args:
dict_of_representations: audio or midi representations
pad_length: how many frames to pad
Returns:
padded representations
"""
for key, value in dict_of_representations.items():
if key == 'time':
padded_time = dict_of_representations[key]
padded_time = np.concatenate([padded_time[:2*pad_length], padded_time+padded_time[2*pad_length]])
dict_of_representations[key] = padded_time - padded_time[pad_length] # this is to ensure that the
# first frame times are negative until the real zero time
elif key in ['onset', 'offset', 'note']:
dict_of_representations[key] = np.pad(value, ((pad_length, pad_length), (0, 0)))
elif key in ['start_anchor', 'end_anchor']:
anchor_time = dict_of_representations[key][0][0]
anchor_time = np.argmin(np.abs(dict_of_representations['time'] - anchor_time))
dict_of_representations[key][:,0] = anchor_time
dict_of_representations[key] = dict_of_representations[key].astype(np.int)
return dict_of_representations
def prepare_for_synchronization(self, audio, midi, feature_rate=44100/256, pad_length=100):
"""
MrMsDTW works better with start and end anchors. This function finds the start and end anchors for audio
based on the midi notes. It also pads the MIDI representations since MIDI files most often start with an active
note and end with an active note. Thus, the DTW will try to align the active notes to the entire duration of the
audio. This is not desirable. Therefore, we pad the MIDI representations with a few frames of silence at the
beginning and end of the audio. This way, the DTW will not try to align the active notes to the entire duration.
Args:
audio:
midi:
feature_rate:
pad_length:
Returns:
"""
# first pad the MIDI
midi = self.pad_representations(midi, pad_length)
# sometimes f0s are more reliable than the notes. So, we use both the f0s and the notes together to find the
# start and end anchors. f0 lookup bins is the number of bins to look around the f0 to assign a note to it.
f0_lookup_bins = int(100//(2*self.labeling.f0_granularity_c))
# find the start anchor for the audio
# first decide on which notes to use for the start anchor (take the entire chord where the MIDI file starts)
anchor_notes = midi['start_anchor'][:, 1] - self.labeling.midi_centers[0]
# now find which f0 bins to look at for the start anchor
anchor_f0s = [self.midi_pitch_to_contour_bin(an+self.labeling.midi_centers[0]) for an in anchor_notes]
anchor_f0s = np.array([list(range(f0-f0_lookup_bins, f0+f0_lookup_bins+1)) for f0 in anchor_f0s]).reshape(-1)
# first start anchor proposals come from the notes
anchor_vals = np.any(audio['note'][:, anchor_notes]>0.5, axis=1)
# now the f0s
anchor_vals_f0 = np.any(audio['f0'][:, anchor_f0s]>0.5, axis=1)
# combine the two
anchor_vals = np.logical_or(anchor_vals, anchor_vals_f0)
if not any(anchor_vals):
return 'corrupted' # do not consider the file if we cannot find the start anchor
audio_start = np.argmax(anchor_vals)
# now the end anchor (most string instruments use chords in cadences: in general the end anchor is polyphonic)
anchor_notes = midi['end_anchor'][:, 1] - self.labeling.midi_centers[0]
anchor_f0s = [self.midi_pitch_to_contour_bin(an+self.labeling.midi_centers[0]) for an in anchor_notes]
anchor_f0s = np.array([list(range(f0-f0_lookup_bins, f0+f0_lookup_bins+1)) for f0 in anchor_f0s]).reshape(-1)
# the same procedure as above
anchor_vals = np.any(audio['note'][::-1, anchor_notes]>0.5, axis=1)
anchor_vals_f0 = np.any(audio['f0'][::-1, anchor_f0s]>0.5, axis=1)
anchor_vals = np.logical_or(anchor_vals, anchor_vals_f0)
if not any(anchor_vals):
return 'corrupted' # do not consider the file if we cannot find the end anchor
audio_end = audio['note'].shape[0] - np.argmax(anchor_vals)
if audio_end - audio_start < (midi['end_anchor'][0][0] - midi['start_anchor'][0][0])/10: # no one plays x10 faster
return 'corrupted' # do not consider the interval between anchors is too short
anchor_pairs = [(audio_start - 5, midi['start_anchor'][0][0] - 5),
(audio_end + 5, midi['end_anchor'][0][0] + 5)]
if anchor_pairs[0][0] < 1:
anchor_pairs[0] = (1, midi['start_anchor'][0][0])
if anchor_pairs[1][0] > audio['note'].shape[0] - 1:
anchor_pairs[1] = (audio['note'].shape[0] - 1, midi['end_anchor'][0][0])
return audio, midi, [(anchor_pairs[0][0]/feature_rate, anchor_pairs[0][1]/feature_rate),
(anchor_pairs[1][0]/feature_rate, anchor_pairs[1][1]/feature_rate)]