dennisvdang commited on
Commit
606184e
·
1 Parent(s): ad0da04

Refactor code and remove unnecessary files

Browse files
.space/app-entrypoint.sh CHANGED
Binary files a/.space/app-entrypoint.sh and b/.space/app-entrypoint.sh differ
 
Dockerfile CHANGED
@@ -27,17 +27,14 @@ RUN pip install -e .
27
  # Make the entry point script executable
28
  RUN chmod +x .space/app-entrypoint.sh || echo "Could not chmod app-entrypoint.sh"
29
 
30
- # Ensure chorus_detection package is properly installed
31
  RUN cd /app && \
32
- python -c "import chorus_detection; print(f'Successfully imported chorus_detection module from {chorus_detection.__file__}')" || \
33
  echo "Warning: chorus_detection module not properly installed"
34
 
35
- # Ensure model exists and debug info
36
- RUN echo "Debug: ls -la /app" && ls -la /app && \
37
- echo "Debug: PYTHONPATH=$PYTHONPATH" && \
38
- python -c "import sys; print(f'Python path: {sys.path}')" && \
39
- python -c "import os; print(f'Working directory: {os.getcwd()}')" && \
40
- python -c "from download_model import ensure_model_exists; ensure_model_exists(revision='${MODEL_REVISION}')" || echo "Warning: Model download failed during build"
41
 
42
  # Expose port for Streamlit
43
  EXPOSE 7860
 
27
  # Make the entry point script executable
28
  RUN chmod +x .space/app-entrypoint.sh || echo "Could not chmod app-entrypoint.sh"
29
 
30
+ # Verify chorus_detection package installation
31
  RUN cd /app && \
32
+ python -c "import chorus_detection; print(f'Successfully imported chorus_detection')" || \
33
  echo "Warning: chorus_detection module not properly installed"
34
 
35
+ # Ensure model exists
36
+ RUN python -c "from download_model import ensure_model_exists; ensure_model_exists(revision='${MODEL_REVISION}')" || \
37
+ echo "Warning: Model download failed during build"
 
 
 
38
 
39
  # Expose port for Streamlit
40
  EXPOSE 7860
app.py CHANGED
@@ -3,8 +3,6 @@
3
 
4
  """
5
  Main entry point for the Chorus Detection Streamlit app.
6
- This file is a simple wrapper that starts the Streamlit app
7
- without circular imports.
8
  """
9
 
10
  import os
@@ -29,9 +27,7 @@ if os.environ.get("SPACE_ID"):
29
  def main():
30
  """Main entry point for the Streamlit app."""
31
  logger.info("Starting Streamlit app...")
32
- # Import the Streamlit app module directly
33
  import streamlit_app
34
- # Run the Streamlit app
35
  streamlit_app.main()
36
 
37
  if __name__ == "__main__":
 
3
 
4
  """
5
  Main entry point for the Chorus Detection Streamlit app.
 
 
6
  """
7
 
8
  import os
 
27
  def main():
28
  """Main entry point for the Streamlit app."""
29
  logger.info("Starting Streamlit app...")
 
30
  import streamlit_app
 
31
  streamlit_app.main()
32
 
33
  if __name__ == "__main__":
download_model.py CHANGED
@@ -148,17 +148,15 @@ def ensure_model_exists(
148
  try:
149
  if HF_HUB_AVAILABLE:
150
  # Use huggingface_hub to download the model
151
- logger.info(f"Downloading model from {repo_id}/{hf_model_filename} (revision: {revision}) using huggingface_hub")
152
  downloaded_path = hf_hub_download(
153
  repo_id=repo_id,
154
  filename=hf_model_filename,
155
  local_dir=model_dir,
156
  local_dir_use_symlinks=False,
157
- revision=revision # Specify the exact revision to use
158
  )
159
 
160
- logger.info(f"Downloaded to: {downloaded_path}")
161
-
162
  # Rename if necessary
163
  if os.path.basename(downloaded_path) != model_filename:
164
  downloaded_path_obj = Path(downloaded_path)
 
148
  try:
149
  if HF_HUB_AVAILABLE:
150
  # Use huggingface_hub to download the model
151
+ logger.info(f"Downloading model from {repo_id}/{hf_model_filename}")
152
  downloaded_path = hf_hub_download(
153
  repo_id=repo_id,
154
  filename=hf_model_filename,
155
  local_dir=model_dir,
156
  local_dir_use_symlinks=False,
157
+ revision=revision
158
  )
159
 
 
 
160
  # Rename if necessary
161
  if os.path.basename(downloaded_path) != model_filename:
162
  downloaded_path_obj = Path(downloaded_path)
setup.py CHANGED
@@ -8,7 +8,6 @@ setup(
8
  version="0.1.0",
9
  packages=find_packages(),
10
  install_requires=[
11
- # These are already in requirements.txt so no need to specify versions
12
  "numpy",
13
  "scipy",
14
  "tqdm",
 
8
  version="0.1.0",
9
  packages=find_packages(),
10
  install_requires=[
 
11
  "numpy",
12
  "scipy",
13
  "tqdm",
src/chorus_detection/__init__.py DELETED
@@ -1,10 +0,0 @@
1
- """Chorus Detection package for identifying choruses in music.
2
-
3
- This package contains modules for:
4
- - Audio processing and feature extraction
5
- - Machine learning models for chorus detection
6
- - Visualization tools for audio analysis
7
- - Utility functions
8
- """
9
-
10
- __version__ = "0.1.0"
 
 
 
 
 
 
 
 
 
 
 
src/chorus_detection/audio/__init__.py DELETED
File without changes
src/chorus_detection/audio/data_processing.py DELETED
@@ -1,180 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """Module for audio data processing including segmentation and positional encoding."""
5
-
6
- from typing import List, Optional, Tuple, Any
7
-
8
- import librosa
9
- import numpy as np
10
-
11
- from chorus_detection.audio.processor import AudioFeature
12
- from chorus_detection.config import SR, HOP_LENGTH, MAX_FRAMES, MAX_METERS, N_FEATURES
13
- from chorus_detection.utils.logging import logger
14
-
15
-
16
- def segment_data_meters(data: np.ndarray, meter_grid: np.ndarray) -> List[np.ndarray]:
17
- """Divide song data into segments based on measure grid frames.
18
-
19
- Args:
20
- data: The song data to be segmented
21
- meter_grid: The grid indicating the start of each measure
22
-
23
- Returns:
24
- A list of song data segments
25
- """
26
- # Create segments using vectorized operations
27
- meter_segments = [data[s:e] for s, e in zip(meter_grid[:-1], meter_grid[1:])]
28
-
29
- # Convert all segments to float32 for consistent processing
30
- meter_segments = [segment.astype(np.float32) for segment in meter_segments]
31
-
32
- return meter_segments
33
-
34
-
35
- def positional_encoding(position: int, d_model: int) -> np.ndarray:
36
- """Generate a positional encoding for a given position and model dimension.
37
-
38
- Args:
39
- position: The position for which to generate the encoding
40
- d_model: The dimension of the model
41
-
42
- Returns:
43
- The positional encoding
44
- """
45
- # Create position array
46
- positions = np.arange(position)[:, np.newaxis]
47
-
48
- # Calculate dimension-based scaling factors
49
- dim_indices = np.arange(d_model)[np.newaxis, :]
50
- angles = positions / np.power(10000, (2 * (dim_indices // 2)) / np.float32(d_model))
51
-
52
- # Apply sine to even indices and cosine to odd indices
53
- encodings = np.zeros((position, d_model), dtype=np.float32)
54
- encodings[:, 0::2] = np.sin(angles[:, 0::2])
55
- encodings[:, 1::2] = np.cos(angles[:, 1::2])
56
-
57
- return encodings
58
-
59
-
60
- def apply_hierarchical_positional_encoding(segments: List[np.ndarray]) -> List[np.ndarray]:
61
- """Apply positional encoding at the meter and frame levels to a list of segments.
62
-
63
- Args:
64
- segments: The list of segments to encode
65
-
66
- Returns:
67
- The list of segments with applied positional encoding
68
- """
69
- if not segments:
70
- logger.warning("No segments to encode")
71
- return []
72
-
73
- n_features = segments[0].shape[1]
74
-
75
- # Generate measure-level positional encodings
76
- measure_level_encodings = positional_encoding(len(segments), n_features)
77
-
78
- # Apply hierarchical encodings to each segment
79
- encoded_segments = []
80
- for i, segment in enumerate(segments):
81
- # Generate frame-level positional encoding
82
- frame_level_encoding = positional_encoding(len(segment), n_features)
83
-
84
- # Combine frame-level and measure-level encodings
85
- encoded_segment = segment + frame_level_encoding + measure_level_encodings[i]
86
- encoded_segments.append(encoded_segment)
87
-
88
- return encoded_segments
89
-
90
-
91
- def pad_song(encoded_segments: List[np.ndarray], max_frames: int = MAX_FRAMES,
92
- max_meters: int = MAX_METERS, n_features: int = N_FEATURES) -> np.ndarray:
93
- """Pad or truncate the encoded segments to have the specified dimensions.
94
-
95
- Args:
96
- encoded_segments: The encoded segments to pad or truncate
97
- max_frames: The maximum number of frames per segment
98
- max_meters: The maximum number of meters
99
- n_features: The number of features per frame
100
-
101
- Returns:
102
- The padded or truncated song as a numpy array
103
- """
104
- if not encoded_segments:
105
- logger.warning("No encoded segments to pad")
106
- return np.zeros((max_meters, max_frames, n_features), dtype=np.float32)
107
-
108
- # Pad or truncate each meter/segment to max_frames
109
- padded_meters = []
110
- for meter in encoded_segments:
111
- # Truncate if longer than max_frames
112
- truncated_meter = meter[:max_frames] if meter.shape[0] > max_frames else meter
113
-
114
- # Pad if shorter than max_frames
115
- if truncated_meter.shape[0] < max_frames:
116
- padding = ((0, max_frames - truncated_meter.shape[0]), (0, 0))
117
- padded_meter = np.pad(truncated_meter, padding, 'constant', constant_values=0)
118
- else:
119
- padded_meter = truncated_meter
120
-
121
- padded_meters.append(padded_meter)
122
-
123
- # Create padding meter (all zeros)
124
- padding_meter = np.zeros((max_frames, n_features), dtype=np.float32)
125
-
126
- # Truncate or pad to max_meters
127
- if len(padded_meters) > max_meters:
128
- padded_song = np.array(padded_meters[:max_meters])
129
- else:
130
- padded_song = np.array(padded_meters + [padding_meter] * (max_meters - len(padded_meters)))
131
-
132
- return padded_song
133
-
134
-
135
- def process_audio(audio_path: str, trim_silence: bool = True, sr: int = SR,
136
- hop_length: int = HOP_LENGTH) -> Tuple[Optional[np.ndarray], Optional[AudioFeature]]:
137
- """Process an audio file, extracting features and applying positional encoding.
138
-
139
- Args:
140
- audio_path: The path to the audio file
141
- trim_silence: Whether to trim silence from the audio
142
- sr: The sample rate to use when loading the audio
143
- hop_length: The hop length to use for feature extraction
144
-
145
- Returns:
146
- A tuple containing the processed audio and its features
147
- """
148
- logger.info(f"Processing audio file: {audio_path}")
149
-
150
- try:
151
- # First optionally strip silence
152
- if trim_silence:
153
- from chorus_detection.audio.processor import strip_silence
154
- strip_silence(audio_path)
155
-
156
- # Create audio feature object and extract features
157
- audio_features = AudioFeature(audio_path=audio_path, sr=sr, hop_length=hop_length)
158
- audio_features.extract_features()
159
- audio_features.create_meter_grid()
160
-
161
- # Segment the audio data by meter grid
162
- audio_segments = segment_data_meters(
163
- audio_features.combined_features, audio_features.meter_grid)
164
-
165
- # Apply positional encoding
166
- encoded_audio_segments = apply_hierarchical_positional_encoding(audio_segments)
167
-
168
- # Pad song to fixed dimensions and add batch dimension
169
- processed_audio = np.expand_dims(pad_song(encoded_audio_segments), axis=0)
170
-
171
- logger.info(f"Audio processing complete: {processed_audio.shape}")
172
- return processed_audio, audio_features
173
-
174
- except Exception as e:
175
- logger.error(f"Error processing audio: {e}")
176
-
177
- import traceback
178
- logger.debug(traceback.format_exc())
179
-
180
- return None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chorus_detection/audio/processor.py DELETED
@@ -1,409 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """Module for audio feature extraction and processing."""
5
-
6
- import os
7
- import subprocess
8
- import time
9
- from functools import reduce
10
- from pathlib import Path
11
- from typing import List, Tuple, Optional, Dict, Any, Union
12
-
13
- import librosa
14
- import numpy as np
15
- from pydub import AudioSegment
16
- from pydub.silence import detect_nonsilent
17
- from sklearn.preprocessing import StandardScaler
18
-
19
- from chorus_detection.config import SR, HOP_LENGTH, AUDIO_TEMP_PATH
20
- from chorus_detection.utils.logging import logger
21
-
22
-
23
- def extract_audio(url: str, output_path: str = str(AUDIO_TEMP_PATH)) -> Tuple[Optional[str], Optional[str]]:
24
- """Download audio from YouTube URL and save as MP3 using yt-dlp.
25
-
26
- Args:
27
- url: YouTube URL of the audio file
28
- output_path: Path to save the downloaded audio file
29
-
30
- Returns:
31
- Tuple containing path to the downloaded audio file and the video title, or None if download fails
32
- """
33
- try:
34
- # Create output directory if it doesn't exist
35
- os.makedirs(output_path, exist_ok=True)
36
-
37
- # Create a unique filename using timestamp
38
- timestamp = int(time.time())
39
- output_file = os.path.join(output_path, f"audio_{timestamp}.mp3")
40
-
41
- # Get the video title first
42
- video_title = get_video_title(url) or f"Video_{timestamp}"
43
-
44
- # Download the audio
45
- success, error_msg = download_audio(url, output_file)
46
-
47
- if not success:
48
- handle_download_error(error_msg)
49
- return None, None
50
-
51
- # Check if file exists and is valid
52
- if os.path.exists(output_file) and os.path.getsize(output_file) > 0:
53
- logger.info(f"Successfully downloaded: {video_title}")
54
- return output_file, video_title
55
- else:
56
- logger.error("Download completed but file not found or empty")
57
- return None, None
58
-
59
- except Exception as e:
60
- import traceback
61
- error_details = traceback.format_exc()
62
- logger.error(f"An error occurred during YouTube download: {e}")
63
- logger.debug(f"Error details: {error_details}")
64
-
65
- check_yt_dlp_installation()
66
- return None, None
67
-
68
-
69
- def get_video_title(url: str) -> Optional[str]:
70
- """Get the title of a YouTube video.
71
-
72
- Args:
73
- url: YouTube URL
74
-
75
- Returns:
76
- Video title if successful, None otherwise
77
- """
78
- try:
79
- title_command = ['yt-dlp', '--get-title', '--no-warnings', url]
80
- video_title = subprocess.check_output(title_command, universal_newlines=True).strip()
81
- return video_title
82
- except subprocess.CalledProcessError as e:
83
- logger.warning(f"Could not retrieve video title: {str(e)}")
84
- return None
85
-
86
-
87
- def download_audio(url: str, output_file: str) -> Tuple[bool, str]:
88
- """Download audio from YouTube URL using yt-dlp.
89
-
90
- Args:
91
- url: YouTube URL
92
- output_file: Output file path
93
-
94
- Returns:
95
- Tuple containing (success, error_message)
96
- """
97
- command = [
98
- 'yt-dlp',
99
- '-f', 'bestaudio',
100
- '--extract-audio',
101
- '--audio-format', 'mp3',
102
- '--audio-quality', '0', # Best quality
103
- '--output', output_file,
104
- '--no-playlist',
105
- '--verbose',
106
- url
107
- ]
108
-
109
- process = subprocess.Popen(
110
- command,
111
- stdout=subprocess.PIPE,
112
- stderr=subprocess.PIPE,
113
- universal_newlines=True
114
- )
115
- stdout, stderr = process.communicate()
116
-
117
- if process.returncode != 0:
118
- error_msg = f"Error downloading from YouTube (code {process.returncode}): {stderr}"
119
- return False, error_msg
120
-
121
- return True, ""
122
-
123
-
124
- def handle_download_error(error_msg: str) -> None:
125
- """Handle common YouTube download errors with helpful messages.
126
-
127
- Args:
128
- error_msg: Error message from yt-dlp
129
- """
130
- logger.error(error_msg)
131
-
132
- if "Sign in to confirm you're not a bot" in error_msg:
133
- logger.error("YouTube is detecting automated access. Try using a local file instead.")
134
- elif any(x in error_msg.lower() for x in ["unavailable video", "private video"]):
135
- logger.error("The video appears to be private or unavailable. Please try another URL.")
136
- elif "copyright" in error_msg.lower():
137
- logger.error("The video may be blocked due to copyright restrictions.")
138
- elif any(x in error_msg.lower() for x in ["rate limit", "429"]):
139
- logger.error("YouTube rate limit reached. Please try again later.")
140
-
141
-
142
- def check_yt_dlp_installation() -> None:
143
- """Check if yt-dlp is installed and provide guidance if it's not."""
144
- try:
145
- subprocess.run(['yt-dlp', '--version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
146
- except FileNotFoundError:
147
- logger.error("yt-dlp is not installed or not in PATH. Please install it with: pip install yt-dlp")
148
-
149
-
150
- def strip_silence(audio_path: str) -> None:
151
- """Remove silent parts from an audio file.
152
-
153
- Args:
154
- audio_path: Path to the audio file
155
- """
156
- try:
157
- sound = AudioSegment.from_file(audio_path)
158
- nonsilent_ranges = detect_nonsilent(
159
- sound, min_silence_len=500, silence_thresh=-50)
160
-
161
- if not nonsilent_ranges:
162
- logger.warning("No non-silent parts detected in the audio. Using original file.")
163
- return
164
-
165
- stripped = reduce(lambda acc, val: acc + sound[val[0]:val[1]],
166
- nonsilent_ranges, AudioSegment.empty())
167
- stripped.export(audio_path, format='mp3')
168
- except Exception as e:
169
- logger.error(f"Error stripping silence: {e}")
170
- logger.info("Proceeding with original audio file")
171
-
172
-
173
- class AudioFeature:
174
- """Class for extracting and processing audio features."""
175
-
176
- def __init__(self, audio_path: str, sr: int = SR, hop_length: int = HOP_LENGTH):
177
- """Initialize the AudioFeature class.
178
-
179
- Args:
180
- audio_path: Path to the audio file
181
- sr: Sample rate for audio processing
182
- hop_length: Hop length for feature extraction
183
- """
184
- self.audio_path: str = audio_path
185
- self.sr: int = sr
186
- self.hop_length: int = hop_length
187
- self.time_signature: int = 4
188
-
189
- # Initialize all features as None
190
- self.y: Optional[np.ndarray] = None
191
- self.y_harm: Optional[np.ndarray] = None
192
- self.y_perc: Optional[np.ndarray] = None
193
- self.beats: Optional[np.ndarray] = None
194
- self.chroma_acts: Optional[np.ndarray] = None
195
- self.chromagram: Optional[np.ndarray] = None
196
- self.combined_features: Optional[np.ndarray] = None
197
- self.key: Optional[str] = None
198
- self.mode: Optional[str] = None
199
- self.mel_acts: Optional[np.ndarray] = None
200
- self.melspectrogram: Optional[np.ndarray] = None
201
- self.meter_grid: Optional[np.ndarray] = None
202
- self.mfccs: Optional[np.ndarray] = None
203
- self.mfcc_acts: Optional[np.ndarray] = None
204
- self.n_frames: Optional[int] = None
205
- self.onset_env: Optional[np.ndarray] = None
206
- self.rms: Optional[np.ndarray] = None
207
- self.spectrogram: Optional[np.ndarray] = None
208
- self.tempo: Optional[float] = None
209
- self.tempogram: Optional[np.ndarray] = None
210
- self.tempogram_acts: Optional[np.ndarray] = None
211
-
212
- def detect_key(self, chroma_vals: np.ndarray) -> Tuple[str, str]:
213
- """Detect the key and mode (major or minor) of the audio segment.
214
-
215
- Args:
216
- chroma_vals: Chromagram values to analyze for key detection
217
-
218
- Returns:
219
- Tuple containing the detected key and mode
220
- """
221
- note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
222
-
223
- # Key profiles (Krumhansl-Kessler profiles)
224
- major_profile = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88])
225
- minor_profile = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17])
226
-
227
- # Normalize profiles
228
- major_profile /= np.linalg.norm(major_profile)
229
- minor_profile /= np.linalg.norm(minor_profile)
230
-
231
- # Calculate correlations for all possible rotations
232
- major_correlations = [np.corrcoef(chroma_vals, np.roll(major_profile, i))[0, 1] for i in range(12)]
233
- minor_correlations = [np.corrcoef(chroma_vals, np.roll(minor_profile, i))[0, 1] for i in range(12)]
234
-
235
- # Find max correlation
236
- max_major_idx = np.argmax(major_correlations)
237
- max_minor_idx = np.argmax(minor_correlations)
238
-
239
- # Determine mode
240
- self.mode = 'major' if major_correlations[max_major_idx] > minor_correlations[max_minor_idx] else 'minor'
241
- self.key = note_names[max_major_idx if self.mode == 'major' else max_minor_idx]
242
-
243
- return self.key, self.mode
244
-
245
- def calculate_ki_chroma(self, waveform: np.ndarray, sr: int, hop_length: int) -> np.ndarray:
246
- """Calculate a normalized, key-invariant chromagram for the given audio waveform.
247
-
248
- Args:
249
- waveform: Audio waveform to analyze
250
- sr: Sample rate of the waveform
251
- hop_length: Hop length for feature extraction
252
-
253
- Returns:
254
- The key-invariant chromagram as a numpy array
255
- """
256
- # Calculate chromagram
257
- chromagram = librosa.feature.chroma_cqt(
258
- y=waveform, sr=sr, hop_length=hop_length, bins_per_octave=24)
259
-
260
- # Normalize to [0, 1]
261
- chromagram = (chromagram - chromagram.min()) / (chromagram.max() - chromagram.min() + 1e-8)
262
-
263
- # Detect key
264
- chroma_vals = np.sum(chromagram, axis=1)
265
- key, mode = self.detect_key(chroma_vals)
266
-
267
- # Make key-invariant
268
- key_idx = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B'].index(key)
269
- shift_amount = -key_idx if mode == 'major' else -(key_idx + 3) % 12
270
-
271
- return librosa.util.normalize(np.roll(chromagram, shift_amount, axis=0), axis=1)
272
-
273
- def extract_features(self) -> None:
274
- """Extract various audio features from the loaded audio."""
275
- # Load audio
276
- self.y, self.sr = librosa.load(self.audio_path, sr=self.sr)
277
-
278
- # Harmonic-percussive source separation
279
- self.y_harm, self.y_perc = librosa.effects.hpss(self.y)
280
-
281
- # Extract spectrogram
282
- self.spectrogram, _ = librosa.magphase(librosa.stft(self.y, hop_length=self.hop_length))
283
-
284
- # RMS energy
285
- self.rms = librosa.feature.rms(S=self.spectrogram, hop_length=self.hop_length).astype(np.float32)
286
-
287
- # Mel spectrogram and activations
288
- self.melspectrogram = librosa.feature.melspectrogram(
289
- y=self.y, sr=self.sr, n_mels=128, hop_length=self.hop_length).astype(np.float32)
290
- self.mel_acts = librosa.decompose.decompose(self.melspectrogram, n_components=3, sort=True)[1].astype(np.float32)
291
-
292
- # Chromagram and activations
293
- self.chromagram = self.calculate_ki_chroma(self.y_harm, self.sr, self.hop_length).astype(np.float32)
294
- self.chroma_acts = librosa.decompose.decompose(self.chromagram, n_components=4, sort=True)[1].astype(np.float32)
295
-
296
- # Onset detection and tempogram
297
- self.onset_env = librosa.onset.onset_strength(y=self.y_perc, sr=self.sr, hop_length=self.hop_length)
298
- self.tempogram = np.clip(librosa.feature.tempogram(
299
- onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length), 0, None)
300
- self.tempogram_acts = librosa.decompose.decompose(self.tempogram, n_components=3, sort=True)[1]
301
-
302
- # MFCCs and activations
303
- self.mfccs = librosa.feature.mfcc(y=self.y, sr=self.sr, n_mfcc=20, hop_length=self.hop_length)
304
- self.mfccs += abs(np.min(self.mfccs) or 0) # Handle negative values
305
- self.mfcc_acts = librosa.decompose.decompose(self.mfccs, n_components=4, sort=True)[1].astype(np.float32)
306
-
307
- # Combine features with weighted normalization
308
- self._combine_features()
309
-
310
- def _combine_features(self) -> None:
311
- """Combine all extracted features with balanced weights."""
312
- features = [self.rms, self.mel_acts, self.chroma_acts, self.tempogram_acts, self.mfcc_acts]
313
- feature_names = ['rms', 'mel_acts', 'chroma_acts', 'tempogram_acts', 'mfcc_acts']
314
-
315
- # Calculate dimension-based weights
316
- dims = {name: feature.shape[0] for feature, name in zip(features, feature_names)}
317
- total_inv_dim = sum(1 / dim for dim in dims.values())
318
- weights = {name: 1 / (dims[name] * total_inv_dim) for name in feature_names}
319
-
320
- # Normalize and weight each feature
321
- std_weighted_features = [
322
- StandardScaler().fit_transform(feature.T).T * weights[name]
323
- for feature, name in zip(features, feature_names)
324
- ]
325
-
326
- # Combine features
327
- self.combined_features = np.concatenate(std_weighted_features, axis=0).T.astype(np.float32)
328
- self.n_frames = len(self.combined_features)
329
-
330
- def create_meter_grid(self) -> np.ndarray:
331
- """Create a grid based on the meter of the song using tempo and beats.
332
-
333
- Returns:
334
- Numpy array containing the meter grid frame positions
335
- """
336
- # Extract tempo and beat information
337
- self.tempo, self.beats = librosa.beat.beat_track(
338
- onset_envelope=self.onset_env, sr=self.sr, hop_length=self.hop_length)
339
-
340
- # Adjust tempo if it's too slow or too fast
341
- self.tempo = self._adjust_tempo(self.tempo)
342
-
343
- # Create meter grid
344
- self.meter_grid = self._create_meter_grid()
345
- return self.meter_grid
346
-
347
- def _adjust_tempo(self, tempo: float) -> float:
348
- """Adjust tempo to a reasonable range.
349
-
350
- Args:
351
- tempo: Detected tempo
352
-
353
- Returns:
354
- Adjusted tempo
355
- """
356
- if tempo < 70:
357
- return tempo * 2
358
- elif tempo > 140:
359
- return tempo / 2
360
- return tempo
361
-
362
- def _create_meter_grid(self) -> np.ndarray:
363
- """Helper function to create a meter grid for the song.
364
-
365
- Returns:
366
- Numpy array containing the meter grid frame positions
367
- """
368
- # Calculate beat interval
369
- seconds_per_beat = 60 / self.tempo
370
- beat_interval = int(librosa.time_to_frames(seconds_per_beat, sr=self.sr, hop_length=self.hop_length))
371
-
372
- # Find best matching start beat
373
- if len(self.beats) >= 3:
374
- best_match = max(
375
- (1 - abs(np.mean(self.beats[i:i+3]) - beat_interval) / beat_interval, self.beats[i])
376
- for i in range(len(self.beats) - 2)
377
- )
378
- anchor_frame = best_match[1] if best_match[0] > 0.95 else self.beats[0]
379
- else:
380
- anchor_frame = self.beats[0] if len(self.beats) > 0 else 0
381
-
382
- first_beat_time = librosa.frames_to_time(anchor_frame, sr=self.sr, hop_length=self.hop_length)
383
-
384
- # Calculate beats forward and backward
385
- time_duration = librosa.frames_to_time(self.n_frames, sr=self.sr, hop_length=self.hop_length)
386
- num_beats_forward = int((time_duration - first_beat_time) / seconds_per_beat)
387
- num_beats_backward = int(first_beat_time / seconds_per_beat) + 1
388
-
389
- # Create beat times
390
- beat_times_forward = first_beat_time + np.arange(num_beats_forward) * seconds_per_beat
391
- beat_times_backward = first_beat_time - np.arange(1, num_beats_backward) * seconds_per_beat
392
-
393
- # Combine and create meter grid
394
- beat_grid = np.concatenate((np.array([0.0]), beat_times_backward[::-1], beat_times_forward))
395
- meter_indices = np.arange(0, len(beat_grid), self.time_signature)
396
- meter_grid = beat_grid[meter_indices]
397
-
398
- # Ensure grid starts at 0 and ends at frame duration
399
- if meter_grid[0] != 0.0:
400
- meter_grid = np.insert(meter_grid, 0, 0.0)
401
-
402
- # Convert to frames
403
- meter_grid = librosa.time_to_frames(meter_grid, sr=self.sr, hop_length=self.hop_length)
404
-
405
- # Ensure grid ends at the last frame
406
- if meter_grid[-1] != self.n_frames:
407
- meter_grid = np.append(meter_grid, self.n_frames)
408
-
409
- return meter_grid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chorus_detection/config.py DELETED
@@ -1,54 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """Configuration settings for the chorus detection system."""
5
-
6
- import os
7
- from pathlib import Path
8
- from typing import Dict, Any, Union, Optional
9
-
10
- # Audio processing settings
11
- SR: int = 12000 # Sample rate
12
- HOP_LENGTH: int = 128 # Hop length for signal processing
13
- MAX_FRAMES: int = 300 # Maximum frames per segment
14
- MAX_METERS: int = 201 # Maximum meters per song
15
- N_FEATURES: int = 15 # Number of features
16
-
17
- # Project paths
18
- PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.resolve()
19
- MODEL_DIR: Path = PROJECT_ROOT / "models" / "CRNN"
20
- MODEL_PATH: Path = MODEL_DIR / "best_model_V3.h5"
21
- AUDIO_TEMP_PATH: Path = PROJECT_ROOT / "output" / "temp"
22
- LOG_DIR: Path = PROJECT_ROOT / "logs"
23
-
24
- # Alternative Docker paths
25
- DOCKER_MODEL_PATH: str = "/app/models/CRNN/best_model_V3.h5"
26
- DOCKER_TEMP_PATH: str = "/app/output/temp"
27
-
28
-
29
- def get_env_path(env_var: str, default_path: Path) -> Path:
30
- """Get a path from environment variable or use the default.
31
-
32
- Args:
33
- env_var: Name of the environment variable
34
- default_path: Default path to use if environment variable is not set
35
-
36
- Returns:
37
- Path object for the specified location
38
- """
39
- env_value = os.environ.get(env_var)
40
- if env_value:
41
- return Path(env_value).resolve()
42
- return default_path
43
-
44
-
45
- # Override paths with environment variables if provided
46
- MODEL_PATH = get_env_path("CHORUS_MODEL_PATH", MODEL_PATH)
47
- AUDIO_TEMP_PATH = get_env_path("CHORUS_TEMP_PATH", AUDIO_TEMP_PATH)
48
- LOG_DIR = get_env_path("CHORUS_LOG_DIR", LOG_DIR)
49
-
50
- # Create necessary directories
51
- os.makedirs(MODEL_DIR, exist_ok=True)
52
- os.makedirs(AUDIO_TEMP_PATH, exist_ok=True)
53
- os.makedirs(LOG_DIR, exist_ok=True)
54
- os.makedirs(PROJECT_ROOT / "output", exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chorus_detection/models/__init__.py DELETED
File without changes
src/chorus_detection/models/crnn.py DELETED
@@ -1,186 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """Module for loading and managing the CRNN model for chorus detection."""
5
-
6
- import os
7
- from typing import Any, Optional, List, Tuple, Union
8
-
9
- import numpy as np
10
- import tensorflow as tf
11
-
12
- from chorus_detection.config import MODEL_PATH, DOCKER_MODEL_PATH
13
- from chorus_detection.utils.logging import logger
14
-
15
-
16
- def load_CRNN_model(model_path: str = str(MODEL_PATH)) -> tf.keras.Model:
17
- """Load a CRNN model with custom loss and accuracy functions.
18
-
19
- Args:
20
- model_path: Path to the saved model
21
-
22
- Returns:
23
- Loaded Keras model
24
-
25
- Raises:
26
- RuntimeError: If the model cannot be loaded
27
- """
28
- try:
29
- # Define custom objects required for model loading
30
- custom_objects = {
31
- 'custom_binary_crossentropy': lambda y_true, y_pred: y_pred,
32
- 'custom_accuracy': lambda y_true, y_pred: y_pred
33
- }
34
-
35
- # Try to load the model with custom objects
36
- logger.info(f"Loading model from: {model_path}")
37
- model = tf.keras.models.load_model(
38
- model_path, custom_objects=custom_objects, compile=False)
39
-
40
- # Compile the model with default optimizer and loss for prediction only
41
- model.compile(optimizer='adam', loss='binary_crossentropy')
42
-
43
- return model
44
- except Exception as e:
45
- logger.error(f"Error loading model from {model_path}: {e}")
46
-
47
- # Try Docker container path as fallback
48
- if model_path != DOCKER_MODEL_PATH and os.path.exists(DOCKER_MODEL_PATH):
49
- logger.info(f"Trying Docker path: {DOCKER_MODEL_PATH}")
50
- return load_CRNN_model(DOCKER_MODEL_PATH)
51
-
52
- raise RuntimeError(f"Failed to load model: {e}")
53
-
54
-
55
- def smooth_predictions(predictions: np.ndarray) -> np.ndarray:
56
- """Smooth predictions by correcting isolated mispredictions and removing short sequences.
57
-
58
- Args:
59
- predictions: Array of binary predictions
60
-
61
- Returns:
62
- Smoothed array of binary predictions
63
- """
64
- # Convert to numpy array if not already
65
- data = np.array(predictions, copy=True) if not isinstance(predictions, np.ndarray) else predictions.copy()
66
-
67
- # First pass: Correct isolated 0's (handle 0's surrounded by 1's)
68
- for i in range(1, len(data) - 1):
69
- if data[i] == 0 and data[i - 1] == 1 and data[i + 1] == 1:
70
- data[i] = 1
71
-
72
- # Second pass: Correct isolated 1's (handle 1's surrounded by 0's)
73
- corrected_data = data.copy()
74
- for i in range(1, len(data) - 1):
75
- if data[i] == 1 and data[i - 1] == 0 and data[i + 1] == 0:
76
- corrected_data[i] = 0
77
-
78
- # Third pass: Remove short sequences of 1s (less than 5 consecutive 1's)
79
- smoothed_data = corrected_data.copy()
80
- sequence_start = None
81
-
82
- for i in range(len(corrected_data)):
83
- if corrected_data[i] == 1:
84
- if sequence_start is None:
85
- sequence_start = i
86
- else:
87
- if sequence_start is not None:
88
- sequence_length = i - sequence_start
89
- if sequence_length < 5:
90
- smoothed_data[sequence_start:i] = 0
91
- sequence_start = None
92
-
93
- # Handle the case where the sequence extends to the end
94
- if sequence_start is not None:
95
- sequence_length = len(corrected_data) - sequence_start
96
- if sequence_length < 5:
97
- smoothed_data[sequence_start:] = 0
98
-
99
- return smoothed_data
100
-
101
-
102
- def make_predictions(model: tf.keras.Model, processed_audio: np.ndarray,
103
- audio_features: Any, url: Optional[str] = None,
104
- video_name: Optional[str] = None) -> np.ndarray:
105
- """Generate predictions from the model and process them.
106
-
107
- Args:
108
- model: The loaded model for making predictions
109
- processed_audio: The audio data that has been processed for prediction
110
- audio_features: Audio features object containing necessary metadata
111
- url: YouTube URL of the audio file (optional)
112
- video_name: Name of the video (optional)
113
-
114
- Returns:
115
- The smoothed binary predictions
116
- """
117
- import librosa
118
-
119
- logger.info("Generating predictions...")
120
-
121
- # Make predictions
122
- predictions = model.predict(processed_audio)[0]
123
-
124
- # Convert to binary predictions and handle potential size mismatch
125
- meter_grid_length = len(audio_features.meter_grid) - 1
126
- if len(predictions) > meter_grid_length:
127
- predictions = predictions[:meter_grid_length]
128
-
129
- binary_predictions = np.round(predictions).flatten()
130
-
131
- # Apply smoothing to improve prediction quality
132
- smoothed_predictions = smooth_predictions(binary_predictions)
133
-
134
- # Get times for identified chorus sections
135
- meter_grid_times = librosa.frames_to_time(
136
- audio_features.meter_grid,
137
- sr=audio_features.sr,
138
- hop_length=audio_features.hop_length
139
- )
140
-
141
- # Identify where choruses start
142
- chorus_start_times = [
143
- meter_grid_times[i] for i in range(len(smoothed_predictions))
144
- if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)
145
- ]
146
-
147
- # Print results if URL and video name are provided (CLI mode)
148
- if url and video_name:
149
- _print_chorus_results(url, video_name, chorus_start_times)
150
-
151
- return smoothed_predictions
152
-
153
-
154
- def _print_chorus_results(url: str, video_name: str, chorus_start_times: List[float]) -> None:
155
- """Print formatted results showing identified choruses with links.
156
-
157
- Args:
158
- url: YouTube URL of the analyzed video
159
- video_name: Name of the video
160
- chorus_start_times: List of start times (in seconds) for identified choruses
161
- """
162
- # Create YouTube links with time stamps
163
- youtube_links = [
164
- f"\033]8;;{url}&t={int(start_time)}s\033\\{url}&t={int(start_time)}s\033]8;;\033\\"
165
- for start_time in chorus_start_times
166
- ]
167
-
168
- # Format the output
169
- link_lengths = [len(link) for link in youtube_links]
170
- max_length = max(link_lengths + [len(video_name), len(f"Number of choruses identified: {len(chorus_start_times)}")]) if link_lengths else 50
171
- header_footer = "=" * (max_length + 4)
172
-
173
- # Print the results
174
- print("\n\n")
175
- print(header_footer)
176
- print(f"{video_name.center(max_length + 2)}")
177
- print(f"Number of choruses identified: {len(chorus_start_times)}".center(max_length + 4))
178
- print(header_footer)
179
-
180
- if chorus_start_times:
181
- for link in youtube_links:
182
- print(link)
183
- else:
184
- print("No choruses identified.")
185
-
186
- print(header_footer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chorus_detection/utils/__init__.py DELETED
File without changes
src/chorus_detection/utils/cli.py DELETED
@@ -1,107 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """Command-line interface utilities for the chorus detection system."""
5
-
6
- import argparse
7
- import os
8
- import sys
9
- from pathlib import Path
10
- from typing import Dict, Any, Optional, Tuple
11
-
12
- from chorus_detection.config import MODEL_PATH
13
- from chorus_detection.utils.logging import logger
14
-
15
-
16
- def parse_arguments() -> argparse.Namespace:
17
- """Parse command-line arguments.
18
-
19
- Returns:
20
- Parsed command-line arguments
21
- """
22
- parser = argparse.ArgumentParser(
23
- description="Chorus Finder - Detect choruses in songs from YouTube URLs or local audio files")
24
-
25
- input_group = parser.add_mutually_exclusive_group()
26
- input_group.add_argument("--url", type=str,
27
- help="YouTube URL of a song")
28
- input_group.add_argument("--file", type=str,
29
- help="Path to a local audio file")
30
-
31
- parser.add_argument("--model_path", type=str, default=str(MODEL_PATH),
32
- help=f"Path to the pretrained model (default: {MODEL_PATH})")
33
- parser.add_argument("--verbose", action="store_true",
34
- help="Verbose output", default=True)
35
- parser.add_argument("--plot", action="store_true",
36
- help="Display plot of the audio waveform", default=True)
37
- parser.add_argument("--no-plot", dest="plot", action="store_false",
38
- help="Disable plot display (useful for headless environments)")
39
-
40
- return parser.parse_args()
41
-
42
-
43
- def get_input_source(args: argparse.Namespace) -> Optional[str]:
44
- """Get input source from arguments or user input.
45
-
46
- Args:
47
- args: Parsed command-line arguments
48
-
49
- Returns:
50
- Input source (URL or file path)
51
- """
52
- input_source = args.url or args.file
53
- if not input_source:
54
- print("\nChorus Detection Tool")
55
- print("====================")
56
- print("\nNote: YouTube download functionality may be temporarily unavailable")
57
- print("due to YouTube's restrictions. If download fails, please use a local audio file.\n")
58
- print("Choose input method:")
59
- print("1. YouTube URL")
60
- print("2. Local audio file")
61
- choice = input("Enter choice (1 or 2): ")
62
-
63
- if choice == "1":
64
- input_source = input("Please enter the YouTube URL of the song: ")
65
- elif choice == "2":
66
- input_source = input("Please enter the path to the audio file: ")
67
- else:
68
- logger.error("Invalid choice")
69
- sys.exit(1)
70
-
71
- return input_source
72
-
73
-
74
- def is_youtube_url(input_source: str) -> bool:
75
- """Check if the input source is a YouTube URL.
76
-
77
- Args:
78
- input_source: Input source to check
79
-
80
- Returns:
81
- True if the input source is a YouTube URL, False otherwise
82
- """
83
- return input_source.startswith(('http://', 'https://'))
84
-
85
-
86
- def validate_input_file(file_path: str) -> bool:
87
- """Validate that the input file exists and is readable.
88
-
89
- Args:
90
- file_path: Path to the input file
91
-
92
- Returns:
93
- True if the file is valid, False otherwise
94
- """
95
- if not os.path.exists(file_path):
96
- logger.error(f"Error: File not found at {file_path}")
97
- return False
98
-
99
- if not os.path.isfile(file_path):
100
- logger.error(f"Error: {file_path} is not a file")
101
- return False
102
-
103
- if not os.access(file_path, os.R_OK):
104
- logger.error(f"Error: No permission to read {file_path}")
105
- return False
106
-
107
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chorus_detection/utils/logging.py DELETED
@@ -1,53 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """Logging configuration for the chorus detection system."""
5
-
6
- import logging
7
- import os
8
- import sys
9
- from typing import Optional
10
-
11
- from chorus_detection.config import PROJECT_ROOT
12
-
13
- # Create logs directory
14
- LOGS_DIR = PROJECT_ROOT / "logs"
15
- os.makedirs(LOGS_DIR, exist_ok=True)
16
-
17
-
18
- def setup_logger(name: str = "chorus_detection", level: int = logging.INFO,
19
- log_file: Optional[str] = None) -> logging.Logger:
20
- """Configure and return a logger with the specified name and level.
21
-
22
- Args:
23
- name: Name of the logger
24
- level: Logging level (default: INFO)
25
- log_file: Path to the log file (default: None)
26
-
27
- Returns:
28
- Configured logger instance
29
- """
30
- logger = logging.getLogger(name)
31
- logger.setLevel(level)
32
-
33
- # Create formatter
34
- formatter = logging.Formatter(
35
- "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
36
- )
37
-
38
- # Create console handler
39
- console_handler = logging.StreamHandler(sys.stdout)
40
- console_handler.setFormatter(formatter)
41
- logger.addHandler(console_handler)
42
-
43
- # Create file handler if log_file is specified
44
- if log_file:
45
- file_handler = logging.FileHandler(log_file)
46
- file_handler.setFormatter(formatter)
47
- logger.addHandler(file_handler)
48
-
49
- return logger
50
-
51
-
52
- # Create default logger
53
- logger = setup_logger()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/chorus_detection/visualization/__init__.py DELETED
File without changes
src/chorus_detection/visualization/plotter.py DELETED
@@ -1,78 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """Module for visualizing audio data and chorus predictions."""
5
-
6
- from typing import List
7
-
8
- import librosa
9
- import librosa.display
10
- import matplotlib.pyplot as plt
11
- import numpy as np
12
- import os
13
-
14
- from chorus_detection.audio.processor import AudioFeature
15
-
16
-
17
- def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None:
18
- """Draw meter grid lines on the plot.
19
-
20
- Args:
21
- ax: The matplotlib axes object to draw on
22
- meter_grid_times: Array of times at which to draw the meter lines
23
- """
24
- for time in meter_grid_times:
25
- ax.axvline(x=time, color='grey', linestyle='--',
26
- linewidth=1, alpha=0.6)
27
-
28
-
29
- def plot_predictions(audio_features: AudioFeature, binary_predictions: np.ndarray) -> None:
30
- """Plot the audio waveform and overlay the predicted chorus locations.
31
-
32
- Args:
33
- audio_features: An object containing audio features and components
34
- binary_predictions: Array of binary predictions indicating chorus locations
35
- """
36
- meter_grid_times = librosa.frames_to_time(
37
- audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
38
- fig, ax = plt.subplots(figsize=(12.5, 3), dpi=96)
39
-
40
- # Display harmonic and percussive components
41
- librosa.display.waveshow(audio_features.y_harm, sr=audio_features.sr,
42
- alpha=0.8, ax=ax, color='deepskyblue')
43
- librosa.display.waveshow(audio_features.y_perc, sr=audio_features.sr,
44
- alpha=0.7, ax=ax, color='plum')
45
- plot_meter_lines(ax, meter_grid_times)
46
-
47
- # Highlight chorus sections
48
- for i, prediction in enumerate(binary_predictions):
49
- start_time = meter_grid_times[i]
50
- end_time = meter_grid_times[i + 1] if i < len(
51
- meter_grid_times) - 1 else len(audio_features.y) / audio_features.sr
52
- if prediction == 1:
53
- ax.axvspan(start_time, end_time, color='green', alpha=0.3,
54
- label='Predicted Chorus' if i == 0 else None)
55
-
56
- # Set plot limits and labels
57
- ax.set_xlim([0, len(audio_features.y) / audio_features.sr])
58
- ax.set_ylabel('Amplitude')
59
- audio_file_name = os.path.basename(audio_features.audio_path)
60
- ax.set_title(
61
- f'Chorus Predictions for {os.path.splitext(audio_file_name)[0]}')
62
-
63
- # Add legend
64
- chorus_patch = plt.Rectangle((0, 0), 1, 1, fc='green', alpha=0.3)
65
- handles, labels = ax.get_legend_handles_labels()
66
- handles.append(chorus_patch)
67
- labels.append('Chorus')
68
- ax.legend(handles=handles, labels=labels)
69
-
70
- # Set x-tick labels in minutes:seconds format
71
- duration = len(audio_features.y) / audio_features.sr
72
- xticks = np.arange(0, duration, 10)
73
- xlabels = [f"{int(tick // 60)}:{int(tick % 60):02d}" for tick in xticks]
74
- ax.set_xticks(xticks)
75
- ax.set_xticklabels(xlabels)
76
-
77
- plt.tight_layout()
78
- plt.show(block=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/download_model.py DELETED
@@ -1,188 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """Script to download the chorus detection model from HuggingFace.
5
-
6
- This script checks if the model file exists locally, and if not, downloads it
7
- from the specified HuggingFace repository.
8
- """
9
-
10
- import os
11
- import sys
12
- from pathlib import Path
13
- import logging
14
-
15
- # Configure logging
16
- logging.basicConfig(
17
- level=logging.INFO,
18
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
19
- )
20
- logger = logging.getLogger("model-downloader")
21
-
22
- # Debug environment info
23
- logger.info(f"Current working directory: {os.getcwd()}")
24
- logger.info(f"Python path: {sys.path}")
25
- logger.info(f"MODEL_REVISION: {os.environ.get('MODEL_REVISION')}")
26
- logger.info(f"MODEL_HF_REPO: {os.environ.get('MODEL_HF_REPO')}")
27
- logger.info(f"HF_MODEL_FILENAME: {os.environ.get('HF_MODEL_FILENAME')}")
28
-
29
- # Use huggingface_hub for better integration with HF ecosystem
30
- try:
31
- from huggingface_hub import hf_hub_download
32
- HF_HUB_AVAILABLE = True
33
- logger.info("huggingface_hub is available")
34
- except ImportError:
35
- HF_HUB_AVAILABLE = False
36
- logger.warning("huggingface_hub is not available, falling back to direct download")
37
- import requests
38
- from tqdm import tqdm
39
-
40
- def download_file_with_progress(url: str, destination: Path) -> None:
41
- """Download a file with a progress bar.
42
-
43
- Args:
44
- url: URL to download from
45
- destination: Path to save the file to
46
- """
47
- # Create parent directories if they don't exist
48
- destination.parent.mkdir(parents=True, exist_ok=True)
49
-
50
- # Stream the download with progress bar
51
- response = requests.get(url, stream=True)
52
- response.raise_for_status()
53
-
54
- total_size = int(response.headers.get('content-length', 0))
55
- block_size = 1024 # 1 Kibibyte
56
-
57
- logger.info(f"Downloading model from {url}")
58
- logger.info(f"File size: {total_size / (1024*1024):.1f} MB")
59
-
60
- with open(destination, 'wb') as file, tqdm(
61
- desc=destination.name,
62
- total=total_size,
63
- unit='iB',
64
- unit_scale=True,
65
- unit_divisor=1024,
66
- ) as bar:
67
- for data in response.iter_content(block_size):
68
- size = file.write(data)
69
- bar.update(size)
70
-
71
- def ensure_model_exists(
72
- model_filename: str = "best_model_V3.h5",
73
- repo_id: str = None,
74
- model_dir: Path = None,
75
- hf_model_filename: str = None,
76
- revision: str = None
77
- ) -> Path:
78
- """Ensure the model file exists, downloading it if necessary.
79
-
80
- Args:
81
- model_filename: Local filename for the model
82
- repo_id: HuggingFace repository ID
83
- model_dir: Directory to save the model to
84
- hf_model_filename: Filename of the model in the HuggingFace repo
85
- revision: Specific version of the model to use (SHA-256 hash)
86
-
87
- Returns:
88
- Path to the model file
89
- """
90
- # Get parameters from environment variables if not provided
91
- if repo_id is None:
92
- repo_id = os.environ.get("MODEL_HF_REPO", "dennisvdang/chorus-detection")
93
-
94
- if hf_model_filename is None:
95
- hf_model_filename = os.environ.get("HF_MODEL_FILENAME", "chorus_detection_crnn.h5")
96
-
97
- if revision is None:
98
- revision = os.environ.get("MODEL_REVISION", "20e66eb3d0788373c3bdc5b28fa2f2587b0e475f3bbc47e8ab9ff0dbdbb2df32")
99
-
100
- # Handle model directory paths for different environments
101
- if model_dir is None:
102
- # Check if we're in HF Spaces
103
- if os.environ.get("SPACE_ID"):
104
- # Try several possible locations
105
- possible_dirs = [
106
- Path("models/CRNN"),
107
- Path("/home/user/app/models/CRNN"),
108
- Path("/app/models/CRNN"),
109
- Path(os.getcwd()) / "models" / "CRNN"
110
- ]
111
-
112
- for directory in possible_dirs:
113
- if directory.exists() or directory.parent.exists():
114
- model_dir = directory
115
- break
116
-
117
- # If none exist, use the first option and create it
118
- if model_dir is None:
119
- model_dir = possible_dirs[0]
120
- else:
121
- model_dir = Path("models/CRNN")
122
-
123
- # Make sure model_dir is a Path object
124
- if isinstance(model_dir, str):
125
- model_dir = Path(model_dir)
126
-
127
- logger.info(f"Using model directory: {model_dir}")
128
-
129
- model_path = model_dir / model_filename
130
-
131
- # Log environment info when running in HF Space
132
- if os.environ.get("SPACE_ID"):
133
- logger.info(f"Running in Hugging Face Space: {os.environ.get('SPACE_ID')}")
134
- logger.info(f"Using model repo: {repo_id}")
135
- logger.info(f"Using model file: {hf_model_filename}")
136
- logger.info(f"Using revision: {revision}")
137
-
138
- # Check if the model already exists
139
- if model_path.exists():
140
- logger.info(f"Model already exists at {model_path}")
141
- return model_path
142
-
143
- # Create model directory if it doesn't exist
144
- model_dir.mkdir(parents=True, exist_ok=True)
145
-
146
- logger.info(f"Model not found at {model_path}. Downloading...")
147
-
148
- try:
149
- if HF_HUB_AVAILABLE:
150
- # Use huggingface_hub to download the model
151
- logger.info(f"Downloading model from {repo_id}/{hf_model_filename} (revision: {revision}) using huggingface_hub")
152
- downloaded_path = hf_hub_download(
153
- repo_id=repo_id,
154
- filename=hf_model_filename,
155
- local_dir=model_dir,
156
- local_dir_use_symlinks=False,
157
- revision=revision # Specify the exact revision to use
158
- )
159
-
160
- logger.info(f"Downloaded to: {downloaded_path}")
161
-
162
- # Rename if necessary
163
- if os.path.basename(downloaded_path) != model_filename:
164
- downloaded_path_obj = Path(downloaded_path)
165
- model_path.parent.mkdir(parents=True, exist_ok=True)
166
- if model_path.exists():
167
- model_path.unlink()
168
- downloaded_path_obj.rename(model_path)
169
- logger.info(f"Renamed {downloaded_path} to {model_path}")
170
- else:
171
- # Fallback to direct download if huggingface_hub is not available
172
- huggingface_url = f"https://huggingface.co/{repo_id}/resolve/{revision}/{hf_model_filename}"
173
- download_file_with_progress(huggingface_url, model_path)
174
-
175
- logger.info(f"Successfully downloaded model to {model_path}")
176
- return model_path
177
- except Exception as e:
178
- logger.error(f"Failed to download model: {e}", exc_info=True)
179
-
180
- # Handle error more gracefully in production environment
181
- if os.environ.get("SPACE_ID"):
182
- logger.warning("Continuing despite model download failure")
183
- return model_path
184
- else:
185
- sys.exit(1)
186
-
187
- if __name__ == "__main__":
188
- ensure_model_exists()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/streamlit_app.py DELETED
@@ -1,536 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """Streamlit web app for chorus detection in audio files.
5
-
6
- This module provides a web-based interface for the chorus detection system,
7
- allowing users to upload audio files or provide YouTube URLs for analysis.
8
- """
9
-
10
- import os
11
- import sys
12
- import logging
13
-
14
- # Configure logging
15
- logger = logging.getLogger("streamlit-app")
16
-
17
- # Configure TensorFlow logging before importing TensorFlow
18
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logs
19
-
20
- # Ensure proper import paths
21
- current_dir = os.path.dirname(os.path.abspath(__file__))
22
- root_dir = os.path.dirname(current_dir)
23
- if current_dir not in sys.path:
24
- sys.path.insert(0, current_dir)
25
- if root_dir not in sys.path:
26
- sys.path.insert(0, root_dir)
27
-
28
- # Import model downloader to ensure model is available
29
- try:
30
- if os.path.exists(os.path.join(os.getcwd(), "download_model.py")):
31
- # If in the root directory
32
- from download_model import ensure_model_exists
33
- else:
34
- # If in the src directory
35
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
36
- from download_model import ensure_model_exists
37
- except ImportError as e:
38
- logger.error(f"Error importing ensure_model_exists: {e}")
39
- try:
40
- # Try alternative import
41
- from src.download_model import ensure_model_exists
42
- except ImportError as e2:
43
- logger.error(f"Alternative import failed: {e2}")
44
- raise
45
-
46
- import base64
47
- import tempfile
48
- import warnings
49
- from typing import Optional, Tuple, List
50
- import time
51
- import io
52
-
53
- import matplotlib.pyplot as plt
54
- import streamlit as st
55
- import tensorflow as tf
56
- import librosa
57
- import soundfile as sf
58
- import numpy as np
59
- from pydub import AudioSegment
60
-
61
- # Suppress warnings
62
- warnings.filterwarnings("ignore") # Suppress all warnings
63
- tf.get_logger().setLevel('ERROR') # Suppress TensorFlow ERROR logs
64
-
65
- # Debug import paths
66
- logger.info(f"Python path: {sys.path}")
67
- logger.info(f"Current working directory: {os.getcwd()}")
68
-
69
- # First try direct import with src in path
70
- try:
71
- # Add src directory to Python path if not already there
72
- src_path = os.path.dirname(current_dir)
73
- if src_path not in sys.path:
74
- sys.path.insert(0, src_path)
75
-
76
- from chorus_detection.audio.data_processing import process_audio
77
- from chorus_detection.audio.processor import extract_audio
78
- from chorus_detection.models.crnn import load_CRNN_model, make_predictions
79
- from chorus_detection.utils.cli import is_youtube_url
80
- from chorus_detection.utils.logging import logger
81
-
82
- logger.info("Successfully imported chorus_detection modules")
83
- except ImportError as e:
84
- logger.error(f"Error importing chorus_detection modules: {e}")
85
- logger.info("Trying alternative imports...")
86
-
87
- # Try with manual path adjustment
88
- try:
89
- # Adjust import paths - try different directories
90
- possible_paths = [
91
- os.path.join(os.getcwd(), "src"),
92
- os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
93
- os.path.dirname(os.getcwd())
94
- ]
95
-
96
- for path in possible_paths:
97
- if path not in sys.path and os.path.exists(path):
98
- sys.path.insert(0, path)
99
- logger.info(f"Added path to sys.path: {path}")
100
-
101
- # Try importing directly from chorus_detection module path
102
- sys.path.insert(0, os.path.join(os.getcwd(), "src", "chorus_detection"))
103
-
104
- from chorus_detection.audio.data_processing import process_audio
105
- from chorus_detection.audio.processor import extract_audio
106
- from chorus_detection.models.crnn import load_CRNN_model, make_predictions
107
- from chorus_detection.utils.cli import is_youtube_url
108
- from chorus_detection.utils.logging import logger
109
-
110
- logger.info("Successfully imported chorus_detection modules after path adjustment")
111
- except ImportError as e2:
112
- logger.error(f"Alternative imports also failed: {e2}")
113
- raise
114
-
115
- # Define the MODEL_PATH directly
116
- MODEL_PATH = os.path.join(os.getcwd(), "models", "CRNN", "best_model_V3.h5")
117
- if not os.path.exists(MODEL_PATH):
118
- MODEL_PATH = ensure_model_exists()
119
-
120
- # Define color scheme
121
- THEME_COLORS = {
122
- 'background': '#121212',
123
- 'card_bg': '#181818',
124
- 'primary': '#1DB954',
125
- 'secondary': '#1ED760',
126
- 'text': '#FFFFFF',
127
- 'subtext': '#B3B3B3',
128
- 'highlight': '#1DB954',
129
- 'border': '#333333',
130
- }
131
-
132
-
133
- def get_binary_file_downloader_html(bin_file: str, file_label: str = 'File') -> str:
134
- """Generate HTML for file download link.
135
-
136
- Args:
137
- bin_file: Path to the binary file
138
- file_label: Label for the download link
139
-
140
- Returns:
141
- HTML string for the download link
142
- """
143
- with open(bin_file, 'rb') as f:
144
- data = f.read()
145
- b64 = base64.b64encode(data).decode()
146
- return f'<a href="data:application/octet-stream;base64,{b64}" download="{os.path.basename(bin_file)}">{file_label}</a>'
147
-
148
-
149
- def set_custom_theme() -> None:
150
- """Apply custom Spotify-inspired theme to Streamlit UI."""
151
- custom_theme = f"""
152
- <style>
153
- .stApp {{
154
- background-color: {THEME_COLORS['background']};
155
- color: {THEME_COLORS['text']};
156
- }}
157
- .css-18e3th9 {{
158
- padding-top: 2rem;
159
- padding-bottom: 10rem;
160
- padding-left: 5rem;
161
- padding-right: 5rem;
162
- }}
163
- h1, h2, h3, h4, h5, h6 {{
164
- color: {THEME_COLORS['text']} !important;
165
- font-weight: 700 !important;
166
- }}
167
- .stSidebar .sidebar-content {{
168
- background-color: {THEME_COLORS['card_bg']};
169
- }}
170
- .stButton>button {{
171
- background-color: {THEME_COLORS['primary']};
172
- color: white;
173
- border-radius: 500px;
174
- padding: 8px 32px;
175
- font-weight: 600;
176
- border: none;
177
- transition: all 0.3s ease;
178
- }}
179
- .stButton>button:hover {{
180
- background-color: {THEME_COLORS['secondary']};
181
- transform: scale(1.04);
182
- }}
183
- </style>
184
- """
185
- st.markdown(custom_theme, unsafe_allow_html=True)
186
-
187
-
188
- def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
189
- """Process a YouTube URL and extract audio.
190
-
191
- Args:
192
- url: YouTube URL
193
-
194
- Returns:
195
- Tuple of (audio_path, video_name)
196
- """
197
- try:
198
- with st.spinner('Downloading audio from YouTube...'):
199
- audio_path, video_name = extract_audio(url)
200
- return audio_path, video_name
201
- except Exception as e:
202
- st.error(f"Error processing YouTube URL: {e}")
203
- logger.error(f"Error processing YouTube URL: {e}", exc_info=True)
204
- return None, None
205
-
206
-
207
- def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
208
- """Process an uploaded audio file.
209
-
210
- Args:
211
- uploaded_file: Streamlit UploadedFile object
212
-
213
- Returns:
214
- Tuple of (audio_path, file_name)
215
- """
216
- try:
217
- with st.spinner('Processing uploaded file...'):
218
- # Save the uploaded file to a temporary location
219
- temp_dir = tempfile.mkdtemp()
220
- file_name = uploaded_file.name
221
- temp_path = os.path.join(temp_dir, file_name)
222
-
223
- with open(temp_path, 'wb') as f:
224
- f.write(uploaded_file.getbuffer())
225
-
226
- return temp_path, file_name.split('.')[0]
227
- except Exception as e:
228
- st.error(f"Error processing uploaded file: {e}")
229
- logger.error(f"Error processing uploaded file: {e}", exc_info=True)
230
- return None, None
231
-
232
-
233
- def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.ndarray,
234
- meter_grid_times: np.ndarray) -> List[Tuple[float, float, np.ndarray]]:
235
- """Extract chorus segments from predictions.
236
-
237
- Args:
238
- y: Audio data
239
- sr: Sample rate
240
- smoothed_predictions: Smoothed model predictions
241
- meter_grid_times: Time grid for predictions
242
-
243
- Returns:
244
- List of (start_time, end_time, audio_segment) tuples
245
- """
246
- # Define threshold for chorus detection (probability > 0.5)
247
- threshold = 0.5
248
-
249
- # Find the segments where the predictions are above the threshold
250
- chorus_mask = smoothed_predictions > threshold
251
-
252
- # Group consecutive True values to identify segments
253
- segments = []
254
- current_segment = None
255
-
256
- for i, is_chorus in enumerate(chorus_mask):
257
- time = meter_grid_times[i]
258
-
259
- if is_chorus and current_segment is None:
260
- # Start a new segment
261
- current_segment = (time, None, None)
262
- elif not is_chorus and current_segment is not None:
263
- # End the current segment
264
- start_time = current_segment[0]
265
- current_segment = (start_time, time, None)
266
- segments.append(current_segment)
267
- current_segment = None
268
-
269
- # Handle the case where the last segment extends to the end of the song
270
- if current_segment is not None:
271
- start_time = current_segment[0]
272
- segments.append((start_time, meter_grid_times[-1], None))
273
-
274
- # Extract the actual audio for each segment
275
- segments_with_audio = []
276
- for start_time, end_time, _ in segments:
277
- # Convert times to sample indices
278
- start_idx = int(start_time * sr)
279
- end_idx = int(end_time * sr)
280
-
281
- # Extract the audio segment
282
- segment_audio = y[start_idx:end_idx]
283
-
284
- segments_with_audio.append((start_time, end_time, segment_audio))
285
-
286
- return segments_with_audio
287
-
288
-
289
- def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]],
290
- sr: int, fade_duration: float = 0.3) -> Tuple[np.ndarray, str]:
291
- """Create a compilation of chorus segments.
292
-
293
- Args:
294
- segments: List of (start_time, end_time, audio_data) tuples
295
- sr: Sample rate
296
- fade_duration: Duration of fade in/out in seconds
297
-
298
- Returns:
299
- Tuple of (compilation_audio, description)
300
- """
301
- if not segments:
302
- return np.array([]), "No chorus segments found"
303
-
304
- # Calculate the number of samples for fading
305
- fade_samples = int(fade_duration * sr)
306
-
307
- # Prepare a list to store the processed segments
308
- processed_segments = []
309
-
310
- # Description of segments
311
- segment_descriptions = []
312
-
313
- # Process each segment
314
- for i, (start_time, end_time, audio) in enumerate(segments):
315
- # Apply fade in and fade out
316
- segment_length = len(audio)
317
-
318
- if segment_length <= 2 * fade_samples:
319
- # Segment is too short for fading, skip it
320
- continue
321
-
322
- # Create a linear fade in and fade out
323
- fade_in = np.linspace(0, 1, fade_samples)
324
- fade_out = np.linspace(1, 0, fade_samples)
325
-
326
- # Apply the fades
327
- audio_faded = audio.copy()
328
- audio_faded[:fade_samples] *= fade_in
329
- audio_faded[-fade_samples:] *= fade_out
330
-
331
- processed_segments.append(audio_faded)
332
-
333
- # Format the times for the description
334
- start_fmt = format_time(start_time)
335
- end_fmt = format_time(end_time)
336
- segment_descriptions.append(f"Chorus {i+1}: {start_fmt} - {end_fmt}")
337
-
338
- if not processed_segments:
339
- return np.array([]), "No chorus segments long enough for compilation"
340
-
341
- # Concatenate all the processed segments
342
- compilation = np.concatenate(processed_segments)
343
-
344
- # Join the descriptions
345
- description = "\n".join(segment_descriptions)
346
-
347
- return compilation, description
348
-
349
-
350
- def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str = 'mp3') -> bytes:
351
- """Save audio data to a format suitable for Streamlit audio playback.
352
-
353
- Args:
354
- audio_data: Audio samples
355
- sr: Sample rate
356
- file_format: Output format ('mp3', 'wav', etc.)
357
-
358
- Returns:
359
- Audio bytes
360
- """
361
- with io.BytesIO() as buffer:
362
- sf.write(buffer, audio_data, sr, format=file_format)
363
- buffer.seek(0)
364
- return buffer.read()
365
-
366
-
367
- def format_time(seconds: float) -> str:
368
- """Format seconds as MM:SS.
369
-
370
- Args:
371
- seconds: Time in seconds
372
-
373
- Returns:
374
- Formatted time string
375
- """
376
- minutes = int(seconds // 60)
377
- seconds = int(seconds % 60)
378
- return f"{minutes:02d}:{seconds:02d}"
379
-
380
-
381
- def main() -> None:
382
- """Main function for the Streamlit app."""
383
- # Set page config
384
- st.set_page_config(
385
- page_title="Chorus Detection",
386
- page_icon="🎵",
387
- layout="wide",
388
- initial_sidebar_state="collapsed",
389
- )
390
-
391
- # Apply custom theme
392
- set_custom_theme()
393
-
394
- # App title and description
395
- st.title("Chorus Detection")
396
- st.markdown("""
397
- <div class="subheader">
398
- Upload a song or enter a YouTube URL to automatically detect chorus sections using AI
399
- </div>
400
- """, unsafe_allow_html=True)
401
-
402
- # User input section
403
- col1, col2 = st.columns(2)
404
-
405
- with col1:
406
- st.markdown('<div class="input-option">', unsafe_allow_html=True)
407
- st.subheader("Option 1: Upload an audio file")
408
- uploaded_file = st.file_uploader("Choose an audio file", type=['mp3', 'wav', 'ogg', 'flac', 'm4a'])
409
- st.markdown('</div>', unsafe_allow_html=True)
410
-
411
- with col2:
412
- st.markdown('<div class="input-option">', unsafe_allow_html=True)
413
- st.subheader("Option 2: YouTube URL")
414
- youtube_url = st.text_input("Enter a YouTube URL", placeholder="https://www.youtube.com/watch?v=...")
415
- st.markdown('</div>', unsafe_allow_html=True)
416
-
417
- # Process button
418
- if st.button("Analyze"):
419
- # Check the input method
420
- audio_path = None
421
- file_name = None
422
-
423
- if uploaded_file is not None:
424
- audio_path, file_name = process_uploaded_file(uploaded_file)
425
- elif youtube_url:
426
- if is_youtube_url(youtube_url):
427
- audio_path, file_name = process_youtube(youtube_url)
428
- else:
429
- st.error("Invalid YouTube URL. Please enter a valid YouTube URL.")
430
- else:
431
- st.error("Please upload an audio file or enter a YouTube URL.")
432
-
433
- # If we have a valid audio path, process it
434
- if audio_path and file_name:
435
- try:
436
- # Load and process the audio file
437
- with st.spinner('Processing audio...'):
438
- # Load audio and extract features
439
- y, sr = librosa.load(audio_path, sr=22050)
440
-
441
- # Create a temporary directory for model output
442
- temp_output_dir = tempfile.mkdtemp()
443
-
444
- # Load the model
445
- model = load_CRNN_model(MODEL_PATH)
446
-
447
- # Process audio and make predictions
448
- audio_features, _ = process_audio(audio_path, output_path=temp_output_dir)
449
- meter_grid_times, predictions = make_predictions(model, audio_features)
450
-
451
- # Smooth predictions to avoid rapid transitions
452
- smoothed_predictions = np.convolve(predictions,
453
- np.ones(5)/5,
454
- mode='same')
455
-
456
- # Extract chorus segments
457
- chorus_segments = extract_chorus_segments(y, sr, smoothed_predictions, meter_grid_times)
458
-
459
- # Create a chorus compilation
460
- compilation_audio, segments_desc = create_chorus_compilation(chorus_segments, sr)
461
-
462
- # Display results
463
- st.markdown(f"""
464
- <div class="result-container">
465
- <div class="song-title">{file_name}</div>
466
- </div>
467
- """, unsafe_allow_html=True)
468
-
469
- # Display waveform with highlighted chorus sections
470
- fig, ax = plt.subplots(figsize=(14, 5))
471
-
472
- # Plot the waveform
473
- times = np.linspace(0, len(y)/sr, len(y))
474
- ax.plot(times, y, color='#b3b3b3', alpha=0.5, linewidth=1)
475
- ax.set_xlabel('Time (s)')
476
- ax.set_ylabel('Amplitude')
477
- ax.set_title('Audio Waveform with Chorus Sections Highlighted')
478
-
479
- # Highlight chorus sections
480
- for start_time, end_time, _ in chorus_segments:
481
- ax.axvspan(start_time, end_time, alpha=0.3, color=THEME_COLORS['primary'])
482
-
483
- # Add a label at the start of each chorus
484
- ax.annotate('Chorus',
485
- xy=(start_time, 0.8 * max(y)),
486
- xytext=(start_time + 0.5, 0.9 * max(y)),
487
- color=THEME_COLORS['primary'],
488
- weight='bold')
489
-
490
- # Customize plot appearance
491
- ax.set_facecolor(THEME_COLORS['card_bg'])
492
- fig.patch.set_facecolor(THEME_COLORS['background'])
493
- ax.spines['top'].set_visible(False)
494
- ax.spines['right'].set_visible(False)
495
- ax.spines['bottom'].set_color(THEME_COLORS['border'])
496
- ax.spines['left'].set_color(THEME_COLORS['border'])
497
- ax.tick_params(axis='x', colors=THEME_COLORS['text'])
498
- ax.tick_params(axis='y', colors=THEME_COLORS['text'])
499
- ax.xaxis.label.set_color(THEME_COLORS['text'])
500
- ax.yaxis.label.set_color(THEME_COLORS['text'])
501
- ax.title.set_color(THEME_COLORS['text'])
502
-
503
- st.pyplot(fig)
504
-
505
- # Display chorus segments
506
- if chorus_segments:
507
- st.markdown('<div class="chorus-card">', unsafe_allow_html=True)
508
- st.subheader("Chorus Segments")
509
- for i, (start_time, end_time, segment_audio) in enumerate(chorus_segments):
510
- st.markdown(f"""
511
- <div class="time-stamp">Chorus {i+1}: {format_time(start_time)} - {format_time(end_time)}</div>
512
- """, unsafe_allow_html=True)
513
-
514
- # Convert segment audio to bytes for playback
515
- audio_bytes = save_audio_for_streamlit(segment_audio, sr)
516
- st.audio(audio_bytes, format='audio/mp3')
517
- st.markdown('</div>', unsafe_allow_html=True)
518
-
519
- # Chorus compilation
520
- if len(compilation_audio) > 0:
521
- st.markdown('<div class="chorus-card">', unsafe_allow_html=True)
522
- st.subheader("Chorus Compilation")
523
- st.markdown("All chorus segments combined into one track:")
524
-
525
- compilation_bytes = save_audio_for_streamlit(compilation_audio, sr)
526
- st.audio(compilation_bytes, format='audio/mp3')
527
- st.markdown('</div>', unsafe_allow_html=True)
528
- else:
529
- st.info("No chorus sections detected in this audio.")
530
-
531
- except Exception as e:
532
- st.error(f"Error processing audio: {e}")
533
- logger.error(f"Error processing audio: {e}", exc_info=True)
534
-
535
- if __name__ == "__main__":
536
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
streamlit_app.py CHANGED
@@ -1,35 +1,18 @@
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
 
4
- """Streamlit web app for chorus detection in audio files.
5
-
6
- This module provides a web-based interface for the chorus detection system,
7
- allowing users to upload audio files or provide YouTube URLs for analysis.
8
  """
9
 
10
  import os
11
  import sys
12
  import logging
13
-
14
- # Configure logging
15
- logger = logging.getLogger("streamlit-app")
16
-
17
- # Configure TensorFlow logging before importing TensorFlow
18
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow logs
19
-
20
- # Import model downloader to ensure model is available
21
- try:
22
- from download_model import ensure_model_exists
23
- except ImportError as e:
24
- logger.error(f"Error importing ensure_model_exists: {e}")
25
- raise
26
-
27
  import base64
28
  import tempfile
29
  import warnings
30
- from typing import Optional, Tuple, List
31
- import time
32
  import io
 
33
 
34
  import matplotlib.pyplot as plt
35
  import streamlit as st
@@ -39,33 +22,33 @@ import soundfile as sf
39
  import numpy as np
40
  from pydub import AudioSegment
41
 
42
- # Suppress warnings
43
- warnings.filterwarnings("ignore") # Suppress all warnings
44
- tf.get_logger().setLevel('ERROR') # Suppress TensorFlow ERROR logs
45
 
46
- # Debug import paths
47
- logger.info(f"Python path: {sys.path}")
48
- logger.info(f"Current working directory: {os.getcwd()}")
 
49
 
50
- # Import modules
51
  try:
 
52
  from chorus_detection.audio.data_processing import process_audio
53
  from chorus_detection.audio.processor import extract_audio
54
  from chorus_detection.models.crnn import load_CRNN_model, make_predictions
55
  from chorus_detection.utils.cli import is_youtube_url
56
  from chorus_detection.utils.logging import logger
57
-
58
  logger.info("Successfully imported chorus_detection modules")
59
  except ImportError as e:
60
- logger.error(f"Error importing chorus_detection modules: {e}")
61
  raise
62
 
63
- # Define the MODEL_PATH directly
64
  MODEL_PATH = os.path.join(os.getcwd(), "models", "CRNN", "best_model_V3.h5")
65
  if not os.path.exists(MODEL_PATH):
66
  MODEL_PATH = ensure_model_exists()
67
 
68
- # Define color scheme
69
  THEME_COLORS = {
70
  'background': '#121212',
71
  'card_bg': '#181818',
@@ -79,15 +62,7 @@ THEME_COLORS = {
79
 
80
 
81
  def get_binary_file_downloader_html(bin_file: str, file_label: str = 'File') -> str:
82
- """Generate HTML for file download link.
83
-
84
- Args:
85
- bin_file: Path to the binary file
86
- file_label: Label for the download link
87
-
88
- Returns:
89
- HTML string for the download link
90
- """
91
  with open(bin_file, 'rb') as f:
92
  data = f.read()
93
  b64 = base64.b64encode(data).decode()
@@ -134,14 +109,7 @@ def set_custom_theme() -> None:
134
 
135
 
136
  def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
137
- """Process a YouTube URL and extract audio.
138
-
139
- Args:
140
- url: YouTube URL
141
-
142
- Returns:
143
- Tuple of (audio_path, video_name)
144
- """
145
  try:
146
  with st.spinner('Downloading audio from YouTube...'):
147
  audio_path, video_name = extract_audio(url)
@@ -153,17 +121,9 @@ def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
153
 
154
 
155
  def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
156
- """Process an uploaded audio file.
157
-
158
- Args:
159
- uploaded_file: Streamlit UploadedFile object
160
-
161
- Returns:
162
- Tuple of (audio_path, file_name)
163
- """
164
  try:
165
  with st.spinner('Processing uploaded file...'):
166
- # Save the uploaded file to a temporary location
167
  temp_dir = tempfile.mkdtemp()
168
  file_name = uploaded_file.name
169
  temp_path = os.path.join(temp_dir, file_name)
@@ -180,24 +140,9 @@ def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
180
 
181
  def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.ndarray,
182
  meter_grid_times: np.ndarray) -> List[Tuple[float, float, np.ndarray]]:
183
- """Extract chorus segments from predictions.
184
-
185
- Args:
186
- y: Audio data
187
- sr: Sample rate
188
- smoothed_predictions: Smoothed model predictions
189
- meter_grid_times: Time grid for predictions
190
-
191
- Returns:
192
- List of (start_time, end_time, audio_segment) tuples
193
- """
194
- # Define threshold for chorus detection (probability > 0.5)
195
  threshold = 0.5
196
-
197
- # Find the segments where the predictions are above the threshold
198
  chorus_mask = smoothed_predictions > threshold
199
-
200
- # Group consecutive True values to identify segments
201
  segments = []
202
  current_segment = None
203
 
@@ -205,10 +150,8 @@ def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.nda
205
  time = meter_grid_times[i]
206
 
207
  if is_chorus and current_segment is None:
208
- # Start a new segment
209
  current_segment = (time, None, None)
210
  elif not is_chorus and current_segment is not None:
211
- # End the current segment
212
  start_time = current_segment[0]
213
  current_segment = (start_time, time, None)
214
  segments.append(current_segment)
@@ -222,13 +165,9 @@ def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.nda
222
  # Extract the actual audio for each segment
223
  segments_with_audio = []
224
  for start_time, end_time, _ in segments:
225
- # Convert times to sample indices
226
  start_idx = int(start_time * sr)
227
  end_idx = int(end_time * sr)
228
-
229
- # Extract the audio segment
230
  segment_audio = y[start_idx:end_idx]
231
-
232
  segments_with_audio.append((start_time, end_time, segment_audio))
233
 
234
  return segments_with_audio
@@ -236,49 +175,29 @@ def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.nda
236
 
237
  def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]],
238
  sr: int, fade_duration: float = 0.3) -> Tuple[np.ndarray, str]:
239
- """Create a compilation of chorus segments.
240
-
241
- Args:
242
- segments: List of (start_time, end_time, audio_data) tuples
243
- sr: Sample rate
244
- fade_duration: Duration of fade in/out in seconds
245
-
246
- Returns:
247
- Tuple of (compilation_audio, description)
248
- """
249
  if not segments:
250
  return np.array([]), "No chorus segments found"
251
 
252
- # Calculate the number of samples for fading
253
  fade_samples = int(fade_duration * sr)
254
-
255
- # Prepare a list to store the processed segments
256
  processed_segments = []
257
-
258
- # Description of segments
259
  segment_descriptions = []
260
 
261
- # Process each segment
262
  for i, (start_time, end_time, audio) in enumerate(segments):
263
- # Apply fade in and fade out
264
  segment_length = len(audio)
265
 
266
  if segment_length <= 2 * fade_samples:
267
- # Segment is too short for fading, skip it
268
  continue
269
 
270
- # Create a linear fade in and fade out
271
  fade_in = np.linspace(0, 1, fade_samples)
272
  fade_out = np.linspace(1, 0, fade_samples)
273
 
274
- # Apply the fades
275
  audio_faded = audio.copy()
276
  audio_faded[:fade_samples] *= fade_in
277
  audio_faded[-fade_samples:] *= fade_out
278
 
279
  processed_segments.append(audio_faded)
280
 
281
- # Format the times for the description
282
  start_fmt = format_time(start_time)
283
  end_fmt = format_time(end_time)
284
  segment_descriptions.append(f"Chorus {i+1}: {start_fmt} - {end_fmt}")
@@ -286,26 +205,14 @@ def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]],
286
  if not processed_segments:
287
  return np.array([]), "No chorus segments long enough for compilation"
288
 
289
- # Concatenate all the processed segments
290
  compilation = np.concatenate(processed_segments)
291
-
292
- # Join the descriptions
293
  description = "\n".join(segment_descriptions)
294
 
295
  return compilation, description
296
 
297
 
298
  def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str = 'mp3') -> bytes:
299
- """Save audio data to a format suitable for Streamlit audio playback.
300
-
301
- Args:
302
- audio_data: Audio samples
303
- sr: Sample rate
304
- file_format: Output format ('mp3', 'wav', etc.)
305
-
306
- Returns:
307
- Audio bytes
308
- """
309
  with io.BytesIO() as buffer:
310
  sf.write(buffer, audio_data, sr, format=file_format)
311
  buffer.seek(0)
@@ -313,14 +220,7 @@ def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str =
313
 
314
 
315
  def format_time(seconds: float) -> str:
316
- """Format seconds as MM:SS.
317
-
318
- Args:
319
- seconds: Time in seconds
320
-
321
- Returns:
322
- Formatted time string
323
- """
324
  minutes = int(seconds // 60)
325
  seconds = int(seconds % 60)
326
  return f"{minutes:02d}:{seconds:02d}"
@@ -385,11 +285,7 @@ def main() -> None:
385
  with st.spinner('Processing audio...'):
386
  # Load audio and extract features
387
  y, sr = librosa.load(audio_path, sr=22050)
388
-
389
- # Create a temporary directory for model output
390
  temp_output_dir = tempfile.mkdtemp()
391
-
392
- # Load the model
393
  model = load_CRNN_model(MODEL_PATH)
394
 
395
  # Process audio and make predictions
@@ -397,14 +293,10 @@ def main() -> None:
397
  meter_grid_times, predictions = make_predictions(model, audio_features)
398
 
399
  # Smooth predictions to avoid rapid transitions
400
- smoothed_predictions = np.convolve(predictions,
401
- np.ones(5)/5,
402
- mode='same')
403
 
404
- # Extract chorus segments
405
  chorus_segments = extract_chorus_segments(y, sr, smoothed_predictions, meter_grid_times)
406
-
407
- # Create a chorus compilation
408
  compilation_audio, segments_desc = create_chorus_compilation(chorus_segments, sr)
409
 
410
  # Display results
@@ -427,8 +319,6 @@ def main() -> None:
427
  # Highlight chorus sections
428
  for start_time, end_time, _ in chorus_segments:
429
  ax.axvspan(start_time, end_time, alpha=0.3, color=THEME_COLORS['primary'])
430
-
431
- # Add a label at the start of each chorus
432
  ax.annotate('Chorus',
433
  xy=(start_time, 0.8 * max(y)),
434
  xytext=(start_time + 0.5, 0.9 * max(y)),
 
1
  #!/usr/bin/env python3
2
  # -*- coding: utf-8 -*-
3
 
4
+ """
5
+ Streamlit web app for chorus detection in audio files.
 
 
6
  """
7
 
8
  import os
9
  import sys
10
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  import base64
12
  import tempfile
13
  import warnings
 
 
14
  import io
15
+ from typing import Optional, Tuple, List
16
 
17
  import matplotlib.pyplot as plt
18
  import streamlit as st
 
22
  import numpy as np
23
  from pydub import AudioSegment
24
 
25
+ # Configure logging
26
+ logger = logging.getLogger("streamlit-app")
 
27
 
28
+ # Suppress TensorFlow and other warnings
29
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
30
+ warnings.filterwarnings("ignore")
31
+ tf.get_logger().setLevel('ERROR')
32
 
33
+ # Import components
34
  try:
35
+ from download_model import ensure_model_exists
36
  from chorus_detection.audio.data_processing import process_audio
37
  from chorus_detection.audio.processor import extract_audio
38
  from chorus_detection.models.crnn import load_CRNN_model, make_predictions
39
  from chorus_detection.utils.cli import is_youtube_url
40
  from chorus_detection.utils.logging import logger
 
41
  logger.info("Successfully imported chorus_detection modules")
42
  except ImportError as e:
43
+ logger.error(f"Error importing modules: {e}")
44
  raise
45
 
46
+ # Define model path
47
  MODEL_PATH = os.path.join(os.getcwd(), "models", "CRNN", "best_model_V3.h5")
48
  if not os.path.exists(MODEL_PATH):
49
  MODEL_PATH = ensure_model_exists()
50
 
51
+ # UI theme colors
52
  THEME_COLORS = {
53
  'background': '#121212',
54
  'card_bg': '#181818',
 
62
 
63
 
64
  def get_binary_file_downloader_html(bin_file: str, file_label: str = 'File') -> str:
65
+ """Generate HTML for file download link."""
 
 
 
 
 
 
 
 
66
  with open(bin_file, 'rb') as f:
67
  data = f.read()
68
  b64 = base64.b64encode(data).decode()
 
109
 
110
 
111
  def process_youtube(url: str) -> Tuple[Optional[str], Optional[str]]:
112
+ """Process a YouTube URL and extract audio."""
 
 
 
 
 
 
 
113
  try:
114
  with st.spinner('Downloading audio from YouTube...'):
115
  audio_path, video_name = extract_audio(url)
 
121
 
122
 
123
  def process_uploaded_file(uploaded_file) -> Tuple[Optional[str], Optional[str]]:
124
+ """Process an uploaded audio file."""
 
 
 
 
 
 
 
125
  try:
126
  with st.spinner('Processing uploaded file...'):
 
127
  temp_dir = tempfile.mkdtemp()
128
  file_name = uploaded_file.name
129
  temp_path = os.path.join(temp_dir, file_name)
 
140
 
141
  def extract_chorus_segments(y: np.ndarray, sr: int, smoothed_predictions: np.ndarray,
142
  meter_grid_times: np.ndarray) -> List[Tuple[float, float, np.ndarray]]:
143
+ """Extract chorus segments from predictions."""
 
 
 
 
 
 
 
 
 
 
 
144
  threshold = 0.5
 
 
145
  chorus_mask = smoothed_predictions > threshold
 
 
146
  segments = []
147
  current_segment = None
148
 
 
150
  time = meter_grid_times[i]
151
 
152
  if is_chorus and current_segment is None:
 
153
  current_segment = (time, None, None)
154
  elif not is_chorus and current_segment is not None:
 
155
  start_time = current_segment[0]
156
  current_segment = (start_time, time, None)
157
  segments.append(current_segment)
 
165
  # Extract the actual audio for each segment
166
  segments_with_audio = []
167
  for start_time, end_time, _ in segments:
 
168
  start_idx = int(start_time * sr)
169
  end_idx = int(end_time * sr)
 
 
170
  segment_audio = y[start_idx:end_idx]
 
171
  segments_with_audio.append((start_time, end_time, segment_audio))
172
 
173
  return segments_with_audio
 
175
 
176
  def create_chorus_compilation(segments: List[Tuple[float, float, np.ndarray]],
177
  sr: int, fade_duration: float = 0.3) -> Tuple[np.ndarray, str]:
178
+ """Create a compilation of chorus segments."""
 
 
 
 
 
 
 
 
 
179
  if not segments:
180
  return np.array([]), "No chorus segments found"
181
 
 
182
  fade_samples = int(fade_duration * sr)
 
 
183
  processed_segments = []
 
 
184
  segment_descriptions = []
185
 
 
186
  for i, (start_time, end_time, audio) in enumerate(segments):
 
187
  segment_length = len(audio)
188
 
189
  if segment_length <= 2 * fade_samples:
 
190
  continue
191
 
 
192
  fade_in = np.linspace(0, 1, fade_samples)
193
  fade_out = np.linspace(1, 0, fade_samples)
194
 
 
195
  audio_faded = audio.copy()
196
  audio_faded[:fade_samples] *= fade_in
197
  audio_faded[-fade_samples:] *= fade_out
198
 
199
  processed_segments.append(audio_faded)
200
 
 
201
  start_fmt = format_time(start_time)
202
  end_fmt = format_time(end_time)
203
  segment_descriptions.append(f"Chorus {i+1}: {start_fmt} - {end_fmt}")
 
205
  if not processed_segments:
206
  return np.array([]), "No chorus segments long enough for compilation"
207
 
 
208
  compilation = np.concatenate(processed_segments)
 
 
209
  description = "\n".join(segment_descriptions)
210
 
211
  return compilation, description
212
 
213
 
214
  def save_audio_for_streamlit(audio_data: np.ndarray, sr: int, file_format: str = 'mp3') -> bytes:
215
+ """Save audio data to a format suitable for Streamlit audio playback."""
 
 
 
 
 
 
 
 
 
216
  with io.BytesIO() as buffer:
217
  sf.write(buffer, audio_data, sr, format=file_format)
218
  buffer.seek(0)
 
220
 
221
 
222
  def format_time(seconds: float) -> str:
223
+ """Format seconds as MM:SS."""
 
 
 
 
 
 
 
224
  minutes = int(seconds // 60)
225
  seconds = int(seconds % 60)
226
  return f"{minutes:02d}:{seconds:02d}"
 
285
  with st.spinner('Processing audio...'):
286
  # Load audio and extract features
287
  y, sr = librosa.load(audio_path, sr=22050)
 
 
288
  temp_output_dir = tempfile.mkdtemp()
 
 
289
  model = load_CRNN_model(MODEL_PATH)
290
 
291
  # Process audio and make predictions
 
293
  meter_grid_times, predictions = make_predictions(model, audio_features)
294
 
295
  # Smooth predictions to avoid rapid transitions
296
+ smoothed_predictions = np.convolve(predictions, np.ones(5)/5, mode='same')
 
 
297
 
298
+ # Extract chorus segments and create compilation
299
  chorus_segments = extract_chorus_segments(y, sr, smoothed_predictions, meter_grid_times)
 
 
300
  compilation_audio, segments_desc = create_chorus_compilation(chorus_segments, sr)
301
 
302
  # Display results
 
319
  # Highlight chorus sections
320
  for start_time, end_time, _ in chorus_segments:
321
  ax.axvspan(start_time, end_time, alpha=0.3, color=THEME_COLORS['primary'])
 
 
322
  ax.annotate('Chorus',
323
  xy=(start_time, 0.8 * max(y)),
324
  xytext=(start_time + 0.5, 0.9 * max(y)),