Spaces:
Running
Running
Upload 9 files
Browse files- musc/__init__.py +3 -0
- musc/pathway.py +114 -0
- musc/pitch_estimator.py +206 -0
- musc/postprocessing.py +533 -0
- musc/representations.py +212 -0
- musc/synchronizer.py +299 -0
- musc/transcriber.py +163 -0
- musc/violin.json +13 -0
- musc/violin_model.pt +3 -0
musc/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__author__ = ""
|
2 |
+
__version__ = "0.0.2"
|
3 |
+
__description__ = "A Timbre-Aware Pitch Estimator that can transcribe an accompanied target instrument. No source separation, preprocessing, or postprocessing! From multi-instrument raw waveform to high-precision MIDI of the target."
|
musc/pathway.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class ConvBlock(nn.Module):
|
6 |
+
def __init__(self, f, w, s, d, in_channels):
|
7 |
+
super().__init__()
|
8 |
+
p1 = d*(w - 1) // 2
|
9 |
+
p2 = d*(w - 1) - p1
|
10 |
+
self.pad = nn.ZeroPad2d((0, 0, p1, p2))
|
11 |
+
|
12 |
+
self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=f, kernel_size=(w, 1), stride=(s, 1), dilation=(d, 1))
|
13 |
+
self.relu = nn.ReLU()
|
14 |
+
self.bn = nn.BatchNorm2d(f)
|
15 |
+
self.pool = nn.MaxPool2d(kernel_size=(2, 1))
|
16 |
+
self.dropout = nn.Dropout(0.25)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
x = self.pad(x)
|
20 |
+
x = self.conv2d(x)
|
21 |
+
x = self.relu(x)
|
22 |
+
x = self.bn(x)
|
23 |
+
x = self.pool(x)
|
24 |
+
x = self.dropout(x)
|
25 |
+
return x
|
26 |
+
|
27 |
+
|
28 |
+
class NoPadConvBlock(nn.Module):
|
29 |
+
def __init__(self, f, w, s, d, in_channels):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=f, kernel_size=(w, 1), stride=(s, 1),
|
33 |
+
dilation=(d, 1))
|
34 |
+
self.relu = nn.ReLU()
|
35 |
+
self.bn = nn.BatchNorm2d(f)
|
36 |
+
self.pool = nn.MaxPool2d(kernel_size=(2, 1))
|
37 |
+
self.dropout = nn.Dropout(0.25)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
x = self.conv2d(x)
|
41 |
+
x = self.relu(x)
|
42 |
+
x = self.bn(x)
|
43 |
+
x = self.pool(x)
|
44 |
+
x = self.dropout(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class TinyPathway(nn.Module):
|
49 |
+
def __init__(self, dilation=1, hop=256, localize=False,
|
50 |
+
model_capacity="full", n_layers=6, chunk_size=256):
|
51 |
+
super().__init__()
|
52 |
+
|
53 |
+
capacity_multiplier = {
|
54 |
+
'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32
|
55 |
+
}[model_capacity]
|
56 |
+
self.layers = [1, 2, 3, 4, 5, 6]
|
57 |
+
self.layers = self.layers[:n_layers]
|
58 |
+
filters = [n * capacity_multiplier for n in [32, 8, 8, 8, 8, 8]]
|
59 |
+
filters = [1] + filters
|
60 |
+
widths = [512, 64, 64, 64, 32, 32]
|
61 |
+
strides = self.deter_dilations(hop//(4*(2**n_layers)), localize=localize)
|
62 |
+
strides[0] = strides[0]*4 # apply 4 times more stride at the first layer
|
63 |
+
dilations = self.deter_dilations(dilation)
|
64 |
+
|
65 |
+
for i in range(len(self.layers)):
|
66 |
+
f, w, s, d, in_channel = filters[i + 1], widths[i], strides[i], dilations[i], filters[i]
|
67 |
+
self.add_module("conv%d" % i, NoPadConvBlock(f, w, s, d, in_channel))
|
68 |
+
self.chunk_size = chunk_size
|
69 |
+
self.input_window, self.hop = self.find_input_size_for_pathway()
|
70 |
+
self.out_dim = filters[n_layers]
|
71 |
+
|
72 |
+
def find_input_size_for_pathway(self):
|
73 |
+
def find_input_size(output_size, kernel_size, stride, dilation, padding):
|
74 |
+
num = (stride*(output_size-1)) + 1
|
75 |
+
input_size = num - 2*padding + dilation*(kernel_size-1)
|
76 |
+
return input_size
|
77 |
+
conv_calc, n = {}, 0
|
78 |
+
for i in self.layers:
|
79 |
+
layer = self.__getattr__("conv%d" % (i-1))
|
80 |
+
for mm in layer.modules():
|
81 |
+
if hasattr(mm, 'kernel_size'):
|
82 |
+
try:
|
83 |
+
d = mm.dilation[0]
|
84 |
+
except TypeError:
|
85 |
+
d = mm.dilation
|
86 |
+
conv_calc[n] = [mm.kernel_size[0], mm.stride[0], 0, d]
|
87 |
+
n += 1
|
88 |
+
out = self.chunk_size
|
89 |
+
hop = 1
|
90 |
+
for n in sorted(conv_calc.keys())[::-1]:
|
91 |
+
kernel_size_n, stride_n, padding_n, dilation_n = conv_calc[n]
|
92 |
+
out = find_input_size(out, kernel_size_n, stride_n, dilation_n, padding_n)
|
93 |
+
hop = hop*stride_n
|
94 |
+
return out, hop
|
95 |
+
|
96 |
+
def deter_dilations(self, total_dilation, localize=False):
|
97 |
+
n_layers = len(self.layers)
|
98 |
+
if localize: # e.g., 32*1023 window and 3 layers -> [1, 1, 32]
|
99 |
+
a = [total_dilation] + [1 for _ in range(n_layers-1)]
|
100 |
+
else: # e.g., 32*1023 window and 3 layers -> [4, 4, 2]
|
101 |
+
total_dilation = int(np.log2(total_dilation))
|
102 |
+
a = []
|
103 |
+
for layer in range(n_layers):
|
104 |
+
this_dilation = int(np.ceil(total_dilation/(n_layers-layer)))
|
105 |
+
a.append(2**this_dilation)
|
106 |
+
total_dilation = total_dilation - this_dilation
|
107 |
+
return a[::-1]
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
x = x.view(x.shape[0], 1, -1, 1)
|
111 |
+
for i in range(len(self.layers)):
|
112 |
+
x = self.__getattr__("conv%d" % i)(x)
|
113 |
+
x = x.permute(0, 3, 2, 1)
|
114 |
+
return x
|
musc/pitch_estimator.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
from typing import List, Optional, Tuple
|
5 |
+
import pathlib
|
6 |
+
from scipy.signal import medfilt
|
7 |
+
import numpy as np
|
8 |
+
import librosa
|
9 |
+
from librosa.sequence import viterbi_discriminative
|
10 |
+
from scipy.ndimage import gaussian_filter1d
|
11 |
+
from musc.postprocessing import spotify_create_notes
|
12 |
+
|
13 |
+
|
14 |
+
class PitchEstimator(nn.Module):
|
15 |
+
"""
|
16 |
+
This is the base class that everything else inherits from. The hierarchy is:
|
17 |
+
PitchEstimator -> Transcriber -> Synchronizer -> AutonomousAgent -> The n-Head Music Performance Analysis Models
|
18 |
+
PitchEstimator can handle reading the audio, predicting all the features,
|
19 |
+
estimating a single frame level f0 using viterbi, or
|
20 |
+
MIDI pitch bend creation for the predicted note events when used inside a Transcriber, or
|
21 |
+
score-informed f0 estimation when used inside a Synchronizer.
|
22 |
+
"""
|
23 |
+
def __init__(self, labeling, instrument='Violin', sr=16000, window_size=1024, hop_length=160):
|
24 |
+
super().__init__()
|
25 |
+
self.labeling = labeling
|
26 |
+
self.sr = sr
|
27 |
+
self.window_size = window_size
|
28 |
+
self.hop_length = hop_length
|
29 |
+
self.instrument = instrument
|
30 |
+
self.f0_bins_per_semitone = int(np.round(100/self.labeling.f0_granularity_c))
|
31 |
+
|
32 |
+
|
33 |
+
def read_audio(self, audio):
|
34 |
+
"""
|
35 |
+
Read and resample an audio file, convert to mono, and unfold into representation frames.
|
36 |
+
The time array represents the center of each small frame with 5.8ms hop length. This is different than the chunk
|
37 |
+
level frames. The chunk level frames represent the entire sequence the model sees. Whereas it predicts with the
|
38 |
+
small frames intervals (5.8ms).
|
39 |
+
:param audio: str, pathlib.Path, np.ndarray, or torch.Tensor
|
40 |
+
:return: frames: (n_big_frames, frame_length), times: (n_small_frames,)
|
41 |
+
"""
|
42 |
+
if isinstance(audio, str) or isinstance(audio, pathlib.Path):
|
43 |
+
audio, sample_rate = torchaudio.load(audio, normalize=True)
|
44 |
+
audio = audio.mean(axis=0) # convert to mono
|
45 |
+
if sample_rate != self.sr:
|
46 |
+
audio = torchaudio.functional.resample(audio, sample_rate, self.sr)
|
47 |
+
elif isinstance(audio, np.ndarray):
|
48 |
+
audio = torch.from_numpy(audio)
|
49 |
+
else:
|
50 |
+
assert isinstance(audio, torch.Tensor)
|
51 |
+
len_audio = audio.shape[-1]
|
52 |
+
n_frames = int(np.ceil((len_audio + sum(self.frame_overlap)) / (self.hop_length * self.chunk_size)))
|
53 |
+
audio = nn.functional.pad(audio, (self.frame_overlap[0],
|
54 |
+
self.frame_overlap[1] + (n_frames * self.hop_length * self.chunk_size) - len_audio))
|
55 |
+
frames = audio.unfold(0, self.max_window_size, self.hop_length*self.chunk_size)
|
56 |
+
times = np.arange(0, len_audio, self.hop_length) / self.sr # not tensor, we don't compute anything with it
|
57 |
+
return frames, times
|
58 |
+
|
59 |
+
def predict(self, audio, batch_size):
|
60 |
+
frames, times = self.read_audio(audio)
|
61 |
+
performance = {'f0': [], 'note': [], 'onset': [], 'offset': []}
|
62 |
+
self.eval()
|
63 |
+
device = self.main.conv0.conv2d.weight.device
|
64 |
+
with torch.no_grad():
|
65 |
+
for i in range(0, len(frames), batch_size):
|
66 |
+
f = frames[i:min(i + batch_size, len(frames))].to(device)
|
67 |
+
f -= (torch.mean(f, axis=1).unsqueeze(-1))
|
68 |
+
f /= (torch.std(f, axis=1).unsqueeze(-1))
|
69 |
+
out = self.forward(f)
|
70 |
+
for key, value in out.items():
|
71 |
+
value = torch.sigmoid(value)
|
72 |
+
value = torch.nan_to_num(value) # the model outputs nan when the frame is silent (this is an expected behavior due to normalization)
|
73 |
+
value = value.view(-1, value.shape[-1])
|
74 |
+
value = value.detach().cpu().numpy()
|
75 |
+
performance[key].append(value)
|
76 |
+
performance = {key: np.concatenate(value, axis=0)[:len(times)] for key, value in performance.items()}
|
77 |
+
performance['time'] = times
|
78 |
+
return performance
|
79 |
+
|
80 |
+
def estimate_pitch(self, audio, batch_size, viterbi=False):
|
81 |
+
out = self.predict(audio, batch_size)
|
82 |
+
f0_hz = self.out2f0(out, viterbi)
|
83 |
+
return out['time'], f0_hz
|
84 |
+
|
85 |
+
def out2f0(self, out, viterbi=False):
|
86 |
+
"""
|
87 |
+
Monophonic f0 estimation from the model output. The viterbi postprocessing is specialized for the violin family.
|
88 |
+
"""
|
89 |
+
salience = out['f0']
|
90 |
+
if viterbi == 'constrained':
|
91 |
+
assert hasattr(self, 'out2note')
|
92 |
+
notes = spotify_create_notes( out["note"], out["onset"], note_low=self.labeling.midi_centers[0],
|
93 |
+
note_high=self.labeling.midi_centers[-1], onset_thresh=0.5, frame_thresh=0.3,
|
94 |
+
infer_onsets=True, melodia_trick=True,
|
95 |
+
min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))))
|
96 |
+
note_cents = self.get_pitch_bends(salience, notes, to_midi=False, timing_refinement_range=0)
|
97 |
+
cents = np.zeros_like(out['time'])
|
98 |
+
cents[note_cents[:,0].astype(int)] = note_cents[:,1]
|
99 |
+
elif viterbi:
|
100 |
+
# transition probabilities inducing continuous pitch
|
101 |
+
# big changes are penalized with one order of magnitude
|
102 |
+
transition = gaussian_filter1d(np.eye(self.labeling.f0_n_bins), 30) + 99 * gaussian_filter1d(
|
103 |
+
np.eye(self.labeling.f0_n_bins), 2)
|
104 |
+
transition = transition / np.sum(transition, axis=1)[:, None]
|
105 |
+
|
106 |
+
p = salience / salience.sum(axis=1)[:, None]
|
107 |
+
p[np.isnan(p.sum(axis=1)), :] = np.ones(self.labeling.f0_n_bins) * 1 / self.labeling.f0_n_bins
|
108 |
+
path = viterbi_discriminative(p.T, transition)
|
109 |
+
cents = np.array([self.labeling.f0_label2c(salience[i, :], path[i]) for i in range(len(path))])
|
110 |
+
else:
|
111 |
+
cents = self.labeling.f0_label2c(salience, center=None) # use argmax for center
|
112 |
+
|
113 |
+
f0_hz = self.labeling.f0_c2hz(cents)
|
114 |
+
f0_hz[np.isnan(f0_hz)] = 0
|
115 |
+
return f0_hz
|
116 |
+
|
117 |
+
def get_pitch_bends(
|
118 |
+
self,
|
119 |
+
contours: np.ndarray, note_events: List[Tuple[int, int, int, float]],
|
120 |
+
timing_refinement_range: int = 0, to_midi: bool = True,
|
121 |
+
) -> List[Tuple[int, int, int, float, Optional[List[int]]]]:
|
122 |
+
"""Modified version of an excellent script from Spotify/basic_pitch!! Thank you!!!!
|
123 |
+
Given note events and contours, estimate pitch bends per note.
|
124 |
+
Pitch bends are represented as a sequence of evenly spaced midi pitch bend control units.
|
125 |
+
The time stamps of each pitch bend can be inferred by computing an evenly spaced grid between
|
126 |
+
the start and end times of each note event.
|
127 |
+
Args:
|
128 |
+
contours: Matrix of estimated pitch contours
|
129 |
+
note_events: note event tuple
|
130 |
+
timing_refinement_range: if > 0, refine onset/offset boundaries with f0 confidence
|
131 |
+
to_midi: whether to convert pitch bends to midi pitch bends. If False, return pitch estimates in the format
|
132 |
+
[time (index), pitch (Hz), confidence in range [0, 1]].
|
133 |
+
Returns:
|
134 |
+
note events with pitch bends
|
135 |
+
"""
|
136 |
+
|
137 |
+
f0_matrix = [] # [time (index), pitch (Hz), confidence in range [0, 1]]
|
138 |
+
note_events_with_pitch_bends = []
|
139 |
+
for start_idx, end_idx, pitch_midi, amplitude in note_events:
|
140 |
+
if timing_refinement_range:
|
141 |
+
start_idx = np.max([0, start_idx - timing_refinement_range])
|
142 |
+
end_idx = np.min([contours.shape[0], end_idx + timing_refinement_range])
|
143 |
+
freq_idx = int(np.round(self.midi_pitch_to_contour_bin(pitch_midi)))
|
144 |
+
freq_start_idx = np.max([freq_idx - self.labeling.f0_tolerance_bins, 0])
|
145 |
+
freq_end_idx = np.min([self.labeling.f0_n_bins, freq_idx + self.labeling.f0_tolerance_bins + 1])
|
146 |
+
|
147 |
+
trans_start_idx = np.max([0, self.labeling.f0_tolerance_bins - freq_idx])
|
148 |
+
trans_end_idx = (2 * self.labeling.f0_tolerance_bins + 1) - \
|
149 |
+
np.max([0, freq_idx - (self.labeling.f0_n_bins - self.labeling.f0_tolerance_bins - 1)])
|
150 |
+
|
151 |
+
# apply regional viterbi to estimate the intonation
|
152 |
+
# observation probabilities come from the f0_roll matrix
|
153 |
+
observation = contours[start_idx:end_idx, freq_start_idx:freq_end_idx]
|
154 |
+
observation = observation / observation.sum(axis=1)[:, None]
|
155 |
+
observation[np.isnan(observation.sum(axis=1)), :] = np.ones(freq_end_idx - freq_start_idx) * 1 / (
|
156 |
+
freq_end_idx - freq_start_idx)
|
157 |
+
|
158 |
+
# transition probabilities assure continuity
|
159 |
+
transition = self.labeling.f0_transition_matrix[trans_start_idx:trans_end_idx,
|
160 |
+
trans_start_idx:trans_end_idx] + 1e-6
|
161 |
+
transition = transition / np.sum(transition, axis=1)[:, None]
|
162 |
+
|
163 |
+
path = viterbi_discriminative(observation.T / observation.sum(axis=1), transition) + freq_start_idx
|
164 |
+
|
165 |
+
cents = np.array([self.labeling.f0_label2c(contours[i + start_idx, :], path[i]) for i in range(len(path))])
|
166 |
+
bends = cents - self.labeling.midi_centers_c[pitch_midi - self.labeling.midi_centers[0]]
|
167 |
+
if to_midi:
|
168 |
+
bends = (bends * 4096 / 100).astype(int)
|
169 |
+
bends[bends > 8191] = 8191
|
170 |
+
bends[bends < -8192] = -8192
|
171 |
+
|
172 |
+
if timing_refinement_range:
|
173 |
+
confidences = np.array([contours[i + start_idx, path[i]] for i in range(len(path))])
|
174 |
+
threshold = np.median(confidences)
|
175 |
+
threshold = (np.median(confidences > threshold) + threshold) / 2 # some magic
|
176 |
+
median_kernel = 2 * (timing_refinement_range // 2) + 1 # some more magic
|
177 |
+
confidences = medfilt(confidences, kernel_size=median_kernel)
|
178 |
+
conf_bool = confidences > threshold
|
179 |
+
onset_idx = np.argmax(conf_bool)
|
180 |
+
offset_idx = len(confidences) - np.argmax(conf_bool[::-1])
|
181 |
+
bends = bends[onset_idx:offset_idx]
|
182 |
+
start_idx = start_idx + onset_idx
|
183 |
+
end_idx = start_idx + offset_idx
|
184 |
+
|
185 |
+
note_events_with_pitch_bends.append((start_idx, end_idx, pitch_midi, amplitude, bends))
|
186 |
+
else:
|
187 |
+
confidences = np.array([contours[i + start_idx, path[i]] for i in range(len(path))])
|
188 |
+
time_idx = np.arange(len(path)) + start_idx
|
189 |
+
# f0_hz = self.labeling.f0_c2hz(cents)
|
190 |
+
possible_f0s = np.array([time_idx, cents, confidences]).T
|
191 |
+
f0_matrix.append(possible_f0s[np.abs(bends)<100]) # filter out pitch bends that are too large
|
192 |
+
if not to_midi:
|
193 |
+
return np.vstack(f0_matrix)
|
194 |
+
else:
|
195 |
+
return note_events_with_pitch_bends
|
196 |
+
|
197 |
+
|
198 |
+
def midi_pitch_to_contour_bin(self, pitch_midi: int) -> np.array:
|
199 |
+
"""Convert midi pitch to corresponding index in contour matrix
|
200 |
+
Args:
|
201 |
+
pitch_midi: pitch in midi
|
202 |
+
Returns:
|
203 |
+
index in contour matrix
|
204 |
+
"""
|
205 |
+
pitch_hz = librosa.midi_to_hz(pitch_midi)
|
206 |
+
return np.argmin(np.abs(self.labeling.f0_centers_hz - pitch_hz))
|
musc/postprocessing.py
ADDED
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
import scipy
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
# SPOTIFY
|
7 |
+
|
8 |
+
def get_inferred_onsets(onset_roll: np.array, note_roll: np.array, n_diff: int = 2) -> np.array:
|
9 |
+
"""
|
10 |
+
Infer onsets from large changes in note roll matrix amplitudes.
|
11 |
+
Modified from https://github.com/spotify/basic-pitch/blob/main/basic_pitch/note_creation.py
|
12 |
+
:param onset_roll: Onset activation matrix (n_times, n_freqs).
|
13 |
+
:param note_roll: Frame-level note activation matrix (n_times, n_freqs).
|
14 |
+
:param n_diff: Differences used to detect onsets.
|
15 |
+
:return: The maximum between the predicted onsets and its differences.
|
16 |
+
"""
|
17 |
+
|
18 |
+
diffs = []
|
19 |
+
for n in range(1, n_diff + 1):
|
20 |
+
frames_appended = np.concatenate([np.zeros((n, note_roll.shape[1])), note_roll])
|
21 |
+
diffs.append(frames_appended[n:, :] - frames_appended[:-n, :])
|
22 |
+
frame_diff = np.min(diffs, axis=0)
|
23 |
+
frame_diff[frame_diff < 0] = 0
|
24 |
+
frame_diff[:n_diff, :] = 0
|
25 |
+
frame_diff = np.max(onset_roll) * frame_diff / np.max(frame_diff) # rescale to have the same max as onsets
|
26 |
+
|
27 |
+
max_onsets_diff = np.max([onset_roll, frame_diff],
|
28 |
+
axis=0) # use the max of the predicted onsets and the differences
|
29 |
+
|
30 |
+
return max_onsets_diff
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
def spotify_create_notes(
|
35 |
+
note_roll: np.array,
|
36 |
+
onset_roll: np.array,
|
37 |
+
onset_thresh: float,
|
38 |
+
frame_thresh: float,
|
39 |
+
min_note_len: int,
|
40 |
+
infer_onsets: bool,
|
41 |
+
note_low : int, #self.labeling.midi_centers[0]
|
42 |
+
note_high : int, #self.labeling.midi_centers[-1],
|
43 |
+
melodia_trick: bool = True,
|
44 |
+
energy_tol: int = 11,
|
45 |
+
) -> List[Tuple[int, int, int, float]]:
|
46 |
+
"""Decode raw model output to polyphonic note events
|
47 |
+
Modified from https://github.com/spotify/basic-pitch/blob/main/basic_pitch/note_creation.py
|
48 |
+
Args:
|
49 |
+
note_roll: Frame activation matrix (n_times, n_freqs).
|
50 |
+
onset_roll: Onset activation matrix (n_times, n_freqs).
|
51 |
+
onset_thresh: Minimum amplitude of an onset activation to be considered an onset.
|
52 |
+
frame_thresh: Minimum amplitude of a frame activation for a note to remain "on".
|
53 |
+
min_note_len: Minimum allowed note length in frames.
|
54 |
+
infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes.
|
55 |
+
melodia_trick : Whether to use the melodia trick to better detect notes.
|
56 |
+
energy_tol: Drop notes below this energy.
|
57 |
+
Returns:
|
58 |
+
list of tuples [(start_time_frames, end_time_frames, pitch_midi, amplitude)]
|
59 |
+
representing the note events, where amplitude is a number between 0 and 1
|
60 |
+
"""
|
61 |
+
|
62 |
+
n_frames = note_roll.shape[0]
|
63 |
+
|
64 |
+
# use onsets inferred from frames in addition to the predicted onsets
|
65 |
+
if infer_onsets:
|
66 |
+
onset_roll = get_inferred_onsets(onset_roll, note_roll)
|
67 |
+
|
68 |
+
peak_thresh_mat = np.zeros(onset_roll.shape)
|
69 |
+
peaks = scipy.signal.argrelmax(onset_roll, axis=0)
|
70 |
+
peak_thresh_mat[peaks] = onset_roll[peaks]
|
71 |
+
|
72 |
+
onset_idx = np.where(peak_thresh_mat >= onset_thresh)
|
73 |
+
onset_time_idx = onset_idx[0][::-1] # sort to go backwards in time
|
74 |
+
onset_freq_idx = onset_idx[1][::-1] # sort to go backwards in time
|
75 |
+
|
76 |
+
remaining_energy = np.zeros(note_roll.shape)
|
77 |
+
remaining_energy[:, :] = note_roll[:, :]
|
78 |
+
|
79 |
+
# loop over onsets
|
80 |
+
note_events = []
|
81 |
+
for note_start_idx, freq_idx in zip(onset_time_idx, onset_freq_idx):
|
82 |
+
# if we're too close to the end of the audio, continue
|
83 |
+
if note_start_idx >= n_frames - 1:
|
84 |
+
continue
|
85 |
+
|
86 |
+
# find time index at this frequency band where the frames drop below an energy threshold
|
87 |
+
i = note_start_idx + 1
|
88 |
+
k = 0 # number of frames since energy dropped below threshold
|
89 |
+
while i < n_frames - 1 and k < energy_tol:
|
90 |
+
if remaining_energy[i, freq_idx] < frame_thresh:
|
91 |
+
k += 1
|
92 |
+
else:
|
93 |
+
k = 0
|
94 |
+
i += 1
|
95 |
+
|
96 |
+
i -= k # go back to frame above threshold
|
97 |
+
|
98 |
+
# if the note is too short, skip it
|
99 |
+
if i - note_start_idx <= min_note_len:
|
100 |
+
continue
|
101 |
+
|
102 |
+
remaining_energy[note_start_idx:i, freq_idx] = 0
|
103 |
+
if freq_idx < note_high:
|
104 |
+
remaining_energy[note_start_idx:i, freq_idx + 1] = 0
|
105 |
+
if freq_idx > note_low:
|
106 |
+
remaining_energy[note_start_idx:i, freq_idx - 1] = 0
|
107 |
+
|
108 |
+
# add the note
|
109 |
+
amplitude = np.mean(note_roll[note_start_idx:i, freq_idx])
|
110 |
+
note_events.append(
|
111 |
+
(
|
112 |
+
note_start_idx,
|
113 |
+
i,
|
114 |
+
freq_idx + note_low,
|
115 |
+
amplitude,
|
116 |
+
)
|
117 |
+
)
|
118 |
+
|
119 |
+
if melodia_trick:
|
120 |
+
energy_shape = remaining_energy.shape
|
121 |
+
|
122 |
+
while np.max(remaining_energy) > frame_thresh:
|
123 |
+
i_mid, freq_idx = np.unravel_index(np.argmax(remaining_energy), energy_shape)
|
124 |
+
remaining_energy[i_mid, freq_idx] = 0
|
125 |
+
|
126 |
+
# forward pass
|
127 |
+
i = i_mid + 1
|
128 |
+
k = 0
|
129 |
+
while i < n_frames - 1 and k < energy_tol:
|
130 |
+
if remaining_energy[i, freq_idx] < frame_thresh:
|
131 |
+
k += 1
|
132 |
+
else:
|
133 |
+
k = 0
|
134 |
+
|
135 |
+
remaining_energy[i, freq_idx] = 0
|
136 |
+
if freq_idx < note_high:
|
137 |
+
remaining_energy[i, freq_idx + 1] = 0
|
138 |
+
if freq_idx > note_low:
|
139 |
+
remaining_energy[i, freq_idx - 1] = 0
|
140 |
+
|
141 |
+
i += 1
|
142 |
+
|
143 |
+
i_end = i - 1 - k # go back to frame above threshold
|
144 |
+
|
145 |
+
# backward pass
|
146 |
+
i = i_mid - 1
|
147 |
+
k = 0
|
148 |
+
while i > 0 and k < energy_tol:
|
149 |
+
if remaining_energy[i, freq_idx] < frame_thresh:
|
150 |
+
k += 1
|
151 |
+
else:
|
152 |
+
k = 0
|
153 |
+
|
154 |
+
remaining_energy[i, freq_idx] = 0
|
155 |
+
if freq_idx < note_high:
|
156 |
+
remaining_energy[i, freq_idx + 1] = 0
|
157 |
+
if freq_idx > note_low:
|
158 |
+
remaining_energy[i, freq_idx - 1] = 0
|
159 |
+
|
160 |
+
i -= 1
|
161 |
+
|
162 |
+
i_start = i + 1 + k # go back to frame above threshold
|
163 |
+
assert i_start >= 0, "{}".format(i_start)
|
164 |
+
assert i_end < n_frames
|
165 |
+
|
166 |
+
if i_end - i_start <= min_note_len:
|
167 |
+
# note is too short, skip it
|
168 |
+
continue
|
169 |
+
|
170 |
+
# add the note
|
171 |
+
amplitude = np.mean(note_roll[i_start:i_end, freq_idx])
|
172 |
+
note_events.append(
|
173 |
+
(
|
174 |
+
i_start,
|
175 |
+
i_end,
|
176 |
+
freq_idx + note_low,
|
177 |
+
amplitude,
|
178 |
+
)
|
179 |
+
)
|
180 |
+
|
181 |
+
return note_events
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
# TIKTOK
|
186 |
+
|
187 |
+
|
188 |
+
def note_detection_with_onset_offset_regress(frame_output, onset_output,
|
189 |
+
onset_shift_output, offset_output, offset_shift_output, velocity_output,
|
190 |
+
frame_threshold):
|
191 |
+
"""Process prediction matrices to note events information.
|
192 |
+
First, detect onsets with onset outputs. Then, detect offsets
|
193 |
+
with frame and offset outputs.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
frame_output: (frames_num,)
|
197 |
+
onset_output: (frames_num,)
|
198 |
+
onset_shift_output: (frames_num,)
|
199 |
+
offset_output: (frames_num,)
|
200 |
+
offset_shift_output: (frames_num,)
|
201 |
+
velocity_output: (frames_num,)
|
202 |
+
frame_threshold: float
|
203 |
+
Returns:
|
204 |
+
output_tuples: list of [bgn, fin, onset_shift, offset_shift, normalized_velocity],
|
205 |
+
e.g., [
|
206 |
+
[1821, 1909, 0.47498, 0.3048533, 0.72119445],
|
207 |
+
[1909, 1947, 0.30730522, -0.45764327, 0.64200014],
|
208 |
+
...]
|
209 |
+
"""
|
210 |
+
output_tuples = []
|
211 |
+
bgn = None
|
212 |
+
frame_disappear = None
|
213 |
+
offset_occur = None
|
214 |
+
|
215 |
+
for i in range(onset_output.shape[0]):
|
216 |
+
if onset_output[i] == 1:
|
217 |
+
"""Onset detected"""
|
218 |
+
if bgn:
|
219 |
+
"""Consecutive onsets. E.g., pedal is not released, but two
|
220 |
+
consecutive notes being played."""
|
221 |
+
fin = max(i - 1, 0)
|
222 |
+
output_tuples.append([bgn, fin, onset_shift_output[bgn],
|
223 |
+
0, velocity_output[bgn]])
|
224 |
+
frame_disappear, offset_occur = None, None
|
225 |
+
bgn = i
|
226 |
+
|
227 |
+
if bgn and i > bgn:
|
228 |
+
"""If onset found, then search offset"""
|
229 |
+
if frame_output[i] <= frame_threshold and not frame_disappear:
|
230 |
+
"""Frame disappear detected"""
|
231 |
+
frame_disappear = i
|
232 |
+
|
233 |
+
if offset_output[i] == 1 and not offset_occur:
|
234 |
+
"""Offset detected"""
|
235 |
+
offset_occur = i
|
236 |
+
|
237 |
+
if frame_disappear:
|
238 |
+
if offset_occur and offset_occur - bgn > frame_disappear - offset_occur:
|
239 |
+
"""bgn --------- offset_occur --- frame_disappear"""
|
240 |
+
fin = offset_occur
|
241 |
+
else:
|
242 |
+
"""bgn --- offset_occur --------- frame_disappear"""
|
243 |
+
fin = frame_disappear
|
244 |
+
output_tuples.append([bgn, fin, onset_shift_output[bgn],
|
245 |
+
offset_shift_output[fin], velocity_output[bgn]])
|
246 |
+
bgn, frame_disappear, offset_occur = None, None, None
|
247 |
+
|
248 |
+
if bgn and (i - bgn >= 600 or i == onset_output.shape[0] - 1):
|
249 |
+
"""Offset not detected"""
|
250 |
+
fin = i
|
251 |
+
output_tuples.append([bgn, fin, onset_shift_output[bgn],
|
252 |
+
offset_shift_output[fin], velocity_output[bgn]])
|
253 |
+
bgn, frame_disappear, offset_occur = None, None, None
|
254 |
+
|
255 |
+
# Sort pairs by onsets
|
256 |
+
output_tuples.sort(key=lambda pair: pair[0])
|
257 |
+
|
258 |
+
return output_tuples
|
259 |
+
|
260 |
+
|
261 |
+
class RegressionPostProcessor(object):
|
262 |
+
def __init__(self, frames_per_second, classes_num, onset_threshold,
|
263 |
+
offset_threshold, frame_threshold, pedal_offset_threshold,
|
264 |
+
begin_note):
|
265 |
+
"""Postprocess the output probabilities of a transription model to MIDI
|
266 |
+
events.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
frames_per_second: float
|
270 |
+
classes_num: int
|
271 |
+
onset_threshold: float
|
272 |
+
offset_threshold: float
|
273 |
+
frame_threshold: float
|
274 |
+
pedal_offset_threshold: float
|
275 |
+
"""
|
276 |
+
self.frames_per_second = frames_per_second
|
277 |
+
self.classes_num = classes_num
|
278 |
+
self.onset_threshold = onset_threshold
|
279 |
+
self.offset_threshold = offset_threshold
|
280 |
+
self.frame_threshold = frame_threshold
|
281 |
+
self.pedal_offset_threshold = pedal_offset_threshold
|
282 |
+
self.begin_note = begin_note
|
283 |
+
self.velocity_scale = 128
|
284 |
+
|
285 |
+
def output_dict_to_midi_events(self, output_dict):
|
286 |
+
"""Main function. Post process model outputs to MIDI events.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
output_dict: {
|
290 |
+
'reg_onset_output': (segment_frames, classes_num),
|
291 |
+
'reg_offset_output': (segment_frames, classes_num),
|
292 |
+
'frame_output': (segment_frames, classes_num),
|
293 |
+
'velocity_output': (segment_frames, classes_num),
|
294 |
+
'reg_pedal_onset_output': (segment_frames, 1),
|
295 |
+
'reg_pedal_offset_output': (segment_frames, 1),
|
296 |
+
'pedal_frame_output': (segment_frames, 1)}
|
297 |
+
|
298 |
+
Outputs:
|
299 |
+
est_note_events: list of dict, e.g. [
|
300 |
+
{'onset_time': 39.74, 'offset_time': 39.87, 'midi_note': 27, 'velocity': 83},
|
301 |
+
{'onset_time': 11.98, 'offset_time': 12.11, 'midi_note': 33, 'velocity': 88}]
|
302 |
+
|
303 |
+
est_pedal_events: list of dict, e.g. [
|
304 |
+
{'onset_time': 0.17, 'offset_time': 0.96},
|
305 |
+
{'osnet_time': 1.17, 'offset_time': 2.65}]
|
306 |
+
"""
|
307 |
+
output_dict['frame_output'] = output_dict['note']
|
308 |
+
output_dict['velocity_output'] = output_dict['note']
|
309 |
+
output_dict['reg_onset_output'] = output_dict['onset']
|
310 |
+
output_dict['reg_offset_output'] = output_dict['offset']
|
311 |
+
# Post process piano note outputs to piano note and pedal events information
|
312 |
+
(est_on_off_note_vels, est_pedal_on_offs) = \
|
313 |
+
self.output_dict_to_note_pedal_arrays(output_dict)
|
314 |
+
"""est_on_off_note_vels: (events_num, 4), the four columns are: [onset_time, offset_time, piano_note, velocity],
|
315 |
+
est_pedal_on_offs: (pedal_events_num, 2), the two columns are: [onset_time, offset_time]"""
|
316 |
+
|
317 |
+
# Reformat notes to MIDI events
|
318 |
+
est_note_events = self.detected_notes_to_events(est_on_off_note_vels)
|
319 |
+
|
320 |
+
if est_pedal_on_offs is None:
|
321 |
+
est_pedal_events = None
|
322 |
+
else:
|
323 |
+
est_pedal_events = self.detected_pedals_to_events(est_pedal_on_offs)
|
324 |
+
|
325 |
+
return est_note_events, est_pedal_events
|
326 |
+
|
327 |
+
def output_dict_to_note_pedal_arrays(self, output_dict):
|
328 |
+
"""Postprocess the output probabilities of a transription model to MIDI
|
329 |
+
events.
|
330 |
+
|
331 |
+
Args:
|
332 |
+
output_dict: dict, {
|
333 |
+
'reg_onset_output': (frames_num, classes_num),
|
334 |
+
'reg_offset_output': (frames_num, classes_num),
|
335 |
+
'frame_output': (frames_num, classes_num),
|
336 |
+
'velocity_output': (frames_num, classes_num),
|
337 |
+
...}
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
est_on_off_note_vels: (events_num, 4), the 4 columns are onset_time,
|
341 |
+
offset_time, piano_note and velocity. E.g. [
|
342 |
+
[39.74, 39.87, 27, 0.65],
|
343 |
+
[11.98, 12.11, 33, 0.69],
|
344 |
+
...]
|
345 |
+
|
346 |
+
est_pedal_on_offs: (pedal_events_num, 2), the 2 columns are onset_time
|
347 |
+
and offset_time. E.g. [
|
348 |
+
[0.17, 0.96],
|
349 |
+
[1.17, 2.65],
|
350 |
+
...]
|
351 |
+
"""
|
352 |
+
|
353 |
+
# ------ 1. Process regression outputs to binarized outputs ------
|
354 |
+
# For example, onset or offset of [0., 0., 0.15, 0.30, 0.40, 0.35, 0.20, 0.05, 0., 0.]
|
355 |
+
# will be processed to [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]
|
356 |
+
|
357 |
+
# Calculate binarized onset output from regression output
|
358 |
+
(onset_output, onset_shift_output) = \
|
359 |
+
self.get_binarized_output_from_regression(
|
360 |
+
reg_output=output_dict['reg_onset_output'],
|
361 |
+
threshold=self.onset_threshold, neighbour=2)
|
362 |
+
|
363 |
+
output_dict['onset_output'] = onset_output # Values are 0 or 1
|
364 |
+
output_dict['onset_shift_output'] = onset_shift_output
|
365 |
+
|
366 |
+
# Calculate binarized offset output from regression output
|
367 |
+
(offset_output, offset_shift_output) = \
|
368 |
+
self.get_binarized_output_from_regression(
|
369 |
+
reg_output=output_dict['reg_offset_output'],
|
370 |
+
threshold=self.offset_threshold, neighbour=4)
|
371 |
+
|
372 |
+
output_dict['offset_output'] = offset_output # Values are 0 or 1
|
373 |
+
output_dict['offset_shift_output'] = offset_shift_output
|
374 |
+
|
375 |
+
if 'reg_pedal_onset_output' in output_dict.keys():
|
376 |
+
"""Pedal onsets are not used in inference. Instead, frame-wise pedal
|
377 |
+
predictions are used to detect onsets. We empirically found this is
|
378 |
+
more accurate to detect pedal onsets."""
|
379 |
+
pass
|
380 |
+
|
381 |
+
if 'reg_pedal_offset_output' in output_dict.keys():
|
382 |
+
# Calculate binarized pedal offset output from regression output
|
383 |
+
(pedal_offset_output, pedal_offset_shift_output) = \
|
384 |
+
self.get_binarized_output_from_regression(
|
385 |
+
reg_output=output_dict['reg_pedal_offset_output'],
|
386 |
+
threshold=self.pedal_offset_threshold, neighbour=4)
|
387 |
+
|
388 |
+
output_dict['pedal_offset_output'] = pedal_offset_output # Values are 0 or 1
|
389 |
+
output_dict['pedal_offset_shift_output'] = pedal_offset_shift_output
|
390 |
+
|
391 |
+
# ------ 2. Process matrices results to event results ------
|
392 |
+
# Detect piano notes from output_dict
|
393 |
+
est_on_off_note_vels = self.output_dict_to_detected_notes(output_dict)
|
394 |
+
|
395 |
+
est_pedal_on_offs = None
|
396 |
+
|
397 |
+
return est_on_off_note_vels, est_pedal_on_offs
|
398 |
+
|
399 |
+
def get_binarized_output_from_regression(self, reg_output, threshold, neighbour):
|
400 |
+
"""Calculate binarized output and shifts of onsets or offsets from the
|
401 |
+
regression results.
|
402 |
+
|
403 |
+
Args:
|
404 |
+
reg_output: (frames_num, classes_num)
|
405 |
+
threshold: float
|
406 |
+
neighbour: int
|
407 |
+
|
408 |
+
Returns:
|
409 |
+
binary_output: (frames_num, classes_num)
|
410 |
+
shift_output: (frames_num, classes_num)
|
411 |
+
"""
|
412 |
+
binary_output = np.zeros_like(reg_output)
|
413 |
+
shift_output = np.zeros_like(reg_output)
|
414 |
+
(frames_num, classes_num) = reg_output.shape
|
415 |
+
|
416 |
+
for k in range(classes_num):
|
417 |
+
x = reg_output[:, k]
|
418 |
+
for n in range(neighbour, frames_num - neighbour):
|
419 |
+
if x[n] > threshold and self.is_monotonic_neighbour(x, n, neighbour):
|
420 |
+
binary_output[n, k] = 1
|
421 |
+
|
422 |
+
"""See Section III-D in [1] for deduction.
|
423 |
+
[1] Q. Kong, et al., High-resolution Piano Transcription
|
424 |
+
with Pedals by Regressing Onsets and Offsets Times, 2020."""
|
425 |
+
if x[n - 1] > x[n + 1]:
|
426 |
+
shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n + 1]) / 2
|
427 |
+
else:
|
428 |
+
shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n - 1]) / 2
|
429 |
+
shift_output[n, k] = shift
|
430 |
+
|
431 |
+
return binary_output, shift_output
|
432 |
+
|
433 |
+
def is_monotonic_neighbour(self, x, n, neighbour):
|
434 |
+
"""Detect if values are monotonic in both side of x[n].
|
435 |
+
|
436 |
+
Args:
|
437 |
+
x: (frames_num,)
|
438 |
+
n: int
|
439 |
+
neighbour: int
|
440 |
+
|
441 |
+
Returns:
|
442 |
+
monotonic: bool
|
443 |
+
"""
|
444 |
+
monotonic = True
|
445 |
+
for i in range(neighbour):
|
446 |
+
if x[n - i] < x[n - i - 1]:
|
447 |
+
monotonic = False
|
448 |
+
if x[n + i] < x[n + i + 1]:
|
449 |
+
monotonic = False
|
450 |
+
|
451 |
+
return monotonic
|
452 |
+
|
453 |
+
def output_dict_to_detected_notes(self, output_dict):
|
454 |
+
"""Postprocess output_dict to piano notes.
|
455 |
+
|
456 |
+
Args:
|
457 |
+
output_dict: dict, e.g. {
|
458 |
+
'onset_output': (frames_num, classes_num),
|
459 |
+
'onset_shift_output': (frames_num, classes_num),
|
460 |
+
'offset_output': (frames_num, classes_num),
|
461 |
+
'offset_shift_output': (frames_num, classes_num),
|
462 |
+
'frame_output': (frames_num, classes_num),
|
463 |
+
'onset_output': (frames_num, classes_num),
|
464 |
+
...}
|
465 |
+
|
466 |
+
Returns:
|
467 |
+
est_on_off_note_vels: (notes, 4), the four columns are onsets, offsets,
|
468 |
+
MIDI notes and velocities. E.g.,
|
469 |
+
[[39.7375, 39.7500, 27., 0.6638],
|
470 |
+
[11.9824, 12.5000, 33., 0.6892],
|
471 |
+
...]
|
472 |
+
"""
|
473 |
+
|
474 |
+
est_tuples = []
|
475 |
+
est_midi_notes = []
|
476 |
+
classes_num = output_dict['frame_output'].shape[-1]
|
477 |
+
|
478 |
+
for piano_note in range(classes_num):
|
479 |
+
"""Detect piano notes"""
|
480 |
+
est_tuples_per_note = note_detection_with_onset_offset_regress(
|
481 |
+
frame_output=output_dict['frame_output'][:, piano_note],
|
482 |
+
onset_output=output_dict['onset_output'][:, piano_note],
|
483 |
+
onset_shift_output=output_dict['onset_shift_output'][:, piano_note],
|
484 |
+
offset_output=output_dict['offset_output'][:, piano_note],
|
485 |
+
offset_shift_output=output_dict['offset_shift_output'][:, piano_note],
|
486 |
+
velocity_output=output_dict['velocity_output'][:, piano_note],
|
487 |
+
frame_threshold=self.frame_threshold)
|
488 |
+
|
489 |
+
est_tuples += est_tuples_per_note
|
490 |
+
est_midi_notes += [piano_note + self.begin_note] * len(est_tuples_per_note)
|
491 |
+
|
492 |
+
est_tuples = np.array(est_tuples) # (notes, 5)
|
493 |
+
"""(notes, 5), the five columns are onset, offset, onset_shift,
|
494 |
+
offset_shift and normalized_velocity"""
|
495 |
+
|
496 |
+
est_midi_notes = np.array(est_midi_notes) # (notes,)
|
497 |
+
|
498 |
+
onset_times = (est_tuples[:, 0] + est_tuples[:, 2]) / self.frames_per_second
|
499 |
+
offset_times = (est_tuples[:, 1] + est_tuples[:, 3]) / self.frames_per_second
|
500 |
+
velocities = est_tuples[:, 4]
|
501 |
+
|
502 |
+
est_on_off_note_vels = np.stack((onset_times, offset_times, est_midi_notes, velocities), axis=-1)
|
503 |
+
"""(notes, 3), the three columns are onset_times, offset_times and velocity."""
|
504 |
+
|
505 |
+
est_on_off_note_vels = est_on_off_note_vels.astype(np.float32)
|
506 |
+
|
507 |
+
return est_on_off_note_vels
|
508 |
+
|
509 |
+
def detected_notes_to_events(self, est_on_off_note_vels):
|
510 |
+
"""Reformat detected notes to midi events.
|
511 |
+
|
512 |
+
Args:
|
513 |
+
est_on_off_vels: (notes, 3), the three columns are onset_times,
|
514 |
+
offset_times and velocity. E.g.
|
515 |
+
[[32.8376, 35.7700, 0.7932],
|
516 |
+
[37.3712, 39.9300, 0.8058],
|
517 |
+
...]
|
518 |
+
|
519 |
+
Returns:
|
520 |
+
midi_events, list, e.g.,
|
521 |
+
[{'onset_time': 39.7376, 'offset_time': 39.75, 'midi_note': 27, 'velocity': 84},
|
522 |
+
{'onset_time': 11.9824, 'offset_time': 12.50, 'midi_note': 33, 'velocity': 88},
|
523 |
+
...]
|
524 |
+
"""
|
525 |
+
midi_events = []
|
526 |
+
for i in range(est_on_off_note_vels.shape[0]):
|
527 |
+
midi_events.append({
|
528 |
+
'onset_time': est_on_off_note_vels[i][0],
|
529 |
+
'offset_time': est_on_off_note_vels[i][1],
|
530 |
+
'midi_note': int(est_on_off_note_vels[i][2]),
|
531 |
+
'velocity': int(est_on_off_note_vels[i][3] * self.velocity_scale)})
|
532 |
+
|
533 |
+
return midi_events
|
musc/representations.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mir_eval import melody
|
2 |
+
import numpy as np
|
3 |
+
from scipy.stats import norm
|
4 |
+
import librosa
|
5 |
+
import pretty_midi
|
6 |
+
from scipy.ndimage import gaussian_filter1d
|
7 |
+
|
8 |
+
|
9 |
+
class PerformanceLabel:
|
10 |
+
"""
|
11 |
+
The dataset labeling class for performance representations. Currently, includes onset, note, and fine-grained f0
|
12 |
+
representations. Note min, note max, and f0_bin_per_semitone values are to be arranged per instrument. The default
|
13 |
+
values are for violin performance analysis. Fretted instruments might not require such f0 resolutions per semitone.
|
14 |
+
"""
|
15 |
+
def __init__(self, note_min='F#3', note_max='C8', f0_bins_per_semitone=9, f0_smooth_std_c=None,
|
16 |
+
onset_smooth_std=0.7, f0_tolerance_c=200):
|
17 |
+
midi_min = pretty_midi.note_name_to_number(note_min)
|
18 |
+
midi_max = pretty_midi.note_name_to_number(note_max)
|
19 |
+
self.midi_centers = np.arange(midi_min, midi_max)
|
20 |
+
self.onset_smooth_std=onset_smooth_std # onset smoothing along time axis (compensate for alignment)
|
21 |
+
|
22 |
+
f0_hz_range = librosa.note_to_hz([note_min, note_max])
|
23 |
+
f0_c_min, f0_c_max = melody.hz2cents(f0_hz_range)
|
24 |
+
self.f0_granularity_c = 100/f0_bins_per_semitone
|
25 |
+
if not f0_smooth_std_c:
|
26 |
+
f0_smooth_std_c = self.f0_granularity_c * 5/4 # Keep the ratio from the CREPE paper (20 cents and 25 cents)
|
27 |
+
self.f0_smooth_std_c = f0_smooth_std_c
|
28 |
+
|
29 |
+
self.f0_centers_c = np.arange(f0_c_min, f0_c_max, self.f0_granularity_c)
|
30 |
+
self.f0_centers_hz = 10 * 2 ** (self.f0_centers_c / 1200)
|
31 |
+
self.f0_n_bins = len(self.f0_centers_c)
|
32 |
+
|
33 |
+
self.pdf_normalizer = norm.pdf(0)
|
34 |
+
|
35 |
+
self.f0_c2hz = lambda c: 10*2**(c/1200)
|
36 |
+
self.f0_hz2c = melody.hz2cents
|
37 |
+
self.midi_centers_c = self.f0_hz2c(librosa.midi_to_hz(self.midi_centers))
|
38 |
+
|
39 |
+
self.f0_tolerance_bins = int(f0_tolerance_c/self.f0_granularity_c)
|
40 |
+
self.f0_transition_matrix = gaussian_filter1d(np.eye(2*self.f0_tolerance_bins + 1), 25/self.f0_granularity_c)
|
41 |
+
|
42 |
+
def f0_c2label(self, pitch_c):
|
43 |
+
"""
|
44 |
+
Convert a single f0 value in cents to a one-hot label vector with smoothing (i.e., create a gaussian blur around
|
45 |
+
the target f0 bin for regularization and training stability. The blur is controlled by self.f0_smooth_std_c
|
46 |
+
:param pitch_c: a single pitch value in cents
|
47 |
+
:return: one-hot label vector with frequency blur
|
48 |
+
"""
|
49 |
+
result = norm.pdf((self.f0_centers_c - pitch_c) / self.f0_smooth_std_c).astype(np.float32)
|
50 |
+
result /= self.pdf_normalizer
|
51 |
+
return result
|
52 |
+
|
53 |
+
def f0_label2c(self, salience, center=None):
|
54 |
+
"""
|
55 |
+
Convert the salience predictions to monophonic f0 in cents. Only outputs a single f0 value per frame!
|
56 |
+
:param salience: f0 activations
|
57 |
+
:param center: f0 center bin to calculate the weighted average. Use argmax if empty
|
58 |
+
:return: f0 array per frame (in cents).
|
59 |
+
"""
|
60 |
+
if salience.ndim == 1:
|
61 |
+
if center is None:
|
62 |
+
center = int(np.argmax(salience))
|
63 |
+
start = max(0, center - 4)
|
64 |
+
end = min(len(salience), center + 5)
|
65 |
+
salience = salience[start:end]
|
66 |
+
product_sum = np.sum(salience * self.f0_centers_c[start:end])
|
67 |
+
weight_sum = np.sum(salience)
|
68 |
+
return product_sum / np.clip(weight_sum, 1e-8, None)
|
69 |
+
if salience.ndim == 2:
|
70 |
+
return np.array([self.f0_label2c(salience[i, :]) for i in range(salience.shape[0])])
|
71 |
+
raise Exception("label should be either 1d or 2d ndarray")
|
72 |
+
|
73 |
+
def fill_onset_matrix(self, onsets, window, feature_rate):
|
74 |
+
"""
|
75 |
+
Create a sparse onset matrix from window and onsets (per-semitone). Apply a gaussian smoothing (along time)
|
76 |
+
so that we can tolerate better the alignment problems. This is similar to the frequency smoothing for the f0.
|
77 |
+
The temporal smoothing is controlled by the parameter self.onset_smooth_std
|
78 |
+
:param onsets: A 2d np.array of individual note onsets with their respective time values
|
79 |
+
(Nx2: time in seconds - midi number)
|
80 |
+
:param window: Timestamps for the frame centers of the sparse matrix
|
81 |
+
:param feature_rate: Window timestamps are integer, this is to convert them to seconds
|
82 |
+
:return: onset_roll: A sparse matrix filled with temporally blurred onsets.
|
83 |
+
"""
|
84 |
+
onsets = self.get_window_feats(onsets, window, feature_rate)
|
85 |
+
onset_roll = np.zeros((len(window), len(self.midi_centers)))
|
86 |
+
for onset in onsets:
|
87 |
+
onset, note = onset # it was a pair with time and midi note
|
88 |
+
if self.midi_centers[0] < note < self.midi_centers[-1]: # midi note should be in the range defined
|
89 |
+
note = int(note) - self.midi_centers[0] # find the note index in our range
|
90 |
+
onset = (onset*feature_rate)-window[0] # onset index (as float but in frames, not in seconds!)
|
91 |
+
start = max(0, int(onset) - 3)
|
92 |
+
end = min(len(window) - 1, int(onset) + 3)
|
93 |
+
try:
|
94 |
+
vals = norm.pdf(np.linspace(start - onset, end - onset, end - start + 1) / self.onset_smooth_std)
|
95 |
+
# if you increase 0.7 you smooth the peak
|
96 |
+
# if you decrease it, e.g., 0.1, it becomes too peaky! around 0.5-0.7 seems ok
|
97 |
+
vals /= self.pdf_normalizer
|
98 |
+
onset_roll[start:end + 1, note] += vals
|
99 |
+
except ValueError:
|
100 |
+
print('start',start, 'onset', onset, 'end', end)
|
101 |
+
return onset_roll, onsets
|
102 |
+
|
103 |
+
def fill_note_matrix(self, notes, window, feature_rate):
|
104 |
+
"""
|
105 |
+
Create the note matrix (piano roll) from window timestamps and note values per frame.
|
106 |
+
:param notes: A 2d np.array of individual notes with their active time values Nx2
|
107 |
+
:param window: Timestamps for the frame centers of the output
|
108 |
+
:param feature_rate: Window timestamps are integer, this is to convert them to seconds
|
109 |
+
:return note_roll: The piano roll in the defined range of [note_min, note_max).
|
110 |
+
"""
|
111 |
+
notes = self.get_window_feats(notes, window, feature_rate)
|
112 |
+
|
113 |
+
# take the notes in the midi range defined
|
114 |
+
notes = notes[np.logical_and(notes[:,1]>=self.midi_centers[0], notes[:,1]<=self.midi_centers[-1]),:]
|
115 |
+
|
116 |
+
times = (notes[:,0]*feature_rate - window[0]).astype(int) # in feature samples (fs:self.hop/self.sr)
|
117 |
+
notes = (notes[:,1] - self.midi_centers[0]).astype(int)
|
118 |
+
|
119 |
+
note_roll = np.zeros((len(window), len(self.midi_centers)))
|
120 |
+
note_roll[(times, notes)] = 1
|
121 |
+
return note_roll, notes
|
122 |
+
|
123 |
+
|
124 |
+
def fill_f0_matrix(self, f0s, window, feature_rate):
|
125 |
+
"""
|
126 |
+
Unlike the labels for onsets and notes, f0 label is only relevant for strictly monophonic regions! Thus, this
|
127 |
+
function returns a boolean which represents where to apply the given values.
|
128 |
+
Never back-propagate without the boolean! Empty frames mean that the label is not that reliable.
|
129 |
+
|
130 |
+
:param f0s: A 2d np.array of f0 values with the time they belong to (2xN: time in seconds - f0 in Hz)
|
131 |
+
:param window: Timestamps for the frame centers of the output
|
132 |
+
:param feature_rate: Window timestamps are integer, this is to convert them to seconds
|
133 |
+
|
134 |
+
:return f0_roll: f0 label matrix and
|
135 |
+
f0_hz: f0 values in Hz
|
136 |
+
annotation_bool: A boolean array representing which frames have reliable f0 annotations.
|
137 |
+
"""
|
138 |
+
f0s = self.get_window_feats(f0s, window, feature_rate)
|
139 |
+
f0_cents = np.zeros_like(window, dtype=float)
|
140 |
+
f0s[:,1] = self.f0_hz2c(f0s[:,1]) # convert f0 in hz to cents
|
141 |
+
|
142 |
+
annotation_bool = np.zeros_like(window, dtype=bool)
|
143 |
+
f0_roll = np.zeros((len(window), len(self.f0_centers_c)))
|
144 |
+
times_in_frame = f0s[:, 0]*feature_rate - window[0]
|
145 |
+
for t, f0 in enumerate(f0s):
|
146 |
+
t = times_in_frame[t]
|
147 |
+
if t%1 < 0.25: # only consider it as annotation if the f0 values is really close to the frame center
|
148 |
+
t = int(np.round(t))
|
149 |
+
f0_roll[t] = self.f0_c2label(f0[1])
|
150 |
+
annotation_bool[t] = True
|
151 |
+
f0_cents[t] = f0[1]
|
152 |
+
|
153 |
+
return f0_roll, f0_cents, annotation_bool
|
154 |
+
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def get_window_feats(time_feature_matrix, window, feature_rate):
|
158 |
+
"""
|
159 |
+
Restrict the feature matrix to the features that are inside the window
|
160 |
+
:param window: Timestamps for the frame centers of the output
|
161 |
+
:param time_feature_matrix: A 2d array of Nx2 per the entire file.
|
162 |
+
:param feature_rate: Window timestamps are integer, this is to convert them to seconds
|
163 |
+
:return: window_features: the features inside the given window
|
164 |
+
"""
|
165 |
+
start = time_feature_matrix[:,0]>(window[0]-0.5)/feature_rate
|
166 |
+
end = time_feature_matrix[:,0]<(window[-1]+0.5)/feature_rate
|
167 |
+
window_features = np.logical_and(start, end)
|
168 |
+
window_features = np.array(time_feature_matrix[window_features,:])
|
169 |
+
return window_features
|
170 |
+
|
171 |
+
def represent_midi(self, midi, feature_rate):
|
172 |
+
"""
|
173 |
+
Represent a midi file as sparse matrices of onsets, offsets, and notes. No f0 is included.
|
174 |
+
:param midi: A midi file (either a path or a pretty_midi.PrettyMIDI object)
|
175 |
+
:param feature_rate: The feature rate in Hz
|
176 |
+
:return: dict {onset, offset, note, time}: Same format with the model's learning and outputs
|
177 |
+
"""
|
178 |
+
def _get_onsets_offsets_frames(midi_content):
|
179 |
+
if isinstance(midi_content, str):
|
180 |
+
midi_content = pretty_midi.PrettyMIDI(midi_content)
|
181 |
+
onsets = []
|
182 |
+
offsets = []
|
183 |
+
frames = []
|
184 |
+
for instrument in midi_content.instruments:
|
185 |
+
for note in instrument.notes:
|
186 |
+
start = int(np.round(note.start * feature_rate))
|
187 |
+
end = int(np.round(note.end * feature_rate))
|
188 |
+
note_times = (np.arange(start, end+0.5)/feature_rate)[:, np.newaxis]
|
189 |
+
note_pitch = np.full_like(note_times, fill_value=note.pitch)
|
190 |
+
onsets.append([note.start, note.pitch])
|
191 |
+
offsets.append([note.end, note.pitch])
|
192 |
+
frames.append(np.hstack([note_times, note_pitch]))
|
193 |
+
onsets = np.vstack(onsets)
|
194 |
+
offsets = np.vstack(offsets)
|
195 |
+
frames = np.vstack(frames)
|
196 |
+
return onsets, offsets, frames, midi_content
|
197 |
+
onset_array, offset_array, frame_array, midi_object = _get_onsets_offsets_frames(midi)
|
198 |
+
window = np.arange(frame_array[0, 0]*feature_rate, frame_array[-1, 0]*feature_rate, dtype=int)
|
199 |
+
onset_roll, _ = self.fill_onset_matrix(onset_array, window, feature_rate)
|
200 |
+
offset_roll, _ = self.fill_onset_matrix(offset_array, window, feature_rate)
|
201 |
+
note_roll, _ = self.fill_note_matrix(frame_array, window, feature_rate)
|
202 |
+
start_anchor = onset_array[onset_array[:, 0]==np.min(onset_array[:, 0])]
|
203 |
+
end_anchor = offset_array[offset_array[:, 0]==np.max(offset_array[:, 0])]
|
204 |
+
return {
|
205 |
+
'midi': midi_object,
|
206 |
+
'note': note_roll,
|
207 |
+
'onset': onset_roll,
|
208 |
+
'offset': offset_roll,
|
209 |
+
'time': window/feature_rate,
|
210 |
+
'start_anchor': start_anchor,
|
211 |
+
'end_anchor': end_anchor
|
212 |
+
}
|
musc/synchronizer.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from musc.dtw.mrmsdtw import sync_via_mrmsdtw_with_anchors
|
2 |
+
from musc.dtw.utils import make_path_strictly_monotonic
|
3 |
+
import numpy as np
|
4 |
+
from musc.transcriber import Transcriber
|
5 |
+
from typing import Dict
|
6 |
+
|
7 |
+
class Synchronizer(Transcriber):
|
8 |
+
def __init__(self, labeling, instrument='Violin', sr=16000, window_size=1024, hop_length=160):
|
9 |
+
super().__init__(labeling, instrument=instrument, sr=sr, window_size=window_size, hop_length=hop_length)
|
10 |
+
def synchronize(self, audio, midi, batch_size=128, include_pitch_bends=True, to_midi=True, debug=False,
|
11 |
+
include_velocity=False, alignment_padding=50, timing_refinement_range_with_f0s=0):
|
12 |
+
"""
|
13 |
+
Synchronize an audio file or mono waveform in numpy or torch with a MIDI file.
|
14 |
+
:param audio: str, pathlib.Path, np.ndarray, or torch.Tensor
|
15 |
+
:param midi: str, pathlib.Path, or pretty_midi.PrettyMIDI
|
16 |
+
:param batch_size: frames to process at once
|
17 |
+
:param include_pitch_bends: whether to include pitch bends in the MIDI file
|
18 |
+
:param to_midi: whether to return a MIDI file or a list of note events (as tuple)
|
19 |
+
:param debug: whether to plot the alignment path and compare the alignment with the predicted notes
|
20 |
+
:param include_velocity: whether to embed the note confidence in place of the velocity in the MIDI file
|
21 |
+
:param alignment_padding: how many frames to pad the audio and MIDI representations with
|
22 |
+
:param timing_refinement_range_with_f0s: how many frames to refine the alignment with the f0 confidence
|
23 |
+
:return: aligned MIDI file as a pretty_midi.PrettyMIDI object
|
24 |
+
|
25 |
+
Args:
|
26 |
+
debug:
|
27 |
+
to_midi:
|
28 |
+
include_pitch_bends:
|
29 |
+
"""
|
30 |
+
|
31 |
+
audio = self.predict(audio, batch_size)
|
32 |
+
notes_and_midi = self.out2sync(audio, midi, include_velocity=include_velocity,
|
33 |
+
alignment_padding=alignment_padding)
|
34 |
+
if notes_and_midi: # it might be none
|
35 |
+
notes, midi = notes_and_midi
|
36 |
+
|
37 |
+
if debug:
|
38 |
+
import matplotlib.pyplot as plt
|
39 |
+
import pandas as pd
|
40 |
+
estimated_notes = self.out2note(audio, postprocessing='spotify', include_pitch_bends=True)
|
41 |
+
est_df = pd.DataFrame(estimated_notes).sort_values(by=0)
|
42 |
+
note_df = pd.DataFrame(notes).sort_values(by=0)
|
43 |
+
|
44 |
+
fig, ax = plt.subplots(figsize=(20, 10))
|
45 |
+
|
46 |
+
for row in notes:
|
47 |
+
t_start = row[0] # sec
|
48 |
+
t_end = row[1] # sec
|
49 |
+
freq = row[2] # Hz
|
50 |
+
ax.hlines(freq, t_start, t_end, color='k', linewidth=3, zorder=2, alpha=0.5)
|
51 |
+
|
52 |
+
for row in estimated_notes:
|
53 |
+
t_start = row[0] # sec
|
54 |
+
t_end = row[1] # sec
|
55 |
+
freq = row[2] # Hz
|
56 |
+
ax.hlines(freq, t_start, t_end, color='r', linewidth=3, zorder=2, alpha=0.5)
|
57 |
+
fig.suptitle('alignment (black) vs. estimated (red)')
|
58 |
+
fig.show()
|
59 |
+
|
60 |
+
if not include_pitch_bends:
|
61 |
+
if to_midi:
|
62 |
+
return midi['midi']
|
63 |
+
else:
|
64 |
+
return notes
|
65 |
+
else:
|
66 |
+
notes = [(np.argmin(np.abs(audio['time']-note[0])),
|
67 |
+
np.argmin(np.abs(audio['time']-note[1])),
|
68 |
+
note[2], note[3]) for note in notes]
|
69 |
+
notes = self.get_pitch_bends(audio["f0"], notes, timing_refinement_range_with_f0s)
|
70 |
+
notes = [
|
71 |
+
(audio['time'][note[0]], audio['time'][note[1]], note[2], note[3], note[4]) for note in
|
72 |
+
notes
|
73 |
+
]
|
74 |
+
if to_midi:
|
75 |
+
return self.note2midi(notes, 120) #int(midi['midi'].estimate_tempo()))
|
76 |
+
else:
|
77 |
+
return notes
|
78 |
+
|
79 |
+
def out2sync_old(self, out: Dict[str, np.array], midi, include_velocity=False, alignment_padding=50, debug=False):
|
80 |
+
"""
|
81 |
+
Synchronizes the output of the model with the MIDI file.
|
82 |
+
Args:
|
83 |
+
out: Model output dictionary
|
84 |
+
midi: Path to the MIDI file or PrettyMIDI object
|
85 |
+
include_velocity: Whether to encode the note confidence in place of velocity
|
86 |
+
alignment_padding: Number of frames to pad the MIDI features with zeros
|
87 |
+
debug: Visualize the alignment
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
note events and the aligned PrettyMIDI object
|
91 |
+
"""
|
92 |
+
midi = self.labeling.represent_midi(midi, self.sr/self.hop_length)
|
93 |
+
|
94 |
+
audio_midi_anchors = self.prepare_for_synchronization(out, midi, feature_rate=self.sr/self.hop_length,
|
95 |
+
pad_length=alignment_padding)
|
96 |
+
if isinstance(audio_midi_anchors, str):
|
97 |
+
print(audio_midi_anchors)
|
98 |
+
return None # the file is corrupted! no possible alignment at all
|
99 |
+
else:
|
100 |
+
audio, midi, anchor_pairs = audio_midi_anchors
|
101 |
+
|
102 |
+
ALPHA = 0.6 # This is the coefficient of onsets, 1 - ALPHA for offsets
|
103 |
+
|
104 |
+
wp = sync_via_mrmsdtw_with_anchors(f_chroma1=audio['note'].T,
|
105 |
+
f_onset1=np.hstack([ALPHA * audio['onset'],
|
106 |
+
(1 - ALPHA) * audio['offset']]).T,
|
107 |
+
f_chroma2=midi['note'].T,
|
108 |
+
f_onset2=np.hstack([ALPHA * midi['onset'],
|
109 |
+
(1 - ALPHA) * midi['offset']]).T,
|
110 |
+
input_feature_rate=self.sr/self.hop_length,
|
111 |
+
step_weights=np.array([1.5, 1.5, 2.0]),
|
112 |
+
threshold_rec=10 ** 6,
|
113 |
+
verbose=debug, normalize_chroma=False,
|
114 |
+
anchor_pairs=anchor_pairs)
|
115 |
+
wp = make_path_strictly_monotonic(wp).astype(int)
|
116 |
+
|
117 |
+
audio_time = np.take(audio['time'], wp[0])
|
118 |
+
midi_time = np.take(midi['time'], wp[1])
|
119 |
+
|
120 |
+
notes = []
|
121 |
+
for instrument in midi['midi'].instruments:
|
122 |
+
for note in instrument.notes:
|
123 |
+
note.start = np.interp(note.start, midi_time, audio_time)
|
124 |
+
note.end = np.interp(note.end, midi_time, audio_time)
|
125 |
+
|
126 |
+
if note.end - note.start <= 0.012: # notes should be at least 12 ms (i.e. 2 frames)
|
127 |
+
note.start = note.start - 0.003
|
128 |
+
note.end = note.start + 0.012
|
129 |
+
|
130 |
+
if include_velocity: # encode the note confidence in place of velocity
|
131 |
+
velocity = np.median(audio['note'][np.argmin(np.abs(audio['time']-note.start)):
|
132 |
+
np.argmin(np.abs(audio['time']-note.end)),
|
133 |
+
note.pitch-self.labeling.midi_centers[0]])
|
134 |
+
|
135 |
+
note.velocity = max(1, velocity*127) # velocity should be at least 1 otherwise midi removes the note
|
136 |
+
else:
|
137 |
+
velocity = note.velocity/127
|
138 |
+
notes.append((note.start, note.end, note.pitch, velocity))
|
139 |
+
return notes, midi
|
140 |
+
|
141 |
+
|
142 |
+
def out2sync(self, out: Dict[str, np.array], midi, include_velocity=False, alignment_padding=50, debug=False):
|
143 |
+
"""
|
144 |
+
Synchronizes the output of the model with the MIDI file.
|
145 |
+
Args:
|
146 |
+
out: Model output dictionary
|
147 |
+
midi: Path to the MIDI file or PrettyMIDI object
|
148 |
+
include_velocity: Whether to encode the note confidence in place of velocity
|
149 |
+
alignment_padding: Number of frames to pad the MIDI features with zeros
|
150 |
+
debug: Visualize the alignment
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
note events and the aligned PrettyMIDI object
|
154 |
+
"""
|
155 |
+
midi = self.labeling.represent_midi(midi, self.sr/self.hop_length)
|
156 |
+
|
157 |
+
audio_midi_anchors = self.prepare_for_synchronization(out, midi, feature_rate=self.sr/self.hop_length,
|
158 |
+
pad_length=alignment_padding)
|
159 |
+
if isinstance(audio_midi_anchors, str):
|
160 |
+
print(audio_midi_anchors)
|
161 |
+
return None # the file is corrupted! no possible alignment at all
|
162 |
+
else:
|
163 |
+
audio, midi, anchor_pairs = audio_midi_anchors
|
164 |
+
|
165 |
+
ALPHA = 0.6 # This is the coefficient of onsets, 1 - ALPHA for offsets
|
166 |
+
|
167 |
+
starts = (np.array(anchor_pairs[0])*self.sr/self.hop_length).astype(int)
|
168 |
+
ends = (np.array(anchor_pairs[1])*self.sr/self.hop_length).astype(int)
|
169 |
+
|
170 |
+
wp = sync_via_mrmsdtw_with_anchors(f_chroma1=audio['note'].T[:, starts[0]:ends[0]],
|
171 |
+
f_onset1=np.hstack([ALPHA * audio['onset'],
|
172 |
+
(1 - ALPHA) * audio['offset']]).T[:, starts[0]:ends[0]],
|
173 |
+
f_chroma2=midi['note'].T[:, starts[1]:ends[1]],
|
174 |
+
f_onset2=np.hstack([ALPHA * midi['onset'],
|
175 |
+
(1 - ALPHA) * midi['offset']]).T[:, starts[1]:ends[1]],
|
176 |
+
input_feature_rate=self.sr/self.hop_length,
|
177 |
+
step_weights=np.array([1.5, 1.5, 2.0]),
|
178 |
+
threshold_rec=10 ** 6,
|
179 |
+
verbose=debug, normalize_chroma=False,
|
180 |
+
anchor_pairs=None)
|
181 |
+
wp = make_path_strictly_monotonic(wp).astype(int)
|
182 |
+
wp[0] += starts[0]
|
183 |
+
wp[1] += starts[1]
|
184 |
+
wp = np.hstack((wp, ends[:,np.newaxis]))
|
185 |
+
|
186 |
+
audio_time = np.take(audio['time'], wp[0])
|
187 |
+
midi_time = np.take(midi['time'], wp[1])
|
188 |
+
|
189 |
+
notes = []
|
190 |
+
for instrument in midi['midi'].instruments:
|
191 |
+
for note in instrument.notes:
|
192 |
+
note.start = np.interp(note.start, midi_time, audio_time)
|
193 |
+
note.end = np.interp(note.end, midi_time, audio_time)
|
194 |
+
|
195 |
+
if note.end - note.start <= 0.012: # notes should be at least 12 ms (i.e. 2 frames)
|
196 |
+
note.start = note.start - 0.003
|
197 |
+
note.end = note.start + 0.012
|
198 |
+
|
199 |
+
if include_velocity: # encode the note confidence in place of velocity
|
200 |
+
velocity = np.median(audio['note'][np.argmin(np.abs(audio['time']-note.start)):
|
201 |
+
np.argmin(np.abs(audio['time']-note.end)),
|
202 |
+
note.pitch-self.labeling.midi_centers[0]])
|
203 |
+
|
204 |
+
note.velocity = max(1, velocity*127) # velocity should be at least 1 otherwise midi removes the note
|
205 |
+
else:
|
206 |
+
velocity = note.velocity/127
|
207 |
+
notes.append((note.start, note.end, note.pitch, velocity))
|
208 |
+
return notes, midi
|
209 |
+
|
210 |
+
@staticmethod
|
211 |
+
def pad_representations(dict_of_representations, pad_length=10):
|
212 |
+
"""
|
213 |
+
Pad the representations so that the DTW does not enforce them to encompass the entire duration.
|
214 |
+
Args:
|
215 |
+
dict_of_representations: audio or midi representations
|
216 |
+
pad_length: how many frames to pad
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
padded representations
|
220 |
+
"""
|
221 |
+
for key, value in dict_of_representations.items():
|
222 |
+
if key == 'time':
|
223 |
+
padded_time = dict_of_representations[key]
|
224 |
+
padded_time = np.concatenate([padded_time[:2*pad_length], padded_time+padded_time[2*pad_length]])
|
225 |
+
dict_of_representations[key] = padded_time - padded_time[pad_length] # this is to ensure that the
|
226 |
+
# first frame times are negative until the real zero time
|
227 |
+
elif key in ['onset', 'offset', 'note']:
|
228 |
+
dict_of_representations[key] = np.pad(value, ((pad_length, pad_length), (0, 0)))
|
229 |
+
elif key in ['start_anchor', 'end_anchor']:
|
230 |
+
anchor_time = dict_of_representations[key][0][0]
|
231 |
+
anchor_time = np.argmin(np.abs(dict_of_representations['time'] - anchor_time))
|
232 |
+
dict_of_representations[key][:,0] = anchor_time
|
233 |
+
dict_of_representations[key] = dict_of_representations[key].astype(np.int)
|
234 |
+
return dict_of_representations
|
235 |
+
|
236 |
+
def prepare_for_synchronization(self, audio, midi, feature_rate=44100/256, pad_length=100):
|
237 |
+
"""
|
238 |
+
MrMsDTW works better with start and end anchors. This function finds the start and end anchors for audio
|
239 |
+
based on the midi notes. It also pads the MIDI representations since MIDI files most often start with an active
|
240 |
+
note and end with an active note. Thus, the DTW will try to align the active notes to the entire duration of the
|
241 |
+
audio. This is not desirable. Therefore, we pad the MIDI representations with a few frames of silence at the
|
242 |
+
beginning and end of the audio. This way, the DTW will not try to align the active notes to the entire duration.
|
243 |
+
Args:
|
244 |
+
audio:
|
245 |
+
midi:
|
246 |
+
feature_rate:
|
247 |
+
pad_length:
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
|
251 |
+
"""
|
252 |
+
# first pad the MIDI
|
253 |
+
midi = self.pad_representations(midi, pad_length)
|
254 |
+
|
255 |
+
# sometimes f0s are more reliable than the notes. So, we use both the f0s and the notes together to find the
|
256 |
+
# start and end anchors. f0 lookup bins is the number of bins to look around the f0 to assign a note to it.
|
257 |
+
f0_lookup_bins = int(100//(2*self.labeling.f0_granularity_c))
|
258 |
+
|
259 |
+
# find the start anchor for the audio
|
260 |
+
# first decide on which notes to use for the start anchor (take the entire chord where the MIDI file starts)
|
261 |
+
anchor_notes = midi['start_anchor'][:, 1] - self.labeling.midi_centers[0]
|
262 |
+
# now find which f0 bins to look at for the start anchor
|
263 |
+
anchor_f0s = [self.midi_pitch_to_contour_bin(an+self.labeling.midi_centers[0]) for an in anchor_notes]
|
264 |
+
anchor_f0s = np.array([list(range(f0-f0_lookup_bins, f0+f0_lookup_bins+1)) for f0 in anchor_f0s]).reshape(-1)
|
265 |
+
# first start anchor proposals come from the notes
|
266 |
+
anchor_vals = np.any(audio['note'][:, anchor_notes]>0.5, axis=1)
|
267 |
+
# now the f0s
|
268 |
+
anchor_vals_f0 = np.any(audio['f0'][:, anchor_f0s]>0.5, axis=1)
|
269 |
+
# combine the two
|
270 |
+
anchor_vals = np.logical_or(anchor_vals, anchor_vals_f0)
|
271 |
+
if not any(anchor_vals):
|
272 |
+
return 'corrupted' # do not consider the file if we cannot find the start anchor
|
273 |
+
audio_start = np.argmax(anchor_vals)
|
274 |
+
|
275 |
+
# now the end anchor (most string instruments use chords in cadences: in general the end anchor is polyphonic)
|
276 |
+
anchor_notes = midi['end_anchor'][:, 1] - self.labeling.midi_centers[0]
|
277 |
+
anchor_f0s = [self.midi_pitch_to_contour_bin(an+self.labeling.midi_centers[0]) for an in anchor_notes]
|
278 |
+
anchor_f0s = np.array([list(range(f0-f0_lookup_bins, f0+f0_lookup_bins+1)) for f0 in anchor_f0s]).reshape(-1)
|
279 |
+
# the same procedure as above
|
280 |
+
anchor_vals = np.any(audio['note'][::-1, anchor_notes]>0.5, axis=1)
|
281 |
+
anchor_vals_f0 = np.any(audio['f0'][::-1, anchor_f0s]>0.5, axis=1)
|
282 |
+
anchor_vals = np.logical_or(anchor_vals, anchor_vals_f0)
|
283 |
+
if not any(anchor_vals):
|
284 |
+
return 'corrupted' # do not consider the file if we cannot find the end anchor
|
285 |
+
audio_end = audio['note'].shape[0] - np.argmax(anchor_vals)
|
286 |
+
|
287 |
+
if audio_end - audio_start < (midi['end_anchor'][0][0] - midi['start_anchor'][0][0])/10: # no one plays x10 faster
|
288 |
+
return 'corrupted' # do not consider the interval between anchors is too short
|
289 |
+
anchor_pairs = [(audio_start - 5, midi['start_anchor'][0][0] - 5),
|
290 |
+
(audio_end + 5, midi['end_anchor'][0][0] + 5)]
|
291 |
+
|
292 |
+
if anchor_pairs[0][0] < 1:
|
293 |
+
anchor_pairs[0] = (1, midi['start_anchor'][0][0])
|
294 |
+
if anchor_pairs[1][0] > audio['note'].shape[0] - 1:
|
295 |
+
anchor_pairs[1] = (audio['note'].shape[0] - 1, midi['end_anchor'][0][0])
|
296 |
+
|
297 |
+
return audio, midi, [(anchor_pairs[0][0]/feature_rate, anchor_pairs[0][1]/feature_rate),
|
298 |
+
(anchor_pairs[1][0]/feature_rate, anchor_pairs[1][1]/feature_rate)]
|
299 |
+
|
musc/transcriber.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from typing import DefaultDict, Dict, List, Optional, Tuple
|
3 |
+
import pretty_midi
|
4 |
+
import numpy as np
|
5 |
+
from musc.postprocessing import RegressionPostProcessor, spotify_create_notes
|
6 |
+
from musc.pitch_estimator import PitchEstimator
|
7 |
+
|
8 |
+
|
9 |
+
class Transcriber(PitchEstimator):
|
10 |
+
def __init__(self, labeling, instrument='Violin', sr=16000, window_size=1024, hop_length=160):
|
11 |
+
super().__init__(labeling, instrument=instrument, sr=sr, window_size=window_size, hop_length=hop_length)
|
12 |
+
|
13 |
+
def transcribe(self, audio, batch_size=128, postprocessing='spotify', include_pitch_bends=True, to_midi=True,
|
14 |
+
debug=False):
|
15 |
+
"""
|
16 |
+
Transcribe an audio file or mono waveform in numpy or torch into MIDI with pitch bends.
|
17 |
+
:param audio: str, pathlib.Path, np.ndarray, or torch.Tensor
|
18 |
+
:param batch_size: frames to process at once
|
19 |
+
:param postprocessing: note creation method. 'spotify'(default) or 'tiktok'
|
20 |
+
:param include_pitch_bends: whether to include pitch bends in the MIDI file
|
21 |
+
:param to_midi: whether to return a MIDI file or a list of note events (as tuple)
|
22 |
+
:return: transcribed MIDI file as a pretty_midi.PrettyMIDI object
|
23 |
+
"""
|
24 |
+
out = self.predict(audio, batch_size)
|
25 |
+
if debug:
|
26 |
+
import matplotlib.pyplot as plt
|
27 |
+
plt.imshow(out['f0'].T, aspect='auto', origin='lower')
|
28 |
+
plt.show()
|
29 |
+
plt.imshow(out['note'].T, aspect='auto', origin='lower')
|
30 |
+
plt.show()
|
31 |
+
|
32 |
+
plt.imshow(out['onset'].T, aspect='auto', origin='lower')
|
33 |
+
plt.show()
|
34 |
+
|
35 |
+
plt.imshow(out['offset'].T, aspect='auto', origin='lower')
|
36 |
+
plt.show()
|
37 |
+
|
38 |
+
if to_midi:
|
39 |
+
return self.out2midi(out, postprocessing, include_pitch_bends)
|
40 |
+
else:
|
41 |
+
return self.out2note(out, postprocessing, include_pitch_bends)
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def out2note(self, output: Dict[str, np.array], postprocessing='spotify',
|
46 |
+
include_pitch_bends: bool = True,
|
47 |
+
) -> List[Tuple[float, float, int, float, Optional[List[int]]]]:
|
48 |
+
"""Convert model output to notes
|
49 |
+
"""
|
50 |
+
if postprocessing == 'spotify':
|
51 |
+
estimated_notes = spotify_create_notes(
|
52 |
+
output["note"],
|
53 |
+
output["onset"],
|
54 |
+
note_low=self.labeling.midi_centers[0],
|
55 |
+
note_high=self.labeling.midi_centers[-1],
|
56 |
+
onset_thresh=0.5,
|
57 |
+
frame_thresh=0.3,
|
58 |
+
infer_onsets=True,
|
59 |
+
min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))), #127.70
|
60 |
+
melodia_trick=True,
|
61 |
+
)
|
62 |
+
|
63 |
+
if postprocessing == 'rebab':
|
64 |
+
estimated_notes = spotify_create_notes(
|
65 |
+
output["note"],
|
66 |
+
output["onset"],
|
67 |
+
note_low=self.labeling.midi_centers[0],
|
68 |
+
note_high=self.labeling.midi_centers[-1],
|
69 |
+
onset_thresh=0.2,
|
70 |
+
frame_thresh=0.2,
|
71 |
+
infer_onsets=True,
|
72 |
+
min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))), #127.70
|
73 |
+
melodia_trick=True,
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
elif postprocessing == 'tiktok':
|
78 |
+
postprocessor = RegressionPostProcessor(
|
79 |
+
frames_per_second=self.sr / self.hop_length,
|
80 |
+
classes_num=self.labeling.midi_centers.shape[0],
|
81 |
+
begin_note=self.labeling.midi_centers[0],
|
82 |
+
onset_threshold=0.2,
|
83 |
+
offset_threshold=0.2,
|
84 |
+
frame_threshold=0.3,
|
85 |
+
pedal_offset_threshold=0.5,
|
86 |
+
)
|
87 |
+
tiktok_note_dict, _ = postprocessor.output_dict_to_midi_events(output)
|
88 |
+
estimated_notes = []
|
89 |
+
for list_item in tiktok_note_dict:
|
90 |
+
if list_item['offset_time'] > 0.6 + list_item['onset_time']:
|
91 |
+
estimated_notes.append((int(np.floor(list_item['onset_time']/(output['time'][1]))),
|
92 |
+
int(np.ceil(list_item['offset_time']/(output['time'][1]))),
|
93 |
+
list_item['midi_note'], list_item['velocity']/128))
|
94 |
+
if include_pitch_bends:
|
95 |
+
estimated_notes_with_pitch_bend = self.get_pitch_bends(output["f0"], estimated_notes)
|
96 |
+
else:
|
97 |
+
estimated_notes_with_pitch_bend = [(note[0], note[1], note[2], note[3], None) for note in estimated_notes]
|
98 |
+
|
99 |
+
times_s = output['time']
|
100 |
+
estimated_notes_time_seconds = [
|
101 |
+
(times_s[note[0]], times_s[note[1]], note[2], note[3], note[4]) for note in estimated_notes_with_pitch_bend
|
102 |
+
]
|
103 |
+
|
104 |
+
return estimated_notes_time_seconds
|
105 |
+
|
106 |
+
|
107 |
+
def out2midi(self, output: Dict[str, np.array], postprocessing: str = 'spotify', include_pitch_bends: bool = True,
|
108 |
+
) -> pretty_midi.PrettyMIDI:
|
109 |
+
"""Convert model output to MIDI
|
110 |
+
Args:
|
111 |
+
output: A dictionary with shape
|
112 |
+
{
|
113 |
+
'frame': array of shape (n_times, n_freqs),
|
114 |
+
'onset': array of shape (n_times, n_freqs),
|
115 |
+
'contour': array of shape (n_times, 3*n_freqs)
|
116 |
+
}
|
117 |
+
representing the output of the basic pitch model.
|
118 |
+
postprocessing: spotify or tiktok postprocessing.
|
119 |
+
include_pitch_bends: If True, include pitch bends.
|
120 |
+
Returns:
|
121 |
+
note_events: A list of note event tuples (start_time_s, end_time_s, pitch_midi, amplitude)
|
122 |
+
"""
|
123 |
+
estimated_notes_time_seconds = self.out2note(output, postprocessing, include_pitch_bends)
|
124 |
+
midi_tempo = 120 # todo: infer tempo from the onsets
|
125 |
+
return self.note2midi(estimated_notes_time_seconds, midi_tempo)
|
126 |
+
|
127 |
+
|
128 |
+
def note2midi(
|
129 |
+
self, note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]],
|
130 |
+
midi_tempo: float = 120,
|
131 |
+
) -> pretty_midi.PrettyMIDI:
|
132 |
+
"""Create a pretty_midi object from note events
|
133 |
+
:param note_events_with_pitch_bends: list of tuples
|
134 |
+
[(start_time_seconds, end_time_seconds, pitch_midi, amplitude)]
|
135 |
+
:param midi_tempo: #todo: infer tempo from the onsets
|
136 |
+
:return: transcribed MIDI file as a pretty_midi.PrettyMIDI object
|
137 |
+
"""
|
138 |
+
mid = pretty_midi.PrettyMIDI(initial_tempo=midi_tempo)
|
139 |
+
|
140 |
+
program = pretty_midi.instrument_name_to_program(self.instrument)
|
141 |
+
instruments: DefaultDict[int, pretty_midi.Instrument] = defaultdict(
|
142 |
+
lambda: pretty_midi.Instrument(program=program)
|
143 |
+
)
|
144 |
+
for start_time, end_time, note_number, amplitude, pitch_bend in note_events_with_pitch_bends:
|
145 |
+
instrument = instruments[note_number]
|
146 |
+
note = pretty_midi.Note(
|
147 |
+
velocity=int(np.round(127 * amplitude)),
|
148 |
+
pitch=note_number,
|
149 |
+
start=start_time,
|
150 |
+
end=end_time,
|
151 |
+
)
|
152 |
+
instrument.notes.append(note)
|
153 |
+
if not isinstance(pitch_bend, np.ndarray):
|
154 |
+
continue
|
155 |
+
pitch_bend_times = np.linspace(start_time, end_time, len(pitch_bend))
|
156 |
+
|
157 |
+
for pb_time, pb_midi in zip(pitch_bend_times, pitch_bend):
|
158 |
+
instrument.pitch_bends.append(pretty_midi.PitchBend(pb_midi, pb_time))
|
159 |
+
|
160 |
+
mid.instruments.extend(instruments.values())
|
161 |
+
|
162 |
+
return mid
|
163 |
+
|
musc/violin.json
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{ "model_file": "1FfUjC3usmZBoTxNT6rNVIcYw4pol4K1g",
|
2 |
+
"wiring": "parallel",
|
3 |
+
"sampling_rate": 44100,
|
4 |
+
"pathway_multiscale": 4,
|
5 |
+
"num_pathway_layers": 2,
|
6 |
+
"num_separator_layers": 16,
|
7 |
+
"num_representation_layers": 4,
|
8 |
+
"hop_length": 256,
|
9 |
+
"chunk_size": 512,
|
10 |
+
"minSNR": -32, "maxSNR": 96,
|
11 |
+
"note_low": "F#3", "note_high": "E8",
|
12 |
+
"f0_bins_per_semitone": 10, "f0_smooth_std_c": 12, "onset_smooth_std": 0.7
|
13 |
+
}
|
musc/violin_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a913356f059be6dc930be41158ac864f7d5511889ef0b2a6b6ba75a4a8732750
|
3 |
+
size 218770231
|