File size: 2,914 Bytes
ad0da04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""Module for visualizing audio data and chorus predictions."""

from typing import List

import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np
import os

from chorus_detection.audio.processor import AudioFeature


def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None:
    """Draw meter grid lines on the plot.
    
    Args:
        ax: The matplotlib axes object to draw on
        meter_grid_times: Array of times at which to draw the meter lines
    """
    for time in meter_grid_times:
        ax.axvline(x=time, color='grey', linestyle='--',
                   linewidth=1, alpha=0.6)


def plot_predictions(audio_features: AudioFeature, binary_predictions: np.ndarray) -> None:
    """Plot the audio waveform and overlay the predicted chorus locations.
    
    Args:
        audio_features: An object containing audio features and components
        binary_predictions: Array of binary predictions indicating chorus locations
    """
    meter_grid_times = librosa.frames_to_time(
        audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
    fig, ax = plt.subplots(figsize=(12.5, 3), dpi=96)

    # Display harmonic and percussive components
    librosa.display.waveshow(audio_features.y_harm, sr=audio_features.sr,
                             alpha=0.8, ax=ax, color='deepskyblue')
    librosa.display.waveshow(audio_features.y_perc, sr=audio_features.sr,
                             alpha=0.7, ax=ax, color='plum')
    plot_meter_lines(ax, meter_grid_times)

    # Highlight chorus sections
    for i, prediction in enumerate(binary_predictions):
        start_time = meter_grid_times[i]
        end_time = meter_grid_times[i + 1] if i < len(
            meter_grid_times) - 1 else len(audio_features.y) / audio_features.sr
        if prediction == 1:
            ax.axvspan(start_time, end_time, color='green', alpha=0.3,
                       label='Predicted Chorus' if i == 0 else None)

    # Set plot limits and labels
    ax.set_xlim([0, len(audio_features.y) / audio_features.sr])
    ax.set_ylabel('Amplitude')
    audio_file_name = os.path.basename(audio_features.audio_path)
    ax.set_title(
        f'Chorus Predictions for {os.path.splitext(audio_file_name)[0]}')

    # Add legend
    chorus_patch = plt.Rectangle((0, 0), 1, 1, fc='green', alpha=0.3)
    handles, labels = ax.get_legend_handles_labels()
    handles.append(chorus_patch)
    labels.append('Chorus')
    ax.legend(handles=handles, labels=labels)

    # Set x-tick labels in minutes:seconds format
    duration = len(audio_features.y) / audio_features.sr
    xticks = np.arange(0, duration, 10)
    xlabels = [f"{int(tick // 60)}:{int(tick % 60):02d}" for tick in xticks]
    ax.set_xticks(xticks)
    ax.set_xticklabels(xlabels)

    plt.tight_layout()
    plt.show(block=False)