dennisvdang commited on
Commit
ad0da04
·
1 Parent(s): 88140b5

Flatten directory structure for simpler imports

Browse files
.space/app-entrypoint.sh CHANGED
Binary files a/.space/app-entrypoint.sh and b/.space/app-entrypoint.sh differ
 
Dockerfile CHANGED
@@ -3,7 +3,7 @@ FROM python:3.9-slim
3
  # Set environment variables
4
  ENV PYTHONDONTWRITEBYTECODE=1
5
  ENV PYTHONUNBUFFERED=1
6
- ENV PYTHONPATH="${PYTHONPATH}:/app:/app/src:/home/user/app:/home/user/app/src:."
7
  ENV MODEL_REVISION="20e66eb3d0788373c3bdc5b28fa2f2587b0e475f3bbc47e8ab9ff0dbdbb2df32"
8
  ENV MODEL_HF_REPO="dennisvdang/chorus-detection"
9
  ENV HF_MODEL_FILENAME="chorus_detection_crnn.h5"
@@ -27,21 +27,17 @@ RUN pip install -e .
27
  # Make the entry point script executable
28
  RUN chmod +x .space/app-entrypoint.sh || echo "Could not chmod app-entrypoint.sh"
29
 
30
- # Create symlinks to ensure proper imports
31
- RUN ln -sf /app/src /src || echo "Could not create symlink"
32
-
33
  # Ensure chorus_detection package is properly installed
34
  RUN cd /app && \
35
  python -c "import chorus_detection; print(f'Successfully imported chorus_detection module from {chorus_detection.__file__}')" || \
36
- echo "Warning: chorus_detection module not properly installed, will be addressed at runtime"
37
 
38
  # Ensure model exists and debug info
39
  RUN echo "Debug: ls -la /app" && ls -la /app && \
40
- echo "Debug: ls -la /app/src" && ls -la /app/src && \
41
  echo "Debug: PYTHONPATH=$PYTHONPATH" && \
42
  python -c "import sys; print(f'Python path: {sys.path}')" && \
43
  python -c "import os; print(f'Working directory: {os.getcwd()}')" && \
44
- python -c "from src.download_model import ensure_model_exists; ensure_model_exists(revision='${MODEL_REVISION}')" || echo "Warning: Model download failed during build"
45
 
46
  # Expose port for Streamlit
47
  EXPOSE 7860
 
3
  # Set environment variables
4
  ENV PYTHONDONTWRITEBYTECODE=1
5
  ENV PYTHONUNBUFFERED=1
6
+ ENV PYTHONPATH="${PYTHONPATH}:/app:/home/user/app:."
7
  ENV MODEL_REVISION="20e66eb3d0788373c3bdc5b28fa2f2587b0e475f3bbc47e8ab9ff0dbdbb2df32"
8
  ENV MODEL_HF_REPO="dennisvdang/chorus-detection"
9
  ENV HF_MODEL_FILENAME="chorus_detection_crnn.h5"
 
27
  # Make the entry point script executable
28
  RUN chmod +x .space/app-entrypoint.sh || echo "Could not chmod app-entrypoint.sh"
29
 
 
 
 
30
  # Ensure chorus_detection package is properly installed
31
  RUN cd /app && \
32
  python -c "import chorus_detection; print(f'Successfully imported chorus_detection module from {chorus_detection.__file__}')" || \
33
+ echo "Warning: chorus_detection module not properly installed"
34
 
35
  # Ensure model exists and debug info
36
  RUN echo "Debug: ls -la /app" && ls -la /app && \
 
37
  echo "Debug: PYTHONPATH=$PYTHONPATH" && \
38
  python -c "import sys; print(f'Python path: {sys.path}')" && \
39
  python -c "import os; print(f'Working directory: {os.getcwd()}')" && \
40
+ python -c "from download_model import ensure_model_exists; ensure_model_exists(revision='${MODEL_REVISION}')" || echo "Warning: Model download failed during build"
41
 
42
  # Expose port for Streamlit
43
  EXPOSE 7860
app.py CHANGED
@@ -10,14 +10,6 @@ without circular imports.
10
  import os
11
  import sys
12
  import logging
13
- import subprocess
14
-
15
- # Add src directory to path
16
- current_dir = os.path.dirname(os.path.abspath(__file__))
17
- src_dir = os.path.join(current_dir, "src")
18
- if src_dir not in sys.path:
19
- sys.path.insert(0, src_dir)
20
- sys.path.insert(0, current_dir)
21
 
22
  # Configure logging
23
  logging.basicConfig(
@@ -33,34 +25,14 @@ if os.environ.get("SPACE_ID"):
33
  logger.info(f"MODEL_REVISION: {os.environ.get('MODEL_REVISION')}")
34
  logger.info(f"Current working directory: {os.getcwd()}")
35
  logger.info(f"Directory contents: {os.listdir()}")
36
- if os.path.exists('src'):
37
- logger.info(f"src directory contents: {os.listdir('src')}")
38
-
39
- # Install package in development mode
40
- try:
41
- logger.info("Installing package in development mode...")
42
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-e", "."])
43
- logger.info("Package installation successful")
44
- except Exception as e:
45
- logger.error(f"Error installing package: {e}")
46
-
47
- # Verify python path
48
- logger.info(f"Python path after update: {sys.path}")
49
-
50
- # Try importing chorus_detection to verify
51
- try:
52
- import chorus_detection
53
- logger.info(f"Successfully imported chorus_detection from {chorus_detection.__file__}")
54
- except ImportError as e:
55
- logger.error(f"Failed to import chorus_detection: {e}")
56
 
57
  def main():
58
  """Main entry point for the Streamlit app."""
59
  logger.info("Starting Streamlit app...")
60
  # Import the Streamlit app module directly
61
- import src.streamlit_app
62
  # Run the Streamlit app
63
- src.streamlit_app.main()
64
 
65
  if __name__ == "__main__":
66
  try:
 
10
  import os
11
  import sys
12
  import logging
 
 
 
 
 
 
 
 
13
 
14
  # Configure logging
15
  logging.basicConfig(
 
25
  logger.info(f"MODEL_REVISION: {os.environ.get('MODEL_REVISION')}")
26
  logger.info(f"Current working directory: {os.getcwd()}")
27
  logger.info(f"Directory contents: {os.listdir()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def main():
30
  """Main entry point for the Streamlit app."""
31
  logger.info("Starting Streamlit app...")
32
  # Import the Streamlit app module directly
33
+ import streamlit_app
34
  # Run the Streamlit app
35
+ streamlit_app.main()
36
 
37
  if __name__ == "__main__":
38
  try:
chorus_detection/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """Chorus Detection package for identifying choruses in music.
2
+
3
+ This package contains modules for:
4
+ - Audio processing and feature extraction
5
+ - Machine learning models for chorus detection
6
+ - Visualization tools for audio analysis
7
+ - Utility functions
8
+ """
9
+
10
+ __version__ = "0.1.0"
chorus_detection/audio/__init__.py ADDED
File without changes
chorus_detection/audio/data_processing.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Module for audio data processing including segmentation and positional encoding."""
5
+
6
+ from typing import List, Optional, Tuple, Any
7
+
8
+ import librosa
9
+ import numpy as np
10
+
11
+ from chorus_detection.audio.processor import AudioFeature
12
+ from chorus_detection.config import SR, HOP_LENGTH, MAX_FRAMES, MAX_METERS, N_FEATURES
13
+ from chorus_detection.utils.logging import logger
14
+
15
+
16
+ def segment_data_meters(data: np.ndarray, meter_grid: np.ndarray) -> List[np.ndarray]:
17
+ """Divide song data into segments based on measure grid frames.
18
+
19
+ Args:
20
+ data: The song data to be segmented
21
+ meter_grid: The grid indicating the start of each measure
22
+
23
+ Returns:
24
+ A list of song data segments
25
+ """
26
+ # Create segments using vectorized operations
27
+ meter_segments = [data[s:e] for s, e in zip(meter_grid[:-1], meter_grid[1:])]
28
+
29
+ # Convert all segments to float32 for consistent processing
30
+ meter_segments = [segment.astype(np.float32) for segment in meter_segments]
31
+
32
+ return meter_segments
33
+
34
+
35
+ def positional_encoding(position: int, d_model: int) -> np.ndarray:
36
+ """Generate a positional encoding for a given position and model dimension.
37
+
38
+ Args:
39
+ position: The position for which to generate the encoding
40
+ d_model: The dimension of the model
41
+
42
+ Returns:
43
+ The positional encoding
44
+ """
45
+ # Create position array
46
+ positions = np.arange(position)[:, np.newaxis]
47
+
48
+ # Calculate dimension-based scaling factors
49
+ dim_indices = np.arange(d_model)[np.newaxis, :]
50
+ angles = positions / np.power(10000, (2 * (dim_indices // 2)) / np.float32(d_model))
51
+
52
+ # Apply sine to even indices and cosine to odd indices
53
+ encodings = np.zeros((position, d_model), dtype=np.float32)
54
+ encodings[:, 0::2] = np.sin(angles[:, 0::2])
55
+ encodings[:, 1::2] = np.cos(angles[:, 1::2])
56
+
57
+ return encodings
58
+
59
+
60
+ def apply_hierarchical_positional_encoding(segments: List[np.ndarray]) -> List[np.ndarray]:
61
+ """Apply positional encoding at the meter and frame levels to a list of segments.
62
+
63
+ Args:
64
+ segments: The list of segments to encode
65
+
66
+ Returns:
67
+ The list of segments with applied positional encoding
68
+ """
69
+ if not segments:
70
+ logger.warning("No segments to encode")
71
+ return []
72
+
73
+ n_features = segments[0].shape[1]
74
+
75
+ # Generate measure-level positional encodings
76
+ measure_level_encodings = positional_encoding(len(segments), n_features)
77
+
78
+ # Apply hierarchical encodings to each segment
79
+ encoded_segments = []
80
+ for i, segment in enumerate(segments):
81
+ # Generate frame-level positional encoding
82
+ frame_level_encoding = positional_encoding(len(segment), n_features)
83
+
84
+ # Combine frame-level and measure-level encodings
85
+ encoded_segment = segment + frame_level_encoding + measure_level_encodings[i]
86
+ encoded_segments.append(encoded_segment)
87
+
88
+ return encoded_segments
89
+
90
+
91
+ def pad_song(encoded_segments: List[np.ndarray], max_frames: int = MAX_FRAMES,
92
+ max_meters: int = MAX_METERS, n_features: int = N_FEATURES) -> np.ndarray:
93
+ """Pad or truncate the encoded segments to have the specified dimensions.
94
+
95
+ Args:
96
+ encoded_segments: The encoded segments to pad or truncate
97
+ max_frames: The maximum number of frames per segment
98
+ max_meters: The maximum number of meters
99
+ n_features: The number of features per frame
100
+
101
+ Returns:
102
+ The padded or truncated song as a numpy array
103
+ """
104
+ if not encoded_segments:
105
+ logger.warning("No encoded segments to pad")
106
+ return np.zeros((max_meters, max_frames, n_features), dtype=np.float32)
107
+
108
+ # Pad or truncate each meter/segment to max_frames
109
+ padded_meters = []
110
+ for meter in encoded_segments:
111
+ # Truncate if longer than max_frames
112
+ truncated_meter = meter[:max_frames] if meter.shape[0] > max_frames else meter
113
+
114
+ # Pad if shorter than max_frames
115
+ if truncated_meter.shape[0] < max_frames:
116
+ padding = ((0, max_frames - truncated_meter.shape[0]), (0, 0))
117
+ padded_meter = np.pad(truncated_meter, padding, 'constant', constant_values=0)
118
+ else:
119
+ padded_meter = truncated_meter
120
+
121
+ padded_meters.append(padded_meter)
122
+
123
+ # Create padding meter (all zeros)
124
+ padding_meter = np.zeros((max_frames, n_features), dtype=np.float32)
125
+
126
+ # Truncate or pad to max_meters
127
+ if len(padded_meters) > max_meters:
128
+ padded_song = np.array(padded_meters[:max_meters])
129
+ else:
130
+ padded_song = np.array(padded_meters + [padding_meter] * (max_meters - len(padded_meters)))
131
+
132
+ return padded_song
133
+
134
+
135
+ def process_audio(audio_path: str, trim_silence: bool = True, sr: int = SR,
136
+ hop_length: int = HOP_LENGTH) -> Tuple[Optional[np.ndarray], Optional[AudioFeature]]:
137
+ """Process an audio file, extracting features and applying positional encoding.
138
+
139
+ Args:
140
+ audio_path: The path to the audio file
141
+ trim_silence: Whether to trim silence from the audio
142
+ sr: The sample rate to use when loading the audio
143
+ hop_length: The hop length to use for feature extraction
144
+
145
+ Returns:
146
+ A tuple containing the processed audio and its features
147
+ """
148
+ logger.info(f"Processing audio file: {audio_path}")
149
+
150
+ try:
151
+ # First optionally strip silence
152
+ if trim_silence:
153
+ from chorus_detection.audio.processor import strip_silence
154
+ strip_silence(audio_path)
155
+
156
+ # Create audio feature object and extract features
157
+ audio_features = AudioFeature(audio_path=audio_path, sr=sr, hop_length=hop_length)
158
+ audio_features.extract_features()
159
+ audio_features.create_meter_grid()
160
+
161
+ # Segment the audio data by meter grid
162
+ audio_segments = segment_data_meters(
163
+ audio_features.combined_features, audio_features.meter_grid)
164
+
165
+ # Apply positional encoding
166
+ encoded_audio_segments = apply_hierarchical_positional_encoding(audio_segments)
167
+
168
+ # Pad song to fixed dimensions and add batch dimension
169
+ processed_audio = np.expand_dims(pad_song(encoded_audio_segments), axis=0)
170
+
171
+ logger.info(f"Audio processing complete: {processed_audio.shape}")
172
+ return processed_audio, audio_features
173
+
174
+ except Exception as e:
175
+ logger.error(f"Error processing audio: {e}")
176
+
177
+ import traceback
178
+ logger.debug(traceback.format_exc())
179
+
180
+ return None, None
chorus_detection/audio/processor.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Module for audio feature extraction and processing."""
5
+
6
+ import os
7
+ import subprocess
8
+ import time
9
+ from functools import reduce
10
+ from pathlib import Path
11
+ from typing import List, Tuple, Optional, Dict, Any, Union
12
+
13
+ import librosa
14
+ import numpy as np
15
+ from pydub import AudioSegment
16
+ from pydub.silence import detect_nonsilent
17
+ from sklearn.preprocessing import StandardScaler
18
+
19
+ from chorus_detection.config import SR, HOP_LENGTH, AUDIO_TEMP_PATH
20
+ from chorus_detection.utils.logging import logger
21
+
22
+
23
+ def extract_audio(url: str, output_path: str = str(AUDIO_TEMP_PATH)) -> Tuple[Optional[str], Optional[str]]:
24
+ """Download audio from YouTube URL and save as MP3 using yt-dlp.
25
+
26
+ Args:
27
+ url: YouTube URL of the audio file
28
+ output_path: Path to save the downloaded audio file
29
+
30
+ Returns:
31
+ Tuple containing path to the downloaded audio file and the video title, or None if download fails
32
+ """
33
+ try:
34
+ # Create output directory if it doesn't exist
35
+ os.makedirs(output_path, exist_ok=True)
36
+
37
+ # Create a unique filename using timestamp
38
+ timestamp = int(time.time())
39
+ output_file = os.path.join(output_path, f"audio_{timestamp}.mp3")
40
+
41
+ # Get the video title first
42
+ video_title = get_video_title(url) or f"Video_{timestamp}"
43
+
44
+ # Download the audio
45
+ success, error_msg = download_audio(url, output_file)
46
+
47
+ if not success:
48
+ handle_download_error(error_msg)
49
+ return None, None
50
+
51
+ # Check if file exists and is valid
52
+ if os.path.exists(output_file) and os.path.getsize(output_file) > 0:
53
+ logger.info(f"Successfully downloaded: {video_title}")
54
+ return output_file, video_title
55
+ else:
56
+ logger.error("Download completed but file not found or empty")
57
+ return None, None
58
+
59
+ except Exception as e:
60
+ import traceback
61
+ error_details = traceback.format_exc()
62
+ logger.error(f"An error occurred during YouTube download: {e}")
63
+ logger.debug(f"Error details: {error_details}")
64
+
65
+ check_yt_dlp_installation()
66
+ return None, None
67
+
68
+
69
+ def get_video_title(url: str) -> Optional[str]:
70
+ """Get the title of a YouTube video.
71
+
72
+ Args:
73
+ url: YouTube URL
74
+
75
+ Returns:
76
+ Video title if successful, None otherwise
77
+ """
78
+ try:
79
+ title_command = ['yt-dlp', '--get-title', '--no-warnings', url]
80
+ video_title = subprocess.check_output(title_command, universal_newlines=True).strip()
81
+ return video_title
82
+ except subprocess.CalledProcessError as e:
83
+ logger.warning(f"Could not retrieve video title: {str(e)}")
84
+ return None
85
+
86
+
87
+ def download_audio(url: str, output_file: str) -> Tuple[bool, str]:
88
+ """Download audio from YouTube URL using yt-dlp.
89
+
90
+ Args:
91
+ url: YouTube URL
92
+ output_file: Output file path
93
+
94
+ Returns:
95
+ Tuple containing (success, error_message)
96
+ """
97
+ command = [
98
+ 'yt-dlp',
99
+ '-f', 'bestaudio',
100
+ '--extract-audio',
101
+ '--audio-format', 'mp3',
102
+ '--audio-quality', '0', # Best quality
103
+ '--output', output_file,
104
+ '--no-playlist',
105
+ '--verbose',
106
+ url
107
+ ]
108
+
109
+ process = subprocess.Popen(
110
+ command,
111
+ stdout=subprocess.PIPE,
112
+ stderr=subprocess.PIPE,
113
+ universal_newlines=True
114
+ )
115
+ stdout, stderr = process.communicate()
116
+
117
+ if process.returncode != 0:
118
+ error_msg = f"Error downloading from YouTube (code {process.returncode}): {stderr}"
119
+ return False, error_msg
120
+
121
+ return True, ""
122
+
123
+
124
+ def handle_download_error(error_msg: str) -> None:
125
+ """Handle common YouTube download errors with helpful messages.
126
+
127
+ Args:
128
+ error_msg: Error message from yt-dlp
129
+ """
130
+ logger.error(error_msg)
131
+
132
+ if "Sign in to confirm you're not a bot" in error_msg:
133
+ logger.error("YouTube is detecting automated access. Try using a local file instead.")
134
+ elif any(x in error_msg.lower() for x in ["unavailable video", "private video"]):
135
+ logger.error("The video appears to be private or unavailable. Please try another URL.")
136
+ elif "copyright" in error_msg.lower():
137
+ logger.error("The video may be blocked due to copyright restrictions.")
138
+ elif any(x in error_msg.lower() for x in ["rate limit", "429"]):
139
+ logger.error("YouTube rate limit reached. Please try again later.")
140
+
141
+
142
+ def check_yt_dlp_installation() -> None:
143
+ """Check if yt-dlp is installed and provide guidance if it's not."""
144
+ try:
145
+ subprocess.run(['yt-dlp', '--version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
146
+ except FileNotFoundError:
147
+ logger.error("yt-dlp is not installed or not in PATH. Please install it with: pip install yt-dlp")
148
+
149
+
150
+ def strip_silence(audio_path: str) -> None:
151
+ """Remove silent parts from an audio file.
152
+
153
+ Args:
154
+ audio_path: Path to the audio file
155
+ """
156
+ try:
157
+ sound = AudioSegment.from_file(audio_path)
158
+ nonsilent_ranges = detect_nonsilent(
159
+ sound, min_silence_len=500, silence_thresh=-50)
160
+
161
+ if not nonsilent_ranges:
162
+ logger.warning("No non-silent parts detected in the audio. Using original file.")
163
+ return
164
+
165
+ stripped = reduce(lambda acc, val: acc + sound[val[0]:val[1]],
166
+ nonsilent_ranges, AudioSegment.empty())
167
+ stripped.export(audio_path, format='mp3')
168
+ except Exception as e:
169
+ logger.error(f"Error stripping silence: {e}")
170
+ logger.info("Proceeding with original audio file")
171
+
172
+
173
+ class AudioFeature:
174
+ """Class for extracting and processing audio features."""
175
+
176
+ def __init__(self, audio_path: str, sr: int = SR, hop_length: int = HOP_LENGTH):
177
+ """Initialize the AudioFeature class.
178
+
179
+ Args:
180
+ audio_path: Path to the audio file
181
+ sr: Sample rate for audio processing
182
+ hop_length: Hop length for feature extraction
183
+ """
184
+ self.audio_path: str = audio_path
185
+ self.sr: int = sr
186
+ self.hop_length: int = hop_length
187
+ self.time_signature: int = 4
188
+
189
+ # Initialize all features as None
190
+ self.y: Optional[np.ndarray] = None
191
+ self.y_harm: Optional[np.ndarray] = None
192
+ self.y_perc: Optional[np.ndarray] = None
193
+ self.beats: Optional[np.ndarray] = None
194
+ self.chroma_acts: Optional[np.ndarray] = None
195
+ self.chromagram: Optional[np.ndarray] = None
196
+ self.combined_features: Optional[np.ndarray] = None
197
+ self.key: Optional[str] = None
198
+ self.mode: Optional[str] = None
199
+ self.mel_acts: Optional[np.ndarray] = None
200
+ self.melspectrogram: Optional[np.ndarray] = None
201
+ self.meter_grid: Optional[np.ndarray] = None
202
+ self.mfccs: Optional[np.ndarray] = None
203
+ self.mfcc_acts: Optional[np.ndarray] = None
204
+ self.n_frames: Optional[int] = None
205
+ self.onset_env: Optional[np.ndarray] = None
206
+ self.rms: Optional[np.ndarray] = None
207
+ self.spectrogram: Optional[np.ndarray] = None
208
+ self.tempo: Optional[float] = None
209
+ self.tempogram: Optional[np.ndarray] = None
210
+ self.tempogram_acts: Optional[np.ndarray] = None
211
+
212
+ def detect_key(self, chroma_vals: np.ndarray) -> Tuple[str, str]:
213
+ """Detect the key and mode (major or minor) of the audio segment.
214
+
215
+ Args:
216
+ chroma_vals: Chromagram values to analyze for key detection
217
+
218
+ Returns:
219
+ Tuple containing the detected key and mode
220
+ """
221
+ note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
222
+
223
+ # Key profiles (Krumhansl-Kessler profiles)
224
+ major_profile = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88])
225
+ minor_profile = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17])
226
+
227
+ # Normalize profiles
228
+ major_profile /= np.linalg.norm(major_profile)
229
+ minor_profile /= np.linalg.norm(minor_profile)
230
+
231
+ # Calculate correlations for all possible rotations
232
+ major_correlations = [np.corrcoef(chroma_vals, np.roll(major_profile, i))[0, 1] for i in range(12)]
233
+ minor_correlations = [np.corrcoef(chroma_vals, np.roll(minor_profile, i))[0, 1] for i in range(12)]
234
+
235
+ # Find max correlation
236
+ max_major_idx = np.argmax(major_correlations)
237
+ max_minor_idx = np.argmax(minor_correlations)
238
+
239
+ # Determine mode
240
+ self.mode = 'major' if major_correlations[max_major_idx] > minor_correlations[max_minor_idx] else 'minor'
241
+ self.key = note_names[max_major_idx if self.mode == 'major' else max_minor_idx]
242
+
243
+ return self.key, self.mode
244
+
245
+ def calculate_ki_chroma(self, waveform: np.ndarray, sr: int, hop_length: int) -> np.ndarray:
246
+ """Calculate a normalized, key-invariant chromagram for the given audio waveform.
247
+
248
+ Args:
249
+ waveform: Audio waveform to analyze
250
+ sr: Sample rate of the waveform
251
+ hop_length: Hop length for feature extraction
252
+
253
+ Returns:
254
+ The key-invariant chromagram as a numpy array
255
+ """
256
+ # Calculate chromagram
257
+ chromagram = librosa.feature.chroma_cqt(
258
+ y=waveform, sr=sr, hop_length=hop_length, bins_per_octave=24)
259
+
260
+ # Normalize to [0, 1]
261
+ chromagram = (chromagram - chromagram.min()) / (chromagram.max() - chromagram.min() + 1e-8)
262
+
263
+ # Detect key
264
+ chroma_vals = np.sum(chromagram, axis=1)
265
+ key, mode = self.detect_key(chroma_vals)
266
+
267
+ # Make key-invariant
268
+ key_idx = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'].index(key)
269
+ shift_amount = -key_idx if mode == 'major' else -(key_idx + 3) % 12
270
+
271
+ return librosa.util.normalize(np.roll(chromagram, shift_amount, axis=0), axis=1)
272
+
273
+ def extract_features(self) -> None:
274
+ """Extract various audio features from the loaded audio."""
275
+ # Load audio
276
+ self.y, self.sr = librosa.load(self.audio_path, sr=self.sr)
277
+
278
+ # Harmonic-percussive source separation
279
+ self.y_harm, self.y_perc = librosa.effects.hpss(self.y)
280
+
281
+ # Extract spectrogram
282
+ self.spectrogram, _ = librosa.magphase(librosa.stft(self.y, hop_length=self.hop_length))
283
+
284
+ # RMS energy
285
+ self.rms = librosa.feature.rms(S=self.spectrogram, hop_length=self.hop_length).astype(np.float32)
286
+
287
+ # Mel spectrogram and activations
288
+ self.melspectrogram = librosa.feature.melspectrogram(
289
+ y=self.y, sr=self.sr, n_mels=128, hop_length=self.hop_length).astype(np.float32)
290
+ self.mel_acts = librosa.decompose.decompose(self.melspectrogram, n_components=3, sort=True)[1].astype(np.float32)
291
+
292
+ # Chromagram and activations
293
+ self.chromagram = self.calculate_ki_chroma(self.y_harm, self.sr, self.hop_length).astype(np.float32)
294
+ self.chroma_acts = librosa.decompose.decompose(self.chromagram, n_components=4, sort=True)[1].astype(np.float32)
295
+
296
+ # Onset detection and tempogram
297
+ self.onset_env = librosa.onset.onset_strength(y=self.y_perc, sr=self.sr, hop_length=self.hop_length)
298
+ self.tempogram = np.clip(librosa.feature.tempogram(
299
+ onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length), 0, None)
300
+ self.tempogram_acts = librosa.decompose.decompose(self.tempogram, n_components=3, sort=True)[1]
301
+
302
+ # MFCCs and activations
303
+ self.mfccs = librosa.feature.mfcc(y=self.y, sr=self.sr, n_mfcc=20, hop_length=self.hop_length)
304
+ self.mfccs += abs(np.min(self.mfccs) or 0) # Handle negative values
305
+ self.mfcc_acts = librosa.decompose.decompose(self.mfccs, n_components=4, sort=True)[1].astype(np.float32)
306
+
307
+ # Combine features with weighted normalization
308
+ self._combine_features()
309
+
310
+ def _combine_features(self) -> None:
311
+ """Combine all extracted features with balanced weights."""
312
+ features = [self.rms, self.mel_acts, self.chroma_acts, self.tempogram_acts, self.mfcc_acts]
313
+ feature_names = ['rms', 'mel_acts', 'chroma_acts', 'tempogram_acts', 'mfcc_acts']
314
+
315
+ # Calculate dimension-based weights
316
+ dims = {name: feature.shape[0] for feature, name in zip(features, feature_names)}
317
+ total_inv_dim = sum(1 / dim for dim in dims.values())
318
+ weights = {name: 1 / (dims[name] * total_inv_dim) for name in feature_names}
319
+
320
+ # Normalize and weight each feature
321
+ std_weighted_features = [
322
+ StandardScaler().fit_transform(feature.T).T * weights[name]
323
+ for feature, name in zip(features, feature_names)
324
+ ]
325
+
326
+ # Combine features
327
+ self.combined_features = np.concatenate(std_weighted_features, axis=0).T.astype(np.float32)
328
+ self.n_frames = len(self.combined_features)
329
+
330
+ def create_meter_grid(self) -> np.ndarray:
331
+ """Create a grid based on the meter of the song using tempo and beats.
332
+
333
+ Returns:
334
+ Numpy array containing the meter grid frame positions
335
+ """
336
+ # Extract tempo and beat information
337
+ self.tempo, self.beats = librosa.beat.beat_track(
338
+ onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length)
339
+
340
+ # Adjust tempo if it's too slow or too fast
341
+ self.tempo = self._adjust_tempo(self.tempo)
342
+
343
+ # Create meter grid
344
+ self.meter_grid = self._create_meter_grid()
345
+ return self.meter_grid
346
+
347
+ def _adjust_tempo(self, tempo: float) -> float:
348
+ """Adjust tempo to a reasonable range.
349
+
350
+ Args:
351
+ tempo: Detected tempo
352
+
353
+ Returns:
354
+ Adjusted tempo
355
+ """
356
+ if tempo < 70:
357
+ return tempo * 2
358
+ elif tempo > 140:
359
+ return tempo / 2
360
+ return tempo
361
+
362
+ def _create_meter_grid(self) -> np.ndarray:
363
+ """Helper function to create a meter grid for the song.
364
+
365
+ Returns:
366
+ Numpy array containing the meter grid frame positions
367
+ """
368
+ # Calculate beat interval
369
+ seconds_per_beat = 60 / self.tempo
370
+ beat_interval = int(librosa.time_to_frames(seconds_per_beat, sr=self.sr, hop_length=self.hop_length))
371
+
372
+ # Find best matching start beat
373
+ if len(self.beats) >= 3:
374
+ best_match = max(
375
+ (1 - abs(np.mean(self.beats[i:i+3]) - beat_interval) / beat_interval, self.beats[i])
376
+ for i in range(len(self.beats) - 2)
377
+ )
378
+ anchor_frame = best_match[1] if best_match[0] > 0.95 else self.beats[0]
379
+ else:
380
+ anchor_frame = self.beats[0] if len(self.beats) > 0 else 0
381
+
382
+ first_beat_time = librosa.frames_to_time(anchor_frame, sr=self.sr, hop_length=self.hop_length)
383
+
384
+ # Calculate beats forward and backward
385
+ time_duration = librosa.frames_to_time(self.n_frames, sr=self.sr, hop_length=self.hop_length)
386
+ num_beats_forward = int((time_duration - first_beat_time) / seconds_per_beat)
387
+ num_beats_backward = int(first_beat_time / seconds_per_beat) + 1
388
+
389
+ # Create beat times
390
+ beat_times_forward = first_beat_time + np.arange(num_beats_forward) * seconds_per_beat
391
+ beat_times_backward = first_beat_time - np.arange(1, num_beats_backward) * seconds_per_beat
392
+
393
+ # Combine and create meter grid
394
+ beat_grid = np.concatenate((np.array([0.0]), beat_times_backward[::-1], beat_times_forward))
395
+ meter_indices = np.arange(0, len(beat_grid), self.time_signature)
396
+ meter_grid = beat_grid[meter_indices]
397
+
398
+ # Ensure grid starts at 0 and ends at frame duration
399
+ if meter_grid[0] != 0.0:
400
+ meter_grid = np.insert(meter_grid, 0, 0.0)
401
+
402
+ # Convert to frames
403
+ meter_grid = librosa.time_to_frames(meter_grid, sr=self.sr, hop_length=self.hop_length)
404
+
405
+ # Ensure grid ends at the last frame
406
+ if meter_grid[-1] != self.n_frames:
407
+ meter_grid = np.append(meter_grid, self.n_frames)
408
+
409
+ return meter_grid
chorus_detection/config.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Configuration settings for the chorus detection system."""
5
+
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Dict, Any, Union, Optional
9
+
10
+ # Audio processing settings
11
+ SR: int = 12000 # Sample rate
12
+ HOP_LENGTH: int = 128 # Hop length for signal processing
13
+ MAX_FRAMES: int = 300 # Maximum frames per segment
14
+ MAX_METERS: int = 201 # Maximum meters per song
15
+ N_FEATURES: int = 15 # Number of features
16
+
17
+ # Project paths
18
+ PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.resolve()
19
+ MODEL_DIR: Path = PROJECT_ROOT / "models" / "CRNN"
20
+ MODEL_PATH: Path = MODEL_DIR / "best_model_V3.h5"
21
+ AUDIO_TEMP_PATH: Path = PROJECT_ROOT / "output" / "temp"
22
+ LOG_DIR: Path = PROJECT_ROOT / "logs"
23
+
24
+ # Alternative Docker paths
25
+ DOCKER_MODEL_PATH: str = "/app/models/CRNN/best_model_V3.h5"
26
+ DOCKER_TEMP_PATH: str = "/app/output/temp"
27
+
28
+
29
+ def get_env_path(env_var: str, default_path: Path) -> Path:
30
+ """Get a path from environment variable or use the default.
31
+
32
+ Args:
33
+ env_var: Name of the environment variable
34
+ default_path: Default path to use if environment variable is not set
35
+
36
+ Returns:
37
+ Path object for the specified location
38
+ """
39
+ env_value = os.environ.get(env_var)
40
+ if env_value:
41
+ return Path(env_value).resolve()
42
+ return default_path
43
+
44
+
45
+ # Override paths with environment variables if provided
46
+ MODEL_PATH = get_env_path("CHORUS_MODEL_PATH", MODEL_PATH)
47
+ AUDIO_TEMP_PATH = get_env_path("CHORUS_TEMP_PATH", AUDIO_TEMP_PATH)
48
+ LOG_DIR = get_env_path("CHORUS_LOG_DIR", LOG_DIR)
49
+
50
+ # Create necessary directories
51
+ os.makedirs(MODEL_DIR, exist_ok=True)
52
+ os.makedirs(AUDIO_TEMP_PATH, exist_ok=True)
53
+ os.makedirs(LOG_DIR, exist_ok=True)
54
+ os.makedirs(PROJECT_ROOT / "output", exist_ok=True)
chorus_detection/models/__init__.py ADDED
File without changes
chorus_detection/models/crnn.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Module for loading and managing the CRNN model for chorus detection."""
5
+
6
+ import os
7
+ from typing import Any, Optional, List, Tuple, Union
8
+
9
+ import numpy as np
10
+ import tensorflow as tf
11
+
12
+ from chorus_detection.config import MODEL_PATH, DOCKER_MODEL_PATH
13
+ from chorus_detection.utils.logging import logger
14
+
15
+
16
+ def load_CRNN_model(model_path: str = str(MODEL_PATH)) -> tf.keras.Model:
17
+ """Load a CRNN model with custom loss and accuracy functions.
18
+
19
+ Args:
20
+ model_path: Path to the saved model
21
+
22
+ Returns:
23
+ Loaded Keras model
24
+
25
+ Raises:
26
+ RuntimeError: If the model cannot be loaded
27
+ """
28
+ try:
29
+ # Define custom objects required for model loading
30
+ custom_objects = {
31
+ 'custom_binary_crossentropy': lambda y_true, y_pred: y_pred,
32
+ 'custom_accuracy': lambda y_true, y_pred: y_pred
33
+ }
34
+
35
+ # Try to load the model with custom objects
36
+ logger.info(f"Loading model from: {model_path}")
37
+ model = tf.keras.models.load_model(
38
+ model_path, custom_objects=custom_objects, compile=False)
39
+
40
+ # Compile the model with default optimizer and loss for prediction only
41
+ model.compile(optimizer='adam', loss='binary_crossentropy')
42
+
43
+ return model
44
+ except Exception as e:
45
+ logger.error(f"Error loading model from {model_path}: {e}")
46
+
47
+ # Try Docker container path as fallback
48
+ if model_path != DOCKER_MODEL_PATH and os.path.exists(DOCKER_MODEL_PATH):
49
+ logger.info(f"Trying Docker path: {DOCKER_MODEL_PATH}")
50
+ return load_CRNN_model(DOCKER_MODEL_PATH)
51
+
52
+ raise RuntimeError(f"Failed to load model: {e}")
53
+
54
+
55
+ def smooth_predictions(predictions: np.ndarray) -> np.ndarray:
56
+ """Smooth predictions by correcting isolated mispredictions and removing short sequences.
57
+
58
+ Args:
59
+ predictions: Array of binary predictions
60
+
61
+ Returns:
62
+ Smoothed array of binary predictions
63
+ """
64
+ # Convert to numpy array if not already
65
+ data = np.array(predictions, copy=True) if not isinstance(predictions, np.ndarray) else predictions.copy()
66
+
67
+ # First pass: Correct isolated 0's (handle 0's surrounded by 1's)
68
+ for i in range(1, len(data) - 1):
69
+ if data[i] == 0 and data[i - 1] == 1 and data[i + 1] == 1:
70
+ data[i] = 1
71
+
72
+ # Second pass: Correct isolated 1's (handle 1's surrounded by 0's)
73
+ corrected_data = data.copy()
74
+ for i in range(1, len(data) - 1):
75
+ if data[i] == 1 and data[i - 1] == 0 and data[i + 1] == 0:
76
+ corrected_data[i] = 0
77
+
78
+ # Third pass: Remove short sequences of 1s (less than 5 consecutive 1's)
79
+ smoothed_data = corrected_data.copy()
80
+ sequence_start = None
81
+
82
+ for i in range(len(corrected_data)):
83
+ if corrected_data[i] == 1:
84
+ if sequence_start is None:
85
+ sequence_start = i
86
+ else:
87
+ if sequence_start is not None:
88
+ sequence_length = i - sequence_start
89
+ if sequence_length < 5:
90
+ smoothed_data[sequence_start:i] = 0
91
+ sequence_start = None
92
+
93
+ # Handle the case where the sequence extends to the end
94
+ if sequence_start is not None:
95
+ sequence_length = len(corrected_data) - sequence_start
96
+ if sequence_length < 5:
97
+ smoothed_data[sequence_start:] = 0
98
+
99
+ return smoothed_data
100
+
101
+
102
+ def make_predictions(model: tf.keras.Model, processed_audio: np.ndarray,
103
+ audio_features: Any, url: Optional[str] = None,
104
+ video_name: Optional[str] = None) -> np.ndarray:
105
+ """Generate predictions from the model and process them.
106
+
107
+ Args:
108
+ model: The loaded model for making predictions
109
+ processed_audio: The audio data that has been processed for prediction
110
+ audio_features: Audio features object containing necessary metadata
111
+ url: YouTube URL of the audio file (optional)
112
+ video_name: Name of the video (optional)
113
+
114
+ Returns:
115
+ The smoothed binary predictions
116
+ """
117
+ import librosa
118
+
119
+ logger.info("Generating predictions...")
120
+
121
+ # Make predictions
122
+ predictions = model.predict(processed_audio)[0]
123
+
124
+ # Convert to binary predictions and handle potential size mismatch
125
+ meter_grid_length = len(audio_features.meter_grid) - 1
126
+ if len(predictions) > meter_grid_length:
127
+ predictions = predictions[:meter_grid_length]
128
+
129
+ binary_predictions = np.round(predictions).flatten()
130
+
131
+ # Apply smoothing to improve prediction quality
132
+ smoothed_predictions = smooth_predictions(binary_predictions)
133
+
134
+ # Get times for identified chorus sections
135
+ meter_grid_times = librosa.frames_to_time(
136
+ audio_features.meter_grid,
137
+ sr=audio_features.sr,
138
+ hop_length=audio_features.hop_length
139
+ )
140
+
141
+ # Identify where choruses start
142
+ chorus_start_times = [
143
+ meter_grid_times[i] for i in range(len(smoothed_predictions))
144
+ if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)
145
+ ]
146
+
147
+ # Print results if URL and video name are provided (CLI mode)
148
+ if url and video_name:
149
+ _print_chorus_results(url, video_name, chorus_start_times)
150
+
151
+ return smoothed_predictions
152
+
153
+
154
+ def _print_chorus_results(url: str, video_name: str, chorus_start_times: List[float]) -> None:
155
+ """Print formatted results showing identified choruses with links.
156
+
157
+ Args:
158
+ url: YouTube URL of the analyzed video
159
+ video_name: Name of the video
160
+ chorus_start_times: List of start times (in seconds) for identified choruses
161
+ """
162
+ # Create YouTube links with time stamps
163
+ youtube_links = [
164
+ f"\033]8;;{url}&t={int(start_time)}s\033\\{url}&t={int(start_time)}s\033]8;;\033\\"
165
+ for start_time in chorus_start_times
166
+ ]
167
+
168
+ # Format the output
169
+ link_lengths = [len(link) for link in youtube_links]
170
+ max_length = max(link_lengths + [len(video_name), len(f"Number of choruses identified: {len(chorus_start_times)}")]) if link_lengths else 50
171
+ header_footer = "=" * (max_length + 4)
172
+
173
+ # Print the results
174
+ print("\n\n")
175
+ print(header_footer)
176
+ print(f"{video_name.center(max_length + 2)}")
177
+ print(f"Number of choruses identified: {len(chorus_start_times)}".center(max_length + 4))
178
+ print(header_footer)
179
+
180
+ if chorus_start_times:
181
+ for link in youtube_links:
182
+ print(link)
183
+ else:
184
+ print("No choruses identified.")
185
+
186
+ print(header_footer)
chorus_detection/utils/__init__.py ADDED
File without changes
chorus_detection/utils/cli.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Command-line interface utilities for the chorus detection system."""
5
+
6
+ import argparse
7
+ import os
8
+ import sys
9
+ from pathlib import Path
10
+ from typing import Dict, Any, Optional, Tuple
11
+
12
+ from chorus_detection.config import MODEL_PATH
13
+ from chorus_detection.utils.logging import logger
14
+
15
+
16
+ def parse_arguments() -> argparse.Namespace:
17
+ """Parse command-line arguments.
18
+
19
+ Returns:
20
+ Parsed command-line arguments
21
+ """
22
+ parser = argparse.ArgumentParser(
23
+ description="Chorus Finder - Detect choruses in songs from YouTube URLs or local audio files")
24
+
25
+ input_group = parser.add_mutually_exclusive_group()
26
+ input_group.add_argument("--url", type=str,
27
+ help="YouTube URL of a song")
28
+ input_group.add_argument("--file", type=str,
29
+ help="Path to a local audio file")
30
+
31
+ parser.add_argument("--model_path", type=str, default=str(MODEL_PATH),
32
+ help=f"Path to the pretrained model (default: {MODEL_PATH})")
33
+ parser.add_argument("--verbose", action="store_true",
34
+ help="Verbose output", default=True)
35
+ parser.add_argument("--plot", action="store_true",
36
+ help="Display plot of the audio waveform", default=True)
37
+ parser.add_argument("--no-plot", dest="plot", action="store_false",
38
+ help="Disable plot display (useful for headless environments)")
39
+
40
+ return parser.parse_args()
41
+
42
+
43
+ def get_input_source(args: argparse.Namespace) -> Optional[str]:
44
+ """Get input source from arguments or user input.
45
+
46
+ Args:
47
+ args: Parsed command-line arguments
48
+
49
+ Returns:
50
+ Input source (URL or file path)
51
+ """
52
+ input_source = args.url or args.file
53
+ if not input_source:
54
+ print("\nChorus Detection Tool")
55
+ print("====================")
56
+ print("\nNote: YouTube download functionality may be temporarily unavailable")
57
+ print("due to YouTube's restrictions. If download fails, please use a local audio file.\n")
58
+ print("Choose input method:")
59
+ print("1. YouTube URL")
60
+ print("2. Local audio file")
61
+ choice = input("Enter choice (1 or 2): ")
62
+
63
+ if choice == "1":
64
+ input_source = input("Please enter the YouTube URL of the song: ")
65
+ elif choice == "2":
66
+ input_source = input("Please enter the path to the audio file: ")
67
+ else:
68
+ logger.error("Invalid choice")
69
+ sys.exit(1)
70
+
71
+ return input_source
72
+
73
+
74
+ def is_youtube_url(input_source: str) -> bool:
75
+ """Check if the input source is a YouTube URL.
76
+
77
+ Args:
78
+ input_source: Input source to check
79
+
80
+ Returns:
81
+ True if the input source is a YouTube URL, False otherwise
82
+ """
83
+ return input_source.startswith(('http://', 'https://'))
84
+
85
+
86
+ def validate_input_file(file_path: str) -> bool:
87
+ """Validate that the input file exists and is readable.
88
+
89
+ Args:
90
+ file_path: Path to the input file
91
+
92
+ Returns:
93
+ True if the file is valid, False otherwise
94
+ """
95
+ if not os.path.exists(file_path):
96
+ logger.error(f"Error: File not found at {file_path}")
97
+ return False
98
+
99
+ if not os.path.isfile(file_path):
100
+ logger.error(f"Error: {file_path} is not a file")
101
+ return False
102
+
103
+ if not os.access(file_path, os.R_OK):
104
+ logger.error(f"Error: No permission to read {file_path}")
105
+ return False
106
+
107
+ return True
chorus_detection/utils/logging.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Logging configuration for the chorus detection system."""
5
+
6
+ import logging
7
+ import os
8
+ import sys
9
+ from typing import Optional
10
+
11
+ from chorus_detection.config import PROJECT_ROOT
12
+
13
+ # Create logs directory
14
+ LOGS_DIR = PROJECT_ROOT / "logs"
15
+ os.makedirs(LOGS_DIR, exist_ok=True)
16
+
17
+
18
+ def setup_logger(name: str = "chorus_detection", level: int = logging.INFO,
19
+ log_file: Optional[str] = None) -> logging.Logger:
20
+ """Configure and return a logger with the specified name and level.
21
+
22
+ Args:
23
+ name: Name of the logger
24
+ level: Logging level (default: INFO)
25
+ log_file: Path to the log file (default: None)
26
+
27
+ Returns:
28
+ Configured logger instance
29
+ """
30
+ logger = logging.getLogger(name)
31
+ logger.setLevel(level)
32
+
33
+ # Create formatter
34
+ formatter = logging.Formatter(
35
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
36
+ )
37
+
38
+ # Create console handler
39
+ console_handler = logging.StreamHandler(sys.stdout)
40
+ console_handler.setFormatter(formatter)
41
+ logger.addHandler(console_handler)
42
+
43
+ # Create file handler if log_file is specified
44
+ if log_file:
45
+ file_handler = logging.FileHandler(log_file)
46
+ file_handler.setFormatter(formatter)
47
+ logger.addHandler(file_handler)
48
+
49
+ return logger
50
+
51
+
52
+ # Create default logger
53
+ logger = setup_logger()
chorus_detection/visualization/__init__.py ADDED
File without changes
chorus_detection/visualization/plotter.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Module for visualizing audio data and chorus predictions."""
5
+
6
+ from typing import List
7
+
8
+ import librosa
9
+ import librosa.display
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import os
13
+
14
+ from chorus_detection.audio.processor import AudioFeature
15
+
16
+
17
+ def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None:
18
+ """Draw meter grid lines on the plot.
19
+
20
+ Args:
21
+ ax: The matplotlib axes object to draw on
22
+ meter_grid_times: Array of times at which to draw the meter lines
23
+ """
24
+ for time in meter_grid_times:
25
+ ax.axvline(x=time, color='grey', linestyle='--',
26
+ linewidth=1, alpha=0.6)
27
+
28
+
29
+ def plot_predictions(audio_features: AudioFeature, binary_predictions: np.ndarray) -> None:
30
+ """Plot the audio waveform and overlay the predicted chorus locations.
31
+
32
+ Args:
33
+ audio_features: An object containing audio features and components
34
+ binary_predictions: Array of binary predictions indicating chorus locations
35
+ """
36
+ meter_grid_times = librosa.frames_to_time(
37
+ audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
38
+ fig, ax = plt.subplots(figsize=(12.5, 3), dpi=96)
39
+
40
+ # Display harmonic and percussive components
41
+ librosa.display.waveshow(audio_features.y_harm, sr=audio_features.sr,
42
+ alpha=0.8, ax=ax, color='deepskyblue')
43
+ librosa.display.waveshow(audio_features.y_perc, sr=audio_features.sr,
44
+ alpha=0.7, ax=ax, color='plum')
45
+ plot_meter_lines(ax, meter_grid_times)
46
+
47
+ # Highlight chorus sections
48
+ for i, prediction in enumerate(binary_predictions):
49
+ start_time = meter_grid_times[i]
50
+ end_time = meter_grid_times[i + 1] if i < len(
51
+ meter_grid_times) - 1 else len(audio_features.y) / audio_features.sr
52
+ if prediction == 1:
53
+ ax.axvspan(start_time, end_time, color='green', alpha=0.3,
54
+ label='Predicted Chorus' if i == 0 else None)
55
+
56
+ # Set plot limits and labels
57
+ ax.set_xlim([0, len(audio_features.y) / audio_features.sr])
58
+ ax.set_ylabel('Amplitude')
59
+ audio_file_name = os.path.basename(audio_features.audio_path)
60
+ ax.set_title(
61
+ f'Chorus Predictions for {os.path.splitext(audio_file_name)[0]}')
62
+
63
+ # Add legend
64
+ chorus_patch = plt.Rectangle((0, 0), 1, 1, fc='green', alpha=0.3)
65
+ handles, labels = ax.get_legend_handles_labels()
66
+ handles.append(chorus_patch)
67
+ labels.append('Chorus')
68
+ ax.legend(handles=handles, labels=labels)
69
+
70
+ # Set x-tick labels in minutes:seconds format
71
+ duration = len(audio_features.y) / audio_features.sr
72
+ xticks = np.arange(0, duration, 10)
73
+ xlabels = [f"{int(tick // 60)}:{int(tick % 60):02d}" for tick in xticks]
74
+ ax.set_xticks(xticks)
75
+ ax.set_xticklabels(xlabels)
76
+
77
+ plt.tight_layout()
78
+ plt.show(block=False)
setup.py CHANGED
@@ -6,8 +6,7 @@ from setuptools import setup, find_packages
6
  setup(
7
  name="chorus_detection",
8
  version="0.1.0",
9
- packages=find_packages(where="src"),
10
- package_dir={"": "src"},
11
  install_requires=[
12
  # These are already in requirements.txt so no need to specify versions
13
  "numpy",
 
6
  setup(
7
  name="chorus_detection",
8
  version="0.1.0",
9
+ packages=find_packages(),
 
10
  install_requires=[
11
  # These are already in requirements.txt so no need to specify versions
12
  "numpy",
streamlit_app.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """Streamlit web app for chorus detection in audio files.
5
+
6
+ This module provides a web-based interface for the chorus detection system,
7
+ allowing users to upload audio files or provide YouTube URLs for analysis.
8
+ """
9
+
10
+ import os
11
+ import sys
12
+ import logging
13
+
14
+ # Configure logging
15
+ logger = logging.getLogger("streamlit-app")
16
+
17
+ # Configure TensorFlow logging before importing TensorFlow
18
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logs
19
+
20
+ # Import model downloader to ensure model is available
21
+ try:
22
+ from download_model import ensure_model_exists
23
+ except ImportError as e:
24
+ logger.error(f"Error importing ensure_model_exists: {e}")
25
+ raise
26
+
27
+ import base64
28
+ import tempfile
29
+ import warnings
30
+ from typing import Optional, Tuple, List
31
+ import time
32
+ import io
33
+
34
+ import matplotlib.pyplot as plt
35
+ import streamlit as st
36
+ import tensorflow as tf
37
+ import librosa
38
+ import soundfile as sf
39
+ import numpy as np
40
+ from pydub import AudioSegment
41
+
42
+ # Suppress warnings
43
+ warnings.filterwarnings("ignore") # Suppress all warnings
44
+ tf.get_logger().setLevel('ERROR') # Suppress TensorFlow ERROR logs
45
+
46
+ # Debug import paths
47
+ logger.info(f"Python path: {sys.path}")
48
+ logger.info(f"Current working directory: {os.getcwd()}")
49
+
50
+ # Import modules
51
+ try:
52
+ from chorus_detection.audio.data_processing import process_audio
53
+ from chorus_detection.audio.processor import extract_audio
54
+ from chorus_detection.models.crnn import load_CRNN_model, make_predictions
55
+ from chorus_detection.utils.cli import is_youtube_url
56
+ from chorus_detection.utils.logging import logger
57
+
58
+ logger.info("Successfully imported chorus_detection modules")
59
+ except ImportError as e:
60
+ logger.error(f"Error importing chorus_detection modules: {e}")
61
+ raise
62
+
63
+ # Define the MODEL_PATH directly
64
+ MODEL_PATH = os.path.join(os.getcwd(), "models", "CRNN", "best_model_V3.h5")
65
+ if not os.path.exists(MODEL_PATH):
66
+ MODEL_PATH = ensure_model_exists()
67
+
68
+ # Define color scheme
69
+ THEME_COLORS = {
70
+ 'background': '#121212',
71
+ 'card_bg': '#181818',
72
+ 'primary': '#1DB954',
73
+ 'secondary': '#1ED760',
74
+ 'text': '#FFFFFF',
75
+ 'subtext': '#B3B3B3',
76
+ 'highlight': '#1DB954',
77
+ 'border': '#333333',
78
+ }
79
+
80
+
81
+ def get_binary_file_downloader_html(bin_file: str, file_label: str = 'File') -> str:
82
+ """Generate HTML for file download link.
83
+
84
+ Args:
85
+ bin_file: Path to the binary file
86
+ file_label: Label for the download link
87
+
88
+ Returns:
89
+ HTML string for the download link
90
+ """
91
+ with open(bin_file, 'rb') as f:
92
+ data = f.read()
93
+ b64 = base64.b64encode(data).decode()
94
+ return f'<a href="data:application/octet-stream;base64,{b64}" download="{os.path.basename(bin_file)}">{file_label}</a>'
95
+
96
+
97
+ def set_custom_theme() -> None:
98
+ """Apply custom Spotify-inspired theme to Streamlit UI."""
99
+ custom_theme = f"""
100
+ <style>
101
+ .stApp {{
102
+ background-color: {THEME_COLORS['background']};
103
+ color: {THEME_COLORS['text']};
104
+ }}
105
+ .css-18e3th9 {{
106
+ padding-top: 2rem;
107
+ padding-bottom: 10rem;
108
+ padding-left: 5rem;
109
+ padding-right: 5rem;
110
+ }}
111
+ h1, h2, h3, h4, h5, h6 {{
112
+ color: {THEME_COLORS['text']} !important;
113
+ font-weight: 700 !important;
114
+ }}
115
+ .stSidebar .sidebar-content {{
116
+ background-color: {THEME_COLORS['card_bg']};
117
+ }}
118
+ .stButton>button {{
119
+ background-color: {THEME_COLORS['primary']};
120
+ color: white;
121
+ border-radius: 500px;
122
+ padding: 8px 32px;
123
+ font-weight: 600;
124
+ border: none;
125
+ transition: all 0.3s ease;
126
+ }}
127
+ .stButton>button:hover {{
128
+ background-color: {THEME_COLORS['secondary']};
129
+ transform: scale(1.04);
130
+ }}
131
+ </style>
132
+ """
133
+ st.markdown(custom_theme, unsafe_allow_html=True)
134
+
135
+
136
+ def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
137
+ """Process a YouTube URL and extract audio.
138
+
139
+ Args:
140
+ url: YouTube URL
141
+
142
+ Returns:
143
+ Tuple of (audio_path, video_name)
144
+ """
145
+ try:
146
+ with st.spinner('Downloading audio from YouTube...'):
147
+ audio_path, video_name = extract_audio(url)
148
+ return audio_path, video_name
149
+ except Exception as e:
150
+ st.error(f"Error processing YouTube URL: {e}")
151
+ logger.error(f"Error processing YouTube URL: {e}", exc_info=True)
152
+ return None, None
153
+
154
+
155
+ def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
156
+ """Process an uploaded audio file.
157
+
158
+ Args:
159
+ uploaded_file: Streamlit UploadedFile object
160
+
161
+ Returns:
162
+ Tuple of (audio_path, file_name)
163
+ """
164
+ try:
165
+ with st.spinner('Processing uploaded file...'):
166
+ # Save the uploaded file to a temporary location
167
+ temp_dir = tempfile.mkdtemp()
168
+ file_name = uploaded_file.name
169
+ temp_path = os.path.join(temp_dir, file_name)
170
+
171
+ with open(temp_path, 'wb') as f:
172
+ f.write(uploaded_file.getbuffer())
173
+
174
+ return temp_path, file_name.split('.')[0]
175
+ except Exception as e:
176
+ st.error(f"Error processing uploaded file: {e}")
177
+ logger.error(f"Error processing uploaded file: {e}", exc_info=True)
178
+ return None, None
179
+
180
+
181
+ def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.ndarray,
182
+ meter_grid_times: np.ndarray) -> List[Tuple[float, float, np.ndarray]]:
183
+ """Extract chorus segments from predictions.
184
+
185
+ Args:
186
+ y: Audio data
187
+ sr: Sample rate
188
+ smoothed_predictions: Smoothed model predictions
189
+ meter_grid_times: Time grid for predictions
190
+
191
+ Returns:
192
+ List of (start_time, end_time, audio_segment) tuples
193
+ """
194
+ # Define threshold for chorus detection (probability > 0.5)
195
+ threshold = 0.5
196
+
197
+ # Find the segments where the predictions are above the threshold
198
+ chorus_mask = smoothed_predictions > threshold
199
+
200
+ # Group consecutive True values to identify segments
201
+ segments = []
202
+ current_segment = None
203
+
204
+ for i, is_chorus in enumerate(chorus_mask):
205
+ time = meter_grid_times[i]
206
+
207
+ if is_chorus and current_segment is None:
208
+ # Start a new segment
209
+ current_segment = (time, None, None)
210
+ elif not is_chorus and current_segment is not None:
211
+ # End the current segment
212
+ start_time = current_segment[0]
213
+ current_segment = (start_time, time, None)
214
+ segments.append(current_segment)
215
+ current_segment = None
216
+
217
+ # Handle the case where the last segment extends to the end of the song
218
+ if current_segment is not None:
219
+ start_time = current_segment[0]
220
+ segments.append((start_time, meter_grid_times[-1], None))
221
+
222
+ # Extract the actual audio for each segment
223
+ segments_with_audio = []
224
+ for start_time, end_time, _ in segments:
225
+ # Convert times to sample indices
226
+ start_idx = int(start_time * sr)
227
+ end_idx = int(end_time * sr)
228
+
229
+ # Extract the audio segment
230
+ segment_audio = y[start_idx:end_idx]
231
+
232
+ segments_with_audio.append((start_time, end_time, segment_audio))
233
+
234
+ return segments_with_audio
235
+
236
+
237
+ def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]],
238
+ sr: int, fade_duration: float = 0.3) -> Tuple[np.ndarray, str]:
239
+ """Create a compilation of chorus segments.
240
+
241
+ Args:
242
+ segments: List of (start_time, end_time, audio_data) tuples
243
+ sr: Sample rate
244
+ fade_duration: Duration of fade in/out in seconds
245
+
246
+ Returns:
247
+ Tuple of (compilation_audio, description)
248
+ """
249
+ if not segments:
250
+ return np.array([]), "No chorus segments found"
251
+
252
+ # Calculate the number of samples for fading
253
+ fade_samples = int(fade_duration * sr)
254
+
255
+ # Prepare a list to store the processed segments
256
+ processed_segments = []
257
+
258
+ # Description of segments
259
+ segment_descriptions = []
260
+
261
+ # Process each segment
262
+ for i, (start_time, end_time, audio) in enumerate(segments):
263
+ # Apply fade in and fade out
264
+ segment_length = len(audio)
265
+
266
+ if segment_length <= 2 * fade_samples:
267
+ # Segment is too short for fading, skip it
268
+ continue
269
+
270
+ # Create a linear fade in and fade out
271
+ fade_in = np.linspace(0, 1, fade_samples)
272
+ fade_out = np.linspace(1, 0, fade_samples)
273
+
274
+ # Apply the fades
275
+ audio_faded = audio.copy()
276
+ audio_faded[:fade_samples] *= fade_in
277
+ audio_faded[-fade_samples:] *= fade_out
278
+
279
+ processed_segments.append(audio_faded)
280
+
281
+ # Format the times for the description
282
+ start_fmt = format_time(start_time)
283
+ end_fmt = format_time(end_time)
284
+ segment_descriptions.append(f"Chorus {i+1}: {start_fmt} - {end_fmt}")
285
+
286
+ if not processed_segments:
287
+ return np.array([]), "No chorus segments long enough for compilation"
288
+
289
+ # Concatenate all the processed segments
290
+ compilation = np.concatenate(processed_segments)
291
+
292
+ # Join the descriptions
293
+ description = "\n".join(segment_descriptions)
294
+
295
+ return compilation, description
296
+
297
+
298
+ def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str = 'mp3') -> bytes:
299
+ """Save audio data to a format suitable for Streamlit audio playback.
300
+
301
+ Args:
302
+ audio_data: Audio samples
303
+ sr: Sample rate
304
+ file_format: Output format ('mp3', 'wav', etc.)
305
+
306
+ Returns:
307
+ Audio bytes
308
+ """
309
+ with io.BytesIO() as buffer:
310
+ sf.write(buffer, audio_data, sr, format=file_format)
311
+ buffer.seek(0)
312
+ return buffer.read()
313
+
314
+
315
+ def format_time(seconds: float) -> str:
316
+ """Format seconds as MM:SS.
317
+
318
+ Args:
319
+ seconds: Time in seconds
320
+
321
+ Returns:
322
+ Formatted time string
323
+ """
324
+ minutes = int(seconds // 60)
325
+ seconds = int(seconds % 60)
326
+ return f"{minutes:02d}:{seconds:02d}"
327
+
328
+
329
+ def main() -> None:
330
+ """Main function for the Streamlit app."""
331
+ # Set page config
332
+ st.set_page_config(
333
+ page_title="Chorus Detection",
334
+ page_icon="🎵",
335
+ layout="wide",
336
+ initial_sidebar_state="collapsed",
337
+ )
338
+
339
+ # Apply custom theme
340
+ set_custom_theme()
341
+
342
+ # App title and description
343
+ st.title("Chorus Detection")
344
+ st.markdown("""
345
+ <div class="subheader">
346
+ Upload a song or enter a YouTube URL to automatically detect chorus sections using AI
347
+ </div>
348
+ """, unsafe_allow_html=True)
349
+
350
+ # User input section
351
+ col1, col2 = st.columns(2)
352
+
353
+ with col1:
354
+ st.markdown('<div class="input-option">', unsafe_allow_html=True)
355
+ st.subheader("Option 1: Upload an audio file")
356
+ uploaded_file = st.file_uploader("Choose an audio file", type=['mp3', 'wav', 'ogg', 'flac', 'm4a'])
357
+ st.markdown('</div>', unsafe_allow_html=True)
358
+
359
+ with col2:
360
+ st.markdown('<div class="input-option">', unsafe_allow_html=True)
361
+ st.subheader("Option 2: YouTube URL")
362
+ youtube_url = st.text_input("Enter a YouTube URL", placeholder="https://www.youtube.com/watch?v=...")
363
+ st.markdown('</div>', unsafe_allow_html=True)
364
+
365
+ # Process button
366
+ if st.button("Analyze"):
367
+ # Check the input method
368
+ audio_path = None
369
+ file_name = None
370
+
371
+ if uploaded_file is not None:
372
+ audio_path, file_name = process_uploaded_file(uploaded_file)
373
+ elif youtube_url:
374
+ if is_youtube_url(youtube_url):
375
+ audio_path, file_name = process_youtube(youtube_url)
376
+ else:
377
+ st.error("Invalid YouTube URL. Please enter a valid YouTube URL.")
378
+ else:
379
+ st.error("Please upload an audio file or enter a YouTube URL.")
380
+
381
+ # If we have a valid audio path, process it
382
+ if audio_path and file_name:
383
+ try:
384
+ # Load and process the audio file
385
+ with st.spinner('Processing audio...'):
386
+ # Load audio and extract features
387
+ y, sr = librosa.load(audio_path, sr=22050)
388
+
389
+ # Create a temporary directory for model output
390
+ temp_output_dir = tempfile.mkdtemp()
391
+
392
+ # Load the model
393
+ model = load_CRNN_model(MODEL_PATH)
394
+
395
+ # Process audio and make predictions
396
+ audio_features, _ = process_audio(audio_path, output_path=temp_output_dir)
397
+ meter_grid_times, predictions = make_predictions(model, audio_features)
398
+
399
+ # Smooth predictions to avoid rapid transitions
400
+ smoothed_predictions = np.convolve(predictions,
401
+ np.ones(5)/5,
402
+ mode='same')
403
+
404
+ # Extract chorus segments
405
+ chorus_segments = extract_chorus_segments(y, sr, smoothed_predictions, meter_grid_times)
406
+
407
+ # Create a chorus compilation
408
+ compilation_audio, segments_desc = create_chorus_compilation(chorus_segments, sr)
409
+
410
+ # Display results
411
+ st.markdown(f"""
412
+ <div class="result-container">
413
+ <div class="song-title">{file_name}</div>
414
+ </div>
415
+ """, unsafe_allow_html=True)
416
+
417
+ # Display waveform with highlighted chorus sections
418
+ fig, ax = plt.subplots(figsize=(14, 5))
419
+
420
+ # Plot the waveform
421
+ times = np.linspace(0, len(y)/sr, len(y))
422
+ ax.plot(times, y, color='#b3b3b3', alpha=0.5, linewidth=1)
423
+ ax.set_xlabel('Time (s)')
424
+ ax.set_ylabel('Amplitude')
425
+ ax.set_title('Audio Waveform with Chorus Sections Highlighted')
426
+
427
+ # Highlight chorus sections
428
+ for start_time, end_time, _ in chorus_segments:
429
+ ax.axvspan(start_time, end_time, alpha=0.3, color=THEME_COLORS['primary'])
430
+
431
+ # Add a label at the start of each chorus
432
+ ax.annotate('Chorus',
433
+ xy=(start_time, 0.8 * max(y)),
434
+ xytext=(start_time + 0.5, 0.9 * max(y)),
435
+ color=THEME_COLORS['primary'],
436
+ weight='bold')
437
+
438
+ # Customize plot appearance
439
+ ax.set_facecolor(THEME_COLORS['card_bg'])
440
+ fig.patch.set_facecolor(THEME_COLORS['background'])
441
+ ax.spines['top'].set_visible(False)
442
+ ax.spines['right'].set_visible(False)
443
+ ax.spines['bottom'].set_color(THEME_COLORS['border'])
444
+ ax.spines['left'].set_color(THEME_COLORS['border'])
445
+ ax.tick_params(axis='x', colors=THEME_COLORS['text'])
446
+ ax.tick_params(axis='y', colors=THEME_COLORS['text'])
447
+ ax.xaxis.label.set_color(THEME_COLORS['text'])
448
+ ax.yaxis.label.set_color(THEME_COLORS['text'])
449
+ ax.title.set_color(THEME_COLORS['text'])
450
+
451
+ st.pyplot(fig)
452
+
453
+ # Display chorus segments
454
+ if chorus_segments:
455
+ st.markdown('<div class="chorus-card">', unsafe_allow_html=True)
456
+ st.subheader("Chorus Segments")
457
+ for i, (start_time, end_time, segment_audio) in enumerate(chorus_segments):
458
+ st.markdown(f"""
459
+ <div class="time-stamp">Chorus {i+1}: {format_time(start_time)} - {format_time(end_time)}</div>
460
+ """, unsafe_allow_html=True)
461
+
462
+ # Convert segment audio to bytes for playback
463
+ audio_bytes = save_audio_for_streamlit(segment_audio, sr)
464
+ st.audio(audio_bytes, format='audio/mp3')
465
+ st.markdown('</div>', unsafe_allow_html=True)
466
+
467
+ # Chorus compilation
468
+ if len(compilation_audio) > 0:
469
+ st.markdown('<div class="chorus-card">', unsafe_allow_html=True)
470
+ st.subheader("Chorus Compilation")
471
+ st.markdown("All chorus segments combined into one track:")
472
+
473
+ compilation_bytes = save_audio_for_streamlit(compilation_audio, sr)
474
+ st.audio(compilation_bytes, format='audio/mp3')
475
+ st.markdown('</div>', unsafe_allow_html=True)
476
+ else:
477
+ st.info("No chorus sections detected in this audio.")
478
+
479
+ except Exception as e:
480
+ st.error(f"Error processing audio: {e}")
481
+ logger.error(f"Error processing audio: {e}", exc_info=True)
482
+
483
+ if __name__ == "__main__":
484
+ main()