Hygee commited on
Commit
e72f2a9
·
verified ·
1 Parent(s): 28a5c23

Upload 9 files

Browse files
musc/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __author__ = ""
2
+ __version__ = "0.0.2"
3
+ __description__ = "A Timbre-Aware Pitch Estimator that can transcribe an accompanied target instrument. No source separation, preprocessing, or postprocessing! From multi-instrument raw waveform to high-precision MIDI of the target."
musc/pathway.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn as nn
3
+
4
+
5
+ class ConvBlock(nn.Module):
6
+ def __init__(self, f, w, s, d, in_channels):
7
+ super().__init__()
8
+ p1 = d*(w - 1) // 2
9
+ p2 = d*(w - 1) - p1
10
+ self.pad = nn.ZeroPad2d((0, 0, p1, p2))
11
+
12
+ self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=f, kernel_size=(w, 1), stride=(s, 1), dilation=(d, 1))
13
+ self.relu = nn.ReLU()
14
+ self.bn = nn.BatchNorm2d(f)
15
+ self.pool = nn.MaxPool2d(kernel_size=(2, 1))
16
+ self.dropout = nn.Dropout(0.25)
17
+
18
+ def forward(self, x):
19
+ x = self.pad(x)
20
+ x = self.conv2d(x)
21
+ x = self.relu(x)
22
+ x = self.bn(x)
23
+ x = self.pool(x)
24
+ x = self.dropout(x)
25
+ return x
26
+
27
+
28
+ class NoPadConvBlock(nn.Module):
29
+ def __init__(self, f, w, s, d, in_channels):
30
+ super().__init__()
31
+
32
+ self.conv2d = nn.Conv2d(in_channels=in_channels, out_channels=f, kernel_size=(w, 1), stride=(s, 1),
33
+ dilation=(d, 1))
34
+ self.relu = nn.ReLU()
35
+ self.bn = nn.BatchNorm2d(f)
36
+ self.pool = nn.MaxPool2d(kernel_size=(2, 1))
37
+ self.dropout = nn.Dropout(0.25)
38
+
39
+ def forward(self, x):
40
+ x = self.conv2d(x)
41
+ x = self.relu(x)
42
+ x = self.bn(x)
43
+ x = self.pool(x)
44
+ x = self.dropout(x)
45
+ return x
46
+
47
+
48
+ class TinyPathway(nn.Module):
49
+ def __init__(self, dilation=1, hop=256, localize=False,
50
+ model_capacity="full", n_layers=6, chunk_size=256):
51
+ super().__init__()
52
+
53
+ capacity_multiplier = {
54
+ 'tiny': 4, 'small': 8, 'medium': 16, 'large': 24, 'full': 32
55
+ }[model_capacity]
56
+ self.layers = [1, 2, 3, 4, 5, 6]
57
+ self.layers = self.layers[:n_layers]
58
+ filters = [n * capacity_multiplier for n in [32, 8, 8, 8, 8, 8]]
59
+ filters = [1] + filters
60
+ widths = [512, 64, 64, 64, 32, 32]
61
+ strides = self.deter_dilations(hop//(4*(2**n_layers)), localize=localize)
62
+ strides[0] = strides[0]*4 # apply 4 times more stride at the first layer
63
+ dilations = self.deter_dilations(dilation)
64
+
65
+ for i in range(len(self.layers)):
66
+ f, w, s, d, in_channel = filters[i + 1], widths[i], strides[i], dilations[i], filters[i]
67
+ self.add_module("conv%d" % i, NoPadConvBlock(f, w, s, d, in_channel))
68
+ self.chunk_size = chunk_size
69
+ self.input_window, self.hop = self.find_input_size_for_pathway()
70
+ self.out_dim = filters[n_layers]
71
+
72
+ def find_input_size_for_pathway(self):
73
+ def find_input_size(output_size, kernel_size, stride, dilation, padding):
74
+ num = (stride*(output_size-1)) + 1
75
+ input_size = num - 2*padding + dilation*(kernel_size-1)
76
+ return input_size
77
+ conv_calc, n = {}, 0
78
+ for i in self.layers:
79
+ layer = self.__getattr__("conv%d" % (i-1))
80
+ for mm in layer.modules():
81
+ if hasattr(mm, 'kernel_size'):
82
+ try:
83
+ d = mm.dilation[0]
84
+ except TypeError:
85
+ d = mm.dilation
86
+ conv_calc[n] = [mm.kernel_size[0], mm.stride[0], 0, d]
87
+ n += 1
88
+ out = self.chunk_size
89
+ hop = 1
90
+ for n in sorted(conv_calc.keys())[::-1]:
91
+ kernel_size_n, stride_n, padding_n, dilation_n = conv_calc[n]
92
+ out = find_input_size(out, kernel_size_n, stride_n, dilation_n, padding_n)
93
+ hop = hop*stride_n
94
+ return out, hop
95
+
96
+ def deter_dilations(self, total_dilation, localize=False):
97
+ n_layers = len(self.layers)
98
+ if localize: # e.g., 32*1023 window and 3 layers -> [1, 1, 32]
99
+ a = [total_dilation] + [1 for _ in range(n_layers-1)]
100
+ else: # e.g., 32*1023 window and 3 layers -> [4, 4, 2]
101
+ total_dilation = int(np.log2(total_dilation))
102
+ a = []
103
+ for layer in range(n_layers):
104
+ this_dilation = int(np.ceil(total_dilation/(n_layers-layer)))
105
+ a.append(2**this_dilation)
106
+ total_dilation = total_dilation - this_dilation
107
+ return a[::-1]
108
+
109
+ def forward(self, x):
110
+ x = x.view(x.shape[0], 1, -1, 1)
111
+ for i in range(len(self.layers)):
112
+ x = self.__getattr__("conv%d" % i)(x)
113
+ x = x.permute(0, 3, 2, 1)
114
+ return x
musc/pitch_estimator.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ import torchaudio
4
+ from typing import List, Optional, Tuple
5
+ import pathlib
6
+ from scipy.signal import medfilt
7
+ import numpy as np
8
+ import librosa
9
+ from librosa.sequence import viterbi_discriminative
10
+ from scipy.ndimage import gaussian_filter1d
11
+ from musc.postprocessing import spotify_create_notes
12
+
13
+
14
+ class PitchEstimator(nn.Module):
15
+ """
16
+ This is the base class that everything else inherits from. The hierarchy is:
17
+ PitchEstimator -> Transcriber -> Synchronizer -> AutonomousAgent -> The n-Head Music Performance Analysis Models
18
+ PitchEstimator can handle reading the audio, predicting all the features,
19
+ estimating a single frame level f0 using viterbi, or
20
+ MIDI pitch bend creation for the predicted note events when used inside a Transcriber, or
21
+ score-informed f0 estimation when used inside a Synchronizer.
22
+ """
23
+ def __init__(self, labeling, instrument='Violin', sr=16000, window_size=1024, hop_length=160):
24
+ super().__init__()
25
+ self.labeling = labeling
26
+ self.sr = sr
27
+ self.window_size = window_size
28
+ self.hop_length = hop_length
29
+ self.instrument = instrument
30
+ self.f0_bins_per_semitone = int(np.round(100/self.labeling.f0_granularity_c))
31
+
32
+
33
+ def read_audio(self, audio):
34
+ """
35
+ Read and resample an audio file, convert to mono, and unfold into representation frames.
36
+ The time array represents the center of each small frame with 5.8ms hop length. This is different than the chunk
37
+ level frames. The chunk level frames represent the entire sequence the model sees. Whereas it predicts with the
38
+ small frames intervals (5.8ms).
39
+ :param audio: str, pathlib.Path, np.ndarray, or torch.Tensor
40
+ :return: frames: (n_big_frames, frame_length), times: (n_small_frames,)
41
+ """
42
+ if isinstance(audio, str) or isinstance(audio, pathlib.Path):
43
+ audio, sample_rate = torchaudio.load(audio, normalize=True)
44
+ audio = audio.mean(axis=0) # convert to mono
45
+ if sample_rate != self.sr:
46
+ audio = torchaudio.functional.resample(audio, sample_rate, self.sr)
47
+ elif isinstance(audio, np.ndarray):
48
+ audio = torch.from_numpy(audio)
49
+ else:
50
+ assert isinstance(audio, torch.Tensor)
51
+ len_audio = audio.shape[-1]
52
+ n_frames = int(np.ceil((len_audio + sum(self.frame_overlap)) / (self.hop_length * self.chunk_size)))
53
+ audio = nn.functional.pad(audio, (self.frame_overlap[0],
54
+ self.frame_overlap[1] + (n_frames * self.hop_length * self.chunk_size) - len_audio))
55
+ frames = audio.unfold(0, self.max_window_size, self.hop_length*self.chunk_size)
56
+ times = np.arange(0, len_audio, self.hop_length) / self.sr # not tensor, we don't compute anything with it
57
+ return frames, times
58
+
59
+ def predict(self, audio, batch_size):
60
+ frames, times = self.read_audio(audio)
61
+ performance = {'f0': [], 'note': [], 'onset': [], 'offset': []}
62
+ self.eval()
63
+ device = self.main.conv0.conv2d.weight.device
64
+ with torch.no_grad():
65
+ for i in range(0, len(frames), batch_size):
66
+ f = frames[i:min(i + batch_size, len(frames))].to(device)
67
+ f -= (torch.mean(f, axis=1).unsqueeze(-1))
68
+ f /= (torch.std(f, axis=1).unsqueeze(-1))
69
+ out = self.forward(f)
70
+ for key, value in out.items():
71
+ value = torch.sigmoid(value)
72
+ value = torch.nan_to_num(value) # the model outputs nan when the frame is silent (this is an expected behavior due to normalization)
73
+ value = value.view(-1, value.shape[-1])
74
+ value = value.detach().cpu().numpy()
75
+ performance[key].append(value)
76
+ performance = {key: np.concatenate(value, axis=0)[:len(times)] for key, value in performance.items()}
77
+ performance['time'] = times
78
+ return performance
79
+
80
+ def estimate_pitch(self, audio, batch_size, viterbi=False):
81
+ out = self.predict(audio, batch_size)
82
+ f0_hz = self.out2f0(out, viterbi)
83
+ return out['time'], f0_hz
84
+
85
+ def out2f0(self, out, viterbi=False):
86
+ """
87
+ Monophonic f0 estimation from the model output. The viterbi postprocessing is specialized for the violin family.
88
+ """
89
+ salience = out['f0']
90
+ if viterbi == 'constrained':
91
+ assert hasattr(self, 'out2note')
92
+ notes = spotify_create_notes( out["note"], out["onset"], note_low=self.labeling.midi_centers[0],
93
+ note_high=self.labeling.midi_centers[-1], onset_thresh=0.5, frame_thresh=0.3,
94
+ infer_onsets=True, melodia_trick=True,
95
+ min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))))
96
+ note_cents = self.get_pitch_bends(salience, notes, to_midi=False, timing_refinement_range=0)
97
+ cents = np.zeros_like(out['time'])
98
+ cents[note_cents[:,0].astype(int)] = note_cents[:,1]
99
+ elif viterbi:
100
+ # transition probabilities inducing continuous pitch
101
+ # big changes are penalized with one order of magnitude
102
+ transition = gaussian_filter1d(np.eye(self.labeling.f0_n_bins), 30) + 99 * gaussian_filter1d(
103
+ np.eye(self.labeling.f0_n_bins), 2)
104
+ transition = transition / np.sum(transition, axis=1)[:, None]
105
+
106
+ p = salience / salience.sum(axis=1)[:, None]
107
+ p[np.isnan(p.sum(axis=1)), :] = np.ones(self.labeling.f0_n_bins) * 1 / self.labeling.f0_n_bins
108
+ path = viterbi_discriminative(p.T, transition)
109
+ cents = np.array([self.labeling.f0_label2c(salience[i, :], path[i]) for i in range(len(path))])
110
+ else:
111
+ cents = self.labeling.f0_label2c(salience, center=None) # use argmax for center
112
+
113
+ f0_hz = self.labeling.f0_c2hz(cents)
114
+ f0_hz[np.isnan(f0_hz)] = 0
115
+ return f0_hz
116
+
117
+ def get_pitch_bends(
118
+ self,
119
+ contours: np.ndarray, note_events: List[Tuple[int, int, int, float]],
120
+ timing_refinement_range: int = 0, to_midi: bool = True,
121
+ ) -> List[Tuple[int, int, int, float, Optional[List[int]]]]:
122
+ """Modified version of an excellent script from Spotify/basic_pitch!! Thank you!!!!
123
+ Given note events and contours, estimate pitch bends per note.
124
+ Pitch bends are represented as a sequence of evenly spaced midi pitch bend control units.
125
+ The time stamps of each pitch bend can be inferred by computing an evenly spaced grid between
126
+ the start and end times of each note event.
127
+ Args:
128
+ contours: Matrix of estimated pitch contours
129
+ note_events: note event tuple
130
+ timing_refinement_range: if > 0, refine onset/offset boundaries with f0 confidence
131
+ to_midi: whether to convert pitch bends to midi pitch bends. If False, return pitch estimates in the format
132
+ [time (index), pitch (Hz), confidence in range [0, 1]].
133
+ Returns:
134
+ note events with pitch bends
135
+ """
136
+
137
+ f0_matrix = [] # [time (index), pitch (Hz), confidence in range [0, 1]]
138
+ note_events_with_pitch_bends = []
139
+ for start_idx, end_idx, pitch_midi, amplitude in note_events:
140
+ if timing_refinement_range:
141
+ start_idx = np.max([0, start_idx - timing_refinement_range])
142
+ end_idx = np.min([contours.shape[0], end_idx + timing_refinement_range])
143
+ freq_idx = int(np.round(self.midi_pitch_to_contour_bin(pitch_midi)))
144
+ freq_start_idx = np.max([freq_idx - self.labeling.f0_tolerance_bins, 0])
145
+ freq_end_idx = np.min([self.labeling.f0_n_bins, freq_idx + self.labeling.f0_tolerance_bins + 1])
146
+
147
+ trans_start_idx = np.max([0, self.labeling.f0_tolerance_bins - freq_idx])
148
+ trans_end_idx = (2 * self.labeling.f0_tolerance_bins + 1) - \
149
+ np.max([0, freq_idx - (self.labeling.f0_n_bins - self.labeling.f0_tolerance_bins - 1)])
150
+
151
+ # apply regional viterbi to estimate the intonation
152
+ # observation probabilities come from the f0_roll matrix
153
+ observation = contours[start_idx:end_idx, freq_start_idx:freq_end_idx]
154
+ observation = observation / observation.sum(axis=1)[:, None]
155
+ observation[np.isnan(observation.sum(axis=1)), :] = np.ones(freq_end_idx - freq_start_idx) * 1 / (
156
+ freq_end_idx - freq_start_idx)
157
+
158
+ # transition probabilities assure continuity
159
+ transition = self.labeling.f0_transition_matrix[trans_start_idx:trans_end_idx,
160
+ trans_start_idx:trans_end_idx] + 1e-6
161
+ transition = transition / np.sum(transition, axis=1)[:, None]
162
+
163
+ path = viterbi_discriminative(observation.T / observation.sum(axis=1), transition) + freq_start_idx
164
+
165
+ cents = np.array([self.labeling.f0_label2c(contours[i + start_idx, :], path[i]) for i in range(len(path))])
166
+ bends = cents - self.labeling.midi_centers_c[pitch_midi - self.labeling.midi_centers[0]]
167
+ if to_midi:
168
+ bends = (bends * 4096 / 100).astype(int)
169
+ bends[bends > 8191] = 8191
170
+ bends[bends < -8192] = -8192
171
+
172
+ if timing_refinement_range:
173
+ confidences = np.array([contours[i + start_idx, path[i]] for i in range(len(path))])
174
+ threshold = np.median(confidences)
175
+ threshold = (np.median(confidences > threshold) + threshold) / 2 # some magic
176
+ median_kernel = 2 * (timing_refinement_range // 2) + 1 # some more magic
177
+ confidences = medfilt(confidences, kernel_size=median_kernel)
178
+ conf_bool = confidences > threshold
179
+ onset_idx = np.argmax(conf_bool)
180
+ offset_idx = len(confidences) - np.argmax(conf_bool[::-1])
181
+ bends = bends[onset_idx:offset_idx]
182
+ start_idx = start_idx + onset_idx
183
+ end_idx = start_idx + offset_idx
184
+
185
+ note_events_with_pitch_bends.append((start_idx, end_idx, pitch_midi, amplitude, bends))
186
+ else:
187
+ confidences = np.array([contours[i + start_idx, path[i]] for i in range(len(path))])
188
+ time_idx = np.arange(len(path)) + start_idx
189
+ # f0_hz = self.labeling.f0_c2hz(cents)
190
+ possible_f0s = np.array([time_idx, cents, confidences]).T
191
+ f0_matrix.append(possible_f0s[np.abs(bends)<100]) # filter out pitch bends that are too large
192
+ if not to_midi:
193
+ return np.vstack(f0_matrix)
194
+ else:
195
+ return note_events_with_pitch_bends
196
+
197
+
198
+ def midi_pitch_to_contour_bin(self, pitch_midi: int) -> np.array:
199
+ """Convert midi pitch to corresponding index in contour matrix
200
+ Args:
201
+ pitch_midi: pitch in midi
202
+ Returns:
203
+ index in contour matrix
204
+ """
205
+ pitch_hz = librosa.midi_to_hz(pitch_midi)
206
+ return np.argmin(np.abs(self.labeling.f0_centers_hz - pitch_hz))
musc/postprocessing.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import scipy
3
+ import numpy as np
4
+
5
+
6
+ # SPOTIFY
7
+
8
+ def get_inferred_onsets(onset_roll: np.array, note_roll: np.array, n_diff: int = 2) -> np.array:
9
+ """
10
+ Infer onsets from large changes in note roll matrix amplitudes.
11
+ Modified from https://github.com/spotify/basic-pitch/blob/main/basic_pitch/note_creation.py
12
+ :param onset_roll: Onset activation matrix (n_times, n_freqs).
13
+ :param note_roll: Frame-level note activation matrix (n_times, n_freqs).
14
+ :param n_diff: Differences used to detect onsets.
15
+ :return: The maximum between the predicted onsets and its differences.
16
+ """
17
+
18
+ diffs = []
19
+ for n in range(1, n_diff + 1):
20
+ frames_appended = np.concatenate([np.zeros((n, note_roll.shape[1])), note_roll])
21
+ diffs.append(frames_appended[n:, :] - frames_appended[:-n, :])
22
+ frame_diff = np.min(diffs, axis=0)
23
+ frame_diff[frame_diff < 0] = 0
24
+ frame_diff[:n_diff, :] = 0
25
+ frame_diff = np.max(onset_roll) * frame_diff / np.max(frame_diff) # rescale to have the same max as onsets
26
+
27
+ max_onsets_diff = np.max([onset_roll, frame_diff],
28
+ axis=0) # use the max of the predicted onsets and the differences
29
+
30
+ return max_onsets_diff
31
+
32
+
33
+
34
+ def spotify_create_notes(
35
+ note_roll: np.array,
36
+ onset_roll: np.array,
37
+ onset_thresh: float,
38
+ frame_thresh: float,
39
+ min_note_len: int,
40
+ infer_onsets: bool,
41
+ note_low : int, #self.labeling.midi_centers[0]
42
+ note_high : int, #self.labeling.midi_centers[-1],
43
+ melodia_trick: bool = True,
44
+ energy_tol: int = 11,
45
+ ) -> List[Tuple[int, int, int, float]]:
46
+ """Decode raw model output to polyphonic note events
47
+ Modified from https://github.com/spotify/basic-pitch/blob/main/basic_pitch/note_creation.py
48
+ Args:
49
+ note_roll: Frame activation matrix (n_times, n_freqs).
50
+ onset_roll: Onset activation matrix (n_times, n_freqs).
51
+ onset_thresh: Minimum amplitude of an onset activation to be considered an onset.
52
+ frame_thresh: Minimum amplitude of a frame activation for a note to remain "on".
53
+ min_note_len: Minimum allowed note length in frames.
54
+ infer_onsets: If True, add additional onsets when there are large differences in frame amplitudes.
55
+ melodia_trick : Whether to use the melodia trick to better detect notes.
56
+ energy_tol: Drop notes below this energy.
57
+ Returns:
58
+ list of tuples [(start_time_frames, end_time_frames, pitch_midi, amplitude)]
59
+ representing the note events, where amplitude is a number between 0 and 1
60
+ """
61
+
62
+ n_frames = note_roll.shape[0]
63
+
64
+ # use onsets inferred from frames in addition to the predicted onsets
65
+ if infer_onsets:
66
+ onset_roll = get_inferred_onsets(onset_roll, note_roll)
67
+
68
+ peak_thresh_mat = np.zeros(onset_roll.shape)
69
+ peaks = scipy.signal.argrelmax(onset_roll, axis=0)
70
+ peak_thresh_mat[peaks] = onset_roll[peaks]
71
+
72
+ onset_idx = np.where(peak_thresh_mat >= onset_thresh)
73
+ onset_time_idx = onset_idx[0][::-1] # sort to go backwards in time
74
+ onset_freq_idx = onset_idx[1][::-1] # sort to go backwards in time
75
+
76
+ remaining_energy = np.zeros(note_roll.shape)
77
+ remaining_energy[:, :] = note_roll[:, :]
78
+
79
+ # loop over onsets
80
+ note_events = []
81
+ for note_start_idx, freq_idx in zip(onset_time_idx, onset_freq_idx):
82
+ # if we're too close to the end of the audio, continue
83
+ if note_start_idx >= n_frames - 1:
84
+ continue
85
+
86
+ # find time index at this frequency band where the frames drop below an energy threshold
87
+ i = note_start_idx + 1
88
+ k = 0 # number of frames since energy dropped below threshold
89
+ while i < n_frames - 1 and k < energy_tol:
90
+ if remaining_energy[i, freq_idx] < frame_thresh:
91
+ k += 1
92
+ else:
93
+ k = 0
94
+ i += 1
95
+
96
+ i -= k # go back to frame above threshold
97
+
98
+ # if the note is too short, skip it
99
+ if i - note_start_idx <= min_note_len:
100
+ continue
101
+
102
+ remaining_energy[note_start_idx:i, freq_idx] = 0
103
+ if freq_idx < note_high:
104
+ remaining_energy[note_start_idx:i, freq_idx + 1] = 0
105
+ if freq_idx > note_low:
106
+ remaining_energy[note_start_idx:i, freq_idx - 1] = 0
107
+
108
+ # add the note
109
+ amplitude = np.mean(note_roll[note_start_idx:i, freq_idx])
110
+ note_events.append(
111
+ (
112
+ note_start_idx,
113
+ i,
114
+ freq_idx + note_low,
115
+ amplitude,
116
+ )
117
+ )
118
+
119
+ if melodia_trick:
120
+ energy_shape = remaining_energy.shape
121
+
122
+ while np.max(remaining_energy) > frame_thresh:
123
+ i_mid, freq_idx = np.unravel_index(np.argmax(remaining_energy), energy_shape)
124
+ remaining_energy[i_mid, freq_idx] = 0
125
+
126
+ # forward pass
127
+ i = i_mid + 1
128
+ k = 0
129
+ while i < n_frames - 1 and k < energy_tol:
130
+ if remaining_energy[i, freq_idx] < frame_thresh:
131
+ k += 1
132
+ else:
133
+ k = 0
134
+
135
+ remaining_energy[i, freq_idx] = 0
136
+ if freq_idx < note_high:
137
+ remaining_energy[i, freq_idx + 1] = 0
138
+ if freq_idx > note_low:
139
+ remaining_energy[i, freq_idx - 1] = 0
140
+
141
+ i += 1
142
+
143
+ i_end = i - 1 - k # go back to frame above threshold
144
+
145
+ # backward pass
146
+ i = i_mid - 1
147
+ k = 0
148
+ while i > 0 and k < energy_tol:
149
+ if remaining_energy[i, freq_idx] < frame_thresh:
150
+ k += 1
151
+ else:
152
+ k = 0
153
+
154
+ remaining_energy[i, freq_idx] = 0
155
+ if freq_idx < note_high:
156
+ remaining_energy[i, freq_idx + 1] = 0
157
+ if freq_idx > note_low:
158
+ remaining_energy[i, freq_idx - 1] = 0
159
+
160
+ i -= 1
161
+
162
+ i_start = i + 1 + k # go back to frame above threshold
163
+ assert i_start >= 0, "{}".format(i_start)
164
+ assert i_end < n_frames
165
+
166
+ if i_end - i_start <= min_note_len:
167
+ # note is too short, skip it
168
+ continue
169
+
170
+ # add the note
171
+ amplitude = np.mean(note_roll[i_start:i_end, freq_idx])
172
+ note_events.append(
173
+ (
174
+ i_start,
175
+ i_end,
176
+ freq_idx + note_low,
177
+ amplitude,
178
+ )
179
+ )
180
+
181
+ return note_events
182
+
183
+
184
+
185
+ # TIKTOK
186
+
187
+
188
+ def note_detection_with_onset_offset_regress(frame_output, onset_output,
189
+ onset_shift_output, offset_output, offset_shift_output, velocity_output,
190
+ frame_threshold):
191
+ """Process prediction matrices to note events information.
192
+ First, detect onsets with onset outputs. Then, detect offsets
193
+ with frame and offset outputs.
194
+
195
+ Args:
196
+ frame_output: (frames_num,)
197
+ onset_output: (frames_num,)
198
+ onset_shift_output: (frames_num,)
199
+ offset_output: (frames_num,)
200
+ offset_shift_output: (frames_num,)
201
+ velocity_output: (frames_num,)
202
+ frame_threshold: float
203
+ Returns:
204
+ output_tuples: list of [bgn, fin, onset_shift, offset_shift, normalized_velocity],
205
+ e.g., [
206
+ [1821, 1909, 0.47498, 0.3048533, 0.72119445],
207
+ [1909, 1947, 0.30730522, -0.45764327, 0.64200014],
208
+ ...]
209
+ """
210
+ output_tuples = []
211
+ bgn = None
212
+ frame_disappear = None
213
+ offset_occur = None
214
+
215
+ for i in range(onset_output.shape[0]):
216
+ if onset_output[i] == 1:
217
+ """Onset detected"""
218
+ if bgn:
219
+ """Consecutive onsets. E.g., pedal is not released, but two
220
+ consecutive notes being played."""
221
+ fin = max(i - 1, 0)
222
+ output_tuples.append([bgn, fin, onset_shift_output[bgn],
223
+ 0, velocity_output[bgn]])
224
+ frame_disappear, offset_occur = None, None
225
+ bgn = i
226
+
227
+ if bgn and i > bgn:
228
+ """If onset found, then search offset"""
229
+ if frame_output[i] <= frame_threshold and not frame_disappear:
230
+ """Frame disappear detected"""
231
+ frame_disappear = i
232
+
233
+ if offset_output[i] == 1 and not offset_occur:
234
+ """Offset detected"""
235
+ offset_occur = i
236
+
237
+ if frame_disappear:
238
+ if offset_occur and offset_occur - bgn > frame_disappear - offset_occur:
239
+ """bgn --------- offset_occur --- frame_disappear"""
240
+ fin = offset_occur
241
+ else:
242
+ """bgn --- offset_occur --------- frame_disappear"""
243
+ fin = frame_disappear
244
+ output_tuples.append([bgn, fin, onset_shift_output[bgn],
245
+ offset_shift_output[fin], velocity_output[bgn]])
246
+ bgn, frame_disappear, offset_occur = None, None, None
247
+
248
+ if bgn and (i - bgn >= 600 or i == onset_output.shape[0] - 1):
249
+ """Offset not detected"""
250
+ fin = i
251
+ output_tuples.append([bgn, fin, onset_shift_output[bgn],
252
+ offset_shift_output[fin], velocity_output[bgn]])
253
+ bgn, frame_disappear, offset_occur = None, None, None
254
+
255
+ # Sort pairs by onsets
256
+ output_tuples.sort(key=lambda pair: pair[0])
257
+
258
+ return output_tuples
259
+
260
+
261
+ class RegressionPostProcessor(object):
262
+ def __init__(self, frames_per_second, classes_num, onset_threshold,
263
+ offset_threshold, frame_threshold, pedal_offset_threshold,
264
+ begin_note):
265
+ """Postprocess the output probabilities of a transription model to MIDI
266
+ events.
267
+
268
+ Args:
269
+ frames_per_second: float
270
+ classes_num: int
271
+ onset_threshold: float
272
+ offset_threshold: float
273
+ frame_threshold: float
274
+ pedal_offset_threshold: float
275
+ """
276
+ self.frames_per_second = frames_per_second
277
+ self.classes_num = classes_num
278
+ self.onset_threshold = onset_threshold
279
+ self.offset_threshold = offset_threshold
280
+ self.frame_threshold = frame_threshold
281
+ self.pedal_offset_threshold = pedal_offset_threshold
282
+ self.begin_note = begin_note
283
+ self.velocity_scale = 128
284
+
285
+ def output_dict_to_midi_events(self, output_dict):
286
+ """Main function. Post process model outputs to MIDI events.
287
+
288
+ Args:
289
+ output_dict: {
290
+ 'reg_onset_output': (segment_frames, classes_num),
291
+ 'reg_offset_output': (segment_frames, classes_num),
292
+ 'frame_output': (segment_frames, classes_num),
293
+ 'velocity_output': (segment_frames, classes_num),
294
+ 'reg_pedal_onset_output': (segment_frames, 1),
295
+ 'reg_pedal_offset_output': (segment_frames, 1),
296
+ 'pedal_frame_output': (segment_frames, 1)}
297
+
298
+ Outputs:
299
+ est_note_events: list of dict, e.g. [
300
+ {'onset_time': 39.74, 'offset_time': 39.87, 'midi_note': 27, 'velocity': 83},
301
+ {'onset_time': 11.98, 'offset_time': 12.11, 'midi_note': 33, 'velocity': 88}]
302
+
303
+ est_pedal_events: list of dict, e.g. [
304
+ {'onset_time': 0.17, 'offset_time': 0.96},
305
+ {'osnet_time': 1.17, 'offset_time': 2.65}]
306
+ """
307
+ output_dict['frame_output'] = output_dict['note']
308
+ output_dict['velocity_output'] = output_dict['note']
309
+ output_dict['reg_onset_output'] = output_dict['onset']
310
+ output_dict['reg_offset_output'] = output_dict['offset']
311
+ # Post process piano note outputs to piano note and pedal events information
312
+ (est_on_off_note_vels, est_pedal_on_offs) = \
313
+ self.output_dict_to_note_pedal_arrays(output_dict)
314
+ """est_on_off_note_vels: (events_num, 4), the four columns are: [onset_time, offset_time, piano_note, velocity],
315
+ est_pedal_on_offs: (pedal_events_num, 2), the two columns are: [onset_time, offset_time]"""
316
+
317
+ # Reformat notes to MIDI events
318
+ est_note_events = self.detected_notes_to_events(est_on_off_note_vels)
319
+
320
+ if est_pedal_on_offs is None:
321
+ est_pedal_events = None
322
+ else:
323
+ est_pedal_events = self.detected_pedals_to_events(est_pedal_on_offs)
324
+
325
+ return est_note_events, est_pedal_events
326
+
327
+ def output_dict_to_note_pedal_arrays(self, output_dict):
328
+ """Postprocess the output probabilities of a transription model to MIDI
329
+ events.
330
+
331
+ Args:
332
+ output_dict: dict, {
333
+ 'reg_onset_output': (frames_num, classes_num),
334
+ 'reg_offset_output': (frames_num, classes_num),
335
+ 'frame_output': (frames_num, classes_num),
336
+ 'velocity_output': (frames_num, classes_num),
337
+ ...}
338
+
339
+ Returns:
340
+ est_on_off_note_vels: (events_num, 4), the 4 columns are onset_time,
341
+ offset_time, piano_note and velocity. E.g. [
342
+ [39.74, 39.87, 27, 0.65],
343
+ [11.98, 12.11, 33, 0.69],
344
+ ...]
345
+
346
+ est_pedal_on_offs: (pedal_events_num, 2), the 2 columns are onset_time
347
+ and offset_time. E.g. [
348
+ [0.17, 0.96],
349
+ [1.17, 2.65],
350
+ ...]
351
+ """
352
+
353
+ # ------ 1. Process regression outputs to binarized outputs ------
354
+ # For example, onset or offset of [0., 0., 0.15, 0.30, 0.40, 0.35, 0.20, 0.05, 0., 0.]
355
+ # will be processed to [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]
356
+
357
+ # Calculate binarized onset output from regression output
358
+ (onset_output, onset_shift_output) = \
359
+ self.get_binarized_output_from_regression(
360
+ reg_output=output_dict['reg_onset_output'],
361
+ threshold=self.onset_threshold, neighbour=2)
362
+
363
+ output_dict['onset_output'] = onset_output # Values are 0 or 1
364
+ output_dict['onset_shift_output'] = onset_shift_output
365
+
366
+ # Calculate binarized offset output from regression output
367
+ (offset_output, offset_shift_output) = \
368
+ self.get_binarized_output_from_regression(
369
+ reg_output=output_dict['reg_offset_output'],
370
+ threshold=self.offset_threshold, neighbour=4)
371
+
372
+ output_dict['offset_output'] = offset_output # Values are 0 or 1
373
+ output_dict['offset_shift_output'] = offset_shift_output
374
+
375
+ if 'reg_pedal_onset_output' in output_dict.keys():
376
+ """Pedal onsets are not used in inference. Instead, frame-wise pedal
377
+ predictions are used to detect onsets. We empirically found this is
378
+ more accurate to detect pedal onsets."""
379
+ pass
380
+
381
+ if 'reg_pedal_offset_output' in output_dict.keys():
382
+ # Calculate binarized pedal offset output from regression output
383
+ (pedal_offset_output, pedal_offset_shift_output) = \
384
+ self.get_binarized_output_from_regression(
385
+ reg_output=output_dict['reg_pedal_offset_output'],
386
+ threshold=self.pedal_offset_threshold, neighbour=4)
387
+
388
+ output_dict['pedal_offset_output'] = pedal_offset_output # Values are 0 or 1
389
+ output_dict['pedal_offset_shift_output'] = pedal_offset_shift_output
390
+
391
+ # ------ 2. Process matrices results to event results ------
392
+ # Detect piano notes from output_dict
393
+ est_on_off_note_vels = self.output_dict_to_detected_notes(output_dict)
394
+
395
+ est_pedal_on_offs = None
396
+
397
+ return est_on_off_note_vels, est_pedal_on_offs
398
+
399
+ def get_binarized_output_from_regression(self, reg_output, threshold, neighbour):
400
+ """Calculate binarized output and shifts of onsets or offsets from the
401
+ regression results.
402
+
403
+ Args:
404
+ reg_output: (frames_num, classes_num)
405
+ threshold: float
406
+ neighbour: int
407
+
408
+ Returns:
409
+ binary_output: (frames_num, classes_num)
410
+ shift_output: (frames_num, classes_num)
411
+ """
412
+ binary_output = np.zeros_like(reg_output)
413
+ shift_output = np.zeros_like(reg_output)
414
+ (frames_num, classes_num) = reg_output.shape
415
+
416
+ for k in range(classes_num):
417
+ x = reg_output[:, k]
418
+ for n in range(neighbour, frames_num - neighbour):
419
+ if x[n] > threshold and self.is_monotonic_neighbour(x, n, neighbour):
420
+ binary_output[n, k] = 1
421
+
422
+ """See Section III-D in [1] for deduction.
423
+ [1] Q. Kong, et al., High-resolution Piano Transcription
424
+ with Pedals by Regressing Onsets and Offsets Times, 2020."""
425
+ if x[n - 1] > x[n + 1]:
426
+ shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n + 1]) / 2
427
+ else:
428
+ shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n - 1]) / 2
429
+ shift_output[n, k] = shift
430
+
431
+ return binary_output, shift_output
432
+
433
+ def is_monotonic_neighbour(self, x, n, neighbour):
434
+ """Detect if values are monotonic in both side of x[n].
435
+
436
+ Args:
437
+ x: (frames_num,)
438
+ n: int
439
+ neighbour: int
440
+
441
+ Returns:
442
+ monotonic: bool
443
+ """
444
+ monotonic = True
445
+ for i in range(neighbour):
446
+ if x[n - i] < x[n - i - 1]:
447
+ monotonic = False
448
+ if x[n + i] < x[n + i + 1]:
449
+ monotonic = False
450
+
451
+ return monotonic
452
+
453
+ def output_dict_to_detected_notes(self, output_dict):
454
+ """Postprocess output_dict to piano notes.
455
+
456
+ Args:
457
+ output_dict: dict, e.g. {
458
+ 'onset_output': (frames_num, classes_num),
459
+ 'onset_shift_output': (frames_num, classes_num),
460
+ 'offset_output': (frames_num, classes_num),
461
+ 'offset_shift_output': (frames_num, classes_num),
462
+ 'frame_output': (frames_num, classes_num),
463
+ 'onset_output': (frames_num, classes_num),
464
+ ...}
465
+
466
+ Returns:
467
+ est_on_off_note_vels: (notes, 4), the four columns are onsets, offsets,
468
+ MIDI notes and velocities. E.g.,
469
+ [[39.7375, 39.7500, 27., 0.6638],
470
+ [11.9824, 12.5000, 33., 0.6892],
471
+ ...]
472
+ """
473
+
474
+ est_tuples = []
475
+ est_midi_notes = []
476
+ classes_num = output_dict['frame_output'].shape[-1]
477
+
478
+ for piano_note in range(classes_num):
479
+ """Detect piano notes"""
480
+ est_tuples_per_note = note_detection_with_onset_offset_regress(
481
+ frame_output=output_dict['frame_output'][:, piano_note],
482
+ onset_output=output_dict['onset_output'][:, piano_note],
483
+ onset_shift_output=output_dict['onset_shift_output'][:, piano_note],
484
+ offset_output=output_dict['offset_output'][:, piano_note],
485
+ offset_shift_output=output_dict['offset_shift_output'][:, piano_note],
486
+ velocity_output=output_dict['velocity_output'][:, piano_note],
487
+ frame_threshold=self.frame_threshold)
488
+
489
+ est_tuples += est_tuples_per_note
490
+ est_midi_notes += [piano_note + self.begin_note] * len(est_tuples_per_note)
491
+
492
+ est_tuples = np.array(est_tuples) # (notes, 5)
493
+ """(notes, 5), the five columns are onset, offset, onset_shift,
494
+ offset_shift and normalized_velocity"""
495
+
496
+ est_midi_notes = np.array(est_midi_notes) # (notes,)
497
+
498
+ onset_times = (est_tuples[:, 0] + est_tuples[:, 2]) / self.frames_per_second
499
+ offset_times = (est_tuples[:, 1] + est_tuples[:, 3]) / self.frames_per_second
500
+ velocities = est_tuples[:, 4]
501
+
502
+ est_on_off_note_vels = np.stack((onset_times, offset_times, est_midi_notes, velocities), axis=-1)
503
+ """(notes, 3), the three columns are onset_times, offset_times and velocity."""
504
+
505
+ est_on_off_note_vels = est_on_off_note_vels.astype(np.float32)
506
+
507
+ return est_on_off_note_vels
508
+
509
+ def detected_notes_to_events(self, est_on_off_note_vels):
510
+ """Reformat detected notes to midi events.
511
+
512
+ Args:
513
+ est_on_off_vels: (notes, 3), the three columns are onset_times,
514
+ offset_times and velocity. E.g.
515
+ [[32.8376, 35.7700, 0.7932],
516
+ [37.3712, 39.9300, 0.8058],
517
+ ...]
518
+
519
+ Returns:
520
+ midi_events, list, e.g.,
521
+ [{'onset_time': 39.7376, 'offset_time': 39.75, 'midi_note': 27, 'velocity': 84},
522
+ {'onset_time': 11.9824, 'offset_time': 12.50, 'midi_note': 33, 'velocity': 88},
523
+ ...]
524
+ """
525
+ midi_events = []
526
+ for i in range(est_on_off_note_vels.shape[0]):
527
+ midi_events.append({
528
+ 'onset_time': est_on_off_note_vels[i][0],
529
+ 'offset_time': est_on_off_note_vels[i][1],
530
+ 'midi_note': int(est_on_off_note_vels[i][2]),
531
+ 'velocity': int(est_on_off_note_vels[i][3] * self.velocity_scale)})
532
+
533
+ return midi_events
musc/representations.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mir_eval import melody
2
+ import numpy as np
3
+ from scipy.stats import norm
4
+ import librosa
5
+ import pretty_midi
6
+ from scipy.ndimage import gaussian_filter1d
7
+
8
+
9
+ class PerformanceLabel:
10
+ """
11
+ The dataset labeling class for performance representations. Currently, includes onset, note, and fine-grained f0
12
+ representations. Note min, note max, and f0_bin_per_semitone values are to be arranged per instrument. The default
13
+ values are for violin performance analysis. Fretted instruments might not require such f0 resolutions per semitone.
14
+ """
15
+ def __init__(self, note_min='F#3', note_max='C8', f0_bins_per_semitone=9, f0_smooth_std_c=None,
16
+ onset_smooth_std=0.7, f0_tolerance_c=200):
17
+ midi_min = pretty_midi.note_name_to_number(note_min)
18
+ midi_max = pretty_midi.note_name_to_number(note_max)
19
+ self.midi_centers = np.arange(midi_min, midi_max)
20
+ self.onset_smooth_std=onset_smooth_std # onset smoothing along time axis (compensate for alignment)
21
+
22
+ f0_hz_range = librosa.note_to_hz([note_min, note_max])
23
+ f0_c_min, f0_c_max = melody.hz2cents(f0_hz_range)
24
+ self.f0_granularity_c = 100/f0_bins_per_semitone
25
+ if not f0_smooth_std_c:
26
+ f0_smooth_std_c = self.f0_granularity_c * 5/4 # Keep the ratio from the CREPE paper (20 cents and 25 cents)
27
+ self.f0_smooth_std_c = f0_smooth_std_c
28
+
29
+ self.f0_centers_c = np.arange(f0_c_min, f0_c_max, self.f0_granularity_c)
30
+ self.f0_centers_hz = 10 * 2 ** (self.f0_centers_c / 1200)
31
+ self.f0_n_bins = len(self.f0_centers_c)
32
+
33
+ self.pdf_normalizer = norm.pdf(0)
34
+
35
+ self.f0_c2hz = lambda c: 10*2**(c/1200)
36
+ self.f0_hz2c = melody.hz2cents
37
+ self.midi_centers_c = self.f0_hz2c(librosa.midi_to_hz(self.midi_centers))
38
+
39
+ self.f0_tolerance_bins = int(f0_tolerance_c/self.f0_granularity_c)
40
+ self.f0_transition_matrix = gaussian_filter1d(np.eye(2*self.f0_tolerance_bins + 1), 25/self.f0_granularity_c)
41
+
42
+ def f0_c2label(self, pitch_c):
43
+ """
44
+ Convert a single f0 value in cents to a one-hot label vector with smoothing (i.e., create a gaussian blur around
45
+ the target f0 bin for regularization and training stability. The blur is controlled by self.f0_smooth_std_c
46
+ :param pitch_c: a single pitch value in cents
47
+ :return: one-hot label vector with frequency blur
48
+ """
49
+ result = norm.pdf((self.f0_centers_c - pitch_c) / self.f0_smooth_std_c).astype(np.float32)
50
+ result /= self.pdf_normalizer
51
+ return result
52
+
53
+ def f0_label2c(self, salience, center=None):
54
+ """
55
+ Convert the salience predictions to monophonic f0 in cents. Only outputs a single f0 value per frame!
56
+ :param salience: f0 activations
57
+ :param center: f0 center bin to calculate the weighted average. Use argmax if empty
58
+ :return: f0 array per frame (in cents).
59
+ """
60
+ if salience.ndim == 1:
61
+ if center is None:
62
+ center = int(np.argmax(salience))
63
+ start = max(0, center - 4)
64
+ end = min(len(salience), center + 5)
65
+ salience = salience[start:end]
66
+ product_sum = np.sum(salience * self.f0_centers_c[start:end])
67
+ weight_sum = np.sum(salience)
68
+ return product_sum / np.clip(weight_sum, 1e-8, None)
69
+ if salience.ndim == 2:
70
+ return np.array([self.f0_label2c(salience[i, :]) for i in range(salience.shape[0])])
71
+ raise Exception("label should be either 1d or 2d ndarray")
72
+
73
+ def fill_onset_matrix(self, onsets, window, feature_rate):
74
+ """
75
+ Create a sparse onset matrix from window and onsets (per-semitone). Apply a gaussian smoothing (along time)
76
+ so that we can tolerate better the alignment problems. This is similar to the frequency smoothing for the f0.
77
+ The temporal smoothing is controlled by the parameter self.onset_smooth_std
78
+ :param onsets: A 2d np.array of individual note onsets with their respective time values
79
+ (Nx2: time in seconds - midi number)
80
+ :param window: Timestamps for the frame centers of the sparse matrix
81
+ :param feature_rate: Window timestamps are integer, this is to convert them to seconds
82
+ :return: onset_roll: A sparse matrix filled with temporally blurred onsets.
83
+ """
84
+ onsets = self.get_window_feats(onsets, window, feature_rate)
85
+ onset_roll = np.zeros((len(window), len(self.midi_centers)))
86
+ for onset in onsets:
87
+ onset, note = onset # it was a pair with time and midi note
88
+ if self.midi_centers[0] < note < self.midi_centers[-1]: # midi note should be in the range defined
89
+ note = int(note) - self.midi_centers[0] # find the note index in our range
90
+ onset = (onset*feature_rate)-window[0] # onset index (as float but in frames, not in seconds!)
91
+ start = max(0, int(onset) - 3)
92
+ end = min(len(window) - 1, int(onset) + 3)
93
+ try:
94
+ vals = norm.pdf(np.linspace(start - onset, end - onset, end - start + 1) / self.onset_smooth_std)
95
+ # if you increase 0.7 you smooth the peak
96
+ # if you decrease it, e.g., 0.1, it becomes too peaky! around 0.5-0.7 seems ok
97
+ vals /= self.pdf_normalizer
98
+ onset_roll[start:end + 1, note] += vals
99
+ except ValueError:
100
+ print('start',start, 'onset', onset, 'end', end)
101
+ return onset_roll, onsets
102
+
103
+ def fill_note_matrix(self, notes, window, feature_rate):
104
+ """
105
+ Create the note matrix (piano roll) from window timestamps and note values per frame.
106
+ :param notes: A 2d np.array of individual notes with their active time values Nx2
107
+ :param window: Timestamps for the frame centers of the output
108
+ :param feature_rate: Window timestamps are integer, this is to convert them to seconds
109
+ :return note_roll: The piano roll in the defined range of [note_min, note_max).
110
+ """
111
+ notes = self.get_window_feats(notes, window, feature_rate)
112
+
113
+ # take the notes in the midi range defined
114
+ notes = notes[np.logical_and(notes[:,1]>=self.midi_centers[0], notes[:,1]<=self.midi_centers[-1]),:]
115
+
116
+ times = (notes[:,0]*feature_rate - window[0]).astype(int) # in feature samples (fs:self.hop/self.sr)
117
+ notes = (notes[:,1] - self.midi_centers[0]).astype(int)
118
+
119
+ note_roll = np.zeros((len(window), len(self.midi_centers)))
120
+ note_roll[(times, notes)] = 1
121
+ return note_roll, notes
122
+
123
+
124
+ def fill_f0_matrix(self, f0s, window, feature_rate):
125
+ """
126
+ Unlike the labels for onsets and notes, f0 label is only relevant for strictly monophonic regions! Thus, this
127
+ function returns a boolean which represents where to apply the given values.
128
+ Never back-propagate without the boolean! Empty frames mean that the label is not that reliable.
129
+
130
+ :param f0s: A 2d np.array of f0 values with the time they belong to (2xN: time in seconds - f0 in Hz)
131
+ :param window: Timestamps for the frame centers of the output
132
+ :param feature_rate: Window timestamps are integer, this is to convert them to seconds
133
+
134
+ :return f0_roll: f0 label matrix and
135
+ f0_hz: f0 values in Hz
136
+ annotation_bool: A boolean array representing which frames have reliable f0 annotations.
137
+ """
138
+ f0s = self.get_window_feats(f0s, window, feature_rate)
139
+ f0_cents = np.zeros_like(window, dtype=float)
140
+ f0s[:,1] = self.f0_hz2c(f0s[:,1]) # convert f0 in hz to cents
141
+
142
+ annotation_bool = np.zeros_like(window, dtype=bool)
143
+ f0_roll = np.zeros((len(window), len(self.f0_centers_c)))
144
+ times_in_frame = f0s[:, 0]*feature_rate - window[0]
145
+ for t, f0 in enumerate(f0s):
146
+ t = times_in_frame[t]
147
+ if t%1 < 0.25: # only consider it as annotation if the f0 values is really close to the frame center
148
+ t = int(np.round(t))
149
+ f0_roll[t] = self.f0_c2label(f0[1])
150
+ annotation_bool[t] = True
151
+ f0_cents[t] = f0[1]
152
+
153
+ return f0_roll, f0_cents, annotation_bool
154
+
155
+
156
+ @staticmethod
157
+ def get_window_feats(time_feature_matrix, window, feature_rate):
158
+ """
159
+ Restrict the feature matrix to the features that are inside the window
160
+ :param window: Timestamps for the frame centers of the output
161
+ :param time_feature_matrix: A 2d array of Nx2 per the entire file.
162
+ :param feature_rate: Window timestamps are integer, this is to convert them to seconds
163
+ :return: window_features: the features inside the given window
164
+ """
165
+ start = time_feature_matrix[:,0]>(window[0]-0.5)/feature_rate
166
+ end = time_feature_matrix[:,0]<(window[-1]+0.5)/feature_rate
167
+ window_features = np.logical_and(start, end)
168
+ window_features = np.array(time_feature_matrix[window_features,:])
169
+ return window_features
170
+
171
+ def represent_midi(self, midi, feature_rate):
172
+ """
173
+ Represent a midi file as sparse matrices of onsets, offsets, and notes. No f0 is included.
174
+ :param midi: A midi file (either a path or a pretty_midi.PrettyMIDI object)
175
+ :param feature_rate: The feature rate in Hz
176
+ :return: dict {onset, offset, note, time}: Same format with the model's learning and outputs
177
+ """
178
+ def _get_onsets_offsets_frames(midi_content):
179
+ if isinstance(midi_content, str):
180
+ midi_content = pretty_midi.PrettyMIDI(midi_content)
181
+ onsets = []
182
+ offsets = []
183
+ frames = []
184
+ for instrument in midi_content.instruments:
185
+ for note in instrument.notes:
186
+ start = int(np.round(note.start * feature_rate))
187
+ end = int(np.round(note.end * feature_rate))
188
+ note_times = (np.arange(start, end+0.5)/feature_rate)[:, np.newaxis]
189
+ note_pitch = np.full_like(note_times, fill_value=note.pitch)
190
+ onsets.append([note.start, note.pitch])
191
+ offsets.append([note.end, note.pitch])
192
+ frames.append(np.hstack([note_times, note_pitch]))
193
+ onsets = np.vstack(onsets)
194
+ offsets = np.vstack(offsets)
195
+ frames = np.vstack(frames)
196
+ return onsets, offsets, frames, midi_content
197
+ onset_array, offset_array, frame_array, midi_object = _get_onsets_offsets_frames(midi)
198
+ window = np.arange(frame_array[0, 0]*feature_rate, frame_array[-1, 0]*feature_rate, dtype=int)
199
+ onset_roll, _ = self.fill_onset_matrix(onset_array, window, feature_rate)
200
+ offset_roll, _ = self.fill_onset_matrix(offset_array, window, feature_rate)
201
+ note_roll, _ = self.fill_note_matrix(frame_array, window, feature_rate)
202
+ start_anchor = onset_array[onset_array[:, 0]==np.min(onset_array[:, 0])]
203
+ end_anchor = offset_array[offset_array[:, 0]==np.max(offset_array[:, 0])]
204
+ return {
205
+ 'midi': midi_object,
206
+ 'note': note_roll,
207
+ 'onset': onset_roll,
208
+ 'offset': offset_roll,
209
+ 'time': window/feature_rate,
210
+ 'start_anchor': start_anchor,
211
+ 'end_anchor': end_anchor
212
+ }
musc/synchronizer.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from musc.dtw.mrmsdtw import sync_via_mrmsdtw_with_anchors
2
+ from musc.dtw.utils import make_path_strictly_monotonic
3
+ import numpy as np
4
+ from musc.transcriber import Transcriber
5
+ from typing import Dict
6
+
7
+ class Synchronizer(Transcriber):
8
+ def __init__(self, labeling, instrument='Violin', sr=16000, window_size=1024, hop_length=160):
9
+ super().__init__(labeling, instrument=instrument, sr=sr, window_size=window_size, hop_length=hop_length)
10
+ def synchronize(self, audio, midi, batch_size=128, include_pitch_bends=True, to_midi=True, debug=False,
11
+ include_velocity=False, alignment_padding=50, timing_refinement_range_with_f0s=0):
12
+ """
13
+ Synchronize an audio file or mono waveform in numpy or torch with a MIDI file.
14
+ :param audio: str, pathlib.Path, np.ndarray, or torch.Tensor
15
+ :param midi: str, pathlib.Path, or pretty_midi.PrettyMIDI
16
+ :param batch_size: frames to process at once
17
+ :param include_pitch_bends: whether to include pitch bends in the MIDI file
18
+ :param to_midi: whether to return a MIDI file or a list of note events (as tuple)
19
+ :param debug: whether to plot the alignment path and compare the alignment with the predicted notes
20
+ :param include_velocity: whether to embed the note confidence in place of the velocity in the MIDI file
21
+ :param alignment_padding: how many frames to pad the audio and MIDI representations with
22
+ :param timing_refinement_range_with_f0s: how many frames to refine the alignment with the f0 confidence
23
+ :return: aligned MIDI file as a pretty_midi.PrettyMIDI object
24
+
25
+ Args:
26
+ debug:
27
+ to_midi:
28
+ include_pitch_bends:
29
+ """
30
+
31
+ audio = self.predict(audio, batch_size)
32
+ notes_and_midi = self.out2sync(audio, midi, include_velocity=include_velocity,
33
+ alignment_padding=alignment_padding)
34
+ if notes_and_midi: # it might be none
35
+ notes, midi = notes_and_midi
36
+
37
+ if debug:
38
+ import matplotlib.pyplot as plt
39
+ import pandas as pd
40
+ estimated_notes = self.out2note(audio, postprocessing='spotify', include_pitch_bends=True)
41
+ est_df = pd.DataFrame(estimated_notes).sort_values(by=0)
42
+ note_df = pd.DataFrame(notes).sort_values(by=0)
43
+
44
+ fig, ax = plt.subplots(figsize=(20, 10))
45
+
46
+ for row in notes:
47
+ t_start = row[0] # sec
48
+ t_end = row[1] # sec
49
+ freq = row[2] # Hz
50
+ ax.hlines(freq, t_start, t_end, color='k', linewidth=3, zorder=2, alpha=0.5)
51
+
52
+ for row in estimated_notes:
53
+ t_start = row[0] # sec
54
+ t_end = row[1] # sec
55
+ freq = row[2] # Hz
56
+ ax.hlines(freq, t_start, t_end, color='r', linewidth=3, zorder=2, alpha=0.5)
57
+ fig.suptitle('alignment (black) vs. estimated (red)')
58
+ fig.show()
59
+
60
+ if not include_pitch_bends:
61
+ if to_midi:
62
+ return midi['midi']
63
+ else:
64
+ return notes
65
+ else:
66
+ notes = [(np.argmin(np.abs(audio['time']-note[0])),
67
+ np.argmin(np.abs(audio['time']-note[1])),
68
+ note[2], note[3]) for note in notes]
69
+ notes = self.get_pitch_bends(audio["f0"], notes, timing_refinement_range_with_f0s)
70
+ notes = [
71
+ (audio['time'][note[0]], audio['time'][note[1]], note[2], note[3], note[4]) for note in
72
+ notes
73
+ ]
74
+ if to_midi:
75
+ return self.note2midi(notes, 120) #int(midi['midi'].estimate_tempo()))
76
+ else:
77
+ return notes
78
+
79
+ def out2sync_old(self, out: Dict[str, np.array], midi, include_velocity=False, alignment_padding=50, debug=False):
80
+ """
81
+ Synchronizes the output of the model with the MIDI file.
82
+ Args:
83
+ out: Model output dictionary
84
+ midi: Path to the MIDI file or PrettyMIDI object
85
+ include_velocity: Whether to encode the note confidence in place of velocity
86
+ alignment_padding: Number of frames to pad the MIDI features with zeros
87
+ debug: Visualize the alignment
88
+
89
+ Returns:
90
+ note events and the aligned PrettyMIDI object
91
+ """
92
+ midi = self.labeling.represent_midi(midi, self.sr/self.hop_length)
93
+
94
+ audio_midi_anchors = self.prepare_for_synchronization(out, midi, feature_rate=self.sr/self.hop_length,
95
+ pad_length=alignment_padding)
96
+ if isinstance(audio_midi_anchors, str):
97
+ print(audio_midi_anchors)
98
+ return None # the file is corrupted! no possible alignment at all
99
+ else:
100
+ audio, midi, anchor_pairs = audio_midi_anchors
101
+
102
+ ALPHA = 0.6 # This is the coefficient of onsets, 1 - ALPHA for offsets
103
+
104
+ wp = sync_via_mrmsdtw_with_anchors(f_chroma1=audio['note'].T,
105
+ f_onset1=np.hstack([ALPHA * audio['onset'],
106
+ (1 - ALPHA) * audio['offset']]).T,
107
+ f_chroma2=midi['note'].T,
108
+ f_onset2=np.hstack([ALPHA * midi['onset'],
109
+ (1 - ALPHA) * midi['offset']]).T,
110
+ input_feature_rate=self.sr/self.hop_length,
111
+ step_weights=np.array([1.5, 1.5, 2.0]),
112
+ threshold_rec=10 ** 6,
113
+ verbose=debug, normalize_chroma=False,
114
+ anchor_pairs=anchor_pairs)
115
+ wp = make_path_strictly_monotonic(wp).astype(int)
116
+
117
+ audio_time = np.take(audio['time'], wp[0])
118
+ midi_time = np.take(midi['time'], wp[1])
119
+
120
+ notes = []
121
+ for instrument in midi['midi'].instruments:
122
+ for note in instrument.notes:
123
+ note.start = np.interp(note.start, midi_time, audio_time)
124
+ note.end = np.interp(note.end, midi_time, audio_time)
125
+
126
+ if note.end - note.start <= 0.012: # notes should be at least 12 ms (i.e. 2 frames)
127
+ note.start = note.start - 0.003
128
+ note.end = note.start + 0.012
129
+
130
+ if include_velocity: # encode the note confidence in place of velocity
131
+ velocity = np.median(audio['note'][np.argmin(np.abs(audio['time']-note.start)):
132
+ np.argmin(np.abs(audio['time']-note.end)),
133
+ note.pitch-self.labeling.midi_centers[0]])
134
+
135
+ note.velocity = max(1, velocity*127) # velocity should be at least 1 otherwise midi removes the note
136
+ else:
137
+ velocity = note.velocity/127
138
+ notes.append((note.start, note.end, note.pitch, velocity))
139
+ return notes, midi
140
+
141
+
142
+ def out2sync(self, out: Dict[str, np.array], midi, include_velocity=False, alignment_padding=50, debug=False):
143
+ """
144
+ Synchronizes the output of the model with the MIDI file.
145
+ Args:
146
+ out: Model output dictionary
147
+ midi: Path to the MIDI file or PrettyMIDI object
148
+ include_velocity: Whether to encode the note confidence in place of velocity
149
+ alignment_padding: Number of frames to pad the MIDI features with zeros
150
+ debug: Visualize the alignment
151
+
152
+ Returns:
153
+ note events and the aligned PrettyMIDI object
154
+ """
155
+ midi = self.labeling.represent_midi(midi, self.sr/self.hop_length)
156
+
157
+ audio_midi_anchors = self.prepare_for_synchronization(out, midi, feature_rate=self.sr/self.hop_length,
158
+ pad_length=alignment_padding)
159
+ if isinstance(audio_midi_anchors, str):
160
+ print(audio_midi_anchors)
161
+ return None # the file is corrupted! no possible alignment at all
162
+ else:
163
+ audio, midi, anchor_pairs = audio_midi_anchors
164
+
165
+ ALPHA = 0.6 # This is the coefficient of onsets, 1 - ALPHA for offsets
166
+
167
+ starts = (np.array(anchor_pairs[0])*self.sr/self.hop_length).astype(int)
168
+ ends = (np.array(anchor_pairs[1])*self.sr/self.hop_length).astype(int)
169
+
170
+ wp = sync_via_mrmsdtw_with_anchors(f_chroma1=audio['note'].T[:, starts[0]:ends[0]],
171
+ f_onset1=np.hstack([ALPHA * audio['onset'],
172
+ (1 - ALPHA) * audio['offset']]).T[:, starts[0]:ends[0]],
173
+ f_chroma2=midi['note'].T[:, starts[1]:ends[1]],
174
+ f_onset2=np.hstack([ALPHA * midi['onset'],
175
+ (1 - ALPHA) * midi['offset']]).T[:, starts[1]:ends[1]],
176
+ input_feature_rate=self.sr/self.hop_length,
177
+ step_weights=np.array([1.5, 1.5, 2.0]),
178
+ threshold_rec=10 ** 6,
179
+ verbose=debug, normalize_chroma=False,
180
+ anchor_pairs=None)
181
+ wp = make_path_strictly_monotonic(wp).astype(int)
182
+ wp[0] += starts[0]
183
+ wp[1] += starts[1]
184
+ wp = np.hstack((wp, ends[:,np.newaxis]))
185
+
186
+ audio_time = np.take(audio['time'], wp[0])
187
+ midi_time = np.take(midi['time'], wp[1])
188
+
189
+ notes = []
190
+ for instrument in midi['midi'].instruments:
191
+ for note in instrument.notes:
192
+ note.start = np.interp(note.start, midi_time, audio_time)
193
+ note.end = np.interp(note.end, midi_time, audio_time)
194
+
195
+ if note.end - note.start <= 0.012: # notes should be at least 12 ms (i.e. 2 frames)
196
+ note.start = note.start - 0.003
197
+ note.end = note.start + 0.012
198
+
199
+ if include_velocity: # encode the note confidence in place of velocity
200
+ velocity = np.median(audio['note'][np.argmin(np.abs(audio['time']-note.start)):
201
+ np.argmin(np.abs(audio['time']-note.end)),
202
+ note.pitch-self.labeling.midi_centers[0]])
203
+
204
+ note.velocity = max(1, velocity*127) # velocity should be at least 1 otherwise midi removes the note
205
+ else:
206
+ velocity = note.velocity/127
207
+ notes.append((note.start, note.end, note.pitch, velocity))
208
+ return notes, midi
209
+
210
+ @staticmethod
211
+ def pad_representations(dict_of_representations, pad_length=10):
212
+ """
213
+ Pad the representations so that the DTW does not enforce them to encompass the entire duration.
214
+ Args:
215
+ dict_of_representations: audio or midi representations
216
+ pad_length: how many frames to pad
217
+
218
+ Returns:
219
+ padded representations
220
+ """
221
+ for key, value in dict_of_representations.items():
222
+ if key == 'time':
223
+ padded_time = dict_of_representations[key]
224
+ padded_time = np.concatenate([padded_time[:2*pad_length], padded_time+padded_time[2*pad_length]])
225
+ dict_of_representations[key] = padded_time - padded_time[pad_length] # this is to ensure that the
226
+ # first frame times are negative until the real zero time
227
+ elif key in ['onset', 'offset', 'note']:
228
+ dict_of_representations[key] = np.pad(value, ((pad_length, pad_length), (0, 0)))
229
+ elif key in ['start_anchor', 'end_anchor']:
230
+ anchor_time = dict_of_representations[key][0][0]
231
+ anchor_time = np.argmin(np.abs(dict_of_representations['time'] - anchor_time))
232
+ dict_of_representations[key][:,0] = anchor_time
233
+ dict_of_representations[key] = dict_of_representations[key].astype(np.int)
234
+ return dict_of_representations
235
+
236
+ def prepare_for_synchronization(self, audio, midi, feature_rate=44100/256, pad_length=100):
237
+ """
238
+ MrMsDTW works better with start and end anchors. This function finds the start and end anchors for audio
239
+ based on the midi notes. It also pads the MIDI representations since MIDI files most often start with an active
240
+ note and end with an active note. Thus, the DTW will try to align the active notes to the entire duration of the
241
+ audio. This is not desirable. Therefore, we pad the MIDI representations with a few frames of silence at the
242
+ beginning and end of the audio. This way, the DTW will not try to align the active notes to the entire duration.
243
+ Args:
244
+ audio:
245
+ midi:
246
+ feature_rate:
247
+ pad_length:
248
+
249
+ Returns:
250
+
251
+ """
252
+ # first pad the MIDI
253
+ midi = self.pad_representations(midi, pad_length)
254
+
255
+ # sometimes f0s are more reliable than the notes. So, we use both the f0s and the notes together to find the
256
+ # start and end anchors. f0 lookup bins is the number of bins to look around the f0 to assign a note to it.
257
+ f0_lookup_bins = int(100//(2*self.labeling.f0_granularity_c))
258
+
259
+ # find the start anchor for the audio
260
+ # first decide on which notes to use for the start anchor (take the entire chord where the MIDI file starts)
261
+ anchor_notes = midi['start_anchor'][:, 1] - self.labeling.midi_centers[0]
262
+ # now find which f0 bins to look at for the start anchor
263
+ anchor_f0s = [self.midi_pitch_to_contour_bin(an+self.labeling.midi_centers[0]) for an in anchor_notes]
264
+ anchor_f0s = np.array([list(range(f0-f0_lookup_bins, f0+f0_lookup_bins+1)) for f0 in anchor_f0s]).reshape(-1)
265
+ # first start anchor proposals come from the notes
266
+ anchor_vals = np.any(audio['note'][:, anchor_notes]>0.5, axis=1)
267
+ # now the f0s
268
+ anchor_vals_f0 = np.any(audio['f0'][:, anchor_f0s]>0.5, axis=1)
269
+ # combine the two
270
+ anchor_vals = np.logical_or(anchor_vals, anchor_vals_f0)
271
+ if not any(anchor_vals):
272
+ return 'corrupted' # do not consider the file if we cannot find the start anchor
273
+ audio_start = np.argmax(anchor_vals)
274
+
275
+ # now the end anchor (most string instruments use chords in cadences: in general the end anchor is polyphonic)
276
+ anchor_notes = midi['end_anchor'][:, 1] - self.labeling.midi_centers[0]
277
+ anchor_f0s = [self.midi_pitch_to_contour_bin(an+self.labeling.midi_centers[0]) for an in anchor_notes]
278
+ anchor_f0s = np.array([list(range(f0-f0_lookup_bins, f0+f0_lookup_bins+1)) for f0 in anchor_f0s]).reshape(-1)
279
+ # the same procedure as above
280
+ anchor_vals = np.any(audio['note'][::-1, anchor_notes]>0.5, axis=1)
281
+ anchor_vals_f0 = np.any(audio['f0'][::-1, anchor_f0s]>0.5, axis=1)
282
+ anchor_vals = np.logical_or(anchor_vals, anchor_vals_f0)
283
+ if not any(anchor_vals):
284
+ return 'corrupted' # do not consider the file if we cannot find the end anchor
285
+ audio_end = audio['note'].shape[0] - np.argmax(anchor_vals)
286
+
287
+ if audio_end - audio_start < (midi['end_anchor'][0][0] - midi['start_anchor'][0][0])/10: # no one plays x10 faster
288
+ return 'corrupted' # do not consider the interval between anchors is too short
289
+ anchor_pairs = [(audio_start - 5, midi['start_anchor'][0][0] - 5),
290
+ (audio_end + 5, midi['end_anchor'][0][0] + 5)]
291
+
292
+ if anchor_pairs[0][0] < 1:
293
+ anchor_pairs[0] = (1, midi['start_anchor'][0][0])
294
+ if anchor_pairs[1][0] > audio['note'].shape[0] - 1:
295
+ anchor_pairs[1] = (audio['note'].shape[0] - 1, midi['end_anchor'][0][0])
296
+
297
+ return audio, midi, [(anchor_pairs[0][0]/feature_rate, anchor_pairs[0][1]/feature_rate),
298
+ (anchor_pairs[1][0]/feature_rate, anchor_pairs[1][1]/feature_rate)]
299
+
musc/transcriber.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import DefaultDict, Dict, List, Optional, Tuple
3
+ import pretty_midi
4
+ import numpy as np
5
+ from musc.postprocessing import RegressionPostProcessor, spotify_create_notes
6
+ from musc.pitch_estimator import PitchEstimator
7
+
8
+
9
+ class Transcriber(PitchEstimator):
10
+ def __init__(self, labeling, instrument='Violin', sr=16000, window_size=1024, hop_length=160):
11
+ super().__init__(labeling, instrument=instrument, sr=sr, window_size=window_size, hop_length=hop_length)
12
+
13
+ def transcribe(self, audio, batch_size=128, postprocessing='spotify', include_pitch_bends=True, to_midi=True,
14
+ debug=False):
15
+ """
16
+ Transcribe an audio file or mono waveform in numpy or torch into MIDI with pitch bends.
17
+ :param audio: str, pathlib.Path, np.ndarray, or torch.Tensor
18
+ :param batch_size: frames to process at once
19
+ :param postprocessing: note creation method. 'spotify'(default) or 'tiktok'
20
+ :param include_pitch_bends: whether to include pitch bends in the MIDI file
21
+ :param to_midi: whether to return a MIDI file or a list of note events (as tuple)
22
+ :return: transcribed MIDI file as a pretty_midi.PrettyMIDI object
23
+ """
24
+ out = self.predict(audio, batch_size)
25
+ if debug:
26
+ import matplotlib.pyplot as plt
27
+ plt.imshow(out['f0'].T, aspect='auto', origin='lower')
28
+ plt.show()
29
+ plt.imshow(out['note'].T, aspect='auto', origin='lower')
30
+ plt.show()
31
+
32
+ plt.imshow(out['onset'].T, aspect='auto', origin='lower')
33
+ plt.show()
34
+
35
+ plt.imshow(out['offset'].T, aspect='auto', origin='lower')
36
+ plt.show()
37
+
38
+ if to_midi:
39
+ return self.out2midi(out, postprocessing, include_pitch_bends)
40
+ else:
41
+ return self.out2note(out, postprocessing, include_pitch_bends)
42
+
43
+
44
+
45
+ def out2note(self, output: Dict[str, np.array], postprocessing='spotify',
46
+ include_pitch_bends: bool = True,
47
+ ) -> List[Tuple[float, float, int, float, Optional[List[int]]]]:
48
+ """Convert model output to notes
49
+ """
50
+ if postprocessing == 'spotify':
51
+ estimated_notes = spotify_create_notes(
52
+ output["note"],
53
+ output["onset"],
54
+ note_low=self.labeling.midi_centers[0],
55
+ note_high=self.labeling.midi_centers[-1],
56
+ onset_thresh=0.5,
57
+ frame_thresh=0.3,
58
+ infer_onsets=True,
59
+ min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))), #127.70
60
+ melodia_trick=True,
61
+ )
62
+
63
+ if postprocessing == 'rebab':
64
+ estimated_notes = spotify_create_notes(
65
+ output["note"],
66
+ output["onset"],
67
+ note_low=self.labeling.midi_centers[0],
68
+ note_high=self.labeling.midi_centers[-1],
69
+ onset_thresh=0.2,
70
+ frame_thresh=0.2,
71
+ infer_onsets=True,
72
+ min_note_len=int(np.round(127.70 / 1000 * (self.sr / self.hop_length))), #127.70
73
+ melodia_trick=True,
74
+ )
75
+
76
+
77
+ elif postprocessing == 'tiktok':
78
+ postprocessor = RegressionPostProcessor(
79
+ frames_per_second=self.sr / self.hop_length,
80
+ classes_num=self.labeling.midi_centers.shape[0],
81
+ begin_note=self.labeling.midi_centers[0],
82
+ onset_threshold=0.2,
83
+ offset_threshold=0.2,
84
+ frame_threshold=0.3,
85
+ pedal_offset_threshold=0.5,
86
+ )
87
+ tiktok_note_dict, _ = postprocessor.output_dict_to_midi_events(output)
88
+ estimated_notes = []
89
+ for list_item in tiktok_note_dict:
90
+ if list_item['offset_time'] > 0.6 + list_item['onset_time']:
91
+ estimated_notes.append((int(np.floor(list_item['onset_time']/(output['time'][1]))),
92
+ int(np.ceil(list_item['offset_time']/(output['time'][1]))),
93
+ list_item['midi_note'], list_item['velocity']/128))
94
+ if include_pitch_bends:
95
+ estimated_notes_with_pitch_bend = self.get_pitch_bends(output["f0"], estimated_notes)
96
+ else:
97
+ estimated_notes_with_pitch_bend = [(note[0], note[1], note[2], note[3], None) for note in estimated_notes]
98
+
99
+ times_s = output['time']
100
+ estimated_notes_time_seconds = [
101
+ (times_s[note[0]], times_s[note[1]], note[2], note[3], note[4]) for note in estimated_notes_with_pitch_bend
102
+ ]
103
+
104
+ return estimated_notes_time_seconds
105
+
106
+
107
+ def out2midi(self, output: Dict[str, np.array], postprocessing: str = 'spotify', include_pitch_bends: bool = True,
108
+ ) -> pretty_midi.PrettyMIDI:
109
+ """Convert model output to MIDI
110
+ Args:
111
+ output: A dictionary with shape
112
+ {
113
+ 'frame': array of shape (n_times, n_freqs),
114
+ 'onset': array of shape (n_times, n_freqs),
115
+ 'contour': array of shape (n_times, 3*n_freqs)
116
+ }
117
+ representing the output of the basic pitch model.
118
+ postprocessing: spotify or tiktok postprocessing.
119
+ include_pitch_bends: If True, include pitch bends.
120
+ Returns:
121
+ note_events: A list of note event tuples (start_time_s, end_time_s, pitch_midi, amplitude)
122
+ """
123
+ estimated_notes_time_seconds = self.out2note(output, postprocessing, include_pitch_bends)
124
+ midi_tempo = 120 # todo: infer tempo from the onsets
125
+ return self.note2midi(estimated_notes_time_seconds, midi_tempo)
126
+
127
+
128
+ def note2midi(
129
+ self, note_events_with_pitch_bends: List[Tuple[float, float, int, float, Optional[List[int]]]],
130
+ midi_tempo: float = 120,
131
+ ) -> pretty_midi.PrettyMIDI:
132
+ """Create a pretty_midi object from note events
133
+ :param note_events_with_pitch_bends: list of tuples
134
+ [(start_time_seconds, end_time_seconds, pitch_midi, amplitude)]
135
+ :param midi_tempo: #todo: infer tempo from the onsets
136
+ :return: transcribed MIDI file as a pretty_midi.PrettyMIDI object
137
+ """
138
+ mid = pretty_midi.PrettyMIDI(initial_tempo=midi_tempo)
139
+
140
+ program = pretty_midi.instrument_name_to_program(self.instrument)
141
+ instruments: DefaultDict[int, pretty_midi.Instrument] = defaultdict(
142
+ lambda: pretty_midi.Instrument(program=program)
143
+ )
144
+ for start_time, end_time, note_number, amplitude, pitch_bend in note_events_with_pitch_bends:
145
+ instrument = instruments[note_number]
146
+ note = pretty_midi.Note(
147
+ velocity=int(np.round(127 * amplitude)),
148
+ pitch=note_number,
149
+ start=start_time,
150
+ end=end_time,
151
+ )
152
+ instrument.notes.append(note)
153
+ if not isinstance(pitch_bend, np.ndarray):
154
+ continue
155
+ pitch_bend_times = np.linspace(start_time, end_time, len(pitch_bend))
156
+
157
+ for pb_time, pb_midi in zip(pitch_bend_times, pitch_bend):
158
+ instrument.pitch_bends.append(pretty_midi.PitchBend(pb_midi, pb_time))
159
+
160
+ mid.instruments.extend(instruments.values())
161
+
162
+ return mid
163
+
musc/violin.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ { "model_file": "1FfUjC3usmZBoTxNT6rNVIcYw4pol4K1g",
2
+ "wiring": "parallel",
3
+ "sampling_rate": 44100,
4
+ "pathway_multiscale": 4,
5
+ "num_pathway_layers": 2,
6
+ "num_separator_layers": 16,
7
+ "num_representation_layers": 4,
8
+ "hop_length": 256,
9
+ "chunk_size": 512,
10
+ "minSNR": -32, "maxSNR": 96,
11
+ "note_low": "F#3", "note_high": "E8",
12
+ "f0_bins_per_semitone": 10, "f0_smooth_std_c": 12, "onset_smooth_std": 0.7
13
+ }
musc/violin_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a913356f059be6dc930be41158ac864f7d5511889ef0b2a6b6ba75a4a8732750
3
+ size 218770231