Spaces:
Running
Running
Commit
·
ad0da04
1
Parent(s):
88140b5
Flatten directory structure for simpler imports
Browse files- .space/app-entrypoint.sh +0 -0
- Dockerfile +3 -7
- app.py +2 -30
- chorus_detection/__init__.py +10 -0
- chorus_detection/audio/__init__.py +0 -0
- chorus_detection/audio/data_processing.py +180 -0
- chorus_detection/audio/processor.py +409 -0
- chorus_detection/config.py +54 -0
- chorus_detection/models/__init__.py +0 -0
- chorus_detection/models/crnn.py +186 -0
- chorus_detection/utils/__init__.py +0 -0
- chorus_detection/utils/cli.py +107 -0
- chorus_detection/utils/logging.py +53 -0
- chorus_detection/visualization/__init__.py +0 -0
- chorus_detection/visualization/plotter.py +78 -0
- setup.py +1 -2
- streamlit_app.py +484 -0
.space/app-entrypoint.sh
CHANGED
Binary files a/.space/app-entrypoint.sh and b/.space/app-entrypoint.sh differ
|
|
Dockerfile
CHANGED
@@ -3,7 +3,7 @@ FROM python:3.9-slim
|
|
3 |
# Set environment variables
|
4 |
ENV PYTHONDONTWRITEBYTECODE=1
|
5 |
ENV PYTHONUNBUFFERED=1
|
6 |
-
ENV PYTHONPATH="${PYTHONPATH}:/app:/
|
7 |
ENV MODEL_REVISION="20e66eb3d0788373c3bdc5b28fa2f2587b0e475f3bbc47e8ab9ff0dbdbb2df32"
|
8 |
ENV MODEL_HF_REPO="dennisvdang/chorus-detection"
|
9 |
ENV HF_MODEL_FILENAME="chorus_detection_crnn.h5"
|
@@ -27,21 +27,17 @@ RUN pip install -e .
|
|
27 |
# Make the entry point script executable
|
28 |
RUN chmod +x .space/app-entrypoint.sh || echo "Could not chmod app-entrypoint.sh"
|
29 |
|
30 |
-
# Create symlinks to ensure proper imports
|
31 |
-
RUN ln -sf /app/src /src || echo "Could not create symlink"
|
32 |
-
|
33 |
# Ensure chorus_detection package is properly installed
|
34 |
RUN cd /app && \
|
35 |
python -c "import chorus_detection; print(f'Successfully imported chorus_detection module from {chorus_detection.__file__}')" || \
|
36 |
-
echo "Warning: chorus_detection module not properly installed
|
37 |
|
38 |
# Ensure model exists and debug info
|
39 |
RUN echo "Debug: ls -la /app" && ls -la /app && \
|
40 |
-
echo "Debug: ls -la /app/src" && ls -la /app/src && \
|
41 |
echo "Debug: PYTHONPATH=$PYTHONPATH" && \
|
42 |
python -c "import sys; print(f'Python path: {sys.path}')" && \
|
43 |
python -c "import os; print(f'Working directory: {os.getcwd()}')" && \
|
44 |
-
python -c "from
|
45 |
|
46 |
# Expose port for Streamlit
|
47 |
EXPOSE 7860
|
|
|
3 |
# Set environment variables
|
4 |
ENV PYTHONDONTWRITEBYTECODE=1
|
5 |
ENV PYTHONUNBUFFERED=1
|
6 |
+
ENV PYTHONPATH="${PYTHONPATH}:/app:/home/user/app:."
|
7 |
ENV MODEL_REVISION="20e66eb3d0788373c3bdc5b28fa2f2587b0e475f3bbc47e8ab9ff0dbdbb2df32"
|
8 |
ENV MODEL_HF_REPO="dennisvdang/chorus-detection"
|
9 |
ENV HF_MODEL_FILENAME="chorus_detection_crnn.h5"
|
|
|
27 |
# Make the entry point script executable
|
28 |
RUN chmod +x .space/app-entrypoint.sh || echo "Could not chmod app-entrypoint.sh"
|
29 |
|
|
|
|
|
|
|
30 |
# Ensure chorus_detection package is properly installed
|
31 |
RUN cd /app && \
|
32 |
python -c "import chorus_detection; print(f'Successfully imported chorus_detection module from {chorus_detection.__file__}')" || \
|
33 |
+
echo "Warning: chorus_detection module not properly installed"
|
34 |
|
35 |
# Ensure model exists and debug info
|
36 |
RUN echo "Debug: ls -la /app" && ls -la /app && \
|
|
|
37 |
echo "Debug: PYTHONPATH=$PYTHONPATH" && \
|
38 |
python -c "import sys; print(f'Python path: {sys.path}')" && \
|
39 |
python -c "import os; print(f'Working directory: {os.getcwd()}')" && \
|
40 |
+
python -c "from download_model import ensure_model_exists; ensure_model_exists(revision='${MODEL_REVISION}')" || echo "Warning: Model download failed during build"
|
41 |
|
42 |
# Expose port for Streamlit
|
43 |
EXPOSE 7860
|
app.py
CHANGED
@@ -10,14 +10,6 @@ without circular imports.
|
|
10 |
import os
|
11 |
import sys
|
12 |
import logging
|
13 |
-
import subprocess
|
14 |
-
|
15 |
-
# Add src directory to path
|
16 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
17 |
-
src_dir = os.path.join(current_dir, "src")
|
18 |
-
if src_dir not in sys.path:
|
19 |
-
sys.path.insert(0, src_dir)
|
20 |
-
sys.path.insert(0, current_dir)
|
21 |
|
22 |
# Configure logging
|
23 |
logging.basicConfig(
|
@@ -33,34 +25,14 @@ if os.environ.get("SPACE_ID"):
|
|
33 |
logger.info(f"MODEL_REVISION: {os.environ.get('MODEL_REVISION')}")
|
34 |
logger.info(f"Current working directory: {os.getcwd()}")
|
35 |
logger.info(f"Directory contents: {os.listdir()}")
|
36 |
-
if os.path.exists('src'):
|
37 |
-
logger.info(f"src directory contents: {os.listdir('src')}")
|
38 |
-
|
39 |
-
# Install package in development mode
|
40 |
-
try:
|
41 |
-
logger.info("Installing package in development mode...")
|
42 |
-
subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "."])
|
43 |
-
logger.info("Package installation successful")
|
44 |
-
except Exception as e:
|
45 |
-
logger.error(f"Error installing package: {e}")
|
46 |
-
|
47 |
-
# Verify python path
|
48 |
-
logger.info(f"Python path after update: {sys.path}")
|
49 |
-
|
50 |
-
# Try importing chorus_detection to verify
|
51 |
-
try:
|
52 |
-
import chorus_detection
|
53 |
-
logger.info(f"Successfully imported chorus_detection from {chorus_detection.__file__}")
|
54 |
-
except ImportError as e:
|
55 |
-
logger.error(f"Failed to import chorus_detection: {e}")
|
56 |
|
57 |
def main():
|
58 |
"""Main entry point for the Streamlit app."""
|
59 |
logger.info("Starting Streamlit app...")
|
60 |
# Import the Streamlit app module directly
|
61 |
-
import
|
62 |
# Run the Streamlit app
|
63 |
-
|
64 |
|
65 |
if __name__ == "__main__":
|
66 |
try:
|
|
|
10 |
import os
|
11 |
import sys
|
12 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Configure logging
|
15 |
logging.basicConfig(
|
|
|
25 |
logger.info(f"MODEL_REVISION: {os.environ.get('MODEL_REVISION')}")
|
26 |
logger.info(f"Current working directory: {os.getcwd()}")
|
27 |
logger.info(f"Directory contents: {os.listdir()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def main():
|
30 |
"""Main entry point for the Streamlit app."""
|
31 |
logger.info("Starting Streamlit app...")
|
32 |
# Import the Streamlit app module directly
|
33 |
+
import streamlit_app
|
34 |
# Run the Streamlit app
|
35 |
+
streamlit_app.main()
|
36 |
|
37 |
if __name__ == "__main__":
|
38 |
try:
|
chorus_detection/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Chorus Detection package for identifying choruses in music.
|
2 |
+
|
3 |
+
This package contains modules for:
|
4 |
+
- Audio processing and feature extraction
|
5 |
+
- Machine learning models for chorus detection
|
6 |
+
- Visualization tools for audio analysis
|
7 |
+
- Utility functions
|
8 |
+
"""
|
9 |
+
|
10 |
+
__version__ = "0.1.0"
|
chorus_detection/audio/__init__.py
ADDED
File without changes
|
chorus_detection/audio/data_processing.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""Module for audio data processing including segmentation and positional encoding."""
|
5 |
+
|
6 |
+
from typing import List, Optional, Tuple, Any
|
7 |
+
|
8 |
+
import librosa
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from chorus_detection.audio.processor import AudioFeature
|
12 |
+
from chorus_detection.config import SR, HOP_LENGTH, MAX_FRAMES, MAX_METERS, N_FEATURES
|
13 |
+
from chorus_detection.utils.logging import logger
|
14 |
+
|
15 |
+
|
16 |
+
def segment_data_meters(data: np.ndarray, meter_grid: np.ndarray) -> List[np.ndarray]:
|
17 |
+
"""Divide song data into segments based on measure grid frames.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
data: The song data to be segmented
|
21 |
+
meter_grid: The grid indicating the start of each measure
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
A list of song data segments
|
25 |
+
"""
|
26 |
+
# Create segments using vectorized operations
|
27 |
+
meter_segments = [data[s:e] for s, e in zip(meter_grid[:-1], meter_grid[1:])]
|
28 |
+
|
29 |
+
# Convert all segments to float32 for consistent processing
|
30 |
+
meter_segments = [segment.astype(np.float32) for segment in meter_segments]
|
31 |
+
|
32 |
+
return meter_segments
|
33 |
+
|
34 |
+
|
35 |
+
def positional_encoding(position: int, d_model: int) -> np.ndarray:
|
36 |
+
"""Generate a positional encoding for a given position and model dimension.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
position: The position for which to generate the encoding
|
40 |
+
d_model: The dimension of the model
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
The positional encoding
|
44 |
+
"""
|
45 |
+
# Create position array
|
46 |
+
positions = np.arange(position)[:, np.newaxis]
|
47 |
+
|
48 |
+
# Calculate dimension-based scaling factors
|
49 |
+
dim_indices = np.arange(d_model)[np.newaxis, :]
|
50 |
+
angles = positions / np.power(10000, (2 * (dim_indices // 2)) / np.float32(d_model))
|
51 |
+
|
52 |
+
# Apply sine to even indices and cosine to odd indices
|
53 |
+
encodings = np.zeros((position, d_model), dtype=np.float32)
|
54 |
+
encodings[:, 0::2] = np.sin(angles[:, 0::2])
|
55 |
+
encodings[:, 1::2] = np.cos(angles[:, 1::2])
|
56 |
+
|
57 |
+
return encodings
|
58 |
+
|
59 |
+
|
60 |
+
def apply_hierarchical_positional_encoding(segments: List[np.ndarray]) -> List[np.ndarray]:
|
61 |
+
"""Apply positional encoding at the meter and frame levels to a list of segments.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
segments: The list of segments to encode
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
The list of segments with applied positional encoding
|
68 |
+
"""
|
69 |
+
if not segments:
|
70 |
+
logger.warning("No segments to encode")
|
71 |
+
return []
|
72 |
+
|
73 |
+
n_features = segments[0].shape[1]
|
74 |
+
|
75 |
+
# Generate measure-level positional encodings
|
76 |
+
measure_level_encodings = positional_encoding(len(segments), n_features)
|
77 |
+
|
78 |
+
# Apply hierarchical encodings to each segment
|
79 |
+
encoded_segments = []
|
80 |
+
for i, segment in enumerate(segments):
|
81 |
+
# Generate frame-level positional encoding
|
82 |
+
frame_level_encoding = positional_encoding(len(segment), n_features)
|
83 |
+
|
84 |
+
# Combine frame-level and measure-level encodings
|
85 |
+
encoded_segment = segment + frame_level_encoding + measure_level_encodings[i]
|
86 |
+
encoded_segments.append(encoded_segment)
|
87 |
+
|
88 |
+
return encoded_segments
|
89 |
+
|
90 |
+
|
91 |
+
def pad_song(encoded_segments: List[np.ndarray], max_frames: int = MAX_FRAMES,
|
92 |
+
max_meters: int = MAX_METERS, n_features: int = N_FEATURES) -> np.ndarray:
|
93 |
+
"""Pad or truncate the encoded segments to have the specified dimensions.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
encoded_segments: The encoded segments to pad or truncate
|
97 |
+
max_frames: The maximum number of frames per segment
|
98 |
+
max_meters: The maximum number of meters
|
99 |
+
n_features: The number of features per frame
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
The padded or truncated song as a numpy array
|
103 |
+
"""
|
104 |
+
if not encoded_segments:
|
105 |
+
logger.warning("No encoded segments to pad")
|
106 |
+
return np.zeros((max_meters, max_frames, n_features), dtype=np.float32)
|
107 |
+
|
108 |
+
# Pad or truncate each meter/segment to max_frames
|
109 |
+
padded_meters = []
|
110 |
+
for meter in encoded_segments:
|
111 |
+
# Truncate if longer than max_frames
|
112 |
+
truncated_meter = meter[:max_frames] if meter.shape[0] > max_frames else meter
|
113 |
+
|
114 |
+
# Pad if shorter than max_frames
|
115 |
+
if truncated_meter.shape[0] < max_frames:
|
116 |
+
padding = ((0, max_frames - truncated_meter.shape[0]), (0, 0))
|
117 |
+
padded_meter = np.pad(truncated_meter, padding, 'constant', constant_values=0)
|
118 |
+
else:
|
119 |
+
padded_meter = truncated_meter
|
120 |
+
|
121 |
+
padded_meters.append(padded_meter)
|
122 |
+
|
123 |
+
# Create padding meter (all zeros)
|
124 |
+
padding_meter = np.zeros((max_frames, n_features), dtype=np.float32)
|
125 |
+
|
126 |
+
# Truncate or pad to max_meters
|
127 |
+
if len(padded_meters) > max_meters:
|
128 |
+
padded_song = np.array(padded_meters[:max_meters])
|
129 |
+
else:
|
130 |
+
padded_song = np.array(padded_meters + [padding_meter] * (max_meters - len(padded_meters)))
|
131 |
+
|
132 |
+
return padded_song
|
133 |
+
|
134 |
+
|
135 |
+
def process_audio(audio_path: str, trim_silence: bool = True, sr: int = SR,
|
136 |
+
hop_length: int = HOP_LENGTH) -> Tuple[Optional[np.ndarray], Optional[AudioFeature]]:
|
137 |
+
"""Process an audio file, extracting features and applying positional encoding.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
audio_path: The path to the audio file
|
141 |
+
trim_silence: Whether to trim silence from the audio
|
142 |
+
sr: The sample rate to use when loading the audio
|
143 |
+
hop_length: The hop length to use for feature extraction
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
A tuple containing the processed audio and its features
|
147 |
+
"""
|
148 |
+
logger.info(f"Processing audio file: {audio_path}")
|
149 |
+
|
150 |
+
try:
|
151 |
+
# First optionally strip silence
|
152 |
+
if trim_silence:
|
153 |
+
from chorus_detection.audio.processor import strip_silence
|
154 |
+
strip_silence(audio_path)
|
155 |
+
|
156 |
+
# Create audio feature object and extract features
|
157 |
+
audio_features = AudioFeature(audio_path=audio_path, sr=sr, hop_length=hop_length)
|
158 |
+
audio_features.extract_features()
|
159 |
+
audio_features.create_meter_grid()
|
160 |
+
|
161 |
+
# Segment the audio data by meter grid
|
162 |
+
audio_segments = segment_data_meters(
|
163 |
+
audio_features.combined_features, audio_features.meter_grid)
|
164 |
+
|
165 |
+
# Apply positional encoding
|
166 |
+
encoded_audio_segments = apply_hierarchical_positional_encoding(audio_segments)
|
167 |
+
|
168 |
+
# Pad song to fixed dimensions and add batch dimension
|
169 |
+
processed_audio = np.expand_dims(pad_song(encoded_audio_segments), axis=0)
|
170 |
+
|
171 |
+
logger.info(f"Audio processing complete: {processed_audio.shape}")
|
172 |
+
return processed_audio, audio_features
|
173 |
+
|
174 |
+
except Exception as e:
|
175 |
+
logger.error(f"Error processing audio: {e}")
|
176 |
+
|
177 |
+
import traceback
|
178 |
+
logger.debug(traceback.format_exc())
|
179 |
+
|
180 |
+
return None, None
|
chorus_detection/audio/processor.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""Module for audio feature extraction and processing."""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import subprocess
|
8 |
+
import time
|
9 |
+
from functools import reduce
|
10 |
+
from pathlib import Path
|
11 |
+
from typing import List, Tuple, Optional, Dict, Any, Union
|
12 |
+
|
13 |
+
import librosa
|
14 |
+
import numpy as np
|
15 |
+
from pydub import AudioSegment
|
16 |
+
from pydub.silence import detect_nonsilent
|
17 |
+
from sklearn.preprocessing import StandardScaler
|
18 |
+
|
19 |
+
from chorus_detection.config import SR, HOP_LENGTH, AUDIO_TEMP_PATH
|
20 |
+
from chorus_detection.utils.logging import logger
|
21 |
+
|
22 |
+
|
23 |
+
def extract_audio(url: str, output_path: str = str(AUDIO_TEMP_PATH)) -> Tuple[Optional[str], Optional[str]]:
|
24 |
+
"""Download audio from YouTube URL and save as MP3 using yt-dlp.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
url: YouTube URL of the audio file
|
28 |
+
output_path: Path to save the downloaded audio file
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
Tuple containing path to the downloaded audio file and the video title, or None if download fails
|
32 |
+
"""
|
33 |
+
try:
|
34 |
+
# Create output directory if it doesn't exist
|
35 |
+
os.makedirs(output_path, exist_ok=True)
|
36 |
+
|
37 |
+
# Create a unique filename using timestamp
|
38 |
+
timestamp = int(time.time())
|
39 |
+
output_file = os.path.join(output_path, f"audio_{timestamp}.mp3")
|
40 |
+
|
41 |
+
# Get the video title first
|
42 |
+
video_title = get_video_title(url) or f"Video_{timestamp}"
|
43 |
+
|
44 |
+
# Download the audio
|
45 |
+
success, error_msg = download_audio(url, output_file)
|
46 |
+
|
47 |
+
if not success:
|
48 |
+
handle_download_error(error_msg)
|
49 |
+
return None, None
|
50 |
+
|
51 |
+
# Check if file exists and is valid
|
52 |
+
if os.path.exists(output_file) and os.path.getsize(output_file) > 0:
|
53 |
+
logger.info(f"Successfully downloaded: {video_title}")
|
54 |
+
return output_file, video_title
|
55 |
+
else:
|
56 |
+
logger.error("Download completed but file not found or empty")
|
57 |
+
return None, None
|
58 |
+
|
59 |
+
except Exception as e:
|
60 |
+
import traceback
|
61 |
+
error_details = traceback.format_exc()
|
62 |
+
logger.error(f"An error occurred during YouTube download: {e}")
|
63 |
+
logger.debug(f"Error details: {error_details}")
|
64 |
+
|
65 |
+
check_yt_dlp_installation()
|
66 |
+
return None, None
|
67 |
+
|
68 |
+
|
69 |
+
def get_video_title(url: str) -> Optional[str]:
|
70 |
+
"""Get the title of a YouTube video.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
url: YouTube URL
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
Video title if successful, None otherwise
|
77 |
+
"""
|
78 |
+
try:
|
79 |
+
title_command = ['yt-dlp', '--get-title', '--no-warnings', url]
|
80 |
+
video_title = subprocess.check_output(title_command, universal_newlines=True).strip()
|
81 |
+
return video_title
|
82 |
+
except subprocess.CalledProcessError as e:
|
83 |
+
logger.warning(f"Could not retrieve video title: {str(e)}")
|
84 |
+
return None
|
85 |
+
|
86 |
+
|
87 |
+
def download_audio(url: str, output_file: str) -> Tuple[bool, str]:
|
88 |
+
"""Download audio from YouTube URL using yt-dlp.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
url: YouTube URL
|
92 |
+
output_file: Output file path
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
Tuple containing (success, error_message)
|
96 |
+
"""
|
97 |
+
command = [
|
98 |
+
'yt-dlp',
|
99 |
+
'-f', 'bestaudio',
|
100 |
+
'--extract-audio',
|
101 |
+
'--audio-format', 'mp3',
|
102 |
+
'--audio-quality', '0', # Best quality
|
103 |
+
'--output', output_file,
|
104 |
+
'--no-playlist',
|
105 |
+
'--verbose',
|
106 |
+
url
|
107 |
+
]
|
108 |
+
|
109 |
+
process = subprocess.Popen(
|
110 |
+
command,
|
111 |
+
stdout=subprocess.PIPE,
|
112 |
+
stderr=subprocess.PIPE,
|
113 |
+
universal_newlines=True
|
114 |
+
)
|
115 |
+
stdout, stderr = process.communicate()
|
116 |
+
|
117 |
+
if process.returncode != 0:
|
118 |
+
error_msg = f"Error downloading from YouTube (code {process.returncode}): {stderr}"
|
119 |
+
return False, error_msg
|
120 |
+
|
121 |
+
return True, ""
|
122 |
+
|
123 |
+
|
124 |
+
def handle_download_error(error_msg: str) -> None:
|
125 |
+
"""Handle common YouTube download errors with helpful messages.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
error_msg: Error message from yt-dlp
|
129 |
+
"""
|
130 |
+
logger.error(error_msg)
|
131 |
+
|
132 |
+
if "Sign in to confirm you're not a bot" in error_msg:
|
133 |
+
logger.error("YouTube is detecting automated access. Try using a local file instead.")
|
134 |
+
elif any(x in error_msg.lower() for x in ["unavailable video", "private video"]):
|
135 |
+
logger.error("The video appears to be private or unavailable. Please try another URL.")
|
136 |
+
elif "copyright" in error_msg.lower():
|
137 |
+
logger.error("The video may be blocked due to copyright restrictions.")
|
138 |
+
elif any(x in error_msg.lower() for x in ["rate limit", "429"]):
|
139 |
+
logger.error("YouTube rate limit reached. Please try again later.")
|
140 |
+
|
141 |
+
|
142 |
+
def check_yt_dlp_installation() -> None:
|
143 |
+
"""Check if yt-dlp is installed and provide guidance if it's not."""
|
144 |
+
try:
|
145 |
+
subprocess.run(['yt-dlp', '--version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
146 |
+
except FileNotFoundError:
|
147 |
+
logger.error("yt-dlp is not installed or not in PATH. Please install it with: pip install yt-dlp")
|
148 |
+
|
149 |
+
|
150 |
+
def strip_silence(audio_path: str) -> None:
|
151 |
+
"""Remove silent parts from an audio file.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
audio_path: Path to the audio file
|
155 |
+
"""
|
156 |
+
try:
|
157 |
+
sound = AudioSegment.from_file(audio_path)
|
158 |
+
nonsilent_ranges = detect_nonsilent(
|
159 |
+
sound, min_silence_len=500, silence_thresh=-50)
|
160 |
+
|
161 |
+
if not nonsilent_ranges:
|
162 |
+
logger.warning("No non-silent parts detected in the audio. Using original file.")
|
163 |
+
return
|
164 |
+
|
165 |
+
stripped = reduce(lambda acc, val: acc + sound[val[0]:val[1]],
|
166 |
+
nonsilent_ranges, AudioSegment.empty())
|
167 |
+
stripped.export(audio_path, format='mp3')
|
168 |
+
except Exception as e:
|
169 |
+
logger.error(f"Error stripping silence: {e}")
|
170 |
+
logger.info("Proceeding with original audio file")
|
171 |
+
|
172 |
+
|
173 |
+
class AudioFeature:
|
174 |
+
"""Class for extracting and processing audio features."""
|
175 |
+
|
176 |
+
def __init__(self, audio_path: str, sr: int = SR, hop_length: int = HOP_LENGTH):
|
177 |
+
"""Initialize the AudioFeature class.
|
178 |
+
|
179 |
+
Args:
|
180 |
+
audio_path: Path to the audio file
|
181 |
+
sr: Sample rate for audio processing
|
182 |
+
hop_length: Hop length for feature extraction
|
183 |
+
"""
|
184 |
+
self.audio_path: str = audio_path
|
185 |
+
self.sr: int = sr
|
186 |
+
self.hop_length: int = hop_length
|
187 |
+
self.time_signature: int = 4
|
188 |
+
|
189 |
+
# Initialize all features as None
|
190 |
+
self.y: Optional[np.ndarray] = None
|
191 |
+
self.y_harm: Optional[np.ndarray] = None
|
192 |
+
self.y_perc: Optional[np.ndarray] = None
|
193 |
+
self.beats: Optional[np.ndarray] = None
|
194 |
+
self.chroma_acts: Optional[np.ndarray] = None
|
195 |
+
self.chromagram: Optional[np.ndarray] = None
|
196 |
+
self.combined_features: Optional[np.ndarray] = None
|
197 |
+
self.key: Optional[str] = None
|
198 |
+
self.mode: Optional[str] = None
|
199 |
+
self.mel_acts: Optional[np.ndarray] = None
|
200 |
+
self.melspectrogram: Optional[np.ndarray] = None
|
201 |
+
self.meter_grid: Optional[np.ndarray] = None
|
202 |
+
self.mfccs: Optional[np.ndarray] = None
|
203 |
+
self.mfcc_acts: Optional[np.ndarray] = None
|
204 |
+
self.n_frames: Optional[int] = None
|
205 |
+
self.onset_env: Optional[np.ndarray] = None
|
206 |
+
self.rms: Optional[np.ndarray] = None
|
207 |
+
self.spectrogram: Optional[np.ndarray] = None
|
208 |
+
self.tempo: Optional[float] = None
|
209 |
+
self.tempogram: Optional[np.ndarray] = None
|
210 |
+
self.tempogram_acts: Optional[np.ndarray] = None
|
211 |
+
|
212 |
+
def detect_key(self, chroma_vals: np.ndarray) -> Tuple[str, str]:
|
213 |
+
"""Detect the key and mode (major or minor) of the audio segment.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
chroma_vals: Chromagram values to analyze for key detection
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
Tuple containing the detected key and mode
|
220 |
+
"""
|
221 |
+
note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
222 |
+
|
223 |
+
# Key profiles (Krumhansl-Kessler profiles)
|
224 |
+
major_profile = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88])
|
225 |
+
minor_profile = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17])
|
226 |
+
|
227 |
+
# Normalize profiles
|
228 |
+
major_profile /= np.linalg.norm(major_profile)
|
229 |
+
minor_profile /= np.linalg.norm(minor_profile)
|
230 |
+
|
231 |
+
# Calculate correlations for all possible rotations
|
232 |
+
major_correlations = [np.corrcoef(chroma_vals, np.roll(major_profile, i))[0, 1] for i in range(12)]
|
233 |
+
minor_correlations = [np.corrcoef(chroma_vals, np.roll(minor_profile, i))[0, 1] for i in range(12)]
|
234 |
+
|
235 |
+
# Find max correlation
|
236 |
+
max_major_idx = np.argmax(major_correlations)
|
237 |
+
max_minor_idx = np.argmax(minor_correlations)
|
238 |
+
|
239 |
+
# Determine mode
|
240 |
+
self.mode = 'major' if major_correlations[max_major_idx] > minor_correlations[max_minor_idx] else 'minor'
|
241 |
+
self.key = note_names[max_major_idx if self.mode == 'major' else max_minor_idx]
|
242 |
+
|
243 |
+
return self.key, self.mode
|
244 |
+
|
245 |
+
def calculate_ki_chroma(self, waveform: np.ndarray, sr: int, hop_length: int) -> np.ndarray:
|
246 |
+
"""Calculate a normalized, key-invariant chromagram for the given audio waveform.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
waveform: Audio waveform to analyze
|
250 |
+
sr: Sample rate of the waveform
|
251 |
+
hop_length: Hop length for feature extraction
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
The key-invariant chromagram as a numpy array
|
255 |
+
"""
|
256 |
+
# Calculate chromagram
|
257 |
+
chromagram = librosa.feature.chroma_cqt(
|
258 |
+
y=waveform, sr=sr, hop_length=hop_length, bins_per_octave=24)
|
259 |
+
|
260 |
+
# Normalize to [0, 1]
|
261 |
+
chromagram = (chromagram - chromagram.min()) / (chromagram.max() - chromagram.min() + 1e-8)
|
262 |
+
|
263 |
+
# Detect key
|
264 |
+
chroma_vals = np.sum(chromagram, axis=1)
|
265 |
+
key, mode = self.detect_key(chroma_vals)
|
266 |
+
|
267 |
+
# Make key-invariant
|
268 |
+
key_idx = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'].index(key)
|
269 |
+
shift_amount = -key_idx if mode == 'major' else -(key_idx + 3) % 12
|
270 |
+
|
271 |
+
return librosa.util.normalize(np.roll(chromagram, shift_amount, axis=0), axis=1)
|
272 |
+
|
273 |
+
def extract_features(self) -> None:
|
274 |
+
"""Extract various audio features from the loaded audio."""
|
275 |
+
# Load audio
|
276 |
+
self.y, self.sr = librosa.load(self.audio_path, sr=self.sr)
|
277 |
+
|
278 |
+
# Harmonic-percussive source separation
|
279 |
+
self.y_harm, self.y_perc = librosa.effects.hpss(self.y)
|
280 |
+
|
281 |
+
# Extract spectrogram
|
282 |
+
self.spectrogram, _ = librosa.magphase(librosa.stft(self.y, hop_length=self.hop_length))
|
283 |
+
|
284 |
+
# RMS energy
|
285 |
+
self.rms = librosa.feature.rms(S=self.spectrogram, hop_length=self.hop_length).astype(np.float32)
|
286 |
+
|
287 |
+
# Mel spectrogram and activations
|
288 |
+
self.melspectrogram = librosa.feature.melspectrogram(
|
289 |
+
y=self.y, sr=self.sr, n_mels=128, hop_length=self.hop_length).astype(np.float32)
|
290 |
+
self.mel_acts = librosa.decompose.decompose(self.melspectrogram, n_components=3, sort=True)[1].astype(np.float32)
|
291 |
+
|
292 |
+
# Chromagram and activations
|
293 |
+
self.chromagram = self.calculate_ki_chroma(self.y_harm, self.sr, self.hop_length).astype(np.float32)
|
294 |
+
self.chroma_acts = librosa.decompose.decompose(self.chromagram, n_components=4, sort=True)[1].astype(np.float32)
|
295 |
+
|
296 |
+
# Onset detection and tempogram
|
297 |
+
self.onset_env = librosa.onset.onset_strength(y=self.y_perc, sr=self.sr, hop_length=self.hop_length)
|
298 |
+
self.tempogram = np.clip(librosa.feature.tempogram(
|
299 |
+
onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length), 0, None)
|
300 |
+
self.tempogram_acts = librosa.decompose.decompose(self.tempogram, n_components=3, sort=True)[1]
|
301 |
+
|
302 |
+
# MFCCs and activations
|
303 |
+
self.mfccs = librosa.feature.mfcc(y=self.y, sr=self.sr, n_mfcc=20, hop_length=self.hop_length)
|
304 |
+
self.mfccs += abs(np.min(self.mfccs) or 0) # Handle negative values
|
305 |
+
self.mfcc_acts = librosa.decompose.decompose(self.mfccs, n_components=4, sort=True)[1].astype(np.float32)
|
306 |
+
|
307 |
+
# Combine features with weighted normalization
|
308 |
+
self._combine_features()
|
309 |
+
|
310 |
+
def _combine_features(self) -> None:
|
311 |
+
"""Combine all extracted features with balanced weights."""
|
312 |
+
features = [self.rms, self.mel_acts, self.chroma_acts, self.tempogram_acts, self.mfcc_acts]
|
313 |
+
feature_names = ['rms', 'mel_acts', 'chroma_acts', 'tempogram_acts', 'mfcc_acts']
|
314 |
+
|
315 |
+
# Calculate dimension-based weights
|
316 |
+
dims = {name: feature.shape[0] for feature, name in zip(features, feature_names)}
|
317 |
+
total_inv_dim = sum(1 / dim for dim in dims.values())
|
318 |
+
weights = {name: 1 / (dims[name] * total_inv_dim) for name in feature_names}
|
319 |
+
|
320 |
+
# Normalize and weight each feature
|
321 |
+
std_weighted_features = [
|
322 |
+
StandardScaler().fit_transform(feature.T).T * weights[name]
|
323 |
+
for feature, name in zip(features, feature_names)
|
324 |
+
]
|
325 |
+
|
326 |
+
# Combine features
|
327 |
+
self.combined_features = np.concatenate(std_weighted_features, axis=0).T.astype(np.float32)
|
328 |
+
self.n_frames = len(self.combined_features)
|
329 |
+
|
330 |
+
def create_meter_grid(self) -> np.ndarray:
|
331 |
+
"""Create a grid based on the meter of the song using tempo and beats.
|
332 |
+
|
333 |
+
Returns:
|
334 |
+
Numpy array containing the meter grid frame positions
|
335 |
+
"""
|
336 |
+
# Extract tempo and beat information
|
337 |
+
self.tempo, self.beats = librosa.beat.beat_track(
|
338 |
+
onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length)
|
339 |
+
|
340 |
+
# Adjust tempo if it's too slow or too fast
|
341 |
+
self.tempo = self._adjust_tempo(self.tempo)
|
342 |
+
|
343 |
+
# Create meter grid
|
344 |
+
self.meter_grid = self._create_meter_grid()
|
345 |
+
return self.meter_grid
|
346 |
+
|
347 |
+
def _adjust_tempo(self, tempo: float) -> float:
|
348 |
+
"""Adjust tempo to a reasonable range.
|
349 |
+
|
350 |
+
Args:
|
351 |
+
tempo: Detected tempo
|
352 |
+
|
353 |
+
Returns:
|
354 |
+
Adjusted tempo
|
355 |
+
"""
|
356 |
+
if tempo < 70:
|
357 |
+
return tempo * 2
|
358 |
+
elif tempo > 140:
|
359 |
+
return tempo / 2
|
360 |
+
return tempo
|
361 |
+
|
362 |
+
def _create_meter_grid(self) -> np.ndarray:
|
363 |
+
"""Helper function to create a meter grid for the song.
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
Numpy array containing the meter grid frame positions
|
367 |
+
"""
|
368 |
+
# Calculate beat interval
|
369 |
+
seconds_per_beat = 60 / self.tempo
|
370 |
+
beat_interval = int(librosa.time_to_frames(seconds_per_beat, sr=self.sr, hop_length=self.hop_length))
|
371 |
+
|
372 |
+
# Find best matching start beat
|
373 |
+
if len(self.beats) >= 3:
|
374 |
+
best_match = max(
|
375 |
+
(1 - abs(np.mean(self.beats[i:i+3]) - beat_interval) / beat_interval, self.beats[i])
|
376 |
+
for i in range(len(self.beats) - 2)
|
377 |
+
)
|
378 |
+
anchor_frame = best_match[1] if best_match[0] > 0.95 else self.beats[0]
|
379 |
+
else:
|
380 |
+
anchor_frame = self.beats[0] if len(self.beats) > 0 else 0
|
381 |
+
|
382 |
+
first_beat_time = librosa.frames_to_time(anchor_frame, sr=self.sr, hop_length=self.hop_length)
|
383 |
+
|
384 |
+
# Calculate beats forward and backward
|
385 |
+
time_duration = librosa.frames_to_time(self.n_frames, sr=self.sr, hop_length=self.hop_length)
|
386 |
+
num_beats_forward = int((time_duration - first_beat_time) / seconds_per_beat)
|
387 |
+
num_beats_backward = int(first_beat_time / seconds_per_beat) + 1
|
388 |
+
|
389 |
+
# Create beat times
|
390 |
+
beat_times_forward = first_beat_time + np.arange(num_beats_forward) * seconds_per_beat
|
391 |
+
beat_times_backward = first_beat_time - np.arange(1, num_beats_backward) * seconds_per_beat
|
392 |
+
|
393 |
+
# Combine and create meter grid
|
394 |
+
beat_grid = np.concatenate((np.array([0.0]), beat_times_backward[::-1], beat_times_forward))
|
395 |
+
meter_indices = np.arange(0, len(beat_grid), self.time_signature)
|
396 |
+
meter_grid = beat_grid[meter_indices]
|
397 |
+
|
398 |
+
# Ensure grid starts at 0 and ends at frame duration
|
399 |
+
if meter_grid[0] != 0.0:
|
400 |
+
meter_grid = np.insert(meter_grid, 0, 0.0)
|
401 |
+
|
402 |
+
# Convert to frames
|
403 |
+
meter_grid = librosa.time_to_frames(meter_grid, sr=self.sr, hop_length=self.hop_length)
|
404 |
+
|
405 |
+
# Ensure grid ends at the last frame
|
406 |
+
if meter_grid[-1] != self.n_frames:
|
407 |
+
meter_grid = np.append(meter_grid, self.n_frames)
|
408 |
+
|
409 |
+
return meter_grid
|
chorus_detection/config.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""Configuration settings for the chorus detection system."""
|
5 |
+
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Dict, Any, Union, Optional
|
9 |
+
|
10 |
+
# Audio processing settings
|
11 |
+
SR: int = 12000 # Sample rate
|
12 |
+
HOP_LENGTH: int = 128 # Hop length for signal processing
|
13 |
+
MAX_FRAMES: int = 300 # Maximum frames per segment
|
14 |
+
MAX_METERS: int = 201 # Maximum meters per song
|
15 |
+
N_FEATURES: int = 15 # Number of features
|
16 |
+
|
17 |
+
# Project paths
|
18 |
+
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.resolve()
|
19 |
+
MODEL_DIR: Path = PROJECT_ROOT / "models" / "CRNN"
|
20 |
+
MODEL_PATH: Path = MODEL_DIR / "best_model_V3.h5"
|
21 |
+
AUDIO_TEMP_PATH: Path = PROJECT_ROOT / "output" / "temp"
|
22 |
+
LOG_DIR: Path = PROJECT_ROOT / "logs"
|
23 |
+
|
24 |
+
# Alternative Docker paths
|
25 |
+
DOCKER_MODEL_PATH: str = "/app/models/CRNN/best_model_V3.h5"
|
26 |
+
DOCKER_TEMP_PATH: str = "/app/output/temp"
|
27 |
+
|
28 |
+
|
29 |
+
def get_env_path(env_var: str, default_path: Path) -> Path:
|
30 |
+
"""Get a path from environment variable or use the default.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
env_var: Name of the environment variable
|
34 |
+
default_path: Default path to use if environment variable is not set
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Path object for the specified location
|
38 |
+
"""
|
39 |
+
env_value = os.environ.get(env_var)
|
40 |
+
if env_value:
|
41 |
+
return Path(env_value).resolve()
|
42 |
+
return default_path
|
43 |
+
|
44 |
+
|
45 |
+
# Override paths with environment variables if provided
|
46 |
+
MODEL_PATH = get_env_path("CHORUS_MODEL_PATH", MODEL_PATH)
|
47 |
+
AUDIO_TEMP_PATH = get_env_path("CHORUS_TEMP_PATH", AUDIO_TEMP_PATH)
|
48 |
+
LOG_DIR = get_env_path("CHORUS_LOG_DIR", LOG_DIR)
|
49 |
+
|
50 |
+
# Create necessary directories
|
51 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
52 |
+
os.makedirs(AUDIO_TEMP_PATH, exist_ok=True)
|
53 |
+
os.makedirs(LOG_DIR, exist_ok=True)
|
54 |
+
os.makedirs(PROJECT_ROOT / "output", exist_ok=True)
|
chorus_detection/models/__init__.py
ADDED
File without changes
|
chorus_detection/models/crnn.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""Module for loading and managing the CRNN model for chorus detection."""
|
5 |
+
|
6 |
+
import os
|
7 |
+
from typing import Any, Optional, List, Tuple, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import tensorflow as tf
|
11 |
+
|
12 |
+
from chorus_detection.config import MODEL_PATH, DOCKER_MODEL_PATH
|
13 |
+
from chorus_detection.utils.logging import logger
|
14 |
+
|
15 |
+
|
16 |
+
def load_CRNN_model(model_path: str = str(MODEL_PATH)) -> tf.keras.Model:
|
17 |
+
"""Load a CRNN model with custom loss and accuracy functions.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
model_path: Path to the saved model
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
Loaded Keras model
|
24 |
+
|
25 |
+
Raises:
|
26 |
+
RuntimeError: If the model cannot be loaded
|
27 |
+
"""
|
28 |
+
try:
|
29 |
+
# Define custom objects required for model loading
|
30 |
+
custom_objects = {
|
31 |
+
'custom_binary_crossentropy': lambda y_true, y_pred: y_pred,
|
32 |
+
'custom_accuracy': lambda y_true, y_pred: y_pred
|
33 |
+
}
|
34 |
+
|
35 |
+
# Try to load the model with custom objects
|
36 |
+
logger.info(f"Loading model from: {model_path}")
|
37 |
+
model = tf.keras.models.load_model(
|
38 |
+
model_path, custom_objects=custom_objects, compile=False)
|
39 |
+
|
40 |
+
# Compile the model with default optimizer and loss for prediction only
|
41 |
+
model.compile(optimizer='adam', loss='binary_crossentropy')
|
42 |
+
|
43 |
+
return model
|
44 |
+
except Exception as e:
|
45 |
+
logger.error(f"Error loading model from {model_path}: {e}")
|
46 |
+
|
47 |
+
# Try Docker container path as fallback
|
48 |
+
if model_path != DOCKER_MODEL_PATH and os.path.exists(DOCKER_MODEL_PATH):
|
49 |
+
logger.info(f"Trying Docker path: {DOCKER_MODEL_PATH}")
|
50 |
+
return load_CRNN_model(DOCKER_MODEL_PATH)
|
51 |
+
|
52 |
+
raise RuntimeError(f"Failed to load model: {e}")
|
53 |
+
|
54 |
+
|
55 |
+
def smooth_predictions(predictions: np.ndarray) -> np.ndarray:
|
56 |
+
"""Smooth predictions by correcting isolated mispredictions and removing short sequences.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
predictions: Array of binary predictions
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
Smoothed array of binary predictions
|
63 |
+
"""
|
64 |
+
# Convert to numpy array if not already
|
65 |
+
data = np.array(predictions, copy=True) if not isinstance(predictions, np.ndarray) else predictions.copy()
|
66 |
+
|
67 |
+
# First pass: Correct isolated 0's (handle 0's surrounded by 1's)
|
68 |
+
for i in range(1, len(data) - 1):
|
69 |
+
if data[i] == 0 and data[i - 1] == 1 and data[i + 1] == 1:
|
70 |
+
data[i] = 1
|
71 |
+
|
72 |
+
# Second pass: Correct isolated 1's (handle 1's surrounded by 0's)
|
73 |
+
corrected_data = data.copy()
|
74 |
+
for i in range(1, len(data) - 1):
|
75 |
+
if data[i] == 1 and data[i - 1] == 0 and data[i + 1] == 0:
|
76 |
+
corrected_data[i] = 0
|
77 |
+
|
78 |
+
# Third pass: Remove short sequences of 1s (less than 5 consecutive 1's)
|
79 |
+
smoothed_data = corrected_data.copy()
|
80 |
+
sequence_start = None
|
81 |
+
|
82 |
+
for i in range(len(corrected_data)):
|
83 |
+
if corrected_data[i] == 1:
|
84 |
+
if sequence_start is None:
|
85 |
+
sequence_start = i
|
86 |
+
else:
|
87 |
+
if sequence_start is not None:
|
88 |
+
sequence_length = i - sequence_start
|
89 |
+
if sequence_length < 5:
|
90 |
+
smoothed_data[sequence_start:i] = 0
|
91 |
+
sequence_start = None
|
92 |
+
|
93 |
+
# Handle the case where the sequence extends to the end
|
94 |
+
if sequence_start is not None:
|
95 |
+
sequence_length = len(corrected_data) - sequence_start
|
96 |
+
if sequence_length < 5:
|
97 |
+
smoothed_data[sequence_start:] = 0
|
98 |
+
|
99 |
+
return smoothed_data
|
100 |
+
|
101 |
+
|
102 |
+
def make_predictions(model: tf.keras.Model, processed_audio: np.ndarray,
|
103 |
+
audio_features: Any, url: Optional[str] = None,
|
104 |
+
video_name: Optional[str] = None) -> np.ndarray:
|
105 |
+
"""Generate predictions from the model and process them.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
model: The loaded model for making predictions
|
109 |
+
processed_audio: The audio data that has been processed for prediction
|
110 |
+
audio_features: Audio features object containing necessary metadata
|
111 |
+
url: YouTube URL of the audio file (optional)
|
112 |
+
video_name: Name of the video (optional)
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
The smoothed binary predictions
|
116 |
+
"""
|
117 |
+
import librosa
|
118 |
+
|
119 |
+
logger.info("Generating predictions...")
|
120 |
+
|
121 |
+
# Make predictions
|
122 |
+
predictions = model.predict(processed_audio)[0]
|
123 |
+
|
124 |
+
# Convert to binary predictions and handle potential size mismatch
|
125 |
+
meter_grid_length = len(audio_features.meter_grid) - 1
|
126 |
+
if len(predictions) > meter_grid_length:
|
127 |
+
predictions = predictions[:meter_grid_length]
|
128 |
+
|
129 |
+
binary_predictions = np.round(predictions).flatten()
|
130 |
+
|
131 |
+
# Apply smoothing to improve prediction quality
|
132 |
+
smoothed_predictions = smooth_predictions(binary_predictions)
|
133 |
+
|
134 |
+
# Get times for identified chorus sections
|
135 |
+
meter_grid_times = librosa.frames_to_time(
|
136 |
+
audio_features.meter_grid,
|
137 |
+
sr=audio_features.sr,
|
138 |
+
hop_length=audio_features.hop_length
|
139 |
+
)
|
140 |
+
|
141 |
+
# Identify where choruses start
|
142 |
+
chorus_start_times = [
|
143 |
+
meter_grid_times[i] for i in range(len(smoothed_predictions))
|
144 |
+
if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)
|
145 |
+
]
|
146 |
+
|
147 |
+
# Print results if URL and video name are provided (CLI mode)
|
148 |
+
if url and video_name:
|
149 |
+
_print_chorus_results(url, video_name, chorus_start_times)
|
150 |
+
|
151 |
+
return smoothed_predictions
|
152 |
+
|
153 |
+
|
154 |
+
def _print_chorus_results(url: str, video_name: str, chorus_start_times: List[float]) -> None:
|
155 |
+
"""Print formatted results showing identified choruses with links.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
url: YouTube URL of the analyzed video
|
159 |
+
video_name: Name of the video
|
160 |
+
chorus_start_times: List of start times (in seconds) for identified choruses
|
161 |
+
"""
|
162 |
+
# Create YouTube links with time stamps
|
163 |
+
youtube_links = [
|
164 |
+
f"\033]8;;{url}&t={int(start_time)}s\033\\{url}&t={int(start_time)}s\033]8;;\033\\"
|
165 |
+
for start_time in chorus_start_times
|
166 |
+
]
|
167 |
+
|
168 |
+
# Format the output
|
169 |
+
link_lengths = [len(link) for link in youtube_links]
|
170 |
+
max_length = max(link_lengths + [len(video_name), len(f"Number of choruses identified: {len(chorus_start_times)}")]) if link_lengths else 50
|
171 |
+
header_footer = "=" * (max_length + 4)
|
172 |
+
|
173 |
+
# Print the results
|
174 |
+
print("\n\n")
|
175 |
+
print(header_footer)
|
176 |
+
print(f"{video_name.center(max_length + 2)}")
|
177 |
+
print(f"Number of choruses identified: {len(chorus_start_times)}".center(max_length + 4))
|
178 |
+
print(header_footer)
|
179 |
+
|
180 |
+
if chorus_start_times:
|
181 |
+
for link in youtube_links:
|
182 |
+
print(link)
|
183 |
+
else:
|
184 |
+
print("No choruses identified.")
|
185 |
+
|
186 |
+
print(header_footer)
|
chorus_detection/utils/__init__.py
ADDED
File without changes
|
chorus_detection/utils/cli.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""Command-line interface utilities for the chorus detection system."""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import Dict, Any, Optional, Tuple
|
11 |
+
|
12 |
+
from chorus_detection.config import MODEL_PATH
|
13 |
+
from chorus_detection.utils.logging import logger
|
14 |
+
|
15 |
+
|
16 |
+
def parse_arguments() -> argparse.Namespace:
|
17 |
+
"""Parse command-line arguments.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Parsed command-line arguments
|
21 |
+
"""
|
22 |
+
parser = argparse.ArgumentParser(
|
23 |
+
description="Chorus Finder - Detect choruses in songs from YouTube URLs or local audio files")
|
24 |
+
|
25 |
+
input_group = parser.add_mutually_exclusive_group()
|
26 |
+
input_group.add_argument("--url", type=str,
|
27 |
+
help="YouTube URL of a song")
|
28 |
+
input_group.add_argument("--file", type=str,
|
29 |
+
help="Path to a local audio file")
|
30 |
+
|
31 |
+
parser.add_argument("--model_path", type=str, default=str(MODEL_PATH),
|
32 |
+
help=f"Path to the pretrained model (default: {MODEL_PATH})")
|
33 |
+
parser.add_argument("--verbose", action="store_true",
|
34 |
+
help="Verbose output", default=True)
|
35 |
+
parser.add_argument("--plot", action="store_true",
|
36 |
+
help="Display plot of the audio waveform", default=True)
|
37 |
+
parser.add_argument("--no-plot", dest="plot", action="store_false",
|
38 |
+
help="Disable plot display (useful for headless environments)")
|
39 |
+
|
40 |
+
return parser.parse_args()
|
41 |
+
|
42 |
+
|
43 |
+
def get_input_source(args: argparse.Namespace) -> Optional[str]:
|
44 |
+
"""Get input source from arguments or user input.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
args: Parsed command-line arguments
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
Input source (URL or file path)
|
51 |
+
"""
|
52 |
+
input_source = args.url or args.file
|
53 |
+
if not input_source:
|
54 |
+
print("\nChorus Detection Tool")
|
55 |
+
print("====================")
|
56 |
+
print("\nNote: YouTube download functionality may be temporarily unavailable")
|
57 |
+
print("due to YouTube's restrictions. If download fails, please use a local audio file.\n")
|
58 |
+
print("Choose input method:")
|
59 |
+
print("1. YouTube URL")
|
60 |
+
print("2. Local audio file")
|
61 |
+
choice = input("Enter choice (1 or 2): ")
|
62 |
+
|
63 |
+
if choice == "1":
|
64 |
+
input_source = input("Please enter the YouTube URL of the song: ")
|
65 |
+
elif choice == "2":
|
66 |
+
input_source = input("Please enter the path to the audio file: ")
|
67 |
+
else:
|
68 |
+
logger.error("Invalid choice")
|
69 |
+
sys.exit(1)
|
70 |
+
|
71 |
+
return input_source
|
72 |
+
|
73 |
+
|
74 |
+
def is_youtube_url(input_source: str) -> bool:
|
75 |
+
"""Check if the input source is a YouTube URL.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
input_source: Input source to check
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
True if the input source is a YouTube URL, False otherwise
|
82 |
+
"""
|
83 |
+
return input_source.startswith(('http://', 'https://'))
|
84 |
+
|
85 |
+
|
86 |
+
def validate_input_file(file_path: str) -> bool:
|
87 |
+
"""Validate that the input file exists and is readable.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
file_path: Path to the input file
|
91 |
+
|
92 |
+
Returns:
|
93 |
+
True if the file is valid, False otherwise
|
94 |
+
"""
|
95 |
+
if not os.path.exists(file_path):
|
96 |
+
logger.error(f"Error: File not found at {file_path}")
|
97 |
+
return False
|
98 |
+
|
99 |
+
if not os.path.isfile(file_path):
|
100 |
+
logger.error(f"Error: {file_path} is not a file")
|
101 |
+
return False
|
102 |
+
|
103 |
+
if not os.access(file_path, os.R_OK):
|
104 |
+
logger.error(f"Error: No permission to read {file_path}")
|
105 |
+
return False
|
106 |
+
|
107 |
+
return True
|
chorus_detection/utils/logging.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""Logging configuration for the chorus detection system."""
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
from chorus_detection.config import PROJECT_ROOT
|
12 |
+
|
13 |
+
# Create logs directory
|
14 |
+
LOGS_DIR = PROJECT_ROOT / "logs"
|
15 |
+
os.makedirs(LOGS_DIR, exist_ok=True)
|
16 |
+
|
17 |
+
|
18 |
+
def setup_logger(name: str = "chorus_detection", level: int = logging.INFO,
|
19 |
+
log_file: Optional[str] = None) -> logging.Logger:
|
20 |
+
"""Configure and return a logger with the specified name and level.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
name: Name of the logger
|
24 |
+
level: Logging level (default: INFO)
|
25 |
+
log_file: Path to the log file (default: None)
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
Configured logger instance
|
29 |
+
"""
|
30 |
+
logger = logging.getLogger(name)
|
31 |
+
logger.setLevel(level)
|
32 |
+
|
33 |
+
# Create formatter
|
34 |
+
formatter = logging.Formatter(
|
35 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
36 |
+
)
|
37 |
+
|
38 |
+
# Create console handler
|
39 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
40 |
+
console_handler.setFormatter(formatter)
|
41 |
+
logger.addHandler(console_handler)
|
42 |
+
|
43 |
+
# Create file handler if log_file is specified
|
44 |
+
if log_file:
|
45 |
+
file_handler = logging.FileHandler(log_file)
|
46 |
+
file_handler.setFormatter(formatter)
|
47 |
+
logger.addHandler(file_handler)
|
48 |
+
|
49 |
+
return logger
|
50 |
+
|
51 |
+
|
52 |
+
# Create default logger
|
53 |
+
logger = setup_logger()
|
chorus_detection/visualization/__init__.py
ADDED
File without changes
|
chorus_detection/visualization/plotter.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""Module for visualizing audio data and chorus predictions."""
|
5 |
+
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
import librosa
|
9 |
+
import librosa.display
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
+
import os
|
13 |
+
|
14 |
+
from chorus_detection.audio.processor import AudioFeature
|
15 |
+
|
16 |
+
|
17 |
+
def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None:
|
18 |
+
"""Draw meter grid lines on the plot.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
ax: The matplotlib axes object to draw on
|
22 |
+
meter_grid_times: Array of times at which to draw the meter lines
|
23 |
+
"""
|
24 |
+
for time in meter_grid_times:
|
25 |
+
ax.axvline(x=time, color='grey', linestyle='--',
|
26 |
+
linewidth=1, alpha=0.6)
|
27 |
+
|
28 |
+
|
29 |
+
def plot_predictions(audio_features: AudioFeature, binary_predictions: np.ndarray) -> None:
|
30 |
+
"""Plot the audio waveform and overlay the predicted chorus locations.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
audio_features: An object containing audio features and components
|
34 |
+
binary_predictions: Array of binary predictions indicating chorus locations
|
35 |
+
"""
|
36 |
+
meter_grid_times = librosa.frames_to_time(
|
37 |
+
audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
|
38 |
+
fig, ax = plt.subplots(figsize=(12.5, 3), dpi=96)
|
39 |
+
|
40 |
+
# Display harmonic and percussive components
|
41 |
+
librosa.display.waveshow(audio_features.y_harm, sr=audio_features.sr,
|
42 |
+
alpha=0.8, ax=ax, color='deepskyblue')
|
43 |
+
librosa.display.waveshow(audio_features.y_perc, sr=audio_features.sr,
|
44 |
+
alpha=0.7, ax=ax, color='plum')
|
45 |
+
plot_meter_lines(ax, meter_grid_times)
|
46 |
+
|
47 |
+
# Highlight chorus sections
|
48 |
+
for i, prediction in enumerate(binary_predictions):
|
49 |
+
start_time = meter_grid_times[i]
|
50 |
+
end_time = meter_grid_times[i + 1] if i < len(
|
51 |
+
meter_grid_times) - 1 else len(audio_features.y) / audio_features.sr
|
52 |
+
if prediction == 1:
|
53 |
+
ax.axvspan(start_time, end_time, color='green', alpha=0.3,
|
54 |
+
label='Predicted Chorus' if i == 0 else None)
|
55 |
+
|
56 |
+
# Set plot limits and labels
|
57 |
+
ax.set_xlim([0, len(audio_features.y) / audio_features.sr])
|
58 |
+
ax.set_ylabel('Amplitude')
|
59 |
+
audio_file_name = os.path.basename(audio_features.audio_path)
|
60 |
+
ax.set_title(
|
61 |
+
f'Chorus Predictions for {os.path.splitext(audio_file_name)[0]}')
|
62 |
+
|
63 |
+
# Add legend
|
64 |
+
chorus_patch = plt.Rectangle((0, 0), 1, 1, fc='green', alpha=0.3)
|
65 |
+
handles, labels = ax.get_legend_handles_labels()
|
66 |
+
handles.append(chorus_patch)
|
67 |
+
labels.append('Chorus')
|
68 |
+
ax.legend(handles=handles, labels=labels)
|
69 |
+
|
70 |
+
# Set x-tick labels in minutes:seconds format
|
71 |
+
duration = len(audio_features.y) / audio_features.sr
|
72 |
+
xticks = np.arange(0, duration, 10)
|
73 |
+
xlabels = [f"{int(tick // 60)}:{int(tick % 60):02d}" for tick in xticks]
|
74 |
+
ax.set_xticks(xticks)
|
75 |
+
ax.set_xticklabels(xlabels)
|
76 |
+
|
77 |
+
plt.tight_layout()
|
78 |
+
plt.show(block=False)
|
setup.py
CHANGED
@@ -6,8 +6,7 @@ from setuptools import setup, find_packages
|
|
6 |
setup(
|
7 |
name="chorus_detection",
|
8 |
version="0.1.0",
|
9 |
-
packages=find_packages(
|
10 |
-
package_dir={"": "src"},
|
11 |
install_requires=[
|
12 |
# These are already in requirements.txt so no need to specify versions
|
13 |
"numpy",
|
|
|
6 |
setup(
|
7 |
name="chorus_detection",
|
8 |
version="0.1.0",
|
9 |
+
packages=find_packages(),
|
|
|
10 |
install_requires=[
|
11 |
# These are already in requirements.txt so no need to specify versions
|
12 |
"numpy",
|
streamlit_app.py
ADDED
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
"""Streamlit web app for chorus detection in audio files.
|
5 |
+
|
6 |
+
This module provides a web-based interface for the chorus detection system,
|
7 |
+
allowing users to upload audio files or provide YouTube URLs for analysis.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import logging
|
13 |
+
|
14 |
+
# Configure logging
|
15 |
+
logger = logging.getLogger("streamlit-app")
|
16 |
+
|
17 |
+
# Configure TensorFlow logging before importing TensorFlow
|
18 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logs
|
19 |
+
|
20 |
+
# Import model downloader to ensure model is available
|
21 |
+
try:
|
22 |
+
from download_model import ensure_model_exists
|
23 |
+
except ImportError as e:
|
24 |
+
logger.error(f"Error importing ensure_model_exists: {e}")
|
25 |
+
raise
|
26 |
+
|
27 |
+
import base64
|
28 |
+
import tempfile
|
29 |
+
import warnings
|
30 |
+
from typing import Optional, Tuple, List
|
31 |
+
import time
|
32 |
+
import io
|
33 |
+
|
34 |
+
import matplotlib.pyplot as plt
|
35 |
+
import streamlit as st
|
36 |
+
import tensorflow as tf
|
37 |
+
import librosa
|
38 |
+
import soundfile as sf
|
39 |
+
import numpy as np
|
40 |
+
from pydub import AudioSegment
|
41 |
+
|
42 |
+
# Suppress warnings
|
43 |
+
warnings.filterwarnings("ignore") # Suppress all warnings
|
44 |
+
tf.get_logger().setLevel('ERROR') # Suppress TensorFlow ERROR logs
|
45 |
+
|
46 |
+
# Debug import paths
|
47 |
+
logger.info(f"Python path: {sys.path}")
|
48 |
+
logger.info(f"Current working directory: {os.getcwd()}")
|
49 |
+
|
50 |
+
# Import modules
|
51 |
+
try:
|
52 |
+
from chorus_detection.audio.data_processing import process_audio
|
53 |
+
from chorus_detection.audio.processor import extract_audio
|
54 |
+
from chorus_detection.models.crnn import load_CRNN_model, make_predictions
|
55 |
+
from chorus_detection.utils.cli import is_youtube_url
|
56 |
+
from chorus_detection.utils.logging import logger
|
57 |
+
|
58 |
+
logger.info("Successfully imported chorus_detection modules")
|
59 |
+
except ImportError as e:
|
60 |
+
logger.error(f"Error importing chorus_detection modules: {e}")
|
61 |
+
raise
|
62 |
+
|
63 |
+
# Define the MODEL_PATH directly
|
64 |
+
MODEL_PATH = os.path.join(os.getcwd(), "models", "CRNN", "best_model_V3.h5")
|
65 |
+
if not os.path.exists(MODEL_PATH):
|
66 |
+
MODEL_PATH = ensure_model_exists()
|
67 |
+
|
68 |
+
# Define color scheme
|
69 |
+
THEME_COLORS = {
|
70 |
+
'background': '#121212',
|
71 |
+
'card_bg': '#181818',
|
72 |
+
'primary': '#1DB954',
|
73 |
+
'secondary': '#1ED760',
|
74 |
+
'text': '#FFFFFF',
|
75 |
+
'subtext': '#B3B3B3',
|
76 |
+
'highlight': '#1DB954',
|
77 |
+
'border': '#333333',
|
78 |
+
}
|
79 |
+
|
80 |
+
|
81 |
+
def get_binary_file_downloader_html(bin_file: str, file_label: str = 'File') -> str:
|
82 |
+
"""Generate HTML for file download link.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
bin_file: Path to the binary file
|
86 |
+
file_label: Label for the download link
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
HTML string for the download link
|
90 |
+
"""
|
91 |
+
with open(bin_file, 'rb') as f:
|
92 |
+
data = f.read()
|
93 |
+
b64 = base64.b64encode(data).decode()
|
94 |
+
return f'<a href="data:application/octet-stream;base64,{b64}" download="{os.path.basename(bin_file)}">{file_label}</a>'
|
95 |
+
|
96 |
+
|
97 |
+
def set_custom_theme() -> None:
|
98 |
+
"""Apply custom Spotify-inspired theme to Streamlit UI."""
|
99 |
+
custom_theme = f"""
|
100 |
+
<style>
|
101 |
+
.stApp {{
|
102 |
+
background-color: {THEME_COLORS['background']};
|
103 |
+
color: {THEME_COLORS['text']};
|
104 |
+
}}
|
105 |
+
.css-18e3th9 {{
|
106 |
+
padding-top: 2rem;
|
107 |
+
padding-bottom: 10rem;
|
108 |
+
padding-left: 5rem;
|
109 |
+
padding-right: 5rem;
|
110 |
+
}}
|
111 |
+
h1, h2, h3, h4, h5, h6 {{
|
112 |
+
color: {THEME_COLORS['text']} !important;
|
113 |
+
font-weight: 700 !important;
|
114 |
+
}}
|
115 |
+
.stSidebar .sidebar-content {{
|
116 |
+
background-color: {THEME_COLORS['card_bg']};
|
117 |
+
}}
|
118 |
+
.stButton>button {{
|
119 |
+
background-color: {THEME_COLORS['primary']};
|
120 |
+
color: white;
|
121 |
+
border-radius: 500px;
|
122 |
+
padding: 8px 32px;
|
123 |
+
font-weight: 600;
|
124 |
+
border: none;
|
125 |
+
transition: all 0.3s ease;
|
126 |
+
}}
|
127 |
+
.stButton>button:hover {{
|
128 |
+
background-color: {THEME_COLORS['secondary']};
|
129 |
+
transform: scale(1.04);
|
130 |
+
}}
|
131 |
+
</style>
|
132 |
+
"""
|
133 |
+
st.markdown(custom_theme, unsafe_allow_html=True)
|
134 |
+
|
135 |
+
|
136 |
+
def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
|
137 |
+
"""Process a YouTube URL and extract audio.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
url: YouTube URL
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
Tuple of (audio_path, video_name)
|
144 |
+
"""
|
145 |
+
try:
|
146 |
+
with st.spinner('Downloading audio from YouTube...'):
|
147 |
+
audio_path, video_name = extract_audio(url)
|
148 |
+
return audio_path, video_name
|
149 |
+
except Exception as e:
|
150 |
+
st.error(f"Error processing YouTube URL: {e}")
|
151 |
+
logger.error(f"Error processing YouTube URL: {e}", exc_info=True)
|
152 |
+
return None, None
|
153 |
+
|
154 |
+
|
155 |
+
def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
|
156 |
+
"""Process an uploaded audio file.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
uploaded_file: Streamlit UploadedFile object
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
Tuple of (audio_path, file_name)
|
163 |
+
"""
|
164 |
+
try:
|
165 |
+
with st.spinner('Processing uploaded file...'):
|
166 |
+
# Save the uploaded file to a temporary location
|
167 |
+
temp_dir = tempfile.mkdtemp()
|
168 |
+
file_name = uploaded_file.name
|
169 |
+
temp_path = os.path.join(temp_dir, file_name)
|
170 |
+
|
171 |
+
with open(temp_path, 'wb') as f:
|
172 |
+
f.write(uploaded_file.getbuffer())
|
173 |
+
|
174 |
+
return temp_path, file_name.split('.')[0]
|
175 |
+
except Exception as e:
|
176 |
+
st.error(f"Error processing uploaded file: {e}")
|
177 |
+
logger.error(f"Error processing uploaded file: {e}", exc_info=True)
|
178 |
+
return None, None
|
179 |
+
|
180 |
+
|
181 |
+
def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.ndarray,
|
182 |
+
meter_grid_times: np.ndarray) -> List[Tuple[float, float, np.ndarray]]:
|
183 |
+
"""Extract chorus segments from predictions.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
y: Audio data
|
187 |
+
sr: Sample rate
|
188 |
+
smoothed_predictions: Smoothed model predictions
|
189 |
+
meter_grid_times: Time grid for predictions
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
List of (start_time, end_time, audio_segment) tuples
|
193 |
+
"""
|
194 |
+
# Define threshold for chorus detection (probability > 0.5)
|
195 |
+
threshold = 0.5
|
196 |
+
|
197 |
+
# Find the segments where the predictions are above the threshold
|
198 |
+
chorus_mask = smoothed_predictions > threshold
|
199 |
+
|
200 |
+
# Group consecutive True values to identify segments
|
201 |
+
segments = []
|
202 |
+
current_segment = None
|
203 |
+
|
204 |
+
for i, is_chorus in enumerate(chorus_mask):
|
205 |
+
time = meter_grid_times[i]
|
206 |
+
|
207 |
+
if is_chorus and current_segment is None:
|
208 |
+
# Start a new segment
|
209 |
+
current_segment = (time, None, None)
|
210 |
+
elif not is_chorus and current_segment is not None:
|
211 |
+
# End the current segment
|
212 |
+
start_time = current_segment[0]
|
213 |
+
current_segment = (start_time, time, None)
|
214 |
+
segments.append(current_segment)
|
215 |
+
current_segment = None
|
216 |
+
|
217 |
+
# Handle the case where the last segment extends to the end of the song
|
218 |
+
if current_segment is not None:
|
219 |
+
start_time = current_segment[0]
|
220 |
+
segments.append((start_time, meter_grid_times[-1], None))
|
221 |
+
|
222 |
+
# Extract the actual audio for each segment
|
223 |
+
segments_with_audio = []
|
224 |
+
for start_time, end_time, _ in segments:
|
225 |
+
# Convert times to sample indices
|
226 |
+
start_idx = int(start_time * sr)
|
227 |
+
end_idx = int(end_time * sr)
|
228 |
+
|
229 |
+
# Extract the audio segment
|
230 |
+
segment_audio = y[start_idx:end_idx]
|
231 |
+
|
232 |
+
segments_with_audio.append((start_time, end_time, segment_audio))
|
233 |
+
|
234 |
+
return segments_with_audio
|
235 |
+
|
236 |
+
|
237 |
+
def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]],
|
238 |
+
sr: int, fade_duration: float = 0.3) -> Tuple[np.ndarray, str]:
|
239 |
+
"""Create a compilation of chorus segments.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
segments: List of (start_time, end_time, audio_data) tuples
|
243 |
+
sr: Sample rate
|
244 |
+
fade_duration: Duration of fade in/out in seconds
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
Tuple of (compilation_audio, description)
|
248 |
+
"""
|
249 |
+
if not segments:
|
250 |
+
return np.array([]), "No chorus segments found"
|
251 |
+
|
252 |
+
# Calculate the number of samples for fading
|
253 |
+
fade_samples = int(fade_duration * sr)
|
254 |
+
|
255 |
+
# Prepare a list to store the processed segments
|
256 |
+
processed_segments = []
|
257 |
+
|
258 |
+
# Description of segments
|
259 |
+
segment_descriptions = []
|
260 |
+
|
261 |
+
# Process each segment
|
262 |
+
for i, (start_time, end_time, audio) in enumerate(segments):
|
263 |
+
# Apply fade in and fade out
|
264 |
+
segment_length = len(audio)
|
265 |
+
|
266 |
+
if segment_length <= 2 * fade_samples:
|
267 |
+
# Segment is too short for fading, skip it
|
268 |
+
continue
|
269 |
+
|
270 |
+
# Create a linear fade in and fade out
|
271 |
+
fade_in = np.linspace(0, 1, fade_samples)
|
272 |
+
fade_out = np.linspace(1, 0, fade_samples)
|
273 |
+
|
274 |
+
# Apply the fades
|
275 |
+
audio_faded = audio.copy()
|
276 |
+
audio_faded[:fade_samples] *= fade_in
|
277 |
+
audio_faded[-fade_samples:] *= fade_out
|
278 |
+
|
279 |
+
processed_segments.append(audio_faded)
|
280 |
+
|
281 |
+
# Format the times for the description
|
282 |
+
start_fmt = format_time(start_time)
|
283 |
+
end_fmt = format_time(end_time)
|
284 |
+
segment_descriptions.append(f"Chorus {i+1}: {start_fmt} - {end_fmt}")
|
285 |
+
|
286 |
+
if not processed_segments:
|
287 |
+
return np.array([]), "No chorus segments long enough for compilation"
|
288 |
+
|
289 |
+
# Concatenate all the processed segments
|
290 |
+
compilation = np.concatenate(processed_segments)
|
291 |
+
|
292 |
+
# Join the descriptions
|
293 |
+
description = "\n".join(segment_descriptions)
|
294 |
+
|
295 |
+
return compilation, description
|
296 |
+
|
297 |
+
|
298 |
+
def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str = 'mp3') -> bytes:
|
299 |
+
"""Save audio data to a format suitable for Streamlit audio playback.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
audio_data: Audio samples
|
303 |
+
sr: Sample rate
|
304 |
+
file_format: Output format ('mp3', 'wav', etc.)
|
305 |
+
|
306 |
+
Returns:
|
307 |
+
Audio bytes
|
308 |
+
"""
|
309 |
+
with io.BytesIO() as buffer:
|
310 |
+
sf.write(buffer, audio_data, sr, format=file_format)
|
311 |
+
buffer.seek(0)
|
312 |
+
return buffer.read()
|
313 |
+
|
314 |
+
|
315 |
+
def format_time(seconds: float) -> str:
|
316 |
+
"""Format seconds as MM:SS.
|
317 |
+
|
318 |
+
Args:
|
319 |
+
seconds: Time in seconds
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
Formatted time string
|
323 |
+
"""
|
324 |
+
minutes = int(seconds // 60)
|
325 |
+
seconds = int(seconds % 60)
|
326 |
+
return f"{minutes:02d}:{seconds:02d}"
|
327 |
+
|
328 |
+
|
329 |
+
def main() -> None:
|
330 |
+
"""Main function for the Streamlit app."""
|
331 |
+
# Set page config
|
332 |
+
st.set_page_config(
|
333 |
+
page_title="Chorus Detection",
|
334 |
+
page_icon="🎵",
|
335 |
+
layout="wide",
|
336 |
+
initial_sidebar_state="collapsed",
|
337 |
+
)
|
338 |
+
|
339 |
+
# Apply custom theme
|
340 |
+
set_custom_theme()
|
341 |
+
|
342 |
+
# App title and description
|
343 |
+
st.title("Chorus Detection")
|
344 |
+
st.markdown("""
|
345 |
+
<div class="subheader">
|
346 |
+
Upload a song or enter a YouTube URL to automatically detect chorus sections using AI
|
347 |
+
</div>
|
348 |
+
""", unsafe_allow_html=True)
|
349 |
+
|
350 |
+
# User input section
|
351 |
+
col1, col2 = st.columns(2)
|
352 |
+
|
353 |
+
with col1:
|
354 |
+
st.markdown('<div class="input-option">', unsafe_allow_html=True)
|
355 |
+
st.subheader("Option 1: Upload an audio file")
|
356 |
+
uploaded_file = st.file_uploader("Choose an audio file", type=['mp3', 'wav', 'ogg', 'flac', 'm4a'])
|
357 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
358 |
+
|
359 |
+
with col2:
|
360 |
+
st.markdown('<div class="input-option">', unsafe_allow_html=True)
|
361 |
+
st.subheader("Option 2: YouTube URL")
|
362 |
+
youtube_url = st.text_input("Enter a YouTube URL", placeholder="https://www.youtube.com/watch?v=...")
|
363 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
364 |
+
|
365 |
+
# Process button
|
366 |
+
if st.button("Analyze"):
|
367 |
+
# Check the input method
|
368 |
+
audio_path = None
|
369 |
+
file_name = None
|
370 |
+
|
371 |
+
if uploaded_file is not None:
|
372 |
+
audio_path, file_name = process_uploaded_file(uploaded_file)
|
373 |
+
elif youtube_url:
|
374 |
+
if is_youtube_url(youtube_url):
|
375 |
+
audio_path, file_name = process_youtube(youtube_url)
|
376 |
+
else:
|
377 |
+
st.error("Invalid YouTube URL. Please enter a valid YouTube URL.")
|
378 |
+
else:
|
379 |
+
st.error("Please upload an audio file or enter a YouTube URL.")
|
380 |
+
|
381 |
+
# If we have a valid audio path, process it
|
382 |
+
if audio_path and file_name:
|
383 |
+
try:
|
384 |
+
# Load and process the audio file
|
385 |
+
with st.spinner('Processing audio...'):
|
386 |
+
# Load audio and extract features
|
387 |
+
y, sr = librosa.load(audio_path, sr=22050)
|
388 |
+
|
389 |
+
# Create a temporary directory for model output
|
390 |
+
temp_output_dir = tempfile.mkdtemp()
|
391 |
+
|
392 |
+
# Load the model
|
393 |
+
model = load_CRNN_model(MODEL_PATH)
|
394 |
+
|
395 |
+
# Process audio and make predictions
|
396 |
+
audio_features, _ = process_audio(audio_path, output_path=temp_output_dir)
|
397 |
+
meter_grid_times, predictions = make_predictions(model, audio_features)
|
398 |
+
|
399 |
+
# Smooth predictions to avoid rapid transitions
|
400 |
+
smoothed_predictions = np.convolve(predictions,
|
401 |
+
np.ones(5)/5,
|
402 |
+
mode='same')
|
403 |
+
|
404 |
+
# Extract chorus segments
|
405 |
+
chorus_segments = extract_chorus_segments(y, sr, smoothed_predictions, meter_grid_times)
|
406 |
+
|
407 |
+
# Create a chorus compilation
|
408 |
+
compilation_audio, segments_desc = create_chorus_compilation(chorus_segments, sr)
|
409 |
+
|
410 |
+
# Display results
|
411 |
+
st.markdown(f"""
|
412 |
+
<div class="result-container">
|
413 |
+
<div class="song-title">{file_name}</div>
|
414 |
+
</div>
|
415 |
+
""", unsafe_allow_html=True)
|
416 |
+
|
417 |
+
# Display waveform with highlighted chorus sections
|
418 |
+
fig, ax = plt.subplots(figsize=(14, 5))
|
419 |
+
|
420 |
+
# Plot the waveform
|
421 |
+
times = np.linspace(0, len(y)/sr, len(y))
|
422 |
+
ax.plot(times, y, color='#b3b3b3', alpha=0.5, linewidth=1)
|
423 |
+
ax.set_xlabel('Time (s)')
|
424 |
+
ax.set_ylabel('Amplitude')
|
425 |
+
ax.set_title('Audio Waveform with Chorus Sections Highlighted')
|
426 |
+
|
427 |
+
# Highlight chorus sections
|
428 |
+
for start_time, end_time, _ in chorus_segments:
|
429 |
+
ax.axvspan(start_time, end_time, alpha=0.3, color=THEME_COLORS['primary'])
|
430 |
+
|
431 |
+
# Add a label at the start of each chorus
|
432 |
+
ax.annotate('Chorus',
|
433 |
+
xy=(start_time, 0.8 * max(y)),
|
434 |
+
xytext=(start_time + 0.5, 0.9 * max(y)),
|
435 |
+
color=THEME_COLORS['primary'],
|
436 |
+
weight='bold')
|
437 |
+
|
438 |
+
# Customize plot appearance
|
439 |
+
ax.set_facecolor(THEME_COLORS['card_bg'])
|
440 |
+
fig.patch.set_facecolor(THEME_COLORS['background'])
|
441 |
+
ax.spines['top'].set_visible(False)
|
442 |
+
ax.spines['right'].set_visible(False)
|
443 |
+
ax.spines['bottom'].set_color(THEME_COLORS['border'])
|
444 |
+
ax.spines['left'].set_color(THEME_COLORS['border'])
|
445 |
+
ax.tick_params(axis='x', colors=THEME_COLORS['text'])
|
446 |
+
ax.tick_params(axis='y', colors=THEME_COLORS['text'])
|
447 |
+
ax.xaxis.label.set_color(THEME_COLORS['text'])
|
448 |
+
ax.yaxis.label.set_color(THEME_COLORS['text'])
|
449 |
+
ax.title.set_color(THEME_COLORS['text'])
|
450 |
+
|
451 |
+
st.pyplot(fig)
|
452 |
+
|
453 |
+
# Display chorus segments
|
454 |
+
if chorus_segments:
|
455 |
+
st.markdown('<div class="chorus-card">', unsafe_allow_html=True)
|
456 |
+
st.subheader("Chorus Segments")
|
457 |
+
for i, (start_time, end_time, segment_audio) in enumerate(chorus_segments):
|
458 |
+
st.markdown(f"""
|
459 |
+
<div class="time-stamp">Chorus {i+1}: {format_time(start_time)} - {format_time(end_time)}</div>
|
460 |
+
""", unsafe_allow_html=True)
|
461 |
+
|
462 |
+
# Convert segment audio to bytes for playback
|
463 |
+
audio_bytes = save_audio_for_streamlit(segment_audio, sr)
|
464 |
+
st.audio(audio_bytes, format='audio/mp3')
|
465 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
466 |
+
|
467 |
+
# Chorus compilation
|
468 |
+
if len(compilation_audio) > 0:
|
469 |
+
st.markdown('<div class="chorus-card">', unsafe_allow_html=True)
|
470 |
+
st.subheader("Chorus Compilation")
|
471 |
+
st.markdown("All chorus segments combined into one track:")
|
472 |
+
|
473 |
+
compilation_bytes = save_audio_for_streamlit(compilation_audio, sr)
|
474 |
+
st.audio(compilation_bytes, format='audio/mp3')
|
475 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
476 |
+
else:
|
477 |
+
st.info("No chorus sections detected in this audio.")
|
478 |
+
|
479 |
+
except Exception as e:
|
480 |
+
st.error(f"Error processing audio: {e}")
|
481 |
+
logger.error(f"Error processing audio: {e}", exc_info=True)
|
482 |
+
|
483 |
+
if __name__ == "__main__":
|
484 |
+
main()
|