#!/usr/bin/env python3 # -*- coding: utf-8 -*- """Module for visualizing audio data and chorus predictions.""" from typing import List import librosa import librosa.display import matplotlib.pyplot as plt import numpy as np import os from chorus_detection.audio.processor import AudioFeature def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None: """Draw meter grid lines on the plot. Args: ax: The matplotlib axes object to draw on meter_grid_times: Array of times at which to draw the meter lines """ for time in meter_grid_times: ax.axvline(x=time, color='grey', linestyle='--', linewidth=1, alpha=0.6) def plot_predictions(audio_features: AudioFeature, binary_predictions: np.ndarray) -> None: """Plot the audio waveform and overlay the predicted chorus locations. Args: audio_features: An object containing audio features and components binary_predictions: Array of binary predictions indicating chorus locations """ meter_grid_times = librosa.frames_to_time( audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length) fig, ax = plt.subplots(figsize=(12.5, 3), dpi=96) # Display harmonic and percussive components librosa.display.waveshow(audio_features.y_harm, sr=audio_features.sr, alpha=0.8, ax=ax, color='deepskyblue') librosa.display.waveshow(audio_features.y_perc, sr=audio_features.sr, alpha=0.7, ax=ax, color='plum') plot_meter_lines(ax, meter_grid_times) # Highlight chorus sections for i, prediction in enumerate(binary_predictions): start_time = meter_grid_times[i] end_time = meter_grid_times[i + 1] if i < len( meter_grid_times) - 1 else len(audio_features.y) / audio_features.sr if prediction == 1: ax.axvspan(start_time, end_time, color='green', alpha=0.3, label='Predicted Chorus' if i == 0 else None) # Set plot limits and labels ax.set_xlim([0, len(audio_features.y) / audio_features.sr]) ax.set_ylabel('Amplitude') audio_file_name = os.path.basename(audio_features.audio_path) ax.set_title( f'Chorus Predictions for {os.path.splitext(audio_file_name)[0]}') # Add legend chorus_patch = plt.Rectangle((0, 0), 1, 1, fc='green', alpha=0.3) handles, labels = ax.get_legend_handles_labels() handles.append(chorus_patch) labels.append('Chorus') ax.legend(handles=handles, labels=labels) # Set x-tick labels in minutes:seconds format duration = len(audio_features.y) / audio_features.sr xticks = np.arange(0, duration, 10) xlabels = [f"{int(tick // 60)}:{int(tick % 60):02d}" for tick in xticks] ax.set_xticks(xticks) ax.set_xticklabels(xlabels) plt.tight_layout() plt.show(block=False)