Spaces:
Running
Running
from collections import defaultdict | |
from typing import DefaultDict, Dict, List, Optional, Tuple | |
import pretty_midi | |
import numpy as np | |
from musc.postprocessing import RegressionPostProcessor, spotify_create_notes | |
from musc.pitch_estimator import PitchEstimator | |
class Transcriber(PitchEstimator): | |
def __init__(self, labeling, instrument='Violin', sr=16000, window_size=1024, hop_length=160): | |
super().__init__(labeling, instrument=instrument, sr=sr, window_size=window_size, hop_length=hop_length) | |
def transcribe(self, audio, batch_size=128, postprocessing='spotify', include_pitch_bends=True, to_midi=True, | |
debug=False): | |
""" | |
Transcribe an audio file or mono waveform in numpy or torch into MIDI with pitch bends. | |
:param audio: str, pathlib.Path, np.ndarray, or torch.Tensor | |
:param batch_size: frames to process at once | |
:param postprocessing: note creation method. 'spotify'(default) or 'tiktok' | |
:param include_pitch_bends: whether to include pitch bends in the MIDI file | |
:param to_midi: whether to return a MIDI file or a list of note events (as tuple) | |
:return: transcribed MIDI file as a pretty_midi.PrettyMIDI object | |
""" | |
out = self.predict(audio, batch_size) | |
if debug: | |
import matplotlib.pyplot as plt | |
plt.imshow(out['f0'].T, aspect='auto', origin='lower') | |
plt.show() | |
plt.imshow(out['note'].T, aspect='auto', origin='lower') | |
plt.show() | |
plt.imshow(out['onset'].T, aspect='auto', origin='lower') | |
plt.show() | |
plt.imshow(out['offset'].T, aspect='auto', origin='lower') | |
plt.show() | |
if to_midi: | |
return self.out2midi(out, postprocessing, include_pitch_bends) | |
else: | |
return self.out2note(out, postprocessing, include_pitch_bends) | |
def out2note(self, output: Dict[str, np.array], postprocessing='spotify', | |
include_pitch_bends: bool = True, | |
) -> List[Tuple[float, float, int, float, Optional[List[int]]]]: | |
"""Convert model output to notes | |
""" | |
if postprocessing == 'spotify': | |
estimated_notes = spotify_create_notes( | |
output["note"], | |
output["onset"], | |
note_low=self.labeling.midi_centers[0], | |
note_high=self.labeling.midi_centers[-1], | |
onset_thresh=0.5, | |
frame_thresh=0.3, | |
infer_onsets=True, | |
min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))), #127.70 | |
melodia_trick=True, | |
) | |
if postprocessing == 'rebab': | |
estimated_notes = spotify_create_notes( | |
output["note"], | |
output["onset"], | |
note_low=self.labeling.midi_centers[0], | |
note_high=self.labeling.midi_centers[-1], | |
onset_thresh=0.2, | |
frame_thresh=0.2, | |
infer_onsets=True, | |
min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))), #127.70 | |
melodia_trick=True, | |
) | |
elif postprocessing == 'tiktok': | |
postprocessor = RegressionPostProcessor( | |
frames_per_second=self.sr / self.hop_length, | |
classes_num=self.labeling.midi_centers.shape[0], | |
begin_note=self.labeling.midi_centers[0], | |
onset_threshold=0.2, | |
offset_threshold=0.2, | |
frame_threshold=0.3, | |
pedal_offset_threshold=0.5, | |
) | |
tiktok_note_dict, _ = postprocessor.output_dict_to_midi_events(output) | |
estimated_notes = [] | |
for list_item in tiktok_note_dict: | |
if list_item['offset_time'] > 0.6 + list_item['onset_time']: | |
estimated_notes.append((int(np.floor(list_item['onset_time']/(output['time'][1]))), | |
int(np.ceil(list_item['offset_time']/(output['time'][1]))), | |
list_item['midi_note'], list_item['velocity']/128)) | |
if include_pitch_bends: | |
estimated_notes_with_pitch_bend = self.get_pitch_bends(output["f0"], estimated_notes) | |
else: | |
estimated_notes_with_pitch_bend = [(note[0], note[1], note[2], note[3], None) for note in estimated_notes] | |
times_s = output['time'] | |
estimated_notes_time_seconds = [ | |
(times_s[note[0]], times_s[note[1]], note[2], note[3], note[4]) for note in estimated_notes_with_pitch_bend | |
] | |
return estimated_notes_time_seconds | |
def out2midi(self, output: Dict[str, np.array], postprocessing: str = 'spotify', include_pitch_bends: bool = True, | |
) -> pretty_midi.PrettyMIDI: | |
"""Convert model output to MIDI | |
Args: | |
output: A dictionary with shape | |
{ | |
'frame': array of shape (n_times, n_freqs), | |
'onset': array of shape (n_times, n_freqs), | |
'contour': array of shape (n_times, 3*n_freqs) | |
} | |
representing the output of the basic pitch model. | |
postprocessing: spotify or tiktok postprocessing. | |
include_pitch_bends: If True, include pitch bends. | |
Returns: | |
note_events: A list of note event tuples (start_time_s, end_time_s, pitch_midi, amplitude) | |
""" | |
estimated_notes_time_seconds = self.out2note(output, postprocessing, include_pitch_bends) | |
midi_tempo = 120 # todo: infer tempo from the onsets | |
return self.note2midi(estimated_notes_time_seconds, midi_tempo) | |
def note2midi( | |
self, note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]], | |
midi_tempo: float = 120, | |
) -> pretty_midi.PrettyMIDI: | |
"""Create a pretty_midi object from note events | |
:param note_events_with_pitch_bends: list of tuples | |
[(start_time_seconds, end_time_seconds, pitch_midi, amplitude)] | |
:param midi_tempo: #todo: infer tempo from the onsets | |
:return: transcribed MIDI file as a pretty_midi.PrettyMIDI object | |
""" | |
mid = pretty_midi.PrettyMIDI(initial_tempo=midi_tempo) | |
program = pretty_midi.instrument_name_to_program(self.instrument) | |
instruments: DefaultDict[int, pretty_midi.Instrument] = defaultdict( | |
lambda: pretty_midi.Instrument(program=program) | |
) | |
for start_time, end_time, note_number, amplitude, pitch_bend in note_events_with_pitch_bends: | |
instrument = instruments[note_number] | |
note = pretty_midi.Note( | |
velocity=int(np.round(127 * amplitude)), | |
pitch=note_number, | |
start=start_time, | |
end=end_time, | |
) | |
instrument.notes.append(note) | |
if not isinstance(pitch_bend, np.ndarray): | |
continue | |
pitch_bend_times = np.linspace(start_time, end_time, len(pitch_bend)) | |
for pb_time, pb_midi in zip(pitch_bend_times, pitch_bend): | |
instrument.pitch_bends.append(pretty_midi.PitchBend(pb_midi, pb_time)) | |
mid.instruments.extend(instruments.values()) | |
return mid | |