#!/usr/bin/env python3 # -*- coding: utf-8 -*- """Module for loading and managing the CRNN model for chorus detection.""" import os from typing import Any, Optional, List, Tuple, Union import numpy as np import tensorflow as tf from chorus_detection.config import MODEL_PATH, DOCKER_MODEL_PATH from chorus_detection.utils.logging import logger def load_CRNN_model(model_path: str = str(MODEL_PATH)) -> tf.keras.Model: """Load a CRNN model with custom loss and accuracy functions. Args: model_path: Path to the saved model Returns: Loaded Keras model Raises: RuntimeError: If the model cannot be loaded """ try: # Define custom objects required for model loading custom_objects = { 'custom_binary_crossentropy': lambda y_true, y_pred: y_pred, 'custom_accuracy': lambda y_true, y_pred: y_pred } # Try to load the model with custom objects logger.info(f"Loading model from: {model_path}") model = tf.keras.models.load_model( model_path, custom_objects=custom_objects, compile=False) # Compile the model with default optimizer and loss for prediction only model.compile(optimizer='adam', loss='binary_crossentropy') return model except Exception as e: logger.error(f"Error loading model from {model_path}: {e}") # Try Docker container path as fallback if model_path != DOCKER_MODEL_PATH and os.path.exists(DOCKER_MODEL_PATH): logger.info(f"Trying Docker path: {DOCKER_MODEL_PATH}") return load_CRNN_model(DOCKER_MODEL_PATH) raise RuntimeError(f"Failed to load model: {e}") def smooth_predictions(predictions: np.ndarray) -> np.ndarray: """Smooth predictions by correcting isolated mispredictions and removing short sequences. Args: predictions: Array of binary predictions Returns: Smoothed array of binary predictions """ # Convert to numpy array if not already data = np.array(predictions, copy=True) if not isinstance(predictions, np.ndarray) else predictions.copy() # First pass: Correct isolated 0's (handle 0's surrounded by 1's) for i in range(1, len(data) - 1): if data[i] == 0 and data[i - 1] == 1 and data[i + 1] == 1: data[i] = 1 # Second pass: Correct isolated 1's (handle 1's surrounded by 0's) corrected_data = data.copy() for i in range(1, len(data) - 1): if data[i] == 1 and data[i - 1] == 0 and data[i + 1] == 0: corrected_data[i] = 0 # Third pass: Remove short sequences of 1s (less than 5 consecutive 1's) smoothed_data = corrected_data.copy() sequence_start = None for i in range(len(corrected_data)): if corrected_data[i] == 1: if sequence_start is None: sequence_start = i else: if sequence_start is not None: sequence_length = i - sequence_start if sequence_length < 5: smoothed_data[sequence_start:i] = 0 sequence_start = None # Handle the case where the sequence extends to the end if sequence_start is not None: sequence_length = len(corrected_data) - sequence_start if sequence_length < 5: smoothed_data[sequence_start:] = 0 return smoothed_data def make_predictions(model: tf.keras.Model, processed_audio: np.ndarray, audio_features: Any, url: Optional[str] = None, video_name: Optional[str] = None) -> np.ndarray: """Generate predictions from the model and process them. Args: model: The loaded model for making predictions processed_audio: The audio data that has been processed for prediction audio_features: Audio features object containing necessary metadata url: YouTube URL of the audio file (optional) video_name: Name of the video (optional) Returns: The smoothed binary predictions """ import librosa logger.info("Generating predictions...") # Make predictions predictions = model.predict(processed_audio)[0] # Convert to binary predictions and handle potential size mismatch meter_grid_length = len(audio_features.meter_grid) - 1 if len(predictions) > meter_grid_length: predictions = predictions[:meter_grid_length] binary_predictions = np.round(predictions).flatten() # Apply smoothing to improve prediction quality smoothed_predictions = smooth_predictions(binary_predictions) # Get times for identified chorus sections meter_grid_times = librosa.frames_to_time( audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length ) # Identify where choruses start chorus_start_times = [ meter_grid_times[i] for i in range(len(smoothed_predictions)) if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0) ] # Print results if URL and video name are provided (CLI mode) if url and video_name: _print_chorus_results(url, video_name, chorus_start_times) return smoothed_predictions def _print_chorus_results(url: str, video_name: str, chorus_start_times: List[float]) -> None: """Print formatted results showing identified choruses with links. Args: url: YouTube URL of the analyzed video video_name: Name of the video chorus_start_times: List of start times (in seconds) for identified choruses """ # Create YouTube links with time stamps youtube_links = [ f"\033]8;;{url}&t={int(start_time)}s\033\\{url}&t={int(start_time)}s\033]8;;\033\\" for start_time in chorus_start_times ] # Format the output link_lengths = [len(link) for link in youtube_links] max_length = max(link_lengths + [len(video_name), len(f"Number of choruses identified: {len(chorus_start_times)}")]) if link_lengths else 50 header_footer = "=" * (max_length + 4) # Print the results print("\n\n") print(header_footer) print(f"{video_name.center(max_length + 2)}") print(f"Number of choruses identified: {len(chorus_start_times)}".center(max_length + 4)) print(header_footer) if chorus_start_times: for link in youtube_links: print(link) else: print("No choruses identified.") print(header_footer)