dennisvdang commited on
Commit
6aa61fa
·
1 Parent(s): 56dbeab

Script fixes

Browse files
Files changed (1) hide show
  1. app.py +411 -70
app.py CHANGED
@@ -50,108 +50,449 @@ def extract_audio(url, output_path=AUDIO_TEMP_PATH):
50
  st.error(f"An error occurred: {e}")
51
  return None, None
52
 
 
53
  def strip_silence(audio_path):
 
54
  sound = AudioSegment.from_file(audio_path)
55
- nonsilent_ranges = detect_nonsilent(sound, min_silence_len=500, silence_thresh=-50)
56
- stripped = reduce(lambda acc, val: acc + sound[val[0]:val[1]], nonsilent_ranges, AudioSegment.empty())
 
 
57
  stripped.export(audio_path, format='mp3')
58
 
 
59
  class AudioFeature:
 
 
60
  def __init__(self, audio_path, sr=SR, hop_length=HOP_LENGTH):
61
  self.audio_path = audio_path
62
- self.sr = sr
 
 
 
63
  self.hop_length = hop_length
64
- self.y = None
65
- self.y_harm, self.y_perc = None, None
66
- self.spectrogram = None
67
- self.rms = None
68
- self.melspectrogram = None
69
  self.mel_acts = None
70
- self.chromagram = None
71
- self.chroma_acts = None
72
- self.onset_env = None
73
- self.tempogram = None
74
- self.tempogram_acts = None
75
  self.mfccs = None
76
  self.mfcc_acts = None
77
- self.combined_features = None
78
  self.n_frames = None
 
 
 
 
79
  self.tempo = None
80
- self.beats = None
81
- self.meter_grid = None
82
- self.key, self.mode = None, None
 
 
83
 
84
- def detect_key(self, chroma_vals):
85
- note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
86
- major_profile = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88])
87
- minor_profile = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17])
 
 
 
 
88
  major_profile /= np.linalg.norm(major_profile)
89
  minor_profile /= np.linalg.norm(minor_profile)
90
 
91
- major_correlations = [np.corrcoef(chroma_vals, np.roll(major_profile, i))[0, 1] for i in range(12)]
92
- minor_correlations = [np.corrcoef(chroma_vals, np.roll(minor_profile, i))[0, 1] for i in range(12)]
 
 
93
 
94
  max_major_idx = np.argmax(major_correlations)
95
  max_minor_idx = np.argmax(minor_correlations)
96
 
97
  self.mode = 'major' if major_correlations[max_major_idx] > minor_correlations[max_minor_idx] else 'minor'
98
- self.key = note_names[max_major_idx if self.mode == 'major' else max_minor_idx]
 
99
  return self.key, self.mode
100
 
101
- def calculate_ki_chroma(self, waveform, sr, hop_length):
102
- chromagram = librosa.feature.chroma_cqt(y=waveform, sr=sr, hop_length=hop_length, bins_per_octave=24)
103
- chromagram = (chromagram - chromagram.min()) / (chromagram.max() - chromagram.min())
 
 
 
104
  chroma_vals = np.sum(chromagram, axis=1)
105
  key, mode = self.detect_key(chroma_vals)
106
- key_idx = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'].index(key)
 
107
  shift_amount = -key_idx if mode == 'major' else -(key_idx + 3) % 12
108
  return librosa.util.normalize(np.roll(chromagram, shift_amount, axis=0), axis=1)
109
-
110
  def extract_features(self):
 
111
  self.y, self.sr = librosa.load(self.audio_path, sr=self.sr)
112
  self.y_harm, self.y_perc = librosa.effects.hpss(self.y)
113
- self.spectrogram, _ = librosa.magphase(librosa.stft(self.y, hop_length=self.hop_length))
114
- self.rms = librosa.feature.rms(S=self.spectrogram, hop_length=self.hop_length).astype(np.float32)
115
- self.melspectrogram = librosa.feature.melspectrogram(y=self.y, sr=self.sr, n_mels=128, hop_length=self.hop_length).astype(np.float32)
116
- self.mel_acts = librosa.decompose.decompose(self.melspectrogram, n_components=3, sort=True)[1].astype(np.float32)
117
- self.chromagram = self.calculate_ki_chroma(self.y_harm, self.sr, self.hop_length).astype(np.float32)
118
- self.chroma_acts = librosa.decompose.decompose(self.chromagram, n_components=4, sort=True)[1].astype(np.float32)
119
- self.onset_env = librosa.onset.onset_strength(y=self.y_perc, sr=self.sr, hop_length=self.hop_length)
120
- self.tempogram = np.clip(librosa.feature.tempogram(onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length), 0, np.percentile(self.tempogram, 99)).astype(np.float32)
121
- self.tempogram_acts = librosa.decompose.decompose(self.tempogram, n_components=3, sort=True)[1].astype(np.float32)
122
- self.mfccs = librosa.feature.mfcc(y=self.y, sr=self.sr, n_mfcc=13, hop_length=self.hop_length).astype(np.float32)
123
- self.mfcc_acts = librosa.decompose.decompose(self.mfccs, n_components=3, sort=True)[1].astype(np.float32)
124
- self.combined_features = np.vstack([self.rms, self.mel_acts, self.chroma_acts, self.tempogram_acts, self.mfcc_acts])
125
- self.n_frames = self.combined_features.shape[1]
126
- self.tempo, self.beats = librosa.beat.beat_track(y=self.y_perc, sr=self.sr, hop_length=self.hop_length)
127
- self.meter_grid = librosa.util.fix_frames(librosa.util.frame(self.beats, frame_length=MAX_METERS, hop_length=1), x_min=0, x_max=self.n_frames)
128
- self.key, self.mode = self.detect_key(np.sum(self.chromagram, axis=1))
129
-
130
- def get_features(self):
131
- self.extract_features()
132
- return self.combined_features, self.n_frames, self.tempo, self.beats, self.meter_grid, self.key, self.mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  def load_model(model_path=MODEL_PATH):
135
- return tf.keras.models.load_model(model_path)
136
-
137
- def predict_chorus(audio_features, model):
138
- features, n_frames, tempo, beats, meter_grid, key, mode = audio_features.get_features()
139
- features = features[:, :MAX_FRAMES]
140
- features = np.expand_dims(features, axis=0)
141
- scaler = StandardScaler()
142
- features = scaler.fit_transform(features.reshape(-1, features.shape[-1])).reshape(features.shape)
143
- predictions = model.predict(features)
144
- return predictions
145
-
146
- def plot_predictions(predictions, title):
147
- plt.figure(figsize=(10, 4))
148
- plt.plot(predictions[0], label='Chorus Probability')
149
- plt.title(title)
150
- plt.xlabel('Frame')
151
- plt.ylabel('Probability')
152
- plt.legend()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  st.pyplot(plt)
154
 
 
155
  def main():
156
  st.title("Chorus Finder")
157
  st.write("Upload a YouTube URL to find the chorus in the song.")
@@ -161,10 +502,10 @@ def main():
161
  audio_file, video_title = extract_audio(url)
162
  if audio_file:
163
  strip_silence(audio_file)
164
- audio_features = AudioFeature(audio_file)
165
  model = load_model()
166
- predictions = predict_chorus(audio_features, model)
167
- plot_predictions(predictions, video_title)
168
  shutil.rmtree(AUDIO_TEMP_PATH)
169
  else:
170
  st.error("Please enter a valid YouTube URL")
 
50
  st.error(f"An error occurred: {e}")
51
  return None, None
52
 
53
+
54
  def strip_silence(audio_path):
55
+ """Removes silent parts from an audio file."""
56
  sound = AudioSegment.from_file(audio_path)
57
+ nonsilent_ranges = detect_nonsilent(
58
+ sound, min_silence_len=500, silence_thresh=-50)
59
+ stripped = reduce(lambda acc, val: acc + sound[val[0]:val[1]],
60
+ nonsilent_ranges, AudioSegment.empty())
61
  stripped.export(audio_path, format='mp3')
62
 
63
+
64
  class AudioFeature:
65
+ """Class for extracting and processing audio features."""
66
+
67
  def __init__(self, audio_path, sr=SR, hop_length=HOP_LENGTH):
68
  self.audio_path = audio_path
69
+ self.beats = None
70
+ self.chroma_acts = None
71
+ self.chromagram = None
72
+ self.combined_features = None
73
  self.hop_length = hop_length
74
+ self.key, self.mode = None, None
 
 
 
 
75
  self.mel_acts = None
76
+ self.melspectrogram = None
77
+ self.meter_grid = None
 
 
 
78
  self.mfccs = None
79
  self.mfcc_acts = None
 
80
  self.n_frames = None
81
+ self.onset_env = None
82
+ self.rms = None
83
+ self.spectrogram = None
84
+ self.sr = sr
85
  self.tempo = None
86
+ self.tempogram = None
87
+ self.tempogram_acts = None
88
+ self.time_signature = 4
89
+ self.y = None
90
+ self.y_harm, self.y_perc = None, None
91
 
92
+ def detect_key(self, chroma_vals: np.ndarray) -> Tuple[str, str]:
93
+ """Detect the key and mode (major or minor) of the audio segment."""
94
+ note_names = ['C', 'C#', 'D', 'D#', 'E',
95
+ 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
96
+ major_profile = np.array(
97
+ [6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88])
98
+ minor_profile = np.array(
99
+ [6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17])
100
  major_profile /= np.linalg.norm(major_profile)
101
  minor_profile /= np.linalg.norm(minor_profile)
102
 
103
+ major_correlations = [np.corrcoef(chroma_vals, np.roll(major_profile, i))[
104
+ 0, 1] for i in range(12)]
105
+ minor_correlations = [np.corrcoef(chroma_vals, np.roll(minor_profile, i))[
106
+ 0, 1] for i in range(12)]
107
 
108
  max_major_idx = np.argmax(major_correlations)
109
  max_minor_idx = np.argmax(minor_correlations)
110
 
111
  self.mode = 'major' if major_correlations[max_major_idx] > minor_correlations[max_minor_idx] else 'minor'
112
+ self.key = note_names[max_major_idx if self.mode ==
113
+ 'major' else max_minor_idx]
114
  return self.key, self.mode
115
 
116
+ def calculate_ki_chroma(self, waveform: np.ndarray, sr: int, hop_length: int) -> np.ndarray:
117
+ """Calculate a normalized, key-invariant chromagram for the given audio waveform."""
118
+ chromagram = librosa.feature.chroma_cqt(
119
+ y=waveform, sr=sr, hop_length=hop_length, bins_per_octave=24)
120
+ chromagram = (chromagram - chromagram.min()) / \
121
+ (chromagram.max() - chromagram.min())
122
  chroma_vals = np.sum(chromagram, axis=1)
123
  key, mode = self.detect_key(chroma_vals)
124
+ key_idx = ['C', 'C#', 'D', 'D#', 'E', 'F',
125
+ 'F#', 'G', 'G#', 'A', 'A#', 'B'].index(key)
126
  shift_amount = -key_idx if mode == 'major' else -(key_idx + 3) % 12
127
  return librosa.util.normalize(np.roll(chromagram, shift_amount, axis=0), axis=1)
128
+
129
  def extract_features(self):
130
+ """Extract various audio features from the loaded audio."""
131
  self.y, self.sr = librosa.load(self.audio_path, sr=self.sr)
132
  self.y_harm, self.y_perc = librosa.effects.hpss(self.y)
133
+ self.spectrogram, _ = librosa.magphase(
134
+ librosa.stft(self.y, hop_length=self.hop_length))
135
+ self.rms = librosa.feature.rms(
136
+ S=self.spectrogram, hop_length=self.hop_length).astype(np.float32)
137
+ self.melspectrogram = librosa.feature.melspectrogram(
138
+ y=self.y, sr=self.sr, n_mels=128, hop_length=self.hop_length).astype(np.float32)
139
+ self.mel_acts = librosa.decompose.decompose(
140
+ self.melspectrogram, n_components=3, sort=True)[1].astype(np.float32)
141
+ self.chromagram = self.calculate_ki_chroma(
142
+ self.y_harm, self.sr, self.hop_length).astype(np.float32)
143
+ self.chroma_acts = librosa.decompose.decompose(
144
+ self.chromagram, n_components=4, sort=True)[1].astype(np.float32)
145
+ self.onset_env = librosa.onset.onset_strength(
146
+ y=self.y_perc, sr=self.sr, hop_length=self.hop_length)
147
+ self.tempogram = np.clip(librosa.feature.tempogram(
148
+ onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length), 0, None)
149
+ self.tempogram_acts = librosa.decompose.decompose(
150
+ self.tempogram, n_components=3, sort=True)[1]
151
+ self.mfccs = librosa.feature.mfcc(
152
+ y=self.y, sr=self.sr, n_mfcc=20, hop_length=self.hop_length)
153
+ self.mfccs += abs(np.min(self.mfccs))
154
+ self.mfcc_acts = librosa.decompose.decompose(
155
+ self.mfccs, n_components=4, sort=True)[1].astype(np.float32)
156
+
157
+ features = [self.rms, self.mel_acts, self.chroma_acts,
158
+ self.tempogram_acts, self.mfcc_acts]
159
+ feature_names = ['rms', 'mel_acts', 'chroma_acts',
160
+ 'tempogram_acts', 'mfcc_acts']
161
+ dims = {name: feature.shape[0]
162
+ for feature, name in zip(features, feature_names)}
163
+ total_inv_dim = sum(1 / dim for dim in dims.values())
164
+ weights = {name: 1 / (dims[name] * total_inv_dim)
165
+ for name in feature_names}
166
+ std_weighted_features = [StandardScaler().fit_transform(feature.T).T * weights[name]
167
+ for feature, name in zip(features, feature_names)]
168
+ self.combined_features = np.concatenate(
169
+ std_weighted_features, axis=0).T.astype(np.float32)
170
+ self.n_frames = len(self.combined_features)
171
+
172
+ def create_meter_grid(self):
173
+ """Create a grid based on the meter of the song, using tempo and beats."""
174
+ self.tempo, self.beats = librosa.beat.beat_track(
175
+ onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length)
176
+ self.tempo = self.tempo * 2 if self.tempo < 70 else self.tempo / \
177
+ 2 if self.tempo > 140 else self.tempo
178
+ self.meter_grid = self._create_meter_grid()
179
+ return self.meter_grid
180
+
181
+ def _create_meter_grid(self) -> np.ndarray:
182
+ """
183
+ Helper function to create a meter grid for the song, extrapolating both forwards and backwards from an anchor frame.
184
+
185
+ Returns:
186
+ - np.ndarray: The meter grid.
187
+ """
188
+ seconds_per_beat = 60 / self.tempo
189
+ beat_interval = int(librosa.time_to_frames(
190
+ seconds_per_beat, sr=self.sr, hop_length=self.hop_length))
191
+
192
+ # Find the best matching start beat based on the tempo and existing beats
193
+ best_match_start = max((1 - abs(np.mean(self.beats[i:i+3]) - beat_interval) / beat_interval, self.beats[i])
194
+ for i in range(len(self.beats) - 2))[1]
195
+ anchor_frame = best_match_start if best_match_start > 0.95 else self.beats[0]
196
+ first_beat_time = librosa.frames_to_time(
197
+ anchor_frame, sr=self.sr, hop_length=self.hop_length)
198
+
199
+ # Calculate the number of beats forward and backward
200
+ time_duration = librosa.frames_to_time(
201
+ self.n_frames, sr=self.sr, hop_length=self.hop_length)
202
+ num_beats_forward = int(
203
+ (time_duration - first_beat_time) / seconds_per_beat)
204
+ num_beats_backward = int(first_beat_time / seconds_per_beat) + 1
205
+
206
+ # Create beat times forward and backward
207
+ beat_times_forward = first_beat_time + \
208
+ np.arange(num_beats_forward) * seconds_per_beat
209
+ beat_times_backward = first_beat_time - \
210
+ np.arange(1, num_beats_backward) * seconds_per_beat
211
+
212
+ # Combine and sort the beat times
213
+ beat_grid = np.concatenate(
214
+ (np.array([0.0]), beat_times_backward[::-1], beat_times_forward))
215
+ meter_indices = np.arange(0, len(beat_grid), self.time_signature)
216
+ meter_grid = beat_grid[meter_indices]
217
+
218
+ # Ensure the meter grid starts at 0 and ends at frame_duration
219
+ if meter_grid[0] != 0.0:
220
+ meter_grid = np.insert(meter_grid, 0, 0.0)
221
+ meter_grid = librosa.time_to_frames(
222
+ meter_grid, sr=self.sr, hop_length=self.hop_length)
223
+ if meter_grid[-1] != self.n_frames:
224
+ meter_grid = np.append(meter_grid, self.n_frames)
225
+
226
+ return meter_grid
227
+
228
+
229
+ def segment_data_meters(data: np.ndarray, meter_grid: List[int]) -> List[np.ndarray]:
230
+ """
231
+ Divide song data into segments based on measure grid frames.
232
+
233
+ Parameters:
234
+ - data (np.ndarray): The song data to be segmented.
235
+ - meter_grid (List[int]): The grid indicating the start of each measure.
236
+
237
+ Returns:
238
+ - List[np.ndarray]: A list of song data segments.
239
+ """
240
+ meter_segments = [data[s:e]
241
+ for s, e in zip(meter_grid[:-1], meter_grid[1:])]
242
+ meter_segments = [segment.astype(np.float32) for segment in meter_segments]
243
+ return meter_segments
244
+
245
+
246
+ def positional_encoding(position: int, d_model: int) -> np.ndarray:
247
+ """
248
+ Generate a positional encoding for a given position and model dimension.
249
+
250
+ Parameters:
251
+ - position (int): The position for which to generate the encoding.
252
+ - d_model (int): The dimension of the model.
253
+
254
+ Returns:
255
+ - np.ndarray: The positional encoding.
256
+ """
257
+ angle_rads = np.arange(position)[:, np.newaxis] / np.power(
258
+ 10000, (2 * (np.arange(d_model)[np.newaxis, :] // 2)) / np.float32(d_model))
259
+ return np.concatenate([np.sin(angle_rads[:, 0::2]), np.cos(angle_rads[:, 1::2])], axis=-1)
260
+
261
+
262
+ def apply_hierarchical_positional_encoding(segments: List[np.ndarray]) -> List[np.ndarray]:
263
+ """
264
+ Apply positional encoding at the meter and frame levels to a list of segments.
265
+
266
+ Parameters:
267
+ - segments (List[np.ndarray]): The list of segments to encode.
268
+
269
+ Returns:
270
+ - List[np.ndarray]: The list of segments with applied positional encoding.
271
+ """
272
+ n_features = segments[0].shape[1]
273
+ measure_level_encodings = positional_encoding(len(segments), n_features)
274
+ return [
275
+ seg + positional_encoding(len(seg), n_features) +
276
+ measure_level_encodings[i]
277
+ for i, seg in enumerate(segments)
278
+ ]
279
+
280
+
281
+ def pad_song(encoded_segments: List[np.ndarray], max_frames: int = MAX_FRAMES, max_meters: int = MAX_METERS, n_features: int = N_FEATURES) -> np.ndarray:
282
+ """
283
+ Pad or truncate the encoded segments to have the specified max_frames and max_meters dimensions.
284
+
285
+ Parameters:
286
+ - encoded_segments (List[np.ndarray]): The encoded segments to pad or truncate.
287
+ - max_frames (int): The maximum number of frames per segment.
288
+ - max_meters (int): The maximum number of meters.
289
+ - n_features (int): The number of features per frame.
290
+
291
+ Returns:
292
+ - np.ndarray: The padded or truncated song.
293
+ """
294
+ padded_meters = [
295
+ np.pad(meter[:max_frames], ((0, max(0, max_frames -
296
+ meter.shape[0])), (0, 0)), 'constant', constant_values=0)
297
+ for meter in encoded_segments
298
+ ]
299
+ padding_meter = np.zeros((max_frames, n_features))
300
+ padded_song = np.array(
301
+ padded_meters[:max_meters] + [padding_meter] * max(0, max_meters - len(padded_meters)))
302
+ return padded_song
303
+
304
+
305
+ def process_audio(audio_path, trim_silence=True, sr=SR, hop_length=HOP_LENGTH):
306
+ """
307
+ Process an audio file, extracting features and applying positional encoding.
308
+
309
+ Parameters:
310
+ - audio_path (str): The path to the audio file.
311
+ - trim_silence (bool): Whether to trim silence from the audio.
312
+ - sr (int): The sample rate to use when loading the audio.
313
+ - hop_length (int): The hop length to use for feature extraction.
314
+
315
+ Returns:
316
+ - Tuple[np.ndarray, AudioFeature]: The processed audio and its features.
317
+ """
318
+ if trim_silence:
319
+ strip_silence(audio_path)
320
+
321
+ audio_features = AudioFeature(
322
+ audio_path=audio_path, sr=sr, hop_length=hop_length)
323
+ audio_features.extract_features()
324
+ audio_features.create_meter_grid()
325
+ audio_segments = segment_data_meters(
326
+ audio_features.combined_features, audio_features.meter_grid)
327
+ encoded_audio_segments = apply_hierarchical_positional_encoding(
328
+ audio_segments)
329
+ processed_audio = np.expand_dims(pad_song(encoded_audio_segments), axis=0)
330
+
331
+ return processed_audio, audio_features
332
 
333
  def load_model(model_path=MODEL_PATH):
334
+ # Placeholder functions for loading the model
335
+ def custom_binary_crossentropy(y_true, y_pred):
336
+ return y_pred
337
+
338
+ def custom_accuracy(y_true, y_pred):
339
+ return y_pred
340
+
341
+ custom_objects = {
342
+ 'custom_binary_crossentropy': custom_binary_crossentropy,
343
+ 'custom_accuracy': custom_accuracy
344
+ }
345
+ model = tf.keras.models.load_model(model_path, custom_objects=custom_objects)
346
+ return model
347
+
348
+ def smooth_predictions(data: np.ndarray) -> np.ndarray:
349
+ """
350
+ Smooth predictions by correcting isolated mispredictions and removing short sequences of 1s.
351
+
352
+ This function applies a smoothing algorithm to correct isolated zeros and ones in a sequence
353
+ of binary predictions. It also removes isolated sequences of 1s that are shorter than 5.
354
+
355
+ Parameters:
356
+ - data (np.ndarray): Array of binary predictions.
357
+
358
+ Returns:
359
+ - np.ndarray: Smoothed array of binary predictions.
360
+ """
361
+ if not isinstance(data, np.ndarray):
362
+ data = np.array(data)
363
+
364
+ # First pass: Correct isolated 0's
365
+ data_first_pass = data.copy()
366
+ for i in range(1, len(data) - 1):
367
+ if data[i] == 0 and data[i - 1] == 1 and data[i + 1] == 1:
368
+ data_first_pass[i] = 1
369
+
370
+ # Second pass: Correct isolated 1's
371
+ corrected_data = data_first_pass.copy()
372
+ for i in range(1, len(data_first_pass) - 1):
373
+ if data_first_pass[i] == 1 and data_first_pass[i - 1] == 0 and data_first_pass[i + 1] == 0:
374
+ corrected_data[i] = 0
375
+
376
+ # Third pass: Remove short sequences of 1s (less than 5)
377
+ smoothed_data = corrected_data.copy()
378
+ sequence_start = None
379
+ for i in range(len(corrected_data)):
380
+ if corrected_data[i] == 1:
381
+ if sequence_start is None:
382
+ sequence_start = i
383
+ else:
384
+ if sequence_start is not None:
385
+ sequence_length = i - sequence_start
386
+ if sequence_length < 5:
387
+ smoothed_data[sequence_start:i] = 0
388
+ sequence_start = None
389
+
390
+ return smoothed_data
391
+
392
+ def make_predictions(model, processed_audio, audio_features, url, video_name):
393
+ """
394
+ Generate predictions from the model and process them to binary and smoothed predictions.
395
+
396
+ Parameters:
397
+ - model: The loaded model for making predictions.
398
+ - processed_audio: The audio data that has been processed for prediction.
399
+ - audio_features: Audio features object containing necessary metadata like meter grid.
400
+ - url (str): YouTube URL of the audio file.
401
+ - video_name (str): Name of the video.
402
+
403
+ Returns:
404
+ - np.ndarray: The smoothed binary predictions.
405
+ """
406
+ predictions = model.predict(processed_audio)[0]
407
+ binary_predictions = np.round(
408
+ predictions[:(len(audio_features.meter_grid) - 1)]).flatten()
409
+ smoothed_predictions = smooth_predictions(binary_predictions)
410
+
411
+ meter_grid_times = librosa.frames_to_time(
412
+ audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
413
+ chorus_start_times = [meter_grid_times[i] for i in range(len(
414
+ smoothed_predictions)) if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)]
415
+
416
+ youtube_links = [
417
+ f"\033]8;;{url}&t={int(start_time)}s\033\\{url}&t={int(start_time)}s\033]8;;\033\\" for start_time in chorus_start_times
418
+ ]
419
+ max_length = max([len(link) for link in youtube_links] + [len(video_name), len(
420
+ f"Number of choruses identified: {len(chorus_start_times)}")] if chorus_start_times else [0])
421
+ header_footer = "=" * (max_length + 4)
422
+ print()
423
+ print()
424
+ print(header_footer)
425
+ print(f"{video_name.center(max_length + 2)}")
426
+ print(f"Number of choruses identified: {len(chorus_start_times)}".center(
427
+ max_length + 4))
428
+ print(header_footer)
429
+ for link in youtube_links:
430
+ print(link)
431
+ print(header_footer)
432
+
433
+ if len(chorus_start_times) == 0:
434
+ print("No choruses identified.")
435
+
436
+ return smoothed_predictions
437
+
438
+
439
+ def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None:
440
+ """
441
+ Draw meter grid lines on the plot.
442
+
443
+ Parameters:
444
+ - ax (plt.Axes): The matplotlib axes object to draw on.
445
+ - meter_grid_times (np.ndarray): Array of times at which to draw the meter lines.
446
+ """
447
+ for time in meter_grid_times:
448
+ ax.axvline(x=time, color='grey', linestyle='--',
449
+ linewidth=1, alpha=0.6)
450
+
451
+
452
+ def plot_predictions(audio_features, predictions):
453
+ meter_grid_times = librosa.frames_to_time(
454
+ audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
455
+ fig, ax = plt.subplots(figsize=(12.5, 3), dpi=96)
456
+
457
+ # Display harmonic and percussive components without adding them to the legend
458
+ librosa.display.waveshow(audio_features.y_harm, sr=audio_features.sr,
459
+ alpha=0.8, ax=ax, color='deepskyblue')
460
+ librosa.display.waveshow(audio_features.y_perc, sr=audio_features.sr,
461
+ alpha=0.7, ax=ax, color='plum')
462
+ plot_meter_lines(ax, meter_grid_times)
463
+
464
+ for i, prediction in enumerate(predictions):
465
+ start_time = meter_grid_times[i]
466
+ end_time = meter_grid_times[i + 1] if i < len(
467
+ meter_grid_times) - 1 else len(audio_features.y) / audio_features.sr
468
+ if prediction == 1:
469
+ ax.axvspan(start_time, end_time, color='green', alpha=0.3,
470
+ label='Predicted Chorus' if i == 0 else None)
471
+
472
+ ax.set_xlim([0, len(audio_features.y) / audio_features.sr])
473
+ ax.set_ylabel('Amplitude')
474
+ audio_file_name = os.path.basename(audio_features.audio_path)
475
+ ax.set_title(
476
+ f'Chorus Predictions for {os.path.splitext(audio_file_name)[0]}')
477
+
478
+ # Add a green square patch to represent "Chorus" in the legend
479
+ chorus_patch = plt.Rectangle((0, 0), 1, 1, fc='green', alpha=0.3)
480
+ handles, labels = ax.get_legend_handles_labels()
481
+ handles.append(chorus_patch)
482
+ labels.append('Chorus')
483
+ ax.legend(handles=handles, labels=labels)
484
+
485
+ # Set x-tick labels every 10 seconds in single-digit minutes format
486
+ duration = len(audio_features.y) / audio_features.sr
487
+ xticks = np.arange(0, duration, 10)
488
+ xlabels = [f"{int(tick // 60)}:{int(tick % 60):02d}" for tick in xticks]
489
+ ax.set_xticks(xticks)
490
+ ax.set_xticklabels(xlabels)
491
+
492
+ plt.tight_layout()
493
  st.pyplot(plt)
494
 
495
+
496
  def main():
497
  st.title("Chorus Finder")
498
  st.write("Upload a YouTube URL to find the chorus in the song.")
 
502
  audio_file, video_title = extract_audio(url)
503
  if audio_file:
504
  strip_silence(audio_file)
505
+ processed_audio, audio_features = process_audio(audio_path=AUDIO_TEMP_PATH)
506
  model = load_model()
507
+ smoothed_predictions = make_predictions(model, processed_audio, audio_features, url, video_title)
508
+ plot_predictions(audio_features=audio_features, predictions=smoothed_predictions)
509
  shutil.rmtree(AUDIO_TEMP_PATH)
510
  else:
511
  st.error("Please enter a valid YouTube URL")