Spaces:
Sleeping
Sleeping
File size: 2,914 Bytes
ad0da04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Module for visualizing audio data and chorus predictions."""
from typing import List
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
import os
from chorus_detection.audio.processor import AudioFeature
def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None:
"""Draw meter grid lines on the plot.
Args:
ax: The matplotlib axes object to draw on
meter_grid_times: Array of times at which to draw the meter lines
"""
for time in meter_grid_times:
ax.axvline(x=time, color='grey', linestyle='--',
linewidth=1, alpha=0.6)
def plot_predictions(audio_features: AudioFeature, binary_predictions: np.ndarray) -> None:
"""Plot the audio waveform and overlay the predicted chorus locations.
Args:
audio_features: An object containing audio features and components
binary_predictions: Array of binary predictions indicating chorus locations
"""
meter_grid_times = librosa.frames_to_time(
audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
fig, ax = plt.subplots(figsize=(12.5, 3), dpi=96)
# Display harmonic and percussive components
librosa.display.waveshow(audio_features.y_harm, sr=audio_features.sr,
alpha=0.8, ax=ax, color='deepskyblue')
librosa.display.waveshow(audio_features.y_perc, sr=audio_features.sr,
alpha=0.7, ax=ax, color='plum')
plot_meter_lines(ax, meter_grid_times)
# Highlight chorus sections
for i, prediction in enumerate(binary_predictions):
start_time = meter_grid_times[i]
end_time = meter_grid_times[i + 1] if i < len(
meter_grid_times) - 1 else len(audio_features.y) / audio_features.sr
if prediction == 1:
ax.axvspan(start_time, end_time, color='green', alpha=0.3,
label='Predicted Chorus' if i == 0 else None)
# Set plot limits and labels
ax.set_xlim([0, len(audio_features.y) / audio_features.sr])
ax.set_ylabel('Amplitude')
audio_file_name = os.path.basename(audio_features.audio_path)
ax.set_title(
f'Chorus Predictions for {os.path.splitext(audio_file_name)[0]}')
# Add legend
chorus_patch = plt.Rectangle((0, 0), 1, 1, fc='green', alpha=0.3)
handles, labels = ax.get_legend_handles_labels()
handles.append(chorus_patch)
labels.append('Chorus')
ax.legend(handles=handles, labels=labels)
# Set x-tick labels in minutes:seconds format
duration = len(audio_features.y) / audio_features.sr
xticks = np.arange(0, duration, 10)
xlabels = [f"{int(tick // 60)}:{int(tick % 60):02d}" for tick in xticks]
ax.set_xticks(xticks)
ax.set_xticklabels(xlabels)
plt.tight_layout()
plt.show(block=False) |