Violin_midi_pro / musc /transcriber.py
Hygee's picture
Upload 9 files
e72f2a9 verified
from collections import defaultdict
from typing import DefaultDict, Dict, List, Optional, Tuple
import pretty_midi
import numpy as np
from musc.postprocessing import RegressionPostProcessor, spotify_create_notes
from musc.pitch_estimator import PitchEstimator
class Transcriber(PitchEstimator):
def __init__(self, labeling, instrument='Violin', sr=16000, window_size=1024, hop_length=160):
super().__init__(labeling, instrument=instrument, sr=sr, window_size=window_size, hop_length=hop_length)
def transcribe(self, audio, batch_size=128, postprocessing='spotify', include_pitch_bends=True, to_midi=True,
debug=False):
"""
Transcribe an audio file or mono waveform in numpy or torch into MIDI with pitch bends.
:param audio: str, pathlib.Path, np.ndarray, or torch.Tensor
:param batch_size: frames to process at once
:param postprocessing: note creation method. 'spotify'(default) or 'tiktok'
:param include_pitch_bends: whether to include pitch bends in the MIDI file
:param to_midi: whether to return a MIDI file or a list of note events (as tuple)
:return: transcribed MIDI file as a pretty_midi.PrettyMIDI object
"""
out = self.predict(audio, batch_size)
if debug:
import matplotlib.pyplot as plt
plt.imshow(out['f0'].T, aspect='auto', origin='lower')
plt.show()
plt.imshow(out['note'].T, aspect='auto', origin='lower')
plt.show()
plt.imshow(out['onset'].T, aspect='auto', origin='lower')
plt.show()
plt.imshow(out['offset'].T, aspect='auto', origin='lower')
plt.show()
if to_midi:
return self.out2midi(out, postprocessing, include_pitch_bends)
else:
return self.out2note(out, postprocessing, include_pitch_bends)
def out2note(self, output: Dict[str, np.array], postprocessing='spotify',
include_pitch_bends: bool = True,
) -> List[Tuple[float, float, int, float, Optional[List[int]]]]:
"""Convert model output to notes
"""
if postprocessing == 'spotify':
estimated_notes = spotify_create_notes(
output["note"],
output["onset"],
note_low=self.labeling.midi_centers[0],
note_high=self.labeling.midi_centers[-1],
onset_thresh=0.5,
frame_thresh=0.3,
infer_onsets=True,
min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))), #127.70
melodia_trick=True,
)
if postprocessing == 'rebab':
estimated_notes = spotify_create_notes(
output["note"],
output["onset"],
note_low=self.labeling.midi_centers[0],
note_high=self.labeling.midi_centers[-1],
onset_thresh=0.2,
frame_thresh=0.2,
infer_onsets=True,
min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))), #127.70
melodia_trick=True,
)
elif postprocessing == 'tiktok':
postprocessor = RegressionPostProcessor(
frames_per_second=self.sr / self.hop_length,
classes_num=self.labeling.midi_centers.shape[0],
begin_note=self.labeling.midi_centers[0],
onset_threshold=0.2,
offset_threshold=0.2,
frame_threshold=0.3,
pedal_offset_threshold=0.5,
)
tiktok_note_dict, _ = postprocessor.output_dict_to_midi_events(output)
estimated_notes = []
for list_item in tiktok_note_dict:
if list_item['offset_time'] > 0.6 + list_item['onset_time']:
estimated_notes.append((int(np.floor(list_item['onset_time']/(output['time'][1]))),
int(np.ceil(list_item['offset_time']/(output['time'][1]))),
list_item['midi_note'], list_item['velocity']/128))
if include_pitch_bends:
estimated_notes_with_pitch_bend = self.get_pitch_bends(output["f0"], estimated_notes)
else:
estimated_notes_with_pitch_bend = [(note[0], note[1], note[2], note[3], None) for note in estimated_notes]
times_s = output['time']
estimated_notes_time_seconds = [
(times_s[note[0]], times_s[note[1]], note[2], note[3], note[4]) for note in estimated_notes_with_pitch_bend
]
return estimated_notes_time_seconds
def out2midi(self, output: Dict[str, np.array], postprocessing: str = 'spotify', include_pitch_bends: bool = True,
) -> pretty_midi.PrettyMIDI:
"""Convert model output to MIDI
Args:
output: A dictionary with shape
{
'frame': array of shape (n_times, n_freqs),
'onset': array of shape (n_times, n_freqs),
'contour': array of shape (n_times, 3*n_freqs)
}
representing the output of the basic pitch model.
postprocessing: spotify or tiktok postprocessing.
include_pitch_bends: If True, include pitch bends.
Returns:
note_events: A list of note event tuples (start_time_s, end_time_s, pitch_midi, amplitude)
"""
estimated_notes_time_seconds = self.out2note(output, postprocessing, include_pitch_bends)
midi_tempo = 120 # todo: infer tempo from the onsets
return self.note2midi(estimated_notes_time_seconds, midi_tempo)
def note2midi(
self, note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]],
midi_tempo: float = 120,
) -> pretty_midi.PrettyMIDI:
"""Create a pretty_midi object from note events
:param note_events_with_pitch_bends: list of tuples
[(start_time_seconds, end_time_seconds, pitch_midi, amplitude)]
:param midi_tempo: #todo: infer tempo from the onsets
:return: transcribed MIDI file as a pretty_midi.PrettyMIDI object
"""
mid = pretty_midi.PrettyMIDI(initial_tempo=midi_tempo)
program = pretty_midi.instrument_name_to_program(self.instrument)
instruments: DefaultDict[int, pretty_midi.Instrument] = defaultdict(
lambda: pretty_midi.Instrument(program=program)
)
for start_time, end_time, note_number, amplitude, pitch_bend in note_events_with_pitch_bends:
instrument = instruments[note_number]
note = pretty_midi.Note(
velocity=int(np.round(127 * amplitude)),
pitch=note_number,
start=start_time,
end=end_time,
)
instrument.notes.append(note)
if not isinstance(pitch_bend, np.ndarray):
continue
pitch_bend_times = np.linspace(start_time, end_time, len(pitch_bend))
for pb_time, pb_midi in zip(pitch_bend_times, pitch_bend):
instrument.pitch_bends.append(pretty_midi.PitchBend(pb_midi, pb_time))
mid.instruments.extend(instruments.values())
return mid