Spaces:
Running
Running
File size: 7,400 Bytes
e72f2a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from collections import defaultdict
from typing import DefaultDict, Dict, List, Optional, Tuple
import pretty_midi
import numpy as np
from musc.postprocessing import RegressionPostProcessor, spotify_create_notes
from musc.pitch_estimator import PitchEstimator
class Transcriber(PitchEstimator):
def __init__(self, labeling, instrument='Violin', sr=16000, window_size=1024, hop_length=160):
super().__init__(labeling, instrument=instrument, sr=sr, window_size=window_size, hop_length=hop_length)
def transcribe(self, audio, batch_size=128, postprocessing='spotify', include_pitch_bends=True, to_midi=True,
debug=False):
"""
Transcribe an audio file or mono waveform in numpy or torch into MIDI with pitch bends.
:param audio: str, pathlib.Path, np.ndarray, or torch.Tensor
:param batch_size: frames to process at once
:param postprocessing: note creation method. 'spotify'(default) or 'tiktok'
:param include_pitch_bends: whether to include pitch bends in the MIDI file
:param to_midi: whether to return a MIDI file or a list of note events (as tuple)
:return: transcribed MIDI file as a pretty_midi.PrettyMIDI object
"""
out = self.predict(audio, batch_size)
if debug:
import matplotlib.pyplot as plt
plt.imshow(out['f0'].T, aspect='auto', origin='lower')
plt.show()
plt.imshow(out['note'].T, aspect='auto', origin='lower')
plt.show()
plt.imshow(out['onset'].T, aspect='auto', origin='lower')
plt.show()
plt.imshow(out['offset'].T, aspect='auto', origin='lower')
plt.show()
if to_midi:
return self.out2midi(out, postprocessing, include_pitch_bends)
else:
return self.out2note(out, postprocessing, include_pitch_bends)
def out2note(self, output: Dict[str, np.array], postprocessing='spotify',
include_pitch_bends: bool = True,
) -> List[Tuple[float, float, int, float, Optional[List[int]]]]:
"""Convert model output to notes
"""
if postprocessing == 'spotify':
estimated_notes = spotify_create_notes(
output["note"],
output["onset"],
note_low=self.labeling.midi_centers[0],
note_high=self.labeling.midi_centers[-1],
onset_thresh=0.5,
frame_thresh=0.3,
infer_onsets=True,
min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))), #127.70
melodia_trick=True,
)
if postprocessing == 'rebab':
estimated_notes = spotify_create_notes(
output["note"],
output["onset"],
note_low=self.labeling.midi_centers[0],
note_high=self.labeling.midi_centers[-1],
onset_thresh=0.2,
frame_thresh=0.2,
infer_onsets=True,
min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))), #127.70
melodia_trick=True,
)
elif postprocessing == 'tiktok':
postprocessor = RegressionPostProcessor(
frames_per_second=self.sr / self.hop_length,
classes_num=self.labeling.midi_centers.shape[0],
begin_note=self.labeling.midi_centers[0],
onset_threshold=0.2,
offset_threshold=0.2,
frame_threshold=0.3,
pedal_offset_threshold=0.5,
)
tiktok_note_dict, _ = postprocessor.output_dict_to_midi_events(output)
estimated_notes = []
for list_item in tiktok_note_dict:
if list_item['offset_time'] > 0.6 + list_item['onset_time']:
estimated_notes.append((int(np.floor(list_item['onset_time']/(output['time'][1]))),
int(np.ceil(list_item['offset_time']/(output['time'][1]))),
list_item['midi_note'], list_item['velocity']/128))
if include_pitch_bends:
estimated_notes_with_pitch_bend = self.get_pitch_bends(output["f0"], estimated_notes)
else:
estimated_notes_with_pitch_bend = [(note[0], note[1], note[2], note[3], None) for note in estimated_notes]
times_s = output['time']
estimated_notes_time_seconds = [
(times_s[note[0]], times_s[note[1]], note[2], note[3], note[4]) for note in estimated_notes_with_pitch_bend
]
return estimated_notes_time_seconds
def out2midi(self, output: Dict[str, np.array], postprocessing: str = 'spotify', include_pitch_bends: bool = True,
) -> pretty_midi.PrettyMIDI:
"""Convert model output to MIDI
Args:
output: A dictionary with shape
{
'frame': array of shape (n_times, n_freqs),
'onset': array of shape (n_times, n_freqs),
'contour': array of shape (n_times, 3*n_freqs)
}
representing the output of the basic pitch model.
postprocessing: spotify or tiktok postprocessing.
include_pitch_bends: If True, include pitch bends.
Returns:
note_events: A list of note event tuples (start_time_s, end_time_s, pitch_midi, amplitude)
"""
estimated_notes_time_seconds = self.out2note(output, postprocessing, include_pitch_bends)
midi_tempo = 120 # todo: infer tempo from the onsets
return self.note2midi(estimated_notes_time_seconds, midi_tempo)
def note2midi(
self, note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]],
midi_tempo: float = 120,
) -> pretty_midi.PrettyMIDI:
"""Create a pretty_midi object from note events
:param note_events_with_pitch_bends: list of tuples
[(start_time_seconds, end_time_seconds, pitch_midi, amplitude)]
:param midi_tempo: #todo: infer tempo from the onsets
:return: transcribed MIDI file as a pretty_midi.PrettyMIDI object
"""
mid = pretty_midi.PrettyMIDI(initial_tempo=midi_tempo)
program = pretty_midi.instrument_name_to_program(self.instrument)
instruments: DefaultDict[int, pretty_midi.Instrument] = defaultdict(
lambda: pretty_midi.Instrument(program=program)
)
for start_time, end_time, note_number, amplitude, pitch_bend in note_events_with_pitch_bends:
instrument = instruments[note_number]
note = pretty_midi.Note(
velocity=int(np.round(127 * amplitude)),
pitch=note_number,
start=start_time,
end=end_time,
)
instrument.notes.append(note)
if not isinstance(pitch_bend, np.ndarray):
continue
pitch_bend_times = np.linspace(start_time, end_time, len(pitch_bend))
for pb_time, pb_midi in zip(pitch_bend_times, pitch_bend):
instrument.pitch_bends.append(pretty_midi.PitchBend(pb_midi, pb_time))
mid.instruments.extend(instruments.values())
return mid
|