Spaces:
Running
Running
Commit
·
606184e
1
Parent(s):
ad0da04
Refactor code and remove unnecessary files
Browse files- .space/app-entrypoint.sh +0 -0
- Dockerfile +5 -8
- app.py +0 -4
- download_model.py +2 -4
- setup.py +0 -1
- src/chorus_detection/__init__.py +0 -10
- src/chorus_detection/audio/__init__.py +0 -0
- src/chorus_detection/audio/data_processing.py +0 -180
- src/chorus_detection/audio/processor.py +0 -409
- src/chorus_detection/config.py +0 -54
- src/chorus_detection/models/__init__.py +0 -0
- src/chorus_detection/models/crnn.py +0 -186
- src/chorus_detection/utils/__init__.py +0 -0
- src/chorus_detection/utils/cli.py +0 -107
- src/chorus_detection/utils/logging.py +0 -53
- src/chorus_detection/visualization/__init__.py +0 -0
- src/chorus_detection/visualization/plotter.py +0 -78
- src/download_model.py +0 -188
- src/streamlit_app.py +0 -536
- streamlit_app.py +23 -133
.space/app-entrypoint.sh
CHANGED
Binary files a/.space/app-entrypoint.sh and b/.space/app-entrypoint.sh differ
|
|
Dockerfile
CHANGED
@@ -27,17 +27,14 @@ RUN pip install -e .
|
|
27 |
# Make the entry point script executable
|
28 |
RUN chmod +x .space/app-entrypoint.sh || echo "Could not chmod app-entrypoint.sh"
|
29 |
|
30 |
-
#
|
31 |
RUN cd /app && \
|
32 |
-
python -c "import chorus_detection; print(f'Successfully imported chorus_detection
|
33 |
echo "Warning: chorus_detection module not properly installed"
|
34 |
|
35 |
-
# Ensure model exists
|
36 |
-
RUN
|
37 |
-
echo "
|
38 |
-
python -c "import sys; print(f'Python path: {sys.path}')" && \
|
39 |
-
python -c "import os; print(f'Working directory: {os.getcwd()}')" && \
|
40 |
-
python -c "from download_model import ensure_model_exists; ensure_model_exists(revision='${MODEL_REVISION}')" || echo "Warning: Model download failed during build"
|
41 |
|
42 |
# Expose port for Streamlit
|
43 |
EXPOSE 7860
|
|
|
27 |
# Make the entry point script executable
|
28 |
RUN chmod +x .space/app-entrypoint.sh || echo "Could not chmod app-entrypoint.sh"
|
29 |
|
30 |
+
# Verify chorus_detection package installation
|
31 |
RUN cd /app && \
|
32 |
+
python -c "import chorus_detection; print(f'Successfully imported chorus_detection')" || \
|
33 |
echo "Warning: chorus_detection module not properly installed"
|
34 |
|
35 |
+
# Ensure model exists
|
36 |
+
RUN python -c "from download_model import ensure_model_exists; ensure_model_exists(revision='${MODEL_REVISION}')" || \
|
37 |
+
echo "Warning: Model download failed during build"
|
|
|
|
|
|
|
38 |
|
39 |
# Expose port for Streamlit
|
40 |
EXPOSE 7860
|
app.py
CHANGED
@@ -3,8 +3,6 @@
|
|
3 |
|
4 |
"""
|
5 |
Main entry point for the Chorus Detection Streamlit app.
|
6 |
-
This file is a simple wrapper that starts the Streamlit app
|
7 |
-
without circular imports.
|
8 |
"""
|
9 |
|
10 |
import os
|
@@ -29,9 +27,7 @@ if os.environ.get("SPACE_ID"):
|
|
29 |
def main():
|
30 |
"""Main entry point for the Streamlit app."""
|
31 |
logger.info("Starting Streamlit app...")
|
32 |
-
# Import the Streamlit app module directly
|
33 |
import streamlit_app
|
34 |
-
# Run the Streamlit app
|
35 |
streamlit_app.main()
|
36 |
|
37 |
if __name__ == "__main__":
|
|
|
3 |
|
4 |
"""
|
5 |
Main entry point for the Chorus Detection Streamlit app.
|
|
|
|
|
6 |
"""
|
7 |
|
8 |
import os
|
|
|
27 |
def main():
|
28 |
"""Main entry point for the Streamlit app."""
|
29 |
logger.info("Starting Streamlit app...")
|
|
|
30 |
import streamlit_app
|
|
|
31 |
streamlit_app.main()
|
32 |
|
33 |
if __name__ == "__main__":
|
download_model.py
CHANGED
@@ -148,17 +148,15 @@ def ensure_model_exists(
|
|
148 |
try:
|
149 |
if HF_HUB_AVAILABLE:
|
150 |
# Use huggingface_hub to download the model
|
151 |
-
logger.info(f"Downloading model from {repo_id}/{hf_model_filename}
|
152 |
downloaded_path = hf_hub_download(
|
153 |
repo_id=repo_id,
|
154 |
filename=hf_model_filename,
|
155 |
local_dir=model_dir,
|
156 |
local_dir_use_symlinks=False,
|
157 |
-
revision=revision
|
158 |
)
|
159 |
|
160 |
-
logger.info(f"Downloaded to: {downloaded_path}")
|
161 |
-
|
162 |
# Rename if necessary
|
163 |
if os.path.basename(downloaded_path) != model_filename:
|
164 |
downloaded_path_obj = Path(downloaded_path)
|
|
|
148 |
try:
|
149 |
if HF_HUB_AVAILABLE:
|
150 |
# Use huggingface_hub to download the model
|
151 |
+
logger.info(f"Downloading model from {repo_id}/{hf_model_filename}")
|
152 |
downloaded_path = hf_hub_download(
|
153 |
repo_id=repo_id,
|
154 |
filename=hf_model_filename,
|
155 |
local_dir=model_dir,
|
156 |
local_dir_use_symlinks=False,
|
157 |
+
revision=revision
|
158 |
)
|
159 |
|
|
|
|
|
160 |
# Rename if necessary
|
161 |
if os.path.basename(downloaded_path) != model_filename:
|
162 |
downloaded_path_obj = Path(downloaded_path)
|
setup.py
CHANGED
@@ -8,7 +8,6 @@ setup(
|
|
8 |
version="0.1.0",
|
9 |
packages=find_packages(),
|
10 |
install_requires=[
|
11 |
-
# These are already in requirements.txt so no need to specify versions
|
12 |
"numpy",
|
13 |
"scipy",
|
14 |
"tqdm",
|
|
|
8 |
version="0.1.0",
|
9 |
packages=find_packages(),
|
10 |
install_requires=[
|
|
|
11 |
"numpy",
|
12 |
"scipy",
|
13 |
"tqdm",
|
src/chorus_detection/__init__.py
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
"""Chorus Detection package for identifying choruses in music.
|
2 |
-
|
3 |
-
This package contains modules for:
|
4 |
-
- Audio processing and feature extraction
|
5 |
-
- Machine learning models for chorus detection
|
6 |
-
- Visualization tools for audio analysis
|
7 |
-
- Utility functions
|
8 |
-
"""
|
9 |
-
|
10 |
-
__version__ = "0.1.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/chorus_detection/audio/__init__.py
DELETED
File without changes
|
src/chorus_detection/audio/data_processing.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
"""Module for audio data processing including segmentation and positional encoding."""
|
5 |
-
|
6 |
-
from typing import List, Optional, Tuple, Any
|
7 |
-
|
8 |
-
import librosa
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
from chorus_detection.audio.processor import AudioFeature
|
12 |
-
from chorus_detection.config import SR, HOP_LENGTH, MAX_FRAMES, MAX_METERS, N_FEATURES
|
13 |
-
from chorus_detection.utils.logging import logger
|
14 |
-
|
15 |
-
|
16 |
-
def segment_data_meters(data: np.ndarray, meter_grid: np.ndarray) -> List[np.ndarray]:
|
17 |
-
"""Divide song data into segments based on measure grid frames.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
data: The song data to be segmented
|
21 |
-
meter_grid: The grid indicating the start of each measure
|
22 |
-
|
23 |
-
Returns:
|
24 |
-
A list of song data segments
|
25 |
-
"""
|
26 |
-
# Create segments using vectorized operations
|
27 |
-
meter_segments = [data[s:e] for s, e in zip(meter_grid[:-1], meter_grid[1:])]
|
28 |
-
|
29 |
-
# Convert all segments to float32 for consistent processing
|
30 |
-
meter_segments = [segment.astype(np.float32) for segment in meter_segments]
|
31 |
-
|
32 |
-
return meter_segments
|
33 |
-
|
34 |
-
|
35 |
-
def positional_encoding(position: int, d_model: int) -> np.ndarray:
|
36 |
-
"""Generate a positional encoding for a given position and model dimension.
|
37 |
-
|
38 |
-
Args:
|
39 |
-
position: The position for which to generate the encoding
|
40 |
-
d_model: The dimension of the model
|
41 |
-
|
42 |
-
Returns:
|
43 |
-
The positional encoding
|
44 |
-
"""
|
45 |
-
# Create position array
|
46 |
-
positions = np.arange(position)[:, np.newaxis]
|
47 |
-
|
48 |
-
# Calculate dimension-based scaling factors
|
49 |
-
dim_indices = np.arange(d_model)[np.newaxis, :]
|
50 |
-
angles = positions / np.power(10000, (2 * (dim_indices // 2)) / np.float32(d_model))
|
51 |
-
|
52 |
-
# Apply sine to even indices and cosine to odd indices
|
53 |
-
encodings = np.zeros((position, d_model), dtype=np.float32)
|
54 |
-
encodings[:, 0::2] = np.sin(angles[:, 0::2])
|
55 |
-
encodings[:, 1::2] = np.cos(angles[:, 1::2])
|
56 |
-
|
57 |
-
return encodings
|
58 |
-
|
59 |
-
|
60 |
-
def apply_hierarchical_positional_encoding(segments: List[np.ndarray]) -> List[np.ndarray]:
|
61 |
-
"""Apply positional encoding at the meter and frame levels to a list of segments.
|
62 |
-
|
63 |
-
Args:
|
64 |
-
segments: The list of segments to encode
|
65 |
-
|
66 |
-
Returns:
|
67 |
-
The list of segments with applied positional encoding
|
68 |
-
"""
|
69 |
-
if not segments:
|
70 |
-
logger.warning("No segments to encode")
|
71 |
-
return []
|
72 |
-
|
73 |
-
n_features = segments[0].shape[1]
|
74 |
-
|
75 |
-
# Generate measure-level positional encodings
|
76 |
-
measure_level_encodings = positional_encoding(len(segments), n_features)
|
77 |
-
|
78 |
-
# Apply hierarchical encodings to each segment
|
79 |
-
encoded_segments = []
|
80 |
-
for i, segment in enumerate(segments):
|
81 |
-
# Generate frame-level positional encoding
|
82 |
-
frame_level_encoding = positional_encoding(len(segment), n_features)
|
83 |
-
|
84 |
-
# Combine frame-level and measure-level encodings
|
85 |
-
encoded_segment = segment + frame_level_encoding + measure_level_encodings[i]
|
86 |
-
encoded_segments.append(encoded_segment)
|
87 |
-
|
88 |
-
return encoded_segments
|
89 |
-
|
90 |
-
|
91 |
-
def pad_song(encoded_segments: List[np.ndarray], max_frames: int = MAX_FRAMES,
|
92 |
-
max_meters: int = MAX_METERS, n_features: int = N_FEATURES) -> np.ndarray:
|
93 |
-
"""Pad or truncate the encoded segments to have the specified dimensions.
|
94 |
-
|
95 |
-
Args:
|
96 |
-
encoded_segments: The encoded segments to pad or truncate
|
97 |
-
max_frames: The maximum number of frames per segment
|
98 |
-
max_meters: The maximum number of meters
|
99 |
-
n_features: The number of features per frame
|
100 |
-
|
101 |
-
Returns:
|
102 |
-
The padded or truncated song as a numpy array
|
103 |
-
"""
|
104 |
-
if not encoded_segments:
|
105 |
-
logger.warning("No encoded segments to pad")
|
106 |
-
return np.zeros((max_meters, max_frames, n_features), dtype=np.float32)
|
107 |
-
|
108 |
-
# Pad or truncate each meter/segment to max_frames
|
109 |
-
padded_meters = []
|
110 |
-
for meter in encoded_segments:
|
111 |
-
# Truncate if longer than max_frames
|
112 |
-
truncated_meter = meter[:max_frames] if meter.shape[0] > max_frames else meter
|
113 |
-
|
114 |
-
# Pad if shorter than max_frames
|
115 |
-
if truncated_meter.shape[0] < max_frames:
|
116 |
-
padding = ((0, max_frames - truncated_meter.shape[0]), (0, 0))
|
117 |
-
padded_meter = np.pad(truncated_meter, padding, 'constant', constant_values=0)
|
118 |
-
else:
|
119 |
-
padded_meter = truncated_meter
|
120 |
-
|
121 |
-
padded_meters.append(padded_meter)
|
122 |
-
|
123 |
-
# Create padding meter (all zeros)
|
124 |
-
padding_meter = np.zeros((max_frames, n_features), dtype=np.float32)
|
125 |
-
|
126 |
-
# Truncate or pad to max_meters
|
127 |
-
if len(padded_meters) > max_meters:
|
128 |
-
padded_song = np.array(padded_meters[:max_meters])
|
129 |
-
else:
|
130 |
-
padded_song = np.array(padded_meters + [padding_meter] * (max_meters - len(padded_meters)))
|
131 |
-
|
132 |
-
return padded_song
|
133 |
-
|
134 |
-
|
135 |
-
def process_audio(audio_path: str, trim_silence: bool = True, sr: int = SR,
|
136 |
-
hop_length: int = HOP_LENGTH) -> Tuple[Optional[np.ndarray], Optional[AudioFeature]]:
|
137 |
-
"""Process an audio file, extracting features and applying positional encoding.
|
138 |
-
|
139 |
-
Args:
|
140 |
-
audio_path: The path to the audio file
|
141 |
-
trim_silence: Whether to trim silence from the audio
|
142 |
-
sr: The sample rate to use when loading the audio
|
143 |
-
hop_length: The hop length to use for feature extraction
|
144 |
-
|
145 |
-
Returns:
|
146 |
-
A tuple containing the processed audio and its features
|
147 |
-
"""
|
148 |
-
logger.info(f"Processing audio file: {audio_path}")
|
149 |
-
|
150 |
-
try:
|
151 |
-
# First optionally strip silence
|
152 |
-
if trim_silence:
|
153 |
-
from chorus_detection.audio.processor import strip_silence
|
154 |
-
strip_silence(audio_path)
|
155 |
-
|
156 |
-
# Create audio feature object and extract features
|
157 |
-
audio_features = AudioFeature(audio_path=audio_path, sr=sr, hop_length=hop_length)
|
158 |
-
audio_features.extract_features()
|
159 |
-
audio_features.create_meter_grid()
|
160 |
-
|
161 |
-
# Segment the audio data by meter grid
|
162 |
-
audio_segments = segment_data_meters(
|
163 |
-
audio_features.combined_features, audio_features.meter_grid)
|
164 |
-
|
165 |
-
# Apply positional encoding
|
166 |
-
encoded_audio_segments = apply_hierarchical_positional_encoding(audio_segments)
|
167 |
-
|
168 |
-
# Pad song to fixed dimensions and add batch dimension
|
169 |
-
processed_audio = np.expand_dims(pad_song(encoded_audio_segments), axis=0)
|
170 |
-
|
171 |
-
logger.info(f"Audio processing complete: {processed_audio.shape}")
|
172 |
-
return processed_audio, audio_features
|
173 |
-
|
174 |
-
except Exception as e:
|
175 |
-
logger.error(f"Error processing audio: {e}")
|
176 |
-
|
177 |
-
import traceback
|
178 |
-
logger.debug(traceback.format_exc())
|
179 |
-
|
180 |
-
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/chorus_detection/audio/processor.py
DELETED
@@ -1,409 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
"""Module for audio feature extraction and processing."""
|
5 |
-
|
6 |
-
import os
|
7 |
-
import subprocess
|
8 |
-
import time
|
9 |
-
from functools import reduce
|
10 |
-
from pathlib import Path
|
11 |
-
from typing import List, Tuple, Optional, Dict, Any, Union
|
12 |
-
|
13 |
-
import librosa
|
14 |
-
import numpy as np
|
15 |
-
from pydub import AudioSegment
|
16 |
-
from pydub.silence import detect_nonsilent
|
17 |
-
from sklearn.preprocessing import StandardScaler
|
18 |
-
|
19 |
-
from chorus_detection.config import SR, HOP_LENGTH, AUDIO_TEMP_PATH
|
20 |
-
from chorus_detection.utils.logging import logger
|
21 |
-
|
22 |
-
|
23 |
-
def extract_audio(url: str, output_path: str = str(AUDIO_TEMP_PATH)) -> Tuple[Optional[str], Optional[str]]:
|
24 |
-
"""Download audio from YouTube URL and save as MP3 using yt-dlp.
|
25 |
-
|
26 |
-
Args:
|
27 |
-
url: YouTube URL of the audio file
|
28 |
-
output_path: Path to save the downloaded audio file
|
29 |
-
|
30 |
-
Returns:
|
31 |
-
Tuple containing path to the downloaded audio file and the video title, or None if download fails
|
32 |
-
"""
|
33 |
-
try:
|
34 |
-
# Create output directory if it doesn't exist
|
35 |
-
os.makedirs(output_path, exist_ok=True)
|
36 |
-
|
37 |
-
# Create a unique filename using timestamp
|
38 |
-
timestamp = int(time.time())
|
39 |
-
output_file = os.path.join(output_path, f"audio_{timestamp}.mp3")
|
40 |
-
|
41 |
-
# Get the video title first
|
42 |
-
video_title = get_video_title(url) or f"Video_{timestamp}"
|
43 |
-
|
44 |
-
# Download the audio
|
45 |
-
success, error_msg = download_audio(url, output_file)
|
46 |
-
|
47 |
-
if not success:
|
48 |
-
handle_download_error(error_msg)
|
49 |
-
return None, None
|
50 |
-
|
51 |
-
# Check if file exists and is valid
|
52 |
-
if os.path.exists(output_file) and os.path.getsize(output_file) > 0:
|
53 |
-
logger.info(f"Successfully downloaded: {video_title}")
|
54 |
-
return output_file, video_title
|
55 |
-
else:
|
56 |
-
logger.error("Download completed but file not found or empty")
|
57 |
-
return None, None
|
58 |
-
|
59 |
-
except Exception as e:
|
60 |
-
import traceback
|
61 |
-
error_details = traceback.format_exc()
|
62 |
-
logger.error(f"An error occurred during YouTube download: {e}")
|
63 |
-
logger.debug(f"Error details: {error_details}")
|
64 |
-
|
65 |
-
check_yt_dlp_installation()
|
66 |
-
return None, None
|
67 |
-
|
68 |
-
|
69 |
-
def get_video_title(url: str) -> Optional[str]:
|
70 |
-
"""Get the title of a YouTube video.
|
71 |
-
|
72 |
-
Args:
|
73 |
-
url: YouTube URL
|
74 |
-
|
75 |
-
Returns:
|
76 |
-
Video title if successful, None otherwise
|
77 |
-
"""
|
78 |
-
try:
|
79 |
-
title_command = ['yt-dlp', '--get-title', '--no-warnings', url]
|
80 |
-
video_title = subprocess.check_output(title_command, universal_newlines=True).strip()
|
81 |
-
return video_title
|
82 |
-
except subprocess.CalledProcessError as e:
|
83 |
-
logger.warning(f"Could not retrieve video title: {str(e)}")
|
84 |
-
return None
|
85 |
-
|
86 |
-
|
87 |
-
def download_audio(url: str, output_file: str) -> Tuple[bool, str]:
|
88 |
-
"""Download audio from YouTube URL using yt-dlp.
|
89 |
-
|
90 |
-
Args:
|
91 |
-
url: YouTube URL
|
92 |
-
output_file: Output file path
|
93 |
-
|
94 |
-
Returns:
|
95 |
-
Tuple containing (success, error_message)
|
96 |
-
"""
|
97 |
-
command = [
|
98 |
-
'yt-dlp',
|
99 |
-
'-f', 'bestaudio',
|
100 |
-
'--extract-audio',
|
101 |
-
'--audio-format', 'mp3',
|
102 |
-
'--audio-quality', '0', # Best quality
|
103 |
-
'--output', output_file,
|
104 |
-
'--no-playlist',
|
105 |
-
'--verbose',
|
106 |
-
url
|
107 |
-
]
|
108 |
-
|
109 |
-
process = subprocess.Popen(
|
110 |
-
command,
|
111 |
-
stdout=subprocess.PIPE,
|
112 |
-
stderr=subprocess.PIPE,
|
113 |
-
universal_newlines=True
|
114 |
-
)
|
115 |
-
stdout, stderr = process.communicate()
|
116 |
-
|
117 |
-
if process.returncode != 0:
|
118 |
-
error_msg = f"Error downloading from YouTube (code {process.returncode}): {stderr}"
|
119 |
-
return False, error_msg
|
120 |
-
|
121 |
-
return True, ""
|
122 |
-
|
123 |
-
|
124 |
-
def handle_download_error(error_msg: str) -> None:
|
125 |
-
"""Handle common YouTube download errors with helpful messages.
|
126 |
-
|
127 |
-
Args:
|
128 |
-
error_msg: Error message from yt-dlp
|
129 |
-
"""
|
130 |
-
logger.error(error_msg)
|
131 |
-
|
132 |
-
if "Sign in to confirm you're not a bot" in error_msg:
|
133 |
-
logger.error("YouTube is detecting automated access. Try using a local file instead.")
|
134 |
-
elif any(x in error_msg.lower() for x in ["unavailable video", "private video"]):
|
135 |
-
logger.error("The video appears to be private or unavailable. Please try another URL.")
|
136 |
-
elif "copyright" in error_msg.lower():
|
137 |
-
logger.error("The video may be blocked due to copyright restrictions.")
|
138 |
-
elif any(x in error_msg.lower() for x in ["rate limit", "429"]):
|
139 |
-
logger.error("YouTube rate limit reached. Please try again later.")
|
140 |
-
|
141 |
-
|
142 |
-
def check_yt_dlp_installation() -> None:
|
143 |
-
"""Check if yt-dlp is installed and provide guidance if it's not."""
|
144 |
-
try:
|
145 |
-
subprocess.run(['yt-dlp', '--version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
146 |
-
except FileNotFoundError:
|
147 |
-
logger.error("yt-dlp is not installed or not in PATH. Please install it with: pip install yt-dlp")
|
148 |
-
|
149 |
-
|
150 |
-
def strip_silence(audio_path: str) -> None:
|
151 |
-
"""Remove silent parts from an audio file.
|
152 |
-
|
153 |
-
Args:
|
154 |
-
audio_path: Path to the audio file
|
155 |
-
"""
|
156 |
-
try:
|
157 |
-
sound = AudioSegment.from_file(audio_path)
|
158 |
-
nonsilent_ranges = detect_nonsilent(
|
159 |
-
sound, min_silence_len=500, silence_thresh=-50)
|
160 |
-
|
161 |
-
if not nonsilent_ranges:
|
162 |
-
logger.warning("No non-silent parts detected in the audio. Using original file.")
|
163 |
-
return
|
164 |
-
|
165 |
-
stripped = reduce(lambda acc, val: acc + sound[val[0]:val[1]],
|
166 |
-
nonsilent_ranges, AudioSegment.empty())
|
167 |
-
stripped.export(audio_path, format='mp3')
|
168 |
-
except Exception as e:
|
169 |
-
logger.error(f"Error stripping silence: {e}")
|
170 |
-
logger.info("Proceeding with original audio file")
|
171 |
-
|
172 |
-
|
173 |
-
class AudioFeature:
|
174 |
-
"""Class for extracting and processing audio features."""
|
175 |
-
|
176 |
-
def __init__(self, audio_path: str, sr: int = SR, hop_length: int = HOP_LENGTH):
|
177 |
-
"""Initialize the AudioFeature class.
|
178 |
-
|
179 |
-
Args:
|
180 |
-
audio_path: Path to the audio file
|
181 |
-
sr: Sample rate for audio processing
|
182 |
-
hop_length: Hop length for feature extraction
|
183 |
-
"""
|
184 |
-
self.audio_path: str = audio_path
|
185 |
-
self.sr: int = sr
|
186 |
-
self.hop_length: int = hop_length
|
187 |
-
self.time_signature: int = 4
|
188 |
-
|
189 |
-
# Initialize all features as None
|
190 |
-
self.y: Optional[np.ndarray] = None
|
191 |
-
self.y_harm: Optional[np.ndarray] = None
|
192 |
-
self.y_perc: Optional[np.ndarray] = None
|
193 |
-
self.beats: Optional[np.ndarray] = None
|
194 |
-
self.chroma_acts: Optional[np.ndarray] = None
|
195 |
-
self.chromagram: Optional[np.ndarray] = None
|
196 |
-
self.combined_features: Optional[np.ndarray] = None
|
197 |
-
self.key: Optional[str] = None
|
198 |
-
self.mode: Optional[str] = None
|
199 |
-
self.mel_acts: Optional[np.ndarray] = None
|
200 |
-
self.melspectrogram: Optional[np.ndarray] = None
|
201 |
-
self.meter_grid: Optional[np.ndarray] = None
|
202 |
-
self.mfccs: Optional[np.ndarray] = None
|
203 |
-
self.mfcc_acts: Optional[np.ndarray] = None
|
204 |
-
self.n_frames: Optional[int] = None
|
205 |
-
self.onset_env: Optional[np.ndarray] = None
|
206 |
-
self.rms: Optional[np.ndarray] = None
|
207 |
-
self.spectrogram: Optional[np.ndarray] = None
|
208 |
-
self.tempo: Optional[float] = None
|
209 |
-
self.tempogram: Optional[np.ndarray] = None
|
210 |
-
self.tempogram_acts: Optional[np.ndarray] = None
|
211 |
-
|
212 |
-
def detect_key(self, chroma_vals: np.ndarray) -> Tuple[str, str]:
|
213 |
-
"""Detect the key and mode (major or minor) of the audio segment.
|
214 |
-
|
215 |
-
Args:
|
216 |
-
chroma_vals: Chromagram values to analyze for key detection
|
217 |
-
|
218 |
-
Returns:
|
219 |
-
Tuple containing the detected key and mode
|
220 |
-
"""
|
221 |
-
note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
|
222 |
-
|
223 |
-
# Key profiles (Krumhansl-Kessler profiles)
|
224 |
-
major_profile = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88])
|
225 |
-
minor_profile = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17])
|
226 |
-
|
227 |
-
# Normalize profiles
|
228 |
-
major_profile /= np.linalg.norm(major_profile)
|
229 |
-
minor_profile /= np.linalg.norm(minor_profile)
|
230 |
-
|
231 |
-
# Calculate correlations for all possible rotations
|
232 |
-
major_correlations = [np.corrcoef(chroma_vals, np.roll(major_profile, i))[0, 1] for i in range(12)]
|
233 |
-
minor_correlations = [np.corrcoef(chroma_vals, np.roll(minor_profile, i))[0, 1] for i in range(12)]
|
234 |
-
|
235 |
-
# Find max correlation
|
236 |
-
max_major_idx = np.argmax(major_correlations)
|
237 |
-
max_minor_idx = np.argmax(minor_correlations)
|
238 |
-
|
239 |
-
# Determine mode
|
240 |
-
self.mode = 'major' if major_correlations[max_major_idx] > minor_correlations[max_minor_idx] else 'minor'
|
241 |
-
self.key = note_names[max_major_idx if self.mode == 'major' else max_minor_idx]
|
242 |
-
|
243 |
-
return self.key, self.mode
|
244 |
-
|
245 |
-
def calculate_ki_chroma(self, waveform: np.ndarray, sr: int, hop_length: int) -> np.ndarray:
|
246 |
-
"""Calculate a normalized, key-invariant chromagram for the given audio waveform.
|
247 |
-
|
248 |
-
Args:
|
249 |
-
waveform: Audio waveform to analyze
|
250 |
-
sr: Sample rate of the waveform
|
251 |
-
hop_length: Hop length for feature extraction
|
252 |
-
|
253 |
-
Returns:
|
254 |
-
The key-invariant chromagram as a numpy array
|
255 |
-
"""
|
256 |
-
# Calculate chromagram
|
257 |
-
chromagram = librosa.feature.chroma_cqt(
|
258 |
-
y=waveform, sr=sr, hop_length=hop_length, bins_per_octave=24)
|
259 |
-
|
260 |
-
# Normalize to [0, 1]
|
261 |
-
chromagram = (chromagram - chromagram.min()) / (chromagram.max() - chromagram.min() + 1e-8)
|
262 |
-
|
263 |
-
# Detect key
|
264 |
-
chroma_vals = np.sum(chromagram, axis=1)
|
265 |
-
key, mode = self.detect_key(chroma_vals)
|
266 |
-
|
267 |
-
# Make key-invariant
|
268 |
-
key_idx = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'].index(key)
|
269 |
-
shift_amount = -key_idx if mode == 'major' else -(key_idx + 3) % 12
|
270 |
-
|
271 |
-
return librosa.util.normalize(np.roll(chromagram, shift_amount, axis=0), axis=1)
|
272 |
-
|
273 |
-
def extract_features(self) -> None:
|
274 |
-
"""Extract various audio features from the loaded audio."""
|
275 |
-
# Load audio
|
276 |
-
self.y, self.sr = librosa.load(self.audio_path, sr=self.sr)
|
277 |
-
|
278 |
-
# Harmonic-percussive source separation
|
279 |
-
self.y_harm, self.y_perc = librosa.effects.hpss(self.y)
|
280 |
-
|
281 |
-
# Extract spectrogram
|
282 |
-
self.spectrogram, _ = librosa.magphase(librosa.stft(self.y, hop_length=self.hop_length))
|
283 |
-
|
284 |
-
# RMS energy
|
285 |
-
self.rms = librosa.feature.rms(S=self.spectrogram, hop_length=self.hop_length).astype(np.float32)
|
286 |
-
|
287 |
-
# Mel spectrogram and activations
|
288 |
-
self.melspectrogram = librosa.feature.melspectrogram(
|
289 |
-
y=self.y, sr=self.sr, n_mels=128, hop_length=self.hop_length).astype(np.float32)
|
290 |
-
self.mel_acts = librosa.decompose.decompose(self.melspectrogram, n_components=3, sort=True)[1].astype(np.float32)
|
291 |
-
|
292 |
-
# Chromagram and activations
|
293 |
-
self.chromagram = self.calculate_ki_chroma(self.y_harm, self.sr, self.hop_length).astype(np.float32)
|
294 |
-
self.chroma_acts = librosa.decompose.decompose(self.chromagram, n_components=4, sort=True)[1].astype(np.float32)
|
295 |
-
|
296 |
-
# Onset detection and tempogram
|
297 |
-
self.onset_env = librosa.onset.onset_strength(y=self.y_perc, sr=self.sr, hop_length=self.hop_length)
|
298 |
-
self.tempogram = np.clip(librosa.feature.tempogram(
|
299 |
-
onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length), 0, None)
|
300 |
-
self.tempogram_acts = librosa.decompose.decompose(self.tempogram, n_components=3, sort=True)[1]
|
301 |
-
|
302 |
-
# MFCCs and activations
|
303 |
-
self.mfccs = librosa.feature.mfcc(y=self.y, sr=self.sr, n_mfcc=20, hop_length=self.hop_length)
|
304 |
-
self.mfccs += abs(np.min(self.mfccs) or 0) # Handle negative values
|
305 |
-
self.mfcc_acts = librosa.decompose.decompose(self.mfccs, n_components=4, sort=True)[1].astype(np.float32)
|
306 |
-
|
307 |
-
# Combine features with weighted normalization
|
308 |
-
self._combine_features()
|
309 |
-
|
310 |
-
def _combine_features(self) -> None:
|
311 |
-
"""Combine all extracted features with balanced weights."""
|
312 |
-
features = [self.rms, self.mel_acts, self.chroma_acts, self.tempogram_acts, self.mfcc_acts]
|
313 |
-
feature_names = ['rms', 'mel_acts', 'chroma_acts', 'tempogram_acts', 'mfcc_acts']
|
314 |
-
|
315 |
-
# Calculate dimension-based weights
|
316 |
-
dims = {name: feature.shape[0] for feature, name in zip(features, feature_names)}
|
317 |
-
total_inv_dim = sum(1 / dim for dim in dims.values())
|
318 |
-
weights = {name: 1 / (dims[name] * total_inv_dim) for name in feature_names}
|
319 |
-
|
320 |
-
# Normalize and weight each feature
|
321 |
-
std_weighted_features = [
|
322 |
-
StandardScaler().fit_transform(feature.T).T * weights[name]
|
323 |
-
for feature, name in zip(features, feature_names)
|
324 |
-
]
|
325 |
-
|
326 |
-
# Combine features
|
327 |
-
self.combined_features = np.concatenate(std_weighted_features, axis=0).T.astype(np.float32)
|
328 |
-
self.n_frames = len(self.combined_features)
|
329 |
-
|
330 |
-
def create_meter_grid(self) -> np.ndarray:
|
331 |
-
"""Create a grid based on the meter of the song using tempo and beats.
|
332 |
-
|
333 |
-
Returns:
|
334 |
-
Numpy array containing the meter grid frame positions
|
335 |
-
"""
|
336 |
-
# Extract tempo and beat information
|
337 |
-
self.tempo, self.beats = librosa.beat.beat_track(
|
338 |
-
onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length)
|
339 |
-
|
340 |
-
# Adjust tempo if it's too slow or too fast
|
341 |
-
self.tempo = self._adjust_tempo(self.tempo)
|
342 |
-
|
343 |
-
# Create meter grid
|
344 |
-
self.meter_grid = self._create_meter_grid()
|
345 |
-
return self.meter_grid
|
346 |
-
|
347 |
-
def _adjust_tempo(self, tempo: float) -> float:
|
348 |
-
"""Adjust tempo to a reasonable range.
|
349 |
-
|
350 |
-
Args:
|
351 |
-
tempo: Detected tempo
|
352 |
-
|
353 |
-
Returns:
|
354 |
-
Adjusted tempo
|
355 |
-
"""
|
356 |
-
if tempo < 70:
|
357 |
-
return tempo * 2
|
358 |
-
elif tempo > 140:
|
359 |
-
return tempo / 2
|
360 |
-
return tempo
|
361 |
-
|
362 |
-
def _create_meter_grid(self) -> np.ndarray:
|
363 |
-
"""Helper function to create a meter grid for the song.
|
364 |
-
|
365 |
-
Returns:
|
366 |
-
Numpy array containing the meter grid frame positions
|
367 |
-
"""
|
368 |
-
# Calculate beat interval
|
369 |
-
seconds_per_beat = 60 / self.tempo
|
370 |
-
beat_interval = int(librosa.time_to_frames(seconds_per_beat, sr=self.sr, hop_length=self.hop_length))
|
371 |
-
|
372 |
-
# Find best matching start beat
|
373 |
-
if len(self.beats) >= 3:
|
374 |
-
best_match = max(
|
375 |
-
(1 - abs(np.mean(self.beats[i:i+3]) - beat_interval) / beat_interval, self.beats[i])
|
376 |
-
for i in range(len(self.beats) - 2)
|
377 |
-
)
|
378 |
-
anchor_frame = best_match[1] if best_match[0] > 0.95 else self.beats[0]
|
379 |
-
else:
|
380 |
-
anchor_frame = self.beats[0] if len(self.beats) > 0 else 0
|
381 |
-
|
382 |
-
first_beat_time = librosa.frames_to_time(anchor_frame, sr=self.sr, hop_length=self.hop_length)
|
383 |
-
|
384 |
-
# Calculate beats forward and backward
|
385 |
-
time_duration = librosa.frames_to_time(self.n_frames, sr=self.sr, hop_length=self.hop_length)
|
386 |
-
num_beats_forward = int((time_duration - first_beat_time) / seconds_per_beat)
|
387 |
-
num_beats_backward = int(first_beat_time / seconds_per_beat) + 1
|
388 |
-
|
389 |
-
# Create beat times
|
390 |
-
beat_times_forward = first_beat_time + np.arange(num_beats_forward) * seconds_per_beat
|
391 |
-
beat_times_backward = first_beat_time - np.arange(1, num_beats_backward) * seconds_per_beat
|
392 |
-
|
393 |
-
# Combine and create meter grid
|
394 |
-
beat_grid = np.concatenate((np.array([0.0]), beat_times_backward[::-1], beat_times_forward))
|
395 |
-
meter_indices = np.arange(0, len(beat_grid), self.time_signature)
|
396 |
-
meter_grid = beat_grid[meter_indices]
|
397 |
-
|
398 |
-
# Ensure grid starts at 0 and ends at frame duration
|
399 |
-
if meter_grid[0] != 0.0:
|
400 |
-
meter_grid = np.insert(meter_grid, 0, 0.0)
|
401 |
-
|
402 |
-
# Convert to frames
|
403 |
-
meter_grid = librosa.time_to_frames(meter_grid, sr=self.sr, hop_length=self.hop_length)
|
404 |
-
|
405 |
-
# Ensure grid ends at the last frame
|
406 |
-
if meter_grid[-1] != self.n_frames:
|
407 |
-
meter_grid = np.append(meter_grid, self.n_frames)
|
408 |
-
|
409 |
-
return meter_grid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/chorus_detection/config.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
"""Configuration settings for the chorus detection system."""
|
5 |
-
|
6 |
-
import os
|
7 |
-
from pathlib import Path
|
8 |
-
from typing import Dict, Any, Union, Optional
|
9 |
-
|
10 |
-
# Audio processing settings
|
11 |
-
SR: int = 12000 # Sample rate
|
12 |
-
HOP_LENGTH: int = 128 # Hop length for signal processing
|
13 |
-
MAX_FRAMES: int = 300 # Maximum frames per segment
|
14 |
-
MAX_METERS: int = 201 # Maximum meters per song
|
15 |
-
N_FEATURES: int = 15 # Number of features
|
16 |
-
|
17 |
-
# Project paths
|
18 |
-
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.resolve()
|
19 |
-
MODEL_DIR: Path = PROJECT_ROOT / "models" / "CRNN"
|
20 |
-
MODEL_PATH: Path = MODEL_DIR / "best_model_V3.h5"
|
21 |
-
AUDIO_TEMP_PATH: Path = PROJECT_ROOT / "output" / "temp"
|
22 |
-
LOG_DIR: Path = PROJECT_ROOT / "logs"
|
23 |
-
|
24 |
-
# Alternative Docker paths
|
25 |
-
DOCKER_MODEL_PATH: str = "/app/models/CRNN/best_model_V3.h5"
|
26 |
-
DOCKER_TEMP_PATH: str = "/app/output/temp"
|
27 |
-
|
28 |
-
|
29 |
-
def get_env_path(env_var: str, default_path: Path) -> Path:
|
30 |
-
"""Get a path from environment variable or use the default.
|
31 |
-
|
32 |
-
Args:
|
33 |
-
env_var: Name of the environment variable
|
34 |
-
default_path: Default path to use if environment variable is not set
|
35 |
-
|
36 |
-
Returns:
|
37 |
-
Path object for the specified location
|
38 |
-
"""
|
39 |
-
env_value = os.environ.get(env_var)
|
40 |
-
if env_value:
|
41 |
-
return Path(env_value).resolve()
|
42 |
-
return default_path
|
43 |
-
|
44 |
-
|
45 |
-
# Override paths with environment variables if provided
|
46 |
-
MODEL_PATH = get_env_path("CHORUS_MODEL_PATH", MODEL_PATH)
|
47 |
-
AUDIO_TEMP_PATH = get_env_path("CHORUS_TEMP_PATH", AUDIO_TEMP_PATH)
|
48 |
-
LOG_DIR = get_env_path("CHORUS_LOG_DIR", LOG_DIR)
|
49 |
-
|
50 |
-
# Create necessary directories
|
51 |
-
os.makedirs(MODEL_DIR, exist_ok=True)
|
52 |
-
os.makedirs(AUDIO_TEMP_PATH, exist_ok=True)
|
53 |
-
os.makedirs(LOG_DIR, exist_ok=True)
|
54 |
-
os.makedirs(PROJECT_ROOT / "output", exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/chorus_detection/models/__init__.py
DELETED
File without changes
|
src/chorus_detection/models/crnn.py
DELETED
@@ -1,186 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
"""Module for loading and managing the CRNN model for chorus detection."""
|
5 |
-
|
6 |
-
import os
|
7 |
-
from typing import Any, Optional, List, Tuple, Union
|
8 |
-
|
9 |
-
import numpy as np
|
10 |
-
import tensorflow as tf
|
11 |
-
|
12 |
-
from chorus_detection.config import MODEL_PATH, DOCKER_MODEL_PATH
|
13 |
-
from chorus_detection.utils.logging import logger
|
14 |
-
|
15 |
-
|
16 |
-
def load_CRNN_model(model_path: str = str(MODEL_PATH)) -> tf.keras.Model:
|
17 |
-
"""Load a CRNN model with custom loss and accuracy functions.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
model_path: Path to the saved model
|
21 |
-
|
22 |
-
Returns:
|
23 |
-
Loaded Keras model
|
24 |
-
|
25 |
-
Raises:
|
26 |
-
RuntimeError: If the model cannot be loaded
|
27 |
-
"""
|
28 |
-
try:
|
29 |
-
# Define custom objects required for model loading
|
30 |
-
custom_objects = {
|
31 |
-
'custom_binary_crossentropy': lambda y_true, y_pred: y_pred,
|
32 |
-
'custom_accuracy': lambda y_true, y_pred: y_pred
|
33 |
-
}
|
34 |
-
|
35 |
-
# Try to load the model with custom objects
|
36 |
-
logger.info(f"Loading model from: {model_path}")
|
37 |
-
model = tf.keras.models.load_model(
|
38 |
-
model_path, custom_objects=custom_objects, compile=False)
|
39 |
-
|
40 |
-
# Compile the model with default optimizer and loss for prediction only
|
41 |
-
model.compile(optimizer='adam', loss='binary_crossentropy')
|
42 |
-
|
43 |
-
return model
|
44 |
-
except Exception as e:
|
45 |
-
logger.error(f"Error loading model from {model_path}: {e}")
|
46 |
-
|
47 |
-
# Try Docker container path as fallback
|
48 |
-
if model_path != DOCKER_MODEL_PATH and os.path.exists(DOCKER_MODEL_PATH):
|
49 |
-
logger.info(f"Trying Docker path: {DOCKER_MODEL_PATH}")
|
50 |
-
return load_CRNN_model(DOCKER_MODEL_PATH)
|
51 |
-
|
52 |
-
raise RuntimeError(f"Failed to load model: {e}")
|
53 |
-
|
54 |
-
|
55 |
-
def smooth_predictions(predictions: np.ndarray) -> np.ndarray:
|
56 |
-
"""Smooth predictions by correcting isolated mispredictions and removing short sequences.
|
57 |
-
|
58 |
-
Args:
|
59 |
-
predictions: Array of binary predictions
|
60 |
-
|
61 |
-
Returns:
|
62 |
-
Smoothed array of binary predictions
|
63 |
-
"""
|
64 |
-
# Convert to numpy array if not already
|
65 |
-
data = np.array(predictions, copy=True) if not isinstance(predictions, np.ndarray) else predictions.copy()
|
66 |
-
|
67 |
-
# First pass: Correct isolated 0's (handle 0's surrounded by 1's)
|
68 |
-
for i in range(1, len(data) - 1):
|
69 |
-
if data[i] == 0 and data[i - 1] == 1 and data[i + 1] == 1:
|
70 |
-
data[i] = 1
|
71 |
-
|
72 |
-
# Second pass: Correct isolated 1's (handle 1's surrounded by 0's)
|
73 |
-
corrected_data = data.copy()
|
74 |
-
for i in range(1, len(data) - 1):
|
75 |
-
if data[i] == 1 and data[i - 1] == 0 and data[i + 1] == 0:
|
76 |
-
corrected_data[i] = 0
|
77 |
-
|
78 |
-
# Third pass: Remove short sequences of 1s (less than 5 consecutive 1's)
|
79 |
-
smoothed_data = corrected_data.copy()
|
80 |
-
sequence_start = None
|
81 |
-
|
82 |
-
for i in range(len(corrected_data)):
|
83 |
-
if corrected_data[i] == 1:
|
84 |
-
if sequence_start is None:
|
85 |
-
sequence_start = i
|
86 |
-
else:
|
87 |
-
if sequence_start is not None:
|
88 |
-
sequence_length = i - sequence_start
|
89 |
-
if sequence_length < 5:
|
90 |
-
smoothed_data[sequence_start:i] = 0
|
91 |
-
sequence_start = None
|
92 |
-
|
93 |
-
# Handle the case where the sequence extends to the end
|
94 |
-
if sequence_start is not None:
|
95 |
-
sequence_length = len(corrected_data) - sequence_start
|
96 |
-
if sequence_length < 5:
|
97 |
-
smoothed_data[sequence_start:] = 0
|
98 |
-
|
99 |
-
return smoothed_data
|
100 |
-
|
101 |
-
|
102 |
-
def make_predictions(model: tf.keras.Model, processed_audio: np.ndarray,
|
103 |
-
audio_features: Any, url: Optional[str] = None,
|
104 |
-
video_name: Optional[str] = None) -> np.ndarray:
|
105 |
-
"""Generate predictions from the model and process them.
|
106 |
-
|
107 |
-
Args:
|
108 |
-
model: The loaded model for making predictions
|
109 |
-
processed_audio: The audio data that has been processed for prediction
|
110 |
-
audio_features: Audio features object containing necessary metadata
|
111 |
-
url: YouTube URL of the audio file (optional)
|
112 |
-
video_name: Name of the video (optional)
|
113 |
-
|
114 |
-
Returns:
|
115 |
-
The smoothed binary predictions
|
116 |
-
"""
|
117 |
-
import librosa
|
118 |
-
|
119 |
-
logger.info("Generating predictions...")
|
120 |
-
|
121 |
-
# Make predictions
|
122 |
-
predictions = model.predict(processed_audio)[0]
|
123 |
-
|
124 |
-
# Convert to binary predictions and handle potential size mismatch
|
125 |
-
meter_grid_length = len(audio_features.meter_grid) - 1
|
126 |
-
if len(predictions) > meter_grid_length:
|
127 |
-
predictions = predictions[:meter_grid_length]
|
128 |
-
|
129 |
-
binary_predictions = np.round(predictions).flatten()
|
130 |
-
|
131 |
-
# Apply smoothing to improve prediction quality
|
132 |
-
smoothed_predictions = smooth_predictions(binary_predictions)
|
133 |
-
|
134 |
-
# Get times for identified chorus sections
|
135 |
-
meter_grid_times = librosa.frames_to_time(
|
136 |
-
audio_features.meter_grid,
|
137 |
-
sr=audio_features.sr,
|
138 |
-
hop_length=audio_features.hop_length
|
139 |
-
)
|
140 |
-
|
141 |
-
# Identify where choruses start
|
142 |
-
chorus_start_times = [
|
143 |
-
meter_grid_times[i] for i in range(len(smoothed_predictions))
|
144 |
-
if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)
|
145 |
-
]
|
146 |
-
|
147 |
-
# Print results if URL and video name are provided (CLI mode)
|
148 |
-
if url and video_name:
|
149 |
-
_print_chorus_results(url, video_name, chorus_start_times)
|
150 |
-
|
151 |
-
return smoothed_predictions
|
152 |
-
|
153 |
-
|
154 |
-
def _print_chorus_results(url: str, video_name: str, chorus_start_times: List[float]) -> None:
|
155 |
-
"""Print formatted results showing identified choruses with links.
|
156 |
-
|
157 |
-
Args:
|
158 |
-
url: YouTube URL of the analyzed video
|
159 |
-
video_name: Name of the video
|
160 |
-
chorus_start_times: List of start times (in seconds) for identified choruses
|
161 |
-
"""
|
162 |
-
# Create YouTube links with time stamps
|
163 |
-
youtube_links = [
|
164 |
-
f"\033]8;;{url}&t={int(start_time)}s\033\\{url}&t={int(start_time)}s\033]8;;\033\\"
|
165 |
-
for start_time in chorus_start_times
|
166 |
-
]
|
167 |
-
|
168 |
-
# Format the output
|
169 |
-
link_lengths = [len(link) for link in youtube_links]
|
170 |
-
max_length = max(link_lengths + [len(video_name), len(f"Number of choruses identified: {len(chorus_start_times)}")]) if link_lengths else 50
|
171 |
-
header_footer = "=" * (max_length + 4)
|
172 |
-
|
173 |
-
# Print the results
|
174 |
-
print("\n\n")
|
175 |
-
print(header_footer)
|
176 |
-
print(f"{video_name.center(max_length + 2)}")
|
177 |
-
print(f"Number of choruses identified: {len(chorus_start_times)}".center(max_length + 4))
|
178 |
-
print(header_footer)
|
179 |
-
|
180 |
-
if chorus_start_times:
|
181 |
-
for link in youtube_links:
|
182 |
-
print(link)
|
183 |
-
else:
|
184 |
-
print("No choruses identified.")
|
185 |
-
|
186 |
-
print(header_footer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/chorus_detection/utils/__init__.py
DELETED
File without changes
|
src/chorus_detection/utils/cli.py
DELETED
@@ -1,107 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
"""Command-line interface utilities for the chorus detection system."""
|
5 |
-
|
6 |
-
import argparse
|
7 |
-
import os
|
8 |
-
import sys
|
9 |
-
from pathlib import Path
|
10 |
-
from typing import Dict, Any, Optional, Tuple
|
11 |
-
|
12 |
-
from chorus_detection.config import MODEL_PATH
|
13 |
-
from chorus_detection.utils.logging import logger
|
14 |
-
|
15 |
-
|
16 |
-
def parse_arguments() -> argparse.Namespace:
|
17 |
-
"""Parse command-line arguments.
|
18 |
-
|
19 |
-
Returns:
|
20 |
-
Parsed command-line arguments
|
21 |
-
"""
|
22 |
-
parser = argparse.ArgumentParser(
|
23 |
-
description="Chorus Finder - Detect choruses in songs from YouTube URLs or local audio files")
|
24 |
-
|
25 |
-
input_group = parser.add_mutually_exclusive_group()
|
26 |
-
input_group.add_argument("--url", type=str,
|
27 |
-
help="YouTube URL of a song")
|
28 |
-
input_group.add_argument("--file", type=str,
|
29 |
-
help="Path to a local audio file")
|
30 |
-
|
31 |
-
parser.add_argument("--model_path", type=str, default=str(MODEL_PATH),
|
32 |
-
help=f"Path to the pretrained model (default: {MODEL_PATH})")
|
33 |
-
parser.add_argument("--verbose", action="store_true",
|
34 |
-
help="Verbose output", default=True)
|
35 |
-
parser.add_argument("--plot", action="store_true",
|
36 |
-
help="Display plot of the audio waveform", default=True)
|
37 |
-
parser.add_argument("--no-plot", dest="plot", action="store_false",
|
38 |
-
help="Disable plot display (useful for headless environments)")
|
39 |
-
|
40 |
-
return parser.parse_args()
|
41 |
-
|
42 |
-
|
43 |
-
def get_input_source(args: argparse.Namespace) -> Optional[str]:
|
44 |
-
"""Get input source from arguments or user input.
|
45 |
-
|
46 |
-
Args:
|
47 |
-
args: Parsed command-line arguments
|
48 |
-
|
49 |
-
Returns:
|
50 |
-
Input source (URL or file path)
|
51 |
-
"""
|
52 |
-
input_source = args.url or args.file
|
53 |
-
if not input_source:
|
54 |
-
print("\nChorus Detection Tool")
|
55 |
-
print("====================")
|
56 |
-
print("\nNote: YouTube download functionality may be temporarily unavailable")
|
57 |
-
print("due to YouTube's restrictions. If download fails, please use a local audio file.\n")
|
58 |
-
print("Choose input method:")
|
59 |
-
print("1. YouTube URL")
|
60 |
-
print("2. Local audio file")
|
61 |
-
choice = input("Enter choice (1 or 2): ")
|
62 |
-
|
63 |
-
if choice == "1":
|
64 |
-
input_source = input("Please enter the YouTube URL of the song: ")
|
65 |
-
elif choice == "2":
|
66 |
-
input_source = input("Please enter the path to the audio file: ")
|
67 |
-
else:
|
68 |
-
logger.error("Invalid choice")
|
69 |
-
sys.exit(1)
|
70 |
-
|
71 |
-
return input_source
|
72 |
-
|
73 |
-
|
74 |
-
def is_youtube_url(input_source: str) -> bool:
|
75 |
-
"""Check if the input source is a YouTube URL.
|
76 |
-
|
77 |
-
Args:
|
78 |
-
input_source: Input source to check
|
79 |
-
|
80 |
-
Returns:
|
81 |
-
True if the input source is a YouTube URL, False otherwise
|
82 |
-
"""
|
83 |
-
return input_source.startswith(('http://', 'https://'))
|
84 |
-
|
85 |
-
|
86 |
-
def validate_input_file(file_path: str) -> bool:
|
87 |
-
"""Validate that the input file exists and is readable.
|
88 |
-
|
89 |
-
Args:
|
90 |
-
file_path: Path to the input file
|
91 |
-
|
92 |
-
Returns:
|
93 |
-
True if the file is valid, False otherwise
|
94 |
-
"""
|
95 |
-
if not os.path.exists(file_path):
|
96 |
-
logger.error(f"Error: File not found at {file_path}")
|
97 |
-
return False
|
98 |
-
|
99 |
-
if not os.path.isfile(file_path):
|
100 |
-
logger.error(f"Error: {file_path} is not a file")
|
101 |
-
return False
|
102 |
-
|
103 |
-
if not os.access(file_path, os.R_OK):
|
104 |
-
logger.error(f"Error: No permission to read {file_path}")
|
105 |
-
return False
|
106 |
-
|
107 |
-
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/chorus_detection/utils/logging.py
DELETED
@@ -1,53 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
"""Logging configuration for the chorus detection system."""
|
5 |
-
|
6 |
-
import logging
|
7 |
-
import os
|
8 |
-
import sys
|
9 |
-
from typing import Optional
|
10 |
-
|
11 |
-
from chorus_detection.config import PROJECT_ROOT
|
12 |
-
|
13 |
-
# Create logs directory
|
14 |
-
LOGS_DIR = PROJECT_ROOT / "logs"
|
15 |
-
os.makedirs(LOGS_DIR, exist_ok=True)
|
16 |
-
|
17 |
-
|
18 |
-
def setup_logger(name: str = "chorus_detection", level: int = logging.INFO,
|
19 |
-
log_file: Optional[str] = None) -> logging.Logger:
|
20 |
-
"""Configure and return a logger with the specified name and level.
|
21 |
-
|
22 |
-
Args:
|
23 |
-
name: Name of the logger
|
24 |
-
level: Logging level (default: INFO)
|
25 |
-
log_file: Path to the log file (default: None)
|
26 |
-
|
27 |
-
Returns:
|
28 |
-
Configured logger instance
|
29 |
-
"""
|
30 |
-
logger = logging.getLogger(name)
|
31 |
-
logger.setLevel(level)
|
32 |
-
|
33 |
-
# Create formatter
|
34 |
-
formatter = logging.Formatter(
|
35 |
-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
36 |
-
)
|
37 |
-
|
38 |
-
# Create console handler
|
39 |
-
console_handler = logging.StreamHandler(sys.stdout)
|
40 |
-
console_handler.setFormatter(formatter)
|
41 |
-
logger.addHandler(console_handler)
|
42 |
-
|
43 |
-
# Create file handler if log_file is specified
|
44 |
-
if log_file:
|
45 |
-
file_handler = logging.FileHandler(log_file)
|
46 |
-
file_handler.setFormatter(formatter)
|
47 |
-
logger.addHandler(file_handler)
|
48 |
-
|
49 |
-
return logger
|
50 |
-
|
51 |
-
|
52 |
-
# Create default logger
|
53 |
-
logger = setup_logger()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/chorus_detection/visualization/__init__.py
DELETED
File without changes
|
src/chorus_detection/visualization/plotter.py
DELETED
@@ -1,78 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
"""Module for visualizing audio data and chorus predictions."""
|
5 |
-
|
6 |
-
from typing import List
|
7 |
-
|
8 |
-
import librosa
|
9 |
-
import librosa.display
|
10 |
-
import matplotlib.pyplot as plt
|
11 |
-
import numpy as np
|
12 |
-
import os
|
13 |
-
|
14 |
-
from chorus_detection.audio.processor import AudioFeature
|
15 |
-
|
16 |
-
|
17 |
-
def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None:
|
18 |
-
"""Draw meter grid lines on the plot.
|
19 |
-
|
20 |
-
Args:
|
21 |
-
ax: The matplotlib axes object to draw on
|
22 |
-
meter_grid_times: Array of times at which to draw the meter lines
|
23 |
-
"""
|
24 |
-
for time in meter_grid_times:
|
25 |
-
ax.axvline(x=time, color='grey', linestyle='--',
|
26 |
-
linewidth=1, alpha=0.6)
|
27 |
-
|
28 |
-
|
29 |
-
def plot_predictions(audio_features: AudioFeature, binary_predictions: np.ndarray) -> None:
|
30 |
-
"""Plot the audio waveform and overlay the predicted chorus locations.
|
31 |
-
|
32 |
-
Args:
|
33 |
-
audio_features: An object containing audio features and components
|
34 |
-
binary_predictions: Array of binary predictions indicating chorus locations
|
35 |
-
"""
|
36 |
-
meter_grid_times = librosa.frames_to_time(
|
37 |
-
audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
|
38 |
-
fig, ax = plt.subplots(figsize=(12.5, 3), dpi=96)
|
39 |
-
|
40 |
-
# Display harmonic and percussive components
|
41 |
-
librosa.display.waveshow(audio_features.y_harm, sr=audio_features.sr,
|
42 |
-
alpha=0.8, ax=ax, color='deepskyblue')
|
43 |
-
librosa.display.waveshow(audio_features.y_perc, sr=audio_features.sr,
|
44 |
-
alpha=0.7, ax=ax, color='plum')
|
45 |
-
plot_meter_lines(ax, meter_grid_times)
|
46 |
-
|
47 |
-
# Highlight chorus sections
|
48 |
-
for i, prediction in enumerate(binary_predictions):
|
49 |
-
start_time = meter_grid_times[i]
|
50 |
-
end_time = meter_grid_times[i + 1] if i < len(
|
51 |
-
meter_grid_times) - 1 else len(audio_features.y) / audio_features.sr
|
52 |
-
if prediction == 1:
|
53 |
-
ax.axvspan(start_time, end_time, color='green', alpha=0.3,
|
54 |
-
label='Predicted Chorus' if i == 0 else None)
|
55 |
-
|
56 |
-
# Set plot limits and labels
|
57 |
-
ax.set_xlim([0, len(audio_features.y) / audio_features.sr])
|
58 |
-
ax.set_ylabel('Amplitude')
|
59 |
-
audio_file_name = os.path.basename(audio_features.audio_path)
|
60 |
-
ax.set_title(
|
61 |
-
f'Chorus Predictions for {os.path.splitext(audio_file_name)[0]}')
|
62 |
-
|
63 |
-
# Add legend
|
64 |
-
chorus_patch = plt.Rectangle((0, 0), 1, 1, fc='green', alpha=0.3)
|
65 |
-
handles, labels = ax.get_legend_handles_labels()
|
66 |
-
handles.append(chorus_patch)
|
67 |
-
labels.append('Chorus')
|
68 |
-
ax.legend(handles=handles, labels=labels)
|
69 |
-
|
70 |
-
# Set x-tick labels in minutes:seconds format
|
71 |
-
duration = len(audio_features.y) / audio_features.sr
|
72 |
-
xticks = np.arange(0, duration, 10)
|
73 |
-
xlabels = [f"{int(tick // 60)}:{int(tick % 60):02d}" for tick in xticks]
|
74 |
-
ax.set_xticks(xticks)
|
75 |
-
ax.set_xticklabels(xlabels)
|
76 |
-
|
77 |
-
plt.tight_layout()
|
78 |
-
plt.show(block=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/download_model.py
DELETED
@@ -1,188 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
"""Script to download the chorus detection model from HuggingFace.
|
5 |
-
|
6 |
-
This script checks if the model file exists locally, and if not, downloads it
|
7 |
-
from the specified HuggingFace repository.
|
8 |
-
"""
|
9 |
-
|
10 |
-
import os
|
11 |
-
import sys
|
12 |
-
from pathlib import Path
|
13 |
-
import logging
|
14 |
-
|
15 |
-
# Configure logging
|
16 |
-
logging.basicConfig(
|
17 |
-
level=logging.INFO,
|
18 |
-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
19 |
-
)
|
20 |
-
logger = logging.getLogger("model-downloader")
|
21 |
-
|
22 |
-
# Debug environment info
|
23 |
-
logger.info(f"Current working directory: {os.getcwd()}")
|
24 |
-
logger.info(f"Python path: {sys.path}")
|
25 |
-
logger.info(f"MODEL_REVISION: {os.environ.get('MODEL_REVISION')}")
|
26 |
-
logger.info(f"MODEL_HF_REPO: {os.environ.get('MODEL_HF_REPO')}")
|
27 |
-
logger.info(f"HF_MODEL_FILENAME: {os.environ.get('HF_MODEL_FILENAME')}")
|
28 |
-
|
29 |
-
# Use huggingface_hub for better integration with HF ecosystem
|
30 |
-
try:
|
31 |
-
from huggingface_hub import hf_hub_download
|
32 |
-
HF_HUB_AVAILABLE = True
|
33 |
-
logger.info("huggingface_hub is available")
|
34 |
-
except ImportError:
|
35 |
-
HF_HUB_AVAILABLE = False
|
36 |
-
logger.warning("huggingface_hub is not available, falling back to direct download")
|
37 |
-
import requests
|
38 |
-
from tqdm import tqdm
|
39 |
-
|
40 |
-
def download_file_with_progress(url: str, destination: Path) -> None:
|
41 |
-
"""Download a file with a progress bar.
|
42 |
-
|
43 |
-
Args:
|
44 |
-
url: URL to download from
|
45 |
-
destination: Path to save the file to
|
46 |
-
"""
|
47 |
-
# Create parent directories if they don't exist
|
48 |
-
destination.parent.mkdir(parents=True, exist_ok=True)
|
49 |
-
|
50 |
-
# Stream the download with progress bar
|
51 |
-
response = requests.get(url, stream=True)
|
52 |
-
response.raise_for_status()
|
53 |
-
|
54 |
-
total_size = int(response.headers.get('content-length', 0))
|
55 |
-
block_size = 1024 # 1 Kibibyte
|
56 |
-
|
57 |
-
logger.info(f"Downloading model from {url}")
|
58 |
-
logger.info(f"File size: {total_size / (1024*1024):.1f} MB")
|
59 |
-
|
60 |
-
with open(destination, 'wb') as file, tqdm(
|
61 |
-
desc=destination.name,
|
62 |
-
total=total_size,
|
63 |
-
unit='iB',
|
64 |
-
unit_scale=True,
|
65 |
-
unit_divisor=1024,
|
66 |
-
) as bar:
|
67 |
-
for data in response.iter_content(block_size):
|
68 |
-
size = file.write(data)
|
69 |
-
bar.update(size)
|
70 |
-
|
71 |
-
def ensure_model_exists(
|
72 |
-
model_filename: str = "best_model_V3.h5",
|
73 |
-
repo_id: str = None,
|
74 |
-
model_dir: Path = None,
|
75 |
-
hf_model_filename: str = None,
|
76 |
-
revision: str = None
|
77 |
-
) -> Path:
|
78 |
-
"""Ensure the model file exists, downloading it if necessary.
|
79 |
-
|
80 |
-
Args:
|
81 |
-
model_filename: Local filename for the model
|
82 |
-
repo_id: HuggingFace repository ID
|
83 |
-
model_dir: Directory to save the model to
|
84 |
-
hf_model_filename: Filename of the model in the HuggingFace repo
|
85 |
-
revision: Specific version of the model to use (SHA-256 hash)
|
86 |
-
|
87 |
-
Returns:
|
88 |
-
Path to the model file
|
89 |
-
"""
|
90 |
-
# Get parameters from environment variables if not provided
|
91 |
-
if repo_id is None:
|
92 |
-
repo_id = os.environ.get("MODEL_HF_REPO", "dennisvdang/chorus-detection")
|
93 |
-
|
94 |
-
if hf_model_filename is None:
|
95 |
-
hf_model_filename = os.environ.get("HF_MODEL_FILENAME", "chorus_detection_crnn.h5")
|
96 |
-
|
97 |
-
if revision is None:
|
98 |
-
revision = os.environ.get("MODEL_REVISION", "20e66eb3d0788373c3bdc5b28fa2f2587b0e475f3bbc47e8ab9ff0dbdbb2df32")
|
99 |
-
|
100 |
-
# Handle model directory paths for different environments
|
101 |
-
if model_dir is None:
|
102 |
-
# Check if we're in HF Spaces
|
103 |
-
if os.environ.get("SPACE_ID"):
|
104 |
-
# Try several possible locations
|
105 |
-
possible_dirs = [
|
106 |
-
Path("models/CRNN"),
|
107 |
-
Path("/home/user/app/models/CRNN"),
|
108 |
-
Path("/app/models/CRNN"),
|
109 |
-
Path(os.getcwd()) / "models" / "CRNN"
|
110 |
-
]
|
111 |
-
|
112 |
-
for directory in possible_dirs:
|
113 |
-
if directory.exists() or directory.parent.exists():
|
114 |
-
model_dir = directory
|
115 |
-
break
|
116 |
-
|
117 |
-
# If none exist, use the first option and create it
|
118 |
-
if model_dir is None:
|
119 |
-
model_dir = possible_dirs[0]
|
120 |
-
else:
|
121 |
-
model_dir = Path("models/CRNN")
|
122 |
-
|
123 |
-
# Make sure model_dir is a Path object
|
124 |
-
if isinstance(model_dir, str):
|
125 |
-
model_dir = Path(model_dir)
|
126 |
-
|
127 |
-
logger.info(f"Using model directory: {model_dir}")
|
128 |
-
|
129 |
-
model_path = model_dir / model_filename
|
130 |
-
|
131 |
-
# Log environment info when running in HF Space
|
132 |
-
if os.environ.get("SPACE_ID"):
|
133 |
-
logger.info(f"Running in Hugging Face Space: {os.environ.get('SPACE_ID')}")
|
134 |
-
logger.info(f"Using model repo: {repo_id}")
|
135 |
-
logger.info(f"Using model file: {hf_model_filename}")
|
136 |
-
logger.info(f"Using revision: {revision}")
|
137 |
-
|
138 |
-
# Check if the model already exists
|
139 |
-
if model_path.exists():
|
140 |
-
logger.info(f"Model already exists at {model_path}")
|
141 |
-
return model_path
|
142 |
-
|
143 |
-
# Create model directory if it doesn't exist
|
144 |
-
model_dir.mkdir(parents=True, exist_ok=True)
|
145 |
-
|
146 |
-
logger.info(f"Model not found at {model_path}. Downloading...")
|
147 |
-
|
148 |
-
try:
|
149 |
-
if HF_HUB_AVAILABLE:
|
150 |
-
# Use huggingface_hub to download the model
|
151 |
-
logger.info(f"Downloading model from {repo_id}/{hf_model_filename} (revision: {revision}) using huggingface_hub")
|
152 |
-
downloaded_path = hf_hub_download(
|
153 |
-
repo_id=repo_id,
|
154 |
-
filename=hf_model_filename,
|
155 |
-
local_dir=model_dir,
|
156 |
-
local_dir_use_symlinks=False,
|
157 |
-
revision=revision # Specify the exact revision to use
|
158 |
-
)
|
159 |
-
|
160 |
-
logger.info(f"Downloaded to: {downloaded_path}")
|
161 |
-
|
162 |
-
# Rename if necessary
|
163 |
-
if os.path.basename(downloaded_path) != model_filename:
|
164 |
-
downloaded_path_obj = Path(downloaded_path)
|
165 |
-
model_path.parent.mkdir(parents=True, exist_ok=True)
|
166 |
-
if model_path.exists():
|
167 |
-
model_path.unlink()
|
168 |
-
downloaded_path_obj.rename(model_path)
|
169 |
-
logger.info(f"Renamed {downloaded_path} to {model_path}")
|
170 |
-
else:
|
171 |
-
# Fallback to direct download if huggingface_hub is not available
|
172 |
-
huggingface_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/{hf_model_filename}"
|
173 |
-
download_file_with_progress(huggingface_url, model_path)
|
174 |
-
|
175 |
-
logger.info(f"Successfully downloaded model to {model_path}")
|
176 |
-
return model_path
|
177 |
-
except Exception as e:
|
178 |
-
logger.error(f"Failed to download model: {e}", exc_info=True)
|
179 |
-
|
180 |
-
# Handle error more gracefully in production environment
|
181 |
-
if os.environ.get("SPACE_ID"):
|
182 |
-
logger.warning("Continuing despite model download failure")
|
183 |
-
return model_path
|
184 |
-
else:
|
185 |
-
sys.exit(1)
|
186 |
-
|
187 |
-
if __name__ == "__main__":
|
188 |
-
ensure_model_exists()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/streamlit_app.py
DELETED
@@ -1,536 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
# -*- coding: utf-8 -*-
|
3 |
-
|
4 |
-
"""Streamlit web app for chorus detection in audio files.
|
5 |
-
|
6 |
-
This module provides a web-based interface for the chorus detection system,
|
7 |
-
allowing users to upload audio files or provide YouTube URLs for analysis.
|
8 |
-
"""
|
9 |
-
|
10 |
-
import os
|
11 |
-
import sys
|
12 |
-
import logging
|
13 |
-
|
14 |
-
# Configure logging
|
15 |
-
logger = logging.getLogger("streamlit-app")
|
16 |
-
|
17 |
-
# Configure TensorFlow logging before importing TensorFlow
|
18 |
-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logs
|
19 |
-
|
20 |
-
# Ensure proper import paths
|
21 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
22 |
-
root_dir = os.path.dirname(current_dir)
|
23 |
-
if current_dir not in sys.path:
|
24 |
-
sys.path.insert(0, current_dir)
|
25 |
-
if root_dir not in sys.path:
|
26 |
-
sys.path.insert(0, root_dir)
|
27 |
-
|
28 |
-
# Import model downloader to ensure model is available
|
29 |
-
try:
|
30 |
-
if os.path.exists(os.path.join(os.getcwd(), "download_model.py")):
|
31 |
-
# If in the root directory
|
32 |
-
from download_model import ensure_model_exists
|
33 |
-
else:
|
34 |
-
# If in the src directory
|
35 |
-
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
36 |
-
from download_model import ensure_model_exists
|
37 |
-
except ImportError as e:
|
38 |
-
logger.error(f"Error importing ensure_model_exists: {e}")
|
39 |
-
try:
|
40 |
-
# Try alternative import
|
41 |
-
from src.download_model import ensure_model_exists
|
42 |
-
except ImportError as e2:
|
43 |
-
logger.error(f"Alternative import failed: {e2}")
|
44 |
-
raise
|
45 |
-
|
46 |
-
import base64
|
47 |
-
import tempfile
|
48 |
-
import warnings
|
49 |
-
from typing import Optional, Tuple, List
|
50 |
-
import time
|
51 |
-
import io
|
52 |
-
|
53 |
-
import matplotlib.pyplot as plt
|
54 |
-
import streamlit as st
|
55 |
-
import tensorflow as tf
|
56 |
-
import librosa
|
57 |
-
import soundfile as sf
|
58 |
-
import numpy as np
|
59 |
-
from pydub import AudioSegment
|
60 |
-
|
61 |
-
# Suppress warnings
|
62 |
-
warnings.filterwarnings("ignore") # Suppress all warnings
|
63 |
-
tf.get_logger().setLevel('ERROR') # Suppress TensorFlow ERROR logs
|
64 |
-
|
65 |
-
# Debug import paths
|
66 |
-
logger.info(f"Python path: {sys.path}")
|
67 |
-
logger.info(f"Current working directory: {os.getcwd()}")
|
68 |
-
|
69 |
-
# First try direct import with src in path
|
70 |
-
try:
|
71 |
-
# Add src directory to Python path if not already there
|
72 |
-
src_path = os.path.dirname(current_dir)
|
73 |
-
if src_path not in sys.path:
|
74 |
-
sys.path.insert(0, src_path)
|
75 |
-
|
76 |
-
from chorus_detection.audio.data_processing import process_audio
|
77 |
-
from chorus_detection.audio.processor import extract_audio
|
78 |
-
from chorus_detection.models.crnn import load_CRNN_model, make_predictions
|
79 |
-
from chorus_detection.utils.cli import is_youtube_url
|
80 |
-
from chorus_detection.utils.logging import logger
|
81 |
-
|
82 |
-
logger.info("Successfully imported chorus_detection modules")
|
83 |
-
except ImportError as e:
|
84 |
-
logger.error(f"Error importing chorus_detection modules: {e}")
|
85 |
-
logger.info("Trying alternative imports...")
|
86 |
-
|
87 |
-
# Try with manual path adjustment
|
88 |
-
try:
|
89 |
-
# Adjust import paths - try different directories
|
90 |
-
possible_paths = [
|
91 |
-
os.path.join(os.getcwd(), "src"),
|
92 |
-
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
93 |
-
os.path.dirname(os.getcwd())
|
94 |
-
]
|
95 |
-
|
96 |
-
for path in possible_paths:
|
97 |
-
if path not in sys.path and os.path.exists(path):
|
98 |
-
sys.path.insert(0, path)
|
99 |
-
logger.info(f"Added path to sys.path: {path}")
|
100 |
-
|
101 |
-
# Try importing directly from chorus_detection module path
|
102 |
-
sys.path.insert(0, os.path.join(os.getcwd(), "src", "chorus_detection"))
|
103 |
-
|
104 |
-
from chorus_detection.audio.data_processing import process_audio
|
105 |
-
from chorus_detection.audio.processor import extract_audio
|
106 |
-
from chorus_detection.models.crnn import load_CRNN_model, make_predictions
|
107 |
-
from chorus_detection.utils.cli import is_youtube_url
|
108 |
-
from chorus_detection.utils.logging import logger
|
109 |
-
|
110 |
-
logger.info("Successfully imported chorus_detection modules after path adjustment")
|
111 |
-
except ImportError as e2:
|
112 |
-
logger.error(f"Alternative imports also failed: {e2}")
|
113 |
-
raise
|
114 |
-
|
115 |
-
# Define the MODEL_PATH directly
|
116 |
-
MODEL_PATH = os.path.join(os.getcwd(), "models", "CRNN", "best_model_V3.h5")
|
117 |
-
if not os.path.exists(MODEL_PATH):
|
118 |
-
MODEL_PATH = ensure_model_exists()
|
119 |
-
|
120 |
-
# Define color scheme
|
121 |
-
THEME_COLORS = {
|
122 |
-
'background': '#121212',
|
123 |
-
'card_bg': '#181818',
|
124 |
-
'primary': '#1DB954',
|
125 |
-
'secondary': '#1ED760',
|
126 |
-
'text': '#FFFFFF',
|
127 |
-
'subtext': '#B3B3B3',
|
128 |
-
'highlight': '#1DB954',
|
129 |
-
'border': '#333333',
|
130 |
-
}
|
131 |
-
|
132 |
-
|
133 |
-
def get_binary_file_downloader_html(bin_file: str, file_label: str = 'File') -> str:
|
134 |
-
"""Generate HTML for file download link.
|
135 |
-
|
136 |
-
Args:
|
137 |
-
bin_file: Path to the binary file
|
138 |
-
file_label: Label for the download link
|
139 |
-
|
140 |
-
Returns:
|
141 |
-
HTML string for the download link
|
142 |
-
"""
|
143 |
-
with open(bin_file, 'rb') as f:
|
144 |
-
data = f.read()
|
145 |
-
b64 = base64.b64encode(data).decode()
|
146 |
-
return f'<a href="data:application/octet-stream;base64,{b64}" download="{os.path.basename(bin_file)}">{file_label}</a>'
|
147 |
-
|
148 |
-
|
149 |
-
def set_custom_theme() -> None:
|
150 |
-
"""Apply custom Spotify-inspired theme to Streamlit UI."""
|
151 |
-
custom_theme = f"""
|
152 |
-
<style>
|
153 |
-
.stApp {{
|
154 |
-
background-color: {THEME_COLORS['background']};
|
155 |
-
color: {THEME_COLORS['text']};
|
156 |
-
}}
|
157 |
-
.css-18e3th9 {{
|
158 |
-
padding-top: 2rem;
|
159 |
-
padding-bottom: 10rem;
|
160 |
-
padding-left: 5rem;
|
161 |
-
padding-right: 5rem;
|
162 |
-
}}
|
163 |
-
h1, h2, h3, h4, h5, h6 {{
|
164 |
-
color: {THEME_COLORS['text']} !important;
|
165 |
-
font-weight: 700 !important;
|
166 |
-
}}
|
167 |
-
.stSidebar .sidebar-content {{
|
168 |
-
background-color: {THEME_COLORS['card_bg']};
|
169 |
-
}}
|
170 |
-
.stButton>button {{
|
171 |
-
background-color: {THEME_COLORS['primary']};
|
172 |
-
color: white;
|
173 |
-
border-radius: 500px;
|
174 |
-
padding: 8px 32px;
|
175 |
-
font-weight: 600;
|
176 |
-
border: none;
|
177 |
-
transition: all 0.3s ease;
|
178 |
-
}}
|
179 |
-
.stButton>button:hover {{
|
180 |
-
background-color: {THEME_COLORS['secondary']};
|
181 |
-
transform: scale(1.04);
|
182 |
-
}}
|
183 |
-
</style>
|
184 |
-
"""
|
185 |
-
st.markdown(custom_theme, unsafe_allow_html=True)
|
186 |
-
|
187 |
-
|
188 |
-
def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
|
189 |
-
"""Process a YouTube URL and extract audio.
|
190 |
-
|
191 |
-
Args:
|
192 |
-
url: YouTube URL
|
193 |
-
|
194 |
-
Returns:
|
195 |
-
Tuple of (audio_path, video_name)
|
196 |
-
"""
|
197 |
-
try:
|
198 |
-
with st.spinner('Downloading audio from YouTube...'):
|
199 |
-
audio_path, video_name = extract_audio(url)
|
200 |
-
return audio_path, video_name
|
201 |
-
except Exception as e:
|
202 |
-
st.error(f"Error processing YouTube URL: {e}")
|
203 |
-
logger.error(f"Error processing YouTube URL: {e}", exc_info=True)
|
204 |
-
return None, None
|
205 |
-
|
206 |
-
|
207 |
-
def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
|
208 |
-
"""Process an uploaded audio file.
|
209 |
-
|
210 |
-
Args:
|
211 |
-
uploaded_file: Streamlit UploadedFile object
|
212 |
-
|
213 |
-
Returns:
|
214 |
-
Tuple of (audio_path, file_name)
|
215 |
-
"""
|
216 |
-
try:
|
217 |
-
with st.spinner('Processing uploaded file...'):
|
218 |
-
# Save the uploaded file to a temporary location
|
219 |
-
temp_dir = tempfile.mkdtemp()
|
220 |
-
file_name = uploaded_file.name
|
221 |
-
temp_path = os.path.join(temp_dir, file_name)
|
222 |
-
|
223 |
-
with open(temp_path, 'wb') as f:
|
224 |
-
f.write(uploaded_file.getbuffer())
|
225 |
-
|
226 |
-
return temp_path, file_name.split('.')[0]
|
227 |
-
except Exception as e:
|
228 |
-
st.error(f"Error processing uploaded file: {e}")
|
229 |
-
logger.error(f"Error processing uploaded file: {e}", exc_info=True)
|
230 |
-
return None, None
|
231 |
-
|
232 |
-
|
233 |
-
def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.ndarray,
|
234 |
-
meter_grid_times: np.ndarray) -> List[Tuple[float, float, np.ndarray]]:
|
235 |
-
"""Extract chorus segments from predictions.
|
236 |
-
|
237 |
-
Args:
|
238 |
-
y: Audio data
|
239 |
-
sr: Sample rate
|
240 |
-
smoothed_predictions: Smoothed model predictions
|
241 |
-
meter_grid_times: Time grid for predictions
|
242 |
-
|
243 |
-
Returns:
|
244 |
-
List of (start_time, end_time, audio_segment) tuples
|
245 |
-
"""
|
246 |
-
# Define threshold for chorus detection (probability > 0.5)
|
247 |
-
threshold = 0.5
|
248 |
-
|
249 |
-
# Find the segments where the predictions are above the threshold
|
250 |
-
chorus_mask = smoothed_predictions > threshold
|
251 |
-
|
252 |
-
# Group consecutive True values to identify segments
|
253 |
-
segments = []
|
254 |
-
current_segment = None
|
255 |
-
|
256 |
-
for i, is_chorus in enumerate(chorus_mask):
|
257 |
-
time = meter_grid_times[i]
|
258 |
-
|
259 |
-
if is_chorus and current_segment is None:
|
260 |
-
# Start a new segment
|
261 |
-
current_segment = (time, None, None)
|
262 |
-
elif not is_chorus and current_segment is not None:
|
263 |
-
# End the current segment
|
264 |
-
start_time = current_segment[0]
|
265 |
-
current_segment = (start_time, time, None)
|
266 |
-
segments.append(current_segment)
|
267 |
-
current_segment = None
|
268 |
-
|
269 |
-
# Handle the case where the last segment extends to the end of the song
|
270 |
-
if current_segment is not None:
|
271 |
-
start_time = current_segment[0]
|
272 |
-
segments.append((start_time, meter_grid_times[-1], None))
|
273 |
-
|
274 |
-
# Extract the actual audio for each segment
|
275 |
-
segments_with_audio = []
|
276 |
-
for start_time, end_time, _ in segments:
|
277 |
-
# Convert times to sample indices
|
278 |
-
start_idx = int(start_time * sr)
|
279 |
-
end_idx = int(end_time * sr)
|
280 |
-
|
281 |
-
# Extract the audio segment
|
282 |
-
segment_audio = y[start_idx:end_idx]
|
283 |
-
|
284 |
-
segments_with_audio.append((start_time, end_time, segment_audio))
|
285 |
-
|
286 |
-
return segments_with_audio
|
287 |
-
|
288 |
-
|
289 |
-
def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]],
|
290 |
-
sr: int, fade_duration: float = 0.3) -> Tuple[np.ndarray, str]:
|
291 |
-
"""Create a compilation of chorus segments.
|
292 |
-
|
293 |
-
Args:
|
294 |
-
segments: List of (start_time, end_time, audio_data) tuples
|
295 |
-
sr: Sample rate
|
296 |
-
fade_duration: Duration of fade in/out in seconds
|
297 |
-
|
298 |
-
Returns:
|
299 |
-
Tuple of (compilation_audio, description)
|
300 |
-
"""
|
301 |
-
if not segments:
|
302 |
-
return np.array([]), "No chorus segments found"
|
303 |
-
|
304 |
-
# Calculate the number of samples for fading
|
305 |
-
fade_samples = int(fade_duration * sr)
|
306 |
-
|
307 |
-
# Prepare a list to store the processed segments
|
308 |
-
processed_segments = []
|
309 |
-
|
310 |
-
# Description of segments
|
311 |
-
segment_descriptions = []
|
312 |
-
|
313 |
-
# Process each segment
|
314 |
-
for i, (start_time, end_time, audio) in enumerate(segments):
|
315 |
-
# Apply fade in and fade out
|
316 |
-
segment_length = len(audio)
|
317 |
-
|
318 |
-
if segment_length <= 2 * fade_samples:
|
319 |
-
# Segment is too short for fading, skip it
|
320 |
-
continue
|
321 |
-
|
322 |
-
# Create a linear fade in and fade out
|
323 |
-
fade_in = np.linspace(0, 1, fade_samples)
|
324 |
-
fade_out = np.linspace(1, 0, fade_samples)
|
325 |
-
|
326 |
-
# Apply the fades
|
327 |
-
audio_faded = audio.copy()
|
328 |
-
audio_faded[:fade_samples] *= fade_in
|
329 |
-
audio_faded[-fade_samples:] *= fade_out
|
330 |
-
|
331 |
-
processed_segments.append(audio_faded)
|
332 |
-
|
333 |
-
# Format the times for the description
|
334 |
-
start_fmt = format_time(start_time)
|
335 |
-
end_fmt = format_time(end_time)
|
336 |
-
segment_descriptions.append(f"Chorus {i+1}: {start_fmt} - {end_fmt}")
|
337 |
-
|
338 |
-
if not processed_segments:
|
339 |
-
return np.array([]), "No chorus segments long enough for compilation"
|
340 |
-
|
341 |
-
# Concatenate all the processed segments
|
342 |
-
compilation = np.concatenate(processed_segments)
|
343 |
-
|
344 |
-
# Join the descriptions
|
345 |
-
description = "\n".join(segment_descriptions)
|
346 |
-
|
347 |
-
return compilation, description
|
348 |
-
|
349 |
-
|
350 |
-
def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str = 'mp3') -> bytes:
|
351 |
-
"""Save audio data to a format suitable for Streamlit audio playback.
|
352 |
-
|
353 |
-
Args:
|
354 |
-
audio_data: Audio samples
|
355 |
-
sr: Sample rate
|
356 |
-
file_format: Output format ('mp3', 'wav', etc.)
|
357 |
-
|
358 |
-
Returns:
|
359 |
-
Audio bytes
|
360 |
-
"""
|
361 |
-
with io.BytesIO() as buffer:
|
362 |
-
sf.write(buffer, audio_data, sr, format=file_format)
|
363 |
-
buffer.seek(0)
|
364 |
-
return buffer.read()
|
365 |
-
|
366 |
-
|
367 |
-
def format_time(seconds: float) -> str:
|
368 |
-
"""Format seconds as MM:SS.
|
369 |
-
|
370 |
-
Args:
|
371 |
-
seconds: Time in seconds
|
372 |
-
|
373 |
-
Returns:
|
374 |
-
Formatted time string
|
375 |
-
"""
|
376 |
-
minutes = int(seconds // 60)
|
377 |
-
seconds = int(seconds % 60)
|
378 |
-
return f"{minutes:02d}:{seconds:02d}"
|
379 |
-
|
380 |
-
|
381 |
-
def main() -> None:
|
382 |
-
"""Main function for the Streamlit app."""
|
383 |
-
# Set page config
|
384 |
-
st.set_page_config(
|
385 |
-
page_title="Chorus Detection",
|
386 |
-
page_icon="🎵",
|
387 |
-
layout="wide",
|
388 |
-
initial_sidebar_state="collapsed",
|
389 |
-
)
|
390 |
-
|
391 |
-
# Apply custom theme
|
392 |
-
set_custom_theme()
|
393 |
-
|
394 |
-
# App title and description
|
395 |
-
st.title("Chorus Detection")
|
396 |
-
st.markdown("""
|
397 |
-
<div class="subheader">
|
398 |
-
Upload a song or enter a YouTube URL to automatically detect chorus sections using AI
|
399 |
-
</div>
|
400 |
-
""", unsafe_allow_html=True)
|
401 |
-
|
402 |
-
# User input section
|
403 |
-
col1, col2 = st.columns(2)
|
404 |
-
|
405 |
-
with col1:
|
406 |
-
st.markdown('<div class="input-option">', unsafe_allow_html=True)
|
407 |
-
st.subheader("Option 1: Upload an audio file")
|
408 |
-
uploaded_file = st.file_uploader("Choose an audio file", type=['mp3', 'wav', 'ogg', 'flac', 'm4a'])
|
409 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
410 |
-
|
411 |
-
with col2:
|
412 |
-
st.markdown('<div class="input-option">', unsafe_allow_html=True)
|
413 |
-
st.subheader("Option 2: YouTube URL")
|
414 |
-
youtube_url = st.text_input("Enter a YouTube URL", placeholder="https://www.youtube.com/watch?v=...")
|
415 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
416 |
-
|
417 |
-
# Process button
|
418 |
-
if st.button("Analyze"):
|
419 |
-
# Check the input method
|
420 |
-
audio_path = None
|
421 |
-
file_name = None
|
422 |
-
|
423 |
-
if uploaded_file is not None:
|
424 |
-
audio_path, file_name = process_uploaded_file(uploaded_file)
|
425 |
-
elif youtube_url:
|
426 |
-
if is_youtube_url(youtube_url):
|
427 |
-
audio_path, file_name = process_youtube(youtube_url)
|
428 |
-
else:
|
429 |
-
st.error("Invalid YouTube URL. Please enter a valid YouTube URL.")
|
430 |
-
else:
|
431 |
-
st.error("Please upload an audio file or enter a YouTube URL.")
|
432 |
-
|
433 |
-
# If we have a valid audio path, process it
|
434 |
-
if audio_path and file_name:
|
435 |
-
try:
|
436 |
-
# Load and process the audio file
|
437 |
-
with st.spinner('Processing audio...'):
|
438 |
-
# Load audio and extract features
|
439 |
-
y, sr = librosa.load(audio_path, sr=22050)
|
440 |
-
|
441 |
-
# Create a temporary directory for model output
|
442 |
-
temp_output_dir = tempfile.mkdtemp()
|
443 |
-
|
444 |
-
# Load the model
|
445 |
-
model = load_CRNN_model(MODEL_PATH)
|
446 |
-
|
447 |
-
# Process audio and make predictions
|
448 |
-
audio_features, _ = process_audio(audio_path, output_path=temp_output_dir)
|
449 |
-
meter_grid_times, predictions = make_predictions(model, audio_features)
|
450 |
-
|
451 |
-
# Smooth predictions to avoid rapid transitions
|
452 |
-
smoothed_predictions = np.convolve(predictions,
|
453 |
-
np.ones(5)/5,
|
454 |
-
mode='same')
|
455 |
-
|
456 |
-
# Extract chorus segments
|
457 |
-
chorus_segments = extract_chorus_segments(y, sr, smoothed_predictions, meter_grid_times)
|
458 |
-
|
459 |
-
# Create a chorus compilation
|
460 |
-
compilation_audio, segments_desc = create_chorus_compilation(chorus_segments, sr)
|
461 |
-
|
462 |
-
# Display results
|
463 |
-
st.markdown(f"""
|
464 |
-
<div class="result-container">
|
465 |
-
<div class="song-title">{file_name}</div>
|
466 |
-
</div>
|
467 |
-
""", unsafe_allow_html=True)
|
468 |
-
|
469 |
-
# Display waveform with highlighted chorus sections
|
470 |
-
fig, ax = plt.subplots(figsize=(14, 5))
|
471 |
-
|
472 |
-
# Plot the waveform
|
473 |
-
times = np.linspace(0, len(y)/sr, len(y))
|
474 |
-
ax.plot(times, y, color='#b3b3b3', alpha=0.5, linewidth=1)
|
475 |
-
ax.set_xlabel('Time (s)')
|
476 |
-
ax.set_ylabel('Amplitude')
|
477 |
-
ax.set_title('Audio Waveform with Chorus Sections Highlighted')
|
478 |
-
|
479 |
-
# Highlight chorus sections
|
480 |
-
for start_time, end_time, _ in chorus_segments:
|
481 |
-
ax.axvspan(start_time, end_time, alpha=0.3, color=THEME_COLORS['primary'])
|
482 |
-
|
483 |
-
# Add a label at the start of each chorus
|
484 |
-
ax.annotate('Chorus',
|
485 |
-
xy=(start_time, 0.8 * max(y)),
|
486 |
-
xytext=(start_time + 0.5, 0.9 * max(y)),
|
487 |
-
color=THEME_COLORS['primary'],
|
488 |
-
weight='bold')
|
489 |
-
|
490 |
-
# Customize plot appearance
|
491 |
-
ax.set_facecolor(THEME_COLORS['card_bg'])
|
492 |
-
fig.patch.set_facecolor(THEME_COLORS['background'])
|
493 |
-
ax.spines['top'].set_visible(False)
|
494 |
-
ax.spines['right'].set_visible(False)
|
495 |
-
ax.spines['bottom'].set_color(THEME_COLORS['border'])
|
496 |
-
ax.spines['left'].set_color(THEME_COLORS['border'])
|
497 |
-
ax.tick_params(axis='x', colors=THEME_COLORS['text'])
|
498 |
-
ax.tick_params(axis='y', colors=THEME_COLORS['text'])
|
499 |
-
ax.xaxis.label.set_color(THEME_COLORS['text'])
|
500 |
-
ax.yaxis.label.set_color(THEME_COLORS['text'])
|
501 |
-
ax.title.set_color(THEME_COLORS['text'])
|
502 |
-
|
503 |
-
st.pyplot(fig)
|
504 |
-
|
505 |
-
# Display chorus segments
|
506 |
-
if chorus_segments:
|
507 |
-
st.markdown('<div class="chorus-card">', unsafe_allow_html=True)
|
508 |
-
st.subheader("Chorus Segments")
|
509 |
-
for i, (start_time, end_time, segment_audio) in enumerate(chorus_segments):
|
510 |
-
st.markdown(f"""
|
511 |
-
<div class="time-stamp">Chorus {i+1}: {format_time(start_time)} - {format_time(end_time)}</div>
|
512 |
-
""", unsafe_allow_html=True)
|
513 |
-
|
514 |
-
# Convert segment audio to bytes for playback
|
515 |
-
audio_bytes = save_audio_for_streamlit(segment_audio, sr)
|
516 |
-
st.audio(audio_bytes, format='audio/mp3')
|
517 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
518 |
-
|
519 |
-
# Chorus compilation
|
520 |
-
if len(compilation_audio) > 0:
|
521 |
-
st.markdown('<div class="chorus-card">', unsafe_allow_html=True)
|
522 |
-
st.subheader("Chorus Compilation")
|
523 |
-
st.markdown("All chorus segments combined into one track:")
|
524 |
-
|
525 |
-
compilation_bytes = save_audio_for_streamlit(compilation_audio, sr)
|
526 |
-
st.audio(compilation_bytes, format='audio/mp3')
|
527 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
528 |
-
else:
|
529 |
-
st.info("No chorus sections detected in this audio.")
|
530 |
-
|
531 |
-
except Exception as e:
|
532 |
-
st.error(f"Error processing audio: {e}")
|
533 |
-
logger.error(f"Error processing audio: {e}", exc_info=True)
|
534 |
-
|
535 |
-
if __name__ == "__main__":
|
536 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
streamlit_app.py
CHANGED
@@ -1,35 +1,18 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
-
"""
|
5 |
-
|
6 |
-
This module provides a web-based interface for the chorus detection system,
|
7 |
-
allowing users to upload audio files or provide YouTube URLs for analysis.
|
8 |
"""
|
9 |
|
10 |
import os
|
11 |
import sys
|
12 |
import logging
|
13 |
-
|
14 |
-
# Configure logging
|
15 |
-
logger = logging.getLogger("streamlit-app")
|
16 |
-
|
17 |
-
# Configure TensorFlow logging before importing TensorFlow
|
18 |
-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logs
|
19 |
-
|
20 |
-
# Import model downloader to ensure model is available
|
21 |
-
try:
|
22 |
-
from download_model import ensure_model_exists
|
23 |
-
except ImportError as e:
|
24 |
-
logger.error(f"Error importing ensure_model_exists: {e}")
|
25 |
-
raise
|
26 |
-
|
27 |
import base64
|
28 |
import tempfile
|
29 |
import warnings
|
30 |
-
from typing import Optional, Tuple, List
|
31 |
-
import time
|
32 |
import io
|
|
|
33 |
|
34 |
import matplotlib.pyplot as plt
|
35 |
import streamlit as st
|
@@ -39,33 +22,33 @@ import soundfile as sf
|
|
39 |
import numpy as np
|
40 |
from pydub import AudioSegment
|
41 |
|
42 |
-
#
|
43 |
-
|
44 |
-
tf.get_logger().setLevel('ERROR') # Suppress TensorFlow ERROR logs
|
45 |
|
46 |
-
#
|
47 |
-
|
48 |
-
|
|
|
49 |
|
50 |
-
# Import
|
51 |
try:
|
|
|
52 |
from chorus_detection.audio.data_processing import process_audio
|
53 |
from chorus_detection.audio.processor import extract_audio
|
54 |
from chorus_detection.models.crnn import load_CRNN_model, make_predictions
|
55 |
from chorus_detection.utils.cli import is_youtube_url
|
56 |
from chorus_detection.utils.logging import logger
|
57 |
-
|
58 |
logger.info("Successfully imported chorus_detection modules")
|
59 |
except ImportError as e:
|
60 |
-
logger.error(f"Error importing
|
61 |
raise
|
62 |
|
63 |
-
# Define
|
64 |
MODEL_PATH = os.path.join(os.getcwd(), "models", "CRNN", "best_model_V3.h5")
|
65 |
if not os.path.exists(MODEL_PATH):
|
66 |
MODEL_PATH = ensure_model_exists()
|
67 |
|
68 |
-
#
|
69 |
THEME_COLORS = {
|
70 |
'background': '#121212',
|
71 |
'card_bg': '#181818',
|
@@ -79,15 +62,7 @@ THEME_COLORS = {
|
|
79 |
|
80 |
|
81 |
def get_binary_file_downloader_html(bin_file: str, file_label: str = 'File') -> str:
|
82 |
-
"""Generate HTML for file download link.
|
83 |
-
|
84 |
-
Args:
|
85 |
-
bin_file: Path to the binary file
|
86 |
-
file_label: Label for the download link
|
87 |
-
|
88 |
-
Returns:
|
89 |
-
HTML string for the download link
|
90 |
-
"""
|
91 |
with open(bin_file, 'rb') as f:
|
92 |
data = f.read()
|
93 |
b64 = base64.b64encode(data).decode()
|
@@ -134,14 +109,7 @@ def set_custom_theme() -> None:
|
|
134 |
|
135 |
|
136 |
def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
|
137 |
-
"""Process a YouTube URL and extract audio.
|
138 |
-
|
139 |
-
Args:
|
140 |
-
url: YouTube URL
|
141 |
-
|
142 |
-
Returns:
|
143 |
-
Tuple of (audio_path, video_name)
|
144 |
-
"""
|
145 |
try:
|
146 |
with st.spinner('Downloading audio from YouTube...'):
|
147 |
audio_path, video_name = extract_audio(url)
|
@@ -153,17 +121,9 @@ def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
|
|
153 |
|
154 |
|
155 |
def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
|
156 |
-
"""Process an uploaded audio file.
|
157 |
-
|
158 |
-
Args:
|
159 |
-
uploaded_file: Streamlit UploadedFile object
|
160 |
-
|
161 |
-
Returns:
|
162 |
-
Tuple of (audio_path, file_name)
|
163 |
-
"""
|
164 |
try:
|
165 |
with st.spinner('Processing uploaded file...'):
|
166 |
-
# Save the uploaded file to a temporary location
|
167 |
temp_dir = tempfile.mkdtemp()
|
168 |
file_name = uploaded_file.name
|
169 |
temp_path = os.path.join(temp_dir, file_name)
|
@@ -180,24 +140,9 @@ def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
|
|
180 |
|
181 |
def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.ndarray,
|
182 |
meter_grid_times: np.ndarray) -> List[Tuple[float, float, np.ndarray]]:
|
183 |
-
"""Extract chorus segments from predictions.
|
184 |
-
|
185 |
-
Args:
|
186 |
-
y: Audio data
|
187 |
-
sr: Sample rate
|
188 |
-
smoothed_predictions: Smoothed model predictions
|
189 |
-
meter_grid_times: Time grid for predictions
|
190 |
-
|
191 |
-
Returns:
|
192 |
-
List of (start_time, end_time, audio_segment) tuples
|
193 |
-
"""
|
194 |
-
# Define threshold for chorus detection (probability > 0.5)
|
195 |
threshold = 0.5
|
196 |
-
|
197 |
-
# Find the segments where the predictions are above the threshold
|
198 |
chorus_mask = smoothed_predictions > threshold
|
199 |
-
|
200 |
-
# Group consecutive True values to identify segments
|
201 |
segments = []
|
202 |
current_segment = None
|
203 |
|
@@ -205,10 +150,8 @@ def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.nda
|
|
205 |
time = meter_grid_times[i]
|
206 |
|
207 |
if is_chorus and current_segment is None:
|
208 |
-
# Start a new segment
|
209 |
current_segment = (time, None, None)
|
210 |
elif not is_chorus and current_segment is not None:
|
211 |
-
# End the current segment
|
212 |
start_time = current_segment[0]
|
213 |
current_segment = (start_time, time, None)
|
214 |
segments.append(current_segment)
|
@@ -222,13 +165,9 @@ def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.nda
|
|
222 |
# Extract the actual audio for each segment
|
223 |
segments_with_audio = []
|
224 |
for start_time, end_time, _ in segments:
|
225 |
-
# Convert times to sample indices
|
226 |
start_idx = int(start_time * sr)
|
227 |
end_idx = int(end_time * sr)
|
228 |
-
|
229 |
-
# Extract the audio segment
|
230 |
segment_audio = y[start_idx:end_idx]
|
231 |
-
|
232 |
segments_with_audio.append((start_time, end_time, segment_audio))
|
233 |
|
234 |
return segments_with_audio
|
@@ -236,49 +175,29 @@ def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.nda
|
|
236 |
|
237 |
def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]],
|
238 |
sr: int, fade_duration: float = 0.3) -> Tuple[np.ndarray, str]:
|
239 |
-
"""Create a compilation of chorus segments.
|
240 |
-
|
241 |
-
Args:
|
242 |
-
segments: List of (start_time, end_time, audio_data) tuples
|
243 |
-
sr: Sample rate
|
244 |
-
fade_duration: Duration of fade in/out in seconds
|
245 |
-
|
246 |
-
Returns:
|
247 |
-
Tuple of (compilation_audio, description)
|
248 |
-
"""
|
249 |
if not segments:
|
250 |
return np.array([]), "No chorus segments found"
|
251 |
|
252 |
-
# Calculate the number of samples for fading
|
253 |
fade_samples = int(fade_duration * sr)
|
254 |
-
|
255 |
-
# Prepare a list to store the processed segments
|
256 |
processed_segments = []
|
257 |
-
|
258 |
-
# Description of segments
|
259 |
segment_descriptions = []
|
260 |
|
261 |
-
# Process each segment
|
262 |
for i, (start_time, end_time, audio) in enumerate(segments):
|
263 |
-
# Apply fade in and fade out
|
264 |
segment_length = len(audio)
|
265 |
|
266 |
if segment_length <= 2 * fade_samples:
|
267 |
-
# Segment is too short for fading, skip it
|
268 |
continue
|
269 |
|
270 |
-
# Create a linear fade in and fade out
|
271 |
fade_in = np.linspace(0, 1, fade_samples)
|
272 |
fade_out = np.linspace(1, 0, fade_samples)
|
273 |
|
274 |
-
# Apply the fades
|
275 |
audio_faded = audio.copy()
|
276 |
audio_faded[:fade_samples] *= fade_in
|
277 |
audio_faded[-fade_samples:] *= fade_out
|
278 |
|
279 |
processed_segments.append(audio_faded)
|
280 |
|
281 |
-
# Format the times for the description
|
282 |
start_fmt = format_time(start_time)
|
283 |
end_fmt = format_time(end_time)
|
284 |
segment_descriptions.append(f"Chorus {i+1}: {start_fmt} - {end_fmt}")
|
@@ -286,26 +205,14 @@ def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]],
|
|
286 |
if not processed_segments:
|
287 |
return np.array([]), "No chorus segments long enough for compilation"
|
288 |
|
289 |
-
# Concatenate all the processed segments
|
290 |
compilation = np.concatenate(processed_segments)
|
291 |
-
|
292 |
-
# Join the descriptions
|
293 |
description = "\n".join(segment_descriptions)
|
294 |
|
295 |
return compilation, description
|
296 |
|
297 |
|
298 |
def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str = 'mp3') -> bytes:
|
299 |
-
"""Save audio data to a format suitable for Streamlit audio playback.
|
300 |
-
|
301 |
-
Args:
|
302 |
-
audio_data: Audio samples
|
303 |
-
sr: Sample rate
|
304 |
-
file_format: Output format ('mp3', 'wav', etc.)
|
305 |
-
|
306 |
-
Returns:
|
307 |
-
Audio bytes
|
308 |
-
"""
|
309 |
with io.BytesIO() as buffer:
|
310 |
sf.write(buffer, audio_data, sr, format=file_format)
|
311 |
buffer.seek(0)
|
@@ -313,14 +220,7 @@ def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str =
|
|
313 |
|
314 |
|
315 |
def format_time(seconds: float) -> str:
|
316 |
-
"""Format seconds as MM:SS.
|
317 |
-
|
318 |
-
Args:
|
319 |
-
seconds: Time in seconds
|
320 |
-
|
321 |
-
Returns:
|
322 |
-
Formatted time string
|
323 |
-
"""
|
324 |
minutes = int(seconds // 60)
|
325 |
seconds = int(seconds % 60)
|
326 |
return f"{minutes:02d}:{seconds:02d}"
|
@@ -385,11 +285,7 @@ def main() -> None:
|
|
385 |
with st.spinner('Processing audio...'):
|
386 |
# Load audio and extract features
|
387 |
y, sr = librosa.load(audio_path, sr=22050)
|
388 |
-
|
389 |
-
# Create a temporary directory for model output
|
390 |
temp_output_dir = tempfile.mkdtemp()
|
391 |
-
|
392 |
-
# Load the model
|
393 |
model = load_CRNN_model(MODEL_PATH)
|
394 |
|
395 |
# Process audio and make predictions
|
@@ -397,14 +293,10 @@ def main() -> None:
|
|
397 |
meter_grid_times, predictions = make_predictions(model, audio_features)
|
398 |
|
399 |
# Smooth predictions to avoid rapid transitions
|
400 |
-
smoothed_predictions = np.convolve(predictions,
|
401 |
-
np.ones(5)/5,
|
402 |
-
mode='same')
|
403 |
|
404 |
-
# Extract chorus segments
|
405 |
chorus_segments = extract_chorus_segments(y, sr, smoothed_predictions, meter_grid_times)
|
406 |
-
|
407 |
-
# Create a chorus compilation
|
408 |
compilation_audio, segments_desc = create_chorus_compilation(chorus_segments, sr)
|
409 |
|
410 |
# Display results
|
@@ -427,8 +319,6 @@ def main() -> None:
|
|
427 |
# Highlight chorus sections
|
428 |
for start_time, end_time, _ in chorus_segments:
|
429 |
ax.axvspan(start_time, end_time, alpha=0.3, color=THEME_COLORS['primary'])
|
430 |
-
|
431 |
-
# Add a label at the start of each chorus
|
432 |
ax.annotate('Chorus',
|
433 |
xy=(start_time, 0.8 * max(y)),
|
434 |
xytext=(start_time + 0.5, 0.9 * max(y)),
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
# -*- coding: utf-8 -*-
|
3 |
|
4 |
+
"""
|
5 |
+
Streamlit web app for chorus detection in audio files.
|
|
|
|
|
6 |
"""
|
7 |
|
8 |
import os
|
9 |
import sys
|
10 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
import base64
|
12 |
import tempfile
|
13 |
import warnings
|
|
|
|
|
14 |
import io
|
15 |
+
from typing import Optional, Tuple, List
|
16 |
|
17 |
import matplotlib.pyplot as plt
|
18 |
import streamlit as st
|
|
|
22 |
import numpy as np
|
23 |
from pydub import AudioSegment
|
24 |
|
25 |
+
# Configure logging
|
26 |
+
logger = logging.getLogger("streamlit-app")
|
|
|
27 |
|
28 |
+
# Suppress TensorFlow and other warnings
|
29 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
30 |
+
warnings.filterwarnings("ignore")
|
31 |
+
tf.get_logger().setLevel('ERROR')
|
32 |
|
33 |
+
# Import components
|
34 |
try:
|
35 |
+
from download_model import ensure_model_exists
|
36 |
from chorus_detection.audio.data_processing import process_audio
|
37 |
from chorus_detection.audio.processor import extract_audio
|
38 |
from chorus_detection.models.crnn import load_CRNN_model, make_predictions
|
39 |
from chorus_detection.utils.cli import is_youtube_url
|
40 |
from chorus_detection.utils.logging import logger
|
|
|
41 |
logger.info("Successfully imported chorus_detection modules")
|
42 |
except ImportError as e:
|
43 |
+
logger.error(f"Error importing modules: {e}")
|
44 |
raise
|
45 |
|
46 |
+
# Define model path
|
47 |
MODEL_PATH = os.path.join(os.getcwd(), "models", "CRNN", "best_model_V3.h5")
|
48 |
if not os.path.exists(MODEL_PATH):
|
49 |
MODEL_PATH = ensure_model_exists()
|
50 |
|
51 |
+
# UI theme colors
|
52 |
THEME_COLORS = {
|
53 |
'background': '#121212',
|
54 |
'card_bg': '#181818',
|
|
|
62 |
|
63 |
|
64 |
def get_binary_file_downloader_html(bin_file: str, file_label: str = 'File') -> str:
|
65 |
+
"""Generate HTML for file download link."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
with open(bin_file, 'rb') as f:
|
67 |
data = f.read()
|
68 |
b64 = base64.b64encode(data).decode()
|
|
|
109 |
|
110 |
|
111 |
def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
|
112 |
+
"""Process a YouTube URL and extract audio."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
try:
|
114 |
with st.spinner('Downloading audio from YouTube...'):
|
115 |
audio_path, video_name = extract_audio(url)
|
|
|
121 |
|
122 |
|
123 |
def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
|
124 |
+
"""Process an uploaded audio file."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
try:
|
126 |
with st.spinner('Processing uploaded file...'):
|
|
|
127 |
temp_dir = tempfile.mkdtemp()
|
128 |
file_name = uploaded_file.name
|
129 |
temp_path = os.path.join(temp_dir, file_name)
|
|
|
140 |
|
141 |
def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.ndarray,
|
142 |
meter_grid_times: np.ndarray) -> List[Tuple[float, float, np.ndarray]]:
|
143 |
+
"""Extract chorus segments from predictions."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
threshold = 0.5
|
|
|
|
|
145 |
chorus_mask = smoothed_predictions > threshold
|
|
|
|
|
146 |
segments = []
|
147 |
current_segment = None
|
148 |
|
|
|
150 |
time = meter_grid_times[i]
|
151 |
|
152 |
if is_chorus and current_segment is None:
|
|
|
153 |
current_segment = (time, None, None)
|
154 |
elif not is_chorus and current_segment is not None:
|
|
|
155 |
start_time = current_segment[0]
|
156 |
current_segment = (start_time, time, None)
|
157 |
segments.append(current_segment)
|
|
|
165 |
# Extract the actual audio for each segment
|
166 |
segments_with_audio = []
|
167 |
for start_time, end_time, _ in segments:
|
|
|
168 |
start_idx = int(start_time * sr)
|
169 |
end_idx = int(end_time * sr)
|
|
|
|
|
170 |
segment_audio = y[start_idx:end_idx]
|
|
|
171 |
segments_with_audio.append((start_time, end_time, segment_audio))
|
172 |
|
173 |
return segments_with_audio
|
|
|
175 |
|
176 |
def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]],
|
177 |
sr: int, fade_duration: float = 0.3) -> Tuple[np.ndarray, str]:
|
178 |
+
"""Create a compilation of chorus segments."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
if not segments:
|
180 |
return np.array([]), "No chorus segments found"
|
181 |
|
|
|
182 |
fade_samples = int(fade_duration * sr)
|
|
|
|
|
183 |
processed_segments = []
|
|
|
|
|
184 |
segment_descriptions = []
|
185 |
|
|
|
186 |
for i, (start_time, end_time, audio) in enumerate(segments):
|
|
|
187 |
segment_length = len(audio)
|
188 |
|
189 |
if segment_length <= 2 * fade_samples:
|
|
|
190 |
continue
|
191 |
|
|
|
192 |
fade_in = np.linspace(0, 1, fade_samples)
|
193 |
fade_out = np.linspace(1, 0, fade_samples)
|
194 |
|
|
|
195 |
audio_faded = audio.copy()
|
196 |
audio_faded[:fade_samples] *= fade_in
|
197 |
audio_faded[-fade_samples:] *= fade_out
|
198 |
|
199 |
processed_segments.append(audio_faded)
|
200 |
|
|
|
201 |
start_fmt = format_time(start_time)
|
202 |
end_fmt = format_time(end_time)
|
203 |
segment_descriptions.append(f"Chorus {i+1}: {start_fmt} - {end_fmt}")
|
|
|
205 |
if not processed_segments:
|
206 |
return np.array([]), "No chorus segments long enough for compilation"
|
207 |
|
|
|
208 |
compilation = np.concatenate(processed_segments)
|
|
|
|
|
209 |
description = "\n".join(segment_descriptions)
|
210 |
|
211 |
return compilation, description
|
212 |
|
213 |
|
214 |
def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str = 'mp3') -> bytes:
|
215 |
+
"""Save audio data to a format suitable for Streamlit audio playback."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
with io.BytesIO() as buffer:
|
217 |
sf.write(buffer, audio_data, sr, format=file_format)
|
218 |
buffer.seek(0)
|
|
|
220 |
|
221 |
|
222 |
def format_time(seconds: float) -> str:
|
223 |
+
"""Format seconds as MM:SS."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
minutes = int(seconds // 60)
|
225 |
seconds = int(seconds % 60)
|
226 |
return f"{minutes:02d}:{seconds:02d}"
|
|
|
285 |
with st.spinner('Processing audio...'):
|
286 |
# Load audio and extract features
|
287 |
y, sr = librosa.load(audio_path, sr=22050)
|
|
|
|
|
288 |
temp_output_dir = tempfile.mkdtemp()
|
|
|
|
|
289 |
model = load_CRNN_model(MODEL_PATH)
|
290 |
|
291 |
# Process audio and make predictions
|
|
|
293 |
meter_grid_times, predictions = make_predictions(model, audio_features)
|
294 |
|
295 |
# Smooth predictions to avoid rapid transitions
|
296 |
+
smoothed_predictions = np.convolve(predictions, np.ones(5)/5, mode='same')
|
|
|
|
|
297 |
|
298 |
+
# Extract chorus segments and create compilation
|
299 |
chorus_segments = extract_chorus_segments(y, sr, smoothed_predictions, meter_grid_times)
|
|
|
|
|
300 |
compilation_audio, segments_desc = create_chorus_compilation(chorus_segments, sr)
|
301 |
|
302 |
# Display results
|
|
|
319 |
# Highlight chorus sections
|
320 |
for start_time, end_time, _ in chorus_segments:
|
321 |
ax.axvspan(start_time, end_time, alpha=0.3, color=THEME_COLORS['primary'])
|
|
|
|
|
322 |
ax.annotate('Chorus',
|
323 |
xy=(start_time, 0.8 * max(y)),
|
324 |
xytext=(start_time + 0.5, 0.9 * max(y)),
|