File size: 6,583 Bytes
ad0da04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""Module for loading and managing the CRNN model for chorus detection."""

import os
from typing import Any, Optional, List, Tuple, Union

import numpy as np
import tensorflow as tf

from chorus_detection.config import MODEL_PATH, DOCKER_MODEL_PATH
from chorus_detection.utils.logging import logger


def load_CRNN_model(model_path: str = str(MODEL_PATH)) -> tf.keras.Model:
    """Load a CRNN model with custom loss and accuracy functions.

    Args:
        model_path: Path to the saved model

    Returns:
        Loaded Keras model
        
    Raises:
        RuntimeError: If the model cannot be loaded
    """
    try:
        # Define custom objects required for model loading
        custom_objects = {
            'custom_binary_crossentropy': lambda y_true, y_pred: y_pred,
            'custom_accuracy': lambda y_true, y_pred: y_pred
        }

        # Try to load the model with custom objects
        logger.info(f"Loading model from: {model_path}")
        model = tf.keras.models.load_model(
            model_path, custom_objects=custom_objects, compile=False)
        
        # Compile the model with default optimizer and loss for prediction only
        model.compile(optimizer='adam', loss='binary_crossentropy')
        
        return model
    except Exception as e:
        logger.error(f"Error loading model from {model_path}: {e}")
        
        # Try Docker container path as fallback
        if model_path != DOCKER_MODEL_PATH and os.path.exists(DOCKER_MODEL_PATH):
            logger.info(f"Trying Docker path: {DOCKER_MODEL_PATH}")
            return load_CRNN_model(DOCKER_MODEL_PATH)
        
        raise RuntimeError(f"Failed to load model: {e}")


def smooth_predictions(predictions: np.ndarray) -> np.ndarray:
    """Smooth predictions by correcting isolated mispredictions and removing short sequences.

    Args:
        predictions: Array of binary predictions

    Returns:
        Smoothed array of binary predictions
    """
    # Convert to numpy array if not already
    data = np.array(predictions, copy=True) if not isinstance(predictions, np.ndarray) else predictions.copy()
    
    # First pass: Correct isolated 0's (handle 0's surrounded by 1's)
    for i in range(1, len(data) - 1):
        if data[i] == 0 and data[i - 1] == 1 and data[i + 1] == 1:
            data[i] = 1

    # Second pass: Correct isolated 1's (handle 1's surrounded by 0's)
    corrected_data = data.copy()
    for i in range(1, len(data) - 1):
        if data[i] == 1 and data[i - 1] == 0 and data[i + 1] == 0:
            corrected_data[i] = 0

    # Third pass: Remove short sequences of 1s (less than 5 consecutive 1's)
    smoothed_data = corrected_data.copy()
    sequence_start = None
    
    for i in range(len(corrected_data)):
        if corrected_data[i] == 1:
            if sequence_start is None:
                sequence_start = i
        else:
            if sequence_start is not None:
                sequence_length = i - sequence_start
                if sequence_length < 5:
                    smoothed_data[sequence_start:i] = 0
                sequence_start = None
    
    # Handle the case where the sequence extends to the end
    if sequence_start is not None:
        sequence_length = len(corrected_data) - sequence_start
        if sequence_length < 5:
            smoothed_data[sequence_start:] = 0

    return smoothed_data


def make_predictions(model: tf.keras.Model, processed_audio: np.ndarray, 
                     audio_features: Any, url: Optional[str] = None, 
                     video_name: Optional[str] = None) -> np.ndarray:
    """Generate predictions from the model and process them.

    Args:
        model: The loaded model for making predictions
        processed_audio: The audio data that has been processed for prediction
        audio_features: Audio features object containing necessary metadata
        url: YouTube URL of the audio file (optional)
        video_name: Name of the video (optional)

    Returns:
        The smoothed binary predictions
    """
    import librosa
    
    logger.info("Generating predictions...")
    
    # Make predictions
    predictions = model.predict(processed_audio)[0]
    
    # Convert to binary predictions and handle potential size mismatch
    meter_grid_length = len(audio_features.meter_grid) - 1
    if len(predictions) > meter_grid_length:
        predictions = predictions[:meter_grid_length]
    
    binary_predictions = np.round(predictions).flatten()
    
    # Apply smoothing to improve prediction quality
    smoothed_predictions = smooth_predictions(binary_predictions)
    
    # Get times for identified chorus sections
    meter_grid_times = librosa.frames_to_time(
        audio_features.meter_grid, 
        sr=audio_features.sr, 
        hop_length=audio_features.hop_length
    )
    
    # Identify where choruses start
    chorus_start_times = [
        meter_grid_times[i] for i in range(len(smoothed_predictions)) 
        if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)
    ]
    
    # Print results if URL and video name are provided (CLI mode)
    if url and video_name:
        _print_chorus_results(url, video_name, chorus_start_times)

    return smoothed_predictions


def _print_chorus_results(url: str, video_name: str, chorus_start_times: List[float]) -> None:
    """Print formatted results showing identified choruses with links.
    
    Args:
        url: YouTube URL of the analyzed video
        video_name: Name of the video
        chorus_start_times: List of start times (in seconds) for identified choruses
    """
    # Create YouTube links with time stamps
    youtube_links = [
        f"\033]8;;{url}&t={int(start_time)}s\033\\{url}&t={int(start_time)}s\033]8;;\033\\" 
        for start_time in chorus_start_times
    ]
    
    # Format the output
    link_lengths = [len(link) for link in youtube_links]
    max_length = max(link_lengths + [len(video_name), len(f"Number of choruses identified: {len(chorus_start_times)}")]) if link_lengths else 50
    header_footer = "=" * (max_length + 4)
    
    # Print the results
    print("\n\n")
    print(header_footer)
    print(f"{video_name.center(max_length + 2)}")
    print(f"Number of choruses identified: {len(chorus_start_times)}".center(max_length + 4))
    print(header_footer)
    
    if chorus_start_times:
        for link in youtube_links:
            print(link)
    else:
        print("No choruses identified.")
        
    print(header_footer)