Spaces:
Sleeping
Sleeping
File size: 6,583 Bytes
ad0da04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Module for loading and managing the CRNN model for chorus detection."""
import os
from typing import Any, Optional, List, Tuple, Union
import numpy as np
import tensorflow as tf
from chorus_detection.config import MODEL_PATH, DOCKER_MODEL_PATH
from chorus_detection.utils.logging import logger
def load_CRNN_model(model_path: str = str(MODEL_PATH)) -> tf.keras.Model:
"""Load a CRNN model with custom loss and accuracy functions.
Args:
model_path: Path to the saved model
Returns:
Loaded Keras model
Raises:
RuntimeError: If the model cannot be loaded
"""
try:
# Define custom objects required for model loading
custom_objects = {
'custom_binary_crossentropy': lambda y_true, y_pred: y_pred,
'custom_accuracy': lambda y_true, y_pred: y_pred
}
# Try to load the model with custom objects
logger.info(f"Loading model from: {model_path}")
model = tf.keras.models.load_model(
model_path, custom_objects=custom_objects, compile=False)
# Compile the model with default optimizer and loss for prediction only
model.compile(optimizer='adam', loss='binary_crossentropy')
return model
except Exception as e:
logger.error(f"Error loading model from {model_path}: {e}")
# Try Docker container path as fallback
if model_path != DOCKER_MODEL_PATH and os.path.exists(DOCKER_MODEL_PATH):
logger.info(f"Trying Docker path: {DOCKER_MODEL_PATH}")
return load_CRNN_model(DOCKER_MODEL_PATH)
raise RuntimeError(f"Failed to load model: {e}")
def smooth_predictions(predictions: np.ndarray) -> np.ndarray:
"""Smooth predictions by correcting isolated mispredictions and removing short sequences.
Args:
predictions: Array of binary predictions
Returns:
Smoothed array of binary predictions
"""
# Convert to numpy array if not already
data = np.array(predictions, copy=True) if not isinstance(predictions, np.ndarray) else predictions.copy()
# First pass: Correct isolated 0's (handle 0's surrounded by 1's)
for i in range(1, len(data) - 1):
if data[i] == 0 and data[i - 1] == 1 and data[i + 1] == 1:
data[i] = 1
# Second pass: Correct isolated 1's (handle 1's surrounded by 0's)
corrected_data = data.copy()
for i in range(1, len(data) - 1):
if data[i] == 1 and data[i - 1] == 0 and data[i + 1] == 0:
corrected_data[i] = 0
# Third pass: Remove short sequences of 1s (less than 5 consecutive 1's)
smoothed_data = corrected_data.copy()
sequence_start = None
for i in range(len(corrected_data)):
if corrected_data[i] == 1:
if sequence_start is None:
sequence_start = i
else:
if sequence_start is not None:
sequence_length = i - sequence_start
if sequence_length < 5:
smoothed_data[sequence_start:i] = 0
sequence_start = None
# Handle the case where the sequence extends to the end
if sequence_start is not None:
sequence_length = len(corrected_data) - sequence_start
if sequence_length < 5:
smoothed_data[sequence_start:] = 0
return smoothed_data
def make_predictions(model: tf.keras.Model, processed_audio: np.ndarray,
audio_features: Any, url: Optional[str] = None,
video_name: Optional[str] = None) -> np.ndarray:
"""Generate predictions from the model and process them.
Args:
model: The loaded model for making predictions
processed_audio: The audio data that has been processed for prediction
audio_features: Audio features object containing necessary metadata
url: YouTube URL of the audio file (optional)
video_name: Name of the video (optional)
Returns:
The smoothed binary predictions
"""
import librosa
logger.info("Generating predictions...")
# Make predictions
predictions = model.predict(processed_audio)[0]
# Convert to binary predictions and handle potential size mismatch
meter_grid_length = len(audio_features.meter_grid) - 1
if len(predictions) > meter_grid_length:
predictions = predictions[:meter_grid_length]
binary_predictions = np.round(predictions).flatten()
# Apply smoothing to improve prediction quality
smoothed_predictions = smooth_predictions(binary_predictions)
# Get times for identified chorus sections
meter_grid_times = librosa.frames_to_time(
audio_features.meter_grid,
sr=audio_features.sr,
hop_length=audio_features.hop_length
)
# Identify where choruses start
chorus_start_times = [
meter_grid_times[i] for i in range(len(smoothed_predictions))
if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)
]
# Print results if URL and video name are provided (CLI mode)
if url and video_name:
_print_chorus_results(url, video_name, chorus_start_times)
return smoothed_predictions
def _print_chorus_results(url: str, video_name: str, chorus_start_times: List[float]) -> None:
"""Print formatted results showing identified choruses with links.
Args:
url: YouTube URL of the analyzed video
video_name: Name of the video
chorus_start_times: List of start times (in seconds) for identified choruses
"""
# Create YouTube links with time stamps
youtube_links = [
f"\033]8;;{url}&t={int(start_time)}s\033\\{url}&t={int(start_time)}s\033]8;;\033\\"
for start_time in chorus_start_times
]
# Format the output
link_lengths = [len(link) for link in youtube_links]
max_length = max(link_lengths + [len(video_name), len(f"Number of choruses identified: {len(chorus_start_times)}")]) if link_lengths else 50
header_footer = "=" * (max_length + 4)
# Print the results
print("\n\n")
print(header_footer)
print(f"{video_name.center(max_length + 2)}")
print(f"Number of choruses identified: {len(chorus_start_times)}".center(max_length + 4))
print(header_footer)
if chorus_start_times:
for link in youtube_links:
print(link)
else:
print("No choruses identified.")
print(header_footer) |